当前位置: 首页 > article >正文

【Python实现机器遗忘算法】复现2020年顶会CVPR算法Selective Forgetting

【Python实现机器遗忘算法】复现2020年顶会CVPR算法Selective Forgetting

请添加图片描述

1 算法原理

  • Golatkar, A., Achille, A., & Soatto, S. (2020). Eternal sunshine of the spotless net: Selective forgetting in deep networks. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 9304–9312.

  • 本文提出的算法简称为 SF(Selective Forgetting)算法,即选择性遗忘算法。这个名称来源于算法的核心目标:从深度神经网络的权重中选择性地移除特定数据子集的信息,而不影响其他数据的性能。该算法基于二次损失假设和梯度下降的稳定性,具体步骤如下:

  • 1 计算Hessian矩阵

    • 计算整个数据集 D \mathcal{D} D 的Hessian矩阵 A = ∇ 2 L D ( w ) A = \nabla^2 L_{\mathcal{D}}(w) A=2LD(w)
    • 计算保留数据集 D r \mathcal{D}_r Dr 的Hessian矩阵 B = ∇ 2 L D r ( w ) B = \nabla^2 L_{\mathcal{D}_r}(w) B=2LDr(w)
  • 2 计算梯度方向

    • 计算整个数据集 D \mathcal{D} D 的梯度方向 d = A − 1 ∇ w L D ( w ) d = A^{-1} \nabla_w L_{\mathcal{D}}(w) d=A1wLD(w)
    • 计算保留数据集 D r \mathcal{D}_r Dr 的梯度方向 d r = B − 1 ∇ w L D r ( w ) d_r = B^{-1} \nabla_w L_{\mathcal{D}_r}(w) dr=B1wLDr(w)
  • 3 构造遗忘函数

    • 根据遗忘函数的定义,构造 h ( w ) h(w) h(w)
      h ( w ) = w + e − B t e A t d + e − B t ( d − d r ) − d r . h(w) = w + e^{-Bt} e^{At} d + e^{-Bt} (d - d_r) - d_r. h(w)=w+eBteAtd+eBt(ddr)dr.
  • 4 添加噪声

    • 生成高斯噪声 n ∼ N ( 0 , Σ ) n \sim \mathcal{N}(0, \Sigma) nN(0,Σ),其中 Σ = λ σ h 2 B − 1 / 2 \Sigma = \sqrt{\lambda \sigma_h^2} B^{-1/2} Σ=λσh2 B1/2
    • 将噪声添加到遗忘函数中,得到最终的权重更新:
      S ( w ) = h ( w ) + n . S(w) = h(w) + n. S(w)=h(w)+n.
  • 5 更新权重

    • 使用 S ( w ) S(w) S(w) 更新网络权重,确保网络表现得像是从未见过 D f \mathcal{D}_f Df

2 代码实现

工具函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler  
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

warnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义三层全连接网络
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):
    transform = transforms.Compose([transforms.ToTensor()])
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)
    retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])

    forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)
    retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader, retain_train_loader, forgotten_train_loader

# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):
    random.seed(2024 + worker_id)
    np.random.seed(2024 + worker_id)
def get_transforms():
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]
    ])
    
    return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):
    use_fp16 = True
    # 使用新的初始化方式:torch.amp.GradScaler("cuda")
    scaler = GradScaler("cuda")  # 用于混合精度训练
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # 前向传播
        with autocast(enabled=use_fp16, device_type="cuda"):  # 更新为使用 "cuda"
            outputs = model(images)
            loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        if use_fp16:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        running_loss += loss.item()
    if scheduler is not None:
        # 更新学习率
        scheduler.step()

    print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):
    """
    测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。

    :param model: 要测试的模型
    :param test_loader: 测试数据加载器
    :param forgotten_classes: 需要遗忘的类别列表
    :return: overall_accuracy, forgotten_accuracy, retained_accuracy
    """
    model.eval()
    correct = 0
    total = 0
    forgotten_correct = 0
    forgotten_total = 0
    retained_correct = 0
    retained_total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)

            # 计算总的准确率
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # 计算遗忘类别的准确率
            mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))
            forgotten_total += mask_forgotten.sum().item()
            forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()

            # 计算保留类别的准确率(除遗忘类别的其他类别)
            mask_retained = ~mask_forgotten
            retained_total += mask_retained.sum().item()
            retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()

    overall_accuracy = correct / total
    forgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0
    retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0

    # return overall_accuracy, forgotten_accuracy, retained_accuracy
    return  round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)


主函数

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from models.Base import test_model, load_MNIST_data, load_CIFAR100_data, init_model

