AF3 PairStack类源码解读
PairStack
是 AlphaFold 的核心模块之一,用于对残基对(residue-residue pair)的特征张量 z
进行迭代更新。这个模块结合几何操作(如三角形乘法)和注意力机制,逐步建模蛋白质序列中残基之间的复杂关系。
源代码:
class PairStack(nn.Module):
def __init__(
self,
c_z: int,
c_hidden_tri_mul: int = 128,
c_hidden_pair_attn: int = 32,
no_heads_tri_attn: int = 4,
transition_n: int = 4,
pair_dropout: float = 0.25,
fuse_projection_weights: bool = False,
inf: float = 1e8,
):
super(PairStack, self).__init__()
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
c_z,
c_hidden_tri_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
c_z,
c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_tri_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
c_z,
c_hidden_pair_attn,
no_heads_tri_attn,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
c_z,
c_hidden_pair_attn,
no_heads_tri_attn,
inf=inf,
)
self.transition = Transition(
c_z,
transition_n,
)
self.dropout_row_layer = DropoutRowwise(pair_dropout)
self.dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(