当前位置: 首页 > article >正文

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

【Python实现机器遗忘算法】复现2021年顶会 AAAI算法Amnesiac Unlearning

在这里插入图片描述

1 算法原理

论文:Graves, L., Nagisetty, V., & Ganesh, V. (2021). Amnesiac machine learning. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 35, 11516–11524.

Amnesiac Unlearning(遗忘性遗忘) 是一种高效且精确的算法,旨在从已经训练好的神经网络模型中删除特定数据的学习信息,而不会显著影响模型在其他数据上的性能。该算法的核心思想是通过选择性撤销与敏感数据相关的参数更新来实现数据的“遗忘”。

1. 训练阶段:记录参数更新

在模型训练过程中,记录每个批次的参数更新以及哪些批次包含敏感数据。

  • 步骤
    1. 初始化模型参数:从随机初始化的参数 θ i n i t i a l \theta_{initial} θinitial 开始训练模型。
    2. 训练模型:使用标准训练方法(如随机梯度下降)对模型进行训练,训练过程分为多个 epoch,每个 epoch 包含多个批次(batches)。
    3. 记录参数更新
      • 对于每个批次 b b b,记录该批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b,其中 e e e 表示 epoch 编号, b b b 表示批次编号。
      • 同时,记录哪些批次包含敏感数据(即需要删除的数据)。可以将这些批次标记为 S B SB SB(Sensitive Batches)。
    4. 存储信息
      • 存储所有批次的参数更新 Δ θ e , b \Delta_{\theta_{e,b}} Δθe,b
      • 存储敏感数据批次的索引 S B SB SB

2. 数据删除阶段:选择性撤销参数更新

当收到数据删除请求时,撤销与敏感数据相关的参数更新。

  • 步骤

    1. 识别敏感数据批次:从存储的记录中提取包含敏感数据的批次索引 S B SB SB
    2. 撤销参数更新

    计算删除敏感数据后的模型参数 θ M \theta_{M} θM
    θ M ′ = θ M − ∑ s b ∈ S B Δ θ s b \theta_{M'} = \theta_{M} - \sum_{sb \in SB} \Delta_{\theta_{sb}} θM=θMsbSBΔθsb

    其中:

    • θ M \theta_{M} θM 是原始训练后的模型参数。
    • Δ θ s b \Delta_{\theta_{sb}} Δθsb 是敏感数据批次 s b sb sb 的参数更新。
    • 生成保护模型:使用更新后的参数 θ M ′ \theta_{M'} θM 作为新的模型参数。

3. 微调阶段(可选)

如果删除的批次较多,可能会对模型性能产生一定影响。此时可以通过少量微调来恢复模型性能。

  • 步骤
    1. 微调模型:使用删除敏感数据后的数据集对模型进行少量迭代训练。
    2. 恢复性能:通过微调,模型可以恢复在非敏感数据上的性能。

2 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
from models.Base import load_MNIST_data, test_model, device, MLP, load_CIFAR100_data, init_model

# AmnesiacForget类:封装撤销与敏感数据相关的参数更新
class AmnesiacForget:
    def __init__(self, model, all_data, epochs, learning_rate):
        """
        初始化 AmnesiacForget 类。
        
        :param model: 需要训练的模型。
        :param all_data: 训练数据集。
        :param epochs: 训练的总 epoch 数。
        :param learning_rate: 优化器的学习率。
        """
        self.model = model
        self.all_data = all_data
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.batch_updates = []  # 存储每个批次的参数更新值
        self.initial_params = {name: param.clone() for name, param in model.named_parameters()}  # 存储初始模型参数
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 设备选择(GPU 或 CPU)

    def train(self, forgotten_classes):
        """
        训练模型并记录每个批次的参数更新值。
        
        :param forgotten_classes: 需要遗忘的类别列表。
        :return: sensitive_batches: 包含敏感数据的批次索引。
        """
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)  # 使用 Adam 优化器
        self.model.train()  # 将模型设置为训练模式
        sensitive_batches = {}  # 记录每个 epoch 中包含敏感数据的批次索引

        # 训练过程
        for epoch in range(self.epochs):
            running_loss = 0.0
            sensitive_batches[epoch] = set()  # 每个 epoch 的敏感批次集

            for batch_idx, (images, labels) in enumerate(self.all_data):
                optimizer.zero_grad()  # 清空梯度
                images, labels = images.to(self.device), labels.to(self.device)  # 将数据移动到设备上

                # 前向传播和损失计算
                outputs = self.model(images)
                loss = nn.CrossEntropyLoss()(outputs, labels)

                # 反向传播计算梯度
                loss.backward()
                running_loss += loss.item()

                # 记录当前参数值
                current_params = {name: param.clone() for name, param in self.model.named_parameters()}

                # 更新参数
                optimizer.step()

                # 记录参数更新值(当前参数值 - 更新前的参数值)
                batch_update = {}
                for name, param in self.model.named_parameters():
                    if param.requires_grad:
                        batch_update[name] = param.data - current_params[name].data  # 记录参数更新值
                self.batch_updates.append(batch_update)

                # 记录包含敏感数据的批次索引
                if any(label.item() in forgotten_classes for label in labels):
                    sensitive_batches[epoch].add(batch_idx)

            print(f"Epoch [{epoch+1}/{self.epochs}], Loss: {running_loss/len(self.all_data):.4f}")

        return sensitive_batches

    def unlearn(self, sensitive_batches):
        """
        撤销与敏感数据相关的批次更新。
        
        :param sensitive_batches: 包含敏感数据的批次索引。
        :return: 更新后的模型。
        """
        # 计算非敏感批次的参数更新总和
        non_sensitive_updates = {name: torch.zeros_like(param) for name, param in self.model.named_parameters()}
        for batch_idx, batch_update in enumerate(self.batch_updates):
            if batch_idx not in {sb for epoch_batches in sensitive_batches.values() for sb in epoch_batches}:
                for name, update in batch_update.items():
                    non_sensitive_updates[name] += update

        # 更新模型参数:初始参数 + 非敏感批次的更新
        for name, param in self.model.named_parameters():
            param.data = self.initial_params[name].data + non_sensitive_updates[name]

        return self.model


# 全局函数:实现 Amnesiac Forget
def amnesiac_unlearning(model_before, test_loader, forgotten_classes, all_data, epochs=10, learning_rate=0.001):
    """
    执行 Amnesiac Unlearning:训练模型,记录参数更新,并撤销与敏感数据相关的更新。
    
    :param model_before: 遗忘前的模型。
    :param test_loader: 测试数据加载器。
    :param forgotten_classes: 需要遗忘的类别列表。
    :param all_data: 训练数据集。
    :param epochs: 训练的总 epoch 数(默认为 10)。
    :param learning_rate: 优化器的学习率(默认为 0.001)。
    
    :return: 遗忘后的模型。
    """
    # 模拟从头训练的过程,并记录批次更新的过程
    print("模拟重新训练过程,记录批次更新...")
    temp_model = MLP().to(device)  # 初始化一个新模型
    amnesiac_forget = AmnesiacForget(temp_model, all_data, epochs, learning_rate)  # 初始化 AmnesiacForget 类
    sensitive_batches = amnesiac_forget.train(forgotten_classes)  # 训练模型并记录敏感批次

    # 测试遗忘前的模型性能
    overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(amnesiac_forget.model, test_loader)
    print(f"全部准确率: {overall_acc_before:.2f}%, 保留准确率: {retained_acc_before:.2f}%, 遗忘准确率: {forgotten_acc_before:.2f}%")

    # 应用遗忘:撤销与敏感数据相关的批次更新
    model_after = amnesiac_forget.unlearn(sensitive_batches)

    return model_after


def main():
    # 超参数设置
    batch_size = 256
    forgotten_classes = [0]  # 需要遗忘的类别
    ratio = 1
    model_name = "ResNet18"  # 模型名称

    # 加载数据
    if model_name == "MLP":
        train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)
    elif model_name == "ResNet18":
        train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)

    # 初始化模型
    model_before = init_model(model_name, train_loader)

    # 在训练之前测试初始模型准确率
    overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)

    # 实现遗忘操作
    print("执行遗忘 Amnesiac...")
    model_after = amnesiac_unlearning(model_before, test_loader, forgotten_classes, train_loader, epochs=5, learning_rate=0.001)

    # 测试遗忘后的模型
    overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)

    # 输出遗忘前后的准确率变化
    print(f"Unlearning 前遗忘准确率: {100 * forgotten_acc_before:.2f}%")
    print(f"Unlearning 后遗忘准确率: {100 * forgotten_acc_after:.2f}%")
    print(f"Unlearning 前保留准确率: {100 * retained_acc_before:.2f}%")
    print(f"Unlearning 后保留准确率: {100 * retained_acc_after:.2f}%")


