AF3 gdt函数解读
AlphaFold3的函数gdt、gdt_ts以及gdt_ha实现了 Global Distance Test (GDT) 评分计算,用于衡量蛋白质结构预测的准确性。GDT 评分衡量的是 预测结构(p1) 和 真实结构(p2) 之间的相似度,主要用于蛋白质结构比较。
源代码:
def gdt(p1, p2, mask, cutoffs):
"""
Calculate the Global Distance Test (GDT) score for protein structures.
Args:
p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].
p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].
mask (torch.Tensor): Mask for valid residues [..., N].
cutoffs (list): List of distance cutoffs for GDT calculation.
Returns:
torch.Tensor: GDT score [...].
"""
# Ensure inputs are float
p1 = p1.float()
p2 = p2.float()
mask = mask.float()
# Calculate number of valid residues per batch
n = torch.sum(mask, dim=-1)
# Calculate pairwise distances
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
# Calculate score for each cutoff, accounting for the mask
score = torch.sum((distances <= c).float() * mask, dim=-1) / (n + 1e-8)
scores.append(score)
# Stack scores and average across cutoffs
scores = torch.stack(scores, dim=-1)
return torch.mean(scores, dim=-1)
def gdt_ts(p1, p2, mask):
"""
Calculate the Global Distance Test Total Score (GDT_TS).
Args:
p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].
p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].
mask (torch.Tensor): Mask for valid residues [..., N].
Returns:
torch.Tensor: GDT_TS score [...].
"""
return gdt(p1, p2, mask, [1., 2., 4., 8.])
def gdt_ha(p1, p2, mask):
"""
Calculate the Global Distance Test High Accuracy (GDT_HA) score.
Args:
p1 (torch.Tensor): Coo