当前位置: 首页 > article >正文

深度学习中注意力机制介绍及seq2seq案例

一. 注意力机制介绍

普通机器翻译

图中表示的是一个中文到英文的翻译:欢迎 来 北京 → welcome to BeiJing。编码器首先处理中文输入"欢迎 来 北京",通过GRU模型获得每个时间步的输出张量,最后将它们拼接(按位相加)成一个中间语义张量c, 或者C=hn;接着解码器将使用这个中间语义张量c以及每一个时间步的隐层张量, 逐个生成对应的翻译语言

  • 问题1:如果翻译的句子很长很复杂,比如直接一篇文章输进去,模型的计算量很大,并且模型的准确率下降严重。

  • 问题2:在翻译时,可能在不同的语境下,同一个词具有不同的含义,但是网络对这些词向量并没有区分度,没有考虑词与词之间的相关性,导致翻译效果比较差。

注意力机制

让机器注意到每个词向量之间的相关性,有侧重地进行翻译,模拟人类理解的过程。

注意力机制实现原理

原理

通俗来讲就是对于模型的每一个输入项,可能是图片中的不同部分,或者是语句中的某个单词分配一个权重,这个权重的大小就代表了我们希望模型对该部分一个关注程度。这样一来,通过权重大小来模拟人在处理信息的注意力的侧重,有效的提高了模型的性能,并且一定程度上降低了计算量。

分类

深度学习中的注意力机制通常可分为三类: 软注意(全局注意)、硬注意(局部注意)和自注意(内注意)

软注意力机制

