当前位置: 首页 > article >正文

相互作用先验下的 3D 分子生成扩散模型 - IPDiff 评测

IPDiff 是一个基于蛋白质-配体相互作用先验引导的扩散模型,首次把配体-靶标蛋白相互作用引入到扩散模型的扩散和采样过程中,用于蛋白质(口袋)特异性的三维分子生成。

本文将对 IPDiff 实际的分子生成能力进行评测。

一、背景介绍

IPDiff 来源于清华大学深圳国际研究生院的杨文明教授和鹏城实验室的王宇研究员为通讯作者的文章:《Protein-Ligand Interaction Prior for Binding-aware 3D Molecule Diffusion Models》。文章链接: https://openreview.net/forum?id=qH9nrMNTIW。该文章于 2024 年 1 月 16 日发表在 ICLR 2024 上。

当前基于结构的分子生成的扩散模型主要在逆过程时考虑蛋白质-配体相互作用信息,但忽略了前向过程中的相互作用。前向和逆过程的不一致可能会削弱生成分子对靶标蛋白的结合亲和力。

针对这一问题,IPDiff 在扩散和采样过程中引入几何蛋白质-配体相互作用。具体来说,首先通过结合亲和力信号进行监督训练,预训练一个蛋白质-配体先验网络(prior network)。接着,利用先验网络:(1)将目标蛋白质和配体之间的相互作用整合到前向过程中,以调整分子扩散轨迹;(2)增强结合感知的分子采样过程。在 CrossDocked 2020 数据集上的评估结果表明,IPDiff 生成的分子不仅具有更真实的 3D 结构,更好的结合亲和力,还具有良好的药理学属性。

总的来说,IPDiff 在扩散和采样过程中都考虑蛋白质-配体相互作用,在蛋白质-配体的先验模型引导下,能够生成在 3D 结构和结合亲和力表现更加好的分子, 同时保持良好的药理学属性。

二、模型介绍

IPDiff 模型是  Interaction Prior-guided Diffusion model 的简称,是一种新颖的基于相互作用先验引导的 3D 分子生成扩散模型,根据蛋白口袋-配体相互作用先验自适应地调整扩散轨迹。

作者发现,当前的方法在前向过程和逆过程利用靶标蛋白和生成的分子配体之间的相互作用方面存在差异,这可能限制了扩散模型在 SBDD 任务中的表现。在前向过程中,注入噪声的方式对于所有具有不同靶点蛋白的训练样本的注入过程都是相同的。这样就忽略了不同训练样本之间结合位点的差异,所有分子在前向过程中以相同的方式受到扰动。在逆过程中,为了生成能够与特定受体结合的配体分子,会考虑结合位点的差异。这种差异引入了偏差,阻碍了扩散模型完全捕捉口袋和配体分子之间的相互作用,而这种分子间相互作用对于口袋-配体结合是十分关键的。

为消除这种差异,作者提出了一种新颖的基于相互作用先验引导的扩散模型(IPDiff)。如下图所示,在 IPDiff 模型中,蛋白口袋-配体的相互作用通过一个先验的预训练的网络(IPNet)捕捉,该网络以结合亲和力信号监督训练。然后,设计了一个可学习的适配器(adapter),明确地将口袋-配体相互作用纳入前向和逆过程中的所有时间步 t,以对扩散模型的扩散过程进行结合感知(binding-aware)的轨迹调整(先验偏移,prior-shifting)。这两个过程(扩散模型和 IPNet)在训练中联合优化,理论上 IPDiff 能够比现有的分子扩散模型获得更好的表现。此外,在逆过程中,由学习到的相互作用先验网络驱动的估计前一步的蛋白质-配体复合物条件,以便增强去噪过程(先验调节 prior-conditioning)。

为了验证 IPDiff 的性能,在 CrossDocked 2020 数据集上做了一系列评估实验。实验结果表明,IPDiff 能够生成结合亲和力更好。且保持良好药理学属性的分子,优于现有的基于扩散的分子生成模型。

总的来说,IPDiff 的主要贡献是:

(1)提出了一种新颖的三维分子生成的扩散模型,在前向和逆过程中均考虑口袋-配体之间分子内和分子间的相互作用。

(2)提出了先验偏移(prior-shifting),通过基于口袋结合位点和相应配体分子之间的相互作用调整前向的扩散轨迹。

(3)设计了先验调节(prior-conditioning),通过将配体分子的去噪过程与之前估计的蛋白质-配体相互作用条件化来增强逆过程。

(4)IPDiff 在 CrossDocked 2020 基准上达到了 SOTA 性能,生成分子的平均 Vina score 低至 -6.42,同时保持较为合适的分子属性。

2.1 模型框架

作者设计一个先验网络 IPNet,用于从三维结构和化学性质的角度捕捉结合口袋和配体之间的相互作用,IPNet 利用结合亲和力信号进行预训练。IPNet 由 SE(3)-等变神经网络和跨注意力层(cross-attention layers)组成。一个层数较少的全连接的 SE(3)-等变神经网络分别应用于图表示的蛋白质和配体分子,用来学习分子内相互作用。另一个层数较少的全连接 SE(3)-等变神经网络则应用于图表示的蛋白-配体复合物,以建模分子间的相互作用。从 IPNet 可以获得蛋白和小分子的表征向量(嵌入向量),在扩散模型中,嵌入向量经过 MLP 的转化,添加与时间 t 的系数,被当作与原子坐标、原子类型类似的特征参与到扩散模型的向前和去噪过程中。

IPDiff 将预训练的 IPNet 作为相互作用先验(嵌入向量)作为条件,以促进结合感知的配体扩散过程。IPDiff 中提出了两种机制:先验偏移(prior-shifting)和 先验条件(prior-conditioning),以充分利用扩散框架中前向和反向过程中的蛋白质-分子相互作用。

先验偏移(prior-shifting)基于由 IPNet 模型化的蛋白质-分子相互作用嵌入向量,调整配体分子在扩散过程中的位置轨迹。IPDiff 将配体分子和给定的口袋输入到预训练的先验网络 IPNet 中,以提取蛋白质和分子的相互作用表征(嵌入向量)。然后引入一层可学习的 MLP 来生成基于相互作用的,在每一个扩散的时间步 t 中调整配体分子的位置偏移。通过这种方式,IPDiff 将分子扩散过程与分子采样过程中与对齐,以便在利用靶标蛋白质信息时优化扩散轨迹,进而根据蛋白质-分子相互作用优化生成的分子。对应公式:

