当前位置: 首页 > article >正文

Distilling the Knowledge in a Neural Network(2015.5)(d补)


文章目录

  • Abstract
  • 1 Introduction
  • 2 Distillation
    • 2.1 Matching logits is a special case of distillation
  • Results

论文链接

Abstract

提高几乎所有机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测是很麻烦的,而且可能计算成本太高,无法部署到大量用户,特别是如果单个模型是大型神经网络。Caruana和他的合作者[1]已经证明,可以将集成中的知识压缩到一个更容易部署的单一模型中,并且我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,并且我们表明,通过将模型集合中的知识提取到单个模型中,我们可以显著改善大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型所混淆的细粒度类。与混合专家不同,这些专家模型可以快速并行地训练

1 Introduction

许多昆虫都有一个幼虫形态,它最擅长从环境中获取能量和营养,而一个完全不同的成虫形态,它最擅长于旅行和繁殖的不同需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,它可以使用大量的计算量。然而,部署到大量用户时,对延迟和计算资源的要求要严格得多。与昆虫的类比表明,如果能更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用dropout等非常强的正则化器训练的单个非常大的模型[9]。一旦繁琐的模型得到训练,我们就可以使用另一种训练,我们称之为**“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型中**。这种策略的一个版本已经由Rich Caruana和他的合作者开创[1]。在他们的重要论文中,他们令人信服地证明了由大量模型集合获得的知识可以转移到单个小模型中

一个概念上的障碍可能阻碍了对这种非常有前途的方法进行更多的研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何在保持相同知识的情况下改变模型的形式。将知识从任何特定实例中解放出来的更抽象的知识视图是,它是从输入向量到输出向量的学习映射。对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练模型为所有不正确答案分配概率,即使这些概率非常小,其中一些概率也比其他概率大得多。错误答案的相对概率告诉我们很多关于这个繁琐的模型是如何泛化的。例如,宝马的图像可能只有很小的机会被误认为是垃圾车,但这种错误仍然比将其误认为胡萝卜的可能性高很多倍。

将繁琐模型的泛化能力转移到小模型上的一个显而易见的方法是将繁琐模型产生的类概率作为训练小模型的“软目标”。对于这个迁移阶段,我们可以使用相同的训练集或单独的“迁移”集。当繁琐的模型是由许多简单模型组成的大集合时,我们可以使用单个预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,每个训练案例提供的信息量比困难目标大得多,训练案例之间的梯度方差也小得多,因此小模型通常可以在比原始繁琐模型少得多的数据上进行训练,并使用更高的学习率。

