PyTorch核心概念:从梯度、计算图到连续性的全面解析(三)
文章目录
- Contiguous vs Non-Contiguous Tensor
- Tensor and View
- Strides
- 非连续数据结构:Transpose( )
- 在 PyTorch 中检查Contiguous and Non-Contiguous
- 将不连续张量(或视图)转换为连续张量
- view() 和 reshape() 之间的区别
- 总结
- 参考文献
Contiguous vs Non-Contiguous Tensor
Tensor and View
View使用与原始张量相同的数据块,只是“view”其维度的方式不同
视图只不过是解释原始张量维度的另一种方法,而无需在内存中进行物理复制。例如,我们有一个 1x12 张量,即 [1,2,3,4,5,6,7,8,9,10,11,12],然后使用 .view(4,3)
来改变形状将张量转换为 4x3 结构
x = torch.arange(1,13)
print(x)
>> tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
x = torch.arange(1,13)
y = x.view(4,3)
print(y)
>>
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
如果更改原始张量 x 中的数据,它也会反映在视图张量 y 中,因为视图张量 y 不是创建原始张量 x 的另一个副本,而是从与原始张量相同的内存地址读取数据X。反之亦然,视图张量中的值的更改将同时更改原始张量中的值,因为视图张量及其原始张量共享同一块内存块
x = torch.arange(1,13)
y = x.view(4,3)
x[0] = 100
print(y)
>>
tensor([[100, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[ 10, 11, 12]])
x = torch.arange(1,13)
y = x.view(4,3)
y[-1,-1] = 1000
print(x)
>> tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1000])
可以以连续的方式查看不同维度的数据序列
一维张量A中的元素数量为T,经过view()处理之后的张量B,shape为(K,M,N),则需满足
K
×
M
×
N
=
T
K\times M\times N=T
K×M×N=T
Strides
# x is a contiguous data. Recall that view() doesn't change data arrangement in the original 1D tensor
x = torch.arange(1,13).view(6,2)
x
>>
tensor([[ 1, 2],
[ 3, 4],
[ 5, 6],
[ 7, 8],
[ 9, 10],
[11, 12]])
# Check stride
x.stride()
>> (2, 1)
步长 (2, 1) 告诉我们:我们需要跨过 1 个(维度 0)数字才能到达沿轴 0 的下一个数字,并且需要跨过 2 个(维度 1)数字才能到达沿轴 1 的下一个数字
y = torch.arange(0,11).view(2,2,3)
y
>>
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
# Check stride
y.stride()
>> (6, 3, 1)
检索一维张量中 (A, B, C) 位置的公式如下: A × 6 + B × 3 + C × 1 A \times 6 + B \times 3 + C \times 1 A×6+B×3+C×1
非连续数据结构:Transpose( )
首先,Transpose(axis1, axis2) 只是“swapping the way axis1 and axis2 strides”
# Initiate a contiguous tensor
x = torch.arange(0,12).view(2,2,3)
x
>>
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
x.stride()
>> (6,3,1)
# Now let's transpose axis 0 and 1, and see how the strides swap
y = x.transpose(0,2)
y
>>
tensor([[[ 0, 6],
[ 3, 9]],
[[ 1, 7],
[ 4, 10]],
[[ 2, 8],
[ 5, 11]]])
y.stride()
>> (1,3,6)
y 是 x.transpose(0,2)
,它交换 x 张量在轴 0 和轴 2 上的stride,因此 y 的stride是 (1,3,6)。这意味着我们需要跳转 6 个数字才能获取第 0 轴的下一个数字,跳转 3 个数字才能获取第 1 轴的下一个数字,跳转 1 个数字才能获取第 2 轴的下一个数字(stride公式:
A
×
1
+
B
×
3
+
C
×
6
A \times 1+ B \times 3+C \times 6
A×1+B×3+C×6)
transpose的不同之处在于:现在数据序列不再遵循连续的顺序。它不会从最内层维度逐一填充顺序数据,填满后跳转到下一个维度。现在它在最里面的维度跳跃了6个数字,所以它不是连续的
transpose( )
具有不连续的数据结构,但仍然是视图而不是副本
⇒
\Rightarrow
⇒它是一个不连续的“视图”,改变了原始数据的stride方式
# Change the value in a transpose tensor y
x = torch.arange(0,12).view(2,6)
y = x.transpose(0,1)
y[0,0] = 100
y
>>
tensor([[100, 6],
[ 1, 7],
[ 2, 8],
[ 3, 9],
[ 4, 10],
[ 5, 11]])
# Check the original tensor x
x
>>
tensor([[100, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
在 PyTorch 中检查Contiguous and Non-Contiguous
使用PyTorch中的 .is_contigious()
检查张量是否连续
x = torch.arange(0,12).view(2,6)
x.is_contiguous()
>> True
y = x.transpose(0,1)
y.is_contiguous()
>> False
将不连续张量(或视图)转换为连续张量
使用PyTorch中的 .contigious()
将不连续的张量转换成连续的张量
z = y.contiguous()
z.is_contiguous()
>> TRUE
** .contigious()
复制原始的“non-contiguous”张量,然后按照连续顺序将其保存到新的内存块中**
# This is contiguous
x = torch.arange(1,13).view(2,3,2)
x.stride()
>> (6, 2, 1)
# This is non-contiguous
y = x.transpose(0,1)
y.stride()
>> (2, 6, 1)
# This is a converted contiguous tensor with new stride
z = y.contiguous()
z.stride()
>> (4, 2, 1)
print(z.shape)
>> (3, 2, 2)
# The stride across the first dimension is 2*2
# The stride across the second dimension is 2*1
# The stride across the third dimension is 1
(4, 2, 1)=>(2*2, 2*1, 1)
用来区分张量/视图是否连续的一种方法是观察stride中的
(
A
,
B
,
C
)
(A, B, C)
(A,B,C) 是否满足
A
>
B
>
C
A > B > C
A>B>C。如果不满足,则意味着至少有一个维度正在跳过的距离比其上方的维度更长,这使得它不连续
我们还可以观察转换后的连续张量 z 如何以新的顺序存储数据
# y is a non-contiguous 'view' (remember view uses the original chunk of data in memory, but its strides implies 'non-contiguous', (2,6,1).
y.storage()
>>
1
2
3
4
5
6
7
8
9
10
11
12
# Z is a 'contiguous' tensor (not a view, but a new copy of the original data. Notice the order of the data is different). It strides implies 'contiguous', (4,2,1)
z.storage()
>>
1
2
7
8
3
4
9
10
5
6
11
12
view() 和 reshape() 之间的区别
虽然这两个函数都可以改变张量的维度,但两者之间的主要区别是:
- view():不复制原始张量,使用与原始张量相同的数据块,仅适用于连续数据
- reshape():当数据连续时,尽可能返回视图;当数据不连续时,则将数据复制到连续的数据块中,作为副本,它会占用内存空间,而且新张量的变化不会影响原始张量中的原始数值
# When data is contiguous
x = torch.arange(1,13)
x
>> tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
# Reshape returns a view with the new dimension
y = x.reshape(4,3)
y
>>
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
# How do we know it's a view? Because the element change in new tensor y would affect the value in x, and vice versa
y[0,0] = 100
y
>>
tensor([[100, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[ 10, 11, 12]])
print(x)
>>
tensor([100, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
接下来,让我们看看 reshape() 如何处理非连续数据:
# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1, 3, 5, 7, 9, 11],
[ 2, 4, 6, 8, 10, 12]])
# Reshape() works fine on a non-contiguous data
y = x.reshape(4,3)
y
>>
tensor([[ 1, 3, 5],
[ 7, 9, 11],
[ 2, 4, 6],
[ 8, 10, 12]])
# Change an element in y
y[0,0] = 100
y
>>
tensor([[100, 3, 5],
[ 7, 9, 11],
[ 2, 4, 6],
[ 8, 10, 12]])
# Check the original tensor, and nothing was changed
x
>>
tensor([[ 1, 3, 5, 7, 9, 11],
[ 2, 4, 6, 8, 10, 12]])
最后,让我们看看 view() 是否可以处理非连续数据。No, it can’t!
# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1, 3, 5, 7, 9, 11],
[ 2, 4, 6, 8, 10, 12]])
# Try to use view on the non-contiguous data
y = x.view(4,3)
y
>>
-------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
----> 1 y = x.view(4,3)
2 y
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
总结
- view”使用与原始张量相同的内存块,因此该内存块中的任何更改都会影响所有视图以及与其关联的原始张量
- 视图可以是连续的或不连续的。一个不连续的张量视图可以转换为连续的张量视图,并且会复制不连续的视图张量到新的内存空间中,因此数据将不再与原始数据块关联
- stride位置公式:给定一个stride ( A , B , C ) (A,B,C) (A,B,C),索引 ( j , k , v ) (j, k, v) (j,k,v) 在 1D 数据数组中的位置为 ( A × j + B × k + C × v ) (A \times j + B \times k + C \times v) (A×j+B×k+C×v)
view()
和reshape()
之间的区别:view()
不能应用于 '非连续的张量/视图,它返回一个视图;reshape()
可以应用于“连续”和“非连续”张量/视图
参考文献
1、Contiguous vs Non-Contiguous Tensor / View — Understanding view(), reshape(), transpose()