当前位置: 首页 > article >正文

模型压缩——如何进行知识蒸馏?

1.引言

目前的主流的大语言模型基本都是部署在云端,但由于安全隐私和延时问题,有越来越多的场景需要将AI模型部署在边缘设备上,例如个人电脑、智能手机、物联网设备。如何将大模型的能力迁移到小设备上,就成为一个重要的研究方向。

云端有着强大的计算资源,如下图中的NVIDIA A100显卡,能够提供高达19.5 TFLOPS的浮点运算能力,并配备高达80GB的内存。但边缘设备通常只有极为有限的计算能力和内存,它们无法处理像云端一样复杂的神经网络。
在这里插入图片描述

因此,需要压缩来减少模型的参数和计算需求。前面的剪枝、神经网络架构搜索都是为这个目标服务的。除这些技术之外,知识蒸馏(Knowledge Distillation, KD) 也是一种有效的模型压缩技术,本文将会来讨论如何使用知识蒸馏来压缩模型。

2.基本概念

知识蒸馏基本思想是:通过大模型(教师模型)的指导来训练小模型(学生模型),从而让小模型的性能接近大模型,

知识蒸馏的基本做法是:同时使用教师模型和学生模型对输入数据进行预测,然后对齐教师模型和学生模型的输出概率分布,让学生模型来学习教师模型的行为模式,从而达到减少模型大小并保留性能的目的。
在这里插入图片描述

  • 教师模型(Teacher Model):一个预先训练好的复杂模型,性能优异,提供软标签作为知识,用于指导学生模型的学习;
  • 学生模型(Student Model):需要被训练的较简单的模型,通过学习教师模型的知识,以实现接近教师模型性能的目的;
  • 硬标签(Hard Label):实际训练集中,为每个输入样本分配的真实类别标签,使用独热编码(One-Hot Label);
  • 软标签(Soft Label):教师模型输出的概率分布,可以通过温度将输出的概率分布软化,使学生模型能更好的学习概率间的相对关系。
  • 硬损失(Hard Loss):学生模型与真实标签之间的交叉熵损失,用于确保学生模型在标准分类任务上的性能表现。
  • 软损失(Soft Loss):学生模型与教师模型输出之间的损失,通过KL散席来衡量,用于帮助学生模型学习教师模型的预测分布;
  • 总损失(Total Loss):软损失和硬损失的加权和形成总损失,可通过调节权重 λ \lambda λ来控制学生模型在多大程度上依赖教师模型的指导。

3.模型和数据

3.1 模型封装

模型封装主要是确定教师模型和学生模型的结构。

教师模型会直接复用之前模型剪枝中定义的LeNet网络,而学生模型则会通过减少层数和通道数来定义一个半通道的LeNet网络。

首先,我们引进模型剪枝中已经编写过的数据、训练、评估相关的代码。

%run lenet.py

定义一个半通道数的LeNet网络作为学生模型结构,具体在如下方面进行了结构缩减:

  • 卷积层:由原来的两层缩减为一层,通道数由6减为3;
  • 线性层:由原来的3层缩减为一层,将池化层的输出直接用于类别预测;
class LeNetHalfChannel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5) # 3x24x24
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)  # 3x12x12
        self.fc1 = nn.Linear(in_features=3*12*12, out_features=num_classes)
        
    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc1(x))
        return x
3.2 模型实例化

实例化教师模型并加载已训练过的模型权重:

teacher_model = LeNet()
teacher_model.load_state_dict(torch.load("./checkpoint/model.pt"))
print("teacher_model params number:", count_parameters(teacher_model))
    teacher_model params number: 44426

实例化学生模型,并对比学生模型和教师模型的参数数量:

student_model = LeNetHalfChannel()
print("student_model params number:", count_parameters(student_model))
    student_model params number: 4408

学生模型参数量只有教师模型参数量的约1/10。

3.3 数据加载

数据部分直接用神经网络架构搜索一节中封装过的load_data方法来加载数据,得到训练集和测试集。

train_loader, test_loader = load_data("./data", batch_size=64)
len(train_loader), len(test_loader)
    (938, 157)

