AF3 AtomAttentionDecoder类源码解读
AlphaFold3的AtomAttentionDecoder类
旨在从每个 token 的表示扩展到每个原子的表示,同时通过交叉注意力机制对原子及其对关系进行建模。这种设计可以在生物分子建模中捕获复杂的原子级别交互。
源代码:
class AtomAttentionDecoder(nn.Module):
"""AtomAttentionDecoder that broadcasts per-token activations to per-atom activations."""
def __init__(
self,
c_token: int,
c_atom: int = 128,
c_atompair: int = 16,
no_blocks: int = 3,
no_heads: int = 8,
dropout=0.0,
n_queries: int = 32,
n_keys: int = 128,
clear_cache_between_blocks: bool = False
):
"""Initialize the AtomAttentionDecoder module.
Args:
c_token:
The number of channels for the token representation.
c_atom:
The number of channels for the atom representation. Defaults to 128.
c_atompair:
The number of channels for the atom pair representation. Defaults to 16.
no_blocks:
Number of blocks.
no_heads:
Number of parallel attention heads. Note that c_atom will be split across no_heads
(i.e. each head will have dimension c_atom // no_heads).
dropout:
Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
n_queries:
The size of the atom window. Defaults to 32.
n_keys:
Number of atoms each atom attends to in local sequence space. Defaults to 128.
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super().__init__()
self.c_token = c_token
self.c_atom = c_atom
self.c_atompair = c_atompair
self.num_blocks = no_blocks
self.num_heads = no_heads
self.dropout = dropout
self.n_queries = n_queries
self.n_keys = n_keys
self.clear_cache_bet