当前位置: 首页 > article >正文

NLP从零开始------17.文本中阶处理之序列到序列模型(2)

3. 学习

        序列到序列模型可以看成一种条件语言模型,以源句x为条件计算目标句的条件概率该条件概率通过概率乘法公式分解为从左到右每个词的条件概率之积:

        P(y_{1:T}|x)=P(y_{1}|x)P(y_{2}|y_{1},x)P(y_{3}|y_{1},y_{2},x) \cdots P(y_{ \tau }|y_{1}, \cdots ,y_{T-1},x)

        序列到序列模型的监督学习需要使用平行语料,其中每个数据点都包含一对源句和目标句。以中译英机器翻译为例,平行语料的每个数据点就是一句中文句子和对应的一句英文句子。机器翻译领域较为有名的平行语料库来自机器翻译研讨会( workshop on machine translation, WMT), 其中的语料来自新闻、维基百科、小说等各种领域。给定平行语料中的每个数据点, 我们希望最大化条件似然, 即最小化以下损失函数:
                                ​​​​​​​        ​​​​​​​        J=- \sum \limits _{t=1}^{T} \log P(y^ \ast |y_{1}, \cdots ,y_{t-1}^ \ast ,x)
        其中, y* 表示平行语料中源句x对应的目标句。

        训练序列到序列模型的常用方法为教师强制( teacher forcing), 即使用真实的目标序列作为解码器的输入,而不是像解码时那样使用解码器每一步的预测作为下一步的输入。教师强制会使训练过程更稳定且收敛更快,但是也会产生所谓曝光偏差( exposurebias) 的不利影响,即模型在训练时只见过正确输入,因而当解码时前置步骤出现了不正确的预测时模型后续的预测都会变得不准确。
        下面以机器翻译(中译英) 为例展示如何训练序列到序列模型。这里使用的是中英文 Books数据,其中中文标题来源于前面所使用的数据集, 英文标题是使用已训练好的机器翻译模型从中文标题翻译而得,因此该数据并不保证准确性,仅用于演示。
        首先需要对源语言和目标语言分别建立索引, 并记录词频。

SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "<sos>", 1: "<eos>"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
            
    def sent2ids(self, sent):
        return [self.word2index[word] for word in sent.split(' ')]
    
    def ids2sent(self, ids):
        return ' '.join([self.index2word[idx] for idx in ids])

import unicodedata
import string
import re
import random

# 文件使用unicode编码,我们将unicode转为ASCII,转为小写,并修改标点
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    # 在标点前插入空格
    s = re.sub(r"([,.!?])", r" \1", s)
    return s.strip()
# 读取文件,一共有两个文件,两个文件的同一行对应一对源语言和目标语言句子
def readLangs(lang1, lang2):
    # 读取文件,分句
    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\
        .strip().split('\n')
    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\
        .strip().split('\n')
    print(len(lines1), len(lines2))
    
    # 规范化
    lines1 = [normalizeString(s) for s in lines1]
    lines2 = [normalizeString(s) for s in lines2]
    if lang1 == 'zh':
        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]
    if lang2 == 'zh':
        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]
    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]

    input_lang = Lang(lang1)
    output_lang = Lang(lang2)
    return input_lang, output_lang, pairs
# 为了快速训练,过滤掉一些过长的句子
MAX_LENGTH = 30

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(lang1, lang2):
    input_lang, output_lang, pairs = readLangs(lang1, lang2)
    print(f"读取 {len(pairs)} 对序列")
    pairs = filterPairs(pairs)
    print(f"过滤后剩余 {len(pairs)} 对序列")
    print("统计词数")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('zh', 'en')
print(random.choice(pairs))
2157 2157
读取 2157 对序列
过滤后剩余 2003 对序列
统计词数
zh 1368
en 3287
['鸿 衣 赋 古 风 创 意 造 型 与 摄 影 集', "♪ the ancient wind's creative form and photo collection ♪"]

        为了便于训练,对每一对源-目标句子需要准备一个源张量(源句子的词元索引)和一个目标张量(目标句子的词元索引)。在两个句子的末尾会添加“\<eos\>”。

