当前位置: 首页 > article >正文

torch.argsort 函数介绍

torch.argsort 是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。

函数签名

torch.argsort(input, dim=-1, descending=False)

参数

  • input:输入的张量。
  • dim:指定进行排序的维度,默认为最后一个维度(-1)。
  • descending:如果为 True,则返回按降序排列的索引;如果为 False(默认),则按升序排列。

返回

返回与输入张量 input 形状相同的张量,其中每个元素为排序后的索引值。

使用场景

  • 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用 argsort
  • 可以用于生成一种排序的掩码,比如对某些值进行有序操作。

示例代码

1. 基本用法
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices)  # 输出:tensor([1, 3, 0, 2])

解释:torch.argsort 返回的是按升序排列元素的索引。索引 1 对应值 1.2(最小),依次类推。

2. 在多维张量上使用
import torch

# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
                  [2, 8, 5]])

# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)

# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)

输出

tensor([[1, 2, 0],
        [0, 2, 1]])

tensor([[1, 0, 0],
        [0, 1, 1]])

解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。

3. 降序排序
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices)  # 输出:tensor([2, 0, 3, 1])

解释:这里使用了 descending=True,所以 torch.argsort 按降序返回索引,索引 2 对应值 4.8(最大),依次类推。

4. 与 gather 配合使用

你可以使用 torch.argsort 来获取排序后的索引,并结合 torch.gather 获取排序后的值:

import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 获取排序后的索引
sorted_indices = torch.argsort(x)

# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x)  # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])

解释:通过 torch.gather 函数,可以根据排序后的索引来重新排列原张量的值。

总结

torch.argsort 函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather),可以灵活实现很多数据处理操作。


http://www.kler.cn/news/363292.html

相关文章:

  • 网络安全领域推荐证书介绍及备考指南
  • NoSQL 数据库 Redis
  • MongoDB Shell 基本命令(三)生成学生脚本信息和简单查询
  • 搬砖14、Python网络编程入门
  • Java工具类--截至2024常用http工具类分享
  • 科研进展 | RSE:全波形高光谱激光雷达数据Rclonte系列处理算法一
  • 【AIGC】ChatGPT提示词Prompt高效编写模式:Self-ask Prompt、ReACT与Reflexion
  • Java面向对象编程进阶(一)
  • 对于MacOSX开启稳定SMB分享的研究
  • 讲个故事:关于一次接口性能优化的心里路程
  • 数据库运行原理图
  • JVM锁升级的过程(偏向锁,轻量级锁,重量级锁)
  • String的运算符重载
  • C#中跨线程调用的方法一点总结
  • 数据仓库基础概念
  • 新探索研究生英语读写教程pdf答案(基础级)
  • 阿里云服务器如何安装宝塔
  • jupyter DatabaseError: database disk image is malformed 解决办法
  • 分布式检测线路、精准定位故障:输电线路故障定位监测系统
  • 国外电商系统开发-运维系统Docker镜像
  • 构建高效在线教育平台:Spring Boot的力量
  • iPhone图片/照片/视频复制到win10系统的简单方法 - 照片导出
  • Vue 3中的国际化(i18n)实践指南
  • C语言中MySQL库函数
  • MySQL【知识改变命运】09
  • 携程后端JAVA面试汇总