当前位置: 首页 > article >正文

【自然语言处理(NLP)】Bahdanau 注意力(Bahdanau Attention)原理及代码实现

文章目录

  • 介绍
  • Bahdanau 注意力(Bahdanau Attention)
    • 原理
      • 公式含义
      • 计算过程
      • 编码器部分
      • 注意力机制部分
      • 解码器部分
    • 计算过程
    • 代码实现
      • 导包
      • 定义注意力解码器
      • 添加Bahdanau的decoder
      • 训练
      • 评估指标 bleu
      • 开始预测

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

Bahdanau 注意力(Bahdanau Attention)

Bahdanau注意力(Bahdanau Attention)是自然语言处理中一种经典的注意力机制。

在传统的编码器 - 解码器架构(如基于RNN的架构)中,编码器将整个输入序列编码为一个固定长度的向量,解码器依赖该向量生成输出。当输入序列较长时,这种固定长度向量难以存储所有重要信息,导致性能下降。Bahdanau注意力机制通过让解码器在生成每个输出时,动态关注输入序列不同部分,解决此问题。

原理

允许解码器在生成输出时,根据当前状态,从编码器的隐藏状态序列中选择性聚焦,获取与当前生成任务最相关信息,而非仅依赖单一固定向量。

Bahdanau 注意力机制中计算上下文向量的公式:
在这里插入图片描述

公式含义

  • c t c_t ct 表示在解码器的时间步 t t t 时得到的上下文向量,它综合了编码器隐藏状态序列中的信息,用于辅助解码器在该时间步生成输出。
  • T T T 是编码器的时间步总数,意味着要考虑编码器所有时间步的隐藏状态。
  • α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st1,hi) 是注意力权重,它表示在解码器时间步 t − 1 t - 1 t1 的隐藏状态 s t − 1 s_{t - 1} st1 条件下,对编码器第 i i i 个时间步隐藏状态 h i h_i hi 的关注程度。这个权重是通过一个特定的计算(通常涉及一个小型神经网络来计算相似度等)得到,并经过softmax函数归一化,取值范围在 0 0 0 1 1 1 之间,且 ∑ i = 1 T α ( s t − 1 , h i ) = 1 \sum_{i = 1}^{T}\alpha(s_{t - 1}, h_i)=1 i=1Tα(st1,hi)=1
  • h i h_i hi 是编码器在第 i i i 个时间步的隐藏状态,它包含了输入序列在该时间步及之前的信息。

计算过程

  • 首先,根据解码器上一个时间步的隐藏状态 s t − 1 s_{t - 1} st1 和编码器所有时间步的隐藏状态 h i h_i hi i i i 1 1 1 T T T),计算出每个 h i h_i hi 对应的注意力权重 α ( s t − 1 , h i ) \alpha(s_{t - 1}, h_i) α(st1,hi)
  • 然后,将这些注意力权重分别与对应的编码器隐藏状态 h i h_i hi 相乘,并对所有时间步的乘积结果进行求和,就得到了当前解码器时间步 t t t 的上下文向量 c t c_t ct

这个上下文向量 c t c_t ct 后续会与解码器当前时间步 t t t 的隐藏状态等信息结合,用于生成当前时间步的输出,比如在机器翻译任务中预测目标语言的下一个单词。

一个带有Bahdanau注意力的循环神经网络编码器-解码器模型:
在这里插入图片描述

编码器部分

  • 嵌入层:将源序列(如源语言句子中的单词)从离散的符号转换为低维、连续的向量表示,即词嵌入,便于模型后续处理,同时捕捉单词语义关系。
  • 循环层:一般由RNN、LSTM或GRU等单元构成。按顺序处理嵌入层输出的向量序列,每个时间步结合当前输入和上一时刻隐藏状态更新隐藏状态,逐步将源序列信息编码到隐藏状态中,最终输出包含源序列语义信息的隐藏状态序列。

注意力机制部分

位于编码器和解码器之间,允许解码器在生成输出时,根据当前状态从编码器的隐藏状态序列中动态选择相关信息。它计算解码器当前隐藏状态与编码器各时间步隐藏状态的相关性,得到注意力权重,对编码器隐藏状态加权求和生成上下文向量,为解码器提供与当前生成任务相关的信息。