软注意机制(Soft/Global Attention: 对每个输入项的分配的权重为0-1之间,也就是某些部分关注的多一点,某些部分关注的少一点,因为对大部分信息都有考虑,但考虑程度不一样,所以相对来说计算量比较大

硬注意力机制

硬注意机制(Hard/Local Attention,[了解即可]): 对每个输入项分配的权重非0即1,和软注意不同,硬注意机制只考虑那部分需要关注,哪部分不关注,也就是直接舍弃掉一些不相关项。优势在于可以减少一定的时间和计算成本,但有可能丢失掉一些本应该注意的信息

自注意力机制

自注意力机制( Self/Intra Attention): 对每个输入项分配的权重取决于输入项之间的相互作用,即通过输入项内部的"表决"来决定应该关注哪些输入项。和前两种相比,在处理很长的输入时,具有并行计算的优势

soft_attention

注意力的作用

比如机器翻译任务,输入source为:Tom chase Jerry,输出target为:“汤姆”,“追逐”,“杰瑞”。在翻译“Jerry”这个中文单词的时候,普通Encoder-Decoder框架中,source里的每个单词对翻译目标单词“杰瑞”贡献是相同的,很明显这里不太合理,显然“Jerry”对于翻译成“杰瑞”更重要。

如果引入Attention模型,在生成“杰瑞”的时候,应该体现出英文单词对于翻译当前中文单词不同的影响程度,比如给出类似下面一个概率分布值:(Tom,0.3)(Chase,0.2) (Jerry,0.5).每个英文单词的概率代表了翻译当前单词“杰瑞”时,注意力分配模型分配给不同英文单词的注意力大小。

基于上述例子所示, 对于target中任意一个单词都应该有对应的source中的单词的注意力分配概率.而且,由于注意力模型的加入,原来在生成target单词时候的中间语义C就不再是固定的,而是会根据注意力概率变化的C

中间语义C

CTom=g(0.6∗f2(Tom),0.2∗f2(Chase),0.2∗f2(Jerry))

CChase=g(0.2∗f2(Tom),0.7∗f2(Chase),0.1∗f2(Jerry))

CJerry=g(0.3∗f2(Tom),0.2∗f2(Chase),0.5∗f2(Jerry))

生成目标单词

y1=f1(C1)

y2=f1(C2,y1)

y3=f1(C3,y1,y2)

注意力概率分布计算

上图中h_i表示Source中单词j对应的隐层节点状态h_j,H_i表示Target中单词i的隐层节点状态,注意力计算的是Target中单词i对Source中每个单词对齐可能性, Q*K^T,即F(h_j,H_i-1),而函数F可以用不同的方法,然后函数F的输出经过softmax进行归一化就得到了注意力分配概率分布。

上面就是经典的Soft Attention模型的基本思想,区别只是函数F会有所不同。

本质思想

其实Attention机制可以看作,Target中每个单词是对Source每个单词的加权求和,而权重是Source中每个单词对Target中每个单词的重要程度。因此,Attention的本质思想会表示成下图:

将Source中的构成元素看作是一系列的数据对,给定Target中的某个元素Query,通过计算Query和各个Key的相似性或者相关性,即权重系数;然后对Value进行加权求和,并得到最终的Attention数值。将本质思想表示成公式如下:

深度学习中的注意力机制中提到:Source 中的 Key 和 Value 合二为一,指向的是同一个东西,也即输入句子中每个单词对应的语义编码,所以可能不容易看出这种能够体现本质思想的结构。因此,Attention计算转换为下面3个阶段。

输入由三部分构成:Query、Key和Value。其中,(Key, Value)是具有相互关联的KV对,Query是输入的“问题”,Attention可以将Query转化为与Query最相关的向量表示。

第一步:Query和Key进行相似度计算,得到Attention Score;

第二步:对Attention Score进行Softmax归一化,得到权值矩阵;

第三步:权重矩阵与Value进行加权求和计算。

hard_attention

根据注意力分布选择输入向量中的一个作为输出。这里有两种选择方式:

  1. 选择注意力分布中,分数最大的那一项对应的输入向量作为Attention机制的输出。

  2. 根据注意力分布进行随机采样,采样结果作为Attention机制的输出。

硬性注意力通过以上两种方式选择Attention的输出,这会使得最终的损失函数与注意力分布之间的函数关系不可导,导致无法使用反向传播算法训练模型,硬性注意力通常需要使用强化学习来进行训练。因此,一般深度学习算法会使用软性注意力的方式进行计算,

self_attention

Self Attention,指的是Source内部元素之间或者Target内部元素之间发生的Attention机制,也可以理解为Target=Source这种特殊情况下的注意力机制。当然,具体的计算过程仍然是一样的,只是计算对象发生了变化而已。(q=k=v)

将输入信息embedding后乘上不同的权重矩阵得到QKV, 即QKV不相等, 且Q也是来自于source

二. 注意力机制介绍2

seq2seq

注意力机制规则

它需要三个指定的输入Q(query), K(key), V(value), 然后通过计算公式得到注意力的结果, 这个结果代表query在key和value作用下的注意力表示. 当输入的Q=K=V时(来源相同并不是值相等), 称作自注意力计算规则;当Q、K、V不相等时称为一般注意力计算规则

版本1

  • 查询张量Q: 解码器每一步输出或者是当前输入的x

  • 键张量K: 编码部分每个时间步的结果组合而成

  • 值张量V:编码部分每个时间步的结果组合而成

版本2

  • 查询张量Q: 解码器每一步的输出(s1通过线性层后面的结果, 再次转为词向量)或者是当前输入的x

  • 键张量K: 解码器上一步的隐藏层输出

  • 值张量V:编码部分每个时间步输出结果组合而成

版本解码对比

  1. 采用自回归机制,比如:输入“go”来预测“welcome”,输入“welcome”来预测"to",输入“to”来预测“Beijing”。在输入“welcome”来预测"to"解码中,可使用注意力机制

  2. 查询张量Q:一般可以是“welcome”词嵌入层以后的结果,查询张量Q为生成谁就是谁的查询张量(比如这里为了生成“to”,则查询张量就是“to”的查询张量,请仔细体会这一点)

  3. 键向量K:一般可以是上一个时间步的隐藏层输出

  4. 值向量V:一般可以是编码部分每个时间步的结果组合而成

  5. 查询张量Q来生成“to”,去检索“to”单词和“欢迎”、“来”、“北京”三个单词的权重分布,注意力结果表示(用权重分布 乘以内容V)

注意力机制计算规则

版本1

将Q与K的转置做点积运算, 然后除以一个缩放系数, 再使用softmax处理获得结果最后与V做张量乘法.

Attention(Q,K,V)=Softmax(Q⋅K^T/√dk)⋅V

★版本2

将Q,K进行纵轴拼接, 做一次线性变化, 再使用softmax处理获得结果最后与V做张量乘法.

Attention(Q,K,V)=Softmax(Linear([Q,K]))⋅V

bmm计算

import torch
​
​
# 1. torch.bmm(), 不支持广播
# 如果参数1形状是(b × n × m),
# 参数2形状是(b × m × p),
# 则输出为(b × n × p)
input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
print(res.size())   # torch.Size([10, 3, 5])
​
​
# 2. torch.mm(), 不支持广播
res = torch.mm(input[0], mat2[0])
print(res.size())   # torch.Size([3, 5])
​
​
# 3. torch.matmul(), 支持广播, input(10, 3, 4) mat2[0].shape=(4, 5)或者=(1, 4, 5)
mat3 = torch.randn(1, 4, 5)
res1 = torch.matmul(input, mat2[0])
res2 = torch.matmul(input, mat3)
print(res1.size())  # torch.Size([10, 3, 5])
print(res2.size())  # torch.Size([10, 3, 5])
​
​
# 4. * or torch.mul()哈达玛积, 支持广播, input(10, 3, 4) mat2[0].shape=(4, 5)或者=(1, 4, 5)
a = torch.rand(3, 4)
b = torch.rand(3, 4)
print(a * b)
print(torch.mul(a, b))
print(torch.mul(a, b).shape)

注意力机制定义

注意力机制是注意力计算规则能够应用的深度学习网络的载体, 同时包括一些必要的全连接层以及相关张量处理, 使其与应用网络融为一体. 使用自注意力计算规则的注意力机制称为自注意力机制.

  1. rnn等循环神经网络,随着时间步的增长,前面单词的特征会遗忘,造成对句子特征提取不充分

  2. rnn等循环神经网络是一个时间步一个时间步的提取序列特征,效率低下

  3. 研究者开始思考,能不能对32个单词(序列)同时提取事物特征,而且还是并行的,所以引入注意力机制!

注意力机制作用

  • 解码器端的注意力机制: 能够根据模型目标有效的聚焦编码器的输出结果, 当其作为解码器的输入时提升效果. 改善以往编码器输出是单一定长张量(中间语义C), 无法存储过多信息的情况.

  • 编码器端的注意力机制: 主要解决表征问题, 相当于特征提取过程, 得到输入的注意力表示. 一般使用自注意力(self-attention).

注意力机制实现步骤

步骤

  • 第一步: 根据注意力计算规则, 对Q,K,V进行相应的计算.

  • 第二步: 根据第一步采用的计算方法, 如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接, 如果是转置点积, 一般是自注意力, Q与V相同, 则不需要进行与Q的拼接.

  • 第三步: 最后为了使整个attention机制按照指定尺寸输出, 使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示.

代码演示

# 任务描述:
# 有QKV:v是内容比如32个单词,每个单词64个特征,k是32个单词的索引,q是查询张量
# 我们的任务:输入查询张量q,通过注意力机制来计算如下信息:
# 1、查询张量q的注意力权重分布:查询张量q和其他32个单词相关性(相识度)
# 2、查询张量q的结果表示:有一个普通的q升级成一个更强大q;用q和v做bmm运算
# 3 注意:查询张量q查询的目标是谁,就是谁的查询张量。
#   eg:比如查询张量q是来查询单词"我",则q就是我的查询张量
​
import torch
import torch.nn as nn
import torch.nn.functional as F
​
​
# MyAtt类实现思路分析
# 1 init函数 (self, query_size, key_size, value_size1, value_size2, output_size)
# 准备2个线性层 注意力权重分布self.attn 注意力结果表示按照指定维度进行输出层 self.attn_combine
# 2 forward(self, Q, K, V):
# 求查询张量q的注意力权重分布, attn_weights[1,32]
# 求查询张量q的注意力结果表示 bmm运算, attn_applied[1,1,64]
# q 与 attn_applied 融合,再按照指定维度输出 output[1,1,32]
# 返回注意力结果表示output:[1,1,32], 注意力权重分布attn_weights:[1,32]
​
class MyAtt(nn.Module):
    #                   32          32          32              64      32
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        super(MyAtt, self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size
​
        # 线性层1 注意力权重分布
        self.attn = nn.Linear(self.query_size + self.key_size, self.value_size1)
​
        # 线性层2 注意力结果表示按照指定维度输出层 self.attn_combine
        self.attn_combine = nn.Linear(self.query_size + self.value_size2, output_size)
​
    def forward(self, Q, K, V):
        # 1 求查询张量q的注意力权重分布, attn_weights[1,32]
        # [1,1,32],[1,1,32]--> [1,32],[1,32]->[1,64]
        # [1,64] --> [1,32]
        # tmp1 = torch.cat( (Q[0], K[0]), dim=1)
        # tmp2 = self.attn(tmp1)
        # tmp3 = F.softmax(tmp2, dim=1)
        attn_weights = F.softmax(self.attn(torch.cat((Q[0], K[0]), dim=-1)), dim=-1)
​
        # 2 求查询张量q的结果表示 bmm运算, attn_applied[1,1,64]
        # [1,1,32] * [1,32,64] ---> [1,1,64]
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)
​
        # 3 q 与 attn_applied 融合,再按照指定维度输出 output[1,1,64]
        # 3-1 q与结果表示拼接 [1,32],[1,64] ---> [1,96]
        output = torch.cat((Q[0], attn_applied[0]), dim=-1)
        # 3-2 shape [1,96] ---> [1,32]
        output = self.attn_combine(output).unsqueeze(0)
​
        # 4 返回注意力结果表示output:[1,1,32], 注意力权重分布attn_weights:[1,32]
        return output, attn_weights
​
​
if __name__ == '__main__':
    query_size = 32
    key_size = 32
    value_size1 = 32  # 32个单词
    value_size2 = 64  # 64个特征
    output_size = 32
​
    Q = torch.randn(1, 1, 32)
    K = torch.randn(1, 1, 32)
    V = torch.randn(1, 32, 64)
    # V = torch.randn(1, value_size1, value_size2)
​
    # 1 实例化注意力类 对象
    myattobj = MyAtt(query_size, key_size, value_size1, value_size2, output_size)
​
    # 2 把QKV数据扔给注意机制,求查询张量q的注意力结果表示、注意力权重分布
    output, attn_weights = myattobj(Q, K, V)
    print('查询张量q的注意力结果表示output--->', output.shape, output)
    print('查询张量q的注意力权重分布attn_weights--->', attn_weights.shape, attn_weights)
​

三. 英译法案例

数据预处理

SOS表示开始

EOS表示结束

UNK表示未登录词, 即训练数据中没有的词, 用于解决oov(out of vocab)问题

导包及数据清洗

# 用于正则表达式
import re
# 用于构建网络结构和函数的torch工具包
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# torch中预定义的优化方法工具包
import torch.optim as optim
import time
# 用于随机生成数据
import random
import matplotlib.pyplot as plt
​
# 定义设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# 定义开始标志
SOS_TOKEN = 0
# 定义结束标志
EOS_TOKEN = 1
# 最大句子长度
MAX_LENGTH = 10
# data路径
data_path = "./data/eng-fra-v2.txt"
​
​
# 定义数据清洗函数
def normalizeString(s):
    S = s.lower().strip()
    S = re.sub(r"([.!?])", r" \1", S)   # 符号也是token, 不能和单词混在一起
    S = re.sub(r"[^a-zA-Z.!?]+", r" ", S)
    return S

构建文本词典

# 2. 读取文档数据
def my_getdata():
    # 1. 读取数据
    with open(data_path, 'r', encoding='utf-8') as fr:
        content = fr.read()
    lines = content.strip().split('\n')
​
    # 2. 构建语言对pair
    my_pair = [[normalizeString(s) for s in l.split('\t')] for l in lines]
​
    # 3. 构建词典 word2idx, 用于编码
    # 3.1 初始化eng词典 fra词典
    # UNK 表示未登录词(预测时不在词表内的词) 用来解决oov问题
    english_word2idx = {'SOS': 0, 'EOS': 1, 'UNK': 2}
    french_word2idx = {'SOS': 0, 'EOS': 1, 'UNK': 2}
​
    # 3.2 遍历my_pair
    for pair in my_pair:
        for word in pair[0].split(' '):
            if word not in english_word2idx:
                english_word2idx[word] = len(english_word2idx)
​
        for word in pair[1].split(' '):
            if word not in french_word2idx:
                french_word2idx[word] = len(french_word2idx)
​
    # 4. 构建idx2word, 用于解码
    english_index2word = {v: k for k, v in english_word2idx.items()}
    french_index2word = {v: k for k, v in french_word2idx.items()}
​
    # 5. 返回数据
    return english_word2idx, english_index2word, len(english_word2idx), french_word2idx, french_index2word, len(
        french_word2idx), my_pair
​
​
english_word2idx, english_index2word, english_word2idx_n, french_word2idx, french_index2word, french_word2idx_n, my_pair = my_getdata()
​

构建数据源对象

# 3. 构建数据集
class MyDataset(Dataset):
    def __init__(self, my_pair):
        self.my_pair = my_pair
        self.n_samples = len(my_pair)
​
    def __len__(self):
        return self.n_samples
​
    def __getitem__(self, index):   # __getitem__方法会调用全局变量
        # 修正索引
        index = min(max(0, index), self.n_samples - 1)
​
        # 在样本数据中获取第idx条数据数据的x和y
        x = self.my_pair[index][0]
        y = self.my_pair[index][1]
​
        # 将词转为词表中索引, 并转化为张量
        # (虽然english_word2ind并没有显示传入, 但是在主函数中调用my_getdata方法并返回了词典, 自定义的数据集类内可以访问全局变量)
        x = [english_word2idx.get(word, english_word2idx_n['UNK']) for word in x.split(" ")]
        # x可以不加EOS_TOKEN标志
        # x.append(EOS_TOKEN)
        tensor_x = torch.tensor(x, dtype=torch.long, device=device)
        y = [french_word2idx.get(word, english_word2idx_n['UNK']) for word in y.split(" ")]
        y.append(EOS_TOKEN)
        tensor_y = torch.tensor(y, dtype=torch.long, device=device)
​
        return tensor_x, tensor_y

构建数据加载器

# 4. 构建数据加载器
def get_dataloader():
    my_dataset = MyDataset(my_pair)
    my_dataloader = DataLoader(dataset=my_dataset, batch_size=1, shuffle=True)
    return my_dataloader

GRU编码器和解码器

无注意力编码器

# 5. 构建编码器GRU
class EncoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(EncoderGRU, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
​
        self.embed = nn.Embedding(self.vocab_size, self.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True)
​
    def forward(self, input):
        out = self.embed(input)
        out, hn = self.gru(out)
        return out, hn

自注意力编码器

# _5. 构建编码器GRU, 自注意力版本
class SelfAttenEncoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(SelfAttenEncoderGRU, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
​
        self.embed = nn.Embedding(self.vocab_size, self.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True)
​
        self.l1 = nn.Linear(self.hidden_size, self.hidden_size)
        self.l2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.l3 = nn.Linear(self.hidden_size, self.hidden_size)
​
    def forward(self, input):
        out = self.embed(input)
​
        q = self.l1(out)
        k = self.l2(out)
        v = self.l3(out)
​
        out1 = torch.matmul(q, k.transpose(1, 2))
        out2 = out1 / self.hidden_size**0.5
        out3 = F.softmax(out2, dim=-1)
        out = torch.matmul(out3, v)
​
        out, hn = self.gru(out)
        return out, hn
    
# 6. 测试编码器GRU
def ceshi_encoder():
    mydataloader = get_dataloader()
    my_encoder = SelfAttenEncoderGRU(english_word2idx_n, 256).to(device)
    x, y = next(iter(mydataloader))
    out, hn = my_encoder(x)
    print(x.shape)
    print(y.shape)
    print(out.shape)
    print(hn.shape)
​

无注意力解码器

# 7. 构建解码器GRU, 无注意力机制
class DecoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(DecoderGRU, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size

        self.embed = nn.Embedding(self.vocab_size, self.hidden_size)

        self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True)

        self.out = nn.Linear(self.hidden_size, self.vocab_size)

        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, input, hn):   # input.shape(1, 1), 只有一个token
        output = self.embed(input)  # output.shape(1, 1, 256)
        output = F.relu(output)
        output, hn = self.gru(output, hn)
        output = self.out(output[0])    # 得到的output.shape(1, 4345)法语词表大小
        output = self.softmax(output)
        return output, hn
    
    
# 8. 测试逐token解码, 没有注意力机制
def ceshi_decoder():
    mydataloader = get_dataloader()
    my_encoder = SelfAttenEncoderGRU(english_word2idx_n, 256).to(device)
    my_decoder = DecoderGRU(french_word2idx_n, 256).to(device)
    x, y = next(iter(mydataloader))
    print(x.shape)
    print(y.shape)
    out, hn = my_encoder(x)

    # 逐token解码
    for i in range(y.shape[1]):     # y.shape=(1, sql_len)
        temp = y[0][i].view(1, -1)
        out, hn = my_decoder(temp, hn)
        print('out.shape: ', out.shape)

有注意力解码器

# 9. 构建解码器GRU, 有注意力机制
class AttnDecoderGRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderGRU, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        # 词嵌入层
        self.embed = nn.Embedding(self.vocab_size, self.hidden_size)

        # 随机失活层
        self.dropout = nn.Dropout(self.dropout_p)

        # 注意力层, 两个线性层
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)  # 计算注意力权重
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)  # 计算注意力表示, 作为GRU输入

        # GRU层
        self.gru = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True)

        # 输出层
        self.out = nn.Linear(self.hidden_size, self.vocab_size)

        # softmax层
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, input, hn, encoder_outputs):
        # input.shape(1, 1)
        # hn.shape(1, 1, 256)
        # encoder_outputs.shape=(1, 10, 256)
        input_x = self.embed(input)
        input_x = self.dropout(input_x)

        # 拼接词嵌入后的x和hn, 计算注意力权重, attn_weight.shape=(1, 1, 10)
        attn_weight = F.softmax(self.attn(torch.cat((input_x[0], hn[0]), dim=-1)), dim=-1)

        # 计算注意力表示, attn_applied.shape=(1, 1, 256), 计算 hn和encoder_outputs的点积, 注意力权重乘以V
        attn_applied = torch.bmm(attn_weight.unsqueeze(0), encoder_outputs.unsqueeze(0))  # encoder_outputs(1, 10, 256)

        # 融合注意力表示和词嵌入后的x, combine_output.shape=(1, 1, 256)
        combine_output = self.attn_combine(torch.cat((input_x, attn_applied), dim=-1))

        output = F.relu(combine_output)

        output, hn = self.gru(output, hn)

        output = self.out(output[0])

        output = self.softmax(output)
        return output, hn, attn_weight

