当前位置: 首页 > article >正文

einops测试

文章目录

  • 1. einops
  • 2. code
  • 3. pytorch

1. einops

einops 主要是通过爱因斯坦标记法来处理张量矩阵的库,让矩阵处理上非常简单。

  • conda :
conda install conda-forge::einops
  • python:

2. code

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce

torch.set_printoptions(precision=3, sci_mode=False)

if __name__ == "__main__":
    run_code = 0
    x = torch.arange(96).reshape((2, 3, 4, 4)).to(torch.float32)
    print(f"x.shape={x.shape}")
    print(f"x=\n{x}")

    # 1. 转置
    x_torch_trans = x.transpose(1, 2)
    x_einops_trans = rearrange(x, 'b i w h -> b w i h')
    x_check_trans = torch.allclose(x_torch_trans, x_einops_trans)
    print(f"x_torch_trans is {x_check_trans} same with x_einops_trans")

    # 2. 变形
    x_torch_reshape = x.reshape(6, 4, 4)
    x_einops_reshape = rearrange(x, 'b i w h -> (b i) w h')
    x_check_reshape = torch.allclose(x_torch_reshape, x_einops_reshape)
    print(f"x_einops_reshape is {x_check_reshape} same with x_check_reshape")

    # 3. image2patch
    image2patch = rearrange(x, 'b i (h1 p1) (w1 p2) -> b i (h1 w1) p1 p2', p1=2, p2=2)
    print(f"image2patch.shape={image2patch.shape}")
    print(f"image2patch=\n{image2patch}")
    image2patch2 = rearrange(image2patch, 'b i j h w -> b (i j) h w')
    print(f"image2patch2.shape={image2patch2.shape}")
    print(f"image2patch2=\n{image2patch2}")
    y = torch.arange(24).reshape((2, 3, 4)).to(torch.float32)
    y_einops_mean = reduce(y, 'b h w -> b h', 'mean')
    print(f"y=\n{y}")
    print(f"y_einops_mean=\n{y_einops_mean}")
    y_tensor = torch.arange(24).reshape(2, 2, 2, 3)
    y_list = [y_tensor, y_tensor, y_tensor]
    y_output = rearrange(y_list, 'n b i h w -> n b i h w')
    print(f"y_tensor=\n{y_tensor}")
    print(f"y_output=\n{y_output}")
    z_tensor = torch.arange(12).reshape(2, 2, 3).to(torch.float32)
    z_tensor_1 = rearrange(z_tensor, 'b h w -> b h w 1')
    print(f"z_tensor=\n{z_tensor}")
    print(f"z_tensor_1=\n{z_tensor_1}")
    z_tensor_2 = repeat(z_tensor_1, 'b h w 1 -> b h w 2')
    print(f"z_tensor_2=\n{z_tensor_2}")
    z_tensor_repeat = repeat(z_tensor, 'b h w -> b (2 h) (2 w)')
    print(f"z_tensor_repeat=\n{z_tensor_repeat}")
  • python:
x.shape=torch.Size([2, 3, 4, 4])
x=
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.]],

         [[32., 33., 34., 35.],
          [36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]],


        [[[48., 49., 50., 51.],
          [52., 53., 54., 55.],
          [56., 57., 58., 59.],
          [60., 61., 62., 63.]],

         [[64., 65., 66., 67.],
          [68., 69., 70., 71.],
          [72., 73., 74., 75.],
          [76., 77., 78., 79.]],

         [[80., 81., 82., 83.],
          [84., 85., 86., 87.],
          [88., 89., 90., 91.],
          [92., 93., 94., 95.]]]])
x_torch_trans is True same with x_einops_trans
x_einops_reshape is True same with x_check_reshape
image2patch.shape=torch.Size([2, 3, 4, 2, 2])
image2patch=
tensor([[[[[ 0.,  1.],
           [ 4.,  5.]],

          [[ 2.,  3.],
           [ 6.,  7.]],

          [[ 8.,  9.],
           [12., 13.]],

          [[10., 11.],
           [14., 15.]]],


         [[[16., 17.],
           [20., 21.]],

          [[18., 19.],
           [22., 23.]],

          [[24., 25.],
           [28., 29.]],

          [[26., 27.],
           [30., 31.]]],


         [[[32., 33.],
           [36., 37.]],

          [[34., 35.],
           [38., 39.]],

          [[40., 41.],
           [44., 45.]],

          [[42., 43.],
           [46., 47.]]]],



        [[[[48., 49.],
           [52., 53.]],

          [[50., 51.],
           [54., 55.]],

          [[56., 57.],
           [60., 61.]],

          [[58., 59.],
           [62., 63.]]],


         [[[64., 65.],
           [68., 69.]],

          [[66., 67.],
           [70., 71.]],

          [[72., 73.],
           [76., 77.]],

          [[74., 75.],
           [78., 79.]]],


         [[[80., 81.],
           [84., 85.]],

          [[82., 83.],
           [86., 87.]],

          [[88., 89.],
           [92., 93.]],

          [[90., 91.],
           [94., 95.]]]]])
