当前位置: 首页 > article >正文

神经网络防“失忆“秘籍:弹性权重固化如何让AI学会“温故知新“

神经网络防"失忆"秘籍:弹性权重固化如何让AI学会"温故知新"

“就像学霸给重点笔记贴荧光标签,EWC给重要神经网络参数上锁”


一、核心公式对比表

公式名称数学表达式通俗解释类比场景文献
EWC主公式 L t o t a l = L n e w + λ 2 ∑ i F i ( θ i − θ o l d , i ) 2 L_{total} = L_{new} + \frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2 Ltotal=Lnew+2λiFi(θiθold,i)2给重要知识上锁给重点笔记贴荧光标签
贝叶斯推导式 log ⁡ p ( θ ∣ D ) ∝ log ⁡ p ( D B ∣ θ ) + log ⁡ p ( θ ∣ D A ) \log p(\theta|D) \propto \log p(D_B|\theta) + \log p(\theta|D_A) logp(θD)logp(DBθ)+logp(θDA)新旧知识平衡法则考试前既复习新题也温习旧题
费舍尔信息矩阵 F i = E [ ∇ θ 2 log ⁡ p ( y ∣ x , θ ) ] F_i = \mathbb{E}[\nabla_\theta^2 \log p(y|x,\theta)] Fi=E[θ2logp(yx,θ)]知识重要度评分卡根据笔记划重点的频率标记

二、公式详解与类比解释

公式1:EWC核心保护机制

L t o t a l = L n e w ( θ ) ⏟ 新任务 + λ 2 ∑ i F i ( θ i − θ o l d , i ) 2 ⏟ 旧知识防护罩 L_{total} = \underbrace{L_{new}(\theta)}_{\text{新任务}} + \underbrace{\frac{\lambda}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2}_{\text{旧知识防护罩}} Ltotal=新任务 Lnew(θ)+旧知识防护罩 2λiFi(θiθold,i)2

参数数学符号类比解释作用原理
旧任务参数 θ o l d \theta_{old} θold学霸的旧笔记本知识基准锚点
重要性系数 F i F_i Fi荧光标签密度参数重要性量化
约束强度 λ \lambda λ胶水粘性系数平衡新旧知识权重

案例:在图像分类任务中,给识别"猫耳朵"的关键神经元增加3倍保护权重


公式2:费舍尔信息矩阵计算

F i = 1 N ∑ x , y ( ∂ log ⁡ p ( y ∥ x , θ ) ∂ θ i ) 2 F_i = \frac{1}{N} \sum_{x,y} \left( \frac{\partial \log p(y\|x,\theta)}{\partial \theta_i} \right)^2 Fi=N1x,y(θilogp(yx,θ))2
变量解读

  • x x x:输入数据(学生的练习题)
  • y y y:标签答案(标准答案)
  • log ⁡ p ( y ∥ x , θ ) \log p(y\|x,\theta) logp(yx,θ):答案正确率评分

类比解释

如同统计学生复习时翻看某页笔记的次数,翻看越频繁的页面(参数)获得越多荧光标签(高F值)


三、公式体系演进(关键推导步骤)

1. 贝叶斯推导路径

  1. 初始目标
    max ⁡ θ log ⁡ p ( θ ∥ D A , D B ) \max_\theta \log p(\theta\|D_A, D_B) θmaxlogp(θDA,DB)
  2. 任务分解
    ∝ log ⁡ p ( D B ∥ θ ) + log ⁡ p ( θ ∥ D A ) \propto \log p(D_B\|\theta) + \log p(\theta\|D_A) logp(DBθ)+logp(θDA)
  3. 拉普拉斯近似
    log ⁡ p ( θ ∥ D A ) ≈ − 1 2 ∑ i F i ( θ i − θ o l d , i ) 2 \log p(\theta\|D_A) \approx -\frac{1}{2} \sum_i F_i (\theta_i - \theta_{old,i})^2 logp(θDA)21iFi(θiθold,i)2

2. 方法对比

方法核心公式优势局限
L2正则 L = L n e w + λ ∣ θ − θ o l d ∣ 2 L = L_{new} + \lambda |\theta - \theta_{old}|^2 L=Lnew+λθθold2简单易实现无差别保护所有参数
EWC L = L n e w + λ 2 ∑ F i ( θ i − θ o l d , i ) 2 L = L_{new} + \frac{\lambda}{2} \sum F_i(\theta_i - \theta_{old,i})^2 L=Lnew+2λFi(θiθold,i)2智能参数保护需计算二阶导数
LwF L = L n e w + α D K L ( p o l d ∣ p n e w ) L = L_{new} + \alpha D_{KL}(p_{old}|p_{new}) L=Lnew+αDKL(poldpnew)保持输出分布稳定依赖旧模型推理