模型训练

单批次训练

# 10. 模型训练
# 超参数设置
teacher_forcing_ratio = 0.5


# 训练函数
def train_pre_model(x, y, my_encoder, my_attn_decoder, encoder_optimizer, decoder_optimizer, cross_entropy_loss):
    # 计算encoder
    encoder_output, encoder_hn = my_encoder(x)

    # encoder_output_c(中间语义张量C, 或者是V)长度
    encoder_output_c = torch.zeros(MAX_LENGTH, my_encoder.hidden_size, device=device)
    # 逐token复制encoder_output, x.shape(batch_size, seq_len, hidden_size)
    for i in range(x.shape[1]):
        encoder_output_c[i] = encoder_output[0][i]
    # 解码器第一个时间步的h0
    decoder_hn = encoder_hn
    # 解码器第一个时间步的输入y(1, 1)
    input_y = torch.tensor([[SOS_TOKEN]], dtype=torch.long, device=device)

    # 初始化当前样本的损失
    my_loss = 0
    # 初始化teacher_forcing
    #
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # 开始训练
    if use_teacher_forcing:  # 小于指定的teacher_forcing
        for i in range(y.shape[1]):
            output_y, decoder_hn, _ = my_attn_decoder(input_y, decoder_hn, encoder_output_c)
            target_y = y[0][i].view(1)
            my_loss += cross_entropy_loss(output_y, target_y)
            # 用真实值作为下次输入
            input_y = y[0][i].view(1, -1)
    else:
        for i in range(y.shape[1]):
            output_y, decoder_hn, _ = my_attn_decoder(input_y, decoder_hn, encoder_output_c)
            target_y = y[0][i].view(1)
            my_loss += cross_entropy_loss(output_y, target_y)

            topv, topi = output_y.topk(1)

            # 设置结束条件
            if topi.squeeze().item() == EOS_TOKEN:
                break

            input_y = topi.detach()  # 不使用梯度

    # 反向传播
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    my_loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return my_loss.item() / y.shape[1]

