AF3 AttentionPairBias类源码解读
AttentionPairBias
是 AlphaFold3 的一个注意力机制模块,设计用于实现全自注意力(Full Self-Attention)并结合成对表示的偏置(Pair Bias)。它在 AlphaFold3 的架构中发挥重要作用,特别是在处理蛋白质序列和空间对称性相关的任务时。
源代码:
class AttentionPairBias(nn.Module):
"""Full self-attention with pair bias."""
def __init__(
self,
dim: int,
c_pair: int = 16,
no_heads: int = 8,
dropout: float = 0.0,
input_gating: bool = True,
residual: bool = True,
inf: float = 1e8,
):
"""Initialize the AttentionPairBias module.
Args:
dim:
Total dimension of the model.
c_pair:
The number of channels for the pair representation. Defaults to 16.
no_heads:
Number of parallel attention heads. Note that c_atom will be split across no_heads
(i.e. each head will have dimension c_atom // no_heads).
dropout:
Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
residual:
Whether the module is used as a residual block. Default: True. This affects the initialization
of the final projection layer of the MHA attention.
input_gating:
Whether the single representation should be gated with another single-like representation using
adaptive layer normalization. Default: True.
"""
super().__init__()
self.dim = dim
self.c_pair = c_pair
self.num_heads = no_heads
self.dropout = dropout
self.input_gating = input_gating
self.inf = inf
# Perform check for dimensionality
assert dim % no_heads == 0, f"the model dimensionality ({dim}) should be divisible by the " \
f"number of heads ({no_heads}) "
# Projections
self.input_proj = None
self.output_proj_linear = None
if input_gating:
self.input_proj = AdaLN(dim)
# Output projection from AdaLN
self.output_proj_linear = Linear(dim, dim, init='gating')
self.output_proj_linear.bias = nn.Parameter(torch.ones(dim) * -2.0) # gate values will be ~0.11
else:
self.input_proj = LayerNorm(dim)
# Attention
self.attention = Attention(
c_q=dim,
c_k=dim,
c_v=dim,
c_hidden=dim // no_heads,
no_heads=no_heads,
gating=True,
residual=residual,
proj_q_w_bias=True,
)
# Pair bias
self.proj_pair_bias = nn.Sequential(
LayerNorm(self.c_pair),
LinearNoBias(self.c_pair, self.num_heads, init='normal')
)
def _prep_biases(
self,
single_repr: torch.Tensor, # (*, S, N, c_s)
pair_repr: torch.Tensor, # (*, N, N, c_z)
mask: Optional[torch.Tensor] = None, # (*, N)
):
"""Prepares the mask and pair biases in the shapes expected by the DS4Science attention.
Expected shapes for the DS4Science kernel:
# Q, K, V: [Batch, N_seq, N_res, Head, Dim]
# res_mask: [Batch, N_seq, 1, 1, N_res]
# pair_bias: [Batch, 1, Head, N_res, N_res]
"""
# Compute the single mask
n_seq, n_res, _ = single_repr.shape[-3:]
if mask is None:
# [*, N_seq, N_res]
mask = single_repr.new_ones(
single_repr.shape[:-3] + (n_seq, n_res),
)
else:
# Expand mask by N_seq (or samples per trunk)
new_shape = (mask.shape[:-1] + (n_seq, n_res)) # (*, N_seq, N_res)
mask = mask.unsqueeze(-2).expand(new_shape)
mask = mask.to(single_repr.dtype)
# [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# Project pair biases per head from pair representation
pair_bias = self.proj_pair_bias(pair_repr) # (bs, n_tokens, n_tokens, n_heads)
pair_bias = rearrange(pair_bias, 'b i j h -> b h i j') # # (bs, h, n, n)
pair_bias = pair_bias.unsqueeze(-4)
return mask_bias, pair_bias
def forward(
self,
singl