if __name__ == "__main__":
    main()

3 总结

  • 高效性:只需撤销与敏感数据相关的参数更新,避免了从头训练模型的高成本。
  • 精确性:能够精确删除特定数据的学习信息,特别适合删除少量数据。
  • 存储成本:需要存储每个批次的参数更新,存储成本较高,但通常低于从头训练模型的成本。
  • 适用场景:适合删除少量数据(如单个样本或少量样本),而不适合删除大量数据(如整个类别)。

http://www.kler.cn/a/522799.html

相关文章:

  • 危机13小时:追踪一场GitHub投毒事件
  • 【MySQL】初始MySQL、库与表的操作
  • Linux C++
  • GPU上没程序在跑但是显存被占用
  • Tensor 基本操作2 理解 tensor.max 操作,沿着给定的 dim 是什么意思 | PyTorch 深度学习实战
  • 【Elasticsearch】Elasticsearch的查询
  • Node.js日志记录新篇章:morgan中间件的使用与优势
  • Fort Firewall:全方位守护网络安全
  • 数据结构与算法之数组: LeetCode 380. O(1) 时间插入、删除和获取随机元素 (Ts版)
  • TS开发的类型索引目录
  • kubernetes 核心技术-调度器
  • 公式与函数的应用
  • 【前端SEO】使用Vue.js + Nuxt 框架构建服务端渲染 (SSR) 应用满足SEO需求
  • 基于 PyTorch 的深度学习模型开发实战
  • 搭建 docxify 静态博客教程
  • 13、Java JDBC 编程:连接数据库的桥梁
  • Java并发编程实战:深入探索线程池与Executor框架
  • WordPress Web Directory Free插件本地包含漏洞复现(附脚本)(CVE-2024-3673)
  • 更换keil工程芯片到103c8t6(HAL库版本)
  • 豆包MarsCode:字符串字符类型排序问题
  • JS宏进阶:控件与事件
  • java:read weather info from openweathermap.org
  • 书生大模型实战营2
  • Semaphore 与 线程池 Executor 有什么区别?
  • 嵌入式知识点总结 Linux驱动 (三)-文件系统
  • Linux 35.6 + JetPack v5.1.4之编译器升级