【自然语言处理(NLP)】多头注意力(Multi - Head Attention)原理及代码实现
文章目录
- 介绍
- 多头注意力
- 原理
- 代码实现
- 导包
- 多头注意力结构
- qkv转换
- output转换
- 构建注意力模块
- 添加Bahdanau的decoder
- 训练
- 预测
个人主页:道友老李
欢迎加入社区:道友老李的学习社区
介绍
**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。
NLP的任务可以分为多个层次,包括但不限于:
- 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
- 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
- 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
- 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
- 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
- 机器翻译:将一种自然语言转换为另一种自然语言。
- 问答系统:构建可以回答用户问题的系统。
- 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
- 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
- 语音识别:将人类的语音转换为计算机可读的文字格式。
NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。
多头注意力
原理
多头注意力机制(Multi - Head Attention)的结构示意图:
多头注意力机制首先将查询(Query)、键(Key)、值(Value)分别通过多个全连接层进行线性变换,得到多个不同的表示。然后,对这些不同的表示分别进行注意力计算。最后,将各个注意力的结果进行连结(Concatenate),再通过一个全连接层得到最终输出。
这种机制允许模型在不同的表示子空间中并行地关注输入序列的不同部分,能够捕捉到更丰富的语义信息,广泛应用于Transformer等模型架构中,在自然语言处理、计算机视觉等领域有重要应用。
模型计算方式:
在该表达式中,
h
i
h_i
hi 是注意力机制计算得到的输出,
f
f
f 一般表示注意力计算函数(如缩放点积注意力等),
W
q
i
W_q^i
Wqi、
W
k
i
W_k^i
Wki、
W
v
i
W_v^i
Wvi 分别是针对查询(query)、键(key)、值(value)的可学习权重矩阵,
q
q
q、
k
k
k、
v
v
v 分别为查询向量、键向量、值向量 ,
R
n
\mathbb{R}^n
Rn 表示输出
h
i
h_i
hi 处于
n
n
n 维实数空间。它表达了在注意力计算中,通过对查询、键、值进行线性变换后再经过注意力计算函数得到输出的过程。
矩阵运算表达式:
表达式中 [ h 1 ⋮ h n ] \begin{bmatrix}h_1\\ \vdots \\h_n\end{bmatrix} h1⋮hn 是一个由 h 1 h_1 h1 到 h n h_n hn 构成的列向量,这些 h i h_i hi 通常可以是注意力机制等模块的输出。 W o W_o Wo 是一个可学习的权重矩阵,其维度为 R p × n \mathbb{R}^{p\times n} Rp×n ,这里 p p p 是输出维度相关参数, n n n 是输入向量的长度(即 h i h_i hi 的数量)。该表达式表示对由 h i h_i hi 组成的向量进行线性变换,常用于深度学习模型(如Transformer等)的后处理阶段,对前面模块输出进行进一步的特征变换或整合。
代码实现
导包
import math
import torch
from torch import nn
import dltools
多头注意力结构
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.attention = dltools.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# queries, keys, values 传入的形状: (batch_size, 查询熟练或者键值对数量, num_hiddens)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
# print('queries:', queries.shape)
# print('keys:', keys.shape)
# print('values:', values.shape)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
# output shape: (batch_size * num_heads, 查询的个数, num_hiddens/num_heads)
output = self.attention(queries, keys, values, valid_lens)
# print('output:', output.shape)
output_concat = transpose_output(output, self.num_heads)
# print('output_concat:', output_concat.shape)
return self.W_o(output_concat)
qkv转换
def transpose_qkv(X, num_heads):
# 输入X的shape: (batch_size, 查询数/键值对数, num_hiddens)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3) # batch_size, num_heads, 查询数/ 键值对数, num_hiddens/num_heads
# 这里是把batch_size和num_heads合并在一起了.
return X.reshape(-1, X.shape[2], X.shape[3]) # batch_size * num_heads, 查询/键值对数, num_hiddens/ num_heads
output转换
def transpose_output(X, num_heads):
# 逆转transpose_qkv的操作
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
构建注意力模块
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, 0.2)
attention.eval()
添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
super().__init__(**kwargs)
self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
# outputs : (batch_size, num_steps, num_hiddens)
# hidden_state: (num_layers, batch_size, num_hiddens)
outputs, hidden_state = enc_outputs
return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
def forward(self, X, state):
# enc_outputs (batch_size, num_steps, num_hiddens)
# hidden_state: (num_layers, batch_size, num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
# X : (batch_size, num_steps, vocab_size)
X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
X = X.permute(1, 0, 2)
outputs, self._attention_weights = [], []
for x in X:
query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
# print('query:', query.shape) # 4, 1, 16
context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
# print('context: ', context.shape)
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
# print('x: ', x.shape)
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
# print('out:', out.shape)
# print('hidden_state:', hidden_state.shape)
outputs.append(out)
self._attention_weights.append(self.attention_weights)
# print('---------------------------------')
outputs = self.dense(torch.cat(outputs, dim=0))
# print('解码器最终输出形状: ', outputs.shape)
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est paresseux .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000