当前位置: 首页 > article >正文

24.11.21深度学习

Tensor常见操作

获取元素值

我们可以把单个元素tensor转换为Python数值,这是非常常用的操作

import torch

# 标量
t1 = torch.tensor(99)
print(t1.item())

# 一阶
t2 = torch.tensor([99])
print(t2.item())

注意:

  • 和Tensor的维度没有关系,都可以取出来!

  • 如果有多个元素则报错;

元素值运算

常见的加减乘除次方取反开方等各种操作,带有_的方法则会替换原始值

 

import torch

def test1():
    # 设置随机种子
    torch.manual_seed(666)
    t1 = torch.randint(1,10,(3,3))
    print(t1)
    print("--------")
    
    # 不修改原数据的操作
    print(t1.add(10))  # 加 add
    print(t1.sub(10))  # 减 sub
    print(t1.mul(2))  # 乘 mul
    print(t1.div(2))  # 除 div
    print(t1)
    
    # 修改原数据的操作
    print("------------")
    t1 = t1.float()
    t1.add_(10)# 加 add
    print(t1)
    t1.sub_(10)# 减 sub
    print(t1)
    t1.mul_(2)# 乘 mul
    print(t1)
    t1.div_(2)# 除 div
    print(t1)
    
    
if __name__ == '__main__':
    test1()

阿达玛积

阿达玛积指的是矩阵对应位置的元素相乘,可以使用mul函数或者*来实现;

Tensor相乘

点积运算将两个向量映射为一个标量,是向量之间的基本操作。

点积运算要求如果第一个矩阵的shape是 (N, M),那么第二个矩阵 shape必须是 (M, P),最后两个矩阵点积运算的shape为 (N, P)。

使用@或者matmul完成Tensor的乘法。

mm方法也可以用于矩阵相乘 但是只能用于2维矩阵即:m*k和k*n 得到m*n 的矩阵

 

import torch


def test1():
    # 阿达玛积指的是矩阵对应位置的元素相乘
    # 可以使用mul函数或者*来实现
    t1 = torch.tensor([[1, 2],
                       [3, 4]])
    t2 = torch.tensor([[1, 2],
                       [3, 4]])
    t3 = t1 * t2
    print(t3)


def test2():
    # @  , matmul, mm 点乘运算
    # 其中mm方法也可以用于矩阵相乘(仅限于二维矩阵相乘)
    x1 = torch.tensor([[1, 2],
                       [3, 4]])
    x2 = torch.tensor([[1, 2],
                       [3, 4]])
    x3 = x1.matmul(x2)
    x4 = x1 @ x2
    x5 = x1.mm(x2)
    print(x3)
    print(x4)
    print(x5)


if __name__ == '__main__':
    # test1()
    test2()

索引操作

掌握张量的花式索引在处理复杂数据时非常有用。花式索引可以让你灵活地访问、修改张量中的特定元素或子集,从而简化代码并提高操作效率。

简单索引

索引,就是根据指定的下标选取数据。

列表索引

使用list批量的制定要索引的元素位置~此时注意list的维度!

布尔索引

根据条件选择张量中的元素。

索引赋值

使用花式索引轻松进行批量元素值修改

import torch


def test1():
    # 简单索引
    x = torch.randint(1, 10, (5, 5, 3))
    print(x)
    print("形状:", x.shape)
    # 行索引
    print(x[1])
    # 列索引
    print(x[:, 2])
    # 固定位置索引:两种方式都行
    print(x[1, 2, 2])
    print(x[1, 2, 2].item())

    # 根据范围取
    print(x[0:2])  # 取索引为0和1的数据
    print(x[0:2, 0:2, 0:2])  # 第一二个矩阵的每一个矩阵的第一二行的第一二列

    print('--------')
    # 列表索引
    # x[n ,m ]  n和m位置所占的位置是张量的第一维和第二维,里面放列表,可以选特定的索引(index)
    print(x[[1, 3], 1])  # 第一维选索引1和3,在此基础上,第二维选索引为1
    print(x[[1, 3]])  # 索引为1和3的矩阵
    print(x[[1, 3], [0, 2]])  # 索引为1和3的矩阵的 分别对应索引为0和2的行
    print(x[[1, 3], [0, 2], [1, 2]])  # 索引为1和3的矩阵 的分别对应索引为0和2的行 的分别对应索引为1和2的列


