理解 Transformer 中的编码器-解码器注意力层(Encoder-Decoder Attention Layer)
12. Encoder-Decoder Attention Layer
理解 Transformer 中的编码器-解码器注意力层(Encoder-Decoder Attention Layer)
在 Transformer 模型中,编码器-解码器注意力层的作用就像翻译过程中译员不断参考原文来生成准确的译文。这个过程在机器翻译任务中尤为重要,因为解码器需要时刻“回头”检查编码器生成的信息,从而生成正确的目标语句。
编码器-解码器注意力层的作用是什么?
在翻译任务中,解码器在生成每个词时都需要参考原句的上下文。这个层的作用就是帮助解码器根据编码器的结果找到与当前词最相关的原句信息。
可以把它想象成翻译人员在翻译过程中不断看回原文。例如在翻译“我在吃苹果”时,解码器在生成每个中文词(如“我”或“吃”)时,都需要回看编码器的输出,看它对应的是英文原文中的哪个词。
- 编码器:负责处理输入句子的所有词,将它们变成包含信息的向量表示。
- 编码器-解码器注意力层:解码每个词时,会查看编码器的结果,找出当前词最相关的信息。
工作原理:如何进行计算?
编码器-解码器注意力层的计算过程与自注意力机制类似,但来源略有不同:
- Query:来自解码器当前步骤的向量。
- Key 和 Value:来自编码器处理后的所有词的向量。
公式如下:
Attention ( Q , K , V ) = softmax ( Q ⋅ K T d k ) ⋅ V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V Attention(Q,K,V)=softmax(dkQ⋅KT)⋅V
- Query:解码器的输出,用来询问“我当前的翻译需要关注什么?”。
- Key 和 Value:来自编码器的输出,表示输入句子的全部信息。
- Softmax:将注意力权重转化为概率分布,让模型专注于当前需要的关键词。
类比:在课本中查找答案
想象你在做测验时允许参考课本。当遇到问题“地球为什么会有四季?”时,你不需要翻阅整本书,而会找到与“地理、季节”相关的章节。这种查找过程就像是解码器根据编码器的输出选择性参考信息:
- 课本的内容(Key 和 Value):是编码器处理后的输入信息。
- 你的问题(Query):是解码器当前翻译词的需求。
- 查阅过程:通过计算相似度,找出最相关的信息并重点关注。
代码实现:编码器-解码器注意力层
以下代码展示了如何使用 PyTorch 实现编码器-解码器注意力层:
import torch
import torch.nn as nn
class EncoderDecoderAttention(nn.Module):
def __init__(self, embed_size, heads):
super(EncoderDecoderAttention, self).__init__()
self.heads = heads
self.embed_size = embed_size
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"
self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
self.fc_out = nn.Linear(self.embed_size, self.embed_size)
def forward(self, query, key, value, mask):
N = query.shape[0]
query_len, key_len = query.shape[1], key.shape[1]
queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
values = self.values(value).view(N, key_len, self.heads, self.head_dim)
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) # (N, heads, query_len, key_len)
if mask is not None:
energy = energy.masked_fill(mask == 0, float("-1e20"))
attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
N, query_len, self.embed_size
)
out = self.fc_out(out)
return out
# 测试编码器-解码器注意力层
embed_size = 256
heads = 8
attention_layer = EncoderDecoderAttention(embed_size, heads)
query = torch.rand(64, 10, embed_size)
key = torch.rand(64, 15, embed_size)
value = torch.rand(64, 15, embed_size)
mask = None
out = attention_layer(query, key, value, mask)
print(out.shape) # 输出应为 (64, 10, 256)
代码解析
- EncoderDecoderAttention 类:定义了编码器-解码器注意力层,接受来自解码器的 Query 和编码器的 Key、Value。
- forward() 方法:计算 Query 和 Key 的相似度,使用 Softmax 得到注意力权重,再加权 Value。
- einsum 函数:高效地进行矩阵乘法,生成注意力权重和加权输出。
总结
编码器-解码器注意力层在 Transformer 中的核心作用是让解码器在生成每个词时都能“回看”编码器的输出,从而确保输出内容准确、连贯。这种结构在机器翻译中尤为关键,因为它帮助模型根据输入语句的上下文生成合理的输出。
希望这篇文章帮你更好地理解了编码器-解码器注意力层的工作原理和代码实现!如果有疑问,欢迎留言讨论!