当前位置: 首页 > article >正文

神经网络入门实战:(十八)Argmax函数的详细介绍,可以用来计算模型训练准确率

Argmax函数介绍

在Python中,argmax 函数通常用于找出给定数组或列表中元素值最大的索引。

(一) Numpy 中的 Argmax 函数:

numpy.argmax 函数用于找出给定轴(axis)上最大值所在的索引。

示例:

import numpy as np
# 一维数组
arr = np.array([1, 3, 2, 5, 4])
index = np.argmax(arr)
print(index)  

# 二维数组
arr_2d = np.array([[1, 3, 2], [5, 4, 6]])
index_axis_0 = np.argmax(arr_2d, axis=0) # 沿着 列(axis=0) 找出最大值的索引
index_axis_1 = np.argmax(arr_2d, axis=1) # 沿着 行(axis=1) 找出最大值的索引
print(index_axis_0)  # 输出: [1 1 1]
print(index_axis_1)  # 输出: [1 2]
-------------------------------------------------------------------------------------------------------
# 运行结果:
3
[1 1 1]
[1 2]

(二) pytorch 中的 Argmax 函数:

  • 功能:torch.argmax 函数在 PyTorch 中可以用于找到给定张量(Tensor)中每一行或每一列的最大值的索引。

  • 用法:对于一个给定的张量,argmax 会返回沿着指定维度的最大值的索引。如果不指定维度,默认是在最后一个维度上操作

    列就是第一个维度,行就是第二个维度,第三个维度可以理解成厚度,以此类推…

  • 代码:

    torch.argmax(input, dim=None, keepdim=False)
    
    • input:输入的张量。

    • dim:指定在哪个维度上寻找最大值,0为一维,1为二维…如果为 None,则默认在最后一个维度上操作。

    • keepdim:如果为 True,则输出的张量将保持与输入张量相同的维度数,否则,输出张量将减少一个维度,默认是 False

    • 输出结果是 int 型,64位的电脑,输出默认就是 int64 型。

    • 简洁写法:

      # 假设有一个多维张量 some_tensor
      output = some_tensor.argmax(dim=1) # 逐行找出最大值
      

    注意事项:

    1. 数据类型argmax 返回的是 int 类型(64位电脑就是 int64 )的张量,因为索引值必须是整数。
    2. 多维张量dim 参数指定了在哪个维度上寻找最大值。如果 dim 为负值,则从最后一个维度开始计数(例如,dim=-1 等于默认行为)。
    3. 相同最大值:如果有多个相同的最大值,argmax 将返回 第一个 出现的最大值的索引。
    4. NaN 值:如果张量中包含 NaN 值,argmax 会忽略这些值,并在剩余的有效值中寻找最大值。
  • 示例:

    import torch
    
    x = torch.tensor([[1, 2, 3], [4, 5, 6]])
    result1 = torch.argmax(x, dim=0) # 在第一个维度(列)上寻找最大值的索引
    print(result1)
    result2 = torch.argmax(x, dim=1) # 在最后一个维度(行)上寻找最大值的索引
    print(result2)
    result3 = torch.argmax(x, dim=1, keepdim=True) # 指定keepdim=True
    print(result3) 
    -------------------------------------------------------------------------------------------------------
    # 运行结果:
    tensor([1, 1, 1])
    tensor([2, 2])
    tensor([[2],
            [2]])
    
  • 高级应用:

    可以在神经网络训练中,计算准确率。

    假设有一个 n 分类模型,输出结果为 n 个概率,分别对应输入图片被识别成这 n 个类别的概率(∈[0,1]);

    n=3 (这 3 种类别分别用 0、1、2 表示),现在有两张输入图片,输出为:

    outputs=[[0.1, 0.2, 0.8],
            [0.4, 0.6, 0.2]]
    

    假设两张图片分别为类别 0 和类别 2 (即 target=[2,2])。对 outputs 使用 torch.argmax 函数:

    result = torch.argmax(outputs, dim=1) # 在每行上寻找最大值的索引,因为 keepdim 默认为 False ,故输出是一维张量 
    # 输出结果为:
    tensor([2,1])
    

    因为 target=[2,2] ,而训练结果为 result=[2,1] ,所以第一个图片识别正确,第二个图片识别错误。

    正确率的计算 :

    # 前提是 target 和 result 都是张量
    print(target == result) # 直接比较
    A = (target == result).sum() #True默认为1,False默认位0
    print(A)
    accuracy = (A.item() / len(result)) * 100
    accuracy = roun
    d(accuracy, 2) # 2表示保留两位小数(四舍五入)
    print(f"正确率为:{accuracy}%")
    ---------------------------------------------------------------------------------------
    # 运行结果
    tensor([True,False])
    tensor(1)
    正确率为:50%
    

上一篇下一篇
神经网络入门实战(十七)待发布

http://www.kler.cn/a/428831.html

相关文章:

  • 【24】Word:小郑-准考证❗
  • 前沿技术趋势洞察与分析:探寻科技变革的多维密码
  • 四、华为交换机 STP
  • 安路FPGA开发工具TD:问题解决办法 及 Tips 总结
  • 安全测评主要标准
  • R数据分析:有调节的中介与有中介的调节的整体介绍
  • Java的Stream流:文件处理、排序与串并行流的全面指南
  • 智能方法求解-圆环内传感器节点最大最小距离分布
  • 后端返回前端的数据量过大解决方案
  • 最新基于R语言森林生态系统结构、功能与稳定性分析与可视化实践高级应用
  • 低级爬虫实现-记录HCIP云架构考试
  • 数字图像处理(15):图像平移
  • Fiddler 5.21.0 使用指南:过滤浏览器HTTP(S)流量下(四)
  • 基于gitlab API刷新MR的commit的指定status
  • 【Unity高级】如何动态调整物体透明度
  • Linux-Regmap实验
  • Go database/sql包源码分析
  • ShardingSphere 数据库中间件
  • k8s 为什么需要Pod?
  • 高级java每日一道面试题-2024年12月05日-JVM篇-什么是TLAB?
  • 计算机键盘的演变 | 键盘键名称及其功能 | 键盘指法
  • 软件无线电安全之GNU Radio基础(下)
  • 英文论文翻译成中文,怎样翻译更地道?
  • 【开源免费】基于Vue和SpringBoot的高校学科竞赛平台(附论文)
  • 普通算法——一维前缀和
  • k8s-Informer概要解析(2)