当前位置: 首页 > article >正文

深度学习基础—Beam search集束搜索

引言

深度学习基础—Seq2Seq模型icon-default.png?t=O83Ahttps://blog.csdn.net/sniper_fandc/article/details/143781223?fromshare=blogdetail&sharetype=blogdetail&sharerId=143781223&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link

        上篇博客讲到,贪心算法在seq2seq模型中计算压力大,并且不一定能找到最好输出,为此,集束搜索算法就可以解决上述问题。

1.集束搜索算法

        假设有句子“Jane visite l'Afrique en Septembre.”(法语),我们希望翻译成英语:“Jane is visiting Africa in September.”(英语),词汇表有10000个词。要做的第一件事是把法语句子输入到编码器(上图绿色部分),解码器会输出第一个词的概率。

        在集束搜索算法中,有一个参数B(集束宽),表示每次考虑的概率最大的TopB个数的结果。假设B=3,如果第一个输出中概率最大的Top3的词是jane、in、september,那么就把这三个结果保存(贪心算法只会保存概率最大的第一个)。

        第二步,分别把选择的Top3的第一个输出输入到解码器的第二个单元中,即寻找在第一个词分别是in、jane、september情况下,第二个词的最大概率:

        每个softmax单元会输出10000个词的概率,一共有3种情况(第一个词是in、第一个词是jane、第一个词是september),因此第二个输出一共有3*10000=30000种输出,我们选择其中概率最大的Top3(集束宽B是3),假设选择了“in September”、“jane is”和“jane visits”并保存下来。

        第三步,分别把选择的Top3的第二个输出输入到解码器的第三个单元中,即寻找在前面的输出是in september、jane is、jane visits情况下,第三个词的最大概率:

        第三个单元还会输出3*10000个结果,我们仍然只选择概率最大的前三个并保存。最终解码器每输出一次增加一个单词,集束搜索算法最终会找到“Jane visits africa in september”这个句子,在句尾符号(上图编号8所示)终止生成。

        注意:观察集束搜索算法,可以发现其搜索树相当于每次生成3*10000个结果(除了第一次生成外),然后剪枝,只保留3种概率最大的结果,即节省了计算,有能保证最好的结果没有被网络丢掉(一般来说最好的结果和概率最大的结果概率相差不会很远,合理的选择集束宽B的大小可以保证找到最优的结果)

        如果集束宽等于1,只考虑1种可能结果,这实际上就变成了贪婪搜索算法。如果同时考虑多个可能的结果,比如3个、10个或者其他的个数,集束搜索通常会找到比贪婪搜索更好的输出结果。

2.改进算法

        集束搜索算法的优化目标是:

        这个其实就是解码器每一个单元输出的条件概率乘积。这些概率值通常远小于1,很多小于1的数相乘就会得到很小很小的数字,会造成数值下溢(numerical underflow),导致电脑的不能精确地存储浮点数。

        解决办法,取对数,优化目标变为:

        log函数严格单调递增,且更加平滑,(0,1)之间的数会被放缩到远离0的位置,从而不会出现数值下溢,如果log()最大了,那么概率P也是最大的。

        我们还可以对目标函数做归一化(归一化对数似然目标函数):

        其中,α是超参数,作用是让归一化更加柔和,取值[0,1],如果是1,说明完全用句子长度做归一化;如果取0,说明不进行归一化。由于目标函数的特性,句子越短概率越高,句子越长概率越低(条件概率的一个性质:事件越多,同时发生的可能性就越低)。因此优化目标会使概率最大化,从而不易于翻译长句子,因此就需要归一化,减少对长句子的惩罚力度。

        注意:如何选择集束宽B的大小?B越大,算法可选择的越多,结果越好,但计算更慢;B越小,算法选择越少,结果没那么好,但计算更快。一般选择B=10,B越大,模型的改善越小。从1到3或10,一般算法有很大的改善。但是当集束宽从1000增加到3000时,效果就没那么明显。

