AF3 distogram_loss函数解读
AlphaFold3 的distogram loss函数用于训练中比较预测的距离分布(由 logits
表示)与真实距离分布之间的差异。在蛋白质结构预测中,distogram 表示每对残基之间距离落在各个区间(bin)的概率分布,损失函数使用交叉熵来衡量预测分布与真实分布(通过计算残基之间的欧氏距离确定)之间的差异。
源代码:
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * F.log_softmax(logits, dim=-1),
dim=-1,
)
return loss
def distogram_loss(
logits: Tensor, # (bs, n_tokens, n_tokens, n_bins)
all_atom_positions, # (bs, n_tokens * 4, 3)
token_mask, # (bs, n_tokens)
min_bin: float = 0.0,
max_bin: float = 32.0,
no_bins: int = 64,
eps: float = 1e-6,
**kwargs,
) -> Tensor: # (bs,)
# TODO: this is an inelegant implementation, integrate with the data pipeline
batch_size, n_tokens = token_mask.shape
# Compute pseudo beta and mask
all_atom_positions = all_atom_positions.reshape(batch_size, n_tokens, 4, 3)
ca_pos = residue_constants.atom_order["CA"]
pseudo_beta = all_atom_positions[..., ca_pos, :] # (bs, n_tokens, 3)
pseudo_beta_mask = token_mask # (bs, n_tokens)
boundaries = torch.linspace(
min_bin,
max_bin,
no_bins - 1,
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(pseudo_beta[..., :, None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
keepdim=True,
)
true_bi