当前位置: 首页 > article >正文

PyTorch 张量的常用 API

以下内容是《龙良曲 PyTorch 入门到实战》的学习笔记。

PyTorch 张量的 API 和 Numpy 数组的 API 有很多相似的地方。


当然先从 import torch 开始

import torch

1. 创建张量 1

# 创建 0 维张量(标量)
a = torch.tensor(2)
a, a.shape, a.dim()
(tensor(2), torch.Size([]), 0)
# 创建 1 维张量
b = torch.tensor([12])
b, b.shape, b.dim()
(tensor([1, 2]), torch.Size([2]), 1)

torch.Size([2]):[2] 表示张量只有一个维度,且该维度的大小为 2,意味着张量是一个长度为 2 的一维向量。

# 创建 2 维张量
c = torch.rand(23)  # torch.rand 是从均匀分布 [0, 1) 中随机采样
c, c.shape, c.dim()
(tensor([[0.9734, 0.0049, 0.7768],
         [0.2023, 0.0527, 0.7201]]),
 torch.Size([2, 3]),
 2)
# 创建 3 维张量
d = torch.rand(234)
d, d.shape, d.dim()
(tensor([[[0.0913, 0.0158, 0.8031, 0.9192],
          [0.6724, 0.0897, 0.9926, 0.9988],
          [0.8425, 0.8657, 0.5040, 0.8000]],
 
         [[0.4834, 0.4194, 0.6343, 0.7109],
          [0.8020, 0.1699, 0.5067, 0.7636],
          [0.7913, 0.9784, 0.3104, 0.0883]]]),
 torch.Size([2, 3, 4]),
 3)
# 创建 4 维张量
e = torch.rand(2345)
e, e.shape, e.dim()
(tensor([[[[0.9313, 0.5396, 0.4871, 0.3988, 0.2495],
           [0.6972, 0.4542, 0.9189, 0.0674, 0.9499],
           [0.0531, 0.6770, 0.8236, 0.6928, 0.2244],
           [0.2111, 0.3731, 0.9198, 0.1803, 0.2736]],
 
          [[0.3402, 0.2752, 0.5428, 0.6273, 0.5970],
           [0.7716, 0.3818, 0.7193, 0.3282, 0.6251],
           [0.5292, 0.0292, 0.7328, 0.7275, 0.6477],
           [0.4814, 0.9318, 0.1992, 0.9758, 0.5151]],
 
          [[0.3928, 0.5228, 0.8950, 0.8177, 0.0072],
           [0.1537, 0.9172, 0.2723, 0.1919, 0.8734],
           [0.0637, 0.6485, 0.6036, 0.4787, 0.7961],
           [0.7418, 0.6782, 0.8170, 0.2056, 0.4170]]],
 
 
         [[[0.4684, 0.4747, 0.0683, 0.5093, 0.6679],
           [0.3876, 0.2639, 0.9735, 0.5651, 0.8922],
           [0.5401, 0.4198, 0.8448, 0.0081, 0.2305],
           [0.8506, 0.3580, 0.9158, 0.9479, 0.7021]],
 
          [[0.4257, 0.1313, 0.4063, 0.3148, 0.9047],
           [0.0025, 0.7115, 0.6387, 0.4969, 0.8807],
           [0.4448, 0.3831, 0.5796, 0.3017, 0.2502],
           [0.9328, 0.0923, 0.3293, 0.4689, 0.5546]],
 
          [[0.4217, 0.7608, 0.5278, 0.5567, 0.8842],
           [0.3066, 0.5278, 0.5351, 0.2689, 0.4789],
           [0.9639, 0.9956, 0.3730, 0.7439, 0.2094],
           [0.1112, 0.3619, 0.7006, 0.0206, 0.5256]]]]),
 torch.Size([2, 3, 4, 5]),
 4)
e.numel()  # number of element, 2*3*4*5=120
120
# 设置张量的默认数据类型
# torch.set_default_tensor_type(torch.DoubleTensor)

2. 创建张量 2

2.1 从分布中随机采样

# 从均匀分布 [0, 1) 中随机采样
a = torch.rand(23)  
a
tensor([[0.8131, 0.0279, 0.8225],
        [0.5625, 0.9970, 0.6894]])
