当前位置: 首页 > article >正文

深度学习02-pytorch-04-张量的运算函数

在 PyTorch 中,张量(tensor)运算是核心操作之一,PyTorch 提供了丰富的函数来进行张量运算,包括数学运算、线性代数、索引操作等。以下是常见的张量运算函数及其用途:

1. 基本数学运算

  • 加法运算torch.add(a, b) 或者直接使用 +

    a = torch.tensor([1, 2])
    b = torch.tensor([3, 4])
    c = torch.add(a, b)  # [4, 6]
    # 或者
    c = a + b

  • 减法运算torch.sub(a, b) 或者直接使用 -

    c = torch.sub(a, b)  # [-2, -2]
    # 或者
    c = a - b

  • 乘法运算(逐元素)torch.mul(a, b) 或者直接使用 *

    c = torch.mul(a, b)  # [3, 8]
    # 或者
    c = a * b

  • 除法运算(逐元素)torch.div(a, b) 或者直接使用 /

    c = torch.div(a, b)  # [0.3333, 0.5]
    # 或者
    c = a / b

  • 指数运算torch.pow(a, b) 或者 a ** b

    c = torch.pow(a, b)  # a^b -> [1^3, 2^4] = [1, 16]
    # 或者
    c = a ** b

  • 求幂函数torch.sqrt(a)torch.exp(a)torch.log(a)

    c = torch.sqrt(torch.tensor([4.0, 9.0]))  # [2.0, 3.0]
    d = torch.exp(torch.tensor([1.0, 2.0]))  # e^1, e^2
    e = torch.log(torch.tensor([1.0, 2.0]))  # log(1), log(2)

2. 聚合操作

  • 求和torch.sum(tensor, dim=None)

    a = torch.tensor([[1, 2], [3, 4]])
    total_sum = torch.sum(a)  # 全局求和: 10
    row_sum = torch.sum(a, dim=1)  # 对每一行求和: [3, 7]

  • 求平均值torch.mean(tensor, dim=None)

    avg = torch.mean(a.float())  # 平均值: 2.5

  • 最大值torch.max(tensor)torch.max(tensor, dim)

    max_val = torch.max(a)  # 最大值: 4
    max_val_row, idx = torch.max(a, dim=1)  # 每一行的最大值: [2, 4]

  • 最小值torch.min(tensor)torch.min(tensor, dim)

    min_val = torch.min(a)  # 最小值: 1

  • 标准差torch.std(tensor)

    std = torch.std(a.float())  # 标准差

3. 线性代数运算

  • 矩阵乘法torch.mm(a, b) 或者使用 @

    a = torch.tensor([[1, 2], [3, 4]])
    b = torch.tensor([[5, 6], [7, 8]])
    c = torch.mm(a, b)  # 矩阵乘法
    # 或者
    c = a @ b

  • 矩阵转置torch.t(tensor) 或者使用 .T

    a_t = torch.t(a)  # 转置
    # 或者
    a_t = a.T

  • 矩阵求逆torch.inverse(tensor)

    a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
    a_inv = torch.inverse(a)  # 求矩阵的逆

  • 行列式torch.det(tensor)

    det = torch.det(a)  # 计算行列式

  • 特征值和特征向量torch.eig(tensor, eigenvectors=True)

    e_vals, e_vecs = torch.eig(a, eigenvectors=True)  # 计算特征值和特征向量

4. 张量形状操作

  • 张量重塑torch.reshape(tensor, new_shape)tensor.view(new_shape)

    a = torch.tensor([[1, 2], [3, 4], [5, 6]])
    reshaped = torch.reshape(a, (2, 3))  # 改变形状为 (2, 3)

  • 张量扩展torch.unsqueeze(tensor, dim)torch.squeeze(tensor, dim)

    a = torch.tensor([1, 2, 3])
    unsqueezed = torch.unsqueeze(a, 0)  # 在第0维添加一个新维度 -> [[1, 2, 3]]
    squeezed = torch.squeeze(unsqueezed)  # 移除维度 -> [1, 2, 3]

  • 拼接张量torch.cat(tensors, dim)torch.stack(tensors, dim)

    a = torch.tensor([1, 2])
    b = torch.tensor([3, 4])
    concatenated = torch.cat((a, b), dim=0)  # 拼接 -> [1, 2, 3, 4]
    
    stacked = torch.stack((a, b), dim=0)  # 堆叠 -> [[1, 2], [3, 4]]

