当前位置: 首页 > article >正文

2、PyTorch张量的运算API(上)

1. 教学视频

2、PyTorch张量的运算API(上)

  • 因比较忙,暂时就做个过场吧。

2. Python代码

  • Python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :torch_learn2.py
# @Time      :2024/11/16 19:53
# @Author    :Jason Zhang

import torch

torch.manual_seed(12124)

if __name__ == "__main__":
    run_code = 0
    a = torch.rand([3, 2])
    a_chunk1, a_chunk2 = torch.chunk(a, chunks=2)
    a_chunk11, a_chunk21 = torch.chunk(a, chunks=2, dim=1)
    print(f"a=\n{a}")
    print(f"a_chunk1=\n{a_chunk1}")
    print(f"a_chunk2=\n{a_chunk2}")
    print(f"a_chunk11=\n{a_chunk11}")
    print(f"a_chunk21=\n{a_chunk21}")
    t = torch.tensor([[1, 2], [3, 4]])
    t_gather = torch.gather(t, 1, torch.tensor([[0, 1], [1, 0]]))
    print(f"t=\n{t}")
    print(f"t_gather=\n{t_gather}")
    reshape_12 = torch.arange(12).reshape((3, 4))
    print(f"reshape_12=\n{reshape_12}")
    reshape_11 = reshape_12.reshape((-1, 1))
    print(f"reshape_11=\n{reshape_11}")
    src = torch.arange(1, 11).reshape((2, 5))
    index = torch.tensor([[0, 1, 2, 0]])
    y = torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
    print(f"src=\n{src}")
    print(f"y=\n{y}")
    stack_a = torch.rand((3, 4))
    stack_b = torch.rand((3, 4))
    stack_ab = torch.stack((stack_a, stack_b))
    print(f"stack_a=\n{stack_a}")
    print(f"stack_b=\n{stack_b}")
    print(f"stack_ab=\n{stack_ab},shape={stack_ab.shape}")
    squeeze_1 = torch.rand((2, 1, 3))
    squeeze_2 = torch.squeeze(squeeze_1)
    print(f"squeeze_1.shape={squeeze_1.shape}")
    print(f"squeeze_2.shape={squeeze_2.shape}")
  • 结果:
a=
tensor([[0.5555, 0.0484],
        [0.3199, 0.2577],
        [0.8874, 0.6888]])
a_chunk1=
tensor([[0.5555, 0.0484],
        [0.3199, 0.2577]])
a_chunk2=
tensor([[0.8874, 0.6888]])
a_chunk11=
tensor([[0.5555],
        [0.3199],
        [0.8874]])
a_chunk21=
tensor([[0.0484],
        [0.2577],
        [0.6888]])
t=
tensor([[1, 2],
        [3, 4]])
t_gather=
tensor([[1, 2],
        [4, 3]])
reshape_12=
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
reshape_11=
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11]])
src=
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
y=
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])
stack_a=
tensor([[0.2410, 0.9222, 0.5832, 0.3587],
        [0.9344, 0.3320, 0.3852, 0.3239],
        [0.7664, 0.9575, 0.2645, 0.5601]])
stack_b=
tensor([[0.4304, 0.7509, 0.3536, 0.7229],
        [0.9026, 0.0793, 0.3076, 0.3272],
        [0.4434, 0.2406, 0.7080, 0.9304]])
stack_ab=
tensor([[[0.2410, 0.9222, 0.5832, 0.3587],
         [0.9344, 0.3320, 0.3852, 0.3239],
         [0.7664, 0.9575, 0.2645, 0.5601]],

        [[0.4304, 0.7509, 0.3536, 0.7229],
         [0.9026, 0.0793, 0.3076, 0.3272],
         [0.4434, 0.2406, 0.7080, 0.9304]]]),shape=torch.Size([2, 3, 4])
squeeze_1.shape=torch.Size([2, 1, 3])
squeeze_2.shape=torch.Size([2, 3])

Process finished with exit code 0


http://www.kler.cn/a/401855.html

相关文章:

  • 2024年11月19日Github流行趋势
  • 某校园网登录界面前端加密绕过
  • SDIO 和MISC 什么关系
  • PH热榜 | 2024-11-19
  • 详细解读CNAS实验室认证
  • leetcode 扫描线专题 06-leetcode.836 rectangle-overlap 力扣.836 矩形重叠
  • 经验笔记:从生成 SSH 密钥到成功连接测试(以Gitee为例)
  • 微软Office 2021 24年11月授权版
  • c语言金典100题“从入门到放弃”10-15
  • Dubbo自定义扩展注册中心
  • Jav项目实战II基于微信小程序的助农扶贫的设计与实现(开发文档+数据库+源码)
  • 数据结构(二)线性表
  • 助力模型训练,深度学习的经典数据集介绍
  • Matplotlib | 理解直方图中bins表示的数据含义
  • WPF 中 MultiConverter ——XAML中复杂传参方式
  • 推荐一款UI/UX原型设计工具:Icons8 Lunacy
  • 【Rust 学习笔记】Rust 安装与 “Hello World” 程序介绍
  • qt中ctrl+鼠标左键无法进入
  • MFC图形函数学习09——画多边形函数
  • 【小程序】dialog组件
  • PHP批量操作加锁
  • CSP/信奥赛C++语法基础刷题训练(16):洛谷P5731:蛇形方阵
  • C++11——异常
  • 网络安全检测技术
  • python用哈希删除文件夹中重复的图片
  • linux配置动态ip