RFdiffusion get_torsions函数解读
函数功能
get_torsions
函数根据输入的原子坐标(xyz_in
)和氨基酸序列(seq
),计算一组主链和侧链的扭转角(torsions
)。同时生成备用扭转角(torsions_alt
),用于表示可以镜像翻转的几何结构,并返回掩码(tors_mask)和是否平面化(tors_planar)的信息。
输入参数
xyz_in
: 原子坐标张量,形状[B, L, 27, 3]
,其中B
是批量大小,L
是残基数量,27 表示原子类型(如 N, CA, C, CB),3
是坐标。seq
: 氨基酸序列,形状[B, L]
,每个值对应氨基酸编号。torsion_indices
: 表示计算侧链扭转角所需的 4 个原子索引。torsion_can_flip
: 布尔数组,指示哪些扭转角可以翻转。ref_angles
: 理想化的参考角度,主要用于 CB 弯曲、CB 扭转和 CG 弯曲。mask_in
: 掩码(可选),用于屏蔽特定残基。
返回值
torsions
: 主链和侧链的扭转角张量,形状[B, L, 10, 2]
。torsions_alt
: 扭转角备用版本(翻转版),形状同上。tors_mask
: 扭转角的有效掩码。tors_planar
: 布尔掩码,指示哪些扭转角是平面的。
源代码:
def get_torsions(
xyz_in, seq, torsion_indices, torsion_can_flip, ref_angles, mask_in=None
):
B, L = xyz_in.shape[:2]
tors_mask = get_tor_mask(seq, torsion_indices, mask_in)
# torsions to restrain to 0 or 180degree
tors_planar = torch.zeros((B, L, 10), dtype=torch.bool, device=xyz_in.device)
tors_planar[:, :, 5] = seq == aa2num["TYR"] # TYR chi 3 should be planar
# idealize given xyz coordinates before computing torsion angles
xyz = xyz_in.clone()
Rs, Ts = rigid_from_3_points(xyz[..., 0, :], xyz[..., 1, :], xyz[..., 2, :])
Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device)
Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device)
xyz[..., 0, :] = torch.einsum("brij,j->bri", Rs, Nideal) + Ts
xyz[..., 2, :] = torch.einsum("brij,j->bri", Rs, Cideal) + Ts
torsions = torch.zeros((B, L, 10, 2), device=xyz.device)
# avoid undefined angles for H generation
torsions[:, 0, 1, 0] = 1.0
torsions[:, -1, 0, 0] = 1.0
# omega
torsions[:, :-1, 0, :] = th_dih(
xyz[:, :-1, 1, :], xyz[:, :-1, 2, :], xyz[:, 1:, 0, :], xyz[:, 1:, 1, :]
)
# phi
torsions[:, 1:, 1, :] = th_dih(
xyz[:, :-1, 2, :], xyz[:, 1:, 0, :], xyz[:, 1:, 1, :], xyz[:, 1:, 2, :]
)
# psi
torsions[:, :, 2, :] = -1 * th_dih(
xyz[:, :, 0, :], xyz[:, :, 1, :], xyz[:, :, 2, :], xyz[:, :, 3, :]
)
# chis
ti0 = torch.gather(xyz, 2, torsion_indices[seq, :, 0, None].repeat(1, 1, 1, 3))
ti1 = torch.gather(xyz, 2, torsion_indices[seq, :, 1, None].repeat(1, 1, 1, 3))
ti2 = torch.gather(xyz, 2, torsion_indices[seq, :, 2, None].repeat(1, 1, 1, 3))
ti3 = torch.gather(xyz, 2, torsion_indices[seq, :, 3, None].repeat(1, 1, 1, 3))
torsions[:, :, 3:7, :] = th_dih(ti0, ti1, ti2, ti3)
# CB bend
NC = 0.5 * (xyz[:, :, 0, :3] + xyz[:, :, 2, :3])
CA = xyz[:, :, 1, :3]
CB = xyz[:, :, 4, :3]
t = th_ang_v(CB - CA, NC - CA)
t0 = ref_angles[seq][..., 0, :]
torsions[:, :, 7, :] = torch.stack(
(torch.sum(t * t0, dim=-1), t[..., 0] * t0[..., 1] - t