当前位置: 首页 > article >正文

基于Python的自然语言处理系列(12):使用TorchText和LSTM进行序列到序列(seq2seq)翻译

        在前面的文章中,我们探索了如何使用卷积神经网络(CNN)进行文本分类。这次,我们将深入探讨一种经典的深度学习模型——序列到序列模型(seq2seq),并结合长短期记忆网络(LSTM)来处理序列生成任务。本篇将展示如何使用TorchText加载数据,并构建基于LSTM的seq2seq模型来进行德语到英语的翻译任务。

1. 序列到序列模型简介

        seq2seq模型最常用于解决将一个序列映射到另一个序列的问题,如机器翻译和文本摘要。它的核心思想是通过编码器-解码器的结构,将输入序列压缩为一个上下文向量,并依赖这个向量生成输出序列。这个模型的结构如下:

        在模型中,输入序列首先被嵌入为词向量,然后经过编码器(通常是RNN或者LSTM)来生成上下文向量,接着解码器根据上下文向量逐步生成目标序列。在训练中,解码器的输入可能会使用实际的目标序列(称为Teacher Forcing),也可能使用模型上一步的预测。

2. 数据加载与预处理

        首先,我们将使用TorchText加载德语到英语的翻译数据集,并使用spacy进行文本标记化和词汇构建。

from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer

SRC_LANGUAGE = 'de'
TRG_LANGUAGE = 'en'

# 加载Multi30k数据集
train = Multi30k(split=('train'), language_pair=(SRC_LANGUAGE, TRG_LANGUAGE))

# 使用spacy进行标记化
token_transform = {}
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TRG_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

# 词汇构建
from torchtext.vocab import build_vocab_from_iterator

def yield_tokens(data, language):
    for data_sample in data:
        yield token_transform[language](data_sample[language])

vocab_transform = {}
vocab_transform[SRC_LANGUAGE] = build_vocab_from_iterator(yield_tokens(train, SRC_LANGUAGE), min_freq=2)
vocab_transform[TRG_LANGUAGE] = build_vocab_from_iterator(yield_tokens(train, TRG_LANGUAGE), min_freq=2)

        我们定义了词汇表并准备了数据加载器,接下来我们将进行序列标记的数值化,并创建输入输出的张量。

3. 模型设计:基于LSTM的Encoder和Decoder

Encoder

        编码器负责将输入序列转换为上下文向量。我们将使用多层LSTM来实现,它能够捕捉输入序列中的长短期依赖关系。

import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, (hidden, cell) = self.lstm(embedded)
        return hidden, cell

Decoder

        解码器使用编码器生成的上下文向量,并逐步生成目标序列。解码器的输出是每一步生成的单词。

class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        output, (hidden, cell) = self.lstm(embedded, (hidden, cell))
        prediction = self.fc_out(output.squeeze(0))
        return prediction, hidden, cell

Seq2Seq模型

        最后,我们将编码器和解码器整合在一个Seq2Seq模型中,控制输入和输出的流动。

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        trg_len = trg.shape[0]
        batch_size = trg.shape[1]
        trg_vocab_size = self.decoder.output_dim
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        hidden, cell = self.encoder(src)
        input = trg[0, :]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            input = trg[t] if teacher_force else output.argmax(1)
        return outputs

4. 模型训练与评估

        我们使用交叉熵损失函数,并定义训练和评估的函数。

criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, (src, trg) in enumerate(iterator):
        src = src.to(device)
        trg = trg.to(device)
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, (src, trg) in enumerate(iterator):
            src = src.to(device)
            trg = trg.to(device)
            output = model(src, trg, 0)  # 关闭Teacher Forcing
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)

5. 测试与预测

        训练完成后,我们可以使用该模型进行翻译任务。以下是如何测试模型的简单示例。

model.eval()
with torch.no_grad():
    output = model(src, trg, 0)  # 关闭Teacher Forcing
    predicted_tokens = output.argmax(2)
    for token in predicted_tokens:
        print(vocab_transform[TRG_LANGUAGE].get_itos()[token.item()])

结语

        在本文中,我们详细介绍了如何使用TorchText和LSTM构建序列到序列翻译模型。通过这种结构,我们能够将一个序列映射为另一个序列,这在机器翻译、文本生成等任务中非常常见。

        不过,seq2seq模型在处理长序列时可能会遇到一些问题,尤其是在生成的序列与输入序列相差较远时。为了解决这个问题,下一篇文章中我们将引入另一种循环神经网络结构——门控循环单元(GRU),并探讨如何重复使用上下文向量和更有效的Teacher Forcing策略,从而提升模型在长序列生成任务中的表现。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.kler.cn/a/311583.html

相关文章:

  • 实例方法,类方法和静态方法的区别,举例
  • 电子取证小白教程
  • 【C++】异常处理机制(对运行时错误的处理)
  • 蓝桥杯 懒洋洋字符串--字符串读入
  • 劫持微信聊天记录并分析还原 —— 访问数据库并查看聊天记录(五)
  • 4-1-2.C# 数据容器 - List 扩展(List 注意事项、List 存储对象的特性、List 与数组的转换)
  • LVGL 控件之仪表盘(lv_meter)
  • Learn ComputeShader 15 Grass
  • 【JVM】垃圾回收
  • 派遣函数 - 缓冲区设备模拟文件读写
  • 服务器数据恢复—raid5阵列热备盘上线失败导致阵列崩溃的数据恢复案例
  • redis为什么不使用一致性hash
  • 向日葵好用吗?4款稳定的远程控制软件推荐。
  • 关于交叉编译移植到Debian开发板的一些随笔
  • gbase8s存储过程一些隐藏的错误写法
  • docker镜像源
  • info 命令:查看命令手册
  • 案例分析-Stream List 中取出值最大的前 5 个和最小的 5 个值
  • 动态内存
  • 7.Java高级编程 多线程
  • flutter Dio发送post请求
  • Linux: debug:内核log有乱码^@^@
  • Redis——分布式锁
  • JVM 虚拟机的编译器、类加载过程、类加载器有哪些?
  • Paragon NTFS for Mac和Tuxera NTFS for Mac,那么两种工具有什么区别呢?
  • python中的排序函数sorted