文章目录
- 1 张量索引
- 1.1 简单行列索引和列表索引
- 1.2 布尔索引和多维索引
- 2 张量的形状操作
- 2.1 reshape函数
- 2.2 transpose和permute函数的使用
- 2.3 view和contiguous函数
- 2.4 squeeze和unsqueeze函数用法
- 2.5 张量更改形状小结
- 3 常见运算函数
1 张量索引
1.1 简单行列索引和列表索引
import torch
# 1. 简单行列索引
def test01():
# 固定随机数种子
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 30)
# 1.1 获得指定的某行元素
# print(data[2])
# 1.2 获得指定的某个列的元素
# 逗号前面表示行, 逗号后面表示列
# 冒号表示所有行或者所有列
# print(data[:, :])
# 表示获得第3列的元素
print(data[:, 2])
# 获得指定位置的某个元素
print(data[1, 2], data[1][2])
# 表示先获得前三行,然后再获得第三列的数据
print(data[:3, 2])
# 表示获得前三行的前两列
print(data[:3, :2])
# 2. 列表索引
def test02():
# 固定随机数种子
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data)
print('-' * 30)
# 如果索引的行列都是一个1维的列表,那么两个列表的长度必须相等
# print(data[[0, 1, 2], [2, 4]]) # 报错,索引位置都是一维,必须匹配
# 解决方法:如果不想前后维数一样,就采用二维数组
# 使用二维数组进行索引得到仍然为二维数组
print(data[[[0],[1],[2]],[2,4]])
# 1.表示获得 (0, 0)、(2, 1)、(3, 2) 三个位置的元素
# 使用一维数组进行索引,得到的是一维
print(data[[0, 2, 3], [0, 1, 2]])
# 2。表示获得 0、2、3 行的 0、1、2 列
# print(data[[[0], [2], [3]], [0, 1, 2]])
1.2 布尔索引和多维索引
import torch
def test01():
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data)
print(data > 3)
print(data[data > 3])
print(data[:,1] > 6)
print(data[data[:, 1] > 6])
print(data[:, data[1] > 3])
def test02():
torch.manual_seed(0)
data = torch.randint(0, 10, [3, 4, 5])
print(data)
print('-' * 30)
print(data[0, :, :])
print('-' * 30)
print(data[:, 0, :])
print('-' * 30)
print(data[:, :, 0])
print('-' * 30)
2 张量的形状操作
2.1 reshape函数
- 保证张量元素个数不变的情况下改变张量的形状
- 在神经网络中,不同层中的数据形状不同
import torch
def test():
torch.manual_seed(0)
data = torch.randint(0, 10, [4, 5])
print(data.shape, data.shape[0], data.shape[1])
new_data = data.reshape(2, 10)
print(new_data)
new_data = data.reshape(5, -1)
print(new_data)
new_data = data.reshape(-1, 2)
print(new_data)
2.2 transpose和permute函数的使用
- reshape函数更改形状,reshape会重新计算张量的维度,有时候不需要重新计算张量的维度,只要调整张量维度的顺序即可,可以使用transpose函数和permute函数
- transpose函数每次只能交换两个维度
- permute函数可以一次交换多个维度
import torch
def test01():
torch.manual_seed(0)
data = torch.randint(0, 10, [3, 4, 5])
new_data = data.reshape(4, 3, 5)
print(new_data.shape)
new_data = torch.transpose(data, 0, 1)
print(new_data.shape)
new_data = torch.transpose(data, 0, 1)
new_data = torch.transpose(new_data, 1, 2)
print(new_data.shape)
def test02():
torch.manual_seed(0)
data = torch.randint(0, 10, [3, 4, 5])
new_data = torch.permute(data, [1, 2, 0])
print(new_data.shape)
2.3 view和contiguous函数
- view函数改变张量的形状,只能用于存储在整块内存中的张量,具有一定的局限性。
- pytorch中有些张量是由不同的数据块组成,并没有存储在整块的内存中,view函数无法对于这种张量进行变形处理
- 一个张量经过了transpose或者permute函数的处理之后,就无法使用view函数进行形状操作
- 先用contiguous将非连续内存空间转换为连续内存空间,然后再使用view函数进行更改张量形状
import torch
def test01():
data = torch.tensor([[10, 20, 30], [40, 50, 60]])
data = data.view(3, 2)
print(data.shape)
print(data.is_contiguous())
def test02():
data = torch.tensor([[10, 20, 30], [40, 50, 60]])
print('是否连续:', data.is_contiguous())
data = torch.transpose(data, 0, 1)
print('是否连续:', data.is_contiguous())
data = data.contiguous().view(2, 3)
print(data)
2.4 squeeze和unsqueeze函数用法
- squeeze函数可以将维度为1的维度进行删除
- unsqueeze函数给张量增加维度为1的维度
import torch
def test01():
data = torch.randint(0, 10, [1, 3, 1, 5])
print(data.shape)
new_data = data.squeeze(0)
print(new_data.shape)
new_data = data.squeeze(2)
print(new_data.shape)
def test02():
data = torch.randint(0, 10, [3, 5])
print(data.shape)
new_data = data.unsqueeze(-1)
print(new_data.shape)
2.5 张量更改形状小结
- reshape 函数可以在保证张量数据不变的前提下改变数据的维度.
- transpose 函数可以实现交换张量形状的指定维度, permute 可以一次交换更多的维度.
- view 函数也可以用于修改张量的形状, 但是它要求被转换的张量内存必须连续,所以一般配合 contiguous 函数使用.
- squeeze 和 unsqueeze 函数可以用来增加或者减少维度.
3 常见运算函数
- mean()
- sum()
- pow(n)
- sqrt()
- exp()
- log() - 以e为底的对数
- log2()
- log10()
import torch
def test01():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data)
print(data.mean())
print(data.mean(dim=0))
print(data.mean(dim=1))
def test02():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data.sum())
print(data.sum(dim=0))
print(data.sum(dim=1))
def test03():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data)
data = data.pow(2)
print(data)
def test04():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data)
data = data.sqrt()
print(data)
def test05():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data)
data = data.exp()
print(data)
def test06():
torch.manual_seed(0)
data = torch.randint(0, 10, [2, 3]).double()
print(data)
data = data.log()
data = data.log2()
data = data.log10()
print(data)