AOT源码解析4.5-AOT整体结构
论文阅读
- paper
- github
- 论文阅读笔记
- AOT源码解析1-数据集处理
- AOT源码解析2-encoder+decoder
- AOT源码解析3-模型训练
- AOT源码解析4.1-model主体
- AOT源码解析4.2-model主体
- AOT源码解析4.3-model主体
- AOT源码解析4.4-model主体
- AOT源码解析4.5-model主体
4.1~4.4小节详细讲解了ref_imgs相关的操作步骤,中间的结构诸如位置编码和ID编码生成、LSTT结构、Encoder和Decoder结构等都已详细讲解,接下来这些结构的细节不再赘述。
4、整体AOT结构
4.1 输入数据处理
图1:如图所示,上半部分是输入图像的处理,下半部分是mask的处理,最后得到enc_embs、masks和one-hot-masks
在这里,我们回顾AOT模型是如何使用Encoder和one-hot-mask模块处理输入图像和mask的。
-
encoder处理输入图像
输入数据的shape为[20,3,465,465],在这里batch_size为4.因此代表输入数据包含4个batch的数据,每个batch包含5张图像。
将输入数据输给mobilenetv2网络,提取四个不同比例的中间输出特征图,并将这些特征图分成五块。那么这五块的每一块都代表一张图像的特征。
如图1右上角所示,按照图像进行对分块的特征图进行分类,每一个图像都包含四个特征图。在这里我们将最后一个输出特征图命名为关键特征图,后续的LSTT模块只使用关键特征图。 -
处理mask
one-hot-mask的操作就是将mask的每一个对象和背景进行分离,并赋予它们一个one-hot编码。
对于原始mask和进行了one-hot编码的mask,我们都将它们按照图像进行分块,每一块代表不同图像的掩码情况。
4.2 添加参考图像信息
图2 ,如图所示是参考图像的特征添加。
在这里我们先回顾以下添加参考图像的步骤
-
1)位置编码和ID编码*
使用关键特征图生成位置编码,使用one-hot-mask生成ID编码。这两个编码添加到特征图的方式都是直接相加。 -
2)使用LSTT模块
LSTT模块就是进行了三次Multi-Head Attention结构,不同的是KQV的输入。2.1)self-attention
KQ:关键特征与位置编码融合后的特征图作为KQ特征图
V:关键特征图
输出:ref_self_embs2.2)long-term attention
KQ:ref_self_embs
V:ref_self_embs结合ID编码
长期记忆:K(ref_self_embs),V(ref_self_embs结合ID编码)2.3)short-term attention
KQ:ref_self_embs分块
V:ref_self_embs结合ID编码后,再分块
短期记忆:K(ref_self_embs分块),V(ref_self_embs结合ID编码后,再分块)2.4)LSTT的输出
self-attention的输出+long-term attention的输出+short-term attention的输出 -
3)使用Decoder得到预测结果
将关键特征图和其余特征图进行逐步融合、上采样后,得到最终的预测mask -
4)计算损失
根据预测mask和真实mask计算损失,值得注意的是,这里只关注偏差前15%的像素损失。
4.3 传播预测图像
预测其他帧的步骤与”添加参考图像信息“这一步相似相似。使用到的模块(诸如LSTT等)已经在4.1~4.4中详细阐述了,因此接下来不再赘述。
图3,如图所示,是预测其余图像的步骤
AOT模型在预测其余帧时,步骤如下:
-
1)位置编码和ID编码*
使用关键特征图生成位置编码,使用one-hot-mask生成ID编码。这两个编码添加到特征图的方式都是直接相加。 -
2)LSTT模块
LSTT模块就是进行了三次Multi-Head Attention结构,不同的是KQV的输入。2.1)self-attention
KQ:关键特征与位置编码融合后的特征图作为KQ特征图
V:关键特征图
输出:curr_self_embs2.2)long-term attention
Q:curr_self_embs
K:ref_img中得到的长期记忆,即ref_self_embs
V:ref_img中得到的长期记忆,即ref_self_embs结合ID编码
长期记忆:K(ref_self_embs),V(ref_self_embs结合ID编码)
长期记忆一直使用结合了ref_imgs的长期记忆,短期记忆会不断更新
2.3)short-term attention
Q:curr_self_embs分块
K:前一帧图像的slef-attention模块的输出,并进行分块
V:前一帧图像的slef-attention模块的输出结合ID编码后,再分块 -
3)更新短期记忆
本帧图像的slef-attention输出并分块+本帧图像的slef-attention模块的输出结合ID编码后,再分块
本次更新的短期记忆将用于下一帧图像的LSTT模块计算。
可以发现,本文传播预测mask时,采取的是transformer+记忆机制的方式来进行时序信息传播的。共有三个Multi-Head-Attention结构。
- 本帧特征
第一个Multi-Head-Attention结构的输入是本帧图像的关键特征图,得到本帧图像的信息- 长期记忆
长期记忆是参考图像和参考图像与对应掩码的结合。使用第二个Multi-Head-Attention结构找到本帧图像与长期记忆的相关性。
在这里,长期记忆是视频分割模块的核心和唯一分割标准- 短期记忆
这里的短期记忆是前一帧图像和前一帧图像与前一个掩码结合的组合,使用第三个Multi-Head-Attention结构计算本帧图像与前一帧图像的区别。在这里短期记忆辅助本帧图像进行分割。
- 4)decoder得到预测mask并计算loss
这个和参考图像一致,没什么好说的。
至此,AOT的结构已经讲解完毕、