当前位置: 首页 > article >正文

pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill() 详解

pytorch小记(十二):pytorch中 masked_fill_() vs. masked_fill()详解

  • PyTorch `masked_fill_()` vs. `masked_fill()` 详解
    • 1️⃣ `masked_fill()` 和 `masked_fill_()` 的作用
    • 2️⃣ `masked_fill()` vs. `masked_fill_()` 示例
    • 3️⃣ 输出结果
    • 4️⃣ `masked_fill()` vs. `masked_fill_()` 区别
    • 5️⃣ `masked_fill()` 和 `masked_fill_()` 的实际应用
    • 6️⃣ `masked_fill()` 在 Transformer 自注意力中的应用
    • 7️⃣ `masked_fill_()` 在梯度计算中的应用
    • 8️⃣ 总结
      • 💡 实际应用


PyTorch masked_fill_() vs. masked_fill() 详解

在 PyTorch 中,masked_fill_()masked_fill() 主要用于 根据掩码(mask)填充张量(tensor)中的特定元素,但它们的关键区别在于 是否修改原张量(in-place 操作)


1️⃣ masked_fill()masked_fill_() 的作用

  • masked_fill(mask, value)
    • 不会修改原张量,而是返回一个新的张量。
  • masked_fill_(mask, value)
    • 会直接修改原张量(in-place 操作),节省内存。

两者的作用

  • mask 是一个 布尔张量True 代表需要填充的元素)。
  • value 是要填充的数值。

2️⃣ masked_fill() vs. masked_fill_() 示例

import torch

# 创建一个 3×3 的张量
tensor = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

# 创建一个掩码:True 代表要被替换的元素
mask = torch.tensor([
    [False, True, False],
    [True, False, False],
    [False, False, True]
])

print("原张量 tensor:\n", tensor)

# 使用 masked_fill()(不会修改原张量)
new_tensor = tensor.masked_fill(mask, -1)
print("\n新张量 new_tensor(使用 masked_fill()):\n", new_tensor)
print("\n原张量 tensor(未修改):\n", tensor)

# 使用 masked_fill_()(会修改原张量)
tensor.masked_fill_(mask, -1)
print("\n原张量 tensor(被修改):\n", tensor)

3️⃣ 输出结果

原张量 tensor:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

新张量 new_tensor(使用 masked_fill()):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

原张量 tensor(未修改):
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

原张量 tensor(被修改):
tensor([[ 1, -1,  3],
        [-1,  5,  6],
        [ 7,  8, -1]])

4️⃣ masked_fill() vs. masked_fill_() 区别

函数是否修改原张量?返回值
masked_fill(mask, value)❌ 不修改返回新的张量
masked_fill_(mask, value)✅ 直接修改修改后的原张量

总结

  • 如果你希望创建一个新张量,而不修改原数据,用 masked_fill()
  • 如果你希望节省内存并直接修改原张量,用 masked_fill_()

5️⃣ masked_fill()masked_fill_() 的实际应用

自然语言处理(NLP)深度学习模型 中,这两个函数经常用于 掩码(masking)操作,例如:

  • 屏蔽填充(Padding Mask):防止模型处理填充的 PAD 词(如 Transformer)。
  • 屏蔽未来信息(Future Mask):用于自回归模型(如 GPT),确保预测不会使用未来的信息。

6️⃣ masked_fill() 在 Transformer 自注意力中的应用

import torch

# 假设有一个 4×4 的注意力得分矩阵
attn_scores = torch.tensor([
    [0.5, 0.7, 0.8, 0.9],
    [0.6, 0.5, 0.4, 0.8],
    [0.2, 0.4, 0.5, 0.7],
    [0.3, 0.5, 0.6, 0.8]
])

# 创建一个掩码(模拟未来时间步的屏蔽)
mask = torch.tensor([
    [False, False, False, True],
    [False, False, True, True],
    [False, True, True, True],
    [True, True, True, True]
])

# 用 -inf 屏蔽掩码位置
masked_scores = attn_scores.masked_fill(mask, float('-inf'))
print("\n注意力得分(masked_fill()):\n", masked_scores)

示例输出:

注意力得分(masked_fill()):
tensor([[ 0.5000,  0.7000,  0.8000,    -inf],
        [ 0.6000,  0.5000,    -inf,    -inf],
        [ 0.2000,    -inf,    -inf,    -inf],
        [   -inf,    -inf,    -inf,    -inf]])

📌 解释

  • False 的位置保留原始数值。
  • True 的位置填充 -inf,在 softmax 计算时会被归零,不影响其他数值。

7️⃣ masked_fill_() 在梯度计算中的应用

在 PyTorch 训练过程中,如果你想直接修改梯度计算中的变量,可以使用 masked_fill_() 进行 in-place 操作

import torch

# 创建一个需要计算梯度的张量
x = torch.tensor([0.1, 0.2, 0.3, 0.4], requires_grad=True)

# 创建掩码
mask = torch.tensor([False, True, False, True])

# 直接修改 x
x.masked_fill_(mask, 0.0)

print("\n被修改后的 x(masked_fill_()):\n", x)

示例输出:

被修改后的 x(masked_fill_()):
tensor([0.1000, 0.0000, 0.3000, 0.0000], requires_grad=True)

📌 解释

  • 通过 masked_fill_() 直接在计算图中修改 x,避免创建新张量。

8️⃣ 总结

  • masked_fill()masked_fill_() 都用于按掩码填充张量中的特定元素
  • 主要区别
    • masked_fill() 不修改原张量,返回新的张量。
    • masked_fill_() 直接修改原张量(in-place 操作)。
  • 适用场景
    • masked_fill():适用于需要 保持原张量不变 的情况,如 Transformer 掩码处理
    • masked_fill_():适用于需要 节省内存直接修改张量 的情况,如 梯度计算

💡 实际应用

场景推荐使用
创建新张量,不修改原数据masked_fill()
直接修改原数据,减少内存占用masked_fill_()
Transformer 自注意力掩码masked_fill()
梯度计算,避免额外的计算图创建masked_fill_()

🚀 合理使用 masked_fill()masked_fill_(),可以优化你的 PyTorch 代码,提高计算效率! 🎯


http://www.kler.cn/a/591693.html

相关文章:

  • 《我的Python觉醒之路》之转型Python(十三)——控制流
  • Trae插件革命:用VSPlugin Helper实现VSCode市场插件全自动安装
  • RabbitMQ常见问题总结
  • Laravel框架下通过DB获取数据并转为数组的方法
  • 宝石PDF,全新 PC 版本,全部免费
  • CSS中z-index使用详情
  • C++和标准库速成(八)——指针、动态数组、const、constexpr和consteval
  • LeetCode 第14~16题
  • 第29周 面试题精讲(4)
  • KNN算法性能优化技巧与实战案例
  • vue网格布局--grid布局
  • 采用贝塞尔函数,进行恒定束宽波束形成算法
  • 如何用AI制作PPT,轻松实现高效演示
  • Node.js 和 Vite 配置文件中`__dirname`
  • Mysql如何解决幻读问题
  • 提示词prompt如何写
  • 【R语言】二项分布,正态分布,极大似然估计实现
  • 探索 Ollama Deep Researcher:本地网络研究助手的新选择
  • Git 本地常见快捷操作
  • 【html中文本超出元素的宽度后显示省略号...】