解码器部分

  • 嵌入层:与编码器的嵌入层类似,将目标序列(如目标语言句子中的单词)的离散符号转换为向量表示,不过针对目标语言。
  • 循环层:接收编码器输出的隐藏状态序列以及注意力机制生成的上下文向量,结合目标序列嵌入向量,按顺序处理并更新隐藏状态,生成目标序列下一个元素的预测。
  • 全连接层:对循环层输出进行处理,将隐藏状态映射到目标词汇表维度,经softmax函数计算词汇表中每个单词的概率分布,预测当前时间步最可能的输出单词。

该架构在机器翻译、文本摘要等序列到序列任务中应用广泛,注意力机制可有效解决长序列信息处理难题,提升模型性能。

计算过程

  1. 计算注意力分数:解码器在时间步 t t t的隐藏状态 h t d e c h_t^{dec} htdec作为查询(query),与编码器所有时间步的隐藏状态 h i e n c h_i^{enc} hienc i = 1 , ⋯   , T i = 1, \cdots, T i=1,,T T T T为编码器时间步数)计算注意力分数 e t , i e_{t,i} et,i,一般通过一个小型神经网络计算,如 e t , i = a ( h t d e c , h i e n c ) e_{t,i}=a(h_t^{dec}, h_i^{enc}) et,i=a(htdec,hienc) a a a是一个非线性函数。
  2. 归一化注意力分数:将注意力分数 e t , i e_{t,i} et,i通过softmax函数归一化,得到注意力权重 α t , i \alpha_{t,i} αt,i,即 α t , i = exp ⁡ ( e t , i ) ∑ j = 1 T exp ⁡ ( e t , j ) \alpha_{t,i}=\frac{\exp(e_{t,i})}{\sum_{j = 1}^{T}\exp(e_{t,j})} αt,i=j=1Texp(et,j)exp(et,i),表示编码器第 i i i个时间步对解码器当前时间步 t t t的重要程度。
  3. 计算上下文向量:根据注意力权重对编码器隐藏状态加权求和,得到上下文向量 c t c_t ct c t = ∑ i = 1 T α t , i h i e n c c_t=\sum_{i = 1}^{T}\alpha_{t,i}h_i^{enc} ct=i=1Tαt,ihienc,它包含了与当前生成任务相关的输入信息。
  4. 生成输出:上下文向量 c t c_t ct与解码器当前隐藏状态 h t d e c h_t^{dec} htdec结合,如拼接后输入到后续网络层,生成当前时间步的输出。

代码实现

导包

import torch
from torch import nn
import dltools

定义注意力解码器

class AttentionDecoder(dltools.Decoder):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    @property
    def attention_weights(self):
        raise NotImplementedError

添加Bahdanau的decoder

class Seq2SeqAttentionDecoder(AttentionDecoder):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, dropout=0, **kwargs):
        super().__init__(**kwargs)
        self.attention = dltools.AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, dropout)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.dense = nn.Linear(num_hiddens, vocab_size)
        
    def init_state(self, enc_outputs, enc_valid_lens, *args):
        # outputs : (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        outputs, hidden_state = enc_outputs
        return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens)
    
    def forward(self, X, state):
        # enc_outputs (batch_size, num_steps, num_hiddens)
        # hidden_state: (num_layers, batch_size, num_hiddens)
        enc_outputs, hidden_state, enc_valid_lens = state
        # X : (batch_size, num_steps, vocab_size)
        X = self.embedding(X) # X : (batch_size, num_steps, embed_size)
        X = X.permute(1, 0, 2)
        outputs, self._attention_weights = [], []
        
        for x in X:
            query = torch.unsqueeze(hidden_state[-1], dim=1) # batch_size, 1, num_hiddens
            # print('query:', query.shape) # 4, 1, 16
            context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
            # print('context: ', context.shape)
            x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1)
            # print('x: ', x.shape)
            out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state)
            # print('out:', out.shape)
            # print('hidden_state:', hidden_state.shape)
            outputs.append(out)
            self._attention_weights.append(self.attention_weights)
            
        # print('---------------------------------')
        outputs = self.dense(torch.cat(outputs, dim=0))
        # print('解码器最终输出形状: ', outputs.shape)
        return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
    
    @property
    def attention_weights(self):
        return self._attention_weights

