当前位置: 首页 > article >正文

PyTorch学习笔记之基础函数篇(十五)

文章目录

  • 数值比较运算
    • 8.1 torch.equal()函数
    • 8.2 torch.ge()函数
    • 8.3 torch.gt()函数
    • 8.4 torch.le()函数
    • 8.5 torch.lt()函数
    • 8.6 torch.ne()函数
    • 8.7 torch.sort()函数
    • 8.8 torch.topk()函数

数值比较运算

8.1 torch.equal()函数

torch.equal(tensor1, tensor2) -> bool

这个函数接受两个张量(tensor1 和 tensor2)作为参数,并返回一个布尔值(bool),表示这两个张量是否相等。

  • tensor1 (Tensor): 第一个要比较的张量。
  • tensor2 (Tensor): 第二个要比较的张量。

torch.equal 是 PyTorch 中的一个函数,用于比较两个张量(tensors)是否具有相同的形状和元素值。如果两个张量满足以下条件,则 torch.equal 返回 True:

  • 两个张量具有相同的形状(shape)。
  • 两个张量中的每个元素都相等。

这个函数在比较两个张量是否完全相同时非常有用。

示例:

import torch

# 创建两个形状和值都相同的张量
tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([1.0, 2.0, 3.0])

# 使用 torch.equal 检查它们是否相等
result = torch.equal(tensor1, tensor2)
print(result)  # 输出: True

# 创建两个形状相同但值不同的张量
tensor3 = torch.tensor([1.0, 2.0, 4.0])

# 再次使用 torch.equal 检查
result = torch.equal(tensor1, tensor3)
print(result)  # 输出: False

请注意,torch.equal 不仅比较张量的形状,还比较它们的元素值。这与 torch.eq 不同,后者仅比较两个张量中对应位置的元素值是否相等,并返回一个布尔值张量。

另外,如果两个张量都是稀疏的(即它们使用稀疏格式存储),则 torch.equal 还会比较它们的稀疏索引和值是否相同。

使用 torch.equal 时,请确保比较的张量在设备(CPU 或 GPU)和数据类型(如 float32、float64 等)上都是一致的,否则比较可能会失败或产生意外的结果。

8.2 torch.ge()函数

在PyTorch中,torch.ge() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的元素是否满足第一个张量中的元素大于或等于第二个张量中的元素。

函数签名如下:

torch.ge(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])

# 使用 torch.ge 进行逐元素比较
result = torch.ge(tensor1, tensor2)

print(result)
# 输出: tensor([ True,  True, False])
# 因为 1 >= 1, 2 >= 2, 3 < 4

# 使用标量进行比较
scalar = 2
result_scalar = torch.ge(tensor1, scalar)

print(result_scalar)
# 输出: tensor([False,  True,  True])
# 因为 1 < 2, 2 >= 2, 3 >= 2

请注意,torch.ge() 返回的是一个布尔值张量,其形状与输入张量相同。如果比较操作涉及到不同类型的张量(例如,一个张量是浮点数类型,另一个是整数类型),则可能需要先将它们转换为相同的类型,以避免可能的类型不匹配错误。

此外,还可以使用张量的比较运算符 >= 来进行相同的操作,这在语法上可能更简洁:

result = tensor1 >= tensor2

这行代码与 torch.ge(tensor1, tensor2) 等效。

8.3 torch.gt()函数

torch.gt() 是 PyTorch 中的一个函数,用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否大于第二个张量中的元素。

函数签名如下:

torch.gt(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])

# 使用 torch.gt 进行逐元素比较
result = torch.gt(tensor1, tensor2)

print(result)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 1, 2 不大于 2, 3 大于 4(这里的比较是错误的,应该是 3 < 4)

# 修正比较操作
result_corrected = torch.gt(tensor1, 2)

print(result_corrected)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 2, 2 不大于 2, 3 大于 2

# 使用标量进行比较
scalar = 2
result_scalar = torch.gt(tensor1, scalar)

print(result_scalar)
# 输出: tensor([False, False,  True])
# 因为 1 不大于 2, 2 不大于 2, 3 大于 2

请注意,torch.gt() 返回的是一个布尔值张量,其形状与输入张量相同。在这个例子中,tensor1 和 tensor2 的比较结果是一个包含三个元素的布尔值张量,表示 tensor1 中每个元素是否大于 tensor2 中对应位置的元素。

另外,与 torch.ge() 类似,你也可以使用张量的比较运算符 > 来进行相同的操作:

result = tensor1 > tensor2

这行代码与 torch.gt(tensor1, tensor2) 等效。

8.4 torch.le()函数

在PyTorch中,torch.le() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否小于或等于第二个张量中的元素。

函数签名如下:

torch.le(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])

# 使用 torch.le 进行逐元素比较
result = torch.le(tensor1, tensor2)

