计算机视觉语义分割——Attention U-Net(Learning Where to Look for the Pancreas)
计算机视觉语义分割——Attention U-Net(Learning Where to Look for the Pancreas)
文章目录
- 计算机视觉语义分割——Attention U-Net(Learning Where to Look for the Pancreas)
- 摘要
- Abstract
- 一、Attention U-Net
- 1. 基本思想
- 2. Attention Gate模块
- 3. 软注意力与硬注意力
- 4. 实验
- 5. 代码实践
- 总结
摘要
本周学习了Attention U-Net模型,这是一种在U-Net基础上改进的语义分割模型,主要应用于医学影像分割任务。Attention U-Net通过引入注意力门(Attention Gate, AG)模块,自动学习目标的形状和尺寸,同时抑制无关区域,聚焦于显著特征。相比标准U-Net模型,Attention U-Net在分割性能、模型敏感度和准确率方面均有显著提升。AG模块采用加性注意力机制,能够有效突出感兴趣区域的特征响应,减少冗余信息,并且易于与不同的CNN模型集成。此外,学习了软注意力与硬注意力的原理及其在模型中的具体应用。最后,结合代码实践,进一步加深了对Attention U-Net网络架构与实现的理解。
Abstract
This week focused on the Attention U-Net model, an improved semantic segmentation architecture based on U-Net, primarily designed for medical image segmentation tasks. By introducing Attention Gate (AG) modules, the model automatically learns the shape and size of targets while suppressing irrelevant regions and focusing on salient features. Compared to the standard U-Net, the Attention U-Net achieves significant improvements in segmentation performance, sensitivity, and accuracy. The AG module employs an additive attention mechanism, effectively highlighting features of interest and reducing redundant information, while being easily integrable into various CNN models. Additionally, the principles of soft attention and hard attention were studied along with their specific applications in the model. Practical implementation through code further enhanced the understanding of the Attention U-Net architecture and its functionality.
一、Attention U-Net
Attention Unet是Unet的改进版本,主要应用于医学影像领域中的分割任务。Attention Unet是一种基于注意门(Attention Gate, AG)的模型,它会自动学习区分目标的外形和尺寸。这种有Attention Gate的模型在训练时会学会抑制不相关的区域,注重于有用的显著特征。Attention Gate很容易被整合进标准的CNN模型中,极少的额外计算量却能带来显著的模型敏感度和准确率的提高。作者利用Attention U-Net模型,在两个大型CT腹部数据集上进行了多类别的图像分割。实验结果表明,Attention Gate可以在保持计算效率的同时,持续提高U-Net在不同数据集和训练规模下的预测性能。
1. 基本思想
既然Attention U-Net是U-Net的改进,那么需要先简单回顾一下U-Net,来更好地对比。下面是U-Net的网络结构图:
从U-Net的结构图中看出,为了避免在Decoder解码器时丢失大量的细节,使用了Skip Connection跳跃连接,将Encoder编码器中提取的信息直接连接到Decoder对应的层上。但是,Encoder提取的low-level feature有很多的冗余信息,也就是提取的特征不太好,可能对后面并没有帮助,这就是U-Net网络存在的问题。
该问题可以通过在U-Net上加入注意力机制,来减少冗余的Skip Connection,这也是设计Attention Gate的动机。
对比基础的U-Net网络结构,我们可以发现在跳跃连接中都加入了Attention Gate模块。而原始U-Net只是单纯的把同层的下采样层的特征直接连接到上采样层中。Attention U-Net的主要贡献分为三点:
- 提出基于网格的AG,这会使注意力系数更能凸显局部区域的特征。
- 首次在医学图像的CNN中使用Soft Attention,该模块可以替代分类任务中的Hard Attention和器官定位任务中的定位模块。
- 将U-Net改进为Attention U-Net,增加了模型对前景像素的敏感度,并设计实验证明了这种改进是通用的。
2. Attention Gate模块
U-Net(FCN类的网络模型)中的卷积层会根据一层一层的局部信息来提取得到高维图像的表示
x
l
x^l
xl,最终在高维空间的离散像素会具有语义信息。通过U-Net这种顺序结构的处理,由巨大的感受野所提取的信息会影响着模型的预测结果。因此,特征图
x
l
x^l
xl是由
l
l
l层的输出依次通过一个线性变换(卷积操作)和一个非线性激活函数得到的。这个非线性激活函数一般选择为ReLU函数,即:
σ
1
(
x
i
,
c
l
)
=
max
(
0
,
x
i
,
c
l
)
{\sigma _1}(x_{i,c}^l) = \max (0,x_{i,c}^l)
σ1(xi,cl)=max(0,xi,cl),其中
i
i
i和
c
c
c分别代表着空间信息维度和通道信息维度。特征的激活可以写为:
x
c
l
=
σ
1
(
∑
c
′
∈
F
l
x
c
′
l
−
1
∗
k
c
′
,
c
)
x_c^l = {\sigma _1}\left( {\sum\nolimits_{c' \in {F_l}} {x_{c'}^{l - 1} * {k_{c',c}}} } \right)
xcl=σ1(∑c′∈Flxc′l−1∗kc′,c),这是通道维度上的写法,其中
∗
*
∗代表卷积操作,空间信息维度的下标
i
i
i为了简便而省略。因此,每个卷积层的操作过程可以写为函数:
f
(
x
l
;
Φ
l
)
=
x
(
l
+
1
)
f({x^l};{\Phi ^l}) = {x^{(l + 1)}}
f(xl;Φl)=x(l+1),
Φ
l
{\Phi ^l}
Φl是可以训练的卷积核参数。参数的学习是通过最小化训练目标(如交叉熵损失),并使用SGD随机梯度下降法进行学习的。本论文在基本思想的框架图中,基于U-Net体系构建了Attention U-Net,粗粒度的特征图会捕获上下文信息,并突出显示前景对象的类别和位置。随后,通过skip connections合并以多个比例提取的特征图,以合并粗粒度和细粒度的密集预测。
注意系数(Attention coefficients): α i ∈ [ 0 , 1 ] {\alpha _i} \in \left[ {0,1} \right] αi∈[0,1],是为了突出显著的图像区域和抑制任务无关的特征响应。Attention Gate的输出是输入的特征图与 α i {\alpha _i} αi做Element-wise乘法(对应元素逐个相乘),即: x ^ i , c l = x i , c l ⋅ α i l \hat x_{i,c}^l = x_{i,c}^l \cdot \alpha _i^l x^i,cl=xi,cl⋅αil 。在默认设置中,会为每个像素向量 x i l ∈ R F l x_i^l \in {R^{{F_l}}} xil∈RFl计算一个标量的注意值,其中 F l {{F_l}} Fl对应于第 l l l层中特征图的数量。如果有多个语义类别,则应当学习多维的注意系数。
在上图Attention Gate的具体框图中,门控向量 g i ∈ R F g {g_i} \in {R^{{F_g}}} gi∈RFg为每个像素 i i i确定焦点区域。门控向量包含上下文信息,以修剪低级特征响应。论文中选择加性注意力来获得门控系数,尽管这在计算上更昂贵,但从实验上看,它的性能比乘法注意力要高。加性注意力的计算公式如下所示:
注意这里要结合上面的结构来一起分析,其中 σ 1 {\sigma _1} σ1是ReLU激活函数, σ 2 {\sigma _2} σ2是Sigmoid激活函数。 W x ∈ R F l × F i n t {W_x} \in {R^{{F_l} \times {F_{{\mathop{\rm int}} }}}} Wx∈RFl×Fint, W g ∈ R F g × F i n t {W_g} \in {R^{{F_g} \times {F_{{\mathop{\rm int}} }}}} Wg∈RFg×Fint, ψ ∈ R F i n t × 1 \psi \in {R^{{F_{{\mathop{\rm int}} }} \times 1}} ψ∈RFint×1都是卷积操作, b ψ ∈ R {b_\psi } \in R bψ∈R, b g ∈ R F i n t {b_g} \in {R^{{F_{{\mathop{\rm int}} }}}} bg∈RFint为偏置项,分别对应 ψ \psi ψ, W g {W_g} Wg的卷积操作,而 W x {W_x} Wx无偏置项。 F i n t {{F_{{\mathop{\rm int}} }}} Fint一般比 F g {{F_g}} Fg和 F l {{F_l}} Fl要小。在图像标注和分类任务中一般都用Softmax函数,之所以 σ2使用Sigmoid函数,是因为顺序使用Softmax函数会输出较稀疏的激活响应,而用Sigmoid函数能够使训练更好的收敛。门控信号不是全局图像的表示矢量,而是在一定条件下部分图像空间信息的网格信号,每个skip connection的门控信号都会汇总来自多个成像比例的信息。
上图直观地表明了 x x x和 g g g的输入,注意系数(Attention coefficients): α i ∈ [ 0 , 1 ] {\alpha _i} \in \left[ {0,1} \right] αi∈[0,1]是通过 σ 2 {\sigma _2} σ2这个Sigmoid激活函数得到的。
举一个具体的例子来理解。
上图显示了不同训练时期(epochs:3、6、10、60、150)时的注意力系数,表明注意力越来越集中在感兴趣的部分。
3. 软注意力与硬注意力
软注意力(Soft Attention):软(确定性)注意力机制使用所有键的加权平均值来构建上下文向量。对于软注意力,注意力模块相对于输入是可微的,因此整个系统仍然可以通过标准的反向传播方法进行训练。软注意力数学描述如下:
其中 f ( q , k ) f\left( {q,k} \right) f(q,k)的有很多种计算方法,如下表所示:
硬注意力(Hard Attention):硬(随机)注意力中的上下文向量是根据随机采样的键计算的。硬注意力可以实现如下:
注:多项式分布是二项式分布的推广。二项式做n次伯努利实验,规定了每次试验的结果只有两个。如果现在还是做n次试验,只不过每次试验的结果可以有m个,且m个结果发生的概率互斥且和为1,则发生其中一个结果X次的概率就是多项分布。概率密度函数是:
两者的对比与改进方案:与软注意力模型相比,硬注意力模型的计算成本更低,因为它不需要每次都计算所有元素的注意力权重。 然而,在输入特征的每个位置做出艰难的决定会使模块不可微且难以优化,因此可以通过最大化近似变分下限或等效地通过 REINFORCE 来训练整个系统。 在此基础上,Luong 等人提出了机器翻译的全局注意力和局部注意力机制。 全局注意力类似于软注意力。 局部注意力可以看作是硬注意力和软注意力之间的有趣混合,其中一次只考虑源词的一个子集。局部注意力这种方法在计算上比全局注意力或软注意力更便宜。 同时,与硬注意力不同,这种方法几乎在任何地方都是可微的,从而更容易实现和训练。
4. 实验
AGs(Attention Gates)是模块化的,与应用类型无关; 因此,它可以很容易地适应分类和回归任务。为了证明其对图像分割的适用性,论文在具有挑战性的腹部CT多标签分割问题上评估Attention U-Net模型。特别是,由于形状变化和组织对比度差,胰腺边界描绘是一项艰巨的任务。Attention U-Net模型在分割性能,模型容量,计算时间和内存要求方面与标准3D U-Net进行了比较。
- 评估数据集:NIH-TCIA 和 这篇论文中的。
- 实施细节:有一个3D的模型,Adam,BN,deep-supervision和标准数据增强技术(仿射变换,轴向翻转,随机裁剪)。
- 注意力图分析:我们通常观察到AG最初具有均匀分布并且在所有位置,然后逐步更新和定位到目标器官边界。在较粗糙的尺度上,AG提供了粗略的器官轮廓,这些器官在更精细的分辨率下逐渐细化。 此外,通过在每个图像尺度上训练多个AG,我们观察到每个AG学习专注于器官的特定子集。
- 分割实验: 性能比U-Net高 2−3% 。
5. 代码实践
# 导入TensorFlow及其子模块
import tensorflow as tf
from tensorflow.keras import models, layers, regularizers
from tensorflow.keras import backend as K # Keras后端,用于底层操作
# 定义Dice系数指标函数(常用于图像分割任务评估)
def dice_coef(y_true, y_pred):
# 将真实标签和预测结果展平为一维向量
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
# 计算两个向量的交集(对应元素相乘后求和)
intersection = K.sum(y_true_f * y_pred_f)
# Dice系数公式:(2*交集 + 平滑项)/(真实标签总和 + 预测结果总和 + 平滑项)
# 平滑项(1.0)用于防止分母为零
return (2.0 * intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.0)
# 定义Jaccard系数(交并比)指标函数
def jacard_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
# Jaccard系数公式:(交集 + 平滑项)/(并集 + 平滑项)
# 并集 = 真实标签总和 + 预测结果总和 - 交集
return (intersection + 1.0) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + 1.0)
# 定义基于Jaccard系数的损失函数
def jacard_coef_loss(y_true, y_pred):
# 返回负的Jaccard系数,因为损失函数需要最小化,而Jaccard系数需要最大化
return -jacard_coef(y_true, y_pred)
# 定义基于Dice系数的损失函数
def dice_coef_loss(y_true, y_pred):
# 返回负的Dice系数,原理同上
return -dice_coef(y_true, y_pred)
"""
代码特性说明:
1. 适用于图像分割任务,Dice和Jaccard系数能有效评估像素级预测精度
2. 使用K.flatten处理确保适用于不同形状的输出(如batch_size, height, width, channels)
3. 平滑项(+1.0)的作用:
- 防止分母为零的数学错误
- 当预测和标签均为全黑时仍能给出合理评估值
- 起到正则化作用,使指标对极端情况更鲁棒
4. 损失函数通过取负数将指标最大化问题转化为最小化问题,符合优化器的工作方式
"""
# 定义构建U-Net架构的各个模块
def conv_block(x, filter_size, size, dropout, batch_norm=False):
"""
卷积块:包含两个卷积层,可选批归一化和Dropout
参数:
x: 输入张量
filter_size: 卷积核尺寸(整数,如3表示3x3卷积)
size: 卷积核数量(输出通道数)
dropout: Dropout比率(0表示不使用)
batch_norm: 是否使用批归一化
返回:
经过卷积处理后的张量
"""
# 第一个卷积层
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(x)
if batch_norm:
conv = layers.BatchNormalization(axis=3)(conv) # 沿通道轴做批归一化
conv = layers.Activation("relu")(conv) # ReLU激活
# 第二个卷积层
conv = layers.Conv2D(size, (filter_size, filter_size), padding="same")(conv)
if batch_norm:
conv = layers.BatchNormalization(axis=3)(conv)
conv = layers.Activation("relu")(conv)
# 应用Dropout(当比率大于0时)
if dropout > 0:
conv = layers.Dropout(dropout)(conv)
return conv
def repeat_elem(tensor, rep):
"""
张量元素重复:沿着最后一个轴重复张量元素
示例:
输入形状(None, 256,256,3),指定axis=3和rep=2时
输出形状(None, 256,256,6)
"""
return layers.Lambda(
lambda x, repnum: K.repeat_elements(x, repnum, axis=3), # 使用Keras后端函数
arguments={'repnum': rep} # 传入重复次数参数
)(tensor)
def gating_signal(input, out_size, batch_norm=False):
"""
生成门控信号:使用1x1卷积调整特征图维度,匹配上采样层尺寸
返回:
与上层特征图维度相同的门控特征图
"""
x = layers.Conv2D(out_size, (1, 1), padding='same')(input) # 1x1卷积调整通道数
if batch_norm:
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x) # ReLU激活
return x
def attention_block(x, gating, inter_shape):
"""
注意力机制块:通过门控信号学习空间注意力权重
参数:
x: 跳跃连接的特征图
gating: 门控信号(来自深层网络)
inter_shape: 中间特征通道数
返回:
应用注意力权重后的特征图
"""
# 获取输入特征图尺寸
shape_x = K.int_shape(x)
shape_g = K.int_shape(gating)
# 将x的特征图下采样到门控信号尺寸
theta_x = layers.Conv2D(inter_shape, (2, 2), strides=(2, 2), padding='same')(x)
shape_theta_x = K.int_shape(theta_x)
# 调整门控信号通道数
phi_g = layers.Conv2D(inter_shape, (1, 1), padding='same')(gating)
# 上采样门控信号到theta_x的尺寸
upsample_g = layers.Conv2DTranspose(
inter_shape,
(3, 3),
strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]),
padding='same'
)(phi_g)
# 合并特征
concat_xg = layers.add([upsample_g, theta_x])
act_xg = layers.Activation('relu')(concat_xg)
# 生成注意力权重
psi = layers.Conv2D(1, (1, 1), padding='same')(act_xg)
sigmoid_xg = layers.Activation('sigmoid')(psi)
# 上采样注意力权重到原始x的尺寸
shape_sigmoid = K.int_shape(sigmoid_xg)
upsample_psi = layers.UpSampling2D(
size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2])
)(sigmoid_xg)
# 扩展注意力权重的通道数
upsample_psi = repeat_elem(upsample_psi, shape_x[3])
# 应用注意力权重
y = layers.multiply([upsample_psi, x])
# 最终卷积和批归一化
result = layers.Conv2D(shape_x[3], (1, 1), padding='same')(y)
result_bn = layers.BatchNormalization()(result)
return result_bn
# 定义Attention U-Net模型
def Attention_UNet(input_shape, NUM_CLASSES=1, dropout_rate=0.0, batch_norm=True):
'''
Attention UNet网络实现
参数:
input_shape: 输入张量的形状(高度, 宽度, 通道数)
NUM_CLASSES: 输出类别数(默认为1,二分类问题)
dropout_rate: Dropout层的丢弃率(0.0表示不使用)
batch_norm: 是否使用批量归一化(默认为True)
'''
# 网络结构超参数配置
FILTER_NUM = 64 # 第一层的基础卷积核数量
FILTER_SIZE = 3 # 卷积核尺寸
UP_SAMP_SIZE = 2 # 上采样比例
# 输入层(指定数据类型为float32)
inputs = layers.Input(input_shape, dtype=tf.float32)
# 下采样路径(编码器部分)
# 第一层:卷积块 + 最大池化
conv_128 = conv_block(inputs, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm) # 128x128 特征图
pool_64 = layers.MaxPooling2D(pool_size=(2,2))(conv_128) # 下采样到64x64
# 第二层:卷积块 + 最大池化
conv_64 = conv_block(pool_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm) # 64x64 特征图
pool_32 = layers.MaxPooling2D(pool_size=(2,2))(conv_64) # 下采样到32x32
# 第三层:卷积块 + 最大池化
conv_32 = conv_block(pool_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm) # 32x32 特征图
pool_16 = layers.MaxPooling2D(pool_size=(2,2))(conv_32) # 下采样到16x16
# 第四层:卷积块 + 最大池化
conv_16 = conv_block(pool_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm) # 16x16 特征图
pool_8 = layers.MaxPooling2D(pool_size=(2,2))(conv_16) # 下采样到8x8
# 第五层:底层卷积块(无池化)
conv_8 = conv_block(pool_8, FILTER_SIZE, 16*FILTER_NUM, dropout_rate, batch_norm) # 8x8 特征图
# 上采样路径(解码器部分,包含注意力门控)
# 第六层:生成门控信号 -> 注意力机制 -> 上采样 -> 特征拼接 -> 卷积块
gating_16 = gating_signal(conv_8, 8*FILTER_NUM, batch_norm) # 为16x16层生成门控信号
att_16 = attention_block(conv_16, gating_16, 8*FILTER_NUM) # 计算注意力权重
up_16 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(conv_8) # 上采样到16x16
up_16 = layers.concatenate([up_16, att_16], axis=3) # 拼接跳跃连接和注意力加权的特征
up_conv_16 = conv_block(up_16, FILTER_SIZE, 8*FILTER_NUM, dropout_rate, batch_norm)
# 第七层(结构同上)
gating_32 = gating_signal(up_conv_16, 4*FILTER_NUM, batch_norm)
att_32 = attention_block(conv_32, gating_32, 4*FILTER_NUM)
up_32 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_16)
up_32 = layers.concatenate([up_32, att_32], axis=3)
up_conv_32 = conv_block(up_32, FILTER_SIZE, 4*FILTER_NUM, dropout_rate, batch_norm)
# 第八层(结构同上)
gating_64 = gating_signal(up_conv_32, 2*FILTER_NUM, batch_norm)
att_64 = attention_block(conv_64, gating_64, 2*FILTER_NUM)
up_64 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_32)
up_64 = layers.concatenate([up_64, att_64], axis=3)
up_conv_64 = conv_block(up_64, FILTER_SIZE, 2*FILTER_NUM, dropout_rate, batch_norm)
# 第九层(结构同上)
gating_128 = gating_signal(up_conv_64, FILTER_NUM, batch_norm)
att_128 = attention_block(conv_128, gating_128, FILTER_NUM)
up_128 = layers.UpSampling2D(size=(UP_SAMP_SIZE, UP_SAMP_SIZE))(up_conv_64)
up_128 = layers.concatenate([up_128, att_128], axis=3)
up_conv_128 = conv_block(up_128, FILTER_SIZE, FILTER_NUM, dropout_rate, batch_norm)
# 输出层:1x1卷积 + 批量归一化 + 激活函数
conv_final = layers.Conv2D(NUM_CLASSES, kernel_size=(1,1))(up_conv_128) # 调整通道数为类别数
conv_final = layers.BatchNormalization(axis=3)(conv_final)
conv_final = layers.Activation('sigmoid')(conv_final) # 二分类使用sigmoid,多分类需改为softmax
# 构建并返回模型
model = models.Model(inputs, conv_final, name="Attention_UNet")
return model
总结
Attention U-Net通过在U-Net模型中引入注意力门模块,解决了传统跳跃连接中冗余信息的问题,显著提高了语义分割任务中的模型性能和敏感度。通过加性注意力机制,AG模块能够突出感兴趣区域的特征响应,同时抑制无关信息,使分割更加精准。此外,对软注意力和硬注意力的概念及其对模型优化的作用也得到了深入理解。结合代码实践,进一步掌握了Attention U-Net的网络架构及其实现方法。