当前位置: 首页 > article >正文

关于函数torch.topk用法的思考

文章目录

      • 1. 沿着`dim=0`
      • 2. 沿着`dim=1`
      • 3. 沿着`dim=2`
      • 4. 总结

开始介绍之前先来点哲理性的思考,为什么函数 torch.topk,他的名字会叫 topk呢?

个人认为名称是来源于 “top k”,在这种情况下,它表示 “前 k 个最大值”。

假设我们有一个形状为 ( 2 , 3 , 4 ) (2, 3, 4) (2,3,4) 的三维张量 A A A,如下所示:

A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

1. 沿着dim=0

沿着 dim=0(即在子矩阵之间进行比较):

k = 1
topk_values, topk_indices = torch.topk(A, k=k, dim=0)
A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

沿着 dim=0(即在子矩阵之间进行比较):

topk_values = tensor([[[16, 18, 20, 22],
                        [17, 19, 21, 23],
                        [10, 12, 14, 24]]])

topk_indices = tensor([[[1, 1, 1, 1],
                         [1, 1, 1, 1],
                         [1, 1, 1, 1]]])

那此时我们令k = 3会发生什么?很显然,我们并没有三个子矩阵,所以此时程序会报错。给大家看一下程序的错误:

Traceback (most recent call last):
  File "E:\Learning_Material\Junior_Second_Semester\Adademic_Research\BusterNet_pytorch-master\test_2023_4_7.py", line 14, in <module>
    topk_values, topk_indices = torch.topk(A, k=k, dim=0)
RuntimeError: selected index k out of range

2. 沿着dim=1

沿着 dim=1(即在行之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=1)
topk_values = tensor([[[ 9, 11, 13, 15],
         			   [ 2,  4,  6,  8]],

        			   [[17, 19, 21, 24],
         				[16, 18, 20, 23]]])

topk_indices = tensor([[[2, 2, 2, 2],
         				[1, 1, 1, 1]],

        			   [[1, 1, 1, 2],
         				[0, 0, 0, 1]]])

在沿着行进行比较的情况下,比较是不会在子矩阵,也就是更高维度上发生的。仅仅在子矩阵的维度上比较一个子矩阵中最大的行数。

3. 沿着dim=2

沿着 dim=2(即在列之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=2)
topk_values = tensor([[[ 7,  5],
         			   [ 8,  6],
         			   [15, 13]],

        			  [[22, 20],
         		   	   [23, 21],
         			   [24, 14]]])

topk_indices = tensor([[[3, 2],
         				[3, 2],
         				[3, 2]],

        			   [[3, 2],
         				[3, 2],
         				[3, 2]]])

沿着dim=2比较,此时就不会牵扯到更高的两个维度,只会在最后一个维度之内进行排序比较

4. 总结

没有什么别的经验,希望对大家有用就好!


http://www.kler.cn/news/8021.html

相关文章:

  • axios 导出excel表格 文件流格式
  • Alibaba开源的Java诊断工具Arthas-实战
  • [ 应急响应基础篇 ] 解决远程登录权限不足的问题(后门账号添加远程桌面权限)
  • 华为OD机试-组合出合法最小数-2022Q4 A卷-Py/Java/JS
  • 深入浅出Kafka
  • Linux账号管理(用户{创建删除修改}用户组{创建删除修改}一般用户命令{id,finger,chfn,chsh})
  • 【JavaWeb】8—过滤器
  • 【架构师从零进阶】Java基础 练气期 Day1
  • 旅游心得Traveling Experience
  • 从零开始:如何集成美颜SDK到你的应用中
  • Redis常用命令以及如何在Java中操作Redis
  • springcloud——并发请求处理方案
  • 软件测试面试复盘:技术面没有难倒我,hr面被虐的体无完肤
  • 笔记-常见的动态内存错误
  • 收割offer疯狂涨了5K,自动化测试面试题整理大全,你能答上多少?
  • js设计模式——组合模式
  • RBF-UKF径向基神经网络结合无迹卡尔曼滤波估计锂离子电池SOC(附MATLAB代码)
  • 在cmd命令窗口安装Python模块
  • 入门力扣自学笔记257 C++ (题目编号:1041)
  • GuLi商城-SpringCloud-Gateway网关核心概念、测试API网关