4.训练封装

4.1 蒸馏损失函数

首先,我们需要实现一个知识蒸馏的专用损失函数。假设已经通过教师模型和学生模型得两个logits。

logits_student = torch.tensor([[2.0, 1.0, 0.1], [1.0, 3.0, 0.2]])
logits_teacher = torch.tensor([[1.5, 0.5, 0.3], [0.8, 2.5, 0.5]])

第一步需要对两个logits分别执行softmax变换,得到概率分布的对数表示(假设温度为2.0):

temperature = 4.0
pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
pred_student, pred_teacher
    (tensor([[-0.8758, -1.1258, -1.3508],
             [-1.2434, -0.7434, -1.4434]]),
     tensor([[0.3969, 0.3091, 0.2940],
             [0.2892, 0.4424, 0.2683]]))

注:logits_student 除以 temperature 是为了平滑分数,高温(大于1)会让输出概率更平滑,能让模型学习标签之间的相互联系,更有利于知识蒸馏; 而低温(小于等于1)会使得正负标签的概率分布比较尖锐,过于强调正确答案,而忽略了错误答案的信息。

举个例子,对一张图像进行分类识别,有马(正确标签)、驴、车三种可能,虽然驴和车都于错误标签,但驴显然比车更像马,识别为驴的概率应该大于车,平滑的概率分布更有利于保留这些信息。

接着,用KL散度计算学生模型与教师模型两者预测概率之间的差异。

KL 散度(Kullback-Leibler Divergence) 可以帮助我们了解一个分布在多大程度上与另一个分布不同,计算结果为0时表示两个分布完全一样,散度越大表示两个分布之间的差异越大。

kl_div = F.kl_div(pred_student, pred_teacher, reduction="none")
kl_div_sum = kl_div.sum(1)  # 对每个样本的所有类别的 KL 散度值求和
kl_div_mean = kl_div_sum.mean() # 对所有样本 KL 散度求和的结果取平均值,即为散度损失。
kl_div_mean
    tensor(0.0032)

将损失值乘以温度系数的平方,以补偿前面除以温度系数带来的缩放效应,确保损失值的量级合理。

loss_kd = kl_div_mean * temperature**2
loss_kd
    tensor(0.0511)

将上面的实现过程封装为一个KL散度损失函数。

def kl_div(logits_student, logits_teacher, temperature):
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature**2
    return loss_kd
4.2 超参定义

定义训练需要的超参数、损失函数和优化器。硬损失函数使用交叉熵,软损失函数使用刚定义的KL散度损失函数。

lr = 0.01
momentum = 0.5
num_epochs = 5
temperature = 4 # 温度
device = 'cpu'
soft_loss_fn = kl_div
hard_loss_fn = nn.CrossEntropyLoss() 
4.3 训练函数

定义单轮训练函数,大概实现思路如下:

  • 将学生模型设为训练模式,表示只对学生模型进行训练;
  • 分别对学生模型和教师模型进行前向传播,得到两个模型各自的预测结果stu_logits和teacher_logits;
  • 分别计算硬损失(学生预测与真实标签)和软损失(学生预测与教师预测);
  • 按照预先设定的比例因子对硬损失和软损失进行加权,得到综合损失total_loss;
  • 对综合损失进行反向传播并更新模型参数。
def train_epoch(epoch, t_model, stu_model, dataloader, optimizer, alpha, device):
    stu_model.train()
    for inputs, targets in tqdm(dataloader, desc=f"training epoch: {epoch}"):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        
        stu_logits = stu_model(inputs)
        hard_loss = hard_loss_fn(stu_logits, targets)
        if alpha > 0.0:
            with torch.no_grad():
                teacher_logits = t_model(inputs)
                # print("teacher_logits == None: ", teacher_logits == None)
            soft_loss = soft_loss_fn(stu_logits, teacher_logits, temperature)
        else:
            soft_loss = 0.0
        total_loss = alpha * soft_loss + (1-alpha) * hard_loss

        total_loss.backward()
        optimizer.step()

