当前位置: 首页 > article >正文

PyTorch核心概念:从梯度、计算图到连续性的全面解析(三)

文章目录

  • Contiguous vs Non-Contiguous Tensor
    • Tensor and View
    • Strides
    • 非连续数据结构:Transpose( )
    • 在 PyTorch 中检查Contiguous and Non-Contiguous
      • 将不连续张量(或视图)转换为连续张量
      • view() 和 reshape() 之间的区别
      • 总结
  • 参考文献

Contiguous vs Non-Contiguous Tensor

Tensor and View

View使用与原始张量相同的数据块,只是“view”其维度的方式不同
视图只不过是解释原始张量维度的另一种方法,而无需在内存中进行物理复制。例如,我们有一个 1x12 张量,即 [1,2,3,4,5,6,7,8,9,10,11,12],然后使用 .view(4,3) 来改变形状将张量转换为 4x3 结构

x = torch.arange(1,13)
print(x)
>> tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

x = torch.arange(1,13)
y = x.view(4,3)
print(y)
>>
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

如果更改原始张量 x 中的数据,它也会反映在视图张量 y 中,因为视图张量 y 不是创建原始张量 x 的另一个副本,而是从与原始张量相同的内存地址读取数据X。反之亦然,视图张量中的值的更改将同时更改原始张量中的值,因为视图张量及其原始张量共享同一块内存块

x = torch.arange(1,13)
y = x.view(4,3)
x[0] = 100
print(y)
>> 
tensor([[100,   2,   3],
        [  4,   5,   6],
        [  7,   8,   9],
        [ 10,  11,  12]])
        
x = torch.arange(1,13)
y = x.view(4,3)
y[-1,-1] = 1000
print(x)
>> tensor([   1,    2,    3,    4,    5,    6,    7,    8,    9,   10,   11, 1000])

可以以连续的方式查看不同维度的数据序列
一维张量A中的元素数量为T,经过view()处理之后的张量B,shape为(K,M,N),则需满足 K × M × N = T K\times M\times N=T K×M×N=T

Strides

在这里插入图片描述

# x is a contiguous data. Recall that view() doesn't change data arrangement in the original 1D tensor
x = torch.arange(1,13).view(6,2)
x
>>
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
        
# Check stride
x.stride()
>> (2, 1)

步长 (2, 1) 告诉我们:我们需要跨过 1 个(维度 0)数字才能到达沿轴 0 的下一个数字,并且需要跨过 2 个(维度 1)数字才能到达沿轴 1 的下一个数字

y = torch.arange(0,11).view(2,2,3)
y
>>
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],
        [[ 6,  7,  8],
         [ 9, 10, 11]]])
         
# Check stride
y.stride()
>> (6, 3, 1)

检索一维张量中 (A, B, C) 位置的公式如下: A × 6 + B × 3 + C × 1 A \times 6 + B \times 3 + C \times 1 A×6+B×3+C×1

非连续数据结构:Transpose( )

首先,Transpose(axis1, axis2) 只是“swapping the way axis1 and axis2 strides”
在这里插入图片描述

# Initiate a contiguous tensor
x = torch.arange(0,12).view(2,2,3)
x
>>
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],
        [[ 6,  7,  8],
         [ 9, 10, 11]]])
         
x.stride()
>> (6,3,1)

# Now let's transpose axis 0 and 1, and see how the strides swap
y = x.transpose(0,2)
y
>>
tensor([[[ 0,  6],
         [ 3,  9]],
        [[ 1,  7],
         [ 4, 10]],
        [[ 2,  8],
         [ 5, 11]]])
         
y.stride()
>> (1,3,6)

y 是 x.transpose(0,2),它交换 x 张量在轴 0 和轴 2 上的stride,因此 y 的stride是 (1,3,6)。这意味着我们需要跳转 6 个数字才能获取第 0 轴的下一个数字,跳转 3 个数字才能获取第 1 轴的下一个数字,跳转 1 个数字才能获取第 2 轴的下一个数字(stride公式: A × 1 + B × 3 + C × 6 A \times 1+ B \times 3+C \times 6 A×1+B×3+C×6)
transpose的不同之处在于:现在数据序列不再遵循连续的顺序。它不会从最内层维度逐一填充顺序数据,填满后跳转到下一个维度。现在它在最里面的维度跳跃了6个数字,所以它不是连续的
transpose( ) 具有不连续的数据结构,但仍然是视图而不是副本 ⇒ \Rightarrow 它是一个不连续的“视图”,改变了原始数据的stride方式

# Change the value in a transpose tensor y
x = torch.arange(0,12).view(2,6)
y = x.transpose(0,1)
y[0,0] = 100
y
>>
tensor([[100,   6],
        [  1,   7],
        [  2,   8],
        [  3,   9],
        [  4,  10],
        [  5,  11]])
# Check the original tensor x
x
>>
tensor([[100,   1,   2,   3,   4,   5],
        [  6,   7,   8,   9,  10,  11]])

在 PyTorch 中检查Contiguous and Non-Contiguous

使用PyTorch中的 .is_contigious() 检查张量是否连续

x = torch.arange(0,12).view(2,6)
x.is_contiguous()
>> True

y = x.transpose(0,1)
y.is_contiguous()
>> False

将不连续张量(或视图)转换为连续张量

使用PyTorch中的 .contigious() 将不连续的张量转换成连续的张量

z = y.contiguous()
z.is_contiguous()
>> TRUE

** .contigious() 复制原始的“non-contiguous”张量,然后按照连续顺序将其保存到新的内存块中**

