当前位置: 首页 > article >正文

Tensor 基本操作1 | PyTorch 深度学习实战

目录

    • 创建 Tensor
    • 常用操作
      • unsqueeze
      • squeeze
      • Softmax
        • 代码1
        • 代码2
        • 代码3
      • argmax
      • item

创建 Tensor

使用 Torch 接口创建 Tensor

import torch

参考:https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

常用操作

unsqueeze

将多维数组解套,并嵌入新的一层维度。

    data = [[1, 2],[3, 4]]
    x_data = torch.tensor(data)
    print("x_data")
    print(x_data)

    x2_data = x_data.unsqueeze(-1)
    print("x_data>> unsqueeze -1")
    print(x2_data)

    x2_data = x_data.unsqueeze(0)
    print("x_data>> unsqueeze 0")
    print(x2_data)

    x2_data = x_data.unsqueeze(1)
    print("x_data>> unsqueeze 1")
    print(x2_data)

    x2_data = x_data.unsqueeze(2)
    print("x_data>> unsqueeze 2")
    print(x2_data)

结果:

x_data
tensor([[1, 2],
        [3, 4]])
x_data>> unsqueeze -1   # -1 代表最内层,将最内层的数用一个新的维度包起来
tensor([[[1],
         [2]],

        [[3],
         [4]]])
x_data>> unsqueeze 0 # 0 代表最外层,将原来的多维数组整个多套一层
tensor([[[1, 2],
         [3, 4]]])
x_data>> unsqueeze 1 # 代表原来第一维里的每个元素,套一层
tensor([[[1, 2]],

        [[3, 4]]])
x_data>> unsqueeze 2 # 代表原来第二维里的每个元素,套一层
tensor([[[1],        # 当前一共两维,所以效果和 -1 一样
         [2]],

        [[3],
         [4]]])

squeeze

去掉指定或全部的维度中只有一个元素的多维数组。

比如输入为 Ax1xBxCx1xD 维的数组,输出变成了 AxBxCxD 维的数组。

https://pytorch.org/docs/stable/generated/torch.squeeze.html
在这里插入图片描述

    data = [[1], [2],[3], [4]]
    x_data = torch.tensor(data)
    print("x_data")
    print(x_data)

    x2_data = x_data.squeeze()
    print("x_data>> squeeze")
    print(x2_data)

    x2_data = x_data.squeeze(1)
    print("x_data>> squeeze 1")
    print(x2_data)

结果:

x_data
tensor([[1],
        [2],
        [3],
        [4]])
x_data>> squeeze
tensor([1, 2, 3, 4])
x_data>> squeeze 1
tensor([1, 2, 3, 4])

Softmax

https://pytorch.org/docs/stable/generated/torch.softmax.html

归一化操作。
在这里插入图片描述

代码1
    data = torch.tensor([1,2,3], dtype=torch.float) # 维度 3; 注意,此处 dtype 是 int 或 long 接口报错
    x_data = torch.softmax(data, 0)
    print("x_data")
    print(x_data)

结果:

x_data
tensor([0.0900, 0.2447, 0.6652])  # 维度 3
代码2
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1
    x_data2 = torch.softmax(data, 0)
    print("x_data2")
    print(x_data2)

结果:

x_data2  # 维度 3x1
tensor([[0.0900],
        [0.2447],
        [0.6652]])
代码3
    data = torch.tensor([[1],[2],[3]], dtype=torch.float) # 维度 3x1
    x_data2 = torch.softmax(data, 1) # 沿着第一维求
    print("x_data2")
    print(x_data2)

结果:

x_data2
tensor([[1.],
        [1.],
        [1.]])

此时,每维都是 1 个元素,针对自身求 softmax,所以,结果是 1.

argmax

https://pytorch.org/docs/stable/generated/torch.argmax.html

返回一个多维数组的最大值的索引,如果是多维数组,则返回第一维的索引。

在这里插入图片描述

item

https://pytorch.org/docs/stable/generated/torch.Tensor.item.html
返回一个 Tensor 中携带的 Python Number 对象。该接口只对 Tensor 是一维的有效。

x = torch.tensor([1.0])
x.item()

http://www.kler.cn/a/511880.html

相关文章:

  • AI Agent:深度解析与未来展望
  • 服务器一次性部署One API + ChatGPT-Next-Web
  • 大语言模型的语境中“越狱”和思维链
  • JavaScript中提高效率的技巧一
  • 【PyCharm】连接 Git
  • Crewai + langchain 框架配置第三方(非原生/国产)大模型API
  • 【Rust自学】13.9. 使用闭包和迭代器改进IO项目
  • 无监督<视觉-语言>模型中的跨模态对齐
  • vue按照官网设置自动导入后ElMessageBox不生效问题
  • 从零开始:Spring Boot核心概念与架构解析
  • springboot迅捷外卖配送系统
  • STM32CubeIDE使用笔记(一)
  • 【Spring】原型 Bean 被固定
  • 【25】Word:林涵-科普文章❗
  • yum和vim的使用
  • 【Elasticsearch入门到落地】6、索引库的操作
  • Matlab自学笔记四十五:日期时间型和字符、字符串以及double型的相互转换方法
  • React 中hooks之 React useCallback使用方法总结
  • Java 基于微信小程序的原创音乐小程序设计与实现(附源码,部署,文档)
  • Centos7搭建PHP项目,环境(Apache+PHP7.4+Mysql5.7)
  • ubuntu系统文件查找、关键字搜索
  • 2024:成长、创作与平衡的年度全景回顾
  • RabbitMQ---事务及消息分发
  • 【Redis】5种基础数据结构介绍及应用
  • 【MCU】CH591用软件 I2C 出现的 bug
  • 我的创作纪念日——我与CSDN一起走过的365天