当前位置: 首页 > article >正文

beamsearch的计算过程和代码实现

Beam search(束搜索)是一种用于生成序列的搜索算法,常用于序列生成任务,例如机器翻译、语音识别和文本生成。它是一种启发式算法,旨在在生成序列时平衡搜索空间的广度和深度。

Beam search使用一个参数称为"beam width"(束宽度)来控制搜索的宽度,即在每个时间步骤选择保留的最有希望的候选项数量。在每个时间步骤,Beam search保留最有希望的K个候选项,其中K是束宽度。

下面是Beam search算法的详细步骤:

  1. 初始化:将初始输入作为序列的起始点,并将其放入候选项列表中。

  2. 生成候选项:对于每个候选项,使用模型(例如神经网络)生成下一个可能的元素或单词。

  3. 扩展候选项:对于每个候选项,将生成的元素添加到当前序列中,并计算相应的分数或概率。这些分数用于评估候选项的好坏。

  4. 剪枝:根据分数或概率对候选项进行排序,并选择当前分数最高的K个候选项,将其保留为下一步的候选项。

  5. 终止条件:如果生成的序列达到了预定的长度,或者满足特定的终止条件(例如遇到了终止标记),则停止搜索。

  6. 重复步骤2至5,直到到达终止条件。

  7. 返回结果:从最终的K个候选项中选择得分最高的序列作为最终的输出。

Beam search的优点是可以在生成序列时保持一定的多样性,因为它保留了多个候选项,并在每个时间步骤维护了一个较小的搜索空间。这有助于避免过于确定性的结果,并提供更多选择的可能性。

然而,Beam search也存在一些限制。它可能会陷入局部最优解,因为它只考虑了当前时间步骤的最有希望的候选项,并没有全局优化。此外,束宽度的选择也会影响结果,较小的束宽度可能会导致搜索空间不足,而较大的束宽度会增加计算成本。

算法关键点:在解码过程中,每次都挑选当前解码字的前k个最大概率的字符,第一轮可以得到k个结果,第二轮可以k2个结果,然后在这k2个结果中选择前k个最大概率的结果。依次类推...

具体步骤:

1.初始化Result列表用来存储每次得到的最大k个概率结果,初始化为[[list(),1]] 1为当前初始化的成绩

2.遍历解码长度S(解码出来S个字),

3.编历Result,用来为每个当前为止的最大k个结果解码出候选集

4.每个解码出的k个结果统一存储在Condidate列表中

5.按照成绩选取前k个作为Result,继续遍历,直到解码出S长度或者<eos>

from math import log
from numpy import array
from numpy import argmax

# 集束搜索
def beam_search_decoder(data, k):
	sequences = [[list(), 1.0]]#初始化存储最后结果的列表,存储k个
	# 遍历序列中的每一步
	for row in data:#序列的最大长度
		all_candidates = list()
		# 扩展每个候选项,即解码当前所得序列的下一个字
		for i in range(len(sequences)):
			seq, score = sequences[i]
			for j in range(len(row)):#计算每个词表中的字的成绩
				candidate = [seq + [j], score * -log(row[j])]
				all_candidates.append(candidate)
		# 根据分数排列所有候选项
		ordered = sorted(all_candidates, key=lambda tup:tup[1])
		# 选择k个最有可能的
		sequences = ordered[:k]
	return sequences

# 定义一个由10个单词组成的序列,单词来自于大小为5的词汇表
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解码输出序列
result = beam_search_decoder(data, 3)
# 打印结果

for seq in result:
	print(seq)

http://www.kler.cn/a/229739.html

相关文章:

  • zookeeper监听机制(Watcher机制)
  • 基于 JavaEE 的影视创作论坛
  • 【算法与数据结构】—— 回文问题
  • spring cloud stream
  • Docker 可视化工具
  • 计算机网络-差错控制(纠错编码 海明码 纠错方法)
  • javascript实现深度拷贝
  • 体悟PyTorch的优雅
  • Java集合框架在数据处理中的应用场景
  • 6-2、T型加减速计算简化【51单片机+L298N步进电机系列教程】
  • SQL基础
  • Positive Technologies 帮助修复了流行的 Yealink 视频会议系统中的一个危险漏洞
  • 深度解析 Spring Security:身份验证、授权、OAuth2 和 JWT 身份验证的完整指南
  • 在idea中使用maven编译包,直接打包到远程环境上去了
  • 掌握Web服务器之王:Nginx 学习网站全攻略!
  • Unity3d Shader篇(三)— 片元半兰伯特着色器解析
  • Jupyter Notebook中的%matplotlib inline详解
  • python_蓝桥杯刷题记录_笔记_全AC代码_入门5
  • WebSocket相关问题
  • Linux进程信号处理:深入理解与应用(3)
  • 了解这六种最佳移动自动化测试工具吗?
  • 页面单跳转换率统计案例分析
  • Spring boot集成各种数据源操作数据库