论文120:Giga-SSL: Self-supervised learning for gigapixel images (2023, CVPR, 开源)
文章目录
- 1 要点
- 2 方法
- 2.1 算法设计
- 2.2 设计选择
1 要点
题目:用于千兆像素图像的自监督学习 (Giga-SSL: Self-Supervised Learning for Gigapixel Images)
代码:https://github.com/trislaz/gigassl
研究目的:
现有的WSI分类方法依赖于有限的标注数据集,这可能导致模型过拟合和性能不足。同时,大量未标注的WSI数据集的可用性不断增加,但这些数据在现有的自监督学习 (SSL) 框架下未能充分利用。因此,拟提出了一种在WSI标签下进行SSL的策略,以利用大量未标注的WSI数据并在小数据集上提高分类性能。
关键技术:
- Giga-SSL框架:
一个转为WSI设计的自监督学习框架,能够在不使用任何标注数据的情况下,仅使用WSI图像来学习包表示; - SparseConvMIL架构:
用于WSI分类的扩展,结合了ResNet网络作为区块嵌入器和池化函数,以及子流形卷积网络来处理稀疏数据; - 对比损失(Contrastive Loss):
在SSL使用,通过优化正样本对的相似度来训练模型;
数据集:
- TCGA
2 方法
2.1 算法设计
令 X X X表示一个WSI,算法的骨架是扩展的SparseConvMIL架构:
- 包含一个 ResNet网络
f
θ
f_θ
fθ (例如ResNet18),它在第四个残差块的开始处被切成两个连续的部分:
- 第一部分:实例 (tile) 嵌入器 e θ 1 e_{θ_1} eθ1,由 f θ f_θ fθ的从输入层到第四个残差块的部分组成;
- 第二部分:池化函数KaTeX parse error: Expected '}', got 'EOF' at end of input: p_{θ_2,由 f θ f_θ fθ的余下层组成,包括后续的残差块直到全连接层。该部分已经转换为子流形卷积网络,以便它可以处理稀疏数据。
- 对于任何图像
i
i
i,ResNet嵌入为:
f θ ( i ) = p θ 2 ( e θ 1 ( i ) ) ∈ R 512 f_θ(i) = p_{θ_2}(e_{θ_1}(i)) \in \mathbb{R}^{512} fθ(i)=pθ2(eθ1(i))∈R512
算法的训练过程包括6个顺序步骤,以提取WSI表示,如图1。
- 在实例级别设置两个WSI增强函数
t
1
t_1
t1和
t
2
t_2
t2,其图像增强域
A
A
A分别从颜色增强 (色彩抖动、灰度) 和几何增强 (翻转、旋转、缩放、模糊) 中采样:
- 从 X X X中为每个增强函数 t 1 t_1 t1和 t 2 t_2 t2抽取 T T T个实例,得到实例集合 { X 1 } \{X_1\} {X1}和 { X 2 } \{X_2\} {X2};
- 存储实例集合原采样位置的左上角像素坐标,以供进一步处理;
- t 1 t_1 t1应用于 { X 1 } \{X_1\} {X1}中的所有实例,得到增强实例 t 1 ( { X 1 } ) t_1(\{X_1\}) t1({X1}),同理可得 t 2 ( { X 2 } ) t_2(\{X_2\}) t2({X2});
- 实例嵌入:通过实例嵌入器 e θ 1 e_{θ_1} eθ1同时独立地将 t 1 ( { X 1 } ) t_1(\{X_1\}) t1({X1})和 t 2 ( { X 2 } ) t_2(\{X_2\}) t2({X2})中的每个实例向前传递。每张图像因此被转换为一个特征图,通过所有像素取平均,每个增强实例集合中的每个实例,都将获得一个大小为 F F F (对于ResNet18为256) 的实例嵌入;
- 构建稀疏图:按照SparseConvMIL框架,通过将
t
1
(
{
X
1
}
)
t_1(\{X_1\})
t1({X1})产生的每个嵌入分配到它们在步骤1中采样的原始实例的位置,构建一个稀疏图
S
1
S_1
S1 (下采样因子
d
=
224
d = 224
d=224)。类似地,从
t
2
(
{
X
2
}
)
t_2(\{X_2\})
t2({X2})的嵌入构建一个稀疏图
S
2
S_2
S2;
- 难点:之所以叫稀疏图,是因为抽取的实例的位置随机,结合它们的 x x x与 y y y坐标,自然就是稀疏分布的实例图;
-
S
1
S_1
S1和
S
2
S_2
S2被随机翻转、旋转,并独立地沿
x
x
x 轴和
y
y
y轴以
[
0.5
,
2
]
[0.5, 2]
[0.5,2]中均匀采样的因子进行缩放;
- 难点:这样做的目的是进一步增强稀疏图;
- 这也可以用来增强WSI,不过当前和后续的训练过程只需要增强的稀疏图;
- 将稀疏图嵌入到两个增强的WSI表示中:为了计算表示,对两个增强的稀疏图 S 1 S_1 S1和 S 2 S_2 S2应用 p θ 2 p_{θ_2} pθ2。在这个阶段,输入是 X X X的两个增强视图;
- 损失优化与SimCLR中的做法一样,最终将增强视图输入到投影器中,得到两个增强的投影,用它们来计算损失。我们通过优化对比损失NT-XENT来训练池化函数$p_{θ_2}的权重。给定一个由增强WSI
(
X
1
i
,
X
2
i
)
i
∈
B
(X^i_1, X^i_2)_{i∈B}
(X1i,X2i)i∈B组成的小批量
B
B
B,为WSI的正样本对设置损失函数如下:
ℓ i = − l o g exp ( sim ( X 1 i , X 2 i ) / τ ) ∑ x ∈ B ( 1 x ≠ X 1 i exp ( sim ( X 1 i , x ) / τ ) (1) \tag{1} ℓ_i = - log\frac{\exp(\text{sim}(X^i_1, X^i_2) / τ)}{∑_{x∈B}( \mathbf{1}_{x\neq X^i_1} \exp(\text{sim}(X^i_1, x) / τ)} ℓi=−log∑x∈B(1x=X1iexp(sim(X1i,x)/τ)exp(sim(X1i,X2i)/τ)(1)其中 τ τ τ是温度参数, 1 ⋅ 1{\cdot} 1⋅是指示函数。最终损失是这些项在整个视图中的平均值。
2.2 设计选择
- 选择基础CNN架构:
Giga-SSL理论上不依赖于ResNet架构。有许多好的架构可供选择,用于实例编码器和池化函数,包括不同架构的两部分。然而,池化函数必须能够处理稀疏数据,因为它处理增强的稀疏图; - 离线增强策略:
GigaSSL训练的一个关键计算瓶颈是在线计算一批B个WSI的实例嵌入,每个WSI由 T T T个实例组成。GPU内存限制对 B B B和 N t N_t Nt施加限制,这实际上限制了每批可以使用的总实例数。此外,已经证明在自然图像的SSL中需要大的批次大小以获得具有良好下游分类性能的表示。解决这些问题的一个策略是冻结实例编码器 e θ 1 e_{θ_1} eθ1并预计算每个WSI的随机采样和增强实例的嵌入,即基本上绕过步骤1和2:- 采样50个实例级增强函数 t k t_k tk;
- 对于每个 k k k,从WSI中随机抽取256个实例并用 t k t_k tk增强;
- 同时独立地将每个增强实例输入 e θ 1 e_{θ_1} eθ1并存储。这个过程得到 N ∗ 50 ∗ 256 N*50*256 N∗50∗256个实例嵌入,其中 N N N是Giga-SSL训练数据集中WSI的总数。
说人话:
- WSI很大,所有随机抽样 T T T个并增强,这样的结果被称为增强实例集合,也可以说是稀疏图。稀疏图将被进一步增强,这个过程相当于得到新的WSI;
- 新的WSI输入的预训练模型,后面就是自监督学习的过程了;