当前位置: 首页 > article >正文

【连续学习之ResCL算法】2020年AAAI会议论文:Residual continual learning

1 介绍

年份:2020

会议: AAAI

Lee J, Joo D, Hong H G, et al. Residual continual learning[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2020, 34(04): 4553-4560.

本文提出的算法是Residual Continual Learning (ResCL),其核心原理是通过在线性组合原始网络和微调网络的每层参数来重新参数化网络,以此在连续学习多个任务时防止灾难性遗忘现象,同时不需要任何源任务信息,保持网络大小不变。先第一次训练,采用LWF算法知识蒸馏到新模型,再合并LWF模型和旧模型后,然后采用的合并后的损失函数第二次训练合并后的模型。ResCL算法属于基于架构的算法,因为它通过调整网络架构来实现连续学习,同时它也利用了基于正则化的方法来优化组合网络的参数

2 创新点

  1. 残差学习重参数化:ResCL通过线性组合原始网络和微调网络的每层来重新参数化网络参数,这种残差学习的方法允许网络在连续学习多个任务时有效地控制源任务和目标任务之间的性能权衡。
  2. 无需源任务信息:ResCL不需要除了原始网络之外的任何源任务信息,这使得它在实际应用中更加实用,尤其是在源数据不可用或难以处理的情况下。
  3. 网络大小不变:与传统的网络扩展方法不同,ResCL在推理阶段不会增加网络的大小,这对于在资源受限的环境中部署深度学习模型尤为重要。
  4. 特殊权重衰减损失:ResCL引入了一种特殊的权重衰减损失,这种损失函数专门设计用于连续学习,能够有效地防止遗忘源任务知识。
  5. 适用于通用CNNs:该方法可以自然地应用于包括批量归一化(Batch Normalization)层在内的通用卷积神经网络。
  6. 公平比较的衡量标准:ResCL提出了两种衡量不同连续学习方法的公平标准,即最大可实现平均准确率和在所需目标准确率下的源准确率,这些标准不依赖于特定的超参数设置。
  7. 连续学习多个任务:ResCL不仅适用于两个任务的连续学习,还可以扩展到三个或更多任务的连续学习场景,展示了良好的扩展性和适应性。

3 算法

3.1 算法原理

  1. 重新参数化
    ResCL通过线性组合原始网络(源任务网络)和微调网络(目标任务网络)的每层参数来重新参数化网络。对于全连接层,如果 W s W_s Ws是源网络的权重矩阵, W t W_t Wt是微调网络的权重矩阵,那么组合层的输出可以表示为:

( 1 C o + α s ) ⊙ ( W s x ) + α t ⊙ ( W t x ) (1_{Co} + \alpha_s) \odot (W_s x) + \alpha_t \odot (W_t x) (1Co+αs)(Wsx)+αt(Wtx)

其中, 1 C o 1_{Co} 1Co是全1向量, ⊙ \odot 表示逐元素乘法, α s \alpha_s αs α t \alpha_t αt是组合参数向量。这些参数通过反向传播学习得到,允许网络在保留源任务知识的同时适应新任务。

  1. 特殊权重衰减损失
    为了防止遗忘源任务,ResCL引入了一种特殊的权重衰减损失。不同于常规的权重衰减损失,ResCL的权重衰减损失专门设计用于连续学习,其目的是保护原始网络权重不受目标任务影响。这种损失函数可以表示为:

λ ∣ ∣ ( α s , α t ) ∣ ∣ 1 \lambda ||(\alpha_s, \alpha_t)||_1 λ∣∣(αs,αt)1

其中, λ \lambda λ是权衡超参数,控制源任务和目标任务性能之间的平衡。

  1. 批量归一化层的处理
    考虑到现代卷积神经网络(CNN)中广泛使用的批量归一化(Batch Normalization, BN)层,ResCL也考虑了BN层的影响。每个子网络都有自己的BN层,以保持各自任务的分布信息。在推理阶段,BN层可以被视为确定性线性层,因此可以与卷积层合并。
  2. 训练过程
    ResCL的训练包括两个阶段:首先是微调阶段,微调网络 n e t t nett nett在目标任务上进行训练;其次是组合网络训练阶段,组合网络 n e t c netc netc在源任务和目标任务上进行训练,同时使用LwF损失和权重衰减损失来保持源任务性能。
  3. 性能权衡
    ResCL通过超参数 λ \lambda λ来控制源任务和目标任务之间的性能权衡。通过调整 λ \lambda λ,可以在保持源任务性能和提高目标任务性能之间找到平衡点。

3.2 算法步骤

  1. 初始训练(源任务训练)
    • 首先,有一个已经在源任务上训练好的网络,我们称之为源网络 n e t s ( ⋅ ; θ s ∗ ) nets(\cdot; \theta^*_s) nets(;θs)
  2. 微调(目标任务训练)
    • 接着,使用源网络的参数初始化一个新的网络 n e t t ( ⋅ ; θ t ) nett(\cdot; \theta_t) nett(;θt),并在目标任务数据上进行微调,得到微调后的网络参数 θ t ∗ \theta^*_t θt。这个过程类似于 LwF 中的第一步,即在目标任务上微调模型,但并不完全相同,因为 ResCL 的目标是重新参数化网络以保留源任务的知识。
  3. 合并网络
    • 然后,ResCL 将源网络和微调网络的每层参数线性组合,形成一个新的组合网络 n e t c netc netc。这个组合网络包含了可学习的组合参数 α s \alpha_s αs α t \alpha_t αt,它们决定了如何混合源网络和微调网络的特征。
  4. 第二次训练(组合网络训练)
    • 最后,对组合网络 n e t c netc netc进行训练。这个训练过程涉及到特殊的损失函数,包括:
      • LwF 损失:用于保持源任务的性能。
      • 知识蒸馏损失:用于使目标任务的输出接近微调网络的输出。
      • 权重衰减损失:用于防止遗忘,通过衰减组合参数 α s \alpha_s αs α t \alpha_t αt来保护源网络的特征。

4 实验结果

图1展示了ResCL方法的流程,其中包括原始网络的微调、组合层的引入以及使用LwF损失和衰减损失进行的持续学习,以保持源任务性能并适应目标任务。

图2展示了源网络和目标网络的预激活残差单元如何通过组合层在每个非线性激活之前合并,形成在推理阶段使用的等效组合网络,且该网络与原始网络具有相同的规模。

5 思考

(1)训练两次,第一次先蒸馏、第二次是将蒸馏后的模型与旧模型合并,合并后再重新训练一遍

(2)如何组合两个网络?

在Residual Continual Learning (ResCL)算法中,组合两个网络是通过线性组合每一层的输出来实现的。这个过程涉及到源网络(已经训练好的网络,记作 n e t s nets nets)和目标网络(在目标任务上微调后的网络,记作 n e t t nett nett)。

  1. 线性组合
    对于网络中的每个层(假设是全连接层或卷积层),ResCL算法将源网络和目标网络对应层的输出进行线性组合。对于全连接层,如果 W s W_s Ws是源网络层的权重矩阵, W t W_t Wt是目标网络层的权重矩阵,那么组合层的输出可以表示为:

( 1 C o + α s ) ⊙ ( W s x ) + α t ⊙ ( W t x ) (1_{Co} + \alpha_s) \odot (W_s x) + \alpha_t \odot (W_t x) (1Co+αs)(Wsx)+αt(Wtx)

其中, 1 C o 1_{Co} 1Co是一个全1向量,用于与 W s x W_s x Wsx相加, ⊙ \odot 表示逐元素乘法, α s \alpha_s αs α t \alpha_t αt是组合参数向量,它们是可学习的参数,用于控制源网络和目标网络输出的混合比例。

  1. 学习组合参数
    组合参数 α s \alpha_s αs α t \alpha_t αt通过反向传播和梯度下降进行学习和优化。这些参数允许网络在保留源任务知识的同时适应新任务。
  2. 保持网络大小不变
    由于这种线性组合是在特征层面进行的,最终的网络在推理阶段的规模不会增加。这是因为组合层可以被看作是重新参数化,而不是增加额外的层或参数。
  3. 处理非线性层
    对于非线性层(如ReLU激活层),这些层不能被包含在组合中,因为它们不是线性的。因此,组合层应该在非线性层之前应用,以确保网络的非线性特性得以保留。
  4. 批量归一化层的处理
    对于包含批量归一化(Batch Normalization, BN)层的网络,ResCL算法为每个任务保留了独立的BN层,并在组合网络中使用。这是因为BN层依赖于特定任务的统计信息,因此在连续学习中需要保持这些信息以避免遗忘。

(3)论文中的残差结构体现在哪?

残差结构主要体现在网络参数的重新参数化和权重衰减损失的设计上

  1. 线性组合的重新参数化
    ResCL算法通过线性组合原始网络(源网络)和微调网络(目标网络)的每层参数来实现残差学习。对于全连接层或卷积层,这种组合可以表示为:

( 1 C o + α s ) ⊙ ( W s x ) + α t ⊙ ( W t x ) (1_{Co} + \alpha_s) \odot (W_s x) + \alpha_t \odot (W_t x) (1Co+αs)(Wsx)+αt(Wtx)

其中, W s W_s Ws W t W_t Wt分别是源网络和目标网络的权重, α s \alpha_s αs α t \alpha_t αt是可学习的组合参数, 1 C o 1_{Co} 1Co是全1向量。这种结构允许网络学习源网络和目标网络之间的残差,即目标网络相对于源网络的变化。

  1. 残差学习的思想
    ResCL算法的设计理念与残差学习(Residual Learning)相似,残差学习通过将输入直接添加到网络的深层来促进深层网络的训练。在ResCL中,通过重新参数化,网络学习如何调整源网络的输出以适应目标任务,这可以看作是学习源网络输出到目标网络输出的“残差”。
  2. 权重衰减损失
    ResCL中的权重衰减损失( λ ∣ ∣ ( α s , α t ) ∣ ∣ 1 \lambda ||(\alpha_s, \alpha_t)||_1 λ∣∣(αs,αt)1)有助于控制组合参数,使得只有对目标任务必要的特征才会显著偏离源网络的特征。这种设计类似于在残差网络中通过权重衰减来促使残差部分的权重趋于零,从而保持网络的稀疏性和效率。

http://www.kler.cn/a/461135.html

相关文章:

  • 【AI大模型】深入GPT-2模型细节:揭秘其卓越性能的秘密
  • oceanbase集群访问异常问题处理
  • STM32--超声波模块(HC—SR04)(标准库+HAL库)
  • vue2+echarts实现水球+外层动效
  • 从单点 Redis 到 1 主 2 从 3 哨兵的架构演进之路
  • HTML——57. type和name属性
  • 离散数学 群(半群,群,交换群,循环群,对称群,置换群,置换,交代群,轮换)详细,复习笔记
  • LeetCode热题100-反转链表【JavaScript讲解】
  • 【每日学点鸿蒙知识】Json字典问题、高度变化问题、开放测试版本问题、动态库单架构选择、WebView和H5交互
  • 【每日学点鸿蒙知识】人脸活体检测、NodeController刷新、自动关闭输入框、Row设置中间最大宽、WebView单例
  • JavaWeb 开发进阶 - 数据库交互与框架应用
  • 五、Hadoop环境搭建之模板虚拟机准备
  • tomcat窗口闪退,以及在eclipse上面运行不出来
  • HTML5滑块(Slider)
  • 从家谱的层级结构 - 组合模式(Composite Pattern)
  • es单机安装脚本自动化
  • hive-sql 计算每年在校生人数
  • 写在2024的最后一天
  • 【浏览器】缓存
  • Android 检测设备是否 Root
  • 【数据结构】线性数据结构——栈
  • 本地部署Hello-Algo打造私人算法教练让算法学习告别网络限制
  • 解构大语言模型(LLM)
  • 如何免费解锁 IPhone 网络
  • 如何使用 ChatGPT Prompts 写学术论文?
  • 嵌入式单片机中SPI外设控制与实现