【医学影像 AI】EyeMoSt+:用于眼科疾病筛查的置信度感知多模态学习框架
【医学影像 AI】EyeMoS𝑡+:用于眼科疾病筛查的置信度感知多模态学习框架
- 0. 论文介绍
- 0.1 基本情况
- 0.2 主要贡献
- 0.3 下载与引用
- 0.4 摘要
- 1. 引言
- 2. 相关工作
- 2.1 眼科疾病筛查的多模态学习
- 2.2 不确定性估计
- 3. 本文的方法
- 3.1 每种模态的预测和不确定性估计
- 3.2 混合学生 t 分布的置信度融合
- 3.3 置信度感知的多模态排序
- 3.4 多模态学习过程
- 4. 试验
- 4.1 数据集和训练的信息
- 4.2 比较方法和指标
- 4.3 鲁棒性检验
- 4.4 分布外检测
- 4.5 缺失模态的试验
- 4.6 消融研究
- 5. 结论
- 6. Github 项目
- 6.1 训练
- 6.2 测试
- 6.3 核心程序代码(EyeMoSt/MedIA’24/train3_trans.py)
0. 论文介绍
0.1 基本情况
2024年5月,四川大学袁学东、汕头大学陈浩宇等在 Medical Image Analysis 上发表研究论文:Confidence-aware multi-modality learning for eye disease screening,提出一种用于眼科疾病筛查的置信度感知多模态学习框架 EyeMoS𝑡+。
EyeMoS𝑡+ 框架通过融合来自不同来源的信息来提高诊断的准确性,特别是针对糖尿病视网膜病变(DR)、糖尿病性黄斑水肿(DME)和年龄相关性黄斑变性(AMD)等眼科疾病。 通过融合不同模态的信息,并引入置信度感知机制,提高了眼科疾病筛查的准确性和鲁棒性。
0.2 主要贡献
- 本文提出了一种新的基于置信度的多模态眼病筛查方法。
- 本文提出MoSt来动态适应重尾数据分布并捕获特定模态的不确定性。
- 本文为多模态眼病筛查引入了一种创新的基于置信度的排名正则化术语。
- 本文在不同眼病的公共和内部数据集上进行了充分的实验,这些实验清楚地验证了正常和OOD测试样本的准确性、稳健性和可靠性。
0.3 下载与引用
论文下载地址:
https://doi.org/10.1016/j.media.2024.103214
https://arxiv.org/pdf/2405.18167
论文引用:
Ke Zou, Xuedong Yuan, Haoyu Chen, et al. Confidence-aware multi-modality learning for eye disease screening, Medical Image Analysis, Vol. 96, 2024
@InProceedings{uMedGround_Zou_2024,
author="Zou, Ke
and Lin, Tian
and Yuan, Xuedong
and Chen, Haoyu
and Shen, Xiaojing
and Wang, Meng
and Fu, Huazhu",
title="Reliable Multimodality Eye Disease Screening via Mixture of Student's t Distributions",
journal={arXiv preprint arXiv:2404.06798},
year={2024}
}
开源代码:
https://github.com/Cocofeat/EyeMoSt
0.4 摘要
多模态眼科图像分类在诊断眼部疾病中起着关键作用,因为它整合了来自不同来源的信息,以补充其各自的性能。然而,最近的改进主要集中在准确性上,往往忽视了对不同模式预测的信心和稳健性的重要性。
在本研究中,我们提出了一种用于眼部疾病筛查的新型多模态证据融合管道。它为每种模态提供了一个置信度度量,并使用多分布融合视角优雅地整合了多模态信息。
具体来说,我们的方法首先利用预训练模型上的正态逆伽玛先验分布来学习单模态的任意和认知不确定性。然后,将正态逆伽马分布分析为Student分布。此外,在置信度感知融合框架内,我们提出了Student分布的混合,以有效地整合不同的模态,赋予模型重尾特性,并增强其鲁棒性和可靠性。更重要的是,置信度感知的多模态排序正则化项使模型能够更合理地对有噪声的单模态和融合模态置信度进行排序,从而提高了可靠性和准确性。
在公共和内部数据集上的实验结果表明,我们的模型在鲁棒性方面表现出色,特别是在涉及高斯噪声和模态缺失条件的具有挑战性的场景中。此外,我们的模型对分布外数据表现出很强的泛化能力,突显了其作为多模式眼病筛查有前景的解决方案的潜力。
关键词:不确定性估计,眼病,多模态学习
1. 引言
糖尿病视网膜病变(DR)和糖尿病黄斑水肿(DME)是工作年龄人群永久性视力障碍的主要罪魁祸首[23]。年龄相关性黄斑变性(AMD)是全球失明的另一个主要原因,息肉状脉络膜血管病(PCV)是AMD的一种亚型,尤其在亚洲人中更为常见[26,8]。在计算机视觉的巨大发展[21,14,45]的推动下,在计算机辅助检测下对上述眼部疾病进行筛查和持续监测迫在眉睫。
视网膜眼底图像(Fundus)和光学相干断层扫描(OCT)是眼科疾病筛查的常见2D和3D成像技术。这促使研究人员将上述方法结合起来,以提高眼科疾病筛查的性能。毕竟,多模态学习通常比单模态学习提供更多的补充信息[54]。根据融合阶段,现有的多模态学习方法大致可分为早期、中期和晚期融合[3]。对于多模态眼科图像学习,最近的工作[52,16,24,40,48,15,36,6,25]主要集中在早期融合[16,24,40]和中间融合阶段[52,48,15,36,16,25]。先前的研究通常在融合过程中直接结合不同眼睛图像模态的特征。然而,这可能会导致从噪声模态中收集误判的特征,从而导致不正确的预测结果 y ^ F \hat{y}_F y^F,如图1(a)所示。为了应对这一挑战,我们在方法中利用不确定性估计,从单个模态分布的角度评估单模态的可靠性。如图1(b)所示,我们估计了单模态 { y ^ m , U m } m = 1 , 2 \{\hat{y}_m,U_m\}_{m=1,2} {y^m,Um}m=1,2 的预测和不确定性,然后利用置信度的分布融合来推导最终的预测及其不确定性 { y ^ F , U F } \{\hat{y}_F,U_F\} {y^F,UF}。这一努力对于确保临床安全性和可靠性至关重要,特别是在处理来自任何一种图像类型的干扰时,不确定性是整合多模态分布的可靠指标。
不确定性估计为模糊或不确定的网络预测提供了理由。特别是,当一个模型遇到以前从未见过的数据或被噪声污染的输入时,它可以用拼写“我不知道”来表达不确定性,并且可以量化这种不确定性的程度。如[19]所述,不确定性估计包括两种类型:任意不确定性和认知不确定性。任意不确定性是观测数据中固有的,源于基础过程中固有的随机性或可变性。相比之下,认知不确定性源于我们的知识或模型的局限性,表明可以通过额外的数据或改进的模型来减少或消除不确定性。目前的不确定性估计方法主要包括贝叶斯神经网络、深度集成(DE)、基于确定性的方法。贝叶斯神经网络通过将网络权重视为随机变量,使用拉普拉斯近似[30]、马尔可夫链蒙特卡洛[34]和变分推理技术[38]来学习网络权重的分布。受收敛挑战的影响,这些方法需要大量的计算,直到网络中的丢包在一定程度上得到缓解[19]。与其学习分布,估计不确定性的另一种简单方法是学习一组深度网络[22]。为了减轻计算复杂性和过度自信[42,47],许多基于确定性的方法[42,31,44,27]被设计为在网络的单次前向传递中直接输出不确定性。上述方法大多侧重于具有不确定性估计的单模态。如何意识到多模态不确定性并在原则上将其融合仍有待研究。最近,[46]提出了一种通过跨模态随机网络预测来估计不确定性的不确定性感知多模态学习器。值得注意的是,仅依靠交叉注意力进行多模态特征融合可能无法优化融合后的性能,特别是在存在噪声模态的情况下,如图1所示。不同的是,我们的方法为每种模态提供了置信度分数,并使用多分布融合的视角优雅地整合了多模态信息。
图1:眼科疾病筛查的多模态分类方法比较。(a) 传统的多模式眼病筛查。(b) 本文的置信度感知多模态学习用于眼部疾病筛查。
y
^
\hat y
y^ 和
U
U
U 分别表示预测及其不确定性。
我们在本文提出了 EyeMoSt+,这是一种新的具有置信度的多模态眼病筛查方法,旨在促进眼底和OCT模态的可靠融合。
我们的方法利用预训练模型上的正态逆Gamma(NIG)先验分布来学习单模态的任意和认知不确定性。通过将 NIG 先验解析求解为 Student 的 t 分布,我们将其转化为混合Student 的 t 分布融合问题。
为了赋予模型全局不确定性和鲁棒性,我们引入了一种用于混合 Student t(MoSt)分布的置信度融合策略。
为了防止在存在模态噪声的情况下全局置信度的上升,我们利用了一种新的基于置信度感知的排序正则化方法。
为了验证我们提出的方法的有效性,我们在三个数据集上进行了广泛的实验,涵盖了各种眼部疾病,如青光眼分级、AMD、PCV、DR和DME。这些实验强调了EyeMoSt+在眼部疾病多模式筛查中的可靠性和鲁棒性,突出了其在处理噪声输入、识别缺失模式和处理看不见的数据方面的有效性。综上所述,本文的贡献主要包括:
(1) 我们提出了一种新的具有置信度的多模态眼病筛查方法,该方法为具有可靠性和鲁棒性的分类提供了新的证据多模态范式。
(2) 为了整合不同的模态,设计了一种新的MoSt,可以动态地感知每种具有不确定性的模态的重尾和置信度,这有望提供显著的鲁棒性,并促进可靠的决策。
(3) 为了解决单模态和融合模态之间的置信关系,我们提出了一种新的用于多模态眼病筛查的置信感知排名正则化项。
(4) 我们在包括各种眼病的公共和内部数据集上进行了全面的实验,以彻底验证我们模型的准确性、鲁棒性和可靠性,包括其在分布外(OOD)测试样本上的性能,如高斯噪声、缺失模态和看不见数据的样本。
我们比较了与我们的研究直接相关的三种方法[2,41,56],概述如下:
(1) 与证据深度回归方法[2]相比,我们的工作将该框架扩展到眼科分类领域,引入了置信度感知证据多模态融合。利用[2]方法,我们采用证据先验分布NIG来表征不同模态的置信度(方程1至4)。然而,原始方法缺乏整合不同模态的预测和置信度的解决方案。为了解决这一局限性,我们提出了MoSt,这是一种动态融合方法,融合了来自不同模态的预测和不确定性(方程7)。此外,为了确保融合模态的置信度始终超过单模态的置信水平,我们引入了一种为多模态眼病筛查量身定制的新的置信度感知排名正则化项(式12)。
(2) 与[41]中融合涉及两个Student-t分布的情况相比,我们增强了这一过程,并提出了一个用于置信度感知的排名正则化项。基于[41],我们的置信感知融合策略(式7)是一个改进版本。我们假设多个 Student 的 t 分布在融合后仍然是近似的 Student t 分布,自由度
v
F
v_F
vF 和
∑
F
∑_F
∑F 与[41]一致。
v
F
v_F
vF 的计算引入了两种模态的置信度 C。此外,我们为不同的模态构建了一个证据先验分布 NIG,并将其转换为两个 Student 的 t 分布进行融合(式1)。此外,引入了一种新的用于多模态眼病筛查的置信度感知排名正则化项,以建立融合模态的置信度与每种单一模态的置信度之间的排名关系(式12)。
(3)与我们之前的会议版本[56]相比,我们进一步增强了 Student t 分布混合的融合过程(式7),引入了一种创新的置信感知多模态学习排名组件。我们的贡献还包括更稳健的验证、OOD数据检测和实际应用中的缺失模态实验。我们改进了 Student t 分布混合的融合模式,结合了式 10~12 和式18,为多模态眼病筛查提出了一个新的置信感知排名正则化项。此外,我们丰富了第4.3节和第4.4节中的稳健实验验证,并在不同场景下进行了OOD数据验证。最后,在第4.5节中添加了缺失模态实验,旨在全面验证所提出算法的鲁棒性和可靠性。
2. 相关工作
本节首先简要回顾了用于眼病筛查的多模态学习。然后,介绍了不同的不确定度量化方法。
2.1 眼科疾病筛查的多模态学习
根据不同阶段多模态融合的整合,现有的用于典型眼科疾病筛查的多模态图像方法可分为早期、中期和晚期融合的方法[3]。
早期的基于融合的方法直接在数据级别集成多种模态,通常是通过连接原始或预处理的多模态数据。Rodrigues等人[40]倾向于在早期融合阶段使用基于灰度和血管连接属性的互补特征。以下方法倾向于早期融合特殊的原始数据,而不是直接拼接多模态原始图像。Hua等人[16]在早期将预处理的眼底图像和宽场扫描源光学相干断层扫描血管造影(OCTA)相结合,然后提取代表性特征进行DR识别。Li等人[24]通过CycleGAN获得合成的FFA数据[55],然后将成对的FFA和Fundus数据输入卷积神经网络(CNN)。他们试图学习用于视网膜疾病诊断的模态不变特征和患者相似性特征。早期融合阶段的方法可以最大限度地保留原始图像信息,目前大多数人在疾病筛查的中间阶段进行多模态眼科图像融合。
中间融合策略允许在网络的不同中间层融合多种模态。Yoo等人[52]首先尝试在中间融合阶段从多模态图像中诊断AMD。他们使用预先训练的VGG-Net模型提取特征,然后将它们聚合在一起,通过随机森林分类器诊断AMD。与直接聚合预训练模型提取的特征不同,Wang等人[49,48]使用类激活映射训练了端到端的双流CNN,然后将来自Fundus和OCT流的信息连接起来。同样,Ou等人[36]和He等人[15]利用CBAM[50]和模态特异性注意机制提取了不同的模态特征,然后将它们连接起来,实现了视网膜图像分类的多模态融合。Cai等人[6]在早期和中期融合阶段很好地捕捉了嵌入在眼科图像中的领域特定特征,以实现分类。
上述早期和中期融合阶段的方法过于简单,缺乏利用眼底和OCT模态之间的互补信息。因此,Li等人[25]结合了网络多个维度的特征,并通过分层融合策略探索了它们之间的关系。然而,在融合的后期阶段,每种模态的特征和层次融合的特征只会被连接起来进行眼病分类。[9]使用超宽场彩色眼底摄影(UWF-CFP)成像和OCTA识别DR疾病,同时采用了多种混合策略来增强级联特征的泛化。在融合后期,应更加关注如何稳健可靠地组合这些多个模型的预测。
因此,在本文中,我们试图在多模态眼病筛查的后期融合阶段,重点研究基于不确定性估计的自适应融合。我们的目标是使用多分布融合方法整合来自各种模态的信息,特别强调学生的t分布融合。尽管之前在医学图像配准[11,39]和分割[35]中探索过这种方法,但它为推进眼部疾病筛查提供了有前景的途径。例如,MoSt算法中的像素相似性被引入到多模态医学图像的刚性配准中[11]。在此基础上,Ravikumar等人[39]提出了基于组的相似性配准,以增强对应性并更稳健地对齐形状。受上述方法的启发,我们提出了置信度感知融合框架内的MoSt,以有效地整合不同的模式,实现稳健可靠的眼部疾病诊断。
2.2 不确定性估计
不确定性量化提供了可靠的预测和置信水平,这对于推进可解释的深度神经网络(DNN)至关重要[1]。贝叶斯神经网络(BNNs)[4,32,17]通过学习确定性参数的分布来模拟不确定性。常用的技术包括拉普拉斯近似[30]、马尔可夫链蒙特卡罗[34]和变分推理[38]。尽管BNN对过拟合问题具有鲁棒性,但它们的计算量可能高得令人无法接受。为了解决这个问题,Kendall等人[19]引入了一种简单的方法,利用贝叶斯深度学习和蒙特卡洛Dropout(MCDO),在计算机视觉的背景下对任意和认知不确定性进行建模。DE[22]训练并集成了多深度学习模型来产生不确定性。然而,仍然存在一定的内存消耗和计算成本。
最近,基于确定性的方法[27,44]被设计为通过单次前向传递来估计不确定性,而不需要太多的采样和时间成本。Van Amersfoort等人[44]提出,基于构建径向基函数网络的思想,将测试样本和原型之间的距离作为确定性不确定性来测量。此外,[18]认为,多模态数据的不确定性估计仍然是一个挑战。他们引入了多模态神经过程,结合了几种创新和有原则的机制,旨在解决多模态数据的特定特征。[53]引入了质量感知多模态融合,以实现鲁棒的多模态融合。然而,他们没有明确描述每种模态的任意和认知不确定性,这可能会限制模型有效感知和区分数据质量的能力。
本文将我们的模型从深度证据回归[2]扩展到多模态融合的分类。我们专注于特定模态的不确定性估计,以及如何更自信、更可靠地融合多模态估计。
3. 本文的方法
本节介绍了 EyeMoSt+ 的总体框架,该框架有效地估计了每种模态的任意和认知不确定性,并原则上自适应地整合了融合模态。
如图2所示,我们首先使用预训练的CNN或变换器编码器来捕获不同的模态特征。然后,我们在训练好的网络后放置多个证据头,并对每种模态的高阶证据分布参数进行建模。为了合并这些预测分布,我们推导出了Student t(St)分布的正态逆Gamma分布。特别是,引入了St分布混合的置信度融合(MoSt),原则上整合了不同模态的St分布。此外,提出了一种新的置信感知排名正则化项来约束单峰和融合模态之间的置信关系。最后,我们详细阐述了模型证据获取的训练管道。
3.1 每种模态的预测和不确定性估计
给定多模态眼科数据集 D = { { x m i } m = 1 M } D=\{ \{ x^i_m \}^M_{m=1} \} D={{xmi}m=1M} 和相应的标签 y i y^i yi,直观的目标是学习一个可以对不同类别进行分类的函数。在眼科图像中,OCT和眼底是常见的成像方式。因此,这里 M M M=2, x 1 i x^i_1 x1i 和 x 2 i x^i_2 x2i 分别表示OCT和眼底输入模态数据。
我们首先加载预训练的基于2D CNN的[10]或基于 transformer 的[28]骨干编码器 Θ Θ Θ 和基于3D CNN的[7]或基于互感器的[13]骨干编码器 Φ Φ Φ,以识别特征级的信息量,分别定义为 Θ ( x 1 i ) Θ(x^i_1) Θ(x1i) 和 Φ ( x 2 i ) Φ(x^i_2) Φ(x2i) 。我们将深度证据回归模型[2]扩展到眼部疾病筛查的深度多模态证据分类。为此,为了模拟每种模态的不确定性,我们假设观测标签 y i y^i yi 来自高斯分布 N ( y i ∣ μ , σ 2 ) N(y^i|μ,σ^2) N(yi∣μ,σ2),其均值和方差由正态逆伽玛(Normal-Inverse-Gamma, NIG)先验分布决定。
其中KaTeX parse error: Double superscript at position 7: Γ^{−1}^̲ 是逆伽马分布。具体来说,多证据头(Multi-Evidential Heads)将放置在编码器Θ和Φ之后(如图2所示),编码器输出前述的 NIG参数pm=(γm,δm,αm,βm)。因此,Aleatoric Uncertainty(偶然不确定性)和Epistemic Uncertainty(认知不确定性)可以分别通过
E
[
σ
2
]
E[σ^2]
E[σ2]和
V
a
r
[
μ
]
Var[μ]
Var[μ]来估计:
之后,可以推导出Student的t预测分布,该分布由每种模态的先验似然和高斯似然的相互作用形成,由下式给出:
当在高斯似然函数上放置NIG证据时,存在如下解析解:
其中,
o
m
=
β
m
(
1
+
δ
m
)
δ
m
α
m
o_m = \frac{\beta _m(1+\delta_m)}{\delta_m \alpha_m}
om=δmαmβm(1+δm)。因此,这两种模态分布被转化为Student的 t分布:
3.2 混合学生 t 分布的置信度融合
然后,我们专注于融合来自不同模态的多个St分布。如何合理地将多个St整合为一个统一的St是关键问题。为此,联合分布模式可以表示为:
然后,联合t分布为:
为了保持闭合的Student t分布形式和融合模态的重尾特性,更新的参数由[41]给出。简单来说,我们首先将两个分布的自由度调整为一致。如[41]所述,自由度值越小,尾部越重,而方差值越大,尾部也越重。此外,考虑到Student t分布的方差公式(如方程9),重要的是要注意,随着 v v v 的增加,方差减小,表明置信度更高。我们假设多个 Student 的 t分布在融合后仍然是近似的 Student的 t分布。假设 v 1 v_1 v1 的自由度小于 v 2 v_2 v2,则融合的Student t分布 S t ( y i ; u F , ∑ F , v F ) St(y^i; u_F, ∑_F, v_F) St(yi;uF,∑F,vF) 将更新为:
其中C1和C2表示单模态分布的置信度,可以定义为:
因此,融合模态的预测和不确定性可以通过以下方式进行估计:
其中,
p
F
p_F
pF 是融合后 St 分布的参数,可以表示为
p
F
=
(
u
F
,
σ
F
,
v
F
)
p_F=(u_F, σ_F,v_F)
pF=(uF,σF,vF)。MoSt的置信度感知融合如图3所示。
3.3 置信度感知的多模态排序
在当前的多模态方法中,直接融合潜在损坏的模态是一种普遍的做法,导致模型可靠性和识别误差受损。准确定义每种模态的可靠性带来了挑战,特别是在处理同一样本的不同模态的不同置信水平时,尤其是在存在噪声数据的情况下。幸运的是,对置信度估计的监督可以作为一种可行的替代方案。从“信息的本质是消除不确定性(香农)”的信息学原理中汲取灵感[43],其中更多的信息意味着减少不确定性。也就是说,对于一个可靠的多模态分类器,整合多模态信息将消除信息中的不确定部分,从而得到更可靠的结果。基于这一假设,我们引入了一个基于排名的正则化项[29]。该术语限制了单模态和融合模态之间的关系,确保融合模态的置信水平始终超过每种单独模态的置信度。因此,该模型的可靠性得到了显著提高。
具体来说,我们首先直接最小化单模态和融合模态之间的置信度差异,如下所示:
其中
C
(
⋅
)
C(·)
C(⋅) 表示模态的置信度,由 softmax 层对数几率(logits)来计算[29]。尽管存在模态污染,但融合模型有时可以产生准确的预测。因此,我们只关注对正确预测的置信度进行正则化,同时避免对单个模态的置信度最小化。对于任何眼科模态图像,上述公式可以放宽如下:
针对每个样本,在任意模态和融合模态对上整合所提出的置信感知多模态排名损失,形式化如下:
所提出的置信感知多模态正则化项是通用的,可以作为额外的损失项无缝地合并到当前的证据深度学习框架中,以约束其置信度估计。图3②展示了置信度感知的多模态排序规则。这种集成增强了模型的可靠性,提高了性能的稳健性。
3.4 多模态学习过程
在证据学习框架下,我们期望为每种模态收集更多的证据,因此,所提出的模型有望最大化模型证据的似然函数。等效地,该模型有望最小化负对数似然函数,该函数可以表示为:
然后,为了适应分类任务,我们将交叉熵项
L
m
C
E
L^{CE}_m
LmCE 引入到方程(7)中:
其中
λ
m
λ_m
λm是平衡因子,设置为与[2]相同。
η
m
η_m
ηm是总体模型证据,可以表示为:
同样,对于融合模态,我们首先最大化模型证据的似然函数,如下所示:
然后,为了获得更好的分类性能,交叉熵项
L
m
C
E
L^{CE}_m
LmCE 也被引入到方程(16) 中,如下所示:
当
λ
F
λ_F
λF 作为平衡因子时,通过消融研究确定其最佳选择(第4.6节)。总之,多模态筛选的证据深度学习过程可以表示为:
λ
C
λ_C
λC 是一个关键的超参数,它控制着置信度感知多模态学习正则化的效力。根据[29]的见解和消融研究的结果(第4.6节),其值设置为10。本文主要讨论使用两种模式的眼部疾病:OCT和眼底成像。因此,参数M被设置为2。所提出的EyeMoSt+的过程如算法1所示。
4. 试验
4.1 数据集和训练的信息
实验数据集:
本文对EyeMoSt+模型在公共和私人数据集上的性能进行了全面评估:GAMMA数据集[51]、OLIVES[37]和为本研究开发的内部数据集。下面详细介绍了不同疾病的不同数据集。
- GAMMA青光眼识别数据集:为了评估我们提出的青光眼识别方法的有效性,我们在GAMMA数据集上评估了其性能[51]。该数据集包括100个配对病例,每个病例都被分配了三级青光眼分级。OCT和近红外眼底图像的原始图像大小为256×512×992和1956×1934,其中256是OCT切片的总数。有关原始数据集的更多详细信息,请参阅[51]。这些案例被仔细地分为训练和测试子集,分别包含80%和20%的案例。为了减轻附带因素对绩效评估的影响,采用了严格的五重交叉验证策略。
- 用于DR和DME筛查的 OLIVES数据集:随后使用OLIVES数据集验证了所提出的算法在识别DR和DME方面的有效性[37]。该数据集包括96名患者在多个周期内的配对OCT和近红外眼底图像,共产生3128对配对病例。更具体地说,它包括56名DME患者和40名DR患者,共有1837个DME样本和1291个DR样本。OCT和近红外眼底图像的原始图像大小为48×504×496和768×768,其中48是OCT切片的总数。关于原始数据集的更多细节可以在[37]中找到。为了确保可靠的实验,我们将数据集划分为训练、验证和测试子集,保持8:1:1的比例。
- AMD和PCV筛查的内部数据集:最后,我们的方法使用汕头大学汕头国际联合眼科中心获得的独家内部数据集进行了严格的测试,使用Topcon 3D OCT-2000作为OCT和眼底采集设备。该数据集包括149例AMD和178例PCV,共涉及327名患者。一些病例同时具有左眼和右眼图像,共产生604对OCT和眼底图像,包括265个AMD样本和341个PCV样本。OCT和眼底图像的原始图像大小为128×512×885和2100×2000,其中128是OCT切片的总数。遵循既定的实践,我们将患者队列分为培训、验证和测试子集,保持一致的8:1:1患者比例,以进行可靠的实验。
训练细节:
我们提出的方法在PyTorch中实现,并在NVIDIA GeForce RTX 3090上训练。Adam优化[20]用于优化总体参数,初始学习率为0.0001。纪元的最大值是100。GAMMA数据集中的原始图像大小与内部数据集中的图像大小相当。
对于使用GAMMA数据集的OOD相关实验(详见第4.4节),我们分别将OCT和Fundus图像的原始大小统一调整为128×256×128和256×256。鉴于OLIVES与GAMMA原始图像和内部数据集之间的尺寸差异,我们的目标是优化NVIDIA GeForce RTX 3090图形卡内存的利用率,同时最大限度地保留原始图像中的信息。因此,我们将OLIVES数据集的眼底和OCT图像大小调整为512×512和48×248×248。批量大小设置为16。
在以下所有涉及添加高斯噪声的实验中,我们应用了十倍随机添加策略来减轻随机因素导致的任何性能改进。
需要注意的是,根据所使用的编码器,我们提出的EyeMoSt+可分为两个版本,EyeMoSt+(CNN)和EyeMoSp+(Transformer)。因此,对于EyeMoSt+(Transformer)版本,为了符合预训练模型的输入要求[28,13],我们将所有数据集的眼底和OCT的输入大小调整为384×384和96×96×96。
4.2 比较方法和指标
比较方法:
我们比较了以下六种方法:对于不同的融合阶段策略,
- a)基于CNN的中间典型融合方法的B-CNN基线,
- b)基于变换器的中间典型整合方法的B-Transformer基线,
- c)早期融合策略的B-EF基线[16],
- d)中间融合方法和后期融合方法的M2LC[50],
- e)TMC[12]用作比较。
B-EF首先在数据级别进行集成,然后通过相同的MedicalNet[7]。B-CNN和B-transformer首先通过编码器提取特征(与我们相同),然后将它们的输出特征连接起来作为最终预测。
此外,我们比较了:
- f)SmartDSP[5]和
- g)EyeStar[51],后者在GAMMA数据集上排名第一和第三。
对于不确定性量化方法,
- h)MCDO[19]采用测试时间损失作为贝叶斯神经网络的近似值。
- i)DE[22]通过整合多个模型来量化不确定性。
对于所有基线和我们提出的EyeMoSt+,我们使用准确性(ACC)指标根据验证性能选择了最佳检查点进行测试。
绩效指标和评估:
在我们的评估中,我们采用了ACC和Kappa指标,为将我们的方法与其他现有方法进行比较提供了直观的基础。为了量化有序排序的有效性,我们使用了风险覆盖面积(AURC)。这使得我们能够全面了解我们的方法的风险估计如何与实际结果相一致。对于校准,我们采用预期校准误差(ECE)[32]度量。
4.3 鲁棒性检验
1) GAMMA数据集:
我们首先在表1中的GAMMA数据集中使用最先进的方法报告了我们的算法。我们在正常情况下对各种方法进行了性能比较,包括GAMMA挑战赛中排名第一的SmartDSP(第一名)和EyeStar(第三名)方法。根据表1 中正常条件下ACC、Kappa和ECE的三个指标,我们观察到我们提出的EyeMoSt+(CNN)方法取得了相当的性能,稳居第二位。
为了评估所提出方法的稳健性,我们在测试期间分别向眼底模态或OCT模态引入了σ=0.1或σ=0.3的高斯噪声。如表1所示,在Fundus模态中添加σ=0.1高斯噪声导致所有方法的性能显著降低。然而,我们的方法,EyeMoSt+(CNN)和EyeMoSt+(Transformer),仍然具有可比性,其中EyeMoSt+(Transformers)表现出最佳性能。当σ=0.3高斯噪声被添加到OCT模态中时,几乎所有方法都无效,我们提出的EyeMoSt+方法保持了很高的识别精度。在假设检验的背景下,我们计算了噪声条件下所有方法与我们提出的方法获得的最佳结果之间的显著性差异,如p值所示,如表1所示。根据表1中的p值,我们提出的方法与TMC[12]、EyeStar[51]和B-transformer相比表现出了区别。值得注意的是,与其余方法相比,它显示出明显的差异。
此外,在更一般的情况下,我们展示了眼底或OCT模态在不同噪声条件下(σ=0.1、0.2、0.3、0.4、0.5)的ACC和ECE指标,如图4所示。随着噪声的增加,我们提出的两种方法都表现出了最佳性能,突显了我们融合方法的有效性。具体而言,在眼底模式的噪声条件下,EyeMoSt+(Transformer)表现出卓越的性能,而在OCT模式的噪声情况下,EyeMoSt+(CNN)也表现出最佳性能。这种变化可归因于预训练编码器的差异。总体而言,EyeMoSt+(CNN)在正常和嘈杂的条件下表现最佳。
2) OLIVES和内部数据集:
我们在表2和表3中进一步将我们的算法与OLIVES和内部数据集上的不同方法进行了比较。我们在干净的多模态眼部数据下比较了这些方法。我们的方法在ACC和Kappa方面取得了有竞争力的结果。需要注意的是,在OLIVES数据集上,由于数据丰富且易于分类,大多数方法都能达到完美的准确率。
然后,为了验证我们模型的稳健性,我们在数据集上的Fundus或OCT模态(σ=0.5/0.3)中添加了高斯噪声。我们发现,当所有方法遇到受高斯噪声影响的Fundus模态时,它们都能在一定程度上保持其性能。然而,一旦高斯噪声被引入OCT模态,它们的性能就会受到明显影响,如表2和表3所示。具体而言,与早期融合B-EF[16]、中间融合方法B-CNN、B-Transformer和M2LC[50]相比,EyeMoSt[56]在噪声OCT和眼底模态中均显示出增强或保持的分类准确性。同时,我们提出的EyeMoSt+方法表现出最佳性能,在各种条件下始终排名第一或第二。
与晚期融合方法TMC[12]相比,我们的EyeMoSt+在正常条件下表现出相当的性能,在嘈杂的眼底或OCT模式下表现出卓越的性能。我们还计算了在OLIVES和内部数据集上,我们提出的方法与其他方法在噪声条件下的最优结果之间的显著性差异。表2和表3中的p值结果表明,与其他方法相比,存在显著差异。
更一般地说,我们在眼底或OCT模态中添加了不同的高斯噪声(σ=0.1、0.2、0.3、0.4、0.5),如图5(a)和(b)所示,以展示它们对ACC和Kappa指标的影响。从图5中可以得出同样的结论,我们的EyeMoSt+在嘈杂的眼底或OCT模式中都表现出了卓越的性能。此外,我们发现,与眼底模式相比,噪声OCT模式对性能的影响更大。基于上述实验,我们可以得出结论,我们的EyeMoSt+不受任何噪声模态的影响,并在正常条件下实现了相当的性能。这种弹性可以归因于置信度分布融合和多模态排名损失,这使得在噪声模态下能够进行鲁棒融合。图5(c)显示了内部数据集上原始和不同噪声与眼底或OCT模态的视觉比较。
3) 不确定性估计和推理时间:
为了进一步量化不确定性估计的可靠性,我们使用GAMMA数据集上的ECE指标比较了不同的不确定性估计算法[22,19,12]。如表1和图4所示,我们提出的算法在干净模态中表现出相当的性能,在有噪声的单模态情况下表现出更稳健的性能。图5(a-b)显示了ECE和AURC在OLIVES和内部数据集上的更多比较。从这一观察中观察到类似的实验结论。最后,我们比较了不同算法在不同数据集上单个测试样本的平均推理时间,如表1、表2和表3所示。如这些表所示,与没有不确定性估计的方法相比,我们的EyeMost+的运行时间几乎没有改善,但提供了更准确和更稳健的性能。与不确定性估计方法相比,我们的EyeMoSt+(CNN)的处理速度比其他不确定性估计法(MCDO[19]、DE[22]和TMC[12])更快,尽管它比之前的版本EyeMoSt[56]稍慢。
值得注意的是,与其他不确定性估计方法相比,EyeMoSt+(Transformer)实现了最佳的处理速度,这归因于输入图像大小的减小。总体而言,EyeMoSt+(CNN)实现了更稳健的精度和可靠的不确定性估计。因此,在以下部分中,我们将只介绍EyeMoSt+(CNN)的结果。
4.4 分布外检测
根据[33],OOD数据可分为两大类:移位样本,与分布内(ID)数据相比,其表现出视觉差异但语义相似;以及接近OOD样本,其具有感知相似性但相对于ID数据具有不同的语义。为了复制这些场景,我们引入了噪声来创建偏移样本,并在我们的实验设置中用GAMMA数据集的眼底图像替换内部眼底图像,以生成接近OOD的样本。为了推进我们在多模态眼科临床应用中对不确定性估计的追求,我们对眼部数据进行了单模态和多模态不确定性分析。
1) 移位眼睛数据的不确定性分析:在我们的第一次分析中,我们在OLIVES和内部数据集中的单模态数据(眼底或OCT)中引入了不同水平的高斯噪声,以模拟移位的OOD数据。没有噪声的原始样本被标记为分布(ID)数据。图6(a)显示了不确定性和OOD数据之间的显著相关性。单模态图像的不确定性随着噪声的增加而成比例增加。这一观察强调了不确定性作为评估单模态眼部数据可靠性的指标的作用。此外,我们还研究了引入高斯噪声前后单模态和融合模态的不确定性密度。图6(b)提供了一个示例,即在内部数据集上向眼底模态或OCT模态添加σ=0.3的噪声。引入噪声后,融合的不确定性增加,导致熵分布图向右移动。值得注意的是,融合模态的分布与无噪声模态的分布更为接近(图6(b)(2)-(3))。因此,我们提出的方法可以作为评估眼科多模态数据融合中模态可靠性的有价值的工具。
2) 近OOD眼睛数据的置信度分析:此外,我们用以前从未见过的GAMMA[51]数据集中的眼底模态图像替换了内部数据集中的视网膜模态,以模拟近OOD数据。内部数据集的眼底和OCT图像输入被视为ID数据。如表4所示,与分布数据相比,各种方法的性能在不同程度上下降,而我们的方法保持了稳健的性能。我们还分析了眼底模态和融合模态划分内外数据的置信度分布的变化。
如图7所示,我们观察到眼底模态和OOD数据融合模态的置信度降低。然而,与眼底模式相比,融合模式的置信度下降不太明显。因此,我们提出的方法可以作为一种有效的OOD检测器,促进多模态眼病筛查中可靠和稳健的决策。
4.5 缺失模态的试验
在现实世界的临床诊断中,包含配对眼底和OCT图像的数据集通常是有限的。因此,我们在具有缺失模态场景的内部数据集上将我们的方法与其他方法进行了比较。
为了模拟一种模态的缺失,我们将缺失模态的输入设置为0。表5显示,即使Fundus模态缺失,包括我们的算法在内的大多数算法也能保持一定水平的性能。相反,在没有OCT模态的情况下,大多数算法的性能都会显著下降。值得注意的是,我们提出的算法表现出了持续的性能。
这一观察表明,在没有其他模态的情况下,大多数算法倾向于过度依赖特定的模态,如OCT模态,导致性能自然下降。我们提出的算法通过评估每种模态的置信度和可靠性,动态地结合了更可靠的模态。通过这样做,我们的方法解决了其他算法中观察到的局限性,并在缺少模态的情况下表现出了稳健的性能。
4.6 消融研究
-
λ F λ_F λF 和 λ C λ_C λC 的超参数选择:
λ F λ_F λF是 L m N L L L^{NLL}_m LmNLL 损失和 L m C E L^{CE}_m LmCE 损失 之间的平衡因子。在下面的实验中,我们证明了用 EyeMoSt 中引入的证据分类器 L m C E L^{CE}_m LmCE 来增强训练目标的重要性。
λ F ∈ [ 0 , 1 ] λ_F∈[0,1] λF∈[0,1]表示 L m C E L^{CE}_m LmCE 损失的重要性。我们对内部数据集进行了参数验证。如表6所示,引入 L m C E L^{CE}_m LmCE 损失后性能有所提高,最佳值为0.5。
λ C λ_C λC 表示控制置信感知多模态学习正则化的关键超参数。我们在内部数据集上进行了参数选择实验。根据[29],我们探索了 λ C λ_C λC =0.1~15 的范围来评估其性能。此外,为了强调这个正则化项的鲁棒性,我们在Fundus模态中添加了高斯噪声(σ=0.3)。如表7所示, λ C λ_C λC 的最佳值被确定为10。 -
总体学习过程:
此外,我们对方程13进行消融实验,如表8所示。其中B是中间典型融合方法BCNN的基线。
B-CNN首先通过编码器提取特征(与我们相同),然后将它们的输出特征连接起来作为最终预测。 L m N I G L^{NIG}_m LmNIG表示在建立多NIG分布后直接进行成对融合。 -
单峰和多模:
最后,我们在内部数据集上对B-CNN的单峰变体和我们提出的方法进行了比较分析。
具体来说,我们研究了各种单模态场景,称为 uni-B,其中只有Fundus模态用于训练。如表9所示,结果表明,基本方法B-CNN在融合后最初可以达到与单个单模态uni-B方法相当的性能水平。然而,当暴露在噪声中时,它很容易性能下降,有时甚至不如单模态方法。相比之下,本文中描述的我们提出的方法 EyeMoSt+(CNN)在融合过程中采用了基于置信度的分布。
5. 结论
总之,EyeMoSt+是一种开创性的解决方案,旨在通过无缝融合眼底和OCT模式来彻底改变多模式眼部疾病筛查。
基于NIG先验分布的原理,我们利用了单模态数据中嵌入的任意和认知不确定性。更重要的是,提出了一种针对Student t分布混合的置信度融合方法,以建立一个稳健可靠的疾病筛查模型。此外,我们创新的基于置信度感知排名的正则化形式为融合完整性提供了一个新的视角,防止了在存在噪声模态的情况下结果的妥协。
通过对各种眼病数据集的严格验证,包括青光眼识别、AMD和PCV筛查,以及DR和DME识别,我们的方法的可靠性和稳健性得到了牢固的确立。特别值得注意的是它在处理噪声输入、识别缺失模式和处理看不见的数据方面的有效性。
今后,我们的注意力将集中在两条不同的途径上。(1)首先,我们试图超越成对模态融合,深入到综合多模态融合的领域。(2)其次,我们致力于探索我们强大的眼病多模式筛查框架的实际应用。这一战略举措在提高人工智能驱动的医疗决策的准确性和可靠性方面具有重大前景,这一前景与我们的总体目标产生了强烈的共鸣。
6. Github 项目
数据集:
OLIVES dataset
GAMMA dataset
6.1 训练
- 训练基线模型
Run the script “main_train2.sh python baseline_train3_trans.py” to train the baselines (change model_name& mode), models will be saved in folder results
main_train2.sh python baseline_train3_trans.py
- 训练本文的模型
Run the script “main_train2.sh python train3_trans.py” to train our model (change model_name), models will be saved in folder results
main_train2.sh python train3_trans.py
6.2 测试
- 训练基线模型
Run the script “main_train2.sh python baseline_train3_trans.py” to test our model (change model_name& mode)
main_train2.sh python baseline_train3_trans.py
- 测试本文的模型
Run the script “main_train2.sh python train3_trans.py” to test our model (change model_name& mode)
main_train2.sh python train3_trans.py
6.3 核心程序代码(EyeMoSt/MedIA’24/train3_trans.py)
import os
import torch
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from model import TMC,EyeMost_Plus_transformer,EyeMost_Plus,EyeMost,EyeMost_prior
from sklearn.model_selection import KFold
from data import Multi_modal_data,GAMMA_dataset,OLIVES_dataset
import warnings
import matplotlib.pyplot as plt
import numpy as np
from metrics import cal_ece,cal_ece_our
from metrics2 import calc_aurc_eaurc,calc_nll_brier
from scipy.io import loadmat
from sklearn import metrics
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import cohen_kappa_score
import torch.nn as nn
import seaborn as sns
import torch.nn.functional as F
import math
warnings.filterwarnings("ignore")
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import time
from numpy import log, sqrt
from scipy.special import psi, beta
import logging
def log_args(log_file):
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter(
'%(asctime)s ===> %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
# args FileHandler to save log file
fh = logging.FileHandler(log_file)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
# args StreamHandler to print log to console
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
# add the two Handler
logger.addHandler(ch)
logger.addHandler(fh)
def Uentropy(logits,c):
pc = F.softmax(logits, dim=1)
logits = F.log_softmax(logits, dim=1)
u_all = -pc * logits / math.log(c)
NU = torch.sum(u_all[:,1:u_all.shape[1]], dim=1)
return NU
def entropy(sigma_2,predict_id):
id = int(predict_id)
entropy_list = 0.5 * np.log(2 * np.pi * np.exp(1.) * (sigma_2.cpu().detach().float().numpy()))
# entropy_list = np.log(sigma_2.cpu().detach().float().numpy())
NU = entropy_list[0,id]
return NU
def tensor_id_entropy(logits,predict_id):
id = int(predict_id)
p = F.softmax(logits, dim=1)
logp = F.log_softmax(logits, dim=1)
plogp = p * logp
entropy = -plogp[0][id]
return entropy
def loss_plot(args,loss):
num = args.end_epochs
x = [i for i in range(num)]
plot_save_path = r'results/plot/'
if not os.path.exists(plot_save_path):
os.makedirs(plot_save_path)
save_loss = plot_save_path+str(args.model_name)+'_'+str(args.batch_size)+'_'+str(args.dataset)+'_'+str(args.end_epochs)+'_loss.jpg'
list_loss = list(loss)
plt.figure()
plt.plot(x,loss,label='loss')
plt.legend()
plt.savefig(save_loss)
def metrics_plot(arg,name,*args):
num = arg.end_epochs
names = name.split('&')
metrics_value = args
i=0
x = [i for i in range(num)]
plot_save_path = r'results/plot/'
if not os.path.exists(plot_save_path):
os.makedirs(plot_save_path)
save_metrics = plot_save_path + str(arg.model_name) + '_' + str(arg.batch_size) + '_' + str(arg.dataset) + '_' + str(arg.end_epochs) + '_'+name+'.jpg'
plt.figure()
for l in metrics_value:
plt.plot(x,l,label=str(names[i]))
#plt.scatter(x,l,label=str(l))
i+=1
plt.legend()
plt.savefig(save_metrics)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def train(epoch,train_loader,model):
model.train()
loss_meter = AverageMeter()
# loss_list = []
for batch_idx, (data, target) in enumerate(train_loader):
for v_num in range(len(data)):
data[v_num] = Variable(data[v_num].cuda())
target = Variable(target.long().cuda())
# target = Variable(np.array(target)).cuda())
# refresh the optimizer
optimizer.zero_grad()
evidences, evidence_a, loss, _ = model(data, target, epoch)
print("total loss %f"%loss)
# compute gradients and take step
loss.backward()
optimizer.step()
loss_meter.update(loss.item())
# for i in range(0,len(loss_meter)):
# loss_list = loss_list.append(loss_meter[i].avg)
return loss_meter
def val(current_epoch,val_loader,model,best_acc):
model.eval()
loss_meter = AverageMeter()
correct_num, data_num = 0, 0
for batch_idx, (data, target) in enumerate(val_loader):
for m_num in range(len(data)):
data[m_num] = Variable(data[m_num].float().cuda())
data_num += target.size(0)
with torch.no_grad():
target = Variable(target.long().cuda())
evidences, evidence_a, loss,_ = model(data, target, epoch)
_, predicted = torch.max(evidence_a.data, 1)
correct_num += (predicted == target).sum().item()
loss_meter.update(loss.item())
aver_acc = correct_num / data_num
print('====> acc: {:.4f}'.format(aver_acc))
if evidence_a.shape[1] >2:
if aver_acc > best_acc:
print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc))
best_acc = aver_acc
print('===========>save best model!')
file_name = os.path.join(args.save_dir,
args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth')
torch.save({
'epoch': current_epoch,
'state_dict': model.state_dict(),
},
file_name)
return loss_meter.avg, best_acc
else:
if (current_epoch + 1) % int(args.end_epochs - 1) == 0 \
or (current_epoch + 1) % int(args.end_epochs - 2) == 0 \
or (current_epoch + 1) % int(args.end_epochs - 3) == 0:
file_name = os.path.join(args.save_dir,
args.model_name + '_' + args.dataset + '_' + args.folder + '_epoch_{}.pth'.format(
current_epoch))
torch.save({
'epoch': current_epoch,
'state_dict': model.state_dict(),
},
file_name)
if aver_acc > best_acc:
print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc))
best_acc = aver_acc
print('===========>save best model!')
file_name = os.path.join(args.save_dir,
args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth')
torch.save({ 'epoch': current_epoch,'state_dict': model.state_dict()},file_name)
# if (current_epoch + 1) % int(args.end_epochs - 1) == 0 \
# or (current_epoch + 1) % int(args.end_epochs - 2) == 0 \
# or (current_epoch + 1) % int(args.end_epochs - 3) == 0:
# file_name = os.path.join(args.save_dir,
# args.model_name + '_' + args.dataset + '_' + args.folder + '_epoch_{}.pth'.format(
# current_epoch))
# torch.save({
# 'epoch': current_epoch,
# 'state_dict': model.state_dict(),
# },
# file_name)
# if aver_acc > best_acc:
# print('aver_acc:{} > best_acc:{}'.format(aver_acc, best_acc))
# best_acc = aver_acc
# print('===========>save best model!')
# file_name = os.path.join(args.save_dir,
# args.model_name + '_' + args.dataset + '_' + args.folder + '_best_epoch.pth')
# torch.save({ 'epoch': current_epoch,'state_dict': model.state_dict()},file_name)
return loss_meter.avg, best_acc
def test(args, test_loader,model,epoch):
if args.num_classes == 2:
if args.test_epoch > 200:
load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
args.save_dir,
args.model_name + '_' + args.dataset +'_'+ args.folder + '_best_epoch.pth')
else:
load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
args.save_dir,
args.model_name + '_' + args.dataset +'_'+ args.folder + '_epoch_{}.pth'.format(args.test_epoch))
else:
if args.test_epoch > 200:
load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
args.save_dir,
args.model_name + '_' + args.dataset +'_'+ args.folder + '_best_epoch.pth')
else:
load_file = os.path.join(os.path.abspath(os.path.dirname(__file__)),
args.save_dir,
args.model_name + '_' + args.dataset +'_'+ args.folder + '_epoch_{}.pth'.format(args.test_epoch))
if os.path.exists(load_file):
checkpoint = torch.load(load_file)
model.load_state_dict(checkpoint['state_dict'])
args.start_epoch = checkpoint['epoch']
print('Successfully load checkpoint {}'.format(
os.path.join(args.save_dir + '/' + args.model_name +'_'+args.dataset+ '_epoch_' + str(args.test_epoch))))
else:
print('There is no resume file to load!')
model.eval()
list_acc = []
Fundus_confi_list= []
OCT_confi_list= []
fusion_confi_list = []
correct_list = []
entropy_list =[]
ece_list = []
nll_list = []
brier_list = []
label_list = []
prediction_list = []
probability_list = []
evidence_list = []
one_hot_label_list = []
one_hot_probability_list = []
correct_num, data_num = 0, 0
for batch_idx, (data, target) in enumerate(test_loader):
for v_num in range(len(data)):
data[v_num] = Variable(data[v_num].float().cuda())
data_num += target.size(0)
with torch.no_grad():
target = Variable(target.long().cuda())
if args.model_name =="ResNet_TMC":
evidences, evidence_a, _, u_a= model(data, target, epoch)
else:
evidences, evidence_a, _, u_a, gamma, v, alpha, beta = model(data, target, epoch)
probability_fu_m = F.softmax(evidence_a)
confidence_fu_m, _ = torch.max(probability_fu_m, axis=1)
# probability_list.append(b_a.cpu().detach().float().numpy())
correct_pred, predicted = torch.max(evidence_a.data, 1)
correct_num += (predicted == target).sum().item()
correct = (predicted == target)
list_acc.append((predicted == target).sum().item())
probability = torch.softmax(evidence_a, dim=1).cpu().detach().float().numpy()
one_hot_label = F.one_hot(target, num_classes=args.num_classes).squeeze(dim=0).cpu().detach().float().numpy()
# NLL brier
nll, brier = calc_nll_brier(probability, evidence_a, target, one_hot_label)
nll_list.append(nll)
brier_list.append(brier)
prediction_list.append(predicted.cpu().detach().float().numpy())
label_list.append(target.cpu().detach().float().numpy())
correct_list.append(correct.cpu().detach().float().numpy())
one_hot_label_list.append(F.one_hot(target, num_classes=args.num_classes).squeeze(dim=0).cpu().detach().float().numpy())
probability_list.append(torch.softmax(evidence_a, dim=1).cpu().detach().float().numpy()[:,1])
evidence_list.append(evidence_a.cpu().detach().float().numpy())
one_hot_probability_list.append(torch.softmax(evidence_a, dim=1).data.squeeze(dim=0).cpu().detach().float().numpy())
# ece
ece_list.append(cal_ece_our(torch.squeeze(evidence_a), target))
# entropy
entropy_list.append(tensor_id_entropy(evidence_a,predicted.cpu().detach().float().numpy()))
if args.num_classes > 2:
epoch_auc = metrics.roc_auc_score(one_hot_label_list,one_hot_probability_list, multi_class='ovo')
else:
epoch_auc = metrics.roc_auc_score(label_list,probability_list)
# correct_list = (prediction_list==label_list)
aurc, eaurc = calc_aurc_eaurc(probability_list, correct_list)
# np.savez(r'./results/OLIVES_Fundus_noise01.npz', Fundus_confi_list=Fundus_confi_list,OCT_confi_list=OCT_confi_list,fusion_confi_list=fusion_confi_list,np_u_list=np_u_list, np_entropy_list=np_entropy_list,np_OCT_au_list=np_OCT_au_list,np_Fundus_au_list=np_Fundus_au_list,np_OCT_eu_list=np_OCT_eu_list,np_Fundus_eu_list=np_Fundus_eu_list, np_OCT_uall=np_OCT_uall,np_Fundus_uall=np_Fundus_uall)
np.savez(r'./results/' + args.dataset + "_" + "noise" + "_" + str(Condition_G_Variance[0])+".npz", Fundus_confi_list=Fundus_confi_list,OCT_confi_list=OCT_confi_list,fusion_confi_list=fusion_confi_list,np_u_list=np_u_list, np_entropy_list=np_entropy_list,np_OCT_au_list=np_OCT_au_list,np_Fundus_au_list=np_Fundus_au_list,np_OCT_eu_list=np_OCT_eu_list,np_Fundus_eu_list=np_Fundus_eu_list, np_OCT_uall=np_OCT_uall,np_Fundus_uall=np_Fundus_uall)
avg_acc = correct_num/data_num
avg_ece = sum(ece_list)/len(ece_list)
avg_nll = sum(nll_list)/len(nll_list)
avg_brier = sum(brier_list)/len(brier_list)
avg_kappa = cohen_kappa_score(prediction_list, label_list)
F1_Score = f1_score(y_true=label_list, y_pred=prediction_list, average='weighted')
Recall_Score = recall_score(y_true=label_list, y_pred=prediction_list, average='weighted')
if not os.path.exists(os.path.join(args.save_dir, "{}_{}_{}".format(args.model_name,args.dataset,args.folder))):
os.makedirs(os.path.join(args.save_dir, "{}_{}_{}".format(args.model_name,args.dataset,args.folder)))
with open(os.path.join(args.save_dir,"{}_{}_{}_Metric.txt".format(args.model_name,args.dataset,args.folder)),'w') as Txt:
Txt.write("Acc: {}, AUC: {}, AURC: {}, EAURC: {}, NLL: {}, BRIER: {}, F1_Score: {}, Recall_Score: {}, Kappa_Score: {}, ECE: {}\n".format(
round(avg_acc,6),round(epoch_auc,6),round(aurc,6),round(eaurc,6),round(avg_nll,6),round(avg_brier,6),round(F1_Score,6),round(Recall_Score,6),round(avg_kappa,6),round(avg_ece,6)
))
# print(
# "Acc: {:.4f}, AUC: {:.4f}, AURC: {:.4f}, EAURC: {:.4f}, NLL: {:.4f}, BRIER: {:.4f}, F1_Score: {:.4f}, Recall_Score: {:.4f}, kappa: {:.4f}, ECE: {:.4f}".format(
# avg_acc, epoch_auc, aurc, eaurc, avg_nll, avg_brier, F1_Score, Recall_Score, avg_kappa, avg_ece))
# print('====> mean_inc_u: {:.4f},mean_c_u: {:.4f}\n'.format(mean_inc_u, mean_c_u))
return avg_acc,epoch_auc,aurc, eaurc, avg_nll,avg_brier,F1_Score,Recall_Score,avg_kappa,avg_ece
if __name__ == "__main__":
# kkk= math.log(math.pi)
# filename = './datasets/handwritten_6views.mat'
# image = loadmat(filename)
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=16, metavar='N',
help='input batch size for training [default: 100]')
parser.add_argument('--start_epoch', type=int, default=1, metavar='N',
help='number of epochs to train [default: 500]')
parser.add_argument('--end_epochs', type=int, default=100, metavar='N',
help='number of epochs to train [default: 500]')
parser.add_argument('--test_epoch', type=int, default=98, metavar='N',
help='number of epochs to train [default: 500]')
parser.add_argument('--lambda_epochs', type=int, default=50, metavar='N',
help='gradually increase the value of lambda from 0 to 1')
parser.add_argument('--modal_number', type=int, default=2, metavar='N',
help='modalties number')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
help='learning rate') # 0.0001 ResNet_ECNP_ST_Beta_fusion_C_DF_CML_loss ResNet_ECNP_ST_Beta_fusion_C_DF_CML_loss_evi ResNet_ECNP_ST_Beta_fusion_C_DF_CML ResNet_ECNP_ST_Beta_fusion_C_DF_CML_2D ResNet_ECNP_ST_Beta_fusion_C_DF_CML_3D
parser.add_argument('--save_dir', default='./results', type=str) # ResNet_MMST /ResNet_MMST_Reg /ResNet_MMST_Reg_Ce/ResNet_MMST_Reg/ResNet_MMST_Reg_CML//ResNet_MMST_Reg_CMLU /ResNet_MMST_Reg_Ce/ ResNet_TMC / ResNet_ECNP / ResNet_ECNP_Beta_CML / ResNet_ECNP_Beta_Reg /ResNet_ECNP_Beta_Reg_Ce / ResNet_ECNP_Beta_Reg_CML / ResNet_ECNP_Beta_Reg_CMLU/ ResNet_MMST_Reg_Ce_CMLU/ ResNet_MMST_Reg_Ce_CML
parser.add_argument("--model_name", default="EyeMost_PLUS_transformer", type=str, help="ResNet_TMC/Base_transformer/ResNet_ECNP_ST_Beta_fusion3/ResNet_ECNP_ST_Beta_fusion_C_DF/ResNet_ECNP_ST_Beta_fusion_C_DF_CML/ResNet_ECNP_ST_Beta_fusion_C_DF_CML_lamdaC05/ResNet_ECNP_ST_Beta_fusion_C_DF_CML_lamdaC15")
parser.add_argument("--model_base", default="transformer", type=str, help="transformer/cnn")
parser.add_argument("--dataset", default="MGamma", type=str, help="MMOCTF/Gamma/MGamma/OLIVES")
# parser.add_argument('--num_classes', type=int, default=2, metavar='N',
# help='class number: MMOCTF: 2 /Gamma: 3')
parser.add_argument("--condition", default="normal", type=str, help="noise/normal")
parser.add_argument("--condition_name", default="Gaussian", type=str, help="Gaussian/SaltPepper/All")
parser.add_argument("--Condition_SP_Variance", default=0.005, type=int, help="Variance: 0.01/0.1")
parser.add_argument("--Condition_G_Variance", default=0.05, type=int, help="Variance: 15/1/0.1")
parser.add_argument("--folder", default="folder0", type=str, help="folder0/folder1/folder2/folder3/folder4")
parser.add_argument("--mode", default="test", type=str, help="test/train&test")
# -- for ECNP parameters
parser.add_argument('-rps', '--representation_size', type=int, default=128, help='Representation size for context')
parser.add_argument('-hs', '--hidden_size', type=int, default=128, help='Model hidden size')
parser.add_argument('-ev_dec_beta_min', '--ev_dec_beta_min', type=float, default=0.2,
help="EDL Decoder beta minimum value")
parser.add_argument('-ev_dec_alpha_max', '--ev_dec_alpha_max', type=float, default=20.0,
help="EDL output alpha maximum value")
parser.add_argument('-ev_dec_v_max', '--ev_dec_v_max', type=float, default=20.0, help="EDL output v maximum value")
parser.add_argument('-nig_nll_reg_coef', '--nig_nll_reg_coef', type=float, default=0.1,
help="EDL nll reg balancing factor")
parser.add_argument('-nig_nll_ker_reg_coef', '--nig_nll_ker_reg_coef', type=float, default=1.0,
help='EDL kernel reg balancing factor')
parser.add_argument('-ev_st_u_min', '--ev_st_u_min', type=float, default=0.0001,
help="EDL st output sigma minnum value")
parser.add_argument('-ev_st_sigma_min', '--ev_st_sigma_min', type=float, default=0.2,
help="EDL st output sigma minnum value")
parser.add_argument('-ev_st_v_max', '--ev_st_v_max', type=float, default=30.0, help="EDL output v maximum value")
Condition_G_Variance = [0,0.01, 0.03, 0.05, 0.07, 0.1,0.3,0.5] # OCT & GAMMA
seed_num = list(range(1,11))
condition_level = ['normal','noise']
args = parser.parse_args()
args.seed_idx = 11
if args.dataset =="MMOCTF":
args.data_path = '/data/zou_ke/projects_data/Multi-OF/2000/'
args.modalties_name = ["FUN", "OCT"]
args.num_classes = 2
args.dims = [[(128, 256, 128)], [(512, 512)]]
args.modalties = len(args.dims)
train_loader = torch.utils.data.DataLoader(
Multi_modal_data(args.data_path, args.modal_number,args.modalties_name, 'train',args.condition,args, folder=args.folder), batch_size=args.batch_size)
val_loader = torch.utils.data.DataLoader(
Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'val',args.condition,args, folder=args.folder), batch_size=1)
test_loader = torch.utils.data.DataLoader(
Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'test',args.condition,args, folder=args.folder), batch_size=1)
N_mini_batches = len(train_loader)
print('The number of training images = %d' % N_mini_batches)
elif args.dataset =="OLIVES":
args.data_path = '/data/zou_ke/projects_data/OLIVES/OLIVES/'
# args.data_path = '/data/zou_ke/projects_data/OLIVES2/OLIVES/'
# args.data_path = '/data/zou_ke/projects_data/OLIVES3/OLIVES/'
args.modalties_name = ["FUN", "OCT"]
args.num_classes = 2
args.dims = [[(48, 248, 248)], [(512, 512)]]
args.modalties = len(args.dims)
train_loader = torch.utils.data.DataLoader(
OLIVES_dataset(args.data_path, args.modal_number,args.modalties_name, 'train',args.condition,args, folder=args.folder), batch_size=args.batch_size)
val_loader = torch.utils.data.DataLoader(
OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'val',args.condition,args, folder=args.folder), batch_size=1)
test_loader = torch.utils.data.DataLoader(
OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'test',args.condition,args, folder=args.folder), batch_size=1)
N_mini_batches = len(train_loader)
print('The number of training images = %d' % N_mini_batches)
elif args.dataset =="MGamma":
args.modalties_name = ["FUN", "OCT"]
args.dims = [[(128, 256, 128)], [(512, 512)]]
args.num_classes = 3
args.modalties = len(args.dims)
args.base_path = '/data/zou_ke/projects_data/Multi-OF/Gamma/'
args.data_path = '/data/zou_ke/projects_data/Multi-OF/MGamma/'
filelists = os.listdir(args.data_path)
# kf = KFold(n_splits=5, shuffle=True, random_state=10)
kf = KFold(n_splits=5, shuffle=True, random_state=10)
y = kf.split(filelists)
count = 0
train_filelists = [[], [], [], [], []]
val_filelists = [[], [], [], [], []]
for tidx, vidx in y:
train_filelists[count], val_filelists[count] = np.array(filelists)[tidx], np.array(filelists)[vidx]
count = count + 1
f_folder = int(args.folder[-1])
train_dataset = GAMMA_dataset(args,dataset_root = args.data_path,
oct_img_size = args.dims[0],
fundus_img_size = args.dims[1],
mode = 'train',
label_file = args.base_path + 'glaucoma_grading_training_GT.xlsx',
filelists = np.array(train_filelists[f_folder]))
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
val_dataset = GAMMA_dataset(args,dataset_root=args.data_path,
oct_img_size = args.dims[0],
fundus_img_size = args.dims[1],
mode = 'val',
label_file = args.base_path + 'glaucoma_grading_training_GT.xlsx',
filelists = np.array(val_filelists[f_folder]),)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)
test_dataset = val_dataset
test_loader = val_loader
else:
print('There is no this dataset name')
raise NameError
if args.model_name =="ResNet_TMC":
model = TMC(args.num_classes, args.modalties, args.dims, args.lambda_epochs)
elif args.model_name == "EyeMost_prior": # fusion rule for mean Our
model = EyeMost_prior(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs)
elif args.model_name == "EyeMost": # fusion rule for mean Our
model = EyeMost(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs)
elif args.model_name == "EyeMost_Plus": # lamdaC = 10
model = EyeMost_Plus(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs)
elif args.model_name == "EyeMost_Plus_transformer": # lamdaC = 10
model = EyeMost_Plus_transformer(args.num_classes, args.modalties, args.dims, args, args.lambda_epochs)
else:
print('There is no this model name')
raise NameError
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
# optimizer = optim.AdamW(model.parameters(), lr=1e-5, betas=(0.9, 0.999),
# weight_decay=0.0005)
model.cuda()
best_acc = 0
loss_list = []
acc_list = []
if args.mode =='train&test':
epoch = 0
for epoch in range(args.start_epoch, args.end_epochs + 1):
print('===========Train begining!===========')
print('Epoch {}/{}'.format(epoch, args.end_epochs - 1))
epoch_loss = train(epoch,train_loader,model)
print("epoch %d avg_loss:%0.3f" % (epoch, epoch_loss.avg))
val_loss, best_acc = val(epoch,val_loader,model,best_acc)
loss_list.append(epoch_loss.avg)
acc_list.append(best_acc)
loss_plot(args, loss_list)
metrics_plot(args, 'acc', acc_list)
test_acc,test_acclist = test(args,test_loader,model,epoch)
elif args.mode == 'test':
epoch = args.test_epoch
for i in range(len(Condition_G_Variance)):
args.Condition_G_Variance = Condition_G_Variance[i]
print("Gaussian noise: %f" % args.Condition_G_Variance)
acc_list,auc_list,aurc_list,eaurc_list,nll_list, brier_list,\
F1_list,Rec_list,kap_list,ECE_list = [],[],[],[],[],[],[],[],[],[]
for j in range(len(seed_num)):
# for j in range(1):
args.seed_idx = seed_num[j]
# print("seed_idx: %d" % args.seed_idx)
if args.dataset == "MMOCTF":
args.data_path = '/data/zou_ke/projects_data/Multi-OF/2000/'
args.modalties_name = ["FUN", "OCT"]
args.num_classes = 2
args.dims = [[(128, 256, 128)], [(512, 512)]]
args.modalties = len(args.dims)
test_loader = torch.utils.data.DataLoader(
Multi_modal_data(args.data_path, args.modal_number, args.modalties_name, 'test', args.condition,
args, folder=args.folder), batch_size=1)
N_mini_batches = len(test_loader)
# print('The number of testing images = %d' % N_mini_batches)
elif args.dataset == "OLIVES":
args.data_path = '/data/zou_ke/projects_data/OLIVES/OLIVES/'
# args.data_path = '/data/zou_ke/projects_data/OLIVES2/OLIVES/'
# args.data_path = '/data/zou_ke/projects_data/OLIVES3/OLIVES/'
args.modalties_name = ["FUN", "OCT"]
args.num_classes = 2
args.dims = [[(48, 248, 248)], [(512, 512)]]
args.modalties = len(args.dims)
test_loader = torch.utils.data.DataLoader(
OLIVES_dataset(args.data_path, args.modal_number, args.modalties_name, 'test', args.condition,
args, folder=args.folder), batch_size=1)
N_mini_batches = len(test_loader)
# print('The number of testing images = %d' % N_mini_batches)
elif args.dataset == "MGamma":
args.modalties_name = ["FUN", "OCT"]
args.dims = [[(128, 256, 128)], [(512, 512)]]
args.num_classes = 3
args.modalties = len(args.dims)
args.base_path = '/data/zou_ke/projects_data/Multi-OF/Gamma/'
args.data_path = '/data/zou_ke/projects_data/Multi-OF/MGamma/'
filelists = os.listdir(args.data_path)
kf = KFold(n_splits=5, shuffle=True, random_state=10)
y = kf.split(filelists)
count = 0
train_filelists = [[], [], [], [], []]
val_filelists = [[], [], [], [], []]
for tidx, vidx in y:
train_filelists[count], val_filelists[count] = np.array(filelists)[tidx], np.array(filelists)[
vidx]
count = count + 1
f_folder = int(args.folder[-1])
val_dataset = GAMMA_dataset(args, dataset_root=args.data_path,
oct_img_size=args.dims[0],
fundus_img_size=args.dims[1],
mode='val',
label_file=args.base_path + 'glaucoma_grading_training_GT.xlsx',
filelists=np.array(val_filelists[f_folder]), )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1)
test_dataset = val_dataset
test_loader = val_loader
N_mini_batches = len(test_loader)
# print('The number of testing images = %d' % N_mini_batches)
else:
print('There is no this dataset name')
raise NameError
# start_time = time.time()
test_acc, test_auc, test_aurc, test_eaurc, test_nll, test_brier,test_F1, test_Rec, test_kappa, test_ece\
= test(args, test_loader, model, epoch)
# print('time cost:' + str(int(time.time()-start_time)) + 's')
acc_list.append(test_acc)
auc_list.append(test_auc)
aurc_list.append(test_aurc)
eaurc_list.append(test_eaurc)
nll_list.append(test_nll)
brier_list.append(test_brier)
F1_list.append(test_F1)
Rec_list.append(test_Rec)
kap_list.append(test_kappa)
ECE_list.append(test_ece)
acc_list_mean,acc_list_std = np.mean(acc_list),np.std(acc_list)
auc_list_mean,auc_list_std = np.mean(auc_list),np.std(auc_list)
aurc_list_mean,aurc_list_std = np.mean(aurc_list),np.std(aurc_list)
eaurc_list_mean,eaurc_list_std = np.mean(eaurc_list),np.std(eaurc_list)
nll_list_mean,nll_list_std = np.mean(nll_list),np.std(nll_list)
brier_list_mean,brier_list_std = np.mean(brier_list),np.std(brier_list)
F1_list_mean,F1_list_std = np.mean(F1_list),np.std(F1_list)
Rec_list_mean,Rec_list_std = np.mean(Rec_list),np.std(Rec_list)
kap_list_mean,kap_list_std = np.mean(kap_list),np.std(kap_list)
ECE_list_mean,ECE_list_std = np.mean(ECE_list),np.std(ECE_list)
print(
"Mean_Std_Acc: {:.4f} +- {:.4f}, Mean_Std_AUC: {:.4f} +- {:.4f},Mean_Std_AURC: {:.4f} +- {:.4f}, "
"Mean_Std_EAURC: {:.4f} +- {:.4f}, Mean_Std_nll: {:.4f} +- {:.4f}, Mean_Std_brier: {:.4f} +- {:.4f}, Mean_Std_F1_Score: {:.4f} +- "
"{:.4f}, Mean_Std_Recall_Score: {:.4f} +- {:.6f}, Mean_Std_kappa: {:.4f} +- {:.4f}, Mean_Std_ECE: {:.4f} +- {:.4f}".format(
acc_list_mean, acc_list_std, auc_list_mean,auc_list_std, aurc_list_mean, aurc_list_std, eaurc_list_mean, eaurc_list_std,
nll_list_mean, nll_list_std,brier_list_mean,brier_list_std,F1_list_mean, F1_list_std,Rec_list_mean,Rec_list_std,kap_list_mean,kap_list_std,ECE_list_mean,ECE_list_std))
logging.info(
"Mean_Std_Acc: {:.4f} +- {:.4f}, Mean_Std_AUC: {:.4f} +- {:.4f},Mean_Std_AURC: {:.4f} +- {:.4f}, "
"Mean_Std_EAURC: {:.4f} +- {:.4f}, Mean_Std_nll: {:.4f} +- {:.4f}, Mean_Std_brier: {:.4f} +- {:.4f}, Mean_Std_F1_Score: {:.4f} +- "
"{:.4f}, Mean_Std_Recall_Score: {:.4f} +- {:.4f}, Mean_Std_kappa: {:.4f} +- {:.4f}, Mean_Std_ECE: {:.4f} +- {:.4f}".format(
acc_list_mean, acc_list_std, auc_list_mean,auc_list_std, aurc_list_mean, aurc_list_std, eaurc_list_mean, eaurc_list_std,
nll_list_mean, nll_list_std,brier_list_mean,brier_list_std,F1_list_mean, F1_list_std,Rec_list_mean,Rec_list_std,kap_list_mean,kap_list_std,ECE_list_mean,ECE_list_std))
版权说明:
本文由 youcans@xidian 对论文 Confidence-aware multi-modality learning for eye disease screening 进行摘编和翻译。该论文版权属于原文期刊和作者,本译文只供研究学习使用。
youcans@xidian 作品,转载必须标注原文链接:
【医学影像 AI】EyeMoS𝑡+:用于眼科疾病筛查的置信度感知多模态学习框架https://youcans.blog.csdn.net/article/details/145340677
Crated:2025-02