print(result)
# 输出: tensor([ True,  True,  True])
# 因为 1 <= 1, 2 <= 2, 3 <= 4

# 使用标量进行比较
scalar = 2
result_scalar = torch.le(tensor1, scalar)

print(result_scalar)
# 输出: tensor([ True,  True, False])
# 因为 1 <= 2, 2 <= 2, 3 > 2

在这个例子中,torch.le(tensor1, tensor2) 返回一个布尔值张量,其中每个元素对应于 tensor1 和 tensor2 中相应位置的元素比较结果。同样,torch.le(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

与 torch.ge() 类似,你也可以使用张量的比较运算符 <= 来进行相同的操作:

result = tensor1 <= tensor2

这行代码与 torch.le(tensor1, tensor2) 等效。

8.5 torch.lt()函数

在PyTorch中,torch.lt() 函数用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否小于第二个张量中的元素。

函数签名如下:

torch.lt(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])

# 使用 torch.lt 进行逐元素比较
result = torch.lt(tensor1, tensor2)

print(result)
# 输出: tensor([False, False,  True])
# 因为 1 不小于 1, 2 不小于 2, 3 小于 4

# 使用标量进行比较
scalar = 2
result_scalar = torch.lt(tensor1, scalar)

print(result_scalar)
# 输出: tensor([ True,  False,  False])
# 因为 1 小于 2, 2 不小于 2, 3 不小于 2

在这个例子中,torch.lt(tensor1, tensor2) 返回一个布尔值张量,表示 tensor1 中每个元素是否小于 tensor2 中对应位置的元素。同样地,torch.lt(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

你也可以使用张量的比较运算符 < 来执行相同的操作:

result = tensor1 < tensor2

这行代码与 torch.lt(tensor1, tensor2) 等效。

8.6 torch.ne()函数

torch.ne() 是 PyTorch 中的一个函数,用于逐元素地比较两个张量(tensors),并返回一个布尔值张量(boolean tensor),其中每个元素表示对应位置上的第一个张量中的元素是否不等于第二个张量中的元素。

函数签名如下:

torch.ne(input, other, *, out=None) -> Tensor
  • input (Tensor): 要比较的第一个张量。
  • other (Tensor or scalar): 要比较的第二个张量或标量。
  • out (Tensor, optional): 输出张量。

如果 other 是一个标量,那么它会与 input 张量中的每个元素进行比较。如果 other 是一个张量,那么它必须具有与 input 相同的形状,以便进行逐元素的比较。

示例:

import torch

# 创建两个张量
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 2, 4])

# 使用 torch.ne 进行逐元素比较
result = torch.ne(tensor1, tensor2)

print(result)
# 输出: tensor([False, False,  True])
# 因为 1 等于 1, 2 等于 2, 3 不等于 4

# 使用标量进行比较
scalar = 2
result_scalar = torch.ne(tensor1, scalar)

print(result_scalar)
# 输出: tensor([ True,  True, False])
# 因为 1 不等于 2, 2 不等于 2, 3 不等于 2

在这个例子中,torch.ne(tensor1, tensor2) 返回一个布尔值张量,其中每个元素对应于 tensor1 和 tensor2 中相应位置的元素比较结果。同样,torch.ne(tensor1, scalar) 将 tensor1 中的每个元素与标量 scalar 进行比较。

与 torch.ne() 类似,你也可以使用张量的比较运算符 != 来进行相同的操作:

result = tensor1 != tensor2

这行代码与 torch.ne(tensor1, tensor2) 等效。

8.7 torch.sort()函数

torch.sort() 是 PyTorch 中的一个函数,用于对张量(tensor)进行排序。它返回排序后的张量以及原始张量中元素的索引。

函数签名如下:

torch.sort(input, dim=None, descending=False, out=None) -> (Tensor, Tensor)
  • input (Tensor): 要排序的张量。
  • dim (int, optional): 沿着哪个维度进行排序。默认为 None,表示对整个张量进行排序。
  • descending (bool, optional): 是否按降序排序。默认为 False,即按升序排序。
  • out (tuple[Tensor, Tensor], optional): 可选的输出张量元组,用于存放排序结果和索引。

torch.sort() 返回一个元组,其中包含两个张量:

  • 排序后的张量(Tensor):沿着指定维度对输入张量进行排序后的结果。
  • 索引张量(Tensor):原始张量中元素的索引,按照排序后的顺序排列。

示例:

import torch

# 创建一个张量
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5])

# 对张量进行排序
sorted_tensor, sorted_indices = torch.sort(x)

print("Sorted tensor:", sorted_tensor)
# 输出: Sorted tensor: tensor([1, 1, 2, 3, 3, 4, 5, 5, 5, 6, 9])

print("Sorted indices:", sorted_indices)
# 输出: Sorted indices: tensor([1, 3, 7, 0, 10, 2, 4, 8, 9, 5, 6])

