函数 featurize
import torch
import numpy as np
import csv
import time
import os
import random
def featurize(batch, device):
B = len(batch)
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
L_max = max([len(b['seq']) for b in batch])
X = np.zeros([B, L_max, 4, 3])
residue_idx = -100*np.ones([B, L_max], dtype=np.int32) #residue idx with jumps across chains
chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted, 0.0 for the bits that are given
mask_self = np.ones([B, L_max, L_max], dtype=np.int32) #for interface loss calculation - 0.0 for self interaction, 1.0 for other
chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #integer encoding for chains 0, 0, 0,...0, 1, 1,..., 1, 2, 2, 2...
S = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
extra_alphabet = [str(item) for item in list(np.arange(300))]
chain_letters = init_alphabet + extra_alphabet
for i, b in enumerate(batch):
masked_chains = b['masked_list']
visible_chains = b['visible_list']
all_chains = masked_chains + visible_chains
visible_temp_dict = {}
masked_temp_dict = {}
for step, letter in enumerate(all_chains):
chain_seq = b[f'seq_chain_{letter}']
if letter in visible_chains:
visible_temp_dict[letter] = chain_seq
elif letter in masked_chains:
masked_temp_dict[letter] = chain_seq
for km, vm in masked_temp_dict.items():
for kv, vv in visible_temp_dict.items():
if v