基于提示驱动的潜在领域泛化的医学图像分类方法(Python实现代码和数据分析)
摘要
医学图像分析中的深度学习模型易受数据集伪影偏差、相机差异、成像设备差异等导致的分布偏移影响,导致在真实临床环境中诊断不可靠。领域泛化(Domain Generalization, DG)方法旨在通过多领域训练提升模型在未知领域的性能,但其依赖精确的领域标签,而医学数据通常缺乏此类标签。为此,我们提出一种无需领域标签的领域泛化框架——提示驱动的潜在领域泛化(Prompt-driven Latent Domain Generalization, PLDG)。该框架包含无监督领域发现与提示学习:首先通过聚类与偏差相关的风格特征生成伪领域标签,随后利用协作式领域提示引导视觉变换器(ViT)从多样化的潜在领域学习知识。通过领域提示生成器实现跨领域知识共享,并采用领域混合策略缓解伪标签噪声问题。在皮肤癌分类、糖尿病视网膜病变分类和组织病理学癌症检测等任务上的实验表明,PLDG无需领域标签即可达到或超越传统DG方法的性能。
关键词:领域泛化,提示学习,皮肤病学,皮肤癌,糖尿病视网膜病变
I. 引言
深度学习在医学图像分析中取得了显著进展,但其性能易受分布偏移的影响。例如,皮肤镜图像分类模型可能过度依赖标尺、凝胶气泡等伪影而非实际病灶特征;糖尿病视网膜病变(Diabetic Retinopathy, DR)分类模型可能过拟合特定相机的成像风格。此类偏差导致模型在真实临床场景中的泛化能力受限。传统领域泛化方法旨在通过多领域训练提升模型对未知领域的适应性,但其依赖预定义的领域标签。然而,医学数据中领域标签常面临以下挑战:
-
获取成本高:领域标签需人工标注,耗时费力;
-
定义模糊:医学图像的领域划分(如伪影类型、医院来源)缺乏统一标准,不同专家可能存在分歧;
-
任务依赖性:领域划分与下游任务强相关,难以跨任务迁移。
现有方法存在两大局限性:
-
数据集层面:依赖领域标签的假设不切实际;
-
算法层面:领域不变特征学习可能忽略对未知领域有用的信号,而集成学习方法未充分利用跨领域信息。
为此,我们提出潜在领域泛化(Latent Domain Generalization, LDG),通过无监督方式自动发现潜在领域并实现模型泛化。本文的核心贡献如下:
-
无需领域标签的框架:提出PLDG,通过聚类ViT浅层风格特征生成伪领域标签,结合提示学习实现跨领域知识迁移;
-
领域提示生成器:通过低秩分解促进领域提示间的知识共享;
-
领域混合策略:缓解伪标签噪声问题,增强决策边界灵活性;
-
广泛验证:在皮肤病变、DR分类、癌症检测及去偏任务中验证有效性,性能超越传统DG方法。
II. 相关工作
A. 领域泛化
传统方法包括:
-
领域对抗训练:如DANN通过对抗损失对齐特征分布;
-
统计对齐:如CORAL匹配二阶统计量;
-
元学习:通过模拟领域偏移优化模型鲁棒性。
近期研究表明,ViT因其对纹理偏差的弱敏感性,在DG任务中表现优于CNN。然而,现有方法仍依赖领域标签,且医学图像领域泛化研究较少。
B. 医学图像中的领域泛化
现有工作多依赖人工标注的伪影标签或数据集差异作为领域标签,但存在噪声和定义不准确问题。例如,Bissoto等人通过二元分类器标注皮肤数据集的伪影标签,但标注结果可能存在误差;Mohammad等人将不同DR数据集直接视为不同领域,忽略了数据集内部相机多样性。本文首次提出基于ViT风格特征的无监督领域发现方法,摆脱对预定义标签的依赖。
C. 提示学习
提示学习通过添加可学习向量适配预训练模型至下游任务。例如,VPT在ViT中插入可学习提示以微调模型;Doprompt为不同领域设计独立提示以捕获领域特定知识。与现有方法不同,PLDG引入领域提示生成器,通过共享提示与低秩分解实现跨领域协作学习。
III. 方法
A. 问题定义
B. 整体框架
PLDG框架如图1所示,包含以下步骤:
-
无监督领域发现:基于ViT浅层CLS令牌的风格特征聚类生成伪领域标签;
-
领域提示学习:通过领域提示生成器与混合策略优化模型,提升跨领域泛化能力。
C. 基于简约性偏差的伪领域标签聚类
深度学习模型存在简约性偏差(Simplicity Bias),即倾向于学习简单特征(如背景伪影)而非复杂语义特征。本文利用该特性,从ViT浅层(如第1层)提取CLS令牌风格特征,通过k-means聚类生成伪领域标签。风格特征对齐损失定义为:
D. 基于ViT的领域提示学习
2. 损失函数
IV. 实验结果
A. 实验设置
B. 对比实验
1. 皮肤癌分类(表I)
PLDG在Derm7pt_derm和PAD数据集上分别提升3.46%和14.18%,平均ROC-AUC达84.32%,优于DANN、CORAL等传统方法。
2. DR分类(表II)
PLDG平均准确率达75.6%,显著高于依赖领域标签的方法(如ERM++:72.1%),表明其在领域标签噪声场景下的优势。
3. 癌症检测(表III)
PLDG在Camelyon17-WILDS上准确率为89.7%,仅次于使用领域标签的EPVT(90.2%),验证其实际应用价值。
C. 消融实验(表IV、V)
逐步添加提示(P)、适配器(A)、混合(M)、生成器(G)组件,结果显示:
-
+P:平均ROC-AUC提升3.39%;
-
+P+A+M:进一步提升0.87%;
-
+P+A+M+G:最终提升1.26%,验证各模块的有效性。
D. 超参数分析(图4)
-
提示长度:4时性能最优;
-
聚类数:4时平均ROC-AUC最高,且对聚类数不敏感(2~5均表现良好)。
E. 领域提示权重分析(图5)
领域距离(Fr'echet距离)与提示权重呈负相关,表明模型能自适应关注与目标领域相似的源领域。
F. 聚类分析(图6、7)
-
ViT浅层(L1)CLS令牌聚类结果与类别标签无关(NMI=0.12),主要反映风格特征;
-
t-SNE可视化显示伪领域对应“墨水标记”、“暗角”、“深肤色”等医学相关偏差。
G. 去偏评估(图8)
在陷阱数据集中,PLDG在最高偏差等级(Bias=1)时ROC-AUC为68.5%,显著优于ERM(62.37%),表明其对分布偏移的鲁棒性。
V. 结论
本文提出PLDG框架,首次在医学图像分类中实现无需领域标签的潜在领域泛化。实验表明:
-
领域标签非必要:通过伪标签发现,PLDG性能媲美甚至超越传统DG方法;
-
跨领域知识共享:领域提示生成器有效促进知识迁移;
-
鲁棒性:领域混合策略缓解伪标签噪声,提升模型泛化能力。未来工作将扩展至多模态医学数据与实时部署场景。(代码QQandweichat)
![]() | ![]() |
参考文献
图1 传统领域泛化与潜在领域泛化对比
图2 PLDG算法流程
图3 领域提示生成器与混合策略示意图
图4 提示长度与聚类数对性能的影响
图5 领域提示权重与领域距离的关系
图6 伪领域标签与类别/领域标签的标准化互信息(NMI)
图7 伪领域标签的t-SNE可视化
图8 陷阱数据集去偏性能对比