四、代码实战:MNIST/FashionMNIST增量学习

import torch
import torch.nn as nn

class EWC_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1,32,3), 
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(32*13*13,10)

    def forward(self, x):
        x = self.conv(x)
        return self.fc(x.view(x.size(0),-1))

# EWC核心实现
class EWC_Regularizer:
    def __init__(self, model, dataloader, device):
        self.model = model
        self.params = {n:p.detach().clone() for n,p in model.named_parameters()}  # 旧参数快照
        self.fisher = {}
        
        # 计算Fisher信息矩阵
        for batch in dataloader:
            inputs, labels = batch
            outputs = model(inputs.to(device))
            loss = nn.CrossEntropyLoss()(outputs, labels.to(device))
            loss.backward()
            
            for n,p in model.named_parameters():
                if p.grad is not None:
                    self.fisher[n] = p.grad.data.pow(2).mean()  # 梯度平方均值

    def penalty(self, current_params):
        loss = 0
        for n,p in current_params.items():
            loss += (self.fisher[n] * (p - self.params[n]).pow(2)).sum()
        return loss

# 训练流程示例
device = torch.device('cuda')
model = EWC_CNN().to(device)
old_task_loader = ...  # 旧任务数据加载器
new_task_loader = ...  # 新任务数据加载器

# 第一阶段:训练旧任务
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
    for batch in old_task_loader:
        # 常规训练流程...
        
# 第二阶段:计算EWC约束
ewc = EWC_Regularizer(model, old_task_loader, device)

# 第三阶段:增量学习新任务
for epoch in range(10):
    for batch in new_task_loader:
        inputs, labels = batch
        outputs = model(inputs.to(device))
        ce_loss = nn.CrossEntropyLoss()(outputs, labels.to(device))
        ewc_loss = ewc.penalty(dict(model.named_parameters()))
        total_loss = ce_loss + 1000 * ewc_loss  # λ=1000
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

五、可视化解析

1. 参数空间分布

import matplotlib.pyplot as plt
import seaborn as sns

# 生成模拟数据
theta_old = np.random.normal(0,1,1000)  # 旧任务参数分布
theta_new = np.random.normal(3,1,1000)  # 新任务参数分布
theta_ewc = 0.3*theta_old + 0.7*theta_new  # EWC约束参数

# 可视化
plt.figure(figsize=(12,6))
sns.kdeplot(theta_old, label="Old Task", color='grey', linewidth=3)
sns.kdeplot(theta_new, label="New Task", color='gold', linewidth=3)
sns.kdeplot(theta_ewc, label="EWC Compromise", color='red', linestyle='--')
plt.title("Parameter Space Distribution")
plt.xlabel("Parameter Value"), plt.ylabel("Density")
plt.legend()
plt.show()

六、技术演进路线

阶段代表方法关键突破局限
1.0参数冻结物理隔离旧知识丧失模型扩展能力
2.0L2正则简单约束参数漂移无差别保护所有参数
3.0EWC智能参数重要性加权计算二阶导数开销大
4.0动态网络独立适配器模块模型体积膨胀

http://www.kler.cn/a/557138.html

相关文章:

  • 泥沙输送的DEM-CFD耦合案例
  • Kubernetes控制平面组件:APIServer 基于 Webhook Toeken令牌 的认证机制详解
  • 广度优先搜索详解--BFS--蒟蒻的学习之路
  • 沃丰科技大模型标杆案例|周大福集团统一大模型智能服务中心建设实践
  • 蓝桥杯备赛-基础训练(三)哈希表 day15
  • 【论文阅读笔记】知识蒸馏:一项调查 | CVPR 2021 | 近万字翻译+解释
  • 垃圾回收知识点
  • 力扣-回溯-78 子集
  • 微信小程序-二维码绘制
  • ncrfp:一种基于深度学习的端到端非编码RNA家族预测新方法
  • E - Palindromic Shortest Path【ABC394】
  • 洛谷 P4644 USACO05DEC Cleaning Shift/AT_dp_w Intervals 题解
  • 体育品牌排行榜前十名:MLB·棒球1号位
  • 蓝禾,oppo,游卡,汤臣倍健,康冠科技,作业帮,高途教育25届春招内推
  • QUdpSocket的readyRead信号只触发一次
  • python-leetcode-两两交换链表中的节点
  • Scrapy:Downloader下载器设计详解
  • Open WebUI 是什么
  • 常用Web性能指标
  • 游戏引擎学习第117天