encoder = dltools.Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
encoder.eval()
decoder = Seq2SeqAttentionDecoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)
decoder.eval()

# batch_size 4, num_steps 7
X = torch.zeros((4, 7), dtype=torch.long)
state = decoder.init_state(encoder(X), None)
output, state = decoder(X, state)
output.shape, len(state), state[0].shape, len(state[1]), state[1][0].shape
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
query: torch.Size([4, 1, 16])
context:  torch.Size([4, 1, 16])
x:  torch.Size([4, 1, 24])
out: torch.Size([1, 4, 16])
hidden_state: torch.Size([2, 4, 16])
---------------------------------
解码器最终输出形状:  torch.Size([7, 4, 10])
(torch.Size([4, 7, 10]), 3, torch.Size([4, 7, 16]), 2, torch.Size([4, 16]))

训练

执行训练前,将decoder中的print屏蔽掉!!

embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 200, dltools.try_gpu()

train_iter, src_vocab, tgt_vocab = dltools.load_data_nmt(batch_size, num_steps)
encoder = dltools.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = dltools.EncoderDecoder(encoder, decoder)
dltools.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

评估指标 bleu

def bleu(pred_seq, label_seq, k):
    print('pred_seq', pred_seq)
    print('label_seq:', label_seq)
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - (len_label / len_pred)))
    for n in range(1, k + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1
            
        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        score *=  math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))   
    return score

开始预测

engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
for eng, fra in zip(engs, fras):
    translation = dltools.predict_seq2seq(net, eng, src_vocab, tgt_vocab, num_steps, device)
    print(f'{eng} => {translation}, bleu {dltools.bleu(translation[0], fra, k=2):.3f}')
go . => ('va !', []), bleu 1.000
i lost . => ("j'ai perdu .", []), bleu 1.000
he's calm . => ('il est malade .', []), bleu 0.658
i'm home . => ('je suis chez moi .', []), bleu 1.000

http://www.kler.cn/a/534202.html

相关文章:

  • MyBatis XML文件配置
  • 使用SpringBoot发送邮件|解决了部署时连接超时的bug|网易163|2025
  • C#使用实体类Entity Framework Core操作mysql入门:从数据库反向生成模型2 处理连接字符串
  • Codeforces Round 1002 (Div. 2)(部分题解)
  • 为AI聊天工具添加一个知识系统 之85 详细设计之26 批流一体式 与数据提取器
  • 嵌入式知识点总结 操作系统 专题提升(四)-上下文
  • Day36-【13003】短文,数组的行主序方式,矩阵的压缩存储,对称、三角、稀疏矩阵和三元组线性表,广义表求长度、深度、表头、表尾等
  • 02、NodeJS学习笔记,第二节:express与中间件
  • Redis常见数据类型与编码方式
  • RabbitMQ 与 Kafka 的核心区别,如何选择合适的消息中间件?
  • 【LLM】为何DeepSeek 弃用MST却采用Rejection采样
  • 洛谷P2638 安全系统
  • 解锁.NET Fiddle:在线编程的神奇之旅
  • 【Elasticsearch】filter聚合
  • 信标链的基本概念
  • python基础入门:2.2运算符与表达式
  • 根据SQL导出三线表文档
  • 能否通过蓝牙建立TCP/IP连接来传输数据
  • js-对象-JSON
  • [LeetCode] 二叉树 I — 深度优先遍历(前中后序遍历) | 广度优先遍历(层序遍历):递归法迭代法
  • 微服务知识——微服务架构的演进过程
  • 【完整版】DeepSeek-R1大模型学习笔记(架构、训练、Infra)
  • Mybatis之常用动态Sql语句
  • 云原生周刊:K8s引领潮流
  • Android 中APK 体积优化的几种方法
  • 【科研】 -- 医学图像处理方向,常用期刊链接