python的负数索引理解
在 Python 中,负数索引用于从序列(如列表、元组或张量)的末尾开始计数。负数索引的理解方式如下:
-1
表示序列的最后一个元素。-2
表示序列的倒数第二个元素。- 以此类推。
例子:
假设我们有一个列表 list
:
lst = [10, 20, 30, 40, 50]
使用负数索引可以访问列表中的元素:
lst[-1]
返回50
,即列表的最后一个元素。lst[-2]
返回40
,即列表的倒数第二个元素。lst[-3]
返回30
,即列表的倒数第三个元素。
在 Layer Normalization 的实现中,负数索引用于计算需要进行归一化的维度索引。
假设 self.normalized_shape
是 [3, 4]
,则 len(self.normalized_shape)
是 2
。
因此,range(len(self.normalized_shape))
生成 [0, 1]
。
对于每个 i
,计算 -(i+1)
:
- 当
i = 0
时,-(i+1) = -1
,表示最后一个维度。 - 当
i = 1
时,-(i+1) = -2
,表示倒数第二个维度。
因此,dims
的值为 [-1, -2]
,表示需要在最后两个维度上进行归一化。
结合实例讲解
假设我们有一个输入张量 x
,其形状为 [2, 3, 4]
,即批量大小为 2
,通道数为 3
,每个通道有 4
个元素。我们希望在通道和空间维度上进行归一化。
import torch
x = torch.randn(2, 3, 4) # 输入张量的形状为 [2, 3, 4]
normalized_shape = [3, 4]
# 计算需要进行 LN 的维度索引 dims
dims = [-(i+1) for i in range(len(normalized_shape))]
print(dims) # 输出 [-1, -2]
# 计算特征图对应维度的均值和方差
mean = x.mean(dim=dims, keepdims=True)
mean_x2 = (x**2).mean(dim=dims, keepdims=True)
var = mean_x2 - mean**2
# 对输入 x 进行归一化
x_norm = (x - mean) / torch.sqrt(var + 1e-5)
print(x_norm)
这个例子中:
dims
的值为[-1, -2]
,表示需要在最后两个维度上进行归一化。mean
和var
分别是特征图对应维度的均值和方差。x_norm
是归一化后的张量。
通过使用负数索引,我们可以方便地指定需要进行归一化的维度,而不需要显式地计算维度的索引。