所有数据训练-绘图

# 训练主函数
def train_seq2seq(epochs=2, lr=1e-4):
    # 1. 获取数据迭代器
    my_dataloader = get_dataloader()

    # 2. 实例化编码器解码器
    my_encoder = SelfAttenEncoderGRU(english_word2idx_n, 256).to(device)
    my_decoder = AttnDecoderGRU(french_word2idx_n, 256).to(device)

    # 3. 优化器, 损失函数
    encoder_optimizer = optim.Adam(my_encoder.parameters(), lr=lr)
    decoder_optimizer = optim.Adam(my_decoder.parameters(), lr=lr)
    cross_entropy_loss = nn.NLLLoss()
    total_iter = 0
    print_interval_total, plot_interval_total = 0.0, 0.0
    # 4. 绘图的loss列表
    plot_loss_list = []
    # 5. 开始训练, 外循环
    for epoch in range(1, 1 + epochs):
        start_time = time.time()
        # 5.1 内循环
        for i, (x, y) in enumerate(my_dataloader, start=1):
            x, y = x.to(device), y.to(device)
            my_loss = train_pre_model(x, y, my_encoder, my_decoder, encoder_optimizer, decoder_optimizer,
                                      cross_entropy_loss)
            plot_interval_total += my_loss
            print_interval_total += my_loss
            total_iter += 1
            if i % 100 == 0:
                plot_loss_list.append(plot_interval_total / total_iter)
                # plot_interval_total = 0
            if i % 1000 == 0:
                print_loss_avg = print_interval_total / total_iter
                use_time = time.time() - start_time
                print(f'epoch: {epoch}, iter: {i}, loss: {print_loss_avg:.5f}, time: {use_time:.3f}s')
                # print_interval_total = 0
        print('*' * 50)
        print('epoch: %d, total_loss: %.5f, time: %.3f' % (
            epoch, plot_interval_total / len(my_dataloader), time.time() - start_time))
        print('*' * 50)
        torch.save(my_encoder.state_dict(), 'model/encoder_model%d.pth' % epoch)
        torch.save(my_decoder.state_dict(), 'model/decoder_model%d.pth' % epoch)
    # 绘图
    plt.figure()
    plt.plot(plot_loss_list)
    plt.savefig('picture/loss.png')
    plt.show()

    return plot_loss_list