先验条件(prior-conditioning)可以最大化利用预训练的 IPNet 中的蛋白质-配体相互作用先验。在分子采样过程中基于先前估计的蛋白质-分子复合物作为条件化,从而促进结合感知的分子生成。对应公式:

IPDiff 的整体框架如下图所示,预训练的 IPNet 在训练和采样过程中都被冻结,用于提供相互作用的先验。分子 M_0  和  \hat{M}_{0|t} 分别用于正向和反向过程中提取相互作用的先验。\hat{M}_{0|t} 是在采样过程中在时间步 t 估计的分子(因为真实的分子 M_0 无法估计),F 是相互作用的表示,S 表示分子的位置偏移。

2.2 数据集和基线模型

为了全面建模蛋白质-配体相互作用,我们利用了 PDBbind v2016 数据集中的蛋白质-配体对(复合物)及其结合亲和性信号对 IPNet 进行预训练。PDBbind v2016 数据集包括 3767 个训练复合物(训练集)和 290 个测试复合物(测试集),通常用于结合亲和性预测任务。在分子生成任务中,按照之前的工作(AR、Pocket2mol 和 TargetDiff),作者在 CrossDocked 2020 数据集上训练和评估 IPDiff。数据的准备和划分和 AR 一致,其中 2250 万个对接的结合复合物被精炼为高质量的对接姿势(对接姿势与真实结构之间的 RMSD < 1 Å)和多样化的蛋白质(序列相似性 < 30%)。具体来说,我们使用了 10 万个蛋白质-配体对进行训练,并使用 100 个蛋白质进行测试。

在对比研究中,我们将我们的模型与五种最近的 SBDD (Structure-based drug design)代表性方法进行了对比分析。LiGAN 是一种 CVAE 模型(conditional VAE 模型),在蛋白质-配体结构的原子密度网格表示上进行训练。AR 和 Pocket2Mol 是在蛋白质口袋和之前生成的原子条件下以自回归方式生成 3D 分子的模型。TargetDiff 和 DecomposeDiff 是最近的两种最先进的扩散方法,分别以非自回归方式生成原子坐标和原子类型。

2.3 评价指标

作者从三个角度评估生成的分子配体:分子结构、靶点结合亲和力和分子性质。为了评估生成的分子在分子结构方面的表现,另外计算了生成分子与参考分子之间原子/键距离的经验分布的 Jensen-Shannon 散度(JSD)。

在之前的工作 (AR , LiGAN, TargetDiff) 中,利用 AutoDock Vina 计算了与结合相关的指标的均值和中位数,包括 Vina 分数(Vina Score)、局部结构最小化之后的 Vina 打分(Vina Min)、生成分子的重新对接打分(Vina Dock)和 生成分子比参考分子亲和力高的百分比(High Affinity)。Vina Score 直接根据生成的 3D 分子估算结合亲和力;Vina Min 在估算前执行局部结构最小化;Vina Dock 涉及额外的重新对接过程,反映出最佳可能的结合亲和力;High Affinity 则衡量生成的分子中有多少在每个测试蛋白质中比参考分子具有更好的结合能力。

此外,根据 AR 和 LiGAN 使用 QED、SA 和 多样性(Diversity)作为评估分子性质的指标。QED 是一种简单的药物相似性定量估算,结合了几种理想的分子性质;SA(合成可及性)是对合成配体难度的估算;多样性计算为给定口袋生成的所有配体之间的平均成对差异性。为了公平比较,所有采样和评估过程均按照 TargetDiff 的标准进行。

2.4 模型性能

作者对比了 IPDiff 和其他代表性方法生成的分子结构。生成分子的所有原子对之间的距离分布如下图所示,IPDiff 和参考分子的差异维持在非常小的水平(JSD : 0.08)。表中展示了不同方法生成的分子的键分布,并与相应的参考经验分布进行比较。“-”、“=”和“:”分别表示单键、双键和芳香键。结果显示,IPDiff 在主要键类型上的表现优于所有其他方法,展示了 IPDiff 在生成稳定分子结构方面的能力。

下表展示了两类 SBDD 方法(非扩散方法和基于扩散的方法)在结合亲和力和分子性质方面和 IPDiff 的比较。可以看出,IPDiff 在结合相关指标上显著优于非扩散基线方法。值得注意的是,IPDiff 在平均 Vina Score、Vina Min 和 Vina Dock 方面分别比强大的自回归方法 Pocket2Mol 高出 24.9%、16.0%和 19.9%。与最新的基于扩散的方法 DecompDiff 相比,IPDiff 不仅将结合相关指标平均 Vina Score、Vina Min 和 Vina Dock 分别提高了 13.2%、5.8%和 2.1%,还将性质相关指标平均 QED 和平均多样性分别提高了 15.6%和 8.8%。在 High Affinity 方面,平均 69.5% 的 IPDiff 分子表现出比参考分子更好的结合亲和力,这显著优于其他基线方法。这些提升表明,IPDiff 能够有效利用来自 IPNet 的蛋白质-配体相互作用先验,从而生成结合亲和力和分子性质都得到改善的分子。

从上表中可以看出,以往方法在结合相关指标和性质相关指标 QED 之间需要权衡。DecompDiff 在结合相关指标上表现优于 AR 和 Pocket2Mol,但在 QED 得分上落后于它们。相比之下,IPDiff 不仅在结合相关指标上达到了最新水平,还保持了与 Pocket2Mol 相当的 QED 得分,取得了比 DecompDiff 更好的平衡。然而,IPDiff 对 QED 和 SA 的重视程度较低,因为QED 和 SA 在实际药物发现场景中通常作为粗略筛选指标使用,只要在合理范围内即可。

下图展示了一些生成的配体分子及其性质。在这些案例上,IPDiff 模型生成的分子结构有效,且具有合理的结合 pose。但仔细一看,分子似乎有很多并环和大环。

