当前位置: 首页 > article >正文

人工智能安全——大语言模型遗忘学习(LLM unlearning)与多目标优化算法

参考论文:

Multi-Objective Large Language Model Unlearning(ICASSP-2025)

背景

在人工智能安全问题日益严峻的背景下,减少大语言模型(LLM)的不良行为引起了极大的关注。采用Machine Unlearning是一种简单易用的方法,其目的是在不从头完全重新训练LLM的情况下有效消除LLM的不良行为。这里分享一篇信号处理与应用的顶级会议ICASSP-2025的论文:Multi-Objective Large Language Model Unlearning。文章发表了一种全新的“几何法”多目标优化,成功在保护LLM可用性的同时促进LLM对有害信息的遗忘。

挑战

采用传统Machine unlearning来做LLM unlearning会面临下述挑战:

  • 梯度爆炸

  • 灾难性遗忘

下面分别展开分析:

1. 梯度爆炸问题

梯度上升法(Gradient Ascent, GA)是Machine Unlearning中的一种经典方法。我在另一篇博客(人工智能安全与隐私——联邦遗忘学习)中也提到过,GA通过反转梯度,使得训练的时候不再是追求“loss下降”,而变成了“让loss上升”,从而降低模型在需要遗忘的数据上的精度,达到遗忘的效果。以CrossEntropy loss为例,其公式为:

 

其中C表示类别数; y_{i,c}是binary indicator(数值为0或1),其实就是将label转成one-hot编码后的第 c 个元素的值。 p_{i,c}表示模型对应的预测概率。

然而,采用梯度上升法来unlearn的时候,会让p_{i,c}趋于0,因此不可避免地会出现梯度爆炸问题,如下图蓝色虚线所示。

 

为此,文章通过对CrossEntropy loss中的概率p_{i,c}进行反转来解决这一问题,所提出的Unlearning CrossEntropy loss(UCE)表达式如下:

 

如上图橙色线所示,它不再需要梯度上升,而是用梯度下降来实现unlearning,并且有下界,不会影响收敛性。

2. 灾难性遗忘

在 LLM 忘却中,目标是双重的:不仅要忘掉目标忘记数据,即降低模型的性能,而且还要保持下游任务的模型效用(即性能)。前人通过引入KL散度来降低unlearning前后模型的output的差异,并用加权聚合的方法将KL loss与unlearning的loss给聚合成单目标。但显然,这两个目标经常会相互冲突,加权聚合并不能同时降低这两个loss。为此,文章将LLM unlearning视作一个多目标优化问题:

 

其中\mathcal{L}_{fgt}表示unlearning loss(这里叫forget loss):

 

\mathcal{L}_{KL}表示KL loss:

 

\mathcal{L}_{rt}表示模型在下游任务数据上的loss:

 

文章画了一个对应的梯度的冲突示意图:

 

由此可见,如果直接像前人那样加权聚合,那么得到的\bar d将会与其中的某些梯度相冲突(夹角大于90°),不再满足梯度下降,因此会出现两种情况:要么unlearning效果不稳定,甚至是没法unlearn,要么就破坏了模型的可用性。而文章主张计算一个公共的下降方向(common descent direction)来更新模型,以便既促成unlearning,又保护LLM可用性。

对偶法多目标优化

文章的目标是计算一个common descent direction来实现 \mathcal{L}_{fgt}\mathcal{L}_{KL},以及\mathcal{L}_{rt}的同时降低。跟传统的Multiple Gradient Descent Algorithm (MGDA)不同的是,文章提出了一种全新的“对偶法多梯度下降”的方法来求解。

如上图所示,粉色的三角锥表示所有满足common descent的向量组成的区域。只要计算出这个三角锥的“边”,那么这些边的任意凸组合(convex combination)都落在粉色的三角锥内部。在线性代数中,这个锥形区域被称为dual space:

 

记向量的维度为D,有m个向量g_1, g_2, \cdots, g_m并且向量组的秩为m\leq D,那么组成dual space的“边”向量满足这样的特征:

任意一条边都落在其中m-1个向量的法平面的交线上,并且与剩余的那1个向量的夹角小于90°,因此,文章提出的方法就是说,把g_1, g_2, \cdots, g_m的每个向量,分别投影到剩余m-1个向量的normal space上,就得到dual space的一条边了。

具体地,记\mathcal{L}_{fgt}对应的梯度g_{fgt}的对应的dual space的边为g_{fgt}^*,那么它由如下公式计算得到:

 

其中A表示将g_{fgt}, g_{KL}, g_{rt}拼成的矩阵。

类似地可以计算出g_{KL}^*g_{rt}^*

最终,文章取三者的平均值得到最终的模型更新方向:d^t = -\frac{1}{3}(g_{fgt}^*+g_{KL}^*+g_{rt}^*)。它满足对三个loss而言都梯度下降,从而达到既unlearning又保护模型效用的目的。

然后模型\theta的更新公式为:\theta^{t+1}=\theta^t+\eta^td^t

实验

文章用LLama3-8b模型在PKU-SAFERLHF数据集上进行测试,指标大概有这么几个:

模型的output是harmful的概率(Harmful rate, HR)、模型output的毒性(Toxicity),以及Obscenity;对于模型的可用性,用了Fluency的指标。实验结果如下:

 

后记

这篇文章提出的基于dual space的多目标优化方法让我耳目一新。相比于我之前看到过的MGDA方法,它不需要像MGDA那样把多目标优化问题转化成另一个优化问题去求解,而是直接通过矩阵运算来得到一个common descent direction,这是从未有过的。我觉得这篇论文已经超越了LLM unlearning这个topic了,更值得放在多目标优化领域进行深入探讨。最后希望这篇文章在帮助自己记录学习点滴之余,也能帮助大家!


http://www.kler.cn/a/467638.html

相关文章:

  • Java中的CAS操作是什么?它如何实现无锁编程?
  • 【Vim Masterclass 笔记05】第 4 章:Vim 的帮助系统与同步练习(L14+L15+L16)
  • 轻量级通信协议 JSON-RPC 2.0 详解
  • 前端学习-操作元素属性(二十三)
  • JavaScript系列(8)-- Array高级操作
  • CoppeliaSim和Python进行无人机联合仿真
  • 32单片机从入门到精通之软件编程——中断处理(九)
  • Spring Boot 3 实现 MySQL 主从数据库之间的数据同步
  • 手搓人工神经网络
  • Introducing Optimization
  • 基于生成式对抗网络(GAN)的前沿研究与应用
  • 单片机-独立按键矩阵按键实验
  • [Qt] 输入控件 | Line | Text | Combo | Spin | Date | Dial | Slider
  • python基于diagrams库绘制系统架构图
  • 基于Redis有序集合实现滑动窗口限流
  • 【C#特性整理】C#特性及语法基础
  • C# 找出给定三角形的所有角度(Find all angles of a given triangle)
  • 银行系统安全用电解决方案
  • Day29:continue 语句
  • 什么是.net framework,什么是.net core,什么是.net5~8,版本对应关系
  • linux 系统配置ip
  • Linux 内核中网络接口的创建与管理
  • win11 vs2022 opencv 4.10使用vs Image Watch插件实时可视化内存mat对象
  • 洛谷P5318 【深基18.例3】查找文献(c嘎嘎)
  • 常见的框架漏洞
  • 【OceanBase】使用 Superset 连接 OceanBase 数据库并进行数据可视化分析