# 定义机器遗忘类
class OptimalQuadraticForgetter:
    def __init__(self, model, lambda_param, sigma_h, forget_threshold):
        self.model = model
        self.lambda_param = lambda_param
        self.sigma_h = sigma_h
        self.forget_threshold = forget_threshold  # 设置遗忘阈值
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.hessian_D = None  # 整个数据集的Hessian矩阵
        self.hessian_Dr = None  # 保留数据集的Hessian矩阵
        self.grad_D = None  # 整个数据集的梯度
        self.grad_Dr = None  # 保留数据集的梯度

    # 计算Hessian矩阵和梯度
    def compute_hessian_and_grad(self, dataloader_D, dataloader_Dr):
        self.model.eval()
        self.hessian_D = {}
        self.hessian_Dr = {}
        self.grad_D = {}
        self.grad_Dr = {}

        # 计算整个数据集的Hessian和梯度
        for data, target in dataloader_D:
            data, target = data.to(self.device), target.to(self.device)
            self.model.zero_grad()
            output = self.model(data)
            loss = F.cross_entropy(output, target)
            loss.backward(create_graph=True)  # 需要二阶导数

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    if name not in self.hessian_D:
                        self.hessian_D[name] = torch.zeros_like(param.grad)
                        self.grad_D[name] = torch.zeros_like(param.grad)
                    self.hessian_D[name] += torch.autograd.grad(loss, param, create_graph=True)[0]
                    self.grad_D[name] += param.grad

        # 计算保留数据集的Hessian和梯度
        for data, target in dataloader_Dr:
            data, target = data.to(self.device), target.to(self.device)
            self.model.zero_grad()
            output = self.model(data)
            loss = F.cross_entropy(output, target)
            loss.backward(create_graph=True)  # 需要二阶导数

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    if name not in self.hessian_Dr:
                        self.hessian_Dr[name] = torch.zeros_like(param.grad)
                        self.grad_Dr[name] = torch.zeros_like(param.grad)
                    self.hessian_Dr[name] += torch.autograd.grad(loss, param, create_graph=True)[0]
                    self.grad_Dr[name] += param.grad

        # 平均化Hessian和梯度
        for key in self.hessian_D:
            self.hessian_D[key] /= len(dataloader_D.dataset)
            self.grad_D[key] /= len(dataloader_D.dataset)
        for key in self.hessian_Dr:
            self.hessian_Dr[key] /= len(dataloader_Dr.dataset)
            self.grad_Dr[key] /= len(dataloader_Dr.dataset)

    # 执行选择性遗忘操作:根据Hessian和梯度调整参数
    def scrub_weights(self):
        self.model.train()

        for name, param in self.model.named_parameters():
            if name in self.hessian_D and param.requires_grad:
                # 计算牛顿更新方向
                hessian_D = self.hessian_D[name]
                hessian_Dr = self.hessian_Dr[name]
                grad_D = self.grad_D[name]
                grad_Dr = self.grad_Dr[name]

                # 计算遗忘函数 h(w)
                h_w = param.data + torch.matmul(torch.inverse(hessian_Dr), grad_Dr - grad_D)

                # 添加噪声
                noise_std = (self.lambda_param * self.sigma_h**2)**0.25
                noise = torch.normal(mean=0, std=noise_std, size=param.data.shape, device=self.device)
                param.data = h_w + noise  # 更新权重

        return self.model

# 全局函数:进行选择性遗忘
def optimal_quadratic_forgetting(model, dataloader_D, dataloader_Dr, lambda_param, sigma_h, forget_threshold):
    # 创建 Forgetter 对象
    forgetter = OptimalQuadraticForgetter(model, lambda_param, sigma_h, forget_threshold)

    # 计算Hessian矩阵和梯度
    forgetter.compute_hessian_and_grad(dataloader_D, dataloader_Dr)

    # 执行权重擦除操作
    model_after = forgetter.scrub_weights()

    return model_after

def main():
    # 超参数设置
    batch_size = 256
    forgotten_classes = [0]
    ratio = 1
    model_name = "ResNet18"
    
    # 加载数据
    if model_name == "MLP":
        train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)
    elif model_name == "ResNet18":
        train_loader, test_loader, retain_loader, forget_loader = load_CIFAR100_data(batch_size, forgotten_classes, ratio)

    model_before = init_model(model_name, train_loader)
    
    # 在训练之前测试初始模型准确率
    overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)

    # 实现遗忘操作
    print("执行遗忘 optimal_quadratic_forgetting...")
    model_after = optimal_quadratic_forgetting(model_before, train_loader, retain_loader, lambda_param=0.1, sigma_h=0.1, forget_threshold=1e-05)

    # 测试遗忘后的模型
    overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)

    # 输出遗忘前后的准确率变化
    print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")
    print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")
    print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")
    print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")

if __name__ == "__main__":
    main()

3 总结

噪声扰动是遗忘的核心机制之一,但噪声的引入可能会对模型的整体性能产生负面影响,尤其是在需要保留的数据上。需要仔细调整。过大的噪声会导致模型性能下降,而过小的噪声则可能无法有效移除目标数据的信息。


http://www.kler.cn/a/522243.html

相关文章:

  • 大屏 UI 设计风格的未来趋势
  • three.js用粒子使用canvas生成的中文字符位图材质
  • 使用 KNN 搜索和 CLIP 嵌入构建多模态图像检索系统
  • Airflow:精通Airflow任务依赖
  • 【Qt】多线程
  • DataWhale组队学习 leetCode task4
  • 006 mybatis关联查询(一对一、一对多)
  • OPencv3.4.1安装及配置教程
  • 20.Word:小谢-病毒知识的科普文章❗【38】
  • freeswitch在centos上编译过程
  • 白平衡与色温:摄影中的色彩密码
  • 2025_1_27 C语言内存,递归,汉诺塔问题
  • 二叉树(补充)
  • 51单片机开发:IO扩展(串转并)实验
  • 基于单片机的家用无线火灾报警系统的设计
  • PETSc源码分析: Time Integrators
  • 将 OneLake 数据索引到 Elasticsearch - 第 1 部分
  • C语言中的static关键字在函数和变量声明中的不同作用是什么?
  • AI学习指南Ollama篇-Ollama模型的量化与优化
  • MMDetection 详细安装过程
  • Elasticsearch的索引生命周期管理
  • RocketMQ实战—1.订单系统面临的技术挑战
  • 使用 OpenResty 构建高效的动态图片水印代理服务20250127
  • 批量处理多个模型的预测任务
  • 甘油单油酸酯行业分析
  • 常见的多媒体框架(FFmpeg GStreamer DirectShow AVFoundation OpenMax)