pytorch 张量的masked_fill函数介绍
torch.Tensor.masked_fill
是 PyTorch 中用于根据给定的掩码将张量中的特定元素替换为指定值的函数。这个函数可以用于在模型中屏蔽不需要的值,通常与掩码操作(如前向掩码、反向掩码等)结合使用。
函数签名
Tensor.masked_fill(mask, value)
参数
mask
:一个与输入张量相同形状的布尔型(bool
)张量或相同维度的整型张量,True
或非零的元素表示需要替换的元素位置。value
:需要替换为的值。当掩码张量中元素为True
或非零时,原张量中对应位置的元素会被替换为这个值。
返回
返回的是一个新张量,其中根据掩码 mask
对应的位置用 value
进行替换。
使用场景
- 遮蔽(屏蔽)无效值:在序列任务中,比如当处理不等长的输入时,可以使用
masked_fill
将填充的位置(比如PAD
标记)设置为一个极端的值,如负无穷大(-inf
),避免模型关注这些位置。 - 生成注意力掩码:在 Transformer 等模型中,使用掩码来确保某些位置不会被模型关注。