当前位置: 首页 > article >正文

haiku实现TemplatePairStack类

TemplatePairStack是实现蛋白质结构模版pair_act特征表示的类:
通过layer_stack.layer_stack(c.num_block)(block) 堆叠c.num_block(配置文件中为2)block 函数,每个block对输入pair_act 和 pair_mask执行计算流程:TriangleAttention —> dropout ->TriangleAttention —> dropout -> TriangleMultiplication —> dropout -> TriangleMultiplication —> dropout -> Transition

import haiku as hk


class TemplatePairStack(hk.Module):
  """Pair stack for the templates.

  Jumper et al. (2021) Suppl. Alg. 16 "TemplatePairStack"
  """

  def __init__(self, config, global_config, name='template_pair_stack'):
    super().__init__(name=name)
    self.config = config
    self.global_config = global_config

  def __call__(self, pair_act, pair_mask, is_training, safe_key=None):
    """Builds TemplatePairStack module.

    Arguments:
      pair_act: Pair activations for single template, shape [N_res, N_res, c_t].
      pair_mask: Pair mask, shape [N_res, N_res].
      is_training: Whether the module is in training mode.
      safe_key: Safe key object encapsulating the random number generation key.

    Returns:
      Updated pair_act, shape [N_res, N_res, c_t].
    """

    if safe_key is None:
      safe_key = prng.SafeKey(hk.next_rng_key())

    gc = self.global_config
    c = self.config

    if not c.num_block:
      return pair_act

    def block(x):
      """One block of the template pair stack."""
      pair_act, safe_key = x

      dropout_wrapper_fn = functools.partial(
          dropout_wrapper, is_training=is_training, global_config=gc)

      safe_key, *sub_keys = safe_key.split(6)
      sub_keys = iter(sub_keys)

      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_starting_node, gc,
                            name='triangle_attention_starting_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleAttention(c.triangle_attention_ending_node, gc,
                            name='triangle_attention_ending_node'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_outgoing, gc,
                                 name='triangle_multiplication_outgoing'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          TriangleMultiplication(c.triangle_multiplication_incoming, gc,
                                 name='triangle_multiplication_incoming'),
          pair_act,
          pair_mask,
          next(sub_keys))
      pair_act = dropout_wrapper_fn(
          Transition(c.pair_transition, gc, name='pair_transition'),
          pair_act,
          pair_mask,
          next(sub_keys))

      return pair_act, safe_key

    if gc.use_remat:
      block = hk.remat(block)

    res_stack = layer_stack.layer_stack(c.num_block)(block)
    pair_act, safe_key = res_stack((pair_act, safe_key))
    return pair_act


http://www.kler.cn/a/229140.html

相关文章:

  • 计算机组成原理(计算机系统3)--实验二:MIPS64乘法实现实验
  • Ansible自动化运维:基础与实践
  • 探索 Transformer²:大语言模型自适应的新突破
  • 【Vim Masterclass 笔记14】S07L29 + L30:练习课08 —— Vim 文本对象同步练习(含点评课内容)
  • 微信小程序订阅消息提醒-云函数
  • 编译pytorch——cuda-toolkit-nvcc
  • 【Linux】Linux权限(下)
  • 倒模专用制作耳机壳UV树脂:改性丙烯酸树脂
  • 保研机试算法训练个人记录笔记(三)
  • 代码随想录day18--二叉树的应用6
  • 速盾:cdn技术和原理是什么关系
  • transformers重要组件(模型与分词器)
  • 信息打点Day9
  • 02-Java抽象工厂模式 ( Abstract Factory Pattern )
  • 框架学习Maven
  • RabbitMQ-2.SpringAMQP
  • 算法专题:滑动窗口
  • 对于协同过滤算法我自己的一些总结和看法
  • 网易和腾讯面试题精选---性能和优化面试问题
  • Compose | UI组件(十三) | Navigation - 页面导航
  • thinkphp6入门(18)-- 中间件中除了handle函数,还可以有其它函数吗
  • 【LeetCode每日一题】2381. 字母移位 II2406. 将区间分为最少组数 (差分数组)
  • 如何在Windows系统上部署docker
  • PyTorch的 torch.unsqueeze() 和 torch.squeeze()方法详解
  • 黑豹程序员-封装组件-Vue3 setup方式子组件传值给父组件
  • 地下停车场智慧监查系统:科技让停车更智能