IPDiff 的主要假设是,在前向和反向过程引入三维蛋白质-分子相互作用的先验知识有助于提高训练和采样的效率,从而在结合相关和性质相关的指标上提升分子生成性能。作者做了消融实验验证,如下表。

表中,自条件(self-conditioning,在分子生成的去噪过程中,没有使用前一步中的蛋白和小分子相互作用信息)机制无法提高生成性能,因为前一个时间步的估计分子不包括自我精炼所需的蛋白质-配体相互作用信息。相比之下,先验条件机制(proir-conditioning,即 IPNet)通过引入来自预训练的 IPNet 的信息性蛋白质-配体相互作用知识,显著提升了结合相关和性质相关的指标。此外,先验位移(prior-shifting)在结合相关指标上也有显著的提升,这表明先验位移能够有效帮助 IPDiff 生成紧密结合给定蛋白质口袋的配体分子。请注意,先验位移对性质相关指标 (QED 和 SA) 的贡献不大。这是因为先验位移仅应用于分子原子的位置,而性质相关的指标与蛋白质-配体对的几何结构关系较大。此外,同时在 IPDiff 中使用先验条件和先验位移,能够在结合相关和性质相关指标上实现最佳性能,进一步证明了这两种机制的有效性。

三、IPDiff 评测

3.1 安装环境

复制代码项目:

git clone https://github.com/YangLing0818/IPDiff.git

项目提供了项目运行的环境配置文件 IPDiff.yml,配置文件中部分内容展示如下:

name: ipdiff
channels:
  - pytorch
  - pyg
  - conda-forge
  - defaults
dependencies:
  - _ipython_minor_entry_point=8.7.0=h3b92ee0_0
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_kmp_llvm
  - absl-py=1.1.0=pyhd8ed1ab_0

使用提供的 IPDiff.yml 创建 IPDiff  环境,命令如下:

conda env create -f IPDiff.yml

但是安装过程中的依赖库版本有冲突,报错如下:

LibMambaUnsatisfiableError: Encountered problems while solving:
  - package cryptography-37.0.1-py38h9ce1e76_0 requires openssl <1.1.2a, but none of the providers can be installed

Could not solve for environment specs
The following packages are incompatible
├─ cryptography ==37.0.1 py38h9ce1e76_0 is installable and it requires
│  └─ openssl <1.1.2a , which can be installed;
└─ openssl ==3.1.0 hd590300_2 is not installable because it conflicts with any installable versions previously reported.

IPDiff.yml 中 cryptography ==37.0.1 指定的版本需要 openssl <1.1.2a,配置文件中指定了 openssl ==3.1.0,有版本冲突。把 IPDiff.yml 中第 42 行中,cryptography 指定的版本信息删除,让 conda 自动寻找匹配的版本.

cryptography=37.0.1=py38h9ce1e76_0

改为

cryptography

修改之后,继续安装环境

conda env create -f IPDiff.yml

但仍报错 autodocktools 的安装错误,如下:

Collecting package metadata (repodata.json): done
Solving environment: done

Downloading and Extracting Packages:

Preparing transaction: done
Verifying transaction: done
Executing transaction: / By downloading and using the CUDA Toolkit conda packages, you accept the terms and conditions of the CUDA End User License Agreement (EULA): https://docs.nvidia.com/cuda/eula/index.html

\ 

    Installed package of scikit-learn can be accelerated using scikit-learn-intelex.
    More details are available here: https://intel.github.io/scikit-learn-intelex

    For example:

        $ conda install scikit-learn-intelex
        $ python -m sklearnex my_application.py

    

done
Installing pip dependencies: | Ran pip subprocess with arguments:
['/workspace/anaconda3/envs/IPDiff/bin/python', '-m', 'pip', 'install', '-U', '-r', '/workspace/GuanXL/projects/IPDiff/condaenv.dja95gv6.requirements.txt', '--exists-action=b']
Pip subprocess output:
Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple

Pip subprocess error:
ERROR: Could not find a version that satisfies the requirement autodocktools-py3==0+unknown (from versions: none)
ERROR: No matching distribution found for autodocktools-py3==0+unknown

failed

CondaEnvException: Pip failed

通过 pip 安装的 autodocktools-py3 无法找到合适的版本,进入虚拟环境手动安装

conda activate IPDiff
python -m pip install git+https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3

autodocktools-py3 安装成功,打印下面内容:

Collecting git+https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3
  Cloning https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3 to /tmp/pip-req-build-tls85yic
  Running command git clone -q https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3 /tmp/pip-req-build-tls85yic
  Resolved https://github.com/Valdes-Tresanco-MS/AutoDockTools_py3 to commit a62a5d98116f1590183b58a9ad732b997cf2579c
Building wheels for collected packages: AutoDockTools-py3
  Building wheel for AutoDockTools-py3 (setup.py) ... done
  Created wheel for AutoDockTools-py3: filename=AutoDockTools_py3-1.5.7.post1+12.ga62a5d9-py3-none-any.whl size=984166 sha256=fbd2e81612b011bec6ebd5831bedb2467ec4715ddc9abed0f76db013cdd8b44e
  Stored in directory: /tmp/pip-ephem-wheel-cache-dleaadmf/wheels/88/95/bf/1ff0c68c3d3146de99a9b93ad95be4738164f5980b5a415661
Successfully built AutoDockTools-py3
Installing collected packages: AutoDockTools-py3
Successfully installed AutoDockTools-py3-1.5.7.post1+12.ga62a5d9

剩余的依赖库,也通过下面命令安装:

pip install meeko==0.1.dev3 pdb2pqr==3.6.1 vina==1.2.2 propka==3.5.0 mmcif-pdbx==2.0.1

安装过程输出如下:

Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple
Collecting meeko==0.1.dev3
  Downloading https://mirrors.cloud.tencent.com/pypi/packages/89/84/827a0ea853ba0970f0287eb09ba7d07186590a510d10a8b63a2d884ac447/meeko-0.1.dev3-py2.py3-none-any.whl (45 kB)
     |████████████████████████████████| 45 kB 4.0 MB/s 
