RFdiffusion Sampler类 sample_step 方法解读
Sampler类的sample_step
方法的主要目的是根据扩散模型的预测生成在时间步 t-1
上的下一个三维结构、序列和其他相关特征。这是扩散采样过程的核心步骤之一。
源代码:
def sample_step(self, *, t, x_t, seq_init, final_step):
'''Generate the next pose that the model should be supplied at timestep t-1.
Args:
t (int): The timestep that has just been predicted
seq_t (torch.tensor): (L,22) The sequence at the beginning of this timestep
x_t (torch.tensor): (L,14,3) The residue positions at the beginning of this timestep
seq_init (torch.tensor): (L,22) The initialized sequence used in updating the sequence.
Returns:
px0: (L,14,3) The model's prediction of x0.
x_t_1: (L,14,3) The updated positions of the next step.
seq_t_1: (L,22) The updated sequence of the next step.
tors_t_1: (L, ?) The updated torsion angles of the next step.
plddt: (L, 1) Predicted lDDT of x0.
'''
msa_masked, msa_full, seq_in, xt_in, idx_pdb, t1d, t2d, xyz_t, alpha_t = self._preprocess(
seq_init, x_t, t)
N,L = msa_masked.shape[:2]
if self.symmetry is not None:
idx_pdb, self.chain_idx = self.symmetry.res_idx_procesing(res_idx=idx_pdb)
msa_prev = None
pair_prev = None
state_prev = None
with torch.no_grad():
msa_prev, pair_prev, px0, state_prev, alpha, logits, plddt = self.model(msa_masked,
msa_full,
seq_in,
xt_in,