pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解
pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill()详解
- PyTorch `masked_fill_()` vs. `masked_fill()` 详解
- 1️⃣ `masked_fill()` 和 `masked_fill_()` 的作用
- 2️⃣ `masked_fill()` vs. `masked_fill_()` 示例
- 3️⃣ 输出结果
- 4️⃣ `masked_fill()` vs. `masked_fill_()` 区别
- 5️⃣ `masked_fill()` 和 `masked_fill_()` 的实际应用
- 6️⃣ `masked_fill()` 在 Transformer 自注意力中的应用
- 7️⃣ `masked_fill_()` 在梯度计算中的应用
- 8️⃣ 总结
- 💡 实际应用
PyTorch masked_fill_()
vs. masked_fill()
详解
在 PyTorch 中,masked_fill_()
和 masked_fill()
主要用于 根据掩码(mask)填充张量(tensor)中的特定元素,但它们的关键区别在于 是否修改原张量(in-place 操作)。
1️⃣ masked_fill()
和 masked_fill_()
的作用
masked_fill(mask, value)
- 不会修改原张量,而是返回一个新的张量。
masked_fill_(mask, value)
- 会直接修改原张量(in-place 操作),节省内存。
两者的作用:
mask
是一个 布尔张量(True
代表需要填充的元素)。value
是要填充的数值。
2️⃣ masked_fill()
vs. masked_fill_()
示例
import torch
# 创建一个 3×3 的张量
tensor = torch.tensor([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
])
# 创建一个掩码:True 代表要被替换的元素
mask = torch.tensor([
[False, True, False],
[True, False, False],
[False, False, True]
])
print("原张量 tensor:\n", tensor)
# 使用 masked_fill()(不会修改原张量)
new_tensor = tensor.masked_fill(mask, -1)
print("\n新张量 new_tensor(使用 masked_fill()):\n", new_tensor)
print("\n原张量 tensor(未修改):\n", tensor)
# 使用 masked_fill_()(会修改原张量)
tensor.masked_fill_(mask, -1)
print("\n原张量 tensor(被修改):\n", tensor)
3️⃣ 输出结果
原张量 tensor:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
新张量 new_tensor(使用 masked_fill()):
tensor([[ 1, -1, 3],
[-1, 5, 6],
[ 7, 8, -1]])
原张量 tensor(未修改):
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
原张量 tensor(被修改):
tensor([[ 1, -1, 3],
[-1, 5, 6],
[ 7, 8, -1]])
4️⃣ masked_fill()
vs. masked_fill_()
区别
函数 | 是否修改原张量? | 返回值 |
---|---|---|
masked_fill(mask, value) | ❌ 不修改 | 返回新的张量 |
masked_fill_(mask, value) | ✅ 直接修改 | 修改后的原张量 |
总结
- 如果你希望创建一个新张量,而不修改原数据,用
masked_fill()
。 - 如果你希望节省内存并直接修改原张量,用
masked_fill_()
。
5️⃣ masked_fill()
和 masked_fill_()
的实际应用
在 自然语言处理(NLP) 和 深度学习模型 中,这两个函数经常用于 掩码(masking)操作,例如:
- 屏蔽填充(Padding Mask):防止模型处理填充的
PAD
词(如 Transformer)。 - 屏蔽未来信息(Future Mask):用于自回归模型(如 GPT),确保预测不会使用未来的信息。
6️⃣ masked_fill()
在 Transformer 自注意力中的应用
import torch
# 假设有一个 4×4 的注意力得分矩阵
attn_scores = torch.tensor([
[0.5, 0.7, 0.8, 0.9],
[0.6, 0.5, 0.4, 0.8],
[0.2, 0.4, 0.5, 0.7],
[0.3, 0.5, 0.6, 0.8]
])
# 创建一个掩码(模拟未来时间步的屏蔽)
mask = torch.tensor([
[False, False, False, True],
[False, False, True, True],
[False, True, True, True],
[True, True, True, True]
])
# 用 -inf 屏蔽掩码位置
masked_scores = attn_scores.masked_fill(mask, float('-inf'))
print("\n注意力得分(masked_fill()):\n", masked_scores)
示例输出:
注意力得分(masked_fill()):
tensor([[ 0.5000, 0.7000, 0.8000, -inf],
[ 0.6000, 0.5000, -inf, -inf],
[ 0.2000, -inf, -inf, -inf],
[ -inf, -inf, -inf, -inf]])
📌 解释:
False
的位置保留原始数值。True
的位置填充-inf
,在 softmax 计算时会被归零,不影响其他数值。
7️⃣ masked_fill_()
在梯度计算中的应用
在 PyTorch 训练过程中,如果你想直接修改梯度计算中的变量,可以使用 masked_fill_()
进行 in-place 操作。
import torch
# 创建一个需要计算梯度的张量
x = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)
# 创建掩码
mask = torch.tensor([False, True, False, True])
# 直接修改 x
x.masked_fill_(mask, 0.0)
print("\n被修改后的 x(masked_fill_()):\n", x)
示例输出:
被修改后的 x(masked_fill_()):
tensor([0.1000, 0.0000, 0.3000, 0.0000], requires_grad=True)
📌 解释
- 通过
masked_fill_()
直接在计算图中修改x
,避免创建新张量。
8️⃣ 总结
masked_fill()
和masked_fill_()
都用于按掩码填充张量中的特定元素。- 主要区别:
masked_fill()
不修改原张量,返回新的张量。masked_fill_()
直接修改原张量(in-place 操作)。
- 适用场景:
masked_fill()
:适用于需要 保持原张量不变 的情况,如 Transformer 掩码处理。masked_fill_()
:适用于需要 节省内存 并 直接修改张量 的情况,如 梯度计算。
💡 实际应用
场景 | 推荐使用 |
---|---|
创建新张量,不修改原数据 | masked_fill() |
直接修改原数据,减少内存占用 | masked_fill_() |
Transformer 自注意力掩码 | masked_fill() |
梯度计算,避免额外的计算图创建 | masked_fill_() |
🚀 合理使用 masked_fill()
和 masked_fill_()
,可以优化你的 PyTorch 代码,提高计算效率! 🎯