Collecting pdb2pqr==3.6.1
  Downloading https://mirrors.cloud.tencent.com/pypi/packages/3a/7c/3bbf1f414f70cbb14b1ee4da74061ae5a0323cd4cb037a76642fa23d2e2f/pdb2pqr-3.6.1-py2.py3-none-any.whl (208 kB)
     |████████████████████████████████| 208 kB 2.4 MB/s 
Collecting vina==1.2.2
  Downloading https://mirrors.cloud.tencent.com/pypi/packages/18/38/d197002b15b4190d005da557c00dd96af320c0fb49c5b1a9130d01a90e1c/vina-1.2.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.0 MB)
     |████████████████████████████████| 7.0 MB 8.6 MB/s 
Collecting propka==3.5.0
  Downloading https://mirrors.cloud.tencent.com/pypi/packages/c4/e2/5c096dc02874a217b26d01cecb701278c5fbf847e473f4870a711563eb87/propka-3.5.0-py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 16.6 MB/s 
Collecting mmcif-pdbx==2.0.1
  Downloading https://mirrors.cloud.tencent.com/pypi/packages/03/8e/ff50191d2210faac7df6b2e66cf39e021bd7051bd9389b341b1d560cae49/mmcif_pdbx-2.0.1-py2.py3-none-any.whl (20 kB)
Requirement already satisfied: numpy>=1.18 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from meeko==0.1.dev3) (1.22.3)
Requirement already satisfied: docutils<0.18 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from pdb2pqr==3.6.1) (0.17.1)
Requirement already satisfied: requests in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from pdb2pqr==3.6.1) (2.27.1)
Requirement already satisfied: certifi>=2017.4.17 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from requests->pdb2pqr==3.6.1) (2022.12.7)
Requirement already satisfied: idna<4,>=2.5 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from requests->pdb2pqr==3.6.1) (3.3)
Requirement already satisfied: charset-normalizer~=2.0.0 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from requests->pdb2pqr==3.6.1) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /workspace/anaconda3/envs/IPDiff/lib/python3.8/site-packages (from requests->pdb2pqr==3.6.1) (1.26.15)
Installing collected packages: propka, mmcif-pdbx, vina, pdb2pqr, meeko
Successfully installed meeko-0.1.dev3 mmcif-pdbx-2.0.1 pdb2pqr-3.6.1 propka-3.5.0 vina-1.2.2

注:安装过程要注意 pyg 的版本,使用 2.0.4 。其他版本可能会在训练时出现问题。

3.2 分子生成案例测试

作者把训练好的模型保存在谷歌网盘,IPNet 已经在项目文件夹中(./pretrained_models/IPNet)。IPDiff 训练好的模型需要从谷歌网盘下载,谷歌网盘中的文件如下图所示,链接是 https://drive.google.com/drive/folders/1VaCvlRncFHQqvYV-FsUmxpoxIYRm2u_V ,下载完成之后放在 ./pretrained_models 文件夹中。

IPDiff 的数据预处理和数据集和 TargetDiff 一致,相关文件也保存在谷歌网盘中,链接为

https://drive.google.com/drive/folders/1j21cc7-97TedKh_El5E34yI8o5ckI7eK?usp=share_link ,内容具体如下图所示。下载好的 crossdocked_pocket10_pose_split.pt 和

crossdocked_v1.1_rmsd1.0_pocket10_processed_final.lmdb 放在 ./datasets 文件夹中。

由于项目脚本中使用的是绝对路径,所以项目下载之后需要修改路径。为方便后续评测,我们把项目中的所有绝对路径进行修改,我们都进行了修改修改后可以进行正常的分子生成。

3.2.1 内置案例分子生成

测试集中 index 为 0 的体系(PDB : 2Z3H, 即 ./datasets/test_set/BSD_ASPTE_1_130_0 的测试体系)的口袋以及原来小分子的展示如下图,我们把这个蛋白作为测试分子生成的内置案例。

分子采样的脚本是 sample_split.py,使用编号为 0 的蛋白进行分子生成,命令如下:

python sample_split.py \
  --start_index 0 \
  --end_index 0 \
  --batch_size 25 \
  --result_path ./result_0

--start_index 0 和 --end_index 0 指定分子生成的起始和结束编号,这里指定编号为 0 的蛋白,--result_path ./result_0 指定生成分子保存到 ./result_0 中。 

3090 显卡运行此分子生成的时间大约为 60 分钟。生成分子完成,打印信息如下:

{'model': {'checkpoint': './pretrained_models/pretrained_IPDiff.pt'}, 'sample': {'seed': 2024, 'num_samples': 100, 'num_steps': 1000, 'pos_only': False, 'center_pos_mode': 'protein', 'sample_num_atoms': 'prior'}}
Training Config: {'data': {'name': 'pl', 'path': './datasets/crossdocked_v1.1_rmsd1.0', 'split': './datasets/crossdocked_pocket10_pose_split.pt', 'transform': {'ligand_atom_mode': 'add_aromatic', 'random_rot': False}}, 'net_cond': {'ckpt_path': './pretrained_models/IPNet', 'hidden_dim': 128}, 'model': {'cond_dim': 128, 'model_mean_type': 'C0', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'v_beta_schedule': 'cosine', 'v_beta_s': 0.01, 'num_diffusion_timesteps': 1000, 'loss_v_weight': 100.0, 'sample_time_method': 'symmetric', 'time_emb_dim': 0, 'time_emb_mode': 'simple', 'center_pos_mode': 'protein', 'node_indicator': True, 'model_type': 'uni_o2', 'num_blocks': 1, 'num_layers': 9, 'hidden_dim': 128, 'n_heads': 16, 'edge_feat_dim': 4, 'num_r_gaussian': 20, 'knn': 32, 'num_node_types': 8, 'act_fn': 'relu', 'norm': True, 'cutoff_mode': 'knn', 'ew_net_type': 'global', 'num_x2h': 1, 'num_h2x': 1, 'r_max': 10.0, 'x2h_out_fc': False, 'sync_twoup': False}, 'train': {'seed': 2021, 'batch_size': 4, 'num_workers': 4, 'n_acc_batch': 1, 'max_iters': 1000000, 'val_freq': 5000, 'pos_noise_std': 0.1, 'max_grad_norm': 8.0, 'bond_loss_weight': 1.0, 'optimizer': {'type': 'adam', 'lr': 0.0005, 'weight_decay': 0, 'beta1': 0.95, 'beta2': 0.999}, 'scheduler': {'type': 'plateau', 'factor': 0.6, 'patience': 10, 'min_lr': 1e-06}}}
Successfully load the dataset (size: 100)!
Restored from ./pretrained_models/IPNet with 0 missing and 0 unexpected keys
Successfully load the model! ./pretrained_models/pretrained_IPDiff.pt
sampling: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:37<00:00,  1.22it/s]
sampling: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:40<00:00,  1.22it/s]
sampling: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:43<00:00,  1.21it/s]
sampling: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [13:42<00:00,  1.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [55:00<00:00, 825.10s/it]13:42<00:00,  1.22it/s]
Sample done!
sampled data_id:  0

