AF3 block_diag函数解读
AlphaFold3 msa_pairing 模块的 block_diag 函数的作用是创建一个块对角矩阵(block diagonal matrix),即将多个 np.ndarray
沿主对角线排列,并用 pad_value
填充非对角区域。它的行为类似于 scipy.linalg.block_diag
,但额外提供了非对角区域填充值的功能。
源代码:
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag
源码解读:
1. 参数
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
*arrs
: 变长参数,表示多个输入的 NumPy 数组(矩阵)。pad_value
: 用于填充非对角元素的值,默认是0.0
。
2. 生成单位矩阵(相同形状,全 1)
ones_arrs = [np.ones_like(x) for x in arrs]
ones_arrs
是一个列表,包含与arrs
形状相同的全1
矩阵。- 例如,如果
arrs
包含:
A = np.array([[1, 2],
[3, 4]])
B = np.array([[5, 6, 7],
[8, 9, 10],
[11, 12, 13]])
那么 ones_arrs
生成:
ones_A = np.array([[1, 1],
[1, 1]])
ones_B = np.array([[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
3. 计算“非对角区域掩码”
off_diag_mask = 1.0 - scipy.linalg.block_diag(*ones_arrs)
这里使用 scipy.linalg.block_diag(*ones_arrs)
生成一个由 ones_arrs
组成的块对角矩阵:
block_diag(ones_A, ones_B) =
[[1, 1, 0, 0, 0]
[1, 1, 0, 0, 0]
[0, 0, 1, 1, 1]
[0, 0, 1, 1, 1]
[0, 0, 1, 1, 1]]
1.0 - block_diag(ones_A, ones_B)
会将 1
变为 0
,0
变为 1
,得到:
off_diag_mask =
[[0, 0, 1, 1, 1]
[0, 0, 1, 1, 1]
[1, 1, 0, 0, 0]
[1, 1, 0, 0, 0]
[1, 1, 0, 0, 0]]
4. 生成块对角矩阵
diag = scipy.linalg.block_diag(*arrs)
这里直接调用 scipy.linalg.block_diag(*arrs)
,生成:
diag =
[[ 1, 2, 0, 0, 0]
[ 3, 4, 0, 0, 0]
[ 0, 0, 5, 6, 7]
[ 0, 0, 8, 9, 10]
[ 0, 0, 11, 12, 13]]
5. 处理填充值
diag += (off_diag_mask * pad_value).astype(diag.dtype)
off_diag_mask * pad_value
会在非对角区域填充 pad_value
,假设 pad_value=99
:
[[ 0, 0, 99, 99, 99]
[ 0, 0, 99, 99, 99]
[99, 99, 0, 0, 0]
[99, 99, 0, 0, 0]
[99, 99, 0, 0, 0]]
diag += ...
之后:
总结
函数作用
- 创建块对角矩阵,将多个数组沿主对角线排列。
- 可自定义填充值
pad_value
,用于填充非对角区域。
关键步骤
- 构造与输入矩阵形状相同的全 1 矩阵,用于计算非对角区域掩码。
- 利用
scipy.linalg.block_diag
生成块对角矩阵。 - 创建非对角区域掩码
off_diag_mask
,用于识别非对角区域。 - 将
pad_value
填充到非对角区域。
应用场景
- 处理 多序列比对(MSA) 数据,使不同链的 MSA 不会互相混合,而是独立存在于块对角区域。
- 用于 神经网络输入预处理,保持不同输入块的独立性。
- 在 线性代数和数值计算 中用于处理块状矩阵运算。
这样可以在 AlphaFold3 的 MSA
处理中 确保不同蛋白链的 MSA 数据不会干扰,并可用于深度学习的输入对齐!