当前位置: 首页 > article >正文

理解 Transformer 中的编码器-解码器注意力层(Encoder-Decoder Attention Layer)

12. Encoder-Decoder Attention Layer

理解 Transformer 中的编码器-解码器注意力层(Encoder-Decoder Attention Layer)

在 Transformer 模型中,编码器-解码器注意力层的作用就像翻译过程中译员不断参考原文来生成准确的译文。这个过程在机器翻译任务中尤为重要,因为解码器需要时刻“回头”检查编码器生成的信息,从而生成正确的目标语句。


编码器-解码器注意力层的作用是什么?

在翻译任务中,解码器在生成每个词时都需要参考原句的上下文。这个层的作用就是帮助解码器根据编码器的结果找到与当前词最相关的原句信息。

可以把它想象成翻译人员在翻译过程中不断看回原文。例如在翻译“我在吃苹果”时,解码器在生成每个中文词(如“我”或“吃”)时,都需要回看编码器的输出,看它对应的是英文原文中的哪个词。

  • 编码器:负责处理输入句子的所有词,将它们变成包含信息的向量表示。
  • 编码器-解码器注意力层:解码每个词时,会查看编码器的结果,找出当前词最相关的信息。

工作原理:如何进行计算?

编码器-解码器注意力层的计算过程与自注意力机制类似,但来源略有不同:

  • Query:来自解码器当前步骤的向量。
  • Key 和 Value:来自编码器处理后的所有词的向量。

公式如下:

Attention ( Q , K , V ) = softmax ( Q ⋅ K T d k ) ⋅ V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q \cdot K^T}{\sqrt{d_k}}\right) \cdot V Attention(Q,K,V)=softmax(dk QKT)V

  • Query:解码器的输出,用来询问“我当前的翻译需要关注什么?”。
  • Key 和 Value:来自编码器的输出,表示输入句子的全部信息。
  • Softmax:将注意力权重转化为概率分布,让模型专注于当前需要的关键词。

类比:在课本中查找答案

想象你在做测验时允许参考课本。当遇到问题“地球为什么会有四季?”时,你不需要翻阅整本书,而会找到与“地理、季节”相关的章节。这种查找过程就像是解码器根据编码器的输出选择性参考信息:

  • 课本的内容(Key 和 Value):是编码器处理后的输入信息。
  • 你的问题(Query):是解码器当前翻译词的需求。
  • 查阅过程:通过计算相似度,找出最相关的信息并重点关注。

代码实现:编码器-解码器注意力层

以下代码展示了如何使用 PyTorch 实现编码器-解码器注意力层:

import torch
import torch.nn as nn

class EncoderDecoderAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(EncoderDecoderAttention, self).__init__()
        self.heads = heads
        self.embed_size = embed_size
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size)

    def forward(self, query, key, value, mask):
        N = query.shape[0]
        query_len, key_len = query.shape[1], key.shape[1]

        queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
        keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
        values = self.values(value).view(N, key_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.embed_size
        )

        out = self.fc_out(out)
        return out

# 测试编码器-解码器注意力层
embed_size = 256
heads = 8
attention_layer = EncoderDecoderAttention(embed_size, heads)

query = torch.rand(64, 10, embed_size)
key = torch.rand(64, 15, embed_size)
value = torch.rand(64, 15, embed_size)
mask = None

out = attention_layer(query, key, value, mask)
print(out.shape)  # 输出应为 (64, 10, 256)


代码解析

  1. EncoderDecoderAttention 类:定义了编码器-解码器注意力层,接受来自解码器的 Query 和编码器的 Key、Value。
  2. forward() 方法:计算 Query 和 Key 的相似度,使用 Softmax 得到注意力权重,再加权 Value。
  3. einsum 函数:高效地进行矩阵乘法,生成注意力权重和加权输出。

总结

编码器-解码器注意力层在 Transformer 中的核心作用是让解码器在生成每个词时都能“回看”编码器的输出,从而确保输出内容准确、连贯。这种结构在机器翻译中尤为关键,因为它帮助模型根据输入语句的上下文生成合理的输出。

希望这篇文章帮你更好地理解了编码器-解码器注意力层的工作原理和代码实现!如果有疑问,欢迎留言讨论!


http://www.kler.cn/a/382362.html

相关文章:

  • Synchronized锁、锁的四种状态、锁的升级(偏向锁,轻量级锁,重量级锁)
  • 书生大模型第三关Git 基础知识
  • [论文][环境]3DGS+Colmap环境搭建_WSL2_Ubuntu22.04 - 副本
  • spark的RDD分区的设定规则
  • 简介Voronoi图Voronoi Diagrams
  • 商品信息的修改、删除功能
  • 【测试语言篇一】Python进阶篇:内置容器数据类型
  • 24年配置CUDA12.4,Pytorch2.5.1,CUDAnn9.5运行环境
  • 【C++】踏上C++学习之旅(五):auto、范围for以及nullptr的精彩时刻(C++11)
  • 【LeetCode热题100】哈希表
  • 【大模型LLM面试合集】大语言模型架构_bert细节
  • [ DOS 命令基础 3 ] DOS 命令详解-文件操作相关命令
  • 三周精通FastAPI:27 使用使用SQLModel操作SQL (关系型) 数据库
  • 视图-数据库sqlserver
  • jmeter 性能测试步骤是什么?
  • 代码随想录训练营Day18 | 77. 组合 - 216.组合总和III - 17.电话号码的字母组合
  • Qml组件之Text
  • DGL库之dgl.function.u_mul_e(代替dgl.function.src_mul_edge)
  • 模拟实现strcat函数
  • 线程池核心参数有哪些
  • Vue 组件传递数据-Props(六)
  • Vue+Springboot 前后端分离项目如何部署?
  • 【FPGA】Verilog:理解德摩根第一定律: ( ̅A + ̅B) = ̅A x ̅B
  • 爬虫下载网页文夹
  • 【C++刷题】力扣-#697-数组的度
  • 【人工智能】Transformers之Pipeline(二十二):零样本文本分类(zero-shot-classification)