# 从 [1, 10) 中随机采样整数
torch.randint(110, [3,4])  
tensor([[4, 4, 1],
        [5, 9, 6]])
# 从标准正态分布 N(0, 1) 中随机采样
torch.randn(34)  # 3 行 4 列
tensor([[-0.6407, -0.8793, -0.4336, -0.3790],
        [ 0.7999, -0.9402, -1.5473,  0.3704],
        [-0.3192, -0.6442,  1.2120, -0.2569]])
# 从正态分布中随机采样
torch.normal(mean=0, std=4, size=(23))
tensor([[ 3.7403,  2.7976, -3.9352],
        [-3.7567,  0.0156, -0.4714]])

2.2 使用特定值填充张量

# 使用任意值填充张量
torch.full([2,3], 6)  # 2 行 3 列,填充值为 6
tensor([[6, 6, 6],
        [6, 6, 6]])
# 使用 1 填充张量
torch.ones(23)
tensor([[1., 1., 1.],
        [1., 1., 1.]])
# 使用 0 填充张量
torch.zeros(23)
tensor([[0., 0., 0.],
        [0., 0., 0.]])
# 使用 1 填充张量对角线
torch.eye(34)  # 使用 1 填充 3 行 4 列张量对角线
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.]])

2.3 生成序列张量

# 生成整数序列张量
torch.arange(0102)  # 采样 [0,10) 之间的整数,步长为 2,默认为 1
tensor([0, 2, 4, 6, 8])
# 生成等间隔序列张量
torch.linspace(010, steps=4)  # 起始值为 0, 结束值为 10, steps 指定生成的样本数量为 4
tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
# 生成等间隔序列张量,并使用指定底数做指数运算
torch.logspace(010, steps=11, base=2)  # 起始值为 0, 结束值为 10, base 指定底数为 2,默认为 10
tensor([1.0000e+00, 2.0000e+00, 4.0000e+00, 8.0000e+00, 1.6000e+01, 3.2000e+01,
        6.4000e+01, 1.2800e+02, 2.5600e+02, 5.1200e+02, 1.0240e+03])

2.4 生成随机索引种子张量

idx = torch.randperm(10)  # 范围为 [0, 10)
idx
tensor([1, 4, 0, 5, 9, 7, 3, 6, 2, 8])

3. 索引与切片

a = torch.rand(34185)
b = torch.randn(34)

3.1 直接索引

a[0].shape
torch.Size([4, 18, 5])
a[00].shape
torch.Size([18, 5])
a[0024].shape
torch.Size([])

3.2 切片

a.shape
torch.Size([3, 4, 18, 5])
a[:2].shape
torch.Size([2, 4, 18, 5])
a[:2, :1, :, :].shape
torch.Size([2, 1, 18, 5])
a[:21:, :, :].shape
torch.Size([2, 3, 18, 5])
a[:2-1:, :, :].shape
torch.Size([2, 1, 18, 5])
a[:, :-1, :, :].shape
torch.Size([3, 3, 18, 5])

3.3 间隔采样

a.shape
torch.Size([3, 4, 18, 5])

[n : m : k]代表的是在 [n: m) 这一段采样,每 k 个采样一次。 k 代表的是间隔,间隔可正可负,正值代表正向挑取,负值代表反向挑取。 当k为正的时候起始索引应该小于结束索引;当k为负的时候起始索引应该大于结束索引,因为在倒序来看,首先是索引值大的被取到,然后才是索引值小的。

a[:, :, 0:18:2, :].shape
torch.Size([3, 4, 9, 5])
a[:, :, ::2, :].shape  # ::2 表示从全序列采样,每 2 个采样一次
torch.Size([3, 4, 9, 5])

3.4 根据特定索引采样

a.shape
torch.Size([3, 4, 18, 5])

a.index_select(2, torch.tensor([2,5,7,15])).shape 对第 3 维的 [2,5,7,15] 进行采样,[2,5,7,15] 必须是一维 tensor,列表形式会报错。

a.index_select(2, torch.tensor([2,5,7,15])).shape
torch.Size([3, 4, 4, 5])

3.5 根据掩码采样

b
tensor([[ 0.9935, -0.0745, -0.4147, -0.0534],
        [-0.7544, -0.3848,  0.2596,  0.0135],
        [-1.8242,  0.1000,  0.3048,  0.0734]])
