AF3 MSAWeightedAveragingNaive类解读
AlphaFold3的MSAWeightedAveragingNaive
类用于处理多序列比对(MSA)的数据进行特征提取,进行加权平均并利用门控张量进行归一化操作。
源代码:
class MSAWeightedAveragingNaive(nn.Module):
def __init__(self, no_heads: int, c_hidden: int):
super(MSAWeightedAveragingNaive, self).__init__()
self.no_heads = no_heads
self.c_hidden = c_hidden
self.softmax = nn.Softmax(dim=-2)
def forward(self, v, b, g, n_seq, n_res):
new_v_shape = (v.shape[:-4] + (n_seq, n_res, n_res, self.no_heads, self.c_hidden))
v = v.unsqueeze(-4).expand(new_v_shape) # (*, seq, res, res, heads, c_hidden)
# Weighted average with gating
weights = self.softmax(b)
weights = weights.unsqueeze(-4).unsqueeze(-1) # (*, 1, res, res, heads, 1)
o = F.sigmoid(g) * torch.sum(v * weights, dim=-3) # (*, seq, res, heads, c_hidden)
o = flatten_final_dims(o, 2)