【连续学习之LwM算法】2019年CVPR顶会论文:Learning without memorizing
1 介绍
年份:2019
期刊: 2019CVPR
引用量:611
Dhar P, Singh R V, Peng K C, et al. Learning without memorizing[C]//Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019: 5138-5146.
本文提出的“Learning without Memorizing (LwM)”算法原理是通过结合知识蒸馏损失(LD)和注意力蒸馏损失(LAD)来实现增量学习,关键技术步骤是在训练过程中,利用注意力机制(Grad-CAM)生成的注意力图来保持对基础类别的知识,同时学习新类别,而无需存储任何基础类别的数据。本文算法属于基于正则化的算法。
2 创新点
- 无需存储数据的增量学习(Learning without Memorizing):
- 提出了一种新颖的增量学习方法,能够在不存储任何关于现有类别数据的情况下,保持对现有类别的知识,同时学习新类别。
- 注意力蒸馏损失(Attention Distillation Loss, LAD):
- 引入了一种基于注意力图的信息保持惩罚项LAD,通过惩罚分类器注意力图的变化来保留基础类别的信息。
- 注意力图的利用:
- 使用Grad-CAM技术生成注意力图,这些图能够更精确地编码模型的表示,并用于约束学生模型和教师模型之间的表示差异。
- 信息保持的全面提升:
- 通过结合LAD和LD,不仅考虑了类别分数的分布,还考虑了模型梯度流信息,从而更全面地保持了对基础类别的知识。
- 增量学习中的类特异性解释:
- 通过类特异性的注意力图,强化了教师模型和学生模型之间类特异性解释的一致性,这对于没有基础类别数据的增量学习设置尤为重要。
- 实验验证:
- 在多个数据集上进行了实验验证,包括iILSVRC-small、iCIFAR-100、Caltech-101和CUBS-200-2011,证明了LwM方法在增量学习中的有效性和优越性能。
3 算法
3.1 算法原理
- 增量学习(Incremental Learning, IL)背景:
- 增量学习的目标是在不遗忘已学习类别(基础类别)的情况下,使模型能够识别新的类别。这是通过在模型上逐步添加新类别并更新模型来实现的。
- 信息保持惩罚(Information Preserving Penalty, IPP):
- 在增量学习中,为了保持对基础类别的知识,需要一种机制来惩罚模型在新类别学习过程中对基础类别知识造成的遗忘。LwM通过IPP实现这一点。
- 注意力蒸馏损失(Attention Distillation Loss, LAD):
- LwM提出了一种新的IPP,即注意力蒸馏损失(LAD),它通过比较教师模型(Mt-1)和学生模型(Mt)生成的注意力图来工作。注意力图是通过Grad-CAM技术从模型的卷积层特征图中生成的,它们能够显示模型在做出预测时哪些区域是重要的。
- 类特异性注意力图:
- LwM使用类特异性的注意力图来强制教师模型和学生模型之间的类特异性解释保持一致。这意味着模型需要对特定类别的视觉特征保持一致的关注。
- L1距离约束:
- LAD通过计算教师模型和学生模型生成的注意力图之间的L1距离来约束它们的差异,迫使学生模型在学习新类别时保持与教师模型对基础类别的相似响应。
- 损失函数组合:
- LwM的总损失函数是分类损失(LC)、知识蒸馏损失(LD)和注意力蒸馏损失(LAD)的组合,即LLwM = LC + βLD + γLAD,其中β和γ是用于平衡两种损失的权重。
- 增量学习过程:
- 在每个增量步骤中,学生模型Mt被初始化为前一个教师模型Mt-1,并在新类别的数据上进行训练。同时,通过LAD和LD来惩罚Mt与Mt-1之间的差异,以保持对基础类别的知识。
- 无需存储基础类别数据:
- LwM的一个关键特点是在学习新类别时不需要存储任何关于基础类别的数据,这使得该方法在内存受限的边缘设备上特别有用。
3.2 算法步骤
图中展示了应用三种损失函数来训练学生模型$ M_t < f o n t s t y l e = " c o l o r : r g b ( 6 , 6 , 7 ) ; " > ,而教师模型 < / f o n t > <font style="color:rgb(6, 6, 7);">,而教师模型</font> <fontstyle="color:rgb(6,6,7);">,而教师模型</font> M_{t-1} $保持冻结状态。
算法步骤总结:
- 初始化教师模型$ M_0 :使用分类损失 :使用分类损失 :使用分类损失 LC 训练 训练 训练 M_0 在 在 在 N $个基础类别上。
- 对于每个增量步骤$ t = 1 到 到 到 k : a . 使用 : a. 使用 :a.使用 M_{t-1} 初始化学生模型 初始化学生模型 初始化学生模型 M_t 。 b . 将新类别的数据输入到 。 b. 将新类别的数据输入到 。b.将新类别的数据输入到 M_{t-1} 和 和 和 M_t 中。 c . 计算 中。 c. 计算 中。c.计算 M_t 和 和 和 M_{t-1} 的输出,包括类特异性注意力图和分数。 d . 应用信息保持惩罚( I P P ),包括注意力蒸馏损失( L A D )和知识蒸馏损失( L D )。 e . 对 的输出,包括类特异性注意力图和分数。 d. 应用信息保持惩罚(IPP),包括注意力蒸馏损失(LAD)和知识蒸馏损失(LD)。 e. 对 的输出,包括类特异性注意力图和分数。d.应用信息保持惩罚(IPP),包括注意力蒸馏损失(LAD)和知识蒸馏损失(LD)。e.对 M_t 应用分类损失 应用分类损失 应用分类损失 LC ,基于其对新类别的输出。 f . 联合应用分类损失和 I P P ,训练 ,基于其对新类别的输出。 f. 联合应用分类损失和IPP,训练 ,基于其对新类别的输出。f.联合应用分类损失和IPP,训练 M_t 进行多个周期。 g . 将 进行多个周期。 g. 将 进行多个周期。g.将 M_t $作为下一个增量步骤的教师模型。
- 在每个增量步骤后评估$ M_t $的性能。
损失函数计算公式:
总损失函数$ LLwM
是分类损失
是分类损失
是分类损失 LC
、知识蒸馏损失
、知识蒸馏损失
、知识蒸馏损失 LD
和注意力蒸馏损失
和注意力蒸馏损失
和注意力蒸馏损失 LAD $的组合:
$ LLwM = LC + \beta \cdot LD + \gamma \cdot LAD $
其中:
- $ LC 是分类损失,用于训练 是分类损失,用于训练 是分类损失,用于训练 M_t $学习新类别。
- $ LD 是知识蒸馏损失,用于保持 是知识蒸馏损失,用于保持 是知识蒸馏损失,用于保持 M_t 和 和 和 M_{t-1} $对基础类别预测的一致性:
$ LD(y, \hat{y}) = -\sum_{i=1}^{N} y_{0i} \cdot \log(\hat{y}_{0i}) $
其中$ y 和 和 和 \hat{y} 分别是 分别是 分别是 M_{t-1} 和 和 和 M_t 对基础类别的预测向量, 对基础类别的预测向量, 对基础类别的预测向量, y_{0i} = \sigma(y_i) 和 和 和 \hat{y}_{0i} = \sigma(\hat{y}_i) ( ( ( \sigma(\cdot) $是sigmoid激活函数)。
- $ LAD 是注意力蒸馏损失,用于保持 是注意力蒸馏损失,用于保持 是注意力蒸馏损失,用于保持 M_t 和 和 和 M_{t-1} $生成的注意力图的一致性:
$ LAD = \sum_{j=1}^{l} \left| \frac{Q_{In,b}{t-1,j}}{|Q_{In,b}{t-1}|2} - \frac{Q{In,b}{t,j}}{|Q_{In,b}{t}|_2} \right|_1 $
其中$ Q_{In,b}^{t-1} 和 和 和 Q_{In,b}^{t} 分别是 分别是 分别是 M_{t-1} 和 和 和 M_t 为输入图像 为输入图像 为输入图像 In 和基础类别 和基础类别 和基础类别 b 生成的向量化注意力图, 生成的向量化注意力图, 生成的向量化注意力图, l $是每个向量化注意力图的长度。
4 实验分析
展示了在不同的增量学习配置下,注意力图如何随着增量学习步骤的变化而变化。这些配置包括传统的分类训练(C)、LwF-MC方法,以及本文提出的LwM方法。
- 基线比较:
- LwM与LwF-MC作为基线进行比较,LwM在所有测试场景中均优于LwF-MC。
- 性能提升:
- 在iILSVRC-small数据集上,当类别数量达到40或更多时,LwM的性能比LwF-MC提高了30%以上。
- 在100个类别时,LwM的性能比LwF-MC提高了50%以上。
- 不同数据集上的表现:
- LwM在iILSVRC-small和iCIFAR-100数据集上的表现一致优于LwF-MC。
- LwM甚至在iILSVRC-small数据集上超过了iCaRL,尽管iCaRL在训练学生模型时有访问基础类别数据的优势。
- 增量批次大小的影响:
- 在iCIFAR-100数据集上,LwM在不同大小的增量批次(10、20、50类)中均优于LwF-MC。
- Caltech-101和CUBS-200-2011数据集:
- 在这两个数据集上,LwM通过每次增加10个类别的批次与微调(Finetuning)进行比较,显示出LwM在增量学习中的优势。
- LAD的有效性:
- 仅使用LC和LAD的组合在iILSVRC-small数据集上进行了测试,结果表明LAD有助于保持对基础类别的注意力,从而提高模型性能。
- 遗忘和惯性的量化:
- 通过比较不同配置下生成的注意力图,实验结果支持LwM在减少遗忘和提高模型对新类别的适应性方面的有效性。
- 定性结果:
- 通过图4展示的注意力图,LwM显示出在增量学习过程中更好地保持了对基础类别的注意力,与C和LwF-MC相比,LwM生成的注意力图更接近于初始教师模型M0的“理想”注意力图。
5 思考
(1)提出了一种基于注意力的损失函数应用到连续学习。
(2)只对比了一种算法,就发表了CVPR ,都没有对比多种的算法。