当前位置: 首页 > article >正文

【论文阅读笔记】知识蒸馏带来的礼物:快速优化、网络最小化和迁移学习 | FSP

目录

一 方法

1 提出的蒸馏知识

2 蒸馏知识的数学表达

3 FSP矩阵的损失

4 学习过程

二 代码



论文题目:A Gift from Knowledge Distillation:  Fast Optimization, Network Minimization and Transfer Learning

论文地址:https://openaccess.thecvf.com/content_cvpr_2017/papers/Yim_A_Gift_From_CVPR_2017_paper.pdf

代码地址:GitCode - 全球开发者的开源社区,开源代码托管平台

KD ZOO:Knowledge-Distillation-Zoo/kd_losses at master · AberHu/Knowledge-Distillation-Zoo · GitHub

【摘要】介绍了一种新的知识迁移技术,即从一个预训练的深度神经网络( deep neural network,DNN )中提取知识并迁移到另一个DNN。由于DNN通过多层依次从输入空间映射到输出空间,我们定义了将要传输的知识以层间流的形式进行传输,通过计算两层特征之间的内积来计算。当我们将学生DNN和与学生DNN大小相同但没有教师网络训练的原始网络进行比较时,所提出的将蒸馏知识作为两层之间的流进行传输的方法表现出三个重要的现象( 1 )学习蒸馏知识的学生DNN比原始模型的优化速度要快得多( 2 )学生DNN优于原始DNN( 3 )学生DNN可以从在不同任务中训练的教师DNN中学习提取的知识,并且学生DNN的性能优于从零开始训练的原始DNN

本文的主要贡献如下:

提出了一种新颖的知识提取技术

这种方法对于快速优化是有用的。

使用提出的蒸馏知识来寻找初始权重可以提高小型网络的性能

即使学生DNN在与教师DNN不同的任务上进行训练,提出的蒸馏知识也提高了学生DNN的性能

图1。提出的迁移学习方法概念图。FSP矩阵表示从教师DNN中提取的知识,由两层特征生成。通过计算代表方向的内积来生成FSP矩阵,两层之间的流量可以用FSP矩阵来表示。

一 方法

提出的方法的主要思想是如何定义教师DNN的重要信息,并将提取的知识传递给另一个DNN。

介绍了我们在本研究中使用的有用的蒸馏知识。

介绍了我们提出的蒸馏知识的数学表达式。

基于精心设计的蒸馏知识,定义了损失项。

给出了学生DNN的整个学习过程。

1 提出的蒸馏知识

DNN逐层生成特征。更高层的特征更接近执行主要任务的有用特征。如果我们将DNN的输入视为问题,输出视为答案,那么我们可以将DNN中间产生的特征视为求解过程中的中间结果。遵循这一思路,Romero等[ 20 ]提出的知识迁移技术让学生DNN简单地模仿教师DNN的中间结果。然而,在DNN的情况下,有很多方法可以解决从输入生成输出的问题。从这个意义上说,模仿教师DNN的生成特征可以成为学生DNN的硬约束。

人的情况,教师针对一个问题讲解解题过程,学生学习解题程序的流程。学生DNN在输入特定问题时不一定要学习中间输出,但在遇到特定类型问题时可以学习求解方法。这样,我们认为演示问题的求解过程比讲授中间结果具有更好的推广性

2 蒸馏知识的数学表达

求解过程的流程可以通过两个中间结果之间的关系来定义。在DNN的情况下,这种关系可以通过两层特征之间的方向进行数学考虑。我们设计了FSP矩阵来表示求解过程的流程。其中h、w和m分别表示通道的高度、宽度和数目。计算FSP矩阵的公式见公式(1)

其中x和W分别表示输入图像和DNN的权重。我们使用CIFAR-10数据集训练了8,26,32层的残差网络。对于CIFAR-10数据集,残差网络中存在三个空间尺寸变化的点。我们选取了几个点来生成FSP矩阵,如图2所示。

图2。我们提出的方法的完整架构。教师和学生网络的层数可以改变。在保持相同空间尺寸的三个截面上提取FSP矩阵。我们提出的方法有两个阶段。在阶段1中,对学生网络进行训练,使学生网络和教师网络的FSP矩阵之间的距离最小。然后,将学生DNN的预训练权重用于第2阶段的初始权重第二阶段代表正常的训练过程