# This is contiguous
x = torch.arange(1,13).view(2,3,2)
x.stride()
>> (6, 2, 1)

# This is non-contiguous
y = x.transpose(0,1)
y.stride()
>> (2, 6, 1)

# This is a converted contiguous tensor with new stride
z = y.contiguous()
z.stride()
>> (4, 2, 1)

print(z.shape)
>> (3, 2, 2)

# The stride across the first dimension is 2*2
# The stride across the second dimension is 2*1
# The stride across the third dimension is 1
(4, 2, 1)=>(2*2, 2*1, 1)

用来区分张量/视图是否连续的一种方法是观察stride中的 ( A , B , C ) (A, B, C) (A,B,C) 是否满足 A > B > C A > B > C A>B>C。如果不满足,则意味着至少有一个维度正在跳过的距离比其上方的维度更长,这使得它不连续
我们还可以观察转换后的连续张量 z 如何以新的顺序存储数据

# y is a non-contiguous 'view' (remember view uses the original chunk of data in memory, but its strides implies 'non-contiguous', (2,6,1).
y.storage()
>>
 1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 
# Z is a 'contiguous' tensor (not a view, but a new copy of the original data. Notice the order of the data is different). It strides implies 'contiguous', (4,2,1)
z.storage()
>>
 1
 2
 7
 8
 3
 4
 9
 10
 5
 6
 11
 12

view() 和 reshape() 之间的区别

虽然这两个函数都可以改变张量的维度,但两者之间的主要区别是:

  1. view():不复制原始张量,使用与原始张量相同的数据块,仅适用于连续数据
  2. reshape():当数据连续时,尽可能返回视图;当数据不连续时,则将数据复制到连续的数据块中,作为副本,它会占用内存空间,而且新张量的变化不会影响原始张量中的原始数值
# When data is contiguous
x = torch.arange(1,13)
x
>> tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

# Reshape returns a view with the new dimension
y = x.reshape(4,3)
y
>>
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
        
# How do we know it's a view? Because the element change in new tensor y would affect the value in x, and vice versa
y[0,0] = 100
y
>>
tensor([[100,   2,   3],
        [  4,   5,   6],
        [  7,   8,   9],
        [ 10,  11,  12]])
        
print(x)
>>
tensor([100,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12])

接下来,让我们看看 reshape() 如何处理非连续数据:

# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
        
# Reshape() works fine on a non-contiguous data
y = x.reshape(4,3)
y
>>
tensor([[ 1,  3,  5],
        [ 7,  9, 11],
        [ 2,  4,  6],
        [ 8, 10, 12]])
        
# Change an element in y
y[0,0] = 100
y
>>
tensor([[100,   3,   5],
        [  7,   9,  11],
        [  2,   4,   6],
        [  8,  10,  12]])
        
# Check the original tensor, and nothing was changed
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])

最后,让我们看看 view() 是否可以处理非连续数据。No, it can’t!

# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
        
# Try to use view on the non-contiguous data
y = x.view(4,3)
y
>>
-------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
----> 1 y = x.view(4,3)
      2 y
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

总结

  1. view”使用与原始张量相同的内存块,因此该内存块中的任何更改都会影响所有视图以及与其关联的原始张量
  2. 视图可以是连续的或不连续的。一个不连续的张量视图可以转换为连续的张量视图,并且会复制不连续的视图张量到新的内存空间中,因此数据将不再与原始数据块关联
  3. stride位置公式:给定一个stride ( A , B , C ) (A,B,C) (ABC),索引 ( j , k , v ) (j, k, v) (j,k,v) 在 1D 数据数组中的位置为 ( A × j + B × k + C × v ) (A \times j + B \times k + C \times v) (A×j+B×k+C×v)
  4. view()reshape() 之间的区别:view() 不能应用于 '非连续的张量/视图,它返回一个视图;reshape() 可以应用于“连续”和“非连续”张量/视图

参考文献

1、Contiguous vs Non-Contiguous Tensor / View — Understanding view(), reshape(), transpose()


http://www.kler.cn/a/381580.html

相关文章:

  • 使用 AMD GPU 的 ChatGLM-6B 双语语言模型
  • NAT实验
  • Flutter 正在切换成 Monorepo 和支持 workspaces
  • Linux云计算 |【第五阶段】CLOUD-DAY6
  • 一二三应用开发平台自定义查询设计与实现系列3——通用化重构
  • (七)JavaWeb后端开发——Maven
  • 【STM32】通过 DWT 实现毫秒级延时
  • 【Linux】IPC进程间通信System V:并发编程实战指南(二)
  • xcode更新完最新版本无法运行调试
  • Postman断言与依赖接口测试详解
  • 人工智能AI 产品经理与传统产品经理工作到底有什么不同?非常详细收藏我这一篇就够了
  • kubernetes部署rancher无法查看pod日志及通过execute shell进入pod解决办法
  • 【Android Wi-Fi 操作命令指南】
  • pdf添加目录标签python(手动配置)
  • 【大数据学习 | kafka】producer之拦截器,序列化器与分区器
  • 数论——约数(完整版)
  • 动态避障-图扑自动寻路 3D 可视化
  • 使用Python简单实现客户端界面
  • 数据结构(8.7_2)——败者树
  • 苹果iOS 18.4将允许欧盟地区的iPhone用户设置默认地图和翻译应用
  • Excel 个人时间管理工具
  • 一文带您了解SonarScanner的原理和使用方法(包括maven构建和命令行执行)
  • 面试题:Vue生命周期
  • 【python】OpenCV—Connected Components
  • sheng的学习笔记-tidb框架原理
  • angular实现dialog弹窗