def tset2():
    # 布尔运算
    torch.manual_seed(666)
    x = torch.randint(1, 10, (5, 5))
    print(x)
    x2 = x > 5
    print(x2)
    x3 = x[x2]
    print(x3)


def test3():
    torch.manual_seed(666)
    x = torch.randint(1, 10, (5, 5))
    print(x)
    print(x[x[:, 4] > 1])  # x[:, 4] > 1意思是所有行里面选出索引为4的元素,在矩阵里可以表示的是第五列的所有元素

def test4():
    torch.manual_seed(66)
    x = torch.randint(1, 10, (5, 5))
    print(x)
    # 读取指定位置的数
    x2 = x[1, 1]
    print(x2)
    # 把刚刚那个位置的数修改为100
    x[1,1] = 100
    print(x)
    # 将第一维的所有,且第二维的索引为3的元素修改为200
    # 在矩阵中可以看成将所有行中,选中索引为3的列修改为200
    # 即第四列的所有元素修改为200
    x[:,3] = 200
    print(x)
    

if __name__ == '__main__':
    # test1()
    # tset2()
    test3()
    # test4()

张量拼接

在 PyTorch 中,cat 和 stack 是两个用于拼接张量的常用操作,但它们的使用方式和结果略有不同:

  • cat:在现有维度上拼接,不会增加新维度。

  • stack:在新维度上堆叠,会增加一个维度。

torch.cat

元素级别的

torch.cat(concatenate 的缩写)用于沿现有维度拼接张量。换句话说,它在现有的维度上将多个张量连接在一起。

 注意:要拼接的张量在除了指定拼接的维度之外的所有维度上的大小必须相同。

torch.stack

张量级别的

torch.stack 用于在新维度上拼接张量。换句话说,它会增加一个新的维度,然后沿指定维度堆叠张量。

注意:要堆叠的张量必须具有相同的形状。 技巧:堆叠指一人出一个交替添加 拼接指一人出完下个人在出完  

import torch
from PIL import Image
from torchvision import transforms

def test1():
    torch.manual_seed(66)
    x = torch.randint(1,10,(3,3))
    y = torch.randint(1,10,(3,3))
    print(x)
    print(y)
    z = torch.cat([x,y],dim=1)  # dim 指定维度来进行堆叠
    print(z)
    
    
def test2():
    torch.manual_seed(66)
    x = torch.randint(1,10,(3,3))
    y = torch.randint(1,10,(3,3))
    print(x,x.shape)
    print(y,y.shape)
    # torch.stack 用于在新维度上拼接张量。换句话说,它会增加一个新的维度,然后沿指定维度堆叠张量
    z = torch.stack([x,y],dim=0)
    print(z,z.shape)
    
def test3():
    # 加载本地图片
    img_pil = Image.open("./data/1.png")
    # 转换为张量
    transfer = transforms.ToTensor()
    img_tensor = transfer(img_pil)
    print(img_tensor,img_tensor.shape)
    res = torch.stack([img_tensor[0],img_tensor[1],img_tensor[2]],dim=2)
    print(res,res.shape)
if __name__ == '__main__':
    # test1()
    test2()
    # test3()

形状操作

在 PyTorch 中,张量的形状操作是非常重要的,因为它允许你灵活地调整张量的维度和结构,以适应不同的计算需求。

 reshape

可以用于将张量转换为不同的形状,但要确保转换后的形状与原始形状具有相同的元素数量。

view

view进行形状变换的特征:

  • 张量在内存中是连续的;

  • 返回的是原始张量视图,不重新分配内存,效率更高;

  • 如果张量在内存中不连续,view 将无法执行,并抛出错误。

 内存连续性

我们在进行变形或转置操作时,很容易造成内存的不连续性。

 

和reshape比较

view:高效,但需要张量在内存中是连续的;

reshape:更灵活,但涉及内存复制;

view变形操作

transpose

transpose 用于交换张量的两个维度,注意,是2个维度,它返回的是原张量的视图。

permute

permute 用于改变张量的所有维度顺序。与 transpose 类似,但它可以交换多个维度。

flatten

flatten 用于将张量展平为一维向量。

tensor.flatten(start_dim=0, end_dim=-1)
  • start_dim:从哪个维度开始展平。

  • end_dim:在哪个维度结束展平。默认值为 -1,表示展平到最后一个维度。

 

升维和降维

在后续的网络学习中,升维和降维是常用操作,需要掌握。

  • unsqueeze:用于在指定位置插入一个大小为 1 的新维度。

  • squeeze:用于移除所有大小为 1 的维度,或者移除指定维度的大小为 1 的维度。

 

