论文笔记:Gradient Episodic Memory for Continual Learning
1. Contribution
- 提出了一组指标来评估模型在连续数据上的学习情况。这些指标不仅通过测试准确性来表征模型,还通过其跨任务迁移知识的能力来表征模型。
- 针对持续学习任务,提出了GEM模型(Gradient Episodic Memory),它可以减少遗忘,同时允许将知识有益地迁移到以前的任务中。
- 在MNIST和CIFAR-100数据集上做了实验
2. Challenge/ Issues
独立且同分布的(iid假设),意思是:每张图片都是随机挑选的,彼此独立,没有关联
- 非独立同分布的数据(Non-iid input data):监督学习通过多次重复数据来让模型记住东西(不断地让模型重复学习一些图片:苹果、香蕉、橘子等很多水果),而人类学习是通过一次性有顺序地学习。(比如教一个小朋友时,可能不会重复展示同样的苹果或香蕉图片,你是按顺序给他看新的例子。小朋友也不会记住每一张图片,只能记住一部分。更重要的是,现实中的学习任务会不断变化——今天你可能在教水果,明天可能就开始教动物。)
- 灾难性遗忘(Catastrophic forgetting):如果直接用监督学习的方法去解决类似人类的学习任务(比如教不同的东西、不断变化的任务),那问题就来了。因为监督学习依赖于iid假设,认为所有数据都是来自同一分布,并且可以重复多次学习。但是在人类学习的场景中,iid假设不成立,数据不是独立的,也不是来自相同的任务分布。这样一来,监督学习的模型可能会因为新数据而“遗忘”旧的数据和知识,这就是所谓的“灾难性遗忘”。
- 知识迁移:当连续的任务相关时,存在迁移学习的机会,这可以转化为更快地学习新任务以及在旧任务上的性能提升。
3. 论文设定的scenario
在训练时,仅以三元组 (xi, ti, yi) 的形式向模型提供一个示例,model 永远不会两次经历相同的示例,并且任务按顺序流式传输。不需要对任务强加任何顺序,因为未来的任务可能与过去的任务同时发生。
3.1 数据三元组的构成:
每个数据样本由三部分组成:
x:特征向量,比如一张图片。
t:任务描述符,告诉我们当前任务是什么(例如“识别水果”或者“识别动物”)。
y:目标向量,即标签,表示图片的具体内容(如“苹果”或者“猫”)。
模型的目标是学会如何根据给定的图片 x 和任务描述符 t,预测出正确的标签 y。
3.2 局部iid(locally iid):
“局部iid”是指在某个特定的任务中(比如在任务 t 中),数据是独立同分布(iid)的,也就是说,在某个任务的学习阶段内,数据可以随机地、不相关地抽取出来。简单来说,局部iid意味着:
- 对于某一个任务 t 来说,图片和标签的关系是随机的,没有顺序影响。例如,如果任务是“识别水果”,那么苹果和香蕉图片的顺序是随机的,彼此之间没有相关性。
虽然在每个任务中,数据是随机独立的(iid),但在不同任务之间,数据不是随机的。例如,模型可能会先连续看到许多水果图片,然后才切换到动物识别任务,这使得任务间数据的顺序不是随机的。
3.3 学习目标
目标是学习一个预测模型 f,这个模型可以根据任意的图片 x 和任务描述符 t,预测出图片的标签 y。关键是,模型不仅需要识别当前任务的数据(如正在学习中的任务),还要记住以前学过的任务,甚至能够处理未来可能遇到的新任务。
例如:
- 如果模型已经学过了识别水果(任务 t1)和识别动物(任务 t2),那么即使任务结束后,模型仍然能够根据给定的任务描述符 t1 识别苹果,或者根据 t2 识别猫。
- 而且,模型还应该能处理未来可能出现的新任务。
4. Technical contribution
4.1 introduce a set of metrics
- Average Accuracy:这个指标是模型在所有任务上的平均表现。你教完机器人识别水果和动物后,测试它在所有任务上的表现,看它识别每个任务(如水果或动物)的正确率,然后取平均值。
- Backward transfer (BWT):这是评估模型在学习新任务后,对之前任务表现的影响。具体来说,如果学完了任务2,测试它在任务1上的表现,看有没有提高或下降。
- Forward transfer (FWT):这是评估模型在学习新任务时,之前学到的任务对新任务的影响。具体来说,在学任务2时,看看任务1是否帮助了任务2的学习。
Ri,j 表示模型在学完任务 ti 后,在任务 tj 上的测试准确率。
4.2 key point of GEM model
- 记忆分配策略: GEM 有一个总的记忆空间 M,用于存储多个任务的样本。如果任务总数 T 是已知的,那么每个任务可以分配 m = M / T 个记忆位置。如果任务数量不确定,GEM 可以随着新任务的到来动态减少每个任务的分配空间。
- 损失函数和约束条件:每当模型学习一个新任务时,GEM 会通过一个损失函数来调整模型的参数。然而,与传统方法不同的是,GEM 在更新参数时会限制过去任务的损失不增加,即:模型在新任务上的学习不能损害它对旧任务的表现。
- 使用梯度约束来防止遗忘:GEM 的一个核心思想是通过梯度投影来避免遗忘。当模型更新参数时,它会计算一个梯度。如果这个梯度方向会导致旧任务的损失增加,GEM 就会通过调整这个梯度方向,保证不会增加旧任务的损失。
- 解决梯度约束问题:如果梯度违反了约束,GEM 会通过求解一个二次规划问题(Quadratic Programming来找到一个新的梯度,使得在满足约束的前提下尽可能接近原始梯度。
4.3 GEM的训练伪代码
训练过程
- 初始化:
- Mt:每个任务的记忆被初始化为空集,之后会逐步添加样本。
- R:用于记录每个任务的测试准确率,初始化为全零矩阵。
- 任务循环:
- 对每个任务
t
:- 遍历训练数据集中的每个样本
(x, y)
:- 将样本加入到记忆中:将当前样本加入任务
t
的记忆中。 - 计算梯度:
- g:计算当前任务的损失函数梯度,表示当前任务在当前样本下的梯度。
- gk:对所有以前的任务
k < t
,计算其在记忆中的梯度。
- 梯度投影:使用梯度投影方法(公式 11),将当前梯度
g
调整为满足所有之前任务的约束条件,生成新的梯度̃g
。 - 更新参数: 根据投影后的梯度
̃g
更新模型参数
- 将样本加入到记忆中:将当前样本加入任务
- 评估任务:在每次任务学习完成后,评估模型在所有任务上的表现,并记录在矩阵 R 中
- 遍历训练数据集中的每个样本
- 返回训练结果:最终返回更新后的模型和记录的任务表现矩阵 R。
评估过程
- 任务循环:对每个任务 k,计算其测试准确率:
- 遍历任务 k 的所有测试样本 (x, y)。
- 计算准确率:累积预测结果的准确率。
- 平均准确率:最终将累积的准确率除以该任务的测试样本数量,得到任务 k 的准确率。
- 返回准确率向量:最终返回每个任务的测试准确率 r。