【Python实现连续学习算法】Python实现连续学习Baseline 及经典算法EWC
Python实现连续学习Baseline 及经典算法EWC
1 连续学习概念及灾难性遗忘
连续学习(Continual Learning)是一种模拟人类学习过程的机器学习方法,它旨在让模型在面对多个任务时能够连续学习,而不会遗忘已学到的知识。然而,大多数深度学习模型在连续学习多个任务时会出现“灾难性遗忘”(Catastrophic Forgetting)现象。灾难性遗忘指模型在学习新任务时会大幅度遗忘之前学到的任务知识,这是因为模型参数在新任务的训练过程中被完全覆盖。
解决灾难性遗忘问题是连续学习研究的核心。目前已有多种方法被提出,包括正则化方法、回放、架构等等的方法,其中EWC(Elastic Weight Consolidation)是一种经典的正则化方法。
2 PermutdMNIST数据集及模型
PermutedMNIST是连续学习领域的一种经典测试数据集。它通过对MNIST数据集中的像素进行随机置换生成不同的任务。每个任务都是一个由置换规则决定的分类问题,但所有任务共享相同的标签空间。
对于模型的选择,通常采用简单的全连接神经网络。网络结构可以包含若干个隐藏层,每个隐藏层具有一定数量的神经元,并使用ReLU作为激活函数。网络的输出层与标签类别数一致。
模型在训练每个任务时需要调整参数,研究灾难性遗忘问题的严重程度,并在引入算法时测试其对连续学习能力的改善效果。
import random
import torch
from torchvision import datasets
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class PermutedMNIST(datasets.MNIST):
def __init__(self, root="./data/mnist", train=True, permute_idx=None):
super(PermutedMNIST, self).__init__(root, train, download=True)
assert len(permute_idx) == 28 * 28
if self.train:
self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
for img in self.data])
else:
self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
for img in self.data])
def __getitem__(self, index):
if self.train:
img, target = self.data[index], self.train_labels[index]
else:
img, target = self.data[index], self.test_labels[index]
return img.view(1, 28, 28), target
def get_sample(self, sample_size):
random.seed(2024)
sample_idx = random.sample(range(len(self)), sample_size)
return [img.view(1, 28, 28) for img in self.data[sample_idx]]
def worker_init_fn(worker_id):
# 确保每个 worker 的随机种子一致
random.seed(2024 + worker_id)
np.random.seed(2024 + worker_id)
def get_permute_mnist(num_task, batch_size):
random.seed(2024)
train_loader = {}
test_loader = {}
root_dir = './data/permuted_mnist'
os.makedirs(root_dir, exist_ok=True)
for i in range(num_task):
permute_idx = list(range(28 * 28))
random.shuffle(permute_idx)
train_dataset_path = os.path.join(root_dir, f'train_dataset_{i}.pt')
test_dataset_path = os.path.join(root_dir, f'test_dataset_{i}.pt')
if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):
train_dataset = torch.load(train_dataset_path)
test_dataset = torch.load(test_dataset_path)
else:
train_dataset = PermutedMNIST(train=True, permute_idx=permute_idx)
test_dataset = PermutedMNIST(train=False, permute_idx=permute_idx)
torch.save(train_dataset, train_dataset_path)
torch.save(test_dataset, test_dataset_path)
train_loader[i] = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
# num_workers=1,
worker_init_fn=worker_init_fn,
pin_memory=True)
test_loader[i] = DataLoader(test_dataset,
batch_size=batch_size,
shuffle=False,
# num_workers=1,
worker_init_fn=worker_init_fn,
pin_memory=True)
return train_loader, test_loader
class MLP(nn.Module):
def __init__(self, input_size=28 * 28, num_classes_per_task=10, hidden_size=[400, 400, 400]):
super(MLP, self).__init__()
self.hidden_size = hidden_size
self.input_size = input_size
# 初始化类别计数器
self.total_classes = num_classes_per_task
self.num_classes_per_task = num_classes_per_task
# 定义网络结构
self.fc1 = nn.Linear(input_size, hidden_size[0])
self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
self.fc_before_last = nn.Linear(hidden_size[1], hidden_size[2])
self.fc_out = nn.Linear(hidden_size[2], self.total_classes)
def forward(self, input, task_id=-1):
x = F.relu(self.fc1(input))
x = F.relu(self.fc2(x))
x = F.relu(self.fc_before_last(x))
x = self.fc_out(x)
return x
3 Baseline代码
没有任何连续学习算法的Baseline代码实现仅仅是将任务逐个训练。具体过程为:依次加载每个任务的数据集,独立训练模型,而不考虑模型对前一个任务的记忆能力。
class Baseline:
def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
self.num_classes_per_task = num_classes_per_task
self.num_tasks = num_tasks
self.batch_size = batch_size
self.epochs = epochs
self.neurons = neurons
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_size = 28 * 28
# Initialize model
self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
self.criterion = nn.CrossEntropyLoss()
# Get dataset
self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)
def evaluate(self, test_loader, task_id):
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
# Move data to GPU in batches
images = images.view(-1,self.input_size)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
outputs = self.model(images, task_id)
predicted = torch.argmax(outputs, dim=1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
return 100.0 * correct / total
def train_task(self, train_loader,optimizer, task_id):
self.model.train()
for images, labels in train_loader:
images = images.view(-1,self.input_size)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
optimizer.zero_grad()
outputs = self.model(images, task_id)
loss = self.criterion(outputs, labels)
loss.backward()
optimizer.step()
def run(self):
all_avg_acc = []
for task_id in range(self.num_tasks):
train_loader = self.train_loaders[task_id]
self.model = self.model.to(self.device)
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
for epoch in range(self.epochs):
self.train_task(train_loader,optimizer, task_id)
task_acc = []
for eval_task_id in range(task_id + 1):
accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
task_acc.append(accuracy)
mean_avg = np.round(np.mean(task_acc), 2)
print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
all_avg_acc.append(mean_avg)
avg_acc = np.mean(all_avg_acc)
print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")
if __name__ == '__main__':
print('Baseline'+"=" * 50)
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
baseline = Baseline(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
baseline.run()
Baseline==================================================
Task 0: Task Acc = [96.78],AVG=96.78
Task 1: Task Acc = [85.19, 97.0],AVG=91.1
Task 2: Task Acc = [52.66, 89.14, 97.27],AVG=79.69
Task AVG Acc: [96.78, 91.1, 79.69],AVG = 89.19
可以看到模型在学习新任务后,旧任务的准确率在下降,在学习完Task2后,第一个任务的准确率只有52.66,第二个任务的准确率只有89.14。
4 EWC算法
4.1 算法原理
论文《Overcoming catastrophic forgetting in neural networks》的EWC(Elastic Weight Consolidation)通过引入正则化项,保护与之前任务相关的重要参数,以减缓灾难性遗忘现象。其核心思想是利用任务训练完成后的参数重要性来约束模型的优化过程。
EWC假设某些参数对之前任务非常重要,改变这些参数会显著降低模型在旧任务上的性能。因此,EWC通过增加以下正则化项来保护这些参数:
L E W C = L n e w + λ 2 ∑ i F i ( θ i − θ i ∗ ) 2 L_{EWC} = L_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2 LEWC=Lnew+2λi∑Fi(θi−θi∗)2
其中:
- L n e w L_{new} Lnew 是新任务的损失函数;
- θ i \theta_i θi 是模型当前的参数;
- θ i ∗ \theta_i^* θi∗ 是旧任务的最优参数;
- F i F_i Fi 是Fisher信息矩阵,用于衡量每个参数的重要性;
- λ \lambda λ 是一个超参数,控制正则化项的权重。
通过在损失函数中引入这一正则化项,EWC能够在训练新任务时有效保护旧任务的重要参数,从而缓解灾难性遗忘问题。
4.2 代码实现
EWC算法的实现包括以下几个关键步骤:
- 在旧任务训练结束后,保存模型参数和计算Fisher信息矩阵;
- 在训练新任务时,将正则化项加入损失函数;
class EWC:
def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2):
self.num_classes_per_task = num_classes_per_task
self.num_tasks = num_tasks
self.batch_size = batch_size
self.epochs = epochs
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_size = 28 * 28
# Initialize model
self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.scaler = torch.cuda.amp.GradScaler() # Enable mixed precision
self.importance_dict = {}
self.previous_params = {}
self.lambda_ = 10000
self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)
def evaluate(self, test_loader, task_id):
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
# Move data to GPU in batches
images = images.view(-1,self.input_size)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
outputs = self.model(images, task_id)
predicted = torch.argmax(outputs, dim=1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
return 100.0 * correct / total
def train_task(self, train_loader,optimizer, task_id):
self.model.train()
for images, labels in train_loader:
images = images.view(-1,self.input_size)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
optimizer.zero_grad()
outputs = self.model(images, task_id)
if task_id > 0:
loss = self.ewc_multi_objective_loss(outputs, labels)
else:
loss = self.criterion(outputs, labels)
loss.backward()
optimizer.step()
def ewc_compute_importance(self, data_loader, task_id):
importance_dict = {name: torch.zeros_like(param, device=self.device) for name, param in self.model.named_parameters() if 'task' not in name}
self.model.eval()
for images, labels in data_loader:
images = images.view(-1,self.input_size)
images = images.to(self.device, non_blocking=True)
labels = labels.to(self.device, non_blocking=True)
self.model.zero_grad()
outputs = self.model(images, task_id=task_id)
loss = nn.CrossEntropyLoss()(outputs, labels)
loss.backward()
for name, param in self.model.named_parameters():
if name in importance_dict and param.requires_grad:
importance_dict[name] += param.grad ** 2 / len(data_loader)
return importance_dict
def update(self, dataset, task_id):
importance_dict = self.ewc_compute_importance(dataset, task_id)
for name in importance_dict:
if name in self.importance_dict:
self.importance_dict[name] += importance_dict[name]
else:
self.importance_dict[name] = importance_dict[name]
for name, param in self.model.named_parameters():
self.previous_params[name] = param.clone().detach()
def ewc_multi_objective_loss(self, outputs, labels):
regularization_loss = 0.0
for name, param in self.model.named_parameters():
if 'task' not in name and name in self.importance_dict and name in self.previous_params:
importance = self.importance_dict[name]
previous_param = self.previous_params[name]
regularization_loss += (importance * (param - previous_param).pow(2)).sum()
loss = self.criterion(outputs, labels)
total_loss = loss + self.lambda_ * regularization_loss
return total_loss
def run(self):
all_avg_acc = []
for task_id in range(self.num_tasks):
train_loader = self.train_loaders[task_id]
self.model = self.model.to(self.device)
optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
for epoch in range(self.epochs):
self.train_task(train_loader,optimizer, task_id)
self.update(train_loader, task_id)
task_acc = []
for eval_task_id in range(task_id + 1):
accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
task_acc.append(accuracy)
mean_avg = np.round(np.mean(task_acc), 2)
print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg},")
all_avg_acc.append(mean_avg)
avg_acc = np.mean(all_avg_acc)
print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")
if __name__ == '__main__':
print('EWC'+"=" * 50)
# 每次循环前重置随机种子
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
ewc = EWC(num_classes_per_task=10, num_tasks=5, batch_size=256, epochs=2)
ewc.run()
EWC==================================================
Task 0: Task Acc = [96.78],AVG=96.78,
Task 1: Task Acc = [95.47, 96.65],AVG=96.06,
Task 2: Task Acc = [90.9, 95.02, 96.28],AVG=94.07,
Task AVG Acc: [96.78, 96.06, 94.07],AVG = 95.63666666666666
在学习完每个任务后,旧任务的准确率只是轻微的下降,说明该算法有效的缓解了灾难性遗忘。