def get_train_data():
    input_lang, output_lang, pairs = prepareData('zh', 'en')
    train_data = []
    for idx, (src_sent, tgt_sent) in enumerate(pairs):
        src_ids = input_lang.sent2ids(src_sent)
        tgt_ids = output_lang.sent2ids(tgt_sent)
        # 添加<eos>
        src_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        train_data.append([src_ids, tgt_ids])
    return input_lang, output_lang, train_data
        
input_lang, output_lang, train_data = get_train_data()
2157 2157
读取 2157 对序列
过滤后剩余 2003 对序列
统计词数
zh 1368
en 3287

        接下来是训练代码。

from tqdm import trange
import matplotlib.pyplot as plt
from torch.optim import Adam
import numpy as np

# 训练序列到序列模型
def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\
        learning_rate=1e-3):
    # 准备模型和优化器
    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    encoder.train()
    decoder.train()
    encoder.zero_grad()
    decoder.zero_grad()

    step_losses = []
    plot_losses = []
    with trange(n_epochs, desc='epoch', ncols=60) as pbar:
        for epoch in pbar:
            np.random.shuffle(train_data)
            for step, data in enumerate(train_data):
                # 将源序列和目标序列转为 1 * seq_len 的tensor
                # 这里为了简单实现,采用了批次大小为1,
                # 当批次大小大于1时,编码器需要进行填充
                # 并且返回最后一个非填充词的隐状态,
                # 解码也需要进行相应的处理
                input_ids, target_ids = data
                input_tensor, target_tensor = \
                    torch.tensor(input_ids).unsqueeze(0),\
                    torch.tensor(target_ids).unsqueeze(0)

                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()

                encoder_outputs, encoder_hidden = encoder(input_tensor)
                # 输入目标序列用于teacher forcing训练
                decoder_outputs, _, _ = decoder(encoder_outputs,\
                    encoder_hidden, target_tensor)

                loss = criterion(
                    decoder_outputs.view(-1, decoder_outputs.size(-1)),
                    target_tensor.view(-1)
                )
                pbar.set_description(f'epoch-{epoch}, '+\
                    f'loss={loss.item():.4f}')
                step_losses.append(loss.item())
                # 实际训练批次为1,训练损失波动过大
                # 将多步损失求平均可以得到更平滑的训练曲线,便于观察
                plot_losses.append(np.mean(step_losses[-32:]))
                loss.backward()

                encoder_optimizer.step()
                decoder_optimizer.step()

    plot_losses = np.array(plot_losses)
    plt.plot(range(len(plot_losses)), plot_losses)
    plt.xlabel('training step')
    plt.ylabel('loss')
    plt.show()

    
hidden_size = 128
n_epochs = 20
learning_rate = 1e-3

encoder = RNNEncoder(input_lang.n_words, hidden_size)
decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)

