向量元素间是否相等mask矩阵
文章目录
- 1. 描述
- 2. pytorch代码
1. 描述
给定一个向量a_vector,生成一个mask矩阵用来判断向量元素间是否相等
假设有一个向量a_vector[1,2,1,4]此时就两个1相等,所以生成一个mask矩阵用来判断两个元素是否相等
2. pytorch代码
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.set_printoptions(precision=3, sci_mode=False)
if __name__ == "__main__":
run_code = 0
a_vector = torch.tensor([1, 2, 1, 4]).reshape((-1, 1))
print(f"a_vector=\n{a_vector}")
a_matrix = a_vector - a_vector.T
print(f"a_matrix=\n{a_matrix}")
a_mask = a_matrix == 0
print(f"a_mask=\n{a_mask}")
- 结果:
a_vector=
tensor([[1],
[2],
[1],
[4]])
a_matrix=
tensor([[ 0, -1, 0, -3],
[ 1, 0, 1, -2],
[ 0, -1, 0, -3],
[ 3, 2, 3, 0]])
a_mask=
tensor([[ True, False, True, False],
[False, True, False, False],
[ True, False, True, False],
[False, False, False, True]])