模型压缩——如何进行知识蒸馏?
1.引言
目前的主流的大语言模型基本都是部署在云端,但由于安全隐私和延时问题,有越来越多的场景需要将AI模型部署在边缘设备上,例如个人电脑、智能手机、物联网设备。如何将大模型的能力迁移到小设备上,就成为一个重要的研究方向。
云端有着强大的计算资源,如下图中的NVIDIA A100显卡,能够提供高达19.5 TFLOPS的浮点运算能力,并配备高达80GB的内存。但边缘设备通常只有极为有限的计算能力和内存,它们无法处理像云端一样复杂的神经网络。
因此,需要压缩来减少模型的参数和计算需求。前面的剪枝、神经网络架构搜索都是为这个目标服务的。除这些技术之外,知识蒸馏(Knowledge Distillation, KD) 也是一种有效的模型压缩技术,本文将会来讨论如何使用知识蒸馏来压缩模型。
2.基本概念
知识蒸馏基本思想
是:通过大模型(教师模型)的指导来训练小模型(学生模型),从而让小模型的性能接近大模型,
知识蒸馏的基本做法
是:同时使用教师模型和学生模型对输入数据进行预测,然后对齐教师模型和学生模型的输出概率分布,让学生模型来学习教师模型的行为模式,从而达到减少模型大小并保留性能的目的。
- 教师模型(Teacher Model):一个预先训练好的复杂模型,性能优异,提供软标签作为
知识
,用于指导学生模型的学习; - 学生模型(Student Model):需要被训练的较简单的模型,通过学习教师模型的知识,以实现接近教师模型性能的目的;
- 硬标签(Hard Label):实际训练集中,为每个输入样本分配的真实类别标签,使用独热编码(One-Hot Label);
- 软标签(Soft Label):教师模型输出的概率分布,可以通过温度将输出的概率分布软化,使学生模型能更好的学习概率间的相对关系。
- 硬损失(Hard Loss):学生模型与真实标签之间的交叉熵损失,用于确保学生模型在标准分类任务上的性能表现。
- 软损失(Soft Loss):学生模型与教师模型输出之间的损失,通过KL散席来衡量,用于帮助学生模型学习教师模型的预测分布;
- 总损失(Total Loss):软损失和硬损失的加权和形成总损失,可通过调节权重 λ \lambda λ来控制学生模型在多大程度上依赖教师模型的指导。
3.模型和数据
3.1 模型封装
模型封装主要是确定教师模型和学生模型的结构。
教师模型会直接复用之前模型剪枝中定义的LeNet网络,而学生模型则会通过减少层数和通道数来定义一个半通道的LeNet网络。
首先,我们引进模型剪枝中已经编写过的数据、训练、评估相关的代码。
%run lenet.py
定义一个半通道数的LeNet网络作为学生模型结构,具体在如下方面进行了结构缩减:
- 卷积层:由原来的两层缩减为一层,通道数由6减为3;
- 线性层:由原来的3层缩减为一层,将池化层的输出直接用于类别预测;
class LeNetHalfChannel(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=5) # 3x24x24
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) # 3x12x12
self.fc1 = nn.Linear(in_features=3*12*12, out_features=num_classes)
def forward(self, x):
x = self.maxpool(F.relu(self.conv1(x)))
x = x.view(x.size()[0], -1)
x = F.relu(self.fc1(x))
return x
3.2 模型实例化
实例化教师模型并加载已训练过的模型权重:
teacher_model = LeNet()
teacher_model.load_state_dict(torch.load("./checkpoint/model.pt"))
print("teacher_model params number:", count_parameters(teacher_model))
teacher_model params number: 44426
实例化学生模型,并对比学生模型和教师模型的参数数量:
student_model = LeNetHalfChannel()
print("student_model params number:", count_parameters(student_model))
student_model params number: 4408
学生模型参数量只有教师模型参数量的约1/10。
3.3 数据加载
数据部分直接用神经网络架构搜索一节中封装过的load_data方法来加载数据,得到训练集和测试集。
train_loader, test_loader = load_data("./data", batch_size=64)
len(train_loader), len(test_loader)
(938, 157)
4.训练封装
4.1 蒸馏损失函数
首先,我们需要实现一个知识蒸馏的专用损失函数。假设已经通过教师模型和学生模型得两个logits。
logits_student = torch.tensor([[2.0, 1.0, 0.1], [1.0, 3.0, 0.2]])
logits_teacher = torch.tensor([[1.5, 0.5, 0.3], [0.8, 2.5, 0.5]])
第一步需要对两个logits分别执行softmax变换,得到概率分布的对数表示(假设温度为2.0):
temperature = 4.0
pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
pred_student, pred_teacher
(tensor([[-0.8758, -1.1258, -1.3508],
[-1.2434, -0.7434, -1.4434]]),
tensor([[0.3969, 0.3091, 0.2940],
[0.2892, 0.4424, 0.2683]]))
注:logits_student 除以 temperature 是为了平滑分数,高温(大于1)会让输出概率更平滑,能让模型学习标签之间的相互联系,更有利于知识蒸馏; 而低温(小于等于1)会使得正负标签的概率分布比较尖锐,过于强调正确答案,而忽略了错误答案的信息。
举个例子,对一张图像进行分类识别,有马(正确标签)、驴、车三种可能,虽然驴和车都于错误标签,但驴显然比车更像马,识别为驴的概率应该大于车,平滑的概率分布更有利于保留这些信息。
接着,用KL散度计算学生模型与教师模型两者预测概率之间的差异。
KL 散度(Kullback-Leibler Divergence) 可以帮助我们了解一个分布在多大程度上与另一个分布不同,计算结果为0时表示两个分布完全一样,散度越大表示两个分布之间的差异越大。
kl_div = F.kl_div(pred_student, pred_teacher, reduction="none")
kl_div_sum = kl_div.sum(1) # 对每个样本的所有类别的 KL 散度值求和
kl_div_mean = kl_div_sum.mean() # 对所有样本 KL 散度求和的结果取平均值,即为散度损失。
kl_div_mean
tensor(0.0032)
将损失值乘以温度系数的平方,以补偿前面除以温度系数带来的缩放效应,确保损失值的量级合理。
loss_kd = kl_div_mean * temperature**2
loss_kd
tensor(0.0511)
将上面的实现过程封装为一个KL散度损失函数。
def kl_div(logits_student, logits_teacher, temperature):
log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
loss_kd *= temperature**2
return loss_kd
4.2 超参定义
定义训练需要的超参数、损失函数和优化器。硬损失函数使用交叉熵,软损失函数使用刚定义的KL散度损失函数。
lr = 0.01
momentum = 0.5
num_epochs = 5
temperature = 4 # 温度
device = 'cpu'
soft_loss_fn = kl_div
hard_loss_fn = nn.CrossEntropyLoss()
4.3 训练函数
定义单轮训练函数,大概实现思路如下:
- 将学生模型设为训练模式,表示只对学生模型进行训练;
- 分别对学生模型和教师模型进行前向传播,得到两个模型各自的预测结果stu_logits和teacher_logits;
- 分别计算硬损失(学生预测与真实标签)和软损失(学生预测与教师预测);
- 按照预先设定的比例因子对硬损失和软损失进行加权,得到综合损失total_loss;
- 对综合损失进行反向传播并更新模型参数。
def train_epoch(epoch, t_model, stu_model, dataloader, optimizer, alpha, device):
stu_model.train()
for inputs, targets in tqdm(dataloader, desc=f"training epoch: {epoch}"):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
stu_logits = stu_model(inputs)
hard_loss = hard_loss_fn(stu_logits, targets)
if alpha > 0.0:
with torch.no_grad():
teacher_logits = t_model(inputs)
# print("teacher_logits == None: ", teacher_logits == None)
soft_loss = soft_loss_fn(stu_logits, teacher_logits, temperature)
else:
soft_loss = 0.0
total_loss = alpha * soft_loss + (1-alpha) * hard_loss
total_loss.backward()
optimizer.step()
注:alpha参数除了可以作为加权的比例因子,也可以作为蒸馏的开关,当alpha=0时,教师模型将不纳入损失计算,相当于恢复到普通训练的模式。
下面定义一个多轮训练评估函数。不断的对学生模型进行迭代训练(num_epochs),在每轮训练后评估其性能,并通过性能对比来得到一个最优的模型状态。
def train(num_epochs, t_model, stu_model, train_loader, test_loader, optimizer, alpha, device):
best_accuracy = 0.0
best_checkpoint_state = None
for i in range(num_epochs):
train_epoch(i, t_model, stu_model, train_loader, optimizer, alpha, device)
accuracy = evaluate(stu_model, test_loader, device)
if accuracy > best_accuracy:
best_accuracy = accuracy
best_checkpoint_state = copy.deepcopy(stu_model.state_dict())
print(f"epoch:{i + 1:>2d}, accuracy:{accuracy:.2f}%, best_accuracy:{best_accuracy:.2f}%")
stu_model.load_state_dict(best_checkpoint_state)
model_accuracy = evaluate(stu_model, test_loader, device)
print(f"model accuracy: {model_accuracy:.2f}%")
return model_accuracy
5.训练过程
5.1 直接训练
先不引入教师模型(将软损失权重的比例因子alpha设为0),直接训练观察效果。
set_seed(0)
stu_model_1 = LeNetHalfChannel()
optimizer = torch.optim.SGD(stu_model_1.parameters(), lr=lr, momentum=momentum)
train(num_epochs, None, stu_model_1, train_loader, test_loader, optimizer, 0.0, device)
5.2 知识蒸馏训练
set_seed(0)
stu_model_2 = LeNetHalfChannel()
optimizer_2 = torch.optim.SGD(stu_model_2.parameters(), lr=lr, momentum=momentum)
train(num_epochs, teacher_model, stu_model_2, train_loader, test_loader, optimizer_2, 0.9, device)
对比上面的结果,同样的模型结构、训练数据,直接训练学生模型5轮下来只能到58.66%
的准确率,而知识蒸馏5轮训练下来却能达到96.11%
的准确率。
小结:本节从基本概念和相关术语开始,较为完整的介绍了知识蒸馏的基本思想和使用方法。在实践部分,我们引进了KL散度作为计算学生模型与教师模型之间损失的指标,随后以一个半通道数的LeNet模型为例,演示了直接训练和知识蒸馏两种方法在训练效果上的差异。
相关阅读
- 如何进行神经网络架构搜索?
- 如何进行模型剪枝