当前位置: 首页 > article >正文

灾难性遗忘(catastrophic forgetting)学习笔记

  1. 深度学习在结构与参数两方面都植入了灾难性遗忘的基因:
    1. 深度学习的结构一旦确定,在训练过程中很难调整。神经网络的结构直接决定学习模型的容量。固定结构的神经网络意味着模型的容量也是有限的,在容量有限的情况下,神经网络为了学习一个新的任务,就必须擦除旧有的知识。
    2. 深度学习的隐含层的神经元是全局的,单个神经元的细小变化能够同时影响整个网络的输出结果。另外,所有前馈网络的参数与输入的每个维度都相连,新数据很大可能改变网络中所有的参数。我们知道,对于本身结构就已经固定的神经网络,参数是关于知识的唯一变化量。如果变化的参数中包含与历史知识相关性很大的参数,那么最终的效果就是,新知识覆盖了旧的知识。
  2. 解决大致有四种方法:
    1. 利用新数据训练的同时,不断用包含历史数据相关的信息刺激神经元,形成一种竞争,从而使历史知识相关的重要神经元的参数尽可能少的受影响,同时也保证了新知识能够被学习;通常称为Self-refreshing Memory Approaches
    2. 在开始训练新数据前,利用旧网络对新数据进行预测得到虚拟的训练数据【可以看作是旧网络的一个回忆】,目标函数中包含新旧网络的参数约束,每训练一个新数据,利用所有的虚拟数据约束旧参数,抑制遗忘;这类方法被称为知识蒸馏法(wsy:感觉这个可行性更强一些)
    3. 从另一个角度来约束参数的变化,文中认为参数是一个概率分布,只要在这个分布的核心地带,对于该任务就是可行的,不同的任务对应不同的概率分布,如果能找到两个分布重叠的部分,并将参数约束到这个区域,那么这一参数不就可以对这些任务都有效吗,这类方法被称之为Transfer Techniques法
    4. 第四类,其它方法,例如保留所有的历史数据,研究评判重要数据的技术,只保留那些重要的,信息量大的数据。这只是保留所有历史数据的一个改进版本,只要评判方法合理,肯定也能缓解灾难遗忘问题,
  3. 蒸馏神经网络:
    1. 一般简单网络需要面对更加具体的任务,是要被应用的。针对它要应用的任务,我们会有一些数据(数据量远比训练复杂网络时的数据量小),首先利用复杂网络预测这些数据的输出。现在我们有两套训练简单网络的数据:a)状态-真实输出(one-hot label);b)状态-复杂网络预测的概率输出(连续值)。先利用数据集b训练简单网络至稳定,然后利用数据集a继续训练。与标准步骤多出的部分就是先利用复杂网络预测的数据输出来引导简单网络,由于预测的输出值是概率值,连续的,因此用这类数据使网络更易收敛。此时,简单网络已经快速的学到一些粗糙的知识,在此基础上利用真实数据集继续训练就快得多了。
    2. 蒸馏神经网络的核心依据有两点:
      1.  训练完成的神经网络包含历史数据的输出分布信息;
      2. 神经网络具有相似的输入会得到相似的输出的特点。
  4. 要保证旧任务不遗忘,就需要旧任务相关的数据能够在新数据训练时不断刺激强化神经元,抑制遗忘发生。我们的高追求不允许采用记录历史数据的低级方式。旧任务的参数包含历史数据的分布信息,我们可以利用旧任务的参数生成一些虚拟的数据,这些数据相当于分布的采样。在训练新数据时,充分考虑这些虚拟数据。
  5. 在这个新的网络中,我们希望对于原来的任务其输出能和原来的网络的输出接近。采用上面的“回忆刺激神经元抑制遗忘”的想法,我们首先得产生这样的虚拟数据。还记得上面介绍的蒸馏神经网络吗?它直接将新数据输入到训练好的复杂网络中得到输出,并将输入—复杂网络预测的输出对组成新的数据集。我们也采用这样的做法产生需要的虚拟数据,只需将复杂网络替换成旧网络。与原始蒸馏神经网络的目的不一样,那里是加速简单网络的训练与稳定性,此处是用来缓解网络对新任务学习的灾难性遗忘问题。

http://www.kler.cn/a/1456.html

相关文章:

  • SpringBoot环境和Maven配置
  • 【Ubuntu】 Ubuntu22.04搭建NFS服务
  • 基于springboot的网上商城购物系统
  • 【网络协议】静态路由详解
  • 结构化日志和集中日志服务
  • 根据docker file 编译镜像
  • Linux中的标准IO【上】
  • FPGA纯verilog实现RIFFA的PCIE通信,提供工程源码和软件驱动
  • C++ 手撸简易服务器(完善版本)
  • 写CSDN博客两年半的收获--总结篇
  • Python3实现AI版贪吃蛇
  • AI_Papers周刊:第六期
  • Java面向对象:接口的学习
  • Vue学习 -- 如何用Axios发送请求(get post)Promise对象 跨域请求问题
  • 使用QT C++编写一个带有菜单和工具条的文本编辑器
  • QT串口助手开发3串口开发
  • C语言实例:字符转换为 ASCII 码,如何计算两个数的商,如何比较两个数的大小,如何交换两个数的值
  • VR全景城市,用720全景树立城市形象,打造3D可视化智慧城市
  • java-day01
  • 《Linux的权限》
  • 考研408每周一题(2019 41)
  • 嵌入式学习笔记——STM32的时钟树
  • 基于 Apache Flink 的实时计算数据流业务引擎在京东零售的实践和落地
  • 软件测试面试找工作你必须知道的面试技巧(帮助超过100人成功通过面试)
  • 【React】React入门--生命周期
  • 网络作业2【计算机网络】