模型预测

# 11. 预测函数
def predict_seq2seq(x):
    tensor_x = [english_word2idx.get(word, english_word2idx['UNK']) for word in x.split(" ")]
    tensor_x = [tensor_x]
    tensor_x = torch.tensor(tensor_x, dtype=torch.long, device=device)
    with torch.no_grad():
        my_encoder = SelfAttenEncoderGRU(english_word2idx_n, 256).to(device)
        my_decoder = AttnDecoderGRU(french_word2idx_n, 256).to(device)
        my_encoder.load_state_dict(torch.load('model/encoder_model3.pth', weights_only=True))
        my_decoder.load_state_dict(torch.load('model/decoder_model3.pth', weights_only=True))

        encoder_output, encoder_hn = my_encoder(tensor_x)
        # 获取中间语义张量C
        # encoder_output_c(中间语义张量C, 或者是V)长度
        encoder_output_c = torch.zeros(MAX_LENGTH, my_encoder.hidden_size, device=device)
        # 逐token复制encoder_output, tensor_x.shape(batch_size, seq_len, hidden_size)
        for i in range(tensor_x.shape[1]):
            encoder_output_c[i] = encoder_output[0][i]

        # 解码器第一个时间步的h0
        decoder_hn = encoder_hn
        # 解码器第一个时间步的输入y(1, 1)
        input_y = torch.tensor([[SOS_TOKEN]], dtype=torch.long, device=device)

        # 存储预测结果
        french_words = []
        # 循环输出解码结果
        while True:
            output_y, decoder_hn, _ = my_decoder(input_y, decoder_hn, encoder_output_c)
            topv, topi = output_y.topk(1)
            if topi.squeeze().item() == EOS_TOKEN:
                break
            input_y = topi.detach()
            french_word = french_index2word[topi.item()]
            french_words.append(french_word)
        words = ' '.join(french_words)
        print('英语: ', x)
        print('法语:', words)
        
