AF3 make_fixed_size函数解读
AlphaFold3 data_transforms 模块的
make_fixed_size 函数的作用是将输入的蛋白质特征字典 protein
中的各个特征张量调整为固定大小。这是为了确保在批量处理时,所有特征张量的形状一致,从而避免形状不匹配的问题。
源代码:
import itertools
import torch
from src.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
@curry1
def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
pad_size_map = {
NUM_RES: num_res,
NUM_MSA_SEQ: msa_cluster_size,
NUM_EXTRA_SEQ: extra_msa_size,
NUM_TEMPLATES: num_templates,
}
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == "extra_cluster_assignment":
continue
shape = list(v.shape)
schema = shape_schema[k]
msg = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
padding.reverse()
padding = list(itertools.chain(*padding))
if padding:
protein[k] = torch.nn.functional.pad(v, padding)
protein[k] = torch.reshape(protein[k], pad_size)
return protein
源码解读:
函数签名
@curry1
def make_fixed_size(
protein,
shape_schema,
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0,
):
@curry1:这是一个装饰器,用于将函数的参数部分绑定。curry1 表示将函数的第一个参数(protein)