当前位置: 首页 > article >正文

【nlp】知识蒸馏Distilling

一、知识蒸馏介绍

1. 什么是知识蒸馏?

在这里插入图片描述
知识蒸馏(Knowledge Distillation) 是一种用于模型压缩的技术,通过让小模型(称为学生模型,student model)从大模型(称为教师模型,teacher model)中学习,从而提高小模型的性能,同时保留大模型的一部分知识。知识蒸馏常用于深度学习中,以减少计算资源和内存需求,使得模型可以在资源受限的设备上运行,比如移动设备和嵌入式系统。

2. 轻量化网络方式有哪些?

1. 压缩已训练好的模型

  • 知识蒸馏:将大模型的知识传递给小模型,通过模仿大模型的输出提高小模型的性能。
  • 权值量化:将浮点数表示的权重转换为低精度整数(如 INT8)表示,减少模型体积和计算量。
  • 权重剪枝:移除不重要的权重或神经元,减少参数量和计算开销。
  • 通道剪枝:剪掉卷积层的某些通道,降低卷积计算的复杂性。
  • 注意力迁移:通过让小模型学习大模型的注意力机制,使其更好地关注重要的特征。

2. 直接训练轻量化网络

  • SqueezeNet:使用较少的参数量进行等效卷积操作。
  • MobileNetv1/v2/v3:引入深度可分离卷积(depthwise separable convolution)和倒残差结构,显著减少计算量。
  • MnasNet:通过神经架构搜索(NAS)设计的轻量化网络。
  • ShuffleNet:通过通道洗牌来优化组卷积的性能。
  • Xception:一种极度优化的深度可分离卷积网络。
  • EfficientNet:通过复合缩放(compound scaling)策略优化网络深度、宽度和分辨率。
  • EfficientDet:专门针对目标检测任务的轻量化网络,基于 EfficientNet 设计。

3. 加速卷积运算

  • im2col + GEMM:通过将卷积运算转换为矩阵乘法(General Matrix Multiplication)来加速计算。
  • Winograd 算法:用于减少卷积计算中的乘法操作,提升速度。
  • 低秩分解:将卷积核进行分解,减少参数量和计算量。

4. 硬件部署

  • TensorRT:NVIDIA 的深度学习推理库,通过优化模型来加速推理。
  • Jetson:NVIDIA 的嵌入式 AI 计算平台,适合低功耗场景。
  • TensorFlow-Slim:TensorFlow 中的轻量化网络构建工具,用于快速构建轻量模型。
  • OpenVINO:Intel 的推理工具套件,专注于边缘设备上的高效推理。
  • FPGA 集成电路:通过定制的集成电路实现高效的并行化计算,加速推理。

这些技术方法组合使用,可以在保持模型性能的同时大幅减少计算资源和存储需求,适合资源受限的应用场景如移动设备和嵌入式系统。

3. 软标签 vs 硬标签

  • 硬标签(hard targets):通常是训练数据的真实标签,通常采用 one-hot 编码。例如,图片分类任务中,图片所属的正确类别的概率为 1,其他类别的概率为 0。

  • 软标签(soft targets):通过教师模型的输出概率分布得到的标签。与硬标签不同,软标签是一个概率分布,包含了教师模型在所有类别上的预测概率。即使是错误类别,教师模型也会分配一个非零的概率。这些概率可以反映类别之间的相似性。

例如,对于一张图片,教师模型可能给出以下预测分布:

类别 A:70%
类别 B:20%
类别 C:5%
类别 D:5%

这表示该图片最有可能属于类别 A,但类别 B 也有一定的可能性。这样的概率分布提供了比硬标签(如 100% 属于类别 A)更多的细粒度信息。

在这里插入图片描述

4. 蒸馏温度 T T T

知识蒸馏中的温度作用

在标准的分类任务中,模型输出的是每个类别的预测概率,这些概率通常通过 Softmax 函数计算得到。Softmax 函数的定义如下:

P ( y = i ) = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) P(y=i) = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)} P(y=i)=jexp(zj/T)exp(zi/T)