对于像MNIST这样的任务,繁琐的模型几乎总是产生非常高置信度的正确答案,关于学习函数的大部分信息存在于软目标中非常小的概率比率中。例如,一个2可能有10 −6的概率是3,10 −9的概率是7,而另一个版本可能是相反的。这是有价值的信息,它定义了数据的丰富相似性结构(即,它表示哪些2看起来像3,哪些看起来像7),但它在传递阶段对交叉熵成本函数的影响很小,因为概率非常接近于零。Caruana和他的合作者通过使用logits(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logits和小模型产生的logits之间的平方差。我们更通用的解决方案,称为“蒸馏”,是提高最终软最大值的温度,直到繁琐的模型产生合适的软目标集。然后,我们在训练小模型时使用相同的高温来匹配这些软目标。稍后我们将说明,匹配繁琐模型的对数实际上是蒸馏的一种特殊情况

用于训练小模型的转移集可以完全由未标记的数据组成[1],或者我们可以使用原始训练集。我们发现,使用原始的训练集效果很好,特别是如果我们在目标函数中添加一个小项,可以鼓励小模型预测真实目标,并匹配繁琐模型提供的软目标。通常,小模型不能完全匹配软目标,在正确答案的方向上犯错误是有帮助的。

2 Distillation

神经网络通常通过使用“softmax”输出层来产生类概率,通过将 zi 与其他 logits 进行比较,将每个类别的 logit z i 计算为概率 q i

T是温度,通常设为1。使用更高的T值会产生更柔和的类概率分布

在最简单的蒸馏形式中,通过在转移集上训练知识,并对转移集中的每个情况使用软目标分布,将知识转移到蒸馏模型中,该转移集中使用具有高温软最大值的繁琐模型产生的软目标分布。在训练蒸馏模型时使用相同的高温,但在训练完成后,它使用的温度为1。

当所有或部分转移集的正确标签已知时,可以通过训练蒸馏模型来产生正确的标签来显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,该交叉熵是在蒸馏模型的软最大值中使用与从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用相当低的权重通常可以获得最佳结果。由于软目标产生的梯度大小为1/ t2,因此在使用硬目标和软目标时,将它们乘以t2是很重要的。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生变化,则硬目标和软目标的相对贡献大致保持不变。

2.1 Matching logits is a special case of distillation

传递集中的每种情况相对于蒸馏模型的每个logit z i贡献了一个交叉熵梯度dC/dz i。如果繁琐模型的logits v i产生软目标概率p i,并且迁移训练在温度T下进行,则该梯度为:
如果温度比对数的大小高,我们可以近似:
如果我们现在假设对数在每个转移情况下都是Σj zj = Σj vj = 0,式3化简为:

所以在高温极限下,蒸馏等于最小化1/2(zi - vi) 2,只要对数分别为零。在较低的温度下,蒸馏很少注意匹配比平均值负得多的对数。这是潜在的优势,因为这些逻辑几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常负的对数可以传达关于繁琐模型所获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负对数可能是有帮助的。

Results

我们训练了10个独立的模型来预测P(h t |s t;θ),使用完全相同的架构和训练过程作为基线。用不同的初始参数值随机初始化模型,我们发现这在训练模型中产生了足够的多样性,使得集合的平均预测明显优于单个模型。我们已经探索了通过改变每个模型看到的数据集来增加模型的多样性,但是我们发现这不会显著改变我们的结果,所以我们选择了更简单的方法。对于蒸馏,我们尝试了[1,2,5,10]的温度,并对硬目标的交叉熵使用了0.5的相对权重,其中粗体表示表1中使用的最佳值
表1显示,实际上,我们的蒸馏方法能够从训练集中提取更多有用的信息,而不是简单地使用硬标签来训练单个模型。使用10个模型的集合所获得的帧分类精度提高的80%以上被转移到蒸馏模型上,这与我们在MNIST上的初步实验中观察到的改进相似。由于目标函数不匹配,集成对WER的最终目标(在23k个单词的测试集上)给出了较小的改进,但同样,集成实现的WER改进被转移到蒸馏模型上。

我们最近意识到通过匹配已经训练好的大型模型的类概率来学习小型声学模型的相关工作[8]。然而,他们使用大型未标记数据集在1的温度下进行蒸馏,他们的最佳蒸馏模型仅将小模型的错误率降低了28%,这是大模型和小模型在使用硬标签训练时错误率之间的差距


http://www.kler.cn/a/159928.html

相关文章:

  • ComfyUI-PromptOptimizer:文生图提示优化节点
  • 【Python】深入探讨Python中的单例模式:元类与装饰器实现方式分析与代码示例
  • QT:IconButton的动画效果
  • 春秋杯-WEB
  • 51c大模型~合集106
  • JWT在线解密/解码 - 加菲工具
  • ElasticSearch篇---第三篇
  • Leetcode—383.赎金信【简单】
  • Spring Cloud Gateway与spring-cloud-circuitbreaker集成与理解
  • 【IC前端虚拟项目】git和svn项目托管平台的简单使用说明
  • LeetCode Hot100 200.岛屿数量
  • Hadoop学习笔记(HDP)-Part.03 资源规划
  • 【Pytorch使用自制数据集,Dataloader】
  • 7.上传project到服务器及拉取服务器project到本地、更新代码冲突解决
  • Leetcode每日一题学习训练——Python3版(最小化旅行的价格总和)
  • Mac-idea快捷键操作
  • Android 横竖屏切换 窗口全屏
  • C++ 构造函数与析构函数
  • Python Flask 框架开发
  • K-Radar:适用于各种天气条件的自动驾驶4D雷达物体检测
  • 图形遍历效率低?试试 R 树
  • 【华为OD题库-043】二维伞的雨滴效应-java
  • 【C++】:set和map
  • PIKA,一个神奇的AI工具
  • 《LeetCode力扣练习》代码随想录——字符串(反转字符串---Java)
  • 学生上课睡觉老师的正确做法