AF3 TriangleMultiplicativeUpdate类代码解读
AlphaFold3的TriangleMultiplicativeUpdate类继承自BaseTriangleMultiplicativeUpdate
类是,他有两个子类 TriangleMultiplicationOutgoing和TriangleMultiplicationIncoming类,他们差异在于__init__初始化方法_outgoing参数不同,分别表示 “输出”/“输入”边三角形乘法更新类。TriangleMultiplicativeUpdate
主要用于实现算法 11 和 12 的三角乘法更新。这些算法在结构生物信息学中,如蛋白质结构预测的 AlphaFold 模型中,被用来更新氨基酸残基对的表示。
源代码:
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
def _inference_forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_chunk_size: Optional[int] = None,
with_add: bool = True,
):
"""
Args:
z:
A [*, N, N, C_z] pair representation
mask:
A [*, N, N] pair mask
inplace_chunk_size:
Size of chunks used in the main computation. Increase to trade
memory for speed.
with_add:
If True, z is overwritten with (z + update). Otherwise, it is
overwritten with (update).
Returns:
A reference to the overwritten z
More memory-efficient, inference-only version of the forward function.
Uses in-place operations, fusion of the addition that happens after
this module in the Evoformer, a smidge of recomputation, and
a cache of overwritten values to lower peak memory consumption of this
module from 5x the size of the x tensor z to 2.5x its size. Useful
for inference on extremely long sequences.
It works as follows. We will make reference to variables used in the
default forward implementation below. Naively, triangle multiplication
attention requires the manifestation of 5 tensors the size of z:
1) z, the "square" x tensor, 2) a, the first projection of z,
3) b, the second projection of b, 4) g, a z-sized mask, and 5) a
z-sized tensor for intermediate computations. For large N, this is
prohibitively expensive; for N=4000, for example, z is more than 8GB
alone. To avoid this problem, we compute b, g, and all intermediate
tensors in small chunks, noting that the chunks required to compute a
chunk of the output depend only on the tensor a and corresponding
vertical and horizontal chunks of z. This suggests an algorithm that