pytorch计算张量中三维向量的欧式距离
如果 X
是一个包含多个三维向量的张量,形状为 [b, n, 3]
,其中 b
是批次大小,n
是每个批次中的向量数量,那么可以使用类似的广播机制来计算同一批次内不同位置的三维向量之间的欧式距离。
以下是具体实现步骤:
-
扩展张量的维度:需要将
X
的维度扩展,以便能够利用广播机制计算每对向量之间的差值。 -
计算差值并求平方和:计算向量之间的差值,并对差值的平方求和。
-
计算欧式距离:对平方和取平方根,得到欧式距离。
import torch
# 假设 X 是形状为 [b, n, 3] 的张量,b 是批次大小,n 是向量的数量
b = 128
n = 100
X = torch.randn(b, n, 3) # 示例输入
# 第一步:扩展维度
X_expanded_1 = X.unsqueeze(2) # 形状为 [b, n, 1, 3]
X_expanded_2 = X.unsqueeze(1) # 形状为 [b, 1, n, 3]
# 第二步:计算每对向量之间的差值的平方和
dX = X_expanded_1 - X_expanded_2 # 形状为 [b, n, n, 3]
dX_squared_sum = torch.sum(dX**2, dim=3) # 形状为 [b, n, n]
# 第三步:计算欧式距离
distances = torch.sqrt(dX_squared_sum) # 形状为 [b, n, n]
# distances[k, i, j] 表示批次 k 中位置 i 和位置 j 之间的欧式距离
print(distances)
解释:
-
扩展维度:
X.unsqueeze(2)
将X
的形状从[b, n, 3]
变为[b, n, 1, 3]
,而X.unsqueeze(1)
将其形状变为[b, 1, n, 3]
。通过这种扩展,每个批次内的所有位置对可以使用广播机制进行差值计算。 -
计算差值并求平方和:
dX
是一个形状为[b, n, n, 3]
的张量,表示每个批次内的每对位置之间的差值。torch.sum(dX**2, dim=3)
对最后一个维度(即三维坐标的维度)求和,得到每对位置之间的平方距离,形状为[b, n, n]
。 -
计算欧式距离:最后,使用
torch.sqrt
对平方距离取平方根,得到最终的欧式距离矩阵distances
,其形状为[b, n, n]
,表示每个批次内所有位置对之间的欧式距离。
这个 distances
张量的形状为 [b, n, n]
,其中 distances[k, i, j]
表示批次 k
中位置 i
和位置 j
之间的欧式距离。