pytorch torch.triu函数介绍
torch.triu
是 PyTorch 提供的一个函数,用于生成矩阵的上三角部分。它的名称来源于 "triangular upper"(上三角形),作用是将矩阵的下三角部分置为零,只保留对角线及其上方的元素。
函数签名
torch.triu(input, diagonal=0) → Tensor
参数
input
: 输入的张量,一般是一个二维矩阵(Tensor
)。diagonal
: 对角线的偏移量,默认值为0
。- 当
diagonal=0
时,保留主对角线及其上方的元素。 - 当
diagonal>0
时,向上偏移保留的对角线。偏移的值决定从上三角的第几行开始保留。 - 当
diagonal<0
时,向下偏移保留的对角线,即包括主对角线下方的部分。
- 当
返回值
返回一个与 input
形状相同的张量,但下三角部分的值被置为零。
示例
import torch
# 创建一个 3x3 的张量
A = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 获取上三角部分
upper_triangular = torch.triu(A)
print(upper_triangular)
在这个例子中,torch.triu
保留了矩阵 A
的主对角线及其上方的元素,而将下方的元素置为零。
偏移对角线示例
如果我们设置 diagonal
为 1
,则只保留主对角线上方的元素:
upper_triangular = torch.triu(A, diagonal=1)
print(upper_triangular)
输出为:
tensor([[0, 2, 3],
[0, 0, 6],
[0, 0, 0]])
应用场景
- 矩阵运算:
torch.triu
在需要使用上三角矩阵进行特定计算时很有用,比如 Cholesky 分解、图卷积中的邻接矩阵处理。 - 屏蔽下三角部分: 在一些序列处理任务中,常用上三角掩码来忽略无关的值,例如在自注意力机制中用来避免提前看到未来的序列。