mask = b.ge(0.5)  # great equal 为大于等于,取 b 中大于等于 0.5 的元素
mask
tensor([[ True, False, False, False],
        [False, False, False, False],
        [False, False, False, False]])
torch.masked_select(b, mask)
tensor([0.9935])

3.6 打平之后采样

b
tensor([[ 0.9935, -0.0745, -0.4147, -0.0534],
        [-0.7544, -0.3848,  0.2596,  0.0135],
        [-1.8242,  0.1000,  0.3048,  0.0734]])
torch.take(b, torch.tensor([031011]))
tensor([ 0.9935, -0.0534,  0.3048,  0.0734])

3.7 ... 全取

a.shape
torch.Size([3, 4, 18, 5])

... 取连续区间内的所有维度的所有内容

a[2, ...].shape  # 第 1 维取的是索引为 2 的内容,其余维度的内容全取
torch.Size([4, 18, 5])
a[:, 3, ...].shape
torch.Size([3, 18, 5])
a[..., :3].shape
torch.Size([3, 4, 18, 3])
a[2:, ..., :3].shape
torch.Size([1, 4, 18, 3])

4. 维度变换

a = torch.rand(412828)

4.1 view() 和 reshape():重塑数据 shape,两者完全等同

a.view(41*28*28)  # 变换成 2 维
tensor([[0.5880, 0.9894, 0.6273,  ..., 0.1558, 0.7502, 0.1608],
        [0.6493, 0.7738, 0.4530,  ..., 0.5287, 0.9526, 0.9557],
        [0.4089, 0.2363, 0.8786,  ..., 0.8130, 0.5146, 0.0035],
        [0.2904, 0.7985, 0.6169,  ..., 0.9979, 0.8781, 0.7591]])
a.reshape(41*28*28)
tensor([[0.5880, 0.9894, 0.6273,  ..., 0.1558, 0.7502, 0.1608],
        [0.6493, 0.7738, 0.4530,  ..., 0.5287, 0.9526, 0.9557],
        [0.4089, 0.2363, 0.8786,  ..., 0.8130, 0.5146, 0.0035],
        [0.2904, 0.7985, 0.6169,  ..., 0.9979, 0.8781, 0.7591]])

4.2 unsqueeze(): 增加维度

unsqueeze() 增加维度不会改变数据,只是增加一个维度,这个维度的具体含义由自己定义。

a.shape
torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape  # 在第一维之前增加一个维度
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape  # 在最后一维之后增加一个维度
torch.Size([4, 1, 28, 28, 1])
a.unsqueeze(-5).shape
torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(4).shape
torch.Size([4, 1, 28, 28, 1])

4.3 squeeze(): 减少维度

只能去掉 size 为 1 的维度,对 size 不为 1 的维度执行此操作,原数据保持不变。 () 内无参数时,去掉所有 size 为 1 的维度。

a.shape
torch.Size([4, 1, 28, 28])
a.squeeze().shape
torch.Size([4, 28, 28])
a.squeeze(1).shape  
torch.Size([4, 28, 28])
a.squeeze(0).shape  
torch.Size([4, 1, 28, 28])
a.squeeze(-3).shape  
torch.Size([4, 28, 28])

4.4 expand(): 增加某一维度的 size

只能增加 size 为 1 的维度的 size,增加 size 不为 1 的维度的 size 会报错。 注意:增加某一维度的 size 是对既有数据的复制! -1 表示此维度的 size 保持不变。

a.shape, a.numel()
(torch.Size([4, 1, 28, 28]), 3136)
a.expand(422828).shape, a.expand(422828).numel()
(torch.Size([4, 2, 28, 28]), 6272)
a.expand(-12-1-1).shape
torch.Size([4, 2, 28, 28])

4.5 使用示例:不同维度数据相加

# 创建一个包含 32 个数据的一维张量
b = torch.rand(32)

# 创建一个四维张量,其中第二维的 size 为 32
c = torch.rand(4321515)

# 使用 unsqueeze() 增加 b 的维度
b = b.unsqueeze(0).unsqueeze(2).unsqueeze(3)

# 使用 expand() 增加 b 的维度的 size
b = b.expand(4321515)

# 将 b 和 c 相加
(b + c).shape
torch.Size([4, 32, 15, 15])

