当前位置: 首页 > article >正文

预训练蛋白质语言模型ESM-2保姆级使用教程

  • ESM-2(Evolutionary Scale Modeling 2)当下最先进的预训练蛋白质语言模型之一, 由Facebook AI
    Research开发,最新版本使用了48层Transformer编码器架构,有150亿参数。

  • 蛋白质语言模型(PLM)可以用于理解和预测蛋白质序列的特性,包括它们的结构、功能等。

以下是使用ESM-2的详细教程:


# pip install esm
# pip install torch
# pip安装的esm不包括模型,size is small,but torch is big, more than 1G.
import torch
import esm
import os

我们先指定一个新的目录路径,然后写入环境变量字典,设置TORCH_HOME环境变量

  • 为了方便管理模型,我们可以设置TORCH_HOME环境变量,将模型下载到我们指定的目录;

  • 在运行esm.pretrained.esm2_t33_650M_UR50D()时,PyTorch会检查这个目录,如果模型esm2_t33_650M_UR50D已经存在,它会从那里加载模型,否则它会从网上下载模型并保存在这个目录下。

  • 例如:当环境变量设置为’D:\Desktop\model’时,模型的下载地址为’D:\Desktop\model\hub\checkpoints\esm2_t33_650M_UR50D.pt’

  • 设置 TORCH_HOME 环境变量后,所有 PyTorch 相关的库(比如 torch.hub 或 transformers)在下载模型和数据集时,都会使用这个目录作为下载位置。

  • 注:
    PyTorch和ESM都是Facebook的产品;
    os.environ 返回一个代表当前环境变量的字典对象。

new_dir = 'D:\Desktop\model'
os.environ['TORCH_HOME'] = new_dir

下载模型到我们上面指定的目录,或者从指定的目录加载模型;alphabet代表模型使用的字母表,它定义了模型能够处理的字符集合。

  • Size of esm2_t33_650M_UR50D is very big, about 2.4G, 这个模型使用了33层Transformer编码器架构,有650百万(6.5亿参数),使用UniRef50作为训练集;

  • 关于UniRef100、UniRef90和UniRef50的知识,请参考:https://pubmed.ncbi.nlm.nih.gov/17379688/

  • esm包含好几个版本的蛋白质语言预训练模型,可以通过esm.pretrained.xxx指定使用不同版本:esm.pretrained.esm2_t36_3B_UR50D…

model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

构建数据格式转换器

batch_converter = alphabet.get_batch_converter()

将模型设置为评估模式,这会关闭dropout等训练特有的行为

  • 在神经网络中,“dropout” 是一种正则化技术,用于防止或减少模型的过拟合。Dropout通过在训练过程中随机"丢弃"(即暂时移除)网络中的一些神经元(包括它们所有的连接),来减少神经元之间复杂的共适应关系,从而促进模型的泛化能力。

  • 在训练神经网络时启用dropout,以减少过拟合。

  • 在模型评估或预测时禁用dropout,确保所有神经元都参与工作。

model.eval()

demo数据

data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE")
]

使用数据格式转换器将数据转换为模型可以理解/处理的tokens,并将这些tokens填充为相同的长度,然后批量计算每个序列的长度

  • batch_labels:存储每个序列的标签或ID,可能用于后续的监督学习任务;

  • batch_strs:存储原始的蛋白质序列字符串,可能用于调试或显示目的;

  • batch_tokens:是一个二维Tensor张量,一行代表一个蛋白质序列,存储转换后的tokens,这些tokens是原始序列中氨基酸的整数索引(0-20,20种氨基酸);

  • alphabet.padding_idx是一个填充矩阵;

  • batch_tokens != alphabet.padding_idx 生成一个布尔矩阵,实际有氨基酸的位置是True,填充的位置是False;

  • .sum(1) 对这个布尔矩阵沿着维度1求和,即生成每个序列的长度;

  • 这样做的目的是批量求每个序列的长度。

batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

提取蛋白质序列中每个氨基酸残基的 token_representations,获取模型对输入序列的深层次理解

  • with torch.no_grad(): 禁用PyTorch中的梯度计算。在模型评估或预测阶段,我们不需要进行反向传播,禁用梯度计算可以减少内存消耗并加速计算。使用with表明该操作是暂时的。

  • repr_layers=[33] 表示使用Transformer架构第33层的输出作为特征representation;

  • return_contacts=True 表示获取模型预测的氨基酸残基之间的接触图,这在蛋白质结构预测中是一个有用的特征;

  • results[“representations”][33] 表示从模型输出results中提取第33层的representations;

  • results是一个字典,"representations"是这个字典的一个键(key),该键对应的值(value)也是一个字典,存放着每一层的token_representations。字典中嵌套字典。

with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

生成每一个序列的 representations

  • token 0是序列开始标记, 所以第一个氨基酸残基是token 1;最后一个token是序列结束标记;
  • token_representations 是一个三维Tensor张量(可以理解为三维数组);
  • token_representations[i, 1:tokens_len-1] 使用切片取出每条蛋白质序列的token_representations(二维,根据氨基酸数量取行,列全取);
  • .mean(0) 使用每条蛋白质序列的二维token_representations的列平均值作为该蛋白质序列的representations(sequence_representations)
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):  # 例如,氨基酸token为65个,则batch_lens=67
    sequence_representations.append(token_representations[i, 1:tokens_len-1].mean(0))

生成使用无监督学习方法预测的蛋白质内部残基间的接触图

import matplotlib.pyplot as plt
for (ID, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
    plt.matshow(attention_contacts[: tokens_len, : tokens_len])
    plt.title(ID)
    path = os.path.join(os.getcwd(), ID)
    plt.savefig(path, bbox_inches="tight")

参考:https://github.com/facebookresearch/esm


http://www.kler.cn/news/308883.html

相关文章:

  • C++设计模式(更新中)
  • 数据结构:(OJ141)环形列表
  • 李宏毅2023机器学习HW15-Few-shot Classification
  • 部分动态铜皮的孤岛无法删除。报错
  • Linux下的CAN通讯
  • 深度学习中实验、观察与思考的方法与技巧
  • JavaScript:驱动现代Web应用的关键引擎及其与HTML/CSS的集成
  • 数模原理精解【11】
  • el-table 如何实现行列转置?
  • C#读取应用配置的简单类
  • 软件测试工程师面试整理-常见面试问题
  • 后端Controller获取成功,但是前端报错404
  • etcd入门指南:分布式事务、分布式锁及核心API详解
  • 企业开发时,会使用sqlalchedmy来构建数据库 结构吗? 还是说直接写SQL 语句比较多?
  • 断电重启之后服务器都有哪些服务需要重启
  • 828华为云征文|docker部署kafka及ui搭建
  • VRRP 笔记
  • 认知小文3《打破桎梏,编程与人生的基本法则》
  • 抓机遇,创发展︱2025 第十二届广州国际汽车零部件加工技术及汽车模具展览会,零部件国产浪潮不可阻挡
  • Pillow:Python图像处理库详解
  • 计算机网络(网络层)
  • 系统架构设计师:系统质量属性与架构评估
  • 固态硬盘:量产、开卡、ROM短接是指什么?
  • 34.贪心算法1
  • 2024最新股票系统源码 附教程
  • Track 08:AIML
  • CTFHub技能树-信息泄露-HG泄漏
  • 医学数据分析实训 项目二 数据预处理作业
  • 在 React 中掌握 useImperativeHandle(使用 TypeScript)
  • visual prompt tuning和visual instruction tuning