当前位置: 首页 > article >正文

AF3 AttentionPairBias类源码解读

AttentionPairBias 是 AlphaFold3 的一个注意力机制模块,设计用于实现全自注意力(Full Self-Attention)并结合成对表示的偏置(Pair Bias)。它在 AlphaFold3 的架构中发挥重要作用,特别是在处理蛋白质序列和空间对称性相关的任务时。

源代码:

class AttentionPairBias(nn.Module):
    """Full self-attention with pair bias."""

    def __init__(
            self,
            dim: int,
            c_pair: int = 16,
            no_heads: int = 8,
            dropout: float = 0.0,
            input_gating: bool = True,
            residual: bool = True,
            inf: float = 1e8,
    ):
        """Initialize the AttentionPairBias module.
        Args:
            dim:
                Total dimension of the model.
            c_pair:
                The number of channels for the pair representation. Defaults to 16.
            no_heads:
                Number of parallel attention heads. Note that c_atom will be split across no_heads
                (i.e. each head will have dimension c_atom // no_heads).
            dropout:
                Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
            residual:
                Whether the module is used as a residual block. Default: True. This affects the initialization
                of the final projection layer of the MHA attention.
            input_gating:
                Whether the single representation should be gated with another single-like representation using
                adaptive layer normalization. Default: True.
        """
        super().__init__()
        self.dim = dim
        self.c_pair = c_pair
        self.num_heads = no_heads
        self.dropout = dropout
        self.input_gating = input_gating
        self.inf = inf

        # Perform check for dimensionality
        assert dim % no_heads == 0, f"the model dimensionality ({dim}) should be divisible by the " \
                                    f"number of heads ({no_heads}) "
        # Projections
        self.input_proj = None
        self.output_proj_linear = None
        if input_gating:
            self.input_proj = AdaLN(dim)

            # Output projection from AdaLN
            self.output_proj_linear = Linear(dim, dim, init='gating')
            self.output_proj_linear.bias = nn.Parameter(torch.ones(dim) * -2.0)  # gate values will be ~0.11
        else:
            self.input_proj = LayerNorm(dim)

        # Attention
        self.attention = Attention(
            c_q=dim,
            c_k=dim,
            c_v=dim,
            c_hidden=dim // no_heads,
            no_heads=no_heads,
            gating=True,
            residual=residual,
            proj_q_w_bias=True,
        )

        # Pair bias
        self.proj_pair_bias = nn.Sequential(
            LayerNorm(self.c_pair),
            LinearNoBias(self.c_pair, self.num_heads, init='normal')
        )

    def _prep_biases(
            self,
            single_repr: torch.Tensor,  # (*, S, N, c_s)
            pair_repr: torch.Tensor,  # (*, N, N, c_z)
            mask: Optional[torch.Tensor] = None,  # (*, N)
    ):
        """Prepares the mask and pair biases in the shapes expected by the DS4Science attention.

        Expected shapes for the DS4Science kernel:
        # Q, K, V: [Batch, N_seq, N_res, Head, Dim]
        # res_mask: [Batch, N_seq, 1, 1, N_res]
        # pair_bias: [Batch, 1, Head, N_res, N_res]
        """
        # Compute the single mask
        n_seq, n_res, _ = single_repr.shape[-3:]
        if mask is None:
            # [*, N_seq, N_res]
            mask = single_repr.new_ones(
                single_repr.shape[:-3] + (n_seq, n_res),
            )
        else:
            # Expand mask by N_seq (or samples per trunk)
            new_shape = (mask.shape[:-1] + (n_seq, n_res))  # (*, N_seq, N_res)
            mask = mask.unsqueeze(-2).expand(new_shape)
            mask = mask.to(single_repr.dtype)
            
        # [*, N_seq, 1, 1, N_res]
        mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]

        # Project pair biases per head from pair representation
        pair_bias = self.proj_pair_bias(pair_repr)  # (bs, n_tokens, n_tokens, n_heads)
        pair_bias = rearrange(pair_bias, 'b i j h -> b h i j')  # # (bs, h, n, n)
        pair_bias = pair_bias.unsqueeze(-4)
        return mask_bias, pair_bias

    def forward(
            self,
            singl

http://www.kler.cn/a/513062.html

相关文章:

  • 智能新浪潮:亚马逊云科技发布Amazon Nova模型
  • 第12章:Python TDD完善货币加法运算(一)
  • 多级缓存 JVM进程缓存
  • 2024春秋杯密码题第一、二天WP
  • 通过学习更多样化的生成数据进行更广泛的数据分发来改进实例分割
  • 通过Ukey或者OTP动态口令实现windows安全登录
  • 三、I2C客户端驱动 —— htu21d
  • uboot剖析之命令行延时
  • C++ 学习:深入理解 Linux 系统中的冯诺依曼架构
  • python爬虫入门(实践)
  • 基于Springboot+Redis秒杀系统 demo
  • 【2024年华为OD机试】 (JavaScriptJava PythonC/C++)
  • 网络安全态势感知技术综述
  • Apache Hive 聚合函数与 OVER 窗口函数:从基础到高级应用
  • Oracle审计
  • SecureUtil.aes数据加密工具类
  • 通义万相:阿里巴巴 AI 生成式多模态平台解析与实战应用
  • 细说STM32F407单片机电源低功耗StandbyMode待机模式及应用示例
  • AI编程工具使用技巧:在Visual Studio Code中高效利用阿里云通义灵码
  • 如何提升IP地址查询数据服务的安全?
  • controlnet 多 condition 融合
  • 网安篇(一)日志分析——从给的登录日志中找出攻击IP和使用的用户名
  • 数据结构学习记录-树和二叉树
  • 堆的实现(C语言详解版)
  • yolo系列模型为什么坚持使用CNN网络?
  • LeetCode:37. 解数独