在指定的输出文件夹 ./result_0 中生成分子信息文件 result_0.pt,即./result_0/result_0.pt。在 IPDiff 生成和评价分子的过程中,并没有保存生成的分子构象,所以我们修改评价分子的脚本 eval_split.py,加入保存分子构象的代码。

修改后运行,评价生成分子的命令如下(--docking_mode none,对接有问题,暂时不对接)

python eval_split.py \
  --eval_start_index 0 \
  --eval_end_index 0 \
  --sample_path ./result_0  \
  --docking_mode none

返回结果如下,

Load generated data done! sample_id[0:0] examples for evaluation.
Eval: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it]
Evaluate done! 100 samples in total.
mol_stable:   0.3300
atm_stable:   0.9211
recon_success:        0.8800
eval_success: 0.8300
complete:     0.8300
JS bond distances of complete mols: 
JSD_6-6|4:    0.5751
JSD_6-6|1:    0.3309
JSD_6-8|1:    0.4317
JSD_6-7|1:    0.6215
JSD_6-8|2:    0.4572
JSD_6-6|2:    0.4721
JSD_6-7|4:    0.6795
JSD_6-7|2:    0.7596
JSD_CC_2A:    0.3805
JSD_All_12A:  0.1044
Atom type JS: 0.2071
Number of reconstructed mols: 88, complete mols: 83, evaluated mols: 83
QED:   Mean: 0.293 Median: 0.265
SA:    Mean: 0.527 Median: 0.530
ring size: 3 ratio: 0.000
ring size: 4 ratio: 0.120
ring size: 5 ratio: 0.470
ring size: 6 ratio: 0.578
ring size: 7 ratio: 0.349
ring size: 8 ratio: 0.060
ring size: 9 ratio: 0.012
Generated molecules saved as sdf format!

从返回结果可以看出,能够重构的生成分子有 83 个,QED 和 SA 均值分别为 0.265 和 0.530,生成分子中包含 五、六、七环的比例分别为 0.470、0.578 和 0.349。评估结果和生成的分子构象默认保存在 ./eval_results/ 中。因为所有评价结果会输出到 ./eval_results/,为防止后面的案例覆盖结果,我们把 ./eval_results 重命名为 ./eval_results_0。

上面的命令中,指定 --docking_mode none,生成分子没有和口袋对接计算打分。可以指定 --docking_mode vina_score 计算生成分子和口袋的对接打分,命令如下:

python eval_split.py \
  --eval_start_index 0 \ 
  --eval_end_index 0 \
  --sample_path ./result_0 \
  --docking_mode vina_score

评测过程报错如下:

[2024-08-22 02:43:55,718::evaluate::INFO] Load generated data done! sample_id[0:0] examples for evaluation.
Eval: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:40<00:00, 40.39s/it]
[2024-08-22 02:44:36,114::evaluate::INFO] Evaluate done! 100 samples in total.
[2024-08-22 02:44:36,114::evaluate::INFO] mol_stable:   0.3300
[2024-08-22 02:44:36,114::evaluate::INFO] atm_stable:   0.9211
[2024-08-22 02:44:36,114::evaluate::INFO] recon_success:        0.8800
[2024-08-22 02:44:36,115::evaluate::INFO] eval_success: 0.0000
[2024-08-22 02:44:36,115::evaluate::INFO] complete:     0.8300
[2024-08-22 02:44:36,115::evaluate::INFO] JS bond distances of complete mols: 
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-6|4:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-6|1:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-8|1:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-7|1:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-8|2:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-6|2:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-7|4:    None
[2024-08-22 02:44:36,115::evaluate::INFO] JSD_6-7|2:    None
/workspace/GuanXL/projects/IPDiff/utils/evaluation/eval_bond_length.py:29: RuntimeWarning: invalid value encountered in true_divide
  bin_counts = np.array(bin_counts) / np.sum(bin_counts)
[2024-08-22 02:44:36,117::evaluate::INFO] JSD_CC_2A:    nan
[2024-08-22 02:44:36,117::evaluate::INFO] JSD_All_12A:  nan
Traceback (most recent call last):
  File "eval_split.py", line 181, in <module>
    atom_type_js = eval_atom_type.eval_atom_type_distribution(success_atom_types)
  File "/workspace/GuanXL/projects/IPDiff/utils/evaluation/eval_atom_type.py", line 30, in eval_atom_type_distribution
    pred_atom_distribution[k] = pred_counter[k] / total_num_atoms
ZeroDivisionError: division by zero

这是因为在做 vina, qvina 打分或者对接时,需要蛋白结构,代码默认的蛋白结构保存在 ./datasets/crossdocked_v1.1_rmsd1.0(由 protein_root 参数指定)与真实的路径 ./datasets/test_set 不符。因此,运行评估脚本时,需要改成如下命令(添加 protein_root 参数),例如使用 qvina:

python eval_split.py   \
  --eval_start_index 0  \
  --eval_end_index 0   \
  --sample_path ./result_0    \
  --docking_mode qvina \
  --protein_root ./datasets/test_set

注:使用 qvina 打分需要另外安装 adt 的 conda 环境。请参照之前 TagMol 的测评文档。

运行输出示例,输出每一个分子的 qvina 打分:

[2024-08-28 00:39:51,799::evaluate::INFO] Load generated data done! sample_id[0:0] examples for evaluation.
Eval:   0%|                                                                                                                                                                      | 0/1 [00:00<?, ?it/s]Best affinity: -0.8
Best affinity: -5.1
Best affinity: -7.6
Best affinity: -3.3
... ...