其中, z i z_i zi 是第 i i i 类的 logits(即模型输出的未归一化分数), T T T 是温度参数。当 T = 1 T=1 T=1 时,Softmax 函数表现为标准的形式。如果温度 T > 1 T > 1 T>1,Softmax 的输出将变得“更软”,即各类之间的概率分布更加均匀;如果 T < 1 T < 1 T<1,Softmax 输出将变得更加尖锐,接近 one-hot 分布。

在知识蒸馏过程中,使用较高的温度(通常 T > 1 T>1 T>1)可以使教师模型输出的概率分布变得更加平滑,突出各类别之间的相对差异,而不是仅仅关注最高的概率类别。学生模型可以从这种软标签中学习到更多关于各个类别之间关系的信息,而不仅仅是从硬标签中学到的正确类别。
在这里插入图片描述

举例说明

假设我们有一个图像分类任务,教师模型是一种复杂的深度卷积神经网络(例如 ResNet-50),而学生模型是一个较小的模型(例如一个简单的卷积神经网络)。在知识蒸馏过程中,我们用教师模型的输出作为指导,来帮助学生模型学习。

  1. 无温度调整的硬标签训练: 学生模型仅从每个输入样本的真实类别标签(即 one-hot 编码)中学习,这些标签并没有包含类别之间的相关性或其他信息。

  2. 知识蒸馏中的软标签: 使用知识蒸馏时,首先通过调整温度参数 T > 1 T>1 T>1,教师模型输出的类别分布会变得更平滑。例如,对于某个输入图像,教师模型可能预测类别 A 的概率是 0.9,类别 B 的概率是 0.05,类别 C 的概率是 0.03,类别 D 的概率是 0.02。在温度调整后(例如 T = 5 T=5 T=5),这个分布可能会变为类别 A 的概率是 0.4,类别 B 的概率是 0.3,类别 C 的概率是 0.2,类别 D 的概率是 0.1。这个平滑后的分布反映了不同类别之间的相似性。

  3. 学生模型学习: 学生模型从这个更平滑的概率分布中学习,不仅学到了类别 A 的重要性,还学习到了类别 B 和类别 C 与类别 A 的相关性。这样可以帮助学生模型更好地理解数据之间的模式,从而提高泛化性能。

温度选择

温度 T T T 的选择非常关键,它决定了知识蒸馏的效果。较高的温度使得概率分布更平滑(矮胖),能够传递更多的类别信息,但也可能导致过度平滑,使得学生模型难以捕捉有用的信息。通常需要通过实验来确定最适合的温度值。

二、知识蒸馏过程

1. 知识蒸馏的过程

在这里插入图片描述

1. 输入数据

设输入数据为 x x x,同时输入给教师模型和学生模型。

2. 教师模型输出(Teacher Model Output)

教师模型是一个较复杂的神经网络,其通过 softmax 函数生成软标签。softmax 函数使用温度参数 T T T 来控制输出概率的平滑度:

q i teacher = exp ⁡ ( z i teacher / T ) ∑ j exp ⁡ ( z j teacher / T ) q_i^{\text{teacher}} = \frac{\exp(z_i^{\text{teacher}} / T)}{\sum_j \exp(z_j^{\text{teacher}} / T)} qiteacher=jexp(zjteacher/T)exp(ziteacher/T)

其中:

  • q i teacher q_i^{\text{teacher}} qiteacher 是教师模型生成的类别 i i i 的概率。
  • z i teacher z_i^{\text{teacher}} ziteacher 是教师模型的第 i i i 类别的 logit 值。
  • T T T 是温度参数,当 T > 1 T > 1 T>1 时,输出概率分布更加平滑,有助于学生模型学习类别间的相似性。

3. 学生模型输出(Student Model Output)

学生模型是一个较小的模型,它通过学习教师模型的软标签和真实标签来提高性能。学生模型的输出也通过 softmax 函数生成。

软预测(Soft Predictions):

学生模型生成的软预测是通过与教师模型相同温度 T T T 的 softmax 函数计算的:

q i student = exp ⁡ ( z i student / T ) ∑ j exp ⁡ ( z j student / T ) q_i^{\text{student}} = \frac{\exp(z_i^{\text{student}} / T)}{\sum_j \exp(z_j^{\text{student}} / T)} qistudent=jexp(zjstudent/T)exp(zistudent/T)

