当前位置: 首页 > article >正文

【Python实现连续学习算法】Python实现连续学习Baseline 及经典算法EWC

Python实现连续学习Baseline 及经典算法EWC

在这里插入图片描述

1 连续学习概念及灾难性遗忘

连续学习(Continual Learning)是一种模拟人类学习过程的机器学习方法,它旨在让模型在面对多个任务时能够连续学习,而不会遗忘已学到的知识。然而,大多数深度学习模型在连续学习多个任务时会出现“灾难性遗忘”(Catastrophic Forgetting)现象。灾难性遗忘指模型在学习新任务时会大幅度遗忘之前学到的任务知识,这是因为模型参数在新任务的训练过程中被完全覆盖。

解决灾难性遗忘问题是连续学习研究的核心。目前已有多种方法被提出,包括正则化方法、回放、架构等等的方法,其中EWC(Elastic Weight Consolidation)是一种经典的正则化方法。

2 PermutdMNIST数据集及模型

PermutedMNIST是连续学习领域的一种经典测试数据集。它通过对MNIST数据集中的像素进行随机置换生成不同的任务。每个任务都是一个由置换规则决定的分类问题,但所有任务共享相同的标签空间。

对于模型的选择,通常采用简单的全连接神经网络。网络结构可以包含若干个隐藏层,每个隐藏层具有一定数量的神经元,并使用ReLU作为激活函数。网络的输出层与标签类别数一致。

模型在训练每个任务时需要调整参数,研究灾难性遗忘问题的严重程度,并在引入算法时测试其对连续学习能力的改善效果。

import random
import torch
from torchvision import datasets
import os
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import warnings
warnings.filterwarnings("ignore")
# Set seeds
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class PermutedMNIST(datasets.MNIST):
    def __init__(self, root="./data/mnist", train=True, permute_idx=None):
        super(PermutedMNIST, self).__init__(root, train, download=True)
        assert len(permute_idx) == 28 * 28
        if self.train:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])
        else:
            self.data = torch.stack([img.float().view(-1)[permute_idx] / 255
                                      for img in self.data])

    def __getitem__(self, index):
        if self.train:
            img, target = self.data[index], self.train_labels[index]
        else:
            img, target = self.data[index], self.test_labels[index]
        return img.view(1, 28, 28), target

    def get_sample(self, sample_size):
        random.seed(2024)
        sample_idx = random.sample(range(len(self)), sample_size)
        return [img.view(1, 28, 28) for img in self.data[sample_idx]]
def worker_init_fn(worker_id):
    # 确保每个 worker 的随机种子一致
    random.seed(2024 + worker_id)
    np.random.seed(2024 + worker_id)
def get_permute_mnist(num_task, batch_size):
    random.seed(2024)
    train_loader = {}
    test_loader = {}
    root_dir = './data/permuted_mnist'
    os.makedirs(root_dir, exist_ok=True)

    for i in range(num_task):
        permute_idx = list(range(28 * 28))
        random.shuffle(permute_idx)

        train_dataset_path = os.path.join(root_dir, f'train_dataset_{i}.pt')
        test_dataset_path = os.path.join(root_dir, f'test_dataset_{i}.pt')

        if os.path.exists(train_dataset_path) and os.path.exists(test_dataset_path):

            train_dataset = torch.load(train_dataset_path)
            test_dataset = torch.load(test_dataset_path)
        else:
            train_dataset = PermutedMNIST(train=True, permute_idx=permute_idx)
            test_dataset = PermutedMNIST(train=False, permute_idx=permute_idx)
            torch.save(train_dataset, train_dataset_path)
            torch.save(test_dataset, test_dataset_path)

        train_loader[i] = DataLoader(train_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)
        test_loader[i] = DataLoader(test_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    #  num_workers=1,
                                     worker_init_fn=worker_init_fn,
                                     pin_memory=True)

    return train_loader, test_loader

class MLP(nn.Module):
    def __init__(self, input_size=28 * 28, num_classes_per_task=10, hidden_size=[400, 400, 400]):
        super(MLP, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        # 初始化类别计数器
        self.total_classes = num_classes_per_task
        self.num_classes_per_task = num_classes_per_task
        
        # 定义网络结构
        self.fc1 = nn.Linear(input_size, hidden_size[0])
        self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
        self.fc_before_last = nn.Linear(hidden_size[1], hidden_size[2])
        
        self.fc_out = nn.Linear(hidden_size[2], self.total_classes)
    
    def forward(self, input, task_id=-1):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc_before_last(x))
        x = self.fc_out(x)
        return x

