ESMC-600M蛋白质语言模型本地部署攻略
前言
之前介绍了ESMC-6B模型的网络接口调用方法,但申请token比较慢,有网友问能不能出一个本地部署ESMC小模型的攻略,遂有本文。
其实本地部署并不复杂,官方github上面也比较清楚了。
操作过程
环境配置:CUDA 12.1、torch 2.2.1+cu121、esm 3.1.1
完整的环境包列表:(因为做了些其他任务,这个里面其实不是所有都会用到,可以先把上面三个安装好,差哪些库再补哪些库)
Package Version
------------------------ ------------
asttokens 3.0.0
attrs 24.3.0
biopython 1.84
biotite 0.41.2
Brotli 1.1.0
certifi 2024.12.14
charset-normalizer 3.4.0
cloudpathlib 0.20.0
decorator 5.1.1
einops 0.8.0
esm 3.1.1
executing 2.1.0
filelock 3.13.1
fsspec 2024.2.0
huggingface-hub 0.27.0
idna 3.10
ipython 8.30.0
jedi 0.19.2
Jinja2 3.1.3
joblib 1.4.2
MarkupSafe 2.1.5
matplotlib-inline 0.1.7
mpmath 1.3.0
msgpack 1.1.0
msgpack-numpy 0.4.8
networkx 3.2.1
numpy 1.26.3
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.1.105
nvidia-nvtx-cu12 12.1.105
packaging 24.2
pandas 2.2.3
parso 0.8.4
pexpect 4.9.0
pillow 10.2.0
pip 24.2
prompt_toolkit 3.0.48
ptyprocess 0.7.0
pure_eval 0.2.3
Pygments 2.18.0
python-dateutil 2.9.0.post0
pytz 2024.2
PyYAML 6.0.2
regex 2024.11.6
requests 2.32.3
safetensors 0.4.5
scikit-learn 1.6.0
scipy 1.14.1
setuptools 75.1.0
six 1.17.0
stack-data 0.6.3
sympy 1.13.1
tenacity 9.0.0
threadpoolctl 3.5.0
tokenizers 0.20.3
torch 2.2.1+cu121
torchdata 0.7.1
torchtext 0.17.1
torchvision 0.17.1+cu121
tqdm 4.67.1
traitlets 5.14.3
transformers 4.46.3
triton 2.2.0
typing_extensions 4.9.0
tzdata 2024.2
urllib3 2.2.3
wcwidth 0.2.13
wheel 0.44.0
下载ESMC-600m的权重:
EvolutionaryScale/esmc-600m-2024-12 at main
下载之后把权重放在工作目录下的这个地址:data/weights
代码
和官方github上给出的例子比较类似,不过加了些修改。
from esm.models.esmc import ESMC
from esm.sdk.api import *
import torch
import os
import pickle
from esm.tokenization import EsmSequenceTokenizer
# 使用预下载的参数
os.environ["INFRA_PROVIDER"] = "True"
device = torch.device("cuda:0")
client = ESMC.from_pretrained("esmc_600m",device=device)
# 读取蛋白质序列,这里需要根据自己的数据格式进行调整
def read_seq(seqfilepath):
with open(seqfilepath,"r") as f:
line = f.readline()
seq = f.readline()
return seq
# 这里沿用了上一次逆向出来的编码格式,可以替换为ESM自带的编码格式
all_amino_acid_number = {'A':5, 'C':23,'D':13,'E':9, 'F':18,
'G':6, 'H':21,'I':12,'K':15,'L':4,
'M':20,'N':17,'P':14,'Q':16,'R':10,
'S':8, 'T':11,'V':7, 'W':22,'Y':19,
'_':32}
def esm_encoder_seq(seq, pad_len):
s = [all_amino_acid_number[x] for x in seq]
while len(s)<pad_len:
s.append(1)
s.insert(0,0)
s.append(2)
return torch.tensor(s)
def get_esm_embedding(seq):
protein_tensor = ESMProteinTensor(sequence=esm_encoder_seq(seq,len(seq)).to(device))
logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))
esm_embedding = logits_output.embeddings
assert isinstance(esm_embedding,torch.Tensor)
return esm_embedding
# 这个路径设置并不重要,可以自行调整
seq_path = "seq.fasta"
seq = read_seq(seq_path)
print(seq)
# 获取序列embedding
seq_list = [seq]
emb = get_esm_embedding(seq)
with open("seq_emb.pkl","wb") as f:
pickle.dump(emb,f)
print(emb.shape)
随便用了一个序列,得到的运行结果,tensor形状是[1,序列长度+2,1152]