最终完整的评估输出:

[2024-08-28 00:55:37,285::evaluate::INFO] Load generated data done! sample_id[0:0] examples for evaluation.
Eval: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [03:38<00:00, 218.09s/it]
[2024-08-28 00:59:15,376::evaluate::INFO] Evaluate done! 100 samples in total.
[2024-08-28 00:59:15,377::evaluate::INFO] mol_stable:   0.3600
[2024-08-28 00:59:15,378::evaluate::INFO] atm_stable:   0.9344
[2024-08-28 00:59:15,378::evaluate::INFO] recon_success:        0.8700
[2024-08-28 00:59:15,379::evaluate::INFO] eval_success: 0.8200
[2024-08-28 00:59:15,379::evaluate::INFO] complete:     0.8200
[2024-08-28 00:59:15,384::evaluate::INFO] JS bond distances of complete mols: 
[2024-08-28 00:59:15,385::evaluate::INFO] JSD_6-6|4:    0.5588
[2024-08-28 00:59:15,385::evaluate::INFO] JSD_6-6|1:    0.3260
[2024-08-28 00:59:15,385::evaluate::INFO] JSD_6-8|1:    0.4267
[2024-08-28 00:59:15,386::evaluate::INFO] JSD_6-7|1:    0.6060
[2024-08-28 00:59:15,386::evaluate::INFO] JSD_6-8|2:    0.4449
[2024-08-28 00:59:15,387::evaluate::INFO] JSD_6-6|2:    0.4806
[2024-08-28 00:59:15,387::evaluate::INFO] JSD_6-7|4:    0.6813
[2024-08-28 00:59:15,387::evaluate::INFO] JSD_6-7|2:    0.7584
[2024-08-28 00:59:15,418::evaluate::INFO] JSD_CC_2A:    0.3833
[2024-08-28 00:59:15,418::evaluate::INFO] JSD_All_12A:  0.1038
[2024-08-28 00:59:15,419::evaluate::INFO] Atom type JS: 0.2052
[2024-08-28 00:59:15,647::evaluate::INFO] Number of reconstructed mols: 87, complete mols: 82, evaluated mols: 82
[2024-08-28 00:59:15,648::evaluate::INFO] QED:   Mean: 0.293 Median: 0.256
[2024-08-28 00:59:15,649::evaluate::INFO] SA:    Mean: 0.525 Median: 0.525
[2024-08-28 00:59:15,649::evaluate::INFO] Vina Score:  Mean: -7.402 Median: -8.154
[2024-08-28 00:59:15,650::evaluate::INFO] Vina Min  :  Mean: -7.516 Median: -7.993
[2024-08-28 00:59:15,650::evaluate::INFO] ring size: 3 ratio: 0.000
[2024-08-28 00:59:15,651::evaluate::INFO] ring size: 4 ratio: 0.110
[2024-08-28 00:59:15,651::evaluate::INFO] ring size: 5 ratio: 0.463
[2024-08-28 00:59:15,652::evaluate::INFO] ring size: 6 ratio: 0.610
[2024-08-28 00:59:15,652::evaluate::INFO] ring size: 7 ratio: 0.354
[2024-08-28 00:59:15,652::evaluate::INFO] ring size: 8 ratio: 0.073
[2024-08-28 00:59:15,653::evaluate::INFO] ring size: 9 ratio: 0.012
Generated molecules saved as sdf format!

根据代码的设置,对生成分子的评估结果会放置在文件夹:./eval_results 中,其内容包括:

.
|-- Generated_molecules.sdf
|-- log.txt
|-- metrics_-1_0-to-0.pt
`-- pair_dist_hist_0-to-0.png


0 directories, 4 files

接着使用原位打分计算生成分子的 vina_score,命令如下:

python eval_split.py   \
  --eval_start_index 0  \
  --eval_end_index 0   \
  --sample_path ./result_0    \
  --docking_mode vina_score \
  --protein_root ./datasets/test_set

评估输出如下:

d[0:0] examples for evaluation.
Evaluate done! 100 samples in total.
mol_stable:	0.3600
atm_stable:	0.9344
recon_success:	0.8700
eval_success:	0.8200
complete:	0.8200
JS bond distances of complete mols: 
JSD_6-6|4:	0.5588
JSD_6-6|1:	0.3260
JSD_6-8|1:	0.4267
JSD_6-7|1:	0.6060
JSD_6-8|2:	0.4449
JSD_6-6|2:	0.4806
JSD_6-7|4:	0.6813
JSD_6-7|2:	0.7584
JSD_CC_2A:	0.3833
JSD_All_12A:	0.1038
Atom type JS: 0.2052
Number of reconstructed mols: 87, complete mols: 82, evaluated mols: 82
QED:   Mean: 0.293 Median: 0.256
SA:    Mean: 0.525 Median: 0.525
Vina Score:  Mean: -7.402 Median: -8.154
Vina Min  :  Mean: -7.516 Median: -7.993
ring size: 3 ratio: 0.000
ring size: 4 ratio: 0.110
ring size: 5 ratio: 0.463
ring size: 6 ratio: 0.610
ring size: 7 ratio: 0.354
ring size: 8 ratio: 0.073
ring size: 9 ratio: 0.012

Generated molecules saved as sdf format!

从返回结果可以看出,能够重构的生成分子有 82 个,QED 和 SA 均值分别为 0.293 和 0.525,生成分子中包含 五、六、七环的比例分别为 0.463、0.610 和 0.354。评估结果和生成的分子构象默认保存在 ./eval_results/ 中。因为所有评价结果会输出到 ./eval_results/,为防止后面的案例覆盖结果,我们把 ./eval_results 重命名为 ./eval_results_pretrain。

在 Generated_molecules.sdf 提取打分最低的 6 个分子,其 2D 结构以及 vina score 打分如下图:

对应在口袋中的 pose 如下图:

所有生成的分子如下:

IPDiff_2z3h_outputs

3.2.2 自定义的测试案例

我们选择 3WZE 作为自己的测试案例,使用 PyMol 把 3WZE 的配体周围 10 Å 的范围作为口袋,保存为 ./3wze/pocket_3wze.pdb,3WZE 口袋与原来配体小分子的如下图:

项目提供的脚本中有个 sample_for_pocket.py 脚本,在说明文档中没有介绍,尝试用该脚本在口袋中采样分子,命令如下:

python scripts/sample_for_pocket.py \
    configs/sampling.yml \
    --pdb_path ./3wze/pocket_3wze.pdb \
    --result_path outputs_3wze

直接运行上述代码存在问题。

在修改 scripts/sample_for_pocket.py 的代码后,我们成功使得 IPDiff 可以针对某个口袋生成分子,再次运行上述命令,生成的分子保存在:./outputs_3wze。

运行如下命令,评估 PIDiff 在 3wze 体系上生成的分子:

python ./scripts/evaluate_for_pocket.py \
  ./outputs_3wze/sample.pt \
  --verbose True \
  --protein_path ./3wze/pocket_3wze.pdb \
  --docking_mode qvina \
  --exhaustiveness 16

运行输出:

docking_mode qvina   --exhaustiveness 16
[2024-08-29 00:20:16,240::evaluate::INFO] Load generated data done! 1 examples in total.
Eval:   0%|                                                                                                                                                                    | 0/1 [00:00<?, ?it/s]Best affinity: -8.0
Best affinity: -6.9
Best affinity: -8.6
... ...

测试结果:

[2024-08-29 00:43:37,083::evaluate::INFO] Evaluate done! 100 samples in total.
[2024-08-29 00:43:37,083::evaluate::INFO] mol_stable:   0.1100
[2024-08-29 00:43:37,083::evaluate::INFO] atm_stable:   0.8166
[2024-08-29 00:43:37,083::evaluate::INFO] recon_success:        0.9900
[2024-08-29 00:43:37,083::evaluate::INFO] eval_success: 0.8700
[2024-08-29 00:43:37,083::evaluate::INFO] complete:     0.8900
[2024-08-29 00:43:37,086::evaluate::INFO] JS bond distances of complete mols: 
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-6|4:    0.3223
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-6|1:    0.4692
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-8|1:    0.4902
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-7|1:    0.4319
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-8|2:    0.5569
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-6|2:    0.3729
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-7|4:    0.4013
[2024-08-29 00:43:37,086::evaluate::INFO] JSD_6-7|2:    0.4353
[2024-08-29 00:43:37,107::evaluate::INFO] JSD_CC_2A:    0.3532
[2024-08-29 00:43:37,107::evaluate::INFO] JSD_All_12A:  0.1011
[2024-08-29 00:43:37,107::evaluate::INFO] Atom type JS: 0.1784
[2024-08-29 00:43:37,339::evaluate::INFO] Number of reconstructed mols: 99, complete mols: 89, evaluated mols: 87
[2024-08-29 00:43:37,340::evaluate::INFO] QED:   Mean: 0.646 Median: 0.687
[2024-08-29 00:43:37,340::evaluate::INFO] SA:    Mean: 0.648 Median: 0.650
[2024-08-29 00:43:37,341::evaluate::INFO] Vina:  Mean: -6.806 Median: -7.300
[2024-08-29 00:43:37,341::evaluate::INFO] ring size: 3 ratio: 0.000
[2024-08-29 00:43:37,341::evaluate::INFO] ring size: 4 ratio: 0.092
[2024-08-29 00:43:37,342::evaluate::INFO] ring size: 5 ratio: 0.621
[2024-08-29 00:43:37,342::evaluate::INFO] ring size: 6 ratio: 0.920
[2024-08-29 00:43:37,342::evaluate::INFO] ring size: 7 ratio: 0.414
[2024-08-29 00:43:37,343::evaluate::INFO] ring size: 8 ratio: 0.046
[2024-08-29 00:43:37,343::evaluate::INFO] ring size: 9 ratio: 0.011

从返回结果可以看出,在生成的100个分子中,能够重构的生成分子有 99 个,QED 和 SA 均值分别为 0.646 和 0.648,qvina 打分均值为 -6.806,生成分子中包含 五、六、七环的比例分别为 0.621、0.920 和 0.414。与 TagMol 相比,IPDiff 的 qvina 和 QED 打分低了一些,同时,成 7 元环的比例更高。这些结果表明,IPDiff 生成的分子结合力低一些,类药性也差一些,整体性能不如TagMol。

评估结果和生成的分子构象默认保存在 ./outputs_3wze/eval_results 中。生成的分子也被保存为sdf格式,路径为:./outputs_3wze/eval_results/Generated_molecules.sdf。打分最低的 Top 3 分子的pose 如下图:

这个三个分子的打分分别是,-12.0,-12.0,-10.9,对应的2D分子结构如下图:

IPDiff 在 3wze 体系上生成的分子如下:

IPDiff_3wze_outputs

3.3 训练模型

IPDiff 在项目中已经提供了训练的配置文件(./configs/training.yml),具体配置如下:

data:
  name: pl
  path: ./datasets/crossdocked_v1.1_rmsd1.0
  split: ./datasets/crossdocked_pocket10_pose_split.pt
  transform:
    ligand_atom_mode: add_aromatic
    random_rot: False

net_cond:
  ckpt_path: ./pretrained_models/ipnet
  hidden_dim: 128

model:
  cond_dim: 128
  model_mean_type: C0  # ['noise', 'C0']
  beta_schedule: sigmoid
  beta_start: 1.e-7
  beta_end: 2.e-3
  v_beta_schedule: cosine
  v_beta_s: 0.01
  num_diffusion_timesteps: 1000
  loss_v_weight: 100.
  sample_time_method: symmetric  # ['importance', 'symmetric']

  time_emb_dim: 0
  time_emb_mode: simple
  center_pos_mode: protein

  node_indicator: True
  model_type: uni_o2
  num_blocks: 1
  num_layers: 9
  hidden_dim: 128
  n_heads: 16
  edge_feat_dim: 4  # edge type feat
  num_r_gaussian: 20
  knn: 32 # !
  num_node_types: 8
  act_fn: relu
  norm: True
  cutoff_mode: knn  # [radius, none]
  ew_net_type: global  # [r, m, none]
  num_x2h: 1
  num_h2x: 1
  r_max: 10.
  x2h_out_fc: False
  sync_twoup: False

train:
  seed: 2021
  batch_size: 4
  num_workers: 4
  n_acc_batch: 1
  max_iters: 1000000
  val_freq: 5000
  pos_noise_std: 0.1
  max_grad_norm: 8.0
  bond_loss_weight: 1.0
  optimizer:
    type: adam
    lr: 5.e-4
    weight_decay: 0
    beta1: 0.95
    beta2: 0.999
  scheduler:
    type: plateau
    factor: 0.6
    patience: 10
    min_lr: 1.e-6

我们重新训练模型,命令如下:

python train.py \
  --config ./configs/training.yml \
  --logdir ./logs

我们使用项目提供的配置文件重新训练模型,--logdir ./logs 指定训练过程记录在 ./logs 中。训练的最好的模型是 770000.pt。

接下来我们使用我们训练好的 IPDiff 模型进行分子生成,创建新的配置文件 ./configs/sampling_retrain.yml,具体配置如下:

model:
  checkpoint: ./logs/training_2024_08_22__03_06_49/checkpoints/770000.pt

sample:
  seed: 2024
  num_samples: 100
  num_steps: 1000
  pos_only: False
  center_pos_mode: protein
  sample_num_atoms: prior

使用重新训练的模型对编号为 0 的蛋白进行分子生成,命令如下:

python sample_split.py \
  --config ./configs/sampling_retrain.yml
  --start_index 0 \
  --end_index 0 \
  --batch_size 10 \
  --result_path ./result_0_retrain

在指定的输出文件夹 ./result_0_retrain 中生成分子信息文件 result_0.pt,即./result_0_retrain/result_0.pt。

接着评估生成分子,命令如下:

python eval_split.py \
  --eval_start_index 0 \
  --eval_end_index 0 \
  --sample_path ./result_0_retrain  \
  --docking_mode vina_score

返回结果如下,

Load generated data done! sample_id[0:0] examples for evaluation.
Evaluate done! 100 samples in total.
mol_stable:	0.4600
atm_stable:	0.7834
recon_success:	0.7200
eval_success:	0.6900
complete:	0.6900
JS bond distances of complete mols: 
JSD_6-6|4:	0.4711
JSD_6-6|1:	0.2778
JSD_6-8|1:	0.4026
JSD_6-7|1:	0.4505
JSD_6-8|2:	0.4334
JSD_6-6|2:	0.4661
JSD_6-7|4:	0.7283
JSD_6-7|2:	0.6630
JSD_CC_2A:	0.3630
JSD_All_12A:	0.1021
Atom type JS: 0.1596
Number of reconstructed mols: 72, complete mols: 69, evaluated mols: 69
QED:   Mean: 0.479 Median: 0.463
SA:    Mean: 0.528 Median: 0.540
Vina Score:  Mean: 0.182 Median: -3.867
Vina Min  :  Mean: -4.120 Median: -5.745
ring size: 3 ratio: 0.000
ring size: 4 ratio: 0.043
ring size: 5 ratio: 0.667
ring size: 6 ratio: 0.739
ring size: 7 ratio: 0.594
ring size: 8 ratio: 0.101
ring size: 9 ratio: 0.043
Generated molecules saved as sdf format!

从返回结果可以看出,能够重构的生成分子有 69 个,QED 和 SA 均值分别为 0.479 和 0.528,生成分子中包含 五、六、七环的比例分别为 0.667、0.739 和 0.594。评估结果和生成的分子构象默认保存在 ./eval_results/ 中,重命名为 ./eval_results_retrain。相较于使用预训练好的模型(QED 和 SA 均值分别为 0.293 和 0.525,生成分子中包含 五、六、七环的比例分别为 0.463、0.610 和 0.354),QED 打分更高,包含的环结构更多。貌似我们训练的模型结果更好。

在 Generated_molecules.sdf 提取打分最低的 6 个分子,其 2D 结构如下图:

注:作者并没有提供训练先验模型 IPNet的代码,只提供了一个预训练好的checkpoint。但是这个模型仅仅是一个 mse 的回归任务,自己写一个 train.py 应该不难。

完整的测评文档以及修改后的代码下载:

https://download.csdn.net/download/wufeil7/89717499​​​​​​​


http://www.kler.cn/a/301351.html

相关文章:

  • sharding-jdbc自定义分片算法,表对应关系存储在mysql中,缓存到redis或者本地
  • QT仿QQ聊天项目,第三节,实现聊天界面
  • 开源音乐分离器Audio Decomposition:可实现盲源音频分离,无需外部乐器分离库,从头开始制作。将音乐转换为五线谱的程序
  • Vue2+ElementUI:用计算属性实现搜索框功能
  • 另外一种缓冲式图片组件的用法
  • pytorch tensor在CPU和GPU之间转换,numpy之间的转换
  • Hbase的简单使用示例
  • 在 RT-Thread 上使用单色屏 UI 库 - U8G2
  • 【Shiro】Shiro 的学习教程(四)之 SpringBoot 集成 Shiro 原理
  • 海外云手机是否适合运营TikTok?
  • Kubernetes部署(haproxy+keepalived)高可用环境和办公网络打通
  • Java 21的Preferences API的笔记
  • 分布式中间件-几个常用的消息中间件
  • redis基本数据结构-hash
  • 数据分析-11-时间序列分析的概念任务和主要方法
  • 第R2周:LSTM-火灾温度预测
  • C语言——希尔排序
  • Qt什么时候触发paintEvent事件
  • 【论文笔记】NDT: Neural Data Transformers (NBDT, 2022)
  • 一些深度学习相关指令
  • 【Qt】按钮样式--按钮内部布局(调整按钮文本和图标放置在任意位置)
  • 上海亚商投顾:沪指探底回升 华为产业链午后爆发
  • 【深度学习讲解笔记】第1章-机器学习基础(3)
  • Oracle Data Guard:Oracle数据库的高可用性和灾难恢复解决方案
  • 最近试用了FunHPC-AI宝箱-ComfyUI-Plus,使用了dreamshaperXL全能模型,生成了几张国风图,效果真的让人惊叹!
  • 安装MongoDB