import torch


def test1():
    t1 = torch.randint(1, 10, (4, 3))
    print(t1)
    t2 = torch.reshape(t1, (2, 6))
    print(t2)
    x3 = torch.reshape(t1, (2, 2, 3))
    print(x3)
    # x4 = torch.reshape(t1,(3,5))
    # print(x4)
    # 改变形状之后数量不能改变

    # 填入-1可以自动匹配合适的参数
    x5 = torch.reshape(t1, (-1, 6))
    print(x5)
    x6 = torch.reshape(t1, (2, 2, -1))  # 可以且只有一个维度可以填-1
    # x6 = torch.reshape(t1,(2,-1,-1))  这样写就不可以
    print(x6)
    
def test2():
    # 连续性的可以view
    t1 = torch.randint(1, 10, (4, 3))
    print(t1)
    t2 = t1.view((2,6))  # 改变形状,由于没有改变原t1中的数据内存空间
    # 所以它改变形状比 reshape 快
    print(t2)
    
    t3 = torch.randint(1, 10, (4, 3))
    # t4 = torch.reshape(t3,(3,4))
    # t5 = t4.t()  # 转置之后不连续了,不能 view
    # t6 = t5.view((2,6))
    # print(t6)

def test3():
    t1 = torch.randint(1, 10, (4, 3))
    print(t1)
    t2 = torch.transpose(t1,0,1)
    print(t2,t2.shape)


def test4():
    t1 = torch.randint(0, 255, (3, 512, 360))
    print(t1)  # (C,h,w)
    t2 = t1.permute(1,2,0)  # (h,w,C)
    print(t2,t2.shape)


def test5():
    t1 = torch.randint(0, 255, (3, 4))
    print(t1)
    t2 = t1.flatten()
    print(t2)
    
    t3 = torch.randint(0, 255, (3, 4, 2, 2))
    t4 = t3.flatten(start_dim=1,end_dim=2) # 从第二维度开始展平,从第三维度就不展平了
    print(t3)
    print(t4)
    
def test6():
    # 数据升维
    t1 = torch.randint(0, 255, (1,3,4,1))
    print(t1)
    t2 = t1.unsqueeze(0)
    print(t2)
    print(t2.shape)
    
    # 降维操作
    t2 = t1.squeeze(0)
    print(t2)
    print(t2.shape)
    
    
def test8():
    x=torch.randint(0,255,(21,4))
    x2=torch.split(x,2)#每个tensor有2行
    print(x2)
    x3=torch.chunk(x,2)
    print(x3)
    
    
    
if __name__ == '__main__':
    # test1()
    # test2()
    # test3()
    # test4()
    # test5()
    # test6()
    test8()


http://www.kler.cn/a/409943.html

相关文章:

  • MySQL中的ROW_NUMBER窗口函数简单了解下
  • 安宝特方案 | AR助力紧急救援,科技守卫生命每一刻!
  • jupyter notebook的 markdown相关技巧
  • 用 Python 从零开始创建神经网络(九):反向传播(Backpropagation)
  • C0034.在Ubuntu中安装的Qt路径
  • AIGC学习笔记(6)——AI大模型开发工程师
  • .NET Core发布网站报错 HTTP Error 500.31
  • 视频分析设备平台EasyCVR视频设备轨迹回放平台与应急布控球的视频监控方案
  • 嵌入式硬件杂谈(六)充电器原理 线性电源 开关电源 反激电源原理
  • 论文阅读:A fast, scalable and versatile tool for analysis of single-cell omics data
  • node.js nvm 安装和使用
  • 前端面试笔试(五)
  • 网络安全等级保护测评机构管理办法(全文)
  • 【前端学习笔记】Web API——BOM与DOM
  • Python 版本的 2024详细代码
  • AI安全:从现实关切到未来展望
  • Jmeter中的监听器
  • 信息收集(1)
  • WPF中的Button按钮中的PreviewMouseLeftButtonDown事件和MouseLeftButtonDown的区别
  • ThingsBoard规则链节点:AWS SNS 节点详解
  • C语言:树
  • tableau练习-制作30个图表
  • elementui el-input修改字体样式
  • Excel把其中一张工作表导出成一个新的文件
  • 1-golang_org_x_crypto_bcrypt测试 --go开源库测试
  • 论文笔记 SliceGPT: Compress Large Language Models By Deleting Rows And Columns