4.6 t(): 交换维度

.t() 只能用于二维张量交换维度,不能用于其他高维张量,返回输入张量的转置版本。

d = torch.rand(34)
d.t().shape
torch.Size([4, 3])

4.7 transpose(): 交换维度

transpose() 可以用于二维/多维张量的交换维度,返回输入张量的转置版本,但是只能实现两个维度的交换。

a.shape
torch.Size([4, 1, 28, 28])
# 交换第 2 和第 4 维度
a.transpose(13).shape
torch.Size([4, 28, 28, 1])

注意:transpose() 对张量进行了转置,数据的顺序发生了变化,所以当我们使用 view() 先压缩 shape 再恢复 shape 时,在恢复步骤中,必须先恢复为压缩之前的 shape,若恢复为其他 shape,会造成数据污染。这一点,我们要格外注意!!!

# 交换维度 size 之后,再使用 view() reshape,直接 reshape 为 a 的形状(错误的方法)
a1 = a.transpose(13).contiguous().view(41*28*28).view(412828)  # 在张量处于转置状态,数据的顺序发生了变化的情况下,直接 reshape 为 a 的形状是错误的,得到的结果不同于 a

# 交换维度 size 之后,再使用 view() reshape,先 reshape 恢复,再使用 transpose() 恢复为 a 的形状(正确的方法)
a2 = a.transpose(13).contiguous().view(41*28*28).view(428281).transpose(13)

contiguous() 的作用:数据转置之后,内存顺序被打乱了,contiguous() 可以将数据复制到一个新的内存位置,使内存顺序重新连续起来。

a1.shape, a2.shape
(torch.Size([4, 1, 28, 28]), torch.Size([4, 1, 28, 28]))
# 检查 a1, a2 与 a 是否相同
torch.all(torch.eq(a1, a)), torch.all(torch.eq(a2, a))
(tensor(False), tensor(True))

torch.eq(a1, a):这个函数比较 a1 中的每个元素与 a 中的每个元素,返回一个同样形状的布尔张量,其中元素为 True 表示 a1 中对应的元素等于 a,False 表示不等于。 torch.all(...):这个函数检查传入的布尔张量中的所有元素是否都是 True。如果是,返回 True;否则返回 False。

4.8 permute(): 交换维度 size

transpose() 只能实现两个维度的交换,permute() 可以实现张量多个维度的交换(可以理解为多次 transpose() 交换),返回输入张量的转置版本。

a.shape
torch.Size([4, 1, 28, 28])
# 将第二个维度移至最后面
# 使用 transpose()
A1 = a.transpose(12).transpose(23)

# 使用 permute()
A2 = a.permute(0231)

A1.shape, A2.shape, torch.all(torch.eq(A1, A2))
(torch.Size([4, 28, 28, 1]), torch.Size([4, 28, 28, 1]), tensor(True))

5. 合并与分割

5.1 cat()

cat:在某一维度上对两个张量进行合并,不创建新的维度,要求张量其他维度的 size 须相等。

a = torch.rand(4328)
b = torch.rand(5328)

torch.cat([a,b], dim=0).shape
torch.Size([9, 32, 8])

5.2 stack()

stack:对两个张量进行合并时创建一个新的维度,新维度的意义由自己指定,要求张量各个维度的 size 须相等。

a = torch.rand(4328)
b = torch.rand(4328)

torch.stack([a,b], dim=0).shape
torch.Size([2, 4, 32, 8])

5.3 split()

split:既可以指定拆分的 step 长度,也可以指定拆分得到的各块的数量。

a = torch.rand(6328)
# 指定 step=2
a1, a2, a3 = a.split(2, dim=0)
a1.shape, a2.shape, a3.shape
(torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32, 8]))
# 指定拆分得到的各块的数量
a1, a2, a3 = a.split([1,2,3], dim=0)
a1.shape, a2.shape, a3.shape
(torch.Size([1, 32, 8]), torch.Size([2, 32, 8]), torch.Size([3, 32, 8]))

5.4 chunk()

chunk:直接指定要拆分为几块,被拆分的维度的 size 须被指定的数整除。

a = torch.rand(6328)
a1, a2 = a.chunk(2, dim=0)  # 拆分为 2 块
a1.shape, a2.shape
(torch.Size([3, 32, 8]), torch.Size([3, 32, 8]))

