当前位置: 首页 > article >正文

AF3 one_hot函数解读

AlphaFold3的one_hot(src.utils.tensor_utils模块中)one_hot 函数通过寻找输入值 x 在离散化范围 v_bins 中的最近值并生成相应的 one-hot 编码,提供了一种将连续数值映射到离散表示的通用方法。

源代码:

def one_hot(x, v_bins):
    dtype = v_bins.dtype
    reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
    diffs = x[..., None] - reshaped_bins
    am = torch.argmin(torch.abs(diffs), dim=-1)
    return nn.functional.one_hot(am, num_classes=len(v_bins)).to(dtype)

函数的作用

将任意连续数值张量 x 根据给定的离散化范围 v_bins 转换为 one-hot 编码。


参数

  • x:输入张量,形状为 [..., n],其中每个值都需要被映射到 one-hot 编码。
  • v_bins:离散化范围的张量(通常是一个 1D 张量,表示离散化的值)。

核心逻辑

1. 获取 v_bins 的数据类型
dtype = v_bins.dtype

v_bins 的数据类型会被保留并用于后续操作(确保输出结果与 v_bins 的数据类型一致)。

2. 调整 v_bins 的形状
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
  • 将 v_bins 调整为与输入张量 x 广播兼容的形状。
  • 如果 x 的 shape 是 [..., n],则 reshaped_bins 的形状会变成:
(1, 1, ..., 1, len(v_bins))

注: 元组的乘法和加法。 元组之间可以进行加法操作,加法操作的实际效果是将两个元组中的元素合并起来,生成一个新的元组。元组与一个整数相乘时,会按照整数指定的次数复制该元组元素,生成一个新的元组。 例如 (1,) * 2 的结果是 (1, 1);  (1, 1)+ (2, )为(1, 1, 2)。
 

3. 计算差值
diffs = x[..., None] - reshaped_bins
  • 为了将 x 的每个值与 v_bins 中的所有值进行比较,将 x 的最后一维扩展一个维度,形状变为 [..., n, 1]
  • diffs 是每个 x 的值与 v_bins 中每个值的差值,形状为:
[..., n, len(v_bins)]
4. 找到最接近的索引
am = torch.argmin(torch.abs(diffs), dim=-1)
  • 对每个 x 的值,找到与 v_bins 中最接近的值的索引。
  • 结果 am 是一个整数张量,形状为 [..., n],表示每个值在 v_bins 中的最近匹配索引。
5. 生成 one-hot 编码
return nn.functional.one_hot(am, num_classes=len(v_bins)).to(dtype)
  • 使用 torch.nn.functional.one_hot 生成 one-hot 编码。

  • am 中的每个索引位置被转换为一个长度为 len(v_bins) 的 one-hot 向量。

  • 将生成的 one-hot 编码转换为 v_bins 的数据类型。


举例

输入
x = torch.tensor([1.2, 3.7, 2.5])
v_bins = torch.tensor([1.0, 2.0, 3.0, 4.0])
计算过程
diffs
diffs = x[..., None] - v_bins
[[ 0.2, -0.8, -1.8, -2.8],  # 1.2 - [1.0, 2.0, 3.0, 4.0]
 [ 2.7,  1.7,  0.7, -0.3],  # 3.7 - [1.0, 2.0, 3.0, 4.0]
 [ 1.5,  0.5, -0.5, -1.5]]  # 2.5 - [1.0, 2.0, 3.0, 4.0]
am (最小绝对值索引)
am = [0, 2, 1]  # 对应 v_bins 的索引 [1.0, 3.0, 2.0]
one_hot(am)
one_hot = [[1, 0, 0, 0],  # 1.0 -> [1, 0, 0, 0]
           [0, 0, 1, 0],  # 3.0 -> [0, 0, 1, 0]
           [0, 1, 0, 0]]  # 2.0 -> [0, 1, 0, 0]

总结

one_hot 函数通过寻找输入值 x 在离散化范围 v_bins 中的最近值并生成相应的 one-hot 编码,提供了一种将连续数值映射到离散表示的通用方法。这种方法在位置编码和离散化过程中非常有用。


http://www.kler.cn/a/471760.html

相关文章:

  • RPC自定义协议
  • 搭建企业AI助理的创新应用与案例分析
  • 谷粒商城-高级篇完结-Sleuth+Zipkin 服务链路追踪
  • 搭建Hadoop分布式集群
  • 中学教资笔记1
  • 【微服务】6、限流 熔断
  • redis学习笔记(一)了解redis
  • 如何设置通过Visual Studio(VS)打开的C#项目工具集?
  • 书籍推荐:MySQL 是怎样运行的-从根上理解 MySQL
  • C# 之某度协议登录,JS逆向,手机号绑定,获取CK
  • Tableau数据可视化与仪表盘搭建-可视化原则及BI仪表盘搭建
  • 05容器篇(D2_集合 - D6_容器源码分析篇 - D1_ArrayList)
  • Flex布局的三个属性
  • 2025年1月4日蜻蜓q旗舰版st完整开源·包含前后端所有源文件·开源可商用可二开·优雅草科技·优雅草kir|优雅草星星|优雅草银满|优雅草undefined
  • WPF中RenderTargetBitmap问题解决
  • 服务器等保测评审计日志功能开启(auditd)和时间校准
  • 如何从串 ‘ 中国 +86‘ 中,获取到‘中国’:strip()、split()及正则表达式的使用
  • 通达信行情接口失效?C# 实现获取五档报价行情 GetSecurityQuotes
  • Ubuntu 安装 Java 1.8
  • Ruby语言的数据库编程