注:alpha参数除了可以作为加权的比例因子,也可以作为蒸馏的开关,当alpha=0时,教师模型将不纳入损失计算,相当于恢复到普通训练的模式。

下面定义一个多轮训练评估函数。不断的对学生模型进行迭代训练(num_epochs),在每轮训练后评估其性能,并通过性能对比来得到一个最优的模型状态。

def train(num_epochs, t_model, stu_model, train_loader, test_loader, optimizer, alpha, device):
    best_accuracy = 0.0
    best_checkpoint_state = None
    for i in range(num_epochs):
        train_epoch(i, t_model, stu_model, train_loader, optimizer, alpha, device)
        accuracy = evaluate(stu_model, test_loader, device)
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_checkpoint_state = copy.deepcopy(stu_model.state_dict())
        print(f"epoch:{i + 1:>2d}, accuracy:{accuracy:.2f}%, best_accuracy:{best_accuracy:.2f}%")
    
    stu_model.load_state_dict(best_checkpoint_state)
    model_accuracy = evaluate(stu_model, test_loader, device)
    print(f"model accuracy: {model_accuracy:.2f}%")
    return model_accuracy

5.训练过程

5.1 直接训练

先不引入教师模型(将软损失权重的比例因子alpha设为0),直接训练观察效果。

set_seed(0)

stu_model_1 = LeNetHalfChannel()
optimizer = torch.optim.SGD(stu_model_1.parameters(),  lr=lr, momentum=momentum)
train(num_epochs, None, stu_model_1, train_loader, test_loader, optimizer, 0.0, device)

在这里插入图片描述

5.2 知识蒸馏训练
set_seed(0)

stu_model_2 = LeNetHalfChannel()
optimizer_2 = torch.optim.SGD(stu_model_2.parameters(),  lr=lr, momentum=momentum)
train(num_epochs, teacher_model, stu_model_2, train_loader, test_loader, optimizer_2, 0.9, device)

在这里插入图片描述

对比上面的结果,同样的模型结构、训练数据,直接训练学生模型5轮下来只能到58.66%的准确率,而知识蒸馏5轮训练下来却能达到96.11%的准确率。

小结:本节从基本概念和相关术语开始,较为完整的介绍了知识蒸馏的基本思想和使用方法。在实践部分,我们引进了KL散度作为计算学生模型与教师模型之间损失的指标,随后以一个半通道数的LeNet模型为例,演示了直接训练和知识蒸馏两种方法在训练效果上的差异。

相关阅读

  • 如何进行神经网络架构搜索?
  • 如何进行模型剪枝

http://www.kler.cn/a/408347.html

相关文章:

  • 搜索引擎中广泛使用的文档排序算法——BM25(Best Matching 25)
  • 深入浅出,快速安装并了解汇编语言
  • SpringBoot多文件上传
  • 《现代制造技术与装备》是什么级别的期刊?是正规期刊吗?能评职称吗?
  • VSCode【下载】【安装】【汉化】【配置C++环境】【运行调试】(Windows环境)
  • PyTorch图像预处理:计算均值和方差以实现标准化
  • kotlin 的循环
  • 【MySQL】开发技术深度探索:mysql数据库复合查询全面详解
  • Group Convolution(分组卷积)
  • 1123--collection接口,list接口,set接口
  • scau编译原理综合性实验
  • 【数据结构】链表重难点突破
  • CTF之密码学(键盘加密)
  • Linux(2)
  • 16.C++STL 3(string类的模拟,深浅拷贝问题)
  • 〔 MySQL 〕中三种重要的日志类型
  • Java网络编程 - cookiesession
  • Vulnhub靶场 Jangow: 1.0.1 练习
  • C语言超详细教程
  • 挂壁式空气净化器哪个品牌的质量好?排名top3优秀产品测评分析
  • 网络性能及IO性能测试工具
  • golang实现TCP服务器与客户端的断线自动重连功能
  • 优先算法 —— 双指针系列 - 复写零
  • 青训营刷题笔记17
  • [自动化]获取每次翻页后的页面 URL
  • Java核心特性解析:方法、Stream流、文件与IO详解