11.21Pytorch_属性常见转换操作
四、Tensor常见属性
张量有device、dtype、shape等常见属性,知道这些属性对我们认识Tensor很有帮助。
1. 获取属性
掌握代码调试就是掌握了一切~
import torch
def test001():
data = torch.tensor([1, 2, 3])
print(data.dtype, data.device, data.shape)
if __name__ == "__main__":
test001()
2. 切换设备
默认在cpu上运行,可以显式的切换到GPU:不同设备上的数据是不能相互运算的。
import torch
def test001():
data = torch.tensor([1, 2, 3])
print(data.dtype, data.device, data.shape)
# 把数据切换到GPU进行运算
device = "cuda" if torch.cuda.is_available() else "cpu"
data = data.to(device)
print(data.device)
if __name__ == "__main__":
test001()
或者使用cuda进行切换:
data = data.cuda()
当然也可以直接创建在GPU上:
# 直接在GPU上创建张量
data = torch.tensor([1, 2, 3], device='cuda')
print(data.device)
3. 类型转换
在训练模型或推理时,类型转换也是张量的基本操作,是需要掌握的。
import torch
def test001():
data = torch.tensor([1, 2, 3])
print(data.dtype) # torch.int64
# 1. 使用type进行类型转换
data = data.type(torch.float32)
print(data.dtype) # float32
data = data.type(torch.float16)
print(data.dtype) # float16
# 2. 使用类型方法
data = data.float()
print(data.dtype) # float32
data = data.half()
print(data.dtype) # float16
data = data.double()
print(data.dtype) # float64
data = data.long()
print(data.dtype) # int64
if __name__ == "__main__":
test001()
只是看着多~
五、Tensor数据转换
1. Tensor与Numpy
Tensor和Numpy都是常见数据格式,惹不起~
1.1 张量转Numpy
此时分内存共享和内存不共享~
1.1.1 浅拷贝
调用numpy()方法可以把Tensor转换为Numpy,此时内存是共享的。
import torch
def test003():
# 1. 张量转numpy
data_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
data_numpy = data_tensor.numpy()
print(type(data_tensor), type(data_numpy))
# 2. 他们内存是共享的
data_numpy[0, 0] = 100
print(data_tensor, data_numpy)
if __name__ == "__main__":
test003()
1.1.2 深拷贝
使用copy()方法可以避免内存共享:
import torch
def test003():
# 1. 张量转numpy
data_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 2. 使用copy()避免内存共享
data_numpy = data_tensor.numpy().copy()
print(type(data_tensor), type(data_numpy))
# 3. 此时他们内存是不共享的
data_numpy[0, 0] = 100
print(data_tensor, data_numpy)
if __name__ == "__main__":
test003()
1.2 Numpy转张量
也可以分为内存共享和不共享~
1.2.1 浅拷贝
from_numpy方法转Tensor默认是内存共享的
import numpy as np
import torch
def test006():
# 1. numpy转张量
data_numpy = np.array([[1, 2, 3], [4, 5, 6]])
data_tensor = torch.from_numpy(data_numpy)
print(type(data_tensor), type(data_numpy))
# 2. 他们内存是共享的
data_tensor[0, 0] = 100
print(data_tensor, data_numpy)
if __name__ == "__main__":
test006()
1.2.2 深拷贝
使用传统的torch.tensor()则内存是不共享的~
import numpy as np
import torch
def test006():
# 1. numpy转张量
data_numpy = np.array([[1, 2, 3], [4, 5, 6]])
data_tensor = torch.tensor(data_numpy)
print(type(data_tensor), type(data_numpy))
# 2. 内存是不共享的
data_tensor[0, 0] = 100
print(data_tensor, data_numpy)
if __name__ == "__main__":
test006()
2. Tensor与图像
图像是我们视觉处理中最常见的数据,惹不起…
2.1 图片转Tensor
import torch
from PIL import Image
from torchvision import transforms
def test001():
imgpath = r"./105429.jpg"
# 1. 读取图片
img = Image.open(imgpath)
# 使用transforms.ToTensor()将图片转换为张量
transform = transforms.ToTensor()
img_tensor = transform(img)
print(img_tensor)
if __name__ == "__main__":
test001()
2.2 Tensor转图片
import torch
from PIL import Image
from torchvision import transforms
def test002():
# 1. 随机一个数据表示图片
img_tensor = torch.randn(3, 224, 224)
# 2. 创建一个transforms
transform = transforms.ToPILImage()
# 3. 转换为图片
img = transform(img_tensor)
img.show()
# 4. 保存图片
img.save("./test.jpg")
if __name__ == "__main__":
test002()
3. PyTorch图像处理
通过一个Demo加深对Torch的API理解和使用
import torch
from PIL import Image
from torchvision import transforms
def test003():
# 指定读取的文件路径
imgpath = r"./105429.jpg"
# 加载图片
img = Image.open(imgpath)
# 图像转为Tensor
transform = transforms.ToTensor()
img_tensor = transform(img)
# 去掉透明度值
print(img_tensor.shape)
# 检查CUDA是否可用并将tensor移至CUDA
if torch.cuda.is_available():
img_tensor = img_tensor.cuda()
print(img_tensor.device)
# 修改每个像素值
img_tensor += 0.2
# 将tensor移回CPU并转换回PIL图像
img_tensor = img_tensor.cpu()
transform = transforms.ToPILImage()
img = transform(img_tensor)
# 保存图像
img.save("./ok.png")
if __name__ == "__main__":
test003()
效果:
六、Tensor常见操作
在深度学习中,Tensor是一种多维数组,用于存储和操作数据,我们需要掌握张量各种运算。
1. 获取元素值
我们可以把单个元素tensor转换为Python数值,这是非常常用的操作
import torch
def test002():
data = torch.tensor([18])
print(data.item())
pass
if __name__ == "__main__":
test002()
注意:
- 和Tensor的维度没有关系,都可以取出来!
- 如果有多个元素则报错;
2. 元素值运算
常见的加减乘除次方取反开方等各种操作,带有_的方法则会替换原始值。
import torch
def test001():
data = torch.randint(0, 10, (2, 3))
print(data)
# 元素级别的加减乘除:不修改原始值
print(data.add(1))
print(data.sub(1))
print(data.mul(2))
print(data.div(3))
print(data.pow(2))
# 元素级别的加减乘除:修改原始值
data = data.float()
data.add_(1)
data.sub_(1)
data.mul_(2)
data.div_(3.0)
data.pow_(2)
print(data)
if __name__ == "__main__":
test001()
3. 阿达玛积
阿达玛积指的是矩阵对应位置的元素相乘,可以使用mul函数或者*来实现;
import torch
def test001():
data1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
data2 = torch.tensor([[2, 3, 4], [2, 2, 3]])
print(data1 * data2)
def test002():
data1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
data2 = torch.tensor([[2, 3, 4], [2, 2, 3]])
print(data1.mul(data2))
if __name__ == "__main__":
test001()
test002()
4. Tensor相乘
点积运算将两个向量映射为一个标量,是向量之间的基本操作。
点积运算要求如果第一个矩阵的shape是 (N, M),那么第二个矩阵 shape必须是 (M, P),最后两个矩阵点积运算的shape为 (N, P)。
使用@或者matmul完成Tensor的乘法。
mm方法也可以用于矩阵相乘 但是只能用于2维矩阵即: m ∗ k m*k m∗k和 k ∗ n k*n k∗n 得到 m ∗ n m*n m∗n 的矩阵
import torch
def test006():
data1 = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
data2 = torch.tensor([
[3, 2],
[2, 3],
[5, 3]
])
print(data1 @ data2)
print(data1.matmul(data2))
print(data1.mm(data2))
if __name__ == "__main__":
test006()
5. 索引操作
掌握张量的花式索引在处理复杂数据时非常有用。花式索引可以让你灵活地访问、修改张量中的特定元素或子集,从而简化代码并提高操作效率。
5.1 简单索引
索引,就是根据指定的下标选取数据。
import torch
def test006():
data = torch.randint(0, 10, (3, 4))
print(data)
# 1. 行索引
print("行索引:", data[0])
# 2. 列索引
print("列索引:", data[:, 0])
# 3. 固定位置索引:2种方式都行
print("索引:", data[0, 0], data[0][0])
if __name__ == "__main__":
test006()
5.2 列表索引
使用list批量的制定要索引的元素位置~此时注意list的维度!
import torch
def test008():
data = torch.randint(0, 10, (3, 4))
print(data)
# 1. 使用列表进行索引:(0, 0), (1, 1), (2, 1)
print("列表索引:", data[[0, 1, 2], [0, 1, 1]])
# 2. 行级别的列表索引
print("行级别列表索引:", data[[[2], [1]], [0, 1, 2]])
if __name__ == "__main__":
test008()
5.3 布尔索引
根据条件选择张量中的元素。
import torch
def test009():
tensor = torch.tensor([1, 2, 3, 4, 5])
mask = tensor > 3
print(mask)
print(tensor[mask]) # 输出: tensor([4, 5])
if __name__ == "__main__":
test009()
行级别的条件索引
import torch
def test100():
data = torch.randint(0, 10, (3, 4))
print(data)
# 1. 索引第3个元素大于3的所有行
print(data[data[:, 2] > 3])
# 2. 索引第3行 值大于3 的所有的元素 所在的列
print(data[:, data[2] > 3])
# 3. 第二列是偶数, 且第一列大于6的行
print(data[(data[:, 1] % 2 == 0) & (data[:, 0] > 6)])
if __name__ == "__main__":
test100()
5.4 索引赋值
使用花式索引轻松进行批量元素值修改~
import torch
def test666():
data = torch.eye(4)
print(data)
# 赋值
data[:, 1:-1] = 0
print(data)
if __name__ == "__main__":
test666()
6. 张量拼接
在 PyTorch 中,cat 和 stack 是两个用于拼接张量的常用操作,但它们的使用方式和结果略有不同:
- cat:在现有维度上拼接,不会增加新维度。
- stack:在新维度上堆叠,会增加一个维度。
6.1 torch.cat
元素级别的
torch.cat(concatenate 的缩写)用于沿现有维度拼接张量。换句话说,它在现有的维度上将多个张量连接在一起。
import torch
def test001():
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 1. 在指定的维度上进行拼接:0
print(torch.cat([tensor1, tensor2], dim=0))
# 输出:
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
# 2. 在指定的维度上进行拼接:1
print(torch.cat([tensor1, tensor2], dim=1))
# 输出:
# tensor([[ 1, 2, 3, 7, 8, 9],
# [ 4, 5, 6, 10, 11, 12]])
if __name__ == "__main__":
test001()
注意:要拼接的张量在除了指定拼接的维度之外的所有维度上的大小必须相同。
6.2 torch.stack
张量级别的
torch.stack 用于在新维度上拼接张量。换句话说,它会增加一个新的维度,然后沿指定维度堆叠张量。
import torch
def test002():
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 1. 沿新创建的第0维度堆叠:从第一层开始一人出一个数据 堆叠
print(torch.stack([tensor1, tensor2], dim=0))
# 输出:
# tensor([[[ 1, 2, 3],
# [ 4, 5, 6]],
# [[ 7, 8, 9],
# [10, 11, 12]]])
# 2. 沿新创建的第1维度堆叠:从第二层开始一人出一个数据 堆叠
print(torch.stack([tensor1, tensor2], dim=1))
# 输出:
# tensor([[[ 1, 2, 3],
# [ 7, 8, 9]],
# [[ 4, 5, 6],
# [10, 11, 12]]])
# 2. 沿新创建的第2维度堆叠:从第三层开始一人出一个数据 堆叠
print(torch.stack([tensor1, tensor2], dim=2))
if __name__ == "__main__":
test002()
注意:要堆叠的张量必须具有相同的形状。 技巧:堆叠指一人出一个交替添加 拼接指一人出完下个人在出完
7. 形状操作
在 PyTorch 中,张量的形状操作是非常重要的,因为它允许你灵活地调整张量的维度和结构,以适应不同的计算需求。
7.1 reshape
可以用于将张量转换为不同的形状,但要确保转换后的形状与原始形状具有相同的元素数量。
import torch
def test001():
data = torch.randint(0, 10, (4, 3))
print(data)
# 1. 使用reshape改变形状
data = data.reshape(2, 2, 3)
print(data)
# 2. 使用-1表示自动计算
data = data.reshape(2, -1)
print(data)
if __name__ == "__main__":
test001()
7.2 view
view进行形状变换的特征:
- 张量在内存中是连续的;
- 返回的是原始张量视图,不重新分配内存,效率更高;
- 如果张量在内存中不连续,view 将无法执行,并抛出错误。
7.2.1 内存连续性
我们在进行变形或转置操作时,很容易造成内存的不连续性。
import torch
def test001():
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("正常情况下的张量:", tensor.is_contiguous())
# 对张量进行转置操作
tensor = tensor.t()
print("转置操作的张量:", tensor.is_contiguous())
print(tensor)
# 此时使用view进行变形操作
tensor = tensor.view(2, -1)
print(tensor)
if __name__ == "__main__":
test001()
执行结果:
正常情况下的张量: True
转置操作的张量: False
tensor([[1, 4],
[2, 5],
[3, 6]])
Traceback (most recent call last):
File "e:\01.深度学习\01.参考代码\14.PyTorch.内存连续性.py", line 20, in <module>
test001()
File "e:\01.深度学习\01.参考代码\14.PyTorch.内存连续性.py", line 13, in test001
tensor = tensor.view(2, -1)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
7.2.2 和reshape比较
view:高效,但需要张量在内存中是连续的;
reshape:更灵活,但涉及内存复制;
7.2.3 view变形操作
import torch
def test002():
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 将 2x3 的张量转换为 3x2
reshaped_tensor = tensor.view(3, 2)
print(reshaped_tensor)
# 自动推断一个维度
reshaped_tensor = tensor.view(-1, 2)
print(reshaped_tensor)
if __name__ == "__main__":
test002()
7.3 transpose
transpose 用于交换张量的两个维度,注意,是2个维度,它返回的是原张量的视图。
import torch
def test003():
data = torch.randint(0, 10, (3, 4, 5))
print(data, data.shape)
# 使用transpose进行形状变换
transpose_data = data.transpose(0, 1)
print(transpose_data, transpose_data.shape)
if __name__ == "__main__":
test003()
7.4 permute
permute 用于改变张量的所有维度顺序。与 transpose 类似,但它可以交换多个维度。
import torch
def test004():
data = torch.randint(0, 10, (3, 4, 5))
print(data, data.shape)
# 使用permute进行多维度形状变换
permute_data = data.permute(1, 2, 0)
print(permute_data, permute_data.shape)
if __name__ == "__main__":
test004()
7.5 flatten
flatten 用于将张量展平为一维向量。
tensor.flatten(start_dim=0, end_dim=-1)
- start_dim:从哪个维度开始展平。
- end_dim:在哪个维度结束展平。默认值为
-1
,表示展平到最后一个维度。
import torch
def test005():
data = torch.randint(0, 10, (3, 4, 5))
# 展平
flatten_data = data.flatten(1, -1)
print(flatten_data)
if __name__ == "__main__":
test005()
7.6 升维和降维
在后续的网络学习中,升维和降维是常用操作,需要掌握。
-
unsqueeze:用于在指定位置插入一个大小为 1 的新维度。
-
squeeze:用于移除所有大小为 1 的维度,或者移除指定维度的大小为 1 的维度。
7.6.1 squeeze降维
import torch
def test006():
data = torch.randint(0, 10, (1, 4, 5, 1))
print(data, data.shape)
# 进行降维操作
data = data.squeeze(0).squeeze(-1)
print(data.shape)
if __name__ == "__main__":
test006()
7.6.2 unsqueeze升维
import torch
def test007():
data = torch.randint(0, 10, (32, 32, 3))
print(data.shape)
# 升维操作
data = data.unsqueeze(0)
print(data.shape)
if __name__ == "__main__":
test007()
8. 张量分割
可以按照指定的大小或者块数进行分割。
import torch
def test001():
# 创建一个张量
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]])
# 分割成3块
print(torch.chunk(x, 3))
# 按照每块大小为4进行分割
print(torch.split(x, 4))
if __name__ == "__main__":
test001()
9. 广播机制
广播机制允许在对不同形状的张量进行计算,而无需显式地调整它们的形状。广播机制通过自动扩展较小维度的张量,使其与较大维度的张量兼容,从而实现按元素计算。
9.1 广播机制规则
广播机制需要遵循以下规则:
- 每个张量的维度至少为1
- 满足右对齐
9.2 广播案例
1D和2D张量广播
import torch
def test006():
data1d = torch.tensor([1, 2, 3])
data2d = torch.tensor([[4], [2], [3]])
print(data1d.shape, data2d.shape)
# 进行计算:会自动进行广播机制
print(data1d + data2d)
if __name__ == "__main__":
test006()
输出:
torch.Size([3]) torch.Size([3, 1])
tensor([[5, 6, 7],
[3, 4, 5],
[4, 5, 6]])
2D 和 3D 张量广播
广播机制会根据需要对两个张量进行形状扩展,以确保它们的形状对齐,从而能够进行逐元素运算。广播是双向奔赴的。
import torch
def test001():
# 2D 张量
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 3D 张量
b = torch.tensor([[[2, 3, 4]], [[5, 6, 7]]])
print(a.shape, b.shape)
# 进行运算
result = a + b
print(result, result.shape)
if __name__ == "__main__":
test001()
执行结果:
torch.Size([2, 3]) torch.Size([2, 1, 3])
tensor([[[ 3, 5, 7],
[ 6, 8, 10]],
[[ 6, 8, 10],
[ 9, 11, 13]]]) torch.Size([2, 2, 3])
最终参与运算的a和b形式如下:
# 2D 张量
a = torch.tensor([[[1, 2, 3], [4, 5, 6]],[[1, 2, 3], [4, 5, 6]]])
# 3D 张量
b = torch.tensor([[[2, 3, 4], [2, 3, 4]], [[5, 6, 7], [5, 6, 7]]])