当前位置: 首页 > article >正文

attention 注意力机制 学习笔记-GPT2

注意力机制

这可能是比较核心的地方了。

gpt2 是一个decoder-only模型,也就是仅仅使用decoder层而没有encoder层。

decoder层中使用了masked-attention 来进行注意力计算。在看代码之前,先了解attention-forward的相关背景知识。

在普通的self-attention 中,对于一个长为T的句子,对其中第t个单词。需要计算t和句子中所有T个单词的注意力。也就是使用词t的Q向量 q t q_t qt 和 T中的所有单词的key向量 k j , 0 < = j < = T k_j, 0<=j<=T kj,0<=j<=T相乘。得到词t和句子中其他单词的注意力得分。

在这里插入图片描述

于是对于词t和当前句子S, 得到了注意力得分向量,而后对该向量使用softmax. 标准化的同时得到softmax后的注意力得分。

然后使用 每个词对应的值向量与注意力得分相乘之后再求和
( v 1 , v 2 , . . . , v T ) ( s c o r e t 1 s c o r e t 2 . . . s c o r e t T ) = o u t t (v_1, v_2, ..., v_T) \begin{pmatrix}score_{t1}\\score_{t2}\\... \\score_{tT}\end{pmatrix} = out_t (v1,v2,...,vT) scoret1scoret2...scoretT =outt
这里要注意, s o c r e t i socre_{ti} socreti 是一个标量值,但是 v t v_t vt 是 一个向量,长度和词嵌入向量长度相同,相加时,对每个向量位置元素对应相加。

在这里插入图片描述

对于masked-attention呢,实际上就是计算注意力得分时候,对第t个单词,仅仅计算0到t单词的注意力得分,t~T 部分的注意力得分不计算,计算softmaxs时t之后的部分以初值0代替。

在这里插入图片描述

在这里插入图片描述

multi-head attention

前面了解了attention基本知识,就很好理解多头注意力了。多头注意力实际上就是将单个Q,K,V向量,分裂为多个头,然后和self-attention一样流程计算每个头的注意力,最后得到一个输出向量,然后将多个头的输出向量拼接到一起,得到最后的输出结果。

在这里插入图片描述

比如,原本的一个向量长度为 l e n g t h Q = = l e n g t h K = = l e n g t h V = = 168 length_Q == length_K == length_V == 168 lengthQ==lengthK==lengthV==168 分裂为12个注意力头之后,每个注意力头的QKV向量长度为 l e n g t h Q i = = l e n g t h K i = = l e n g t h V i = 64 , i ∈ [ 0 , 12 ] length_{Q_i} == length_{K_i} == length_{V_i} = 64, i \in [0,12] lengthQi==lengthKi==lengthVi=64,i[0,12]

然后和分裂的self-attention一样,对每个词t的第i个头的Q向量 Q t i Q_{t_i} Qti,与其他词的第i个头的K向量 K j i , 0 < = j < = t , i ∈ [ 0 , 12 ] K_{j_i}, 0<=j<=t, i\in[0,12] Kji,0<=j<=t,i[0,12] 内积,得到注意力得分。

而后和self-attention一样的,每一个注意力头的Value向量和该头的注意力得分相乘,得到该注意力头的结果。

对于12个头长度为64的attention,最后得到12个64长的注意力结果

再将其拼接,得到长为768的注意attention forward结果,和单个注意力头但是长为768的attention结果相同。

在这里插入图片描述


http://www.kler.cn/a/393487.html

相关文章:

  • Could not initialize class sun.awt.X11FontManager
  • Python数据类型(一):bool布尔类型
  • 【练习案例】30个 CSS Javascript 加载器动画效果
  • 【嵌入式开发】单片机CAN配置详解
  • BERT配置详解1:构建强大的自然语言处理模型
  • Java项目实战II基于微信小程序的个人行政复议在线预约系统微信小程序(开发文档+数据库+源码)
  • python---基础语法
  • 【HarmonyOS】Install Failed: error: failed to install bundle.code:9568289
  • CCF认证-202403-04 | 十滴水
  • 人工智能(AI)和机器学习(ML)技术学习流程
  • python 同时控制多部手机
  • 华纳云:数据库一般购买什么服务器好?有哪些建议
  • Flink_DataStreamAPI_输出算子Sink
  • 现代无线通信接收机架构:超外差、零中频与低中频的比较分析
  • 人机界面与人们常说的“触摸屏”有什么区别?这下终于清楚了
  • Java反序列化之CommonsCollections2链的学习
  • golang go语言 组建微服务架构详解 - 代码基于开源框架grpc+nacos服务管理配置平台
  • 详解基于C#开发Windows API的SendMessage方法的鼠标键盘消息发送
  • 时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
  • FPGA实现PCIE3.0视频采集转SDI输出,基于XDMA+GS2971架构,提供工程源码和技术支持
  • ASR+LLM+TTS在新能源汽车中的实战
  • 安装luasocket模块时提示“sudo: luarocks:找不到命令“问题,该如何解决?
  • SDL读取PCM音频
  • Docker在微服务架构中的最佳实践
  • 云速搭助力用友 BIP 平台快速接入阿里云产品
  • 计算机网络(8)数据链路层之子层