# 如果想按降序排序
sorted_tensor_desc, sorted_indices_desc = torch.sort(x, descending=True)

print("Sorted tensor (descending):", sorted_tensor_desc)
# 输出: Sorted tensor (descending): tensor([9, 6, 5, 5, 5, 4, 3, 3, 2, 1, 1])

print("Sorted indices (descending):", sorted_indices_desc)
# 输出: Sorted indices (descending): tensor([6, 5, 8, 9, 4, 2, 0, 10, 7, 3, 1])

在这个例子中,torch.sort(x) 返回了排序后的张量 sorted_tensor 和原始张量中元素的索引 sorted_indices。如果指定 descending=True,则会按降序排序。

8.8 torch.topk()函数

torch.topk() 是 PyTorch 中的一个函数,用于返回张量中每个指定维度上的前 k 个最大值(或最小值)及其索引。这个函数非常有用,特别是在你需要获取张量中的顶部元素或进行排序相关的操作时。

函数签名如下:

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, Tensor)
  • input (Tensor): 输入张量。
  • k (int): 要返回的最大(或最小)值的数量。
  • dim (int, optional): 在哪个维度上进行操作。默认是第一个维度(dim=0)。
  • largest (bool, optional): 是否返回最大的 k 个值。如果是 True,则返回最大的 k 个值;如果是 False,则返回最小的 k 个值。默认为 True。
  • sorted (bool, optional): 是否对结果进行排序。如果为 True,则返回的 k 个值会按照降序(对于 largest=True)或升序(对于 largest=False)排列。默认为 True。
  • out (tuple[Tensor, Tensor], optional): 可选的输出张量元组,用于存放结果和索引。

torch.topk() 返回一个元组,其中包含两个张量:

第一个张量包含每个指定维度上的前 k 个最大值(或最小值)。
第二个张量包含这些最大值(或最小值)在原始张量中的索引。

示例:

import torch

# 创建一个张量
x = torch.tensor([[ 3, 2, 1],
                  [ 2, 3, 1],
                  [ 1, 2, 3]])

# 获取每个行的前 2 个最大值及其索引
values, indices = torch.topk(x, k=2, dim=1, largest=True, sorted=True)

print("Top 2 values:", values)
# 输出: Top 2 values: tensor([[3, 2],
#                             [3, 2],
#                             [3, 2]])

print("Indices of top 2 values:", indices)
# 输出: Indices of top 2 values: tensor([[0, 1],
#                                        [1, 0],
#                                        [2, 1]])

# 获取每个列的最小值及其索引
values, indices = torch.topk(x, k=1, dim=0, largest=False, sorted=True)

print("Smallest value in each column:", values)
# 输出: Smallest value in each column: tensor([[1, 1, 1]])

print("Indices of smallest values in each column:", indices)
# 输出: Indices of smallest values in each column: tensor([[2, 2, 2]])

在这个例子中,我们首先使用 torch.topk() 获取了每个行(dim=1)的前两个最大值(largest=True)及其索引。然后,我们改变了维度和条件,获取了每个列(dim=0)的最小值(largest=False)及其索引。


http://www.kler.cn/a/273113.html

相关文章:

  • 通过不当变更导致 PostgreSQL 翻车的案例分析与防范
  • 关于wordpress instagram feed 插件 (现更名为Smash Balloon Social Photo Feed)
  • Java 实现接口幂等的九种方法:确保系统稳定性与数据一致性
  • Java面向对象 C语言字符串常量
  • 【.NET 8 实战--孢子记账--从单体到微服务】--简易权限--接口路径管理
  • 【MIT-OS6.S081笔记1】xv6环境搭建
  • C/C++:有助于define宏定义-原文替换的例题
  • 深入解析JVM加载机制
  • 解决:visio导出公式为pdf图片乱码问题
  • Python笔记四之协程
  • [ComfyUI报错信息] 节点错误归类及处理办法(最新完整版)
  • ThreadLocal-内存泄露问题
  • 【LeetCode热题100】104. 二叉树的最大深度(二叉树)
  • 二级Java程序题--03综合应用:源代码(01-42)
  • 利用自定义 URI Scheme 在 Android 应用中实现安全加密解密功能
  • 【React】Vite创建React+TS项目
  • 类和对象(1)
  • Centos8安装wdCP
  • MATLAB中如何导出EXE或DLL
  • 缺失的数字(c++题解)
  • 【python开发】并发编程(上)
  • 凝思操作系统离线安装mysql和node
  • python 调用redis创建查询key
  • YOLOv9改进策略:注意力机制 | 用于微小目标检测的上下文增强和特征细化网络ContextAggregation,助力小目标检测,暴力涨点
  • SWUST OJ 961: 进制转换问题
  • 网络管理基础