当前位置: 首页 > article >正文

pytorch计算张量中三维向量的欧式距离

如果 X 是一个包含多个三维向量的张量,形状为 [b, n, 3],其中 b 是批次大小,n 是每个批次中的向量数量,那么可以使用类似的广播机制来计算同一批次内不同位置的三维向量之间的欧式距离。

以下是具体实现步骤:

  1. 扩展张量的维度:需要将 X 的维度扩展,以便能够利用广播机制计算每对向量之间的差值。

  2. 计算差值并求平方和:计算向量之间的差值,并对差值的平方求和。

  3. 计算欧式距离:对平方和取平方根,得到欧式距离。

import torch

# 假设 X 是形状为 [b, n, 3] 的张量,b 是批次大小,n 是向量的数量
b = 128
n = 100
X = torch.randn(b, n, 3)  # 示例输入

# 第一步:扩展维度
X_expanded_1 = X.unsqueeze(2)  # 形状为 [b, n, 1, 3]
X_expanded_2 = X.unsqueeze(1)  # 形状为 [b, 1, n, 3]

# 第二步:计算每对向量之间的差值的平方和
dX = X_expanded_1 - X_expanded_2  # 形状为 [b, n, n, 3]
dX_squared_sum = torch.sum(dX**2, dim=3)  # 形状为 [b, n, n]

# 第三步:计算欧式距离
distances = torch.sqrt(dX_squared_sum)  # 形状为 [b, n, n]

# distances[k, i, j] 表示批次 k 中位置 i 和位置 j 之间的欧式距离

print(distances)

解释:

  1. 扩展维度X.unsqueeze(2) 将 X 的形状从 [b, n, 3] 变为 [b, n, 1, 3],而 X.unsqueeze(1) 将其形状变为 [b, 1, n, 3]。通过这种扩展,每个批次内的所有位置对可以使用广播机制进行差值计算。

  2. 计算差值并求平方和dX 是一个形状为 [b, n, n, 3] 的张量,表示每个批次内的每对位置之间的差值。torch.sum(dX**2, dim=3) 对最后一个维度(即三维坐标的维度)求和,得到每对位置之间的平方距离,形状为 [b, n, n]

  3. 计算欧式距离:最后,使用 torch.sqrt 对平方距离取平方根,得到最终的欧式距离矩阵 distances,其形状为 [b, n, n],表示每个批次内所有位置对之间的欧式距离。

这个 distances 张量的形状为 [b, n, n],其中 distances[k, i, j] 表示批次 k 中位置 i 和位置 j 之间的欧式距离。


http://www.kler.cn/a/292125.html

相关文章:

  • 超好用shell脚本NuShell mac安装
  • 【链路层】空口数据包详解(4):数据物理通道协议数据单元(PDU)
  • CSS Module:告别类名冲突,拥抱模块化样式(5)
  • vue 获取摄像头拍照,并旋转、裁剪生成新的图片
  • NodeJS 百度智能云文本转语音(实测)
  • 论文阅读 - Causally Regularized Learning with Agnostic Data Selection
  • WWDG—窗口看门狗
  • Claude的小白入门指南
  • 无人机之摄像头篇
  • Aspose.PDF功能演示:在 C# 中将 JPG 图像合并为 PDF
  • 网络压缩之网络剪枝(network pruning)
  • C#中ArrayList
  • 安卓aosp14上自由窗口划线边框Freeform Caption实战开发-千里马framework实战
  • HttpUtils工具类(三)OKHttpClient使用详细教程
  • vue项目中利用后端接口返回的视频地址获取第一帧作为数据封面展示
  • 【赵渝强老师】大数据主从架构的单点故障
  • HTTPS SEO优势
  • EmguCV学习笔记 VB.Net 第10章 人脸识别
  • elementplus组件库el-tree组件设置默认选中项setCheckKeys不生效
  • Mysql5.6.51修改密码
  • Ubuntu: 配置OpenCV环境
  • ITK-高斯滤波
  • 数据库中的逐行数据处理
  • Python知识点:如何使用Python实现自动驾驶模拟
  • TCP/IP网络编程:Linux实现的web服务器
  • NowCoder HJ50 四则运算