3 FSP矩阵的损失

为了帮助学生网络,我们从教师网络中传输提取的知识。如前所述,提取的知识以FSP矩阵的形式表示,其中包含求解过程的流程信息定义转移蒸馏知识任务的成本函数为公式(2)

其中,λ i和N分别表示每个损失项的权重和数据点的个数。我们假设整个损失条款同样重要。因此,我们对所有实验使用相同的λ i。

4 学习过程

我们的迁移方法使用教师网络生成的蒸馏知识。为了清楚地解释教师网络在我们的论文中代表什么,我们定义了两个条件。

第一教师网络应该通过一些数据集进行预训练。此数据集可以与学生网络将学习的数据集相同或不同。教师网络在迁移学习任务的情况下使用与学生网络不同的数据集

第二教师网络可以比学生网络更深或更浅

然而,我们考虑的是与学生网络相同或更深的教师网络。

学习过程包含两个阶段的训练。

最小化损失函数Lfsp,使学生网络的FSP矩阵与教师网络的FSP矩阵相似。

经过第一阶段的学生网络现在由第二阶段的主要任务损失进行训练

学习过程在算法1中解释如下。

【结论】提出了一种新颖的方法来从DNN生成蒸馏知识通过将提取的知识确定为由所提出的FSP矩阵计算的求解过程的流程,所提出的方法优于最先进的知识转移方法。我们在三个重要方面验证了我们提出的方法的有效性。所提出的方法可以更快地优化DNN,并生成更高水平的性能。此外,所提出的方法可用于迁移学习任务

二 代码

from __future__ import print_function

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


class FSP(nn.Module):
    """A Gift from Knowledge Distillation:
    Fast Optimization, Network Minimization and Transfer Learning"""
    def __init__(self, s_shapes, t_shapes):
        super(FSP, self).__init__()
        assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
        s_c = [s[1] for s in s_shapes]
        t_c = [t[1] for t in t_shapes]
        if np.any(np.asarray(s_c) != np.asarray(t_c)):
            raise ValueError('num of channels not equal (error in FSP)')

    def forward(self, g_s, g_t):
        s_fsp = self.compute_fsp(g_s)
        t_fsp = self.compute_fsp(g_t)
        loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]
        return loss_group

    @staticmethod
    def compute_loss(s, t):
        return (s - t).pow(2).mean()

    @staticmethod
    def compute_fsp(g):
        fsp_list = []
        for i in range(len(g) - 1):
            bot, top = g[i], g[i + 1]
            b_H, t_H = bot.shape[2], top.shape[2]
            if b_H > t_H:
                bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))
            elif b_H < t_H:
                top = F.adaptive_avg_pool2d(top, (b_H, b_H))
            else:
                pass
            bot = bot.unsqueeze(1)
            top = top.unsqueeze(2)
            bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)
            top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)

            fsp = (bot * top).mean(-1)
            fsp_list.append(fsp)
        return fsp_list

至此,本文分享的内容就结束啦。


http://www.kler.cn/a/558563.html

相关文章:

  • 开源协议深度解析:理解MIT、GPL、Apache等常见许可证
  • 案例自定义tabBar
  • 鸿蒙开发深入浅出02(封装Axios请求、渲染Swiper)
  • 本地部署轻量级web开发框架Flask并实现无公网ip远程访问开发界面
  • Prompt-to-Prompt 进行图像编辑
  • forge-1.21.x模组开发(二)给物品添加功能
  • 高速差分信号的布线
  • 怎么合并主从分支,要注意什么
  • PHP二手车置换平台系统小程序源码
  • 【蓝桥】动态规划-多维dp-地图(带有转向次数限制)
  • stm32四种方式精密控制步进电机
  • 理解 “边缘计算“
  • 【C++】模版
  • 细说 Java 引用(强、软、弱、虚)和 GC 流程(二)
  • linux系统如何配置host.docker.internal
  • 关于GeoPandas库
  • 【Golang 面试题】每日 3 题(六十四)
  • CentOS-7-x86_64-Minimal-2009 免费下载与使用教程
  • 【C语言】第七期——字符数组、字符串、类型转换
  • 3D Gaussian Splatting(3DGS)的核心原理