当前位置: 首页 > article >正文

【自然语言处理(NLP)】多头注意力(Multi - Head Attention)原理及代码实现

文章目录

  • 介绍
  • 多头注意力
    • 原理
    • 代码实现
      • 导包
      • 多头注意力结构
      • qkv转换
      • output转换
      • 构建注意力模块
      • 添加Bahdanau的decoder
      • 训练
      • 预测

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

多头注意力

原理

多头注意力机制(Multi - Head Attention)的结构示意图:
在这里插入图片描述

多头注意力机制首先将查询(Query)、键(Key)、值(Value)分别通过多个全连接层进行线性变换,得到多个不同的表示。然后,对这些不同的表示分别进行注意力计算。最后,将各个注意力的结果进行连结(Concatenate),再通过一个全连接层得到最终输出。

这种机制允许模型在不同的表示子空间中并行地关注输入序列的不同部分,能够捕捉到更丰富的语义信息,广泛应用于Transformer等模型架构中,在自然语言处理、计算机视觉等领域有重要应用。

模型计算方式:
在这里插入图片描述
在该表达式中, h i h_i hi 是注意力机制计算得到的输出, f f f 一般表示注意力计算函数(如缩放点积注意力等), W q i W_q^i Wqi W k i W_k^i Wki W v i W_v^i Wvi 分别是针对查询(query)、键(key)、值(value)的可学习权重矩阵, q q q k k k v v v 分别为查询向量、键向量、值向量 , R n \mathbb{R}^n Rn 表示输出 h i h_i hi 处于 n n n 维实数空间。它表达了在注意力计算中,通过对查询、键、值进行线性变换后再经过注意力计算函数得到输出的过程。

矩阵运算表达式:
在这里插入图片描述

表达式中 [ h 1 ⋮ h n ] \begin{bmatrix}h_1\\ \vdots \\h_n\end{bmatrix} h1hn 是一个由 h 1 h_1 h1 h n h_n hn 构成的列向量,这些 h i h_i hi 通常可以是注意力机制等模块的输出。 W o W_o Wo 是一个可学习的权重矩阵,其维度为 R p × n \mathbb{R}^{p\times n} Rp×n ,这里 p p p 是输出维度相关参数, n n n 是输入向量的长度(即 h i h_i hi 的数量)。该表达式表示对由 h i h_i hi 组成的向量进行线性变换,常用于深度学习模型(如Transformer等)的后处理阶段,对前面模块输出进行进一步的特征变换或整合。

代码实现

导包

import math
import torch
from torch import nn
import dltools

多头注意力结构

class MultiHeadAttention(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = dltools.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        
    def forward(self, queries, keys, values, valid_lens):
        # queries, keys, values 传入的形状: (batch_size, 查询熟练或者键值对数量, num_hiddens)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
#         print('queries:', queries.shape)
#         print('keys:', keys.shape)
#         print('values:', values.shape)
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
        
        # output shape: (batch_size * num_heads, 查询的个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)
#         print('output:', output.shape)
        output_concat = transpose_output(output, self.num_heads)
#         print('output_concat:', output_concat.shape)
        return self.W_o(output_concat)

qkv转换

def transpose_qkv(X, num_heads):
    # 输入X的shape: (batch_size, 查询数/键值对数, num_hiddens)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3) # batch_size, num_heads, 查询数/ 键值对数, num_hiddens/num_heads
    # 这里是把batch_size和num_heads合并在一起了. 
    return X.reshape(-1, X.shape[2], X.shape[3]) # batch_size * num_heads, 查询/键值对数, num_hiddens/ num_heads

output转换

def transpose_output(X, num_heads):
    # 逆转transpose_qkv的操作
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

构建注意力模块

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.2)
attention.eval()

在这里插入图片描述

添加Bahdanau的decoder

class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
#             print('query:', query.shape) # 4, 1, 16
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
#             print('context: ', context.shape)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
#             print('x: ', x.shape)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
#             print('out:', out.shape)
#             print('hidden_state:', hidden_state.shape)
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            
#         print('---------------------------------')
        outputs = self.dense(torch.cat(outputs, dim=0))
#         print('解码器最终输出形状: ', outputs.shape)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

训练

embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est paresseux .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

http://www.kler.cn/a/526920.html

相关文章:

  • 39【内存条与硬盘的架构逻辑】
  • 学习数据结构(5)单向链表的实现
  • 【懒删除堆】力扣2349. 设计数字容器系统
  • 【Block总结】DynamicFilter,动态滤波器降低计算复杂度,替换传统的MHSA|即插即用
  • python算法和数据结构刷题[2]:链表、队列、栈
  • zookeeper-3.8.3-基于ACL的访问控制
  • C++中实现全排列方法
  • 10.6 LangChain提示工程终极指南:从基础模板到动态生成的工业级实践
  • JAVA实战开源项目:在线文档管理系统(Vue+SpringBoot) 附源码
  • JavaScript图像处理,腐蚀算法和膨胀算法说明和作用介绍
  • 愿景:做机器视觉行业的颠覆者
  • 刷题记录 贪心算法-4:53. 最大子数组和
  • 从0开始使用面对对象C语言搭建一个基于OLED的图形显示框架(协议层封装)
  • 前端学习-事件解绑,mouseover和mouseenter的区别(二十九)
  • 【MySQL】MySQL客户端连接用 localhost和127.0.0.1的区别
  • SQLAlchemy 2.0的简单使用教程
  • 互斥锁/信号量实现5个线程同步
  • Redis|前言
  • FreeRTOS从入门到精通 第十六章(任务通知)
  • 玄武计划--干中学,知行合一
  • 全网首发,MacMiniA1347安装飞牛最新系统0.8.36,改造双盘位NAS,超详细.36,改造双盘位nas,超详细
  • Teleporters( Educational Codeforces Round 126 (Rated for Div. 2) )
  • 爬虫基础(六)代理简述
  • jvisualvm工具使用
  • 哈工大:屏蔽LLM检索头训练忠实性
  • 158页精品PPT | 机械行业数字化生产供应链产品解决方案