3.误差分析

        当我们的训练好的seq2seq模型在dev集(开发集)出现问题时,比如对于法语句子:“Jane visite l'Afrique en septembre”,集束搜索算法输出的翻译:“Jane visited Africa last September”(y^),而正确的人工翻译是:“Jane visits Africa in September”(y*),也就是不正确的翻译输出概率更大,说明模型出现了问题。模型主要由RNN网络(编码器和解码器)、集束搜索算法组成,那是哪一部分出现问题了呢?换句话说是RNN网络部分值得优化还是集束搜索算法值得优化?通常我们会用到误差分析来确定我们的优化方向:

        我们分别把两种情况的序列输入到解码器部分,计算输出的概率,得到如下两种情况:

        (1)集束搜索算法得到的y^的概率<人工翻译句子y*的概率:

        说明集束搜索算法得到的使概率最大化的句子(y^)并不是使概率最大化的句子(因为y*的概率更大),此时集束搜索算法效果更差,更值得优化集束搜索算法。可以调整集束宽度B的大小等等。

        (2)集束搜索算法得到的y^的概率>人工翻译句子y*的概率:

        说明集束搜索算法得到的使概率最大化的句子(y^)是使概率最大化的句子,但是y*显然是更好的句子,RNN应该输出的是y*的概率更大,此时RNN网络输出更差,更值得优化RNN网络。可以选择更优的结构、正则化、扩充数据集等等。

        RNN网络的优化技巧和神经网络的优化技巧一致,有关优化网络的技巧和误差分析见如下:

深度学习基础—参数调优icon-default.png?t=O83Ahttps://blog.csdn.net/sniper_fandc/article/details/141144823?fromshare=blogdetail&sharetype=blogdetail&sharerId=141144823&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link深度学习基础—正则化icon-default.png?t=O83Ahttps://blog.csdn.net/sniper_fandc/article/details/141176121?fromshare=blogdetail&sharetype=blogdetail&sharerId=141176121&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link深度学习基础—结构化机器学习项目icon-default.png?t=O83Ahttps://blog.csdn.net/sniper_fandc/article/details/141554291?fromshare=blogdetail&sharetype=blogdetail&sharerId=141554291&sharerefer=PC&sharesource=sniper_fandc&sharefrom=from_link

        现在,我们按照上述句子的处理思路,遍历开发集,寻找人工翻译句子和开发集之间的误差分布,统计情况(1)和(2)的比例,哪种情况的比例更高,就说明哪部分更值得优化。


http://www.kler.cn/a/396527.html

相关文章:

  • springboot接口返回数据给前端,BigDecimal为null但返回前端显示-1
  • Spring Boot3 实战案例合集上线了
  • 【Conda】Windows下conda的安装并在终端运行
  • 编译原理(手绘)
  • 如何在Debian系统里使用Redhat(CentOS)的方式配置网络
  • 2024 CCF中国开源大会“开源科学计算与系统建模openSCS”分论坛成功举办
  • 【原创】java+ssm+mysql物流信息网系统设计与实现
  • 木舟0基础学习Java的第三十三天(OA企业管理系统)
  • SpringBootCloud 服务注册中心Nacos对服务进行管理
  • 比特币前景再度不明,剧烈波动性恐即将回归
  • C/C++语言基础--initializer_list表达式、tuple元组、pair对组简介
  • vue2将webpack改为vite
  • 《Kotlin实战》-附录
  • 大数据实验9:Spark安装和编程实践
  • Jackson与GSON的深度对比
  • mybatis-plus: mapper-locations: “classpath*:/mapper/**/*.xml“配置!!!解释
  • 初学人工智不理解的名词3
  • 释放高级功能:Nexusflows Athene-V2-Agent在工具使用和代理用例方面超越 GPT-4o
  • 从电动汽车到车载充电器:LM317LBDR2G 线性稳压器在汽车中的多场景应用
  • springboot实现简单的数据查询接口(无实体类)
  • Java项目实战II基于微信小程序的订餐系统(开发文档+数据库+源码)
  • 本机ip地址和网络ip地址一样吗
  • AI服务器SAS硬盘汰换与数据抹除指南
  • HarmonyOS ArkUI(基于ArkTS) 开发布局 (中)
  • 基于STM32智能电流表
  • Python酷库之旅-第三方库Pandas(218)