image2patch2.shape=torch.Size([2, 12, 2, 2])
image2patch2=
tensor([[[[ 0.,  1.],
          [ 4.,  5.]],

         [[ 2.,  3.],
          [ 6.,  7.]],

         [[ 8.,  9.],
          [12., 13.]],

         [[10., 11.],
          [14., 15.]],

         [[16., 17.],
          [20., 21.]],

         [[18., 19.],
          [22., 23.]],

         [[24., 25.],
          [28., 29.]],

         [[26., 27.],
          [30., 31.]],

         [[32., 33.],
          [36., 37.]],

         [[34., 35.],
          [38., 39.]],

         [[40., 41.],
          [44., 45.]],

         [[42., 43.],
          [46., 47.]]],


        [[[48., 49.],
          [52., 53.]],

         [[50., 51.],
          [54., 55.]],

         [[56., 57.],
          [60., 61.]],

         [[58., 59.],
          [62., 63.]],

         [[64., 65.],
          [68., 69.]],

         [[66., 67.],
          [70., 71.]],

         [[72., 73.],
          [76., 77.]],

         [[74., 75.],
          [78., 79.]],

         [[80., 81.],
          [84., 85.]],

         [[82., 83.],
          [86., 87.]],

         [[88., 89.],
          [92., 93.]],

         [[90., 91.],
          [94., 95.]]]])
y=
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
y_einops_mean=
tensor([[ 1.500,  5.500,  9.500],
        [13.500, 17.500, 21.500]])
y_tensor=
tensor([[[[ 0,  1,  2],
          [ 3,  4,  5]],

         [[ 6,  7,  8],
          [ 9, 10, 11]]],


        [[[12, 13, 14],
          [15, 16, 17]],

         [[18, 19, 20],
          [21, 22, 23]]]])
y_output=
tensor([[[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]],



        [[[[ 0,  1,  2],
           [ 3,  4,  5]],

          [[ 6,  7,  8],
           [ 9, 10, 11]]],


         [[[12, 13, 14],
           [15, 16, 17]],

          [[18, 19, 20],
           [21, 22, 23]]]]])
z_tensor=
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
z_tensor_1=
tensor([[[[ 0.],
          [ 1.],
          [ 2.]],

         [[ 3.],
          [ 4.],
          [ 5.]]],


        [[[ 6.],
          [ 7.],
          [ 8.]],

         [[ 9.],
          [10.],
          [11.]]]])
z_tensor_2=
tensor([[[[ 0.,  0.],
          [ 1.,  1.],
          [ 2.,  2.]],

         [[ 3.,  3.],
          [ 4.,  4.],
          [ 5.,  5.]]],


        [[[ 6.,  6.],
          [ 7.,  7.],
          [ 8.,  8.]],

         [[ 9.,  9.],
          [10., 10.],
          [11., 11.]]]])
z_tensor_repeat=
tensor([[[ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.],
         [ 0.,  1.,  2.,  0.,  1.,  2.],
         [ 3.,  4.,  5.,  3.,  4.,  5.]],

        [[ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.],
         [ 6.,  7.,  8.,  6.,  7.,  8.],
         [ 9., 10., 11.,  9., 10., 11.]]])

3. pytorch

在这里插入图片描述


http://www.kler.cn/a/553341.html

相关文章:

  • 电商分布式场景中如何保证数据库与缓存的一致性?实战方案与Java代码详解
  • rust 安全性
  • springsecurity自定义认证
  • 公司电脑监控软件一般有哪些——软件的类型分析与 WorkWin 的特性探究
  • [含文档+PPT+源码等]精品基于springboot实现的原生Andriod汽车后市场服务系统
  • VictoriaLogs Syslog日志收集存储系统部署
  • R软件用潜在类别混合模型LCM分析老年人抑郁数据轨迹多变量建模研究
  • 01数据准备 抓取图片 通过爬虫方式获取bing的关键词搜索图片
  • uniapp基于JSSDK 开发微信支付(php后端)
  • 4.从零开始学会Vue--{{组件通信}}
  • LED灯闪烁实验:Simulink应用层开发
  • 【Golang 面试题】每日 3 题(五十九)
  • JVM类加载过程详解:从字节码到内存的蜕变之旅
  • HBase简介
  • 微软的基本类库BCL
  • 【python】tkinter简要教程
  • springmvc(13/158)
  • Pytorch实现之统计全局信息的轻量级EGAN
  • 计算机视觉算法实战——图像合成(主页有源码)
  • PHP培训机构教务管理系统小程序源码