当前位置: 首页 > article >正文

Tensor 基本操作4 理解 indexing,加减乘除和 broadcasting 运算 | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作3 理解 shape, stride, storage, view,is_contiguous 和 reshape 操作 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

Tensor 基本使用

  • 索引 indexing
    • 示例代码
  • 加减乘除
    • 加法和减法
    • 乘法和除法
  • broadcasting 机制
  • 更多运算
  • Links

索引 indexing

Tensor 的索引类似于 Python List 的索引和分片。

比如一个 AxBxC 的三个维度的 Tensor a

a[第0维的分片, 第1维的分片, 第2维的分片]

分片的语法和 Python List 分片语法一致,开始:结束:步进

更多索引的高级语法介绍。

示例代码

    print("*" * 8, " a")
    a = torch.randn(5,4,3)
    print(a)

    print("*" * 8, " b")
    b = a[1,]     # 只要第 0 维的第一个成员
    print(b)

    print("*" * 8, " c")
    c = a[1:]   # 第 0 维从第一个成员开始都要,注意:这里索引从 0 开始
    print(c)

    print("*" * 8, " d")
    d = a[1:, 1] # 第 0 维从第一个成员开始都要,第二维只要第一个成员
    print(d)

Result

********  a
tensor([[[ 0.1874, -0.0980, -0.3815],
         [-0.8175,  1.5976, -1.4927],
         [-0.1507,  1.1806, -0.3685],
         [ 1.1583,  0.9419, -0.5540]],

        [[ 1.3078, -1.4250, -1.5981],
         [-0.0756,  2.0776,  0.7708],
         [ 1.6020, -1.9133,  1.2459],
         [-0.2817, -0.7238, -0.5413]],

        [[-0.8057, -0.4368, -1.2398],
         [ 0.8415,  1.7679,  0.6469],
         [ 0.7046, -0.4872,  1.1219],
         [-2.5866, -0.1263,  2.0684]],

        [[ 1.8756,  1.4231, -1.2082],
         [ 0.2111,  0.5244,  2.2242],
         [-0.9658, -1.3731, -0.9126],
         [-0.3850, -0.7273, -0.0519]],

        [[ 0.7949,  2.2807, -0.8793],
         [ 0.4037,  1.2422, -0.2393],
         [ 0.4786,  0.6107,  1.4225],
         [ 0.6104,  1.2682, -0.0801]]])
********  b = a[1,]
tensor([[ 1.3078, -1.4250, -1.5981],
        [-0.0756,  2.0776,  0.7708],
        [ 1.6020, -1.9133,  1.2459],
        [-0.2817, -0.7238, -0.5413]])
********  c = a[1:]
tensor([[[ 1.3078, -1.4250, -1.5981],
         [-0.0756,  2.0776,  0.7708],
         [ 1.6020, -1.9133,  1.2459],
         [-0.2817, -0.7238, -0.5413]],

        [[-0.8057, -0.4368, -1.2398],
         [ 0.8415,  1.7679,  0.6469],
         [ 0.7046, -0.4872,  1.1219],
         [-2.5866, -0.1263,  2.0684]],

        [[ 1.8756,  1.4231, -1.2082],
         [ 0.2111,  0.5244,  2.2242],
         [-0.9658, -1.3731, -0.9126],
         [-0.3850, -0.7273, -0.0519]],

        [[ 0.7949,  2.2807, -0.8793],
         [ 0.4037,  1.2422, -0.2393],
         [ 0.4786,  0.6107,  1.4225],
         [ 0.6104,  1.2682, -0.0801]]])
********  d = a[1:, 1]
tensor([[-0.0756,  2.0776,  0.7708],
        [ 0.8415,  1.7679,  0.6469],
        [ 0.2111,  0.5244,  2.2242],
        [ 0.4037,  1.2422, -0.2393]])

加减乘除

加法和减法

import torch
 
# 这两个Tensor加减乘除会对b自动进行Broadcasting
a = torch.rand(3, 4)
b = torch.rand(4)
 
c1 = a + b
c2 = torch.add(a, b)
print(c1.shape, c2.shape)
print(torch.all(torch.eq(c1, c2)))

乘法和除法

*, torch.mul, torch.mm, torch.matmul

参考: torch.Tensor的4种乘法

除法可以用乘法 API 完成。

broadcasting 机制

在 Tensor 的加减运算中,当两个 tensor 不能直接符合数学的运算规则时,PyTorch 会先尝试将 tensor 进行变换,再进行计算,这个变换的规则就是:broadcasting。
在这里插入图片描述

更多 broadcasting 机制的介绍。

更多运算

更多加法和其他运算,参考Pytorch Tensor基本数学运算:

  • 减法运算
  • 哈达玛积(对应元素相乘,也称为 element wise)
  • 除法运算
  • 幂运算
  • 开方运算
  • 指数与对数运算
  • 近似值运算
  • 裁剪运算

Links

  • Tensor Broadcasting under the hood
  • Mastering PyTorch Indexing: Simple Techniques with Practical Examples
  • torch.Tensor的4种乘法
  • Pytorch Tensor基本数学运算

http://www.kler.cn/a/516263.html

相关文章:

  • 【人工智能】深度卷积神经网络学习
  • 【数据库】详解MySQL数据库中索引的本质与底层原理
  • 代码随想录day16
  • 一键视频转文字/音频转文字,浏览器右键提取B站视频文案,不限时长免费无限次可用
  • CRM项目的开发与调试整体策略
  • Flutter鸿蒙化中的Plugin
  • SpringCloud系列教程:微服务的未来(十五)实现登录校验、网关传递用户、OpenFeign传递用户
  • (Java版本)基于JAVA的网络通讯系统设计与实现-毕业设计
  • 2018 秋招 百度二轮面试---血淋淋的经历写实
  • 重构(4)
  • ruoyi-vue-plus 引入 ShardingSphere-JDBC 实现分库分表
  • docker 部署.netcore应用优势在什么地方?
  • Linux下Ubuntun系统报错find_package(BLAS REQUIRED)找不到
  • 华为OD机试E卷 --树状结构查询--24年OD统一考试(Java JS Python C C++)
  • 概率密度函数(PDF)分布函数(CDF)——直方图累积直方图——直方图规定化的数学基础
  • 智源研究院与乐聚机器人成立具身智能联合实验室
  • 深度学习实战图像OCR识别
  • 【博客之星】2024年度创作成长总结 - 面朝大海 ,春暖花开!
  • STM32——LCD
  • Spring Boot中选择性加载Bean的几种方式