AF3 one_hot函数解读
AlphaFold3的one_hot(src.utils.tensor_utils模块中)one_hot
函数通过寻找输入值 x
在离散化范围 v_bins
中的最近值并生成相应的 one-hot 编码,提供了一种将连续数值映射到离散表示的通用方法。
源代码:
def one_hot(x, v_bins):
dtype = v_bins.dtype
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).to(dtype)
函数的作用
将任意连续数值张量 x
根据给定的离散化范围 v_bins
转换为 one-hot 编码。
参数
x
:输入张量,形状为[..., n]
,其中每个值都需要被映射到 one-hot 编码。v_bins
:离散化范围的张量(通常是一个 1D 张量,表示离散化的值)。
核心逻辑
1. 获取 v_bins
的数据类型
dtype = v_bins.dtype
v_bins
的数据类型会被保留并用于后续操作(确保输出结果与 v_bins
的数据类型一致)。
2. 调整 v_bins
的形状
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
- 将
v_bins
调整为与输入张量x
广播兼容的形状。 - 如果
x
的 shape 是[..., n]
,则reshaped_bins
的形状会变成:
(1, 1, ..., 1, len(v_bins))
注: 元组的乘法和加法。 元组之间可以进行加法操作,加法操作的实际效果是将两个元组中的元素合并起来,生成一个新的元组。元组与一个整数相乘时,会按照整数指定的次数复制该元组元素,生成一个新的元组。 例如 (1,) * 2
的结果是 (1, 1); (1, 1)
+ (2, )为(1, 1, 2)。
3. 计算差值
diffs = x[..., None] - reshaped_bins
- 为了将
x
的每个值与v_bins
中的所有值进行比较,将x
的最后一维扩展一个维度,形状变为[..., n, 1]
。 diffs
是每个x
的值与v_bins
中每个值的差值,形状为:
[..., n, len(v_bins)]
4. 找到最接近的索引
am = torch.argmin(torch.abs(diffs), dim=-1)
- 对每个
x
的值,找到与v_bins
中最接近的值的索引。 - 结果
am
是一个整数张量,形状为[..., n]
,表示每个值在v_bins
中的最近匹配索引。
5. 生成 one-hot 编码
return nn.functional.one_hot(am, num_classes=len(v_bins)).to(dtype)
-
使用
torch.nn.functional.one_hot
生成 one-hot 编码。 -
am
中的每个索引位置被转换为一个长度为len(v_bins)
的 one-hot 向量。 -
将生成的 one-hot 编码转换为
v_bins
的数据类型。
举例
输入
x = torch.tensor([1.2, 3.7, 2.5])
v_bins = torch.tensor([1.0, 2.0, 3.0, 4.0])
计算过程
diffs
:
diffs = x[..., None] - v_bins
[[ 0.2, -0.8, -1.8, -2.8], # 1.2 - [1.0, 2.0, 3.0, 4.0]
[ 2.7, 1.7, 0.7, -0.3], # 3.7 - [1.0, 2.0, 3.0, 4.0]
[ 1.5, 0.5, -0.5, -1.5]] # 2.5 - [1.0, 2.0, 3.0, 4.0]
am
(最小绝对值索引):
am = [0, 2, 1] # 对应 v_bins 的索引 [1.0, 3.0, 2.0]
one_hot(am)
:
one_hot = [[1, 0, 0, 0], # 1.0 -> [1, 0, 0, 0]
[0, 0, 1, 0], # 3.0 -> [0, 0, 1, 0]
[0, 1, 0, 0]] # 2.0 -> [0, 1, 0, 0]
总结
one_hot
函数通过寻找输入值 x
在离散化范围 v_bins
中的最近值并生成相应的 one-hot 编码,提供了一种将连续数值映射到离散表示的通用方法。这种方法在位置编码和离散化过程中非常有用。