3 Baseline代码

没有任何连续学习算法的Baseline代码实现仅仅是将任务逐个训练。具体过程为:依次加载每个任务的数据集,独立训练模型,而不考虑模型对前一个任务的记忆能力。


class Baseline:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2, neurons=0):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.neurons = neurons
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        # Initialize model
        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()


        # Get dataset
        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)
    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                # Move data to GPU in batches
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total


    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()
        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def run(self):
        all_avg_acc = []
        
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)

            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg}")
            all_avg_acc.append(mean_avg)
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")

if __name__ == '__main__':
    print('Baseline'+"=" * 50)
    random.seed(2024)
    torch.manual_seed(2024)
    np.random.seed(2024)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    baseline = Baseline(num_classes_per_task=10, num_tasks=3, batch_size=256, epochs=2)
    baseline.run()

Baseline==================================================

Task 0: Task Acc = [96.78],AVG=96.78

Task 1: Task Acc = [85.19, 97.0],AVG=91.1

Task 2: Task Acc = [52.66, 89.14, 97.27],AVG=79.69

Task AVG Acc: [96.78, 91.1, 79.69],AVG = 89.19

可以看到模型在学习新任务后,旧任务的准确率在下降,在学习完Task2后,第一个任务的准确率只有52.66,第二个任务的准确率只有89.14。

4 EWC算法

4.1 算法原理

论文《Overcoming catastrophic forgetting in neural networks》的EWC(Elastic Weight Consolidation)通过引入正则化项,保护与之前任务相关的重要参数,以减缓灾难性遗忘现象。其核心思想是利用任务训练完成后的参数重要性来约束模型的优化过程。

EWC假设某些参数对之前任务非常重要,改变这些参数会显著降低模型在旧任务上的性能。因此,EWC通过增加以下正则化项来保护这些参数:

L E W C = L n e w + λ 2 ∑ i F i ( θ i − θ i ∗ ) 2 L_{EWC} = L_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_i^*)^2 LEWC=Lnew+2λiFi(θiθi)2

其中:

  • L n e w L_{new} Lnew 是新任务的损失函数;
  • θ i \theta_i θi 是模型当前的参数;
  • θ i ∗ \theta_i^* θi 是旧任务的最优参数;
  • F i F_i Fi 是Fisher信息矩阵,用于衡量每个参数的重要性;
  • λ \lambda λ 是一个超参数,控制正则化项的权重。

通过在损失函数中引入这一正则化项,EWC能够在训练新任务时有效保护旧任务的重要参数,从而缓解灾难性遗忘问题。

4.2 代码实现

EWC算法的实现包括以下几个关键步骤:

  1. 在旧任务训练结束后,保存模型参数和计算Fisher信息矩阵;
  2. 在训练新任务时,将正则化项加入损失函数;

