当前位置: 首页 > article >正文

浅谈Beam Search

什么是 Beam Search?

Beam Search 是一种启发式搜索算法,常用于序列生成任务(如机器翻译、文本生成、语音识别等)。它在每一步生成时,保留当前最优的 ( k ) 个候选序列(( k ) 为 beam width),而不是像贪心搜索那样只保留一个最优解。通过这种方式,它能在一定程度上避免局部最优,同时减少计算量。


Beam Search 的关键步骤

  1. 初始化:从起始符号开始,生成所有可能的候选。
  2. 扩展:对每个候选序列,生成下一步的所有可能扩展。
  3. 评分:使用模型(如语言模型)为每个扩展序列打分。
  4. 剪枝:保留得分最高的 ( k ) 个序列,其余剪枝。
  5. 重复:重复扩展、评分和剪枝,直到生成结束符号或达到最大长度。
  6. 输出:最终选择得分最高的序列作为输出。

使用 PyTorch 实现 Beam Search

以下是一个简单的 Beam Search 实现,用于生成序列。假设我们有一个语言模型,可以预测下一个词的概率分布。

import torch
import torch.nn.functional as F

# 假设的语言模型(简单示例)
class SimpleLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(SimpleLanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.rnn = torch.nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x)
        output, hidden = self.rnn(x, hidden)
        logits = self.fc(output)
        return logits, hidden

# Beam Search 实现
def beam_search(model, start_token, beam_width, max_len, vocab_size, device):
    # 初始化
    sequences = [[start_token]]  # 初始序列
    scores = [0.0]  # 初始得分

    for _ in range(max_len):
        all_candidates = []
        for i in range(len(sequences)):
            seq = sequences[i]
            score = scores[i]

            # 将序列转换为模型输入
            input_seq = torch.tensor([seq], dtype=torch.long).to(device)
            hidden = None  # 假设初始隐藏状态为 None

            # 获取模型输出
            with torch.no_grad():
                logits, hidden = model(input_seq, hidden)
                next_token_probs = F.log_softmax(logits[:, -1, :], dim=-1)

            # 取 top-k 个候选
            top_k_probs, top_k_tokens = torch.topk(next_token_probs, beam_width)
            for j in range(beam_width):
                candidate_seq = seq + [top_k_tokens[0][j].item()]
                candidate_score = score + top_k_probs[0][j].item()
                all_candidates.append((candidate_seq, candidate_score))

        # 按得分排序,保留 top-k 个候选
        ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)
        sequences = [seq for seq, score in ordered[:beam_width]]
        scores = [score for seq, score in ordered[:beam_width]]

    # 返回得分最高的序列
    return sequences[0]

# 参数设置
vocab_size = 10000  # 词汇表大小
hidden_size = 128   # 隐藏层大小
beam_width = 3      # Beam Width
max_len = 10        # 最大生成长度
start_token = 0     # 起始 token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型
model = SimpleLanguageModel(vocab_size, hidden_size).to(device)

# 运行 Beam Search
generated_sequence = beam_search(model, start_token, beam_width, max_len, vocab_size, device)
print("Generated Sequence:", generated_sequence)

代码说明

  1. SimpleLanguageModel:一个简单的语言模型,包含嵌入层、GRU 和全连接层。
  2. beam_search:实现 Beam Search 算法,逐步生成序列。
  3. 参数
    • beam_width:控制每一步保留的候选序列数量。
    • max_len:生成序列的最大长度。
    • start_token:序列的起始 token。
  4. 输出:生成的序列。

示例输出

Generated Sequence: [0, 42, 15, 7, 23, 56, 12, 8, 34, 9]

总结

Beam Search 是一种高效的序列生成算法,通过保留多个候选序列,能够在保证生成质量的同时减少计算量。以上代码展示了如何使用 PyTorch 实现一个简单的 Beam Search。


http://www.kler.cn/a/464177.html

相关文章:

  • 友元和运算符重载
  • Django 中数据库迁移命令
  • 个人交友系统|Java|SSM|JSP|
  • 景区自助售卡机与定点酒店的合作双赢之策-景区酒店方案
  • Java编程规约:集合处理
  • Flutter-插件 scroll-to-index 实现 listView 滚动到指定索引位置
  • “混合双打”二维数组展平的有效方案(Python)
  • 【SqlSugar雪花ID常见问题】.NET开源ORM框架 SqlSugar 系列
  • requests请求带cookie
  • 深入理解Java Map集合
  • 逻辑回归(Logistic Regression)深度解析
  • 在Swagger(现称为OpenAPI)中各类@api之间的区别
  • k8s系列--docker拉取镜像导入k8s的containerd中
  • HTML——56.表单发送
  • 从零开始学桶排序:Java 示例与优化建议
  • 2025.01.02 一月 | 充分地接受生活本身
  • python中常用的内置函数介绍
  • Java开发工具-Jar命令
  • 面试经典问题 —— 链表之返回倒数第k个节点(经典的双指针问题)
  • RK3568适配美格(MEIG-SLM3XX)4G模块
  • JavaWeb开发(五)Servlet-ServletContext
  • 大数据-266 实时数仓 - Canal 对接 Kafka 客户端测试
  • 数字图像总复习
  • ubuntu切换到root用户
  • 【C++动态规划】2088. 统计农场中肥沃金字塔的数目|2104
  • C++11右值与列表初始化