6. 数学运算

6.1 加减乘除

张量的加减除可以直接使用 + - / 乘法:* 表示两个张量的对位元素相乘(element-wise);@ 和 torch.matmul() 表示矩阵乘法。

a = torch.full([2,3], 2)
b = torch.full([2,3], 3)

a, b, a+b, b-a, b/a, a*b, a@b.t()
(tensor([[2, 2, 2],
         [2, 2, 2]]),
 tensor([[3, 3, 3],
         [3, 3, 3]]),
 tensor([[5, 5, 5],
         [5, 5, 5]]),
 tensor([[1, 1, 1],
         [1, 1, 1]]),
 tensor([[1.5000, 1.5000, 1.5000],
         [1.5000, 1.5000, 1.5000]]),
 tensor([[6, 6, 6],
         [6, 6, 6]]),
 tensor([[18, 18],
         [18, 18]]))

6.2 矩阵乘法

多维张量的矩阵乘法只看最后两个维度。 多维张量的矩阵乘法可以使用 @ 和 torch.matmul(),两者是等价的。

a = torch.rand(432864)
b = torch.rand(436432)

(a @ b).shape, torch.matmul(a, b).shape
(torch.Size([4, 3, 28, 32]), torch.Size([4, 3, 28, 32]))

6.3 幂运算和 log 运算

a = torch.full([2,3], 2)
a.pow(3), a**3
(tensor([[8, 8, 8],
         [8, 8, 8]]),
 tensor([[8, 8, 8],
         [8, 8, 8]]))
b = torch.exp(a)  # 以 e 为底数,张量 a 的元素为指数
b
tensor([[7.3891, 7.3891, 7.3891],
        [7.3891, 7.3891, 7.3891]])
# log 对数运算
torch.log(b)  # log 代表 ln,以 2 为底数可以写为 log2,以 10 为底数可以写为 log10
tensor([[2., 2., 2.],
        [2., 2., 2.]])

6.4 约数运算

# 向下取整,向上取整,返回浮点数的整数部分,返回浮点数的小数部分
a = torch.tensor(3.14)
a.floor(), a.ceil(), a.trunc(), a.frac()
(tensor(3.), tensor(4.), tensor(3.), tensor(0.1400))
# 四舍五入
a = torch.tensor(3.499)
b = torch.tensor(3.5)
a.round(), b.round()
(tensor(3.), tensor(4.))

本文由 mdnice 多平台发布


http://www.kler.cn/a/398282.html

相关文章:

  • [Android]相关属性功能的裁剪
  • VSCode插件
  • 小程序19-微信小程序的样式和组件介绍
  • 网络安全SQL初步注入2
  • 【一键整合包及教程】AI照片数字人工具EchoMimic技术解析
  • 力扣刷题日记之150.逆波兰表达式求值
  • Guava Cache
  • SQLI LABS | Less-51 GET-Error Based-ORDER BY CLAUSE-String-Stacked Injectiion
  • 图像分割——Hough变换检测法
  • C语言——判断是不是字母
  • YOLOv7-0.1部分代码阅读笔记-train.py
  • SQLite 安装指南
  • MAC上的Office三件套报53错误解决方案(随笔记)
  • 【MogDB】MogDB5.2.0重磅发布第八篇-支持PLSQL编译全局缓存
  • 如何在 Ubuntu 上安装 Mattermost 团队协作工具
  • 【ArcGIS微课1000例】0127:计算城市之间的距离
  • 9.2 使用haarcascade_frontalface_default.xml分类器检测视频中的人脸,并框出人脸位置。
  • 企业项目级IDEA设置类注释、方法注释模板(仅增加@author和@date)
  • 你的服务器缓存中毒过么?
  • Essential Cell Biology--Fifth Edition--Chapter one (8)
  • ssm126基于HTML5的出租车管理系统+jsp(论文+源码)_kaic
  • 牛客周赛第一题2024/11/17日
  • 深入理解Flutter生命周期函数之StatefulWidget(一)
  • 【Qt聊天室】客户端实现总结
  • 华为欧拉系统使用U盘制作引导安装华为欧拉操作系统
  • Kubernetes 10 问,测测你对 k8s 的理解程度