硬预测(Hard Predictions):

学生模型还生成硬预测,即通过正常的 softmax(温度 T = 1 T = 1 T=1)生成的标准输出,用于匹配真实标签:

q i hard = exp ⁡ ( z i student ) ∑ j exp ⁡ ( z j student ) q_i^{\text{hard}} = \frac{\exp(z_i^{\text{student}})}{\sum_j \exp(z_j^{\text{student}})} qihard=jexp(zjstudent)exp(zistudent)

4. 损失函数

为了训练学生模型,我们引入两个损失函数:

4.1 蒸馏损失(Distillation Loss)

蒸馏损失用于衡量学生模型的软预测和教师模型的软标签之间的差异。它通过使用**Kullback-Leibler 散度(KL 散度)**来度量这两个概率分布之间的距离:

L distill = KL ( q teacher , q student ) = ∑ i q i teacher log ⁡ ( q i teacher q i student ) L_{\text{distill}} = \text{KL}(q^{\text{teacher}}, q^{\text{student}}) = \sum_i q_i^{\text{teacher}} \log\left(\frac{q_i^{\text{teacher}}}{q_i^{\text{student}}}\right) Ldistill=KL(qteacher,qstudent)=iqiteacherlog(qistudentqiteacher)

其中:

  • q teacher q^{\text{teacher}} qteacher 是教师模型生成的软标签。
  • q student q^{\text{student}} qstudent 是学生模型生成的软预测。
4.2 学生损失(Student Loss)

学生损失是学生模型的硬预测与真实标签(硬标签)之间的差异,通常使用交叉熵损失计算:

L student = − ∑ i y i log ⁡ ( q i hard ) L_{\text{student}} = - \sum_i y_i \log(q_i^{\text{hard}}) Lstudent=iyilog(qihard)

其中:

  • y i y_i yi 是真实标签的 one-hot 编码。
  • q i hard q_i^{\text{hard}} qihard 是学生模型对类别 i i i 的硬预测概率。

5. 总损失函数

最终的总损失函数是蒸馏损失学生损失的加权和:

L total = α L student + β L distill L_{\text{total}} = \alpha L_{\text{student}} + \beta L_{\text{distill}} Ltotal=αLstudent+βLdistill

其中:

  • α \alpha α β \beta β 是权重系数,控制学生损失和蒸馏损失的相对重要性。通常 α \alpha α 可以设置为 1, β \beta β 可以调整以控制蒸馏的影响。
  • 为了确保梯度的缩放一致性,蒸馏损失部分的梯度通常会乘以 T 2 T^2 T2,因为软标签的梯度会随温度 T T T 缩放。

6. 温度参数 T T T 的影响

温度参数 T T T 控制了 softmax 函数的平滑程度。较高的 T T T 会使教师模型的输出概率分布更加平滑,从而让学生模型能够学习到类别间的相对关系。这些信息可以帮助学生模型提高泛化能力。

  • T = 1 T = 1 T=1 时,softmax 输出接近 one-hot 编码,类别之间的相对信息较少。
  • T > 1 T > 1 T>1 时,类别之间的概率差异缩小,学生模型 可以从这些更平滑的概率中学习到更多的信息。

在这里插入图片描述

前边的两个图是训练过程,后边一个图是预测过程。

2. 知识蒸馏发展趋势

  1. 教学助长
  2. 助教、多个老师、多个同学
  3. 知识的表示(中间层)、数据集蒸馏、对比学习
  4. 多模态、知识图谱、预训练大模型的知识蒸馏

论文:
Attention Transfer
channel-wise knowledge distillation for dense prediction
contrastive representation Distillation
Distill BERT