class EWC:
    def __init__(self, num_classes_per_task=10, num_tasks=10, batch_size=256, epochs=2):
        self.num_classes_per_task = num_classes_per_task
        self.num_tasks = num_tasks
        self.batch_size = batch_size
        self.epochs = epochs
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.input_size = 28 * 28

        # Initialize model
        self.model = MLP(num_classes_per_task=self.num_classes_per_task).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.scaler = torch.cuda.amp.GradScaler()  # Enable mixed precision
        self.importance_dict = {}
        self.previous_params = {}
        self.lambda_ = 10000

        self.train_loaders, self.test_loaders = get_permute_mnist(self.num_tasks, self.batch_size)

    def evaluate(self, test_loader, task_id):
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                # Move data to GPU in batches
                images = images.view(-1,self.input_size)
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)
                outputs = self.model(images, task_id)
                predicted = torch.argmax(outputs, dim=1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        return 100.0 * correct / total

    def train_task(self, train_loader,optimizer, task_id):
        self.model.train()
        for images, labels in train_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            optimizer.zero_grad()
            outputs = self.model(images, task_id)
            if task_id > 0:
                loss = self.ewc_multi_objective_loss(outputs, labels)
            else:
                loss = self.criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    def ewc_compute_importance(self, data_loader, task_id):
        importance_dict = {name: torch.zeros_like(param, device=self.device) for name, param in self.model.named_parameters() if 'task' not in name}
        self.model.eval()
        for images, labels in data_loader:
            images = images.view(-1,self.input_size)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)
            self.model.zero_grad()

            outputs = self.model(images, task_id=task_id)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            loss.backward()
            for name, param in self.model.named_parameters():
                if name in importance_dict and param.requires_grad:
                    importance_dict[name] += param.grad ** 2 / len(data_loader)
        return importance_dict
    
    def update(self, dataset, task_id):
        importance_dict = self.ewc_compute_importance(dataset, task_id)

        
        for name in importance_dict:
            if name in self.importance_dict:
                self.importance_dict[name] += importance_dict[name]
            else:
                self.importance_dict[name] = importance_dict[name]

        for name, param in self.model.named_parameters():
            self.previous_params[name] = param.clone().detach()

    def ewc_multi_objective_loss(self, outputs, labels):
        regularization_loss = 0.0
        for name, param in self.model.named_parameters():
            if 'task' not in name and name in self.importance_dict and name in self.previous_params:
                importance = self.importance_dict[name]
                previous_param = self.previous_params[name]
                regularization_loss += (importance * (param - previous_param).pow(2)).sum()
                
        loss = self.criterion(outputs, labels)
        total_loss = loss + self.lambda_ * regularization_loss
        return total_loss
    def run(self):
        all_avg_acc = []
        for task_id in range(self.num_tasks):
            train_loader = self.train_loaders[task_id]
            self.model = self.model.to(self.device)
            optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-3, weight_decay=1e-4)
            for epoch in range(self.epochs):
                self.train_task(train_loader,optimizer, task_id)
            self.update(train_loader, task_id)

            task_acc = []
            for eval_task_id in range(task_id + 1):
                accuracy = self.evaluate(self.test_loaders[eval_task_id], eval_task_id)
                task_acc.append(accuracy)
            mean_avg = np.round(np.mean(task_acc), 2)
            print(f"Task {task_id}: Task Acc = {task_acc},AVG={mean_avg},")
            all_avg_acc.append(mean_avg)
        avg_acc = np.mean(all_avg_acc)
        print(f"Task AVG Acc: {all_avg_acc},AVG = {avg_acc}")

if __name__ == '__main__':
    print('EWC'+"=" * 50)
    # 每次循环前重置随机种子
    random.seed(2024)
    torch.manual_seed(2024)
    np.random.seed(2024)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    ewc = EWC(num_classes_per_task=10, num_tasks=5, batch_size=256, epochs=2)
    ewc.run()

EWC==================================================

Task 0: Task Acc = [96.78],AVG=96.78,

Task 1: Task Acc = [95.47, 96.65],AVG=96.06,

Task 2: Task Acc = [90.9, 95.02, 96.28],AVG=94.07,

Task AVG Acc: [96.78, 96.06, 94.07],AVG = 95.63666666666666

在学习完每个任务后,旧任务的准确率只是轻微的下降,说明该算法有效的缓解了灾难性遗忘。


http://www.kler.cn/a/459399.html

相关文章:

  • ‘元素.style.样式名‘获取不到样式,应该使用Window.getComputedStyle()获取正真的样式
  • Rockect基于Dledger的Broker主从同步原理
  • VBA 64位API声明语句第005讲
  • Codigger集成Copilot:智能编程助手
  • 低代码引擎插件开发:开启开发的便捷与创新之路
  • 【Matlab算法】基于改进人工势场法的移动机器人路径规划研究(附MATLAB完整代码)
  • Spring Cloud Alibaba2022之Sentinel总结
  • 【GraphRAG】LEGO-GraphRAG框架解读
  • 商米电子秤服务插件
  • 华为ensp-BGP联盟
  • vue 修改vant样式NoticeBar中的图标,不用插槽可以直接用图片
  • AI与药学:ChatGPT与临床培训——药学博士(Pharm-D)学生的看法、担忧和实践
  • 《机器学习》——数据标准化(0~1标准化,z标准化)
  • 【杂谈】-艺术中的AI:作用及未来
  • C语言内存管理函数
  • [python SQLAlchemy数据库操作入门]-14.实时数据采集 记录股市动态
  • No.2十六届蓝桥杯备战|练习题4道|数据类型|字符型|整型|浮点型|布尔型|signed|unsigned(C++)
  • 下载并使用CICFlowMeter提取网络流特征(Windows版本)
  • Mac 环境 VVenC 编译与编码命令行工具使用教程
  • 英创主板ESM8400支持Debian 12桌面系统
  • CPT203 Software Engineering 软件工程 Pt.1 概论和软件过程(中英双语)
  • Python入门:8.Python中的函数
  • kanzi做3d时钟屏保
  • 【算法day27】动态规划:基础2
  • 微软Win11内核迎新变,Rust语言助力提升系统安全可靠性
  • 第25天:信息收集-项目系统一键打点资产侦察企查产权空间引擎风险监测利器部署