topN 相似度 torch实现
目录
优化版,去重相似度
topN 欧式距离版
没有去重复,
优化版,去重相似度
import torch
import torch.nn.functional as F
torch.manual_seed(42)
# 假设 10 条数据,每条数据的特征维度是 128
data = torch.randn(10, 128)
# 计算所有数据对之间的余弦相似度
cosine_similarities = F.cosine_similarity(data.unsqueeze(0), data.unsqueeze(1), dim=2)
# 通过设置对角线为负无穷,排除自身相似度
cosine_similarities.fill_diagonal_(-float('inf'))
# 生成上三角掩码(i < j 的位置为True)
mask = torch.triu(torch.ones_like(cosine_similarities, dtype=torch.bool), diagonal=1)
# 过滤掉下三角和对角线,仅保留