当前位置: 首页 > article >正文

机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用

目录

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

2.2验证上述封装的代码 

2.3 创建 添加了Bahdanau的decoder 

 2.4训练

 2.5预测

3.知识点个人理解 


 

1.多头注意力(MultiAttentionn)的理念图

2.代码实现 

2.1创建多头注意力函数 

class MultiHeadAttention(nn.Module):
    #初始化属性和方法
    def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        """
        query_size_size: query_size的特征数features
        key_size: key_size的特征数features
        value_size: value_size的特征数features
        num_hiddens:隐藏层的神经元的数量
        num_heads:多头注意力的header的数量
        dropout: 释放模型需要计算的参数的比例
        bias=False:没有偏差
        **kwargs : 不定长度的关键字参数
        """
        super().__init__(**kwargs)
        #接收参数
        self.num_heads = num_heads
        #初始化注意力,    #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同
        self.attention = dltools.DotProductAttention(dropout)
        #初始化四个w模型参数
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        def transpose_qkv(X, num_heads):
            """实现queries, keys, values的数据维度转化"""
            #输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)
            #这里,不能直接用reshape,需要索引维度,防止数据不能一一对应
            X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)   #将原维度的num_hiddens拆分成num_heads, -1,  -1相当于num_hiddens/num_heads的数值
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)
            return X.reshape(-1, X.shape[2], X.shape[3])  #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)

        def transpose_outputs(X, num_heads):
            """逆转transpose_qkv的操作"""
            #此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])  #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)
            X = X.permute(0, 2, 1, 3)  #X的shape=(batch_size, 查询数/键值对数量, num_heads,  num_hiddens/num_heads)
            return X.reshape(X.shape[0], X.shape[1], -1)  #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)

        #queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)
        #获取转换维度之后的queries, keys, values,
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        #若valid_len不为空,存在
        if valid_lens is not None:
            #将valid_lens重复数据self.num_heads次,在0维度上
            valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
        #若为空,什么都不做,跳出if判断,继续执行其他代码

        #通过注意力函数获取输出outputs
        #outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
        outputs = self.attention(queries, keys, values, valid_lens)

        #逆转outputs的维度
        outputs_concat = transpose_outputs(outputs, self.num_heads)

        return self.W_o(outputs_concat)

2.2验证上述封装的代码 

#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval()  #需要预测,加上
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])

X = torch.ones((batch_size, num_queries, num_hiddens))  #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))  #shape(2,6,100) 

attention(X, Y, Y, valid_lens).shape

torch.Size([2, 4, 100])

2.3 创建 添加了Bahdanau的decoder 

# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens

            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)

            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)

            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)

            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            

        outputs = self.dense(torch.cat(outputs, dim=0))

        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

 2.4训练

# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)

encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)

decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)

net = dltools.EncoderDecoder(encoder, decoder)

dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

 2.5预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')

go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('trouvez tom .', []), bleu 0.000
i'm home . => ('je suis chez moi .', []), bleu 1.000

3.知识点个人理解 

 


http://www.kler.cn/news/317913.html

相关文章:

  • 如何使用Spring Cloud Gateway搭建网关系统
  • 怎么录制游戏视频?精选5款游戏录屏软件
  • 电源芯片测试系统如何完成欠压关断/欠压关断滞后?
  • 某花顺爬虫逆向分析
  • Leetcode 543. 124. 二叉树的直径 树形dp C++实现
  • 根据[国家统计局最新行政区规划]数据库代码
  • 研1日记15
  • 快速了解使用路由器
  • openssl-AES-128-CTR加解密char型数组分析
  • 代码随想录算法训练营||二叉树
  • 背景图鼠标放上去切换图片过渡效果
  • PHPMailer低版本用法(实例)
  • 深入解析Linux驱动开发中的I2C时序及I2C高频面试题
  • 前端vue-ref与document.querySelector的对比
  • 2024年9月24日---关于MyBatis框架(3)
  • Linux使用Clash,clash-for-linux
  • OpenCV多通道图像混合(六)
  • 【Linux 从基础到进阶】 QEMU 虚拟化配置与优化
  • OpenAI最新GPT-o1-preview测评
  • 关于事务的一些梳理
  • Springboot+Shiro+Mybatis+mysql实现权限安全认证
  • 深入解析:高性能 SSE 服务器的设计与实现
  • linux中crontab工具详解
  • React-Native 中使用 react-native-image-crop-picker 在华为手机上不能正常使用拍照功能
  • SQL常用技巧总结
  • ​‌GAS系统​
  • 【Kubernetes】常见面试题汇总(三十六)
  • OMRON欧姆龙通讯模块CI541V1
  • 网络安全:构建数字世界的坚固防线
  • MVCC机制解析:提升数据库并发性能的关键