5. 索引操作

  • 通过索引选择元素tensor[index]

    a = torch.tensor([[1, 2], [3, 4], [5, 6]])
    selected = a[0, 1]  # 选择第0行第1列的元素 -> 2

  • 高级索引tensor[range]、布尔索引等

    a = torch.tensor([1, 2, 3, 4, 5])
    selected = a[a > 3]  # 选择大于3的元素 -> [4, 5]

6. 随机数生成

  • 均匀分布随机数torch.rand(size)

    random_tensor = torch.rand(3, 3)  # 生成一个 3x3 的均匀分布随机张量

  • 正态分布随机数torch.randn(size)

    normal_random = torch.randn(3, 3)  # 生成一个 3x3 的正态分布随机张量

  • 指定范围的整数随机数torch.randint(low, high, size)

    randint_tensor = torch.randint(0, 10, (3, 3))  # 生成 0 到 10 之间的随机整数

7. 广播机制

  • 广播运算:当张量的形状不同,但维度兼容时,PyTorch 会自动应用广播机制扩展张量。

    a = torch.tensor([1, 2, 3])
    b = torch.tensor([[1], [2], [3]])
    c = a + b  # 广播操作

8. 自动微分

  • 启用自动微分requires_grad=True

    x = torch.tensor(2.0, requires_grad=True)
    y = x ** 2
    y.backward()  # 计算梯度
    print(x.grad)  # 输出: 4.0

总结

PyTorch 中的张量运算函数非常丰富,从基本的数学运算到高级的线性代数操作、形状调整和随机数生成,这些函数让张量的处理非常灵活和高效。通过这些运算,你可以实现各种数值计算和深度学习模型的训练。


http://www.kler.cn/news/314722.html

相关文章:

  • Selenium4.0实现自动搜索功能
  • 链式前向星建图
  • 【MySQL】 索引
  • Facebook隐私设置指南:如何更好地保护个人信息
  • 【二十二】【QT开发应用】QScrollArea控件应用1,C++11 R原始字符串字面量
  • Oracle(139)如何创建和管理数据库用户?
  • 1.3 计算机网络的分类
  • Hadoop的一些高频面试题 --- hdfs、mapreduce以及yarn的面试题
  • tensorflow同步机制
  • EasyExcel根据模板生成excel文件【xls、xlsx】
  • 【乐企-业务篇】开票前置校验服务-规则链服务接口实现(发票基础信息校验)
  • 2.场景应用:接口关联,文件上传(Postman工具)
  • Shell篇之编写php启动脚本
  • [python]从零开始的PySide安装配置教程
  • JavaEE: 深入探索TCP网络编程的奇妙世界(三)
  • Python实现图形学曲线和曲面的Bezier曲线算法
  • 深度学习-生成式检索-论文速读-2024-09-14
  • 关于自动化测试的一点了解
  • 高效财税自动化软件的特点与优势
  • ChatGPT 为何将前端框架从 Next.js 更换为 Remix以及框架的选择
  • Java中List、ArrayList与顺序表
  • hackmyvm靶场--zon
  • Spring:项目中的统一异常处理和自定义异常
  • 通过Java设计模式提高业务流程灵活性的策略
  • 笔记:DrawingContext和GDI+对比简介
  • 【Python】探索 TensorFlow:构建强大的机器学习模型
  • PostgreSQL技术内幕11:PostgreSQL事务原理解析-MVCC
  • 区块链DAPP质押系统开发
  • python中迭代器和可迭代对象
  • 【系统架构设计师】特定领域软件架构(经典习题)