机器翻译之多头注意力(MultiAttentionn)在Seq2Seq的应用
目录
1.多头注意力(MultiAttentionn)的理念图
2.代码实现
2.1创建多头注意力函数
2.2验证上述封装的代码
2.3 创建 添加了Bahdanau的decoder
2.4训练
2.5预测
3.知识点个人理解
1.多头注意力(MultiAttentionn)的理念图
2.代码实现
2.1创建多头注意力函数
class MultiHeadAttention(nn.Module):
#初始化属性和方法
def __init__(self, query_size, key_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
"""
query_size_size: query_size的特征数features
key_size: key_size的特征数features
value_size: value_size的特征数features
num_hiddens:隐藏层的神经元的数量
num_heads:多头注意力的header的数量
dropout: 释放模型需要计算的参数的比例
bias=False:没有偏差
**kwargs : 不定长度的关键字参数
"""
super().__init__(**kwargs)
#接收参数
self.num_heads = num_heads
#初始化注意力, #使用DotProductAttention时, keys与 values具有相同的长度, 经过decoder,他们长度相同
self.attention = dltools.DotProductAttention(dropout)
#初始化四个w模型参数
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
def transpose_qkv(X, num_heads):
"""实现queries, keys, values的数据维度转化"""
#输入的X的shape=(batch_size, 查询数/键值对数量, num_hiddens)
#这里,不能直接用reshape,需要索引维度,防止数据不能一一对应
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) #将原维度的num_hiddens拆分成num_heads, -1, -1相当于num_hiddens/num_heads的数值
X = X.permute(0, 2, 1, 3) #X的shape=(batch_size, num_size, 查询数/键值对数量, num_hiddens/num_heads)
return X.reshape(-1, X.shape[2], X.shape[3]) #X的shape=(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)
def transpose_outputs(X, num_heads):
"""逆转transpose_qkv的操作"""
#此时数据的X的shape =(batch_size*num_heads, 查询数/键值对数量, num_hiddens/num_heads)
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) #X的shape=(batch_size, num_heads, 查询数/键值对数量, num_hiddens/num_heads)
X = X.permute(0, 2, 1, 3) #X的shape=(batch_size, 查询数/键值对数量, num_heads, num_hiddens/num_heads)
return X.reshape(X.shape[0], X.shape[1], -1) #X的shape还原了=(batch_size, 查询数/键值对数, num_hiddens)
#queries, keys, values,传入的shape=(batch_size, 查询数/键值对数, num_hiddens)
#获取转换维度之后的queries, keys, values,
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
#若valid_len不为空,存在
if valid_lens is not None:
#将valid_lens重复数据self.num_heads次,在0维度上
valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
#若为空,什么都不做,跳出if判断,继续执行其他代码
#通过注意力函数获取输出outputs
#outputs的shape = (batch_size*num_heads, 查询的个数, num_hiddens/num_heads)
outputs = self.attention(queries, keys, values, valid_lens)
#逆转outputs的维度
outputs_concat = transpose_outputs(outputs, self.num_heads)
return self.W_o(outputs_concat)
2.2验证上述封装的代码
#假设变量
num_hiddens, num_heads, dropout = 100, 5, 0.2
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
attention.eval() #需要预测,加上
MultiHeadAttention( (attention): DotProductAttention( (dropout): Dropout(p=0.2, inplace=False) ) (W_q): Linear(in_features=100, out_features=100, bias=False) (W_k): Linear(in_features=100, out_features=100, bias=False) (W_v): Linear(in_features=100, out_features=100, bias=False) (W_o): Linear(in_features=100, out_features=100, bias=False) )
#假设变量
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens = 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens)) #shape(2,4,100)
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) #shape(2,6,100)
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])
2.3 创建 添加了Bahdanau的decoder
# 添加Bahdanau的decoder
class Seq2SeqMultiHeadAttentionDecoder(dltools.AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_heads, num_layers, dropout=0, **kwargs):
super().__init__(**kwargs)
self.attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, num_heads, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Linear(num_hiddens, vocab_size)
def init_state(self, enc_outputs, enc_valid_lens, *args):
# outputs : (batch_size, num_steps, num_hiddens)
# hidden_state: (num_layers, batch_size, num_hiddens)
outputs, hidden_state = enc_outputs
return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
def forward(self, X, state):
# enc_outputs (batch_size, num_steps, num_hiddens)
# hidden_state: (num_layers, batch_size, num_hiddens)
enc_outputs, hidden_state, enc_valid_lens = state
# X : (batch_size, num_steps, vocab_size)
X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
X = X.permute(1, 0, 2)
outputs, self._attention_weights = [], []
for x in X:
query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention_weights)
outputs = self.dense(torch.cat(outputs, dim=0))
return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
@property
def attention_weights(self):
return self._attention_weights
2.4训练
# 训练
embed_size, num_hiddens, num_layers, dropout = 32, 100, 2, 0.1
batch_size, num_steps, num_heads = 64, 10, 5
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()
train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqMultiHeadAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_heads, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
2.5预测
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000 i lost . => ("j'ai perdu .", []), bleu 1.000 he's calm . => ('trouvez tom .', []), bleu 0.000 i'm home . => ('je suis chez moi .', []), bleu 1.000
3.知识点个人理解