if __name__ == '__main__':
    train_seq2seq(epochs=5)

    """
    ['i m impressed with your french .', 'je suis impressionne par votre francais .'],
    ['i m more than a friend .', 'je suis plus qu une amie .'],
    ['she is beautiful like her mother .', 'elle est belle comme sa mere .']
    """

    predict_seq2seq('i m impressed with your french .')


http://www.kler.cn/a/428674.html

相关文章:

  • 数学基础 --线性代数之理解矩阵乘法
  • ConvBERT:通过基于跨度的动态卷积改进BERT
  • Linux下MySQL的简单使用
  • idea中远程调试中配置的参数说明
  • React 中hooks之useTransition使用总结
  • 云原生作业(四)
  • Matlab自学笔记四十四:使用dateshift函数生成日期时间型序列数据
  • 58 基于 单片机的温湿度、光照、电压、电流检测
  • 解决跨域问题方案
  • 高级java每日一道面试题-2024年12月05日-JVM篇-什么是空闲列表?
  • vue中this指针获取不到函数或数据
  • Vue 鼠标滚轮缩放图片的实现
  • Kubernetes 常用操作大全:全面掌握 K8s 基础与进阶命令
  • 基于 Spring Boot + Vue 的宠物领养系统设计与实现
  • Java 初学者的第一个 SpringBoot 登录系统
  • CT中的2D、MPR、VR渲染、高级临床功能
  • 鸿蒙技术分享:❓❓[鸿蒙应用开发]怎么更好的管理模块生命周期?
  • 论文研读|信息科技风险管理模型的主要内容、定位、目标企业、风险管理机制, 以及相应的风险评估流程和风险应对策略
  • Spring Boot中实现JPA多数据源配置指南
  • 再谈多重签名与 MPC
  • sed流编辑器
  • 渤海证券基于互联网环境的漏洞主动防护方案探索与实践
  • 3. React Hooks:为什么你应该使用它们?
  • 微搭低代码AI组件单词消消乐从0到1实践
  • ZOLOZ SMART AML:让复杂的反洗钱变得简单
  • 在Linux设置postgresql开机自启动,创建一个文件 postgresql-15.service