3. 实现代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# 1. 定义教师模型(较大的模型)
class TeacherModel(nn.Module):
    def __init__(self):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 1200)
        self.fc2 = nn.Linear(1200, 1200)
        self.fc3 = nn.Linear(1200, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 2. 定义学生模型(较小的模型)
class StudentModel(nn.Module):
    def __init__(self):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 800)
        self.fc2 = nn.Linear(800, 800)
        self.fc3 = nn.Linear(800, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 3. 定义蒸馏损失函数
def distillation_loss(student_outputs, teacher_outputs, labels, T, alpha):
    """
    :param student_outputs: 学生模型的输出
    :param teacher_outputs: 教师模型的输出
    :param labels: 真实标签
    :param T: 温度参数
    :param alpha: 学生损失与蒸馏损失的权重
    :return: 总损失
    """
    # 计算学生模型的硬标签损失(交叉熵损失)
    hard_loss = F.cross_entropy(student_outputs, labels)

    # 计算软标签损失(KL 散度)
    soft_student = F.log_softmax(student_outputs / T, dim=1)
    soft_teacher = F.softmax(teacher_outputs / T, dim=1)
    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T * T)

    # 总损失 = α * 硬损失 + (1 - α) * 软损失
    return alpha * hard_loss + (1 - alpha) * soft_loss


# 4. 数据加载
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


# 5. 训练过程
def train_student(teacher_model, student_model, train_loader, optimizer, T, alpha, epochs=5):
    teacher_model.eval()  # 教师模型是预训练的,设置为 eval 模式
    student_model.train()  # 学生模型将要训练

    for epoch in range(epochs):
        total_loss = 0.0
        for images, labels in train_loader:
            # images, labels = images.cuda(), labels.cuda()
            # 教师模型预测
            with torch.no_grad():
                teacher_outputs = teacher_model(images)

            # 学生模型预测
            student_outputs = student_model(images)

            # 计算蒸馏损失
            loss = distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)

            # 优化器更新
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")


# 6. 测试过程
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            # images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1) # 返回每行(每个样本)的最大值和对应的索引
            total += labels.size(0) #更新样本计数
            correct += (predicted == labels).sum().item() # 更新正确预测计数

    print(f'Test Accuracy: {100 * correct / total:.2f}%')


# 7. 实例化模型并启动训练
# teacher_model = TeacherModel().cuda()
# student_model = StudentModel().cuda()
teacher_model = TeacherModel()
student_model = StudentModel()

# 假设教师模型已经预训练过
# 这里可以加载预训练的教师模型权重
# torch.load('teacher_model.pth', teacher_model)

optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 设定温度 T 和 α 参数
T = 20  # 温度
alpha = 0.7  # 学生损失与蒸馏损失的权重

# 训练学生模型
train_student(teacher_model, student_model, train_loader, optimizer, T, alpha, epochs=5)

# 测试学生模型
test_model(student_model, test_loader)

# 保存学生模型权重
torch.save(student_model.state_dict(),'student_model.pth')
# 保存教师模型权重
torch.save(teacher_model.state_dict(), 'teacher_model.pth')


输出:

Epoch [1/5], Loss: 0.4672
Epoch [2/5], Loss: 0.3833
Epoch [3/5], Loss: 0.3671
Epoch [4/5], Loss: 0.3566
Epoch [5/5], Loss: 0.3509
Test Accuracy: 98.44%

Process finished with exit code 0

http://www.kler.cn/news/356819.html

相关文章:

  • Postman发送GET、POST请求
  • 【重学 MySQL】七十二、轻松掌握视图的创建与高效查看技巧
  • 网络爬虫自动化Selenium模拟用户操作
  • Python知识点:基于Python工具,如何使用Ethereum Tester进行智能合约测试
  • python中else使用汇总
  • docker启动MySQL容器失败原因排查记录
  • 力扣 142.环形链表Ⅱ【详细解释】
  • C#的自定义对话框和提示窗体 - 开源研究系列文章
  • Shell脚本:用户和用户组管理全面指南
  • 如何用代码将网页打开
  • Hbase安装及使用
  • OpenCV高级图形用户界面(6)获取指定窗口中图像的矩形区域函数getWindowImageRect()的使用
  • 业务逻辑漏洞之墨者学院靶场——身份认证失效
  • 【文化课学习笔记】【化学】选必三:同分异构体的书写
  • 初识Linux之指令(二)
  • 学习资料分享平台计算机毕设基于SpringBootSSM框架
  • 【经典卷积网络】(一)——LeNet-5
  • perl 给特定文件加上特定内容
  • DBeaver导出数据表结构和数据,导入到另一个环境数据库进行数据更新
  • Java中的equals()和hashCode()方法是如何工作的?