train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)
epoch-19, loss=0.0047: 100%|█| 20/20 [46:11<00:00, 138.55s/i

        上面实现的基于循环神经网络和基于 Transformer的编码器和解码器具有相似的接口,大家可以尝试更换编码器和解码器,训练基于 Transformer的序列到序列模型, 此处不再重复展示。

4. 解码

        这里介绍主流的贪心解码和束搜索解码方法。

4.1 贪心解码

        在解码过程中,需要根据解码器所计算的词的概率分布一步步(自回归)地生成词。理想情况下我们希望找到概率最大的目标句子argmax _{1:T}P(y_{1:T}|x),但这无疑是很困难的, 因为所有可能的序列数量呈指数级增长, 并且由于词之间不存在条件独立性,因此不存在可求解的多项式复杂度算法。一个很简单的近似解决方式是贪心解码,即每步取概率最大的词。 然而,这种方式存在所谓错误累积问题,如下面这个例子所示。
        输入: I love Natural Language Processing。
        解码第1步: 我_。
        解码第2步: 我爱_。
        解码第3步: 我爱天然_。
        解码第4步: 我爱天然语_。
        解码第5步: 我爱天然语加工。

        在解码第3步,模型错误地将“ Natural”翻译成了“天然”,而在贪心解码中一旦模型输出一个词,就再也无法回滚和修改。更糟糕的是,“天然”这个词会被作为后续解码的条件,从而有可能让模型产生新的错误,然后这些错误又会引发更多错误,最终导致输出低质量的目标序列。

        下面的代码演示如何使用贪心解码对模型进行验证。评估与训练类似,但是评估时不提供目标句子作为输入, 因此需要将解码器每一步的输出作为下一步的输入, 当预测到“< eos>”时停止。我们也可以存储解码器的注意力输出以用于分析和展示。

def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        # 将源序列转为 1 * seq_length 的tensor
        input_ids = input_lang.sent2ids(sentence)
        input_tensor = torch.tensor(input_ids).unsqueeze(0)
        
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = \
            decoder(encoder_outputs, encoder_hidden)
        
        # 取出每一步预测概率最大的词
        _, topi = decoder_outputs.topk(1)
        
        decoded_ids = []
        for idx in topi.squeeze():
            if idx.item() == EOS_token:
                break
            decoded_ids.append(idx.item())
    return output_lang.ids2sent(decoded_ids), decoder_attn
            
encoder.eval()
decoder.eval()
for i in range(5):
    pair = random.choice(pairs)
    print('input:', pair[0])
    print('target:', pair[1])
    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],
        input_lang, output_lang)
    print('pred:', output_sentence)
    print('')
input: 商 业 分 析 方 法 与 案 例 超 越 报 表 的 商 业 智 能 ( 第 2 版 )
target: business analysis methods and cases business intelligence beyond reporting (version 2)
pred: business analysis , methods and cases business intelligence beyond reporting (version 2)

input: 精 解 w i n d o w s 1 0
target: precision windows10
pred: precision windows10 precision windows10 2016 precision 2016 solidworks 2018 chinese precision windows10 precision windows10 2016 solidworks 2018 chinese precision windows10 precision windows10 2016 solidworks 2018 chinese precision windows10 precision windows10

input: k a f k a 入 门 与 实 践
target: kafka introduction and practice
pred: kafka introduction and practice

input: 长 期 价 值 投 资 如 何 稳 健 地 积 累 财 富 ( 签 名 版 )
target: long-term value investment how to build wealth safely (signed version)
pred: long-term value investment how to build wealth safely (signed version)

input: j r o c k i t 权 威 指 南 深 入 理 解 j v m
target: jrockit's authoritative guide to an in-depth understanding of jvm .
pred: jrockit's authoritative guide to an in-depth understanding of jvm .

        在这个演示中,训练数据太少, 模型也很小, 所以贪心解码的效果不太好。

4.2 束搜索解码

        束搜索解码可以缓解贪心解码的问题。在束搜索解码中,每一步都会保留k个优选的候选结果,其中k被称为束宽。具体而言,在每一步,我们会将上一步保留的k个候选结果中的每一个作为条件,生成k个当前步骤概率最大的词, 从而得到k²个新的候选结果,再从中优选k个予以保留。候选结果之间的比较是基于当前已解码序列的概率对数:
        ​​​​​​​        ​​​​​​score(y_{1}, \cdots ,y_{t})= \log P(y_{1}, \cdots ,y_{t}|x)= \sum \limits _{i=1}^{t} \log P(y_{i}|y_{1}, \cdots ,y_{i-1},x)

        束搜索解码如何判断终止条件呢? 一旦贪心解码解码出终止符“< eos>”就终止解码。然而在束搜索解码中,不同的解码序列可能会在不同的时刻输出终止符“<cos>”,因此, 当一个解码序列预测了终止符“< eos>”时并不会终止整个解码过程,而只是终止这一个解码序列并继续其他解码序列, 直到满足以下两个条件之一:
        解码达到了时间步上线 T;
        已经有n个解码序列终止。
        这里T和n均为预定义的超参数。解码终止后,我们得到了最多n个已终止的解码序列,如何从中选择最终输出的目标序列呢?一个很直接的想法是选择概率最高的解码序列。但需要注意的是,由于越长的句子需要对更多的词的概率求积,因此概率往往越低,这导致单纯依据序列概率会趋向于选择更短的句子,但很多情况下短句并不一定是最好的选择。为了缓解这个问题,可以使用词的平均对数概率来选择最终输出的目标序列:
         ​​​​​​​        ​​​​​​​        ​​​​​​​        score(y_{1}, \cdots ,y_{t})= \frac {1}{t} \sum \limits _{i=1}^{t} \log P(y_{i}|y_{1}, \cdots ,y_{i-1},x)

        虽然束搜索解码仍然无法保证最终预测是最优解,甚至无法保证一定优于贪心解码, 但是它的效果好于贪心解码,因为它考虑了更多可能的目标序列。贪心解码其实可以看作k=1的束搜索解码。
        接下来使用束搜索解码来验证模型。

# 定义容器类用于管理所有的候选结果
class BeamHypotheses:
    def __init__(self, num_beams, max_length):
        self.max_length = max_length
        self.num_beams = num_beams
        self.beams = []
        self.worst_score = 1e9

    def __len__(self):
        return len(self.beams)
    
    # 添加一个候选结果,更新最差得分
    def add(self, sum_logprobs, hyp, hidden):
        score = sum_logprobs / max(len(hyp), 1)
        if len(self) < self.num_beams or score > self.worst_score:
            # 可更新的情况:数量未饱和或超过最差得分
            self.beams.append((score, hyp, hidden))
            if len(self) > self.num_beams:
                # 数量饱和需要删掉一个最差的
                sorted_scores = sorted([(s, idx) for idx,\
                    (s, _, _) in enumerate(self.beams)])
                del self.beams[sorted_scores[0][1]]
                self.worst_score = sorted_scores[1][0]
            else:
                self.worst_score = min(score, self.worst_score)
    
    # 取出一个未停止的候选结果,第一个返回值表示是否成功取出,
    # 如成功,则第二个值为目标候选结果
    def pop(self):
        if len(self) == 0:
            return False, None
        for i, (s, hyp, hid) in enumerate(self.beams):
            # 未停止的候选结果需满足:长度小于最大解码长度;不以<eos>结束
            if len(hyp) < self.max_length and (len(hyp) == 0\
                    or hyp[-1] != EOS_token):
                del self.beams[i]
                if len(self) > 0:
                    sorted_scores = sorted([(s, idx) for idx,\
                        (s, _, _) in enumerate(self.beams)])
                    self.worst_score = sorted_scores[0][0]
                else:
                    self.worst_score = 1e9
                return True, (s, hyp, hid)
        return False, None
    
    # 取出分数最高的候选结果,第一个返回值表示是否成功取出,
    # 如成功,则第二个值为目标候选结果
    def pop_best(self):
        if len(self) == 0:
            return False, None
        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\
            in enumerate(self.beams)])
        return True, self.beams[sorted_scores[-1][1]]


def beam_search_decode(encoder, decoder, sentence, input_lang,
        output_lang, num_beams=3):
    with torch.no_grad():
        # 将源序列转为 1 * seq_length 的tensor
        input_ids = input_lang.sent2ids(sentence)
        input_tensor = torch.tensor(input_ids).unsqueeze(0)

        # 在容器中插入一个空的候选结果
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        init_hyp = []
        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)
        hypotheses.add(0, init_hyp, encoder_hidden)

        while True:
            # 每次取出一个未停止的候选结果
            flag, item = hypotheses.pop()
            if not flag:
                break
                
            score, hyp, decoder_hidden = item
            
            # 当前解码器输入
            if len(hyp) > 0:
                decoder_input = torch.empty(1, 1,\
                    dtype=torch.long).fill_(hyp[-1])
            else:
                decoder_input = torch.empty(1, 1,\
                    dtype=torch.long).fill_(SOS_token)

            # 解码一步
            decoder_output, decoder_hidden, _ = decoder.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )

            # 从输出分布中取出前k个结果
            topk_values, topk_ids = decoder_output.topk(num_beams)
            # 生成并添加新的候选结果到容器
            for logp, token_id in zip(topk_values.squeeze(),\
                    topk_ids.squeeze()):
                sum_logprobs = score * len(hyp) + logp.item()
                new_hyp = hyp + [token_id.item()]
                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)

        flag, item = hypotheses.pop_best()
        if flag:
            hyp = item[1]
            if hyp[-1] == EOS_token:
                del hyp[-1]
            return output_lang.ids2sent(hyp)
        else:
            return ''

encoder.eval()
decoder.eval()
for i in range(5):
    pair = random.choice(pairs)
    print('input:', pair[0])
    print('target:', pair[1])
    output_sentence = beam_search_decode(encoder, decoder,\
        pair[0], input_lang, output_lang)
    print('pred:', output_sentence)
    print('')
input: h 5 和 w e b g l 3 d 开 发 实 战 详 解
target: elaboration of the h5 and webgl 3d development
pred: elaboration of the h5 and webgl 3d development

input: 跨 境 电 子 商 务 英 语 ( 音 频 指 导 版 )
target: cross-border e-commerce english (audio guide version)
pred: cross-border e-commerce english (audio guide version)

input: 采 购 与 供 应 商 管 理 常 用 制 度 与 表 格 范 例
target: examples of common systems and forms for procurement and vendor management
pred: examples of common systems and forms for procurement and vendor management

input: 大 学 生 职 业 生 涯 规 划 与 就 业 创 业 指 导 ( 微 课 版 )
target: career planning and entrepreneurship guidance for university students (micro-curricular version)
pred: career planning and entrepreneurship guidance for university students (micro-curricular version)

input: p y t h o n 数 据 挖 掘 入 门 与 实 践 第 2 版
target: python data digging introduction and practice 2nd edition
pred: python data digging introduction and practice 2nd edition

        可以看到束搜素解码在这个演示中的效果想比贪心解码有所改善。


http://www.kler.cn/a/290726.html

相关文章:

  • Element-ui table组件:单元格未溢出,悬浮出现popover提示框
  • QT-------认识QT
  • ADB 上传文件并使用脚本监控上传百分比
  • 【Python运维】自动化备份与恢复系统的实现:Python脚本实战
  • 『大模型笔记』评估大型语言模型的指标:ELO评分,BLEU,困惑度和交叉熵介绍以及举例解释
  • Nmap基础入门及常用命令汇总
  • Draw.io for Mac/Win:免费且强大的流程图绘制工具
  • 数据库和MySQL
  • 网络协议--HTTP 和 HTTPS 的区别
  • 设计模式 —— 单例模式
  • 惠中科技PV-Wiper全自动光伏组件清洁系统:智能清洁赋能光伏产业
  • 日系编曲:日系钢琴写作思路 双手思维 双手编写思路 双手合并 琶音 刮奏 颤音 震音
  • 点云帧间位姿矩阵的预测和误差计算
  • [Meachines] [Medium] Bitlab 标签自动填充登录+GitLab+Docker横向+Postgresql+逆向工程
  • Spring AOP(下)原理
  • JMeter 接口自动化测试:以搜索功能为例的实现思路详解
  • vue + Lodop 制作可视化设计页面 实现打印设计功能(三)
  • 服务器文件权限限制写入
  • Ribbon 源码分析【Ribbon 负载均衡】
  • go 开发小技巧
  • 9.4日常记录
  • Git+word记笔记
  • DriveLM的baseline复现
  • 关于edge浏览器登陆CSDN安全验证不跳出验证码
  • 『 Linux 』简单TCP英译汉程序
  • 【Webpack】基本使用方法