浅谈Beam Search
什么是 Beam Search?
Beam Search 是一种启发式搜索算法,常用于序列生成任务(如机器翻译、文本生成、语音识别等)。它在每一步生成时,保留当前最优的 ( k ) 个候选序列(( k ) 为 beam width),而不是像贪心搜索那样只保留一个最优解。通过这种方式,它能在一定程度上避免局部最优,同时减少计算量。
Beam Search 的关键步骤
- 初始化:从起始符号开始,生成所有可能的候选。
- 扩展:对每个候选序列,生成下一步的所有可能扩展。
- 评分:使用模型(如语言模型)为每个扩展序列打分。
- 剪枝:保留得分最高的 ( k ) 个序列,其余剪枝。
- 重复:重复扩展、评分和剪枝,直到生成结束符号或达到最大长度。
- 输出:最终选择得分最高的序列作为输出。
使用 PyTorch 实现 Beam Search
以下是一个简单的 Beam Search 实现,用于生成序列。假设我们有一个语言模型,可以预测下一个词的概率分布。
import torch
import torch.nn.functional as F
# 假设的语言模型(简单示例)
class SimpleLanguageModel(torch.nn.Module):
def __init__(self, vocab_size, hidden_size):
super(SimpleLanguageModel, self).__init__()
self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
self.rnn = torch.nn.GRU(hidden_size, hidden_size, batch_first=True)
self.fc = torch.nn.Linear(hidden_size, vocab_size)
def forward(self, x, hidden):
x = self.embedding(x)
output, hidden = self.rnn(x, hidden)
logits = self.fc(output)
return logits, hidden
# Beam Search 实现
def beam_search(model, start_token, beam_width, max_len, vocab_size, device):
# 初始化
sequences = [[start_token]] # 初始序列
scores = [0.0] # 初始得分
for _ in range(max_len):
all_candidates = []
for i in range(len(sequences)):
seq = sequences[i]
score = scores[i]
# 将序列转换为模型输入
input_seq = torch.tensor([seq], dtype=torch.long).to(device)
hidden = None # 假设初始隐藏状态为 None
# 获取模型输出
with torch.no_grad():
logits, hidden = model(input_seq, hidden)
next_token_probs = F.log_softmax(logits[:, -1, :], dim=-1)
# 取 top-k 个候选
top_k_probs, top_k_tokens = torch.topk(next_token_probs, beam_width)
for j in range(beam_width):
candidate_seq = seq + [top_k_tokens[0][j].item()]
candidate_score = score + top_k_probs[0][j].item()
all_candidates.append((candidate_seq, candidate_score))
# 按得分排序,保留 top-k 个候选
ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
sequences = [seq for seq, score in ordered[:beam_width]]
scores = [score for seq, score in ordered[:beam_width]]
# 返回得分最高的序列
return sequences[0]
# 参数设置
vocab_size = 10000 # 词汇表大小
hidden_size = 128 # 隐藏层大小
beam_width = 3 # Beam Width
max_len = 10 # 最大生成长度
start_token = 0 # 起始 token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型
model = SimpleLanguageModel(vocab_size, hidden_size).to(device)
# 运行 Beam Search
generated_sequence = beam_search(model, start_token, beam_width, max_len, vocab_size, device)
print("Generated Sequence:", generated_sequence)
代码说明
- SimpleLanguageModel:一个简单的语言模型,包含嵌入层、GRU 和全连接层。
- beam_search:实现 Beam Search 算法,逐步生成序列。
- 参数:
beam_width
:控制每一步保留的候选序列数量。max_len
:生成序列的最大长度。start_token
:序列的起始 token。
- 输出:生成的序列。
示例输出
Generated Sequence: [0, 42, 15, 7, 23, 56, 12, 8, 34, 9]
总结
Beam Search 是一种高效的序列生成算法,通过保留多个候选序列,能够在保证生成质量的同时减少计算量。以上代码展示了如何使用 PyTorch 实现一个简单的 Beam Search。