当前位置: 首页 > article >正文

【模型学习之路】手写+分析Transformer

手写+分析transformer

目录

前言

positional encoding

注意力机制

多头注意力

高维度乘法

多头注意力机制

多头注意力层的实现

Encoder

FeedForwardNet

EncoderLayer

Encoder

Decoder

DecoderLayer

Decoder

组装Trasformer!

后话

测试一下

mask


前言

Attention is all you need!

读本文前,建议至少看懂【Transformer】10分钟学会Transformer | Pytorch代码讲解 | 代码可运行 - 知乎的图解部分。当然大佬可以忽视这一句话。

positional encoding

每个词可以由一个词向量表示,我们这里假设单词可以由长度为n的向量表示。那么,一个有m词的句子就可以用一个(m,n)的矩阵来表示。

为了方便展示,我们这里取m=3, n=4。

在transformer中,一个句子的每一个字(词)是并行计算的,所以我们在输入的时候需要提前引入位置信息。计算positional encoding,公式如下(这里的索引是从0开始的):

然后将两个矩阵相加

注意力机制

用A分别与  三个矩阵相乘,这三个矩阵是我们神经网络中要训练的参数,得到      三个矩阵。其中q=k

这里写成分块矩阵是因为,一行正好代表一个字(词),这样写可以方便看到注意力到底干了什么。红色的是维度信息。

​​​​​​​

之后:

做一个小小标准化,然后按行做一个softmax

将softmax输出的值继续使用:

​​​​​​​

总结得到我们在Attention时干的事情:

​​​​​​​

多头注意力

高维度乘法

下面提一下numpy和torch的高维度乘法。

numpy和torch的高维度乘法:只在最后两个维度做矩阵乘法,在前面的维度中还是按位相乘的逻辑(有广播性质)。

举栗子:

AB是两个矩阵,shape分别为(m,s)和(s,n),令

    ​​​​​​​

则:

​​​​​​​

说明一下, 这种写法本应该是分块矩阵,但在这里只是借用一下这个符号,这里表示4个2维的 在新的维度上拼接成了一个3维的张量,具体为啥,代码里面要用(狗头)。

只在最后两个维度做矩阵乘法,在前面的维度中还是按位相乘的逻辑,当前面的维度不匹配且符合广播条件时,就会广播:

推广到4维同理。

4维乘3维,发生广播:

​​​​​​​

4维乘2维,发生广播:

​​​​​​​

上代码!

import torch

a = torch.zeros((666,2,3))
b = torch.zeros((666,3,4))
print((a @ b).shape)   # 666, 2, 4

a = torch.zeros((2,3))
b = torch.zeros((666,3,4))
print((a @ b).shape)  # 666, 2, 4

a = torch.zeros((999,666,2,3))
b = torch.zeros((999,666,3,4))
print((a @ b).shape)  # 999, 666, 2, 4

a = torch.zeros((999,1,2,3))
b = torch.zeros((999,666,3,4))
print((a @ b).shape)  # 999, 666, 2, 4

a = torch.zeros((666,2,3))
b = torch.zeros((999,666,3,4))
print((a @ b).shape)   # 999, 666, 2, 4

a = torch.zeros((2,3))
b = torch.zeros((999,666,3,4))
print((a @ b).shape)   # 999, 666, 2, 4

多头注意力机制

之前,我们有这样的公式:

 

​​​​​​​

就像从一个卷积核向多个卷积核过渡一样,我们可以使用多个注意力头,结合之前提到的多维张量的乘法,得到下面这些式子,其中h是头的数量:

其他的和单头大差不差。

 

          

多头注意力层的实现

实现如下图(注意:nn.Linear在应对三维输入时,只会在最后一维运用线性变换,如shape为(p,q,m)的输入经过Linear(m,n)会变成(p,q,n) ,四维及以上同理)。

下图的B显然是Batch_size。

分别指用来生成 的矩阵,在这里显然三者都等于输入进来的 ,这里之所以分开写,是因为后面会出现三者不是同一个矩阵的情况。

  其实就是Linear层,这个容易设计。

个人其实感觉多头注意力:

本质上几乎就是“单个更长的注意力头” (事实上,代码里也是这么用的):

   ​​​​​​​

看了一下attention is all you need论文原文,应该是后者,当时写的时候是按前者写的,算了,差别不大。

在最后输出时,transformer还用到了resnet的思想,output加上了输入。

此外,我们在设计网络时,还要考虑一个叫掩码的东东,它的维度和attn(见下图)一样,在特定的场景(模型的具体应用场景)中使用,用来替换或“屏蔽”attn中特定的值,这里也简单实现一下。

先实现Attention函数,四个输入:Q K V mask

def attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
              mask: torch.Tensor = None):
    k = K.size(-1)  # [B, h, m, k]
    scores = Q @ K.transpose(-2, -1) / np.sqrt(k)
    if mask is not None:
        scores = scores.masked_fill(mask, -1e9)
    attn = nn.Softmax(dim=-1)(scores)
    return attn @ V, attn

进一步,实现MultiHeadAttention类:

这里设置了n1和n2,不过在这一步,n1=n2=n,不过在后面会自然会出现两者不一样的情况。

class MultiHeadAttention(nn.Module):
    def __init__(self, n1, n2, h, k, v):
        # n1 表示encoder的, n2 表示decoder的
        super(MultiHeadAttention, self).__init__()
        self.n1 = n1
        self.n2 = n2
        self.h = h
        self.k = k
        self.v = v
        self.W_Q = nn.Linear(n2, h * k, bias=False)
        self.W_K = nn.Linear(n1, h * k, bias=False)
        self.W_V = nn.Linear(n1, h * v, bias=False)
        self.fc = nn.Linear(h * v, n2, bias=False)

    def forward(self, to_Q, to_K, to_V, mask=None):
        res = to_Q
        batch_size = to_Q.size(0)
        Q = self.W_Q(to_Q).view(batch_size, -1, self.h, self.k).transpose(1, 2)
        K = self.W_K(to_K).view(batch_size, -1, self.h, self.k).transpose(1, 2)
        V = self.W_V(to_V).view(batch_size, -1, self.h, self.v).transpose(1, 2)
        if mask is not None:
            mask = mask.unsqueeze(1).repeat(1, self.h, 1)
        out_put, attn = attention(Q, K, V, mask)
        out_put = out_put.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.v)
        out_put = self.fc(out_put) + res
        out_put = nn.LayerNorm(self.n2)(out_put)
        return out_put, attn

我们封装一下,完成Multi-HeadAttention和Add&Norm部分。

 

Encoder

FeedForwardNet

首先是一个前馈层,这个很简单,一笔带过

class FeedForwardNet(nn.Module):
    def __init__(self, n, d_ff):
        super(FeedForwardNet, self).__init__()
        self.n = n
        self.fc = nn.Sequential(
            nn.Linear(n, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, n, bias=False)
        )

    def forward(self, x):
        out = self.fc(x) + x
        out = nn.LayerNorm(self.n)(out)
        return out

EncoderLayer

然后就可以将我们实现的MultiHeadAttention与前馈层组合起来,形成一个EncoderLayer。

class EncoderLayer(nn.Module):
    def __init__(self, n, h, k, v, d_ff):
        super(EncoderLayer, self).__init__()
        self.d_ff = d_ff
        self.multi_head_attention = MultiHeadAttention(n, n, h, k, v)
        self.feed_forward_net = FeedForwardNet(n, d_ff)

    def forward(self, x, mask=None):
        out, attn = self.multi_head_attention(x, x, x, mask)
        out = self.feed_forward_net(out)
        return out, attn

Encoder

对了,差点忘记写positional encoding的代码了。

以及embedding的代码,这里就不解释embedding是啥意思了。

这里有一次广播,因为不同句子(B个句子(语义块),每个句子m个字,每个字表示为长为n的向量)对应的position矩阵肯定是一样的。

class PositionalEncoding(nn.Module):
    # 这里直接搬运那条知乎作者的代码
    def __init__(self, n, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pos_table = np.array([
            [pos / np.power(10000, 2 * i / n) for i in range(n)]
            if pos != 0 else np.zeros(n) for pos in range(max_len)])
        pos_table[1:, 0::2] = np.sin(pos_table[1:, 0::2])
        pos_table[1:, 1::2] = np.cos(pos_table[1:, 1::2])
        self.pos_table = torch.FloatTensor(pos_table)  # [m, n]

    def forward(self, x):  # x: [B, m, n]
        x += self.pos_table[:x.size(1), :]
        return self.dropout(x)
class Encoder(nn.Module):
    def __init__(self, vocab, n, h, k, v, d_ff, n_layers):
        super(Encoder, self).__init__()
        self.emb = nn.Embedding(vocab, n)  # [B, m] -> [B, m, vocab] -> [B, m, n]
        self.pos = PositionalEncoding(n)
        self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)
                                     for _ in range(n_layers)])

    def forward(self, x, mask=None):
        x = self.emb(x)
        x = self.pos(x)
        attn_lst = []
        for layer in self.layers:
            x, attn = layer(x, mask)
            attn_lst.append(attn)
        return x, attn_lst

Decoder

DecoderLayer

首先是单个decoder layer。

这里有两种MultiHeadAttention,一种输入全为decoder里的值,一种还要接收从endoder里出来的值。前者和encoder里面的一个意思,我们重点关注后者。

首先要知道一点,在NLP中,encoder和decoder输入的是不同的语言,也就是说,二者的m,n,vacab都是不一样的。在这里,encoder的表示为m,n。decoder的表示为M,N。这里只是为了写起来方便,代码中用m1 m2 n1 n2。

其中enc_attn里面长这样:

class DecoderLayer(nn.Module):
    def __init__(self, n1, n2, h, k, v, d_ff):
        super(DecoderLayer, self).__init__()
        self.enc_attn = MultiHeadAttention(n1, n2, h, k, v)
        self.dec_attn = MultiHeadAttention(n2, n2, h, k, v)
        self.ffn = FeedForwardNet(n2, d_ff)

    def forward(self, dec_in, enc_out, enc_mask=None, dec_mask=None):
        out, dec_attn = self.dec_attn(dec_in, dec_in, dec_in, dec_mask)
        out, enc_attn = self.enc_attn(out, enc_out, enc_out, enc_mask)
        out = self.ffn(out)
        return out, dec_attn, enc_attn

Decoder

然后做整个Decoder

class Decoder(nn.Module):
    def __init__(self, vocab, n, h, k, v, d_ff, n_layers):
        super(Decoder, self).__init__()
        self.emb = nn.Embedding(vocab, n)
        self.pos = PositionalEncoding(n)
        self.layers = nn.ModuleList([DecoderLayer(n, h, k, v, d_ff)
                                     for _ in range(n_layers)])

    def forward(self, x, enc_out, enc_mask=None, dec_mask=None):
        x = self.emb(x)
        x = self.pos(x)
        attn_lst = []
        for layer in self.layers:
            x, dec_attn, enc_attn = layer(x, enc_out, enc_mask, dec_mask)
            attn_lst.append((dec_attn, enc_attn))
        return x, attn_lst

组装Trasformer!

最后,组装!

class Transformer(nn.Module):
    def __init__(self, vocab1, vocab2, n1, n2, h, k, v, d_ff, n_layers,
                 enc_mask=None, dec_mask=None, enc_dec_mask=None):
        super(Transformer, self).__init__()
        self.enc_mask = enc_mask
        self.dec_mask = dec_mask
        self.enc_dec_mask = enc_dec_mask
        self.encoder = Encoder(vocab1, n1, h, k, v, d_ff, n_layers)
        self.decoder = Decoder(vocab2, n1, n2, h, k, v, d_ff, n_layers)
        self.out = nn.Linear(n2, vocab2, bias=False)  # [B, m, n] -> [B, m, vocab]
        
    def forward(self, enc_in, dec_in):
        enc_out, enc_attn_lst = self.encoder(enc_in, self.enc_mask)
        dec_out, dec_attn_lst = self.decoder(dec_in, enc_out, self.enc_dec_mask, self.dec_mask)
        out = self.out(dec_out)
        return out, enc_attn_lst, dec_attn_lst    

后话

测试一下

if __name__ == '__main__':
    test_enc_in = torch.randint(0, 50, (2, 10))  # [B, m1]
    test_dec_in = torch.randint(0, 50, (2, 12))  # [B, m2]

    model = Transformer(vocab1=100, vocab2=70, n1=32, n2=64, h=8, k=32, v=128, d_ff=128, n_layers=6,
                        enc_mask=None, dec_mask=None, enc_dec_mask=None)
    test_out, enc_attn_lst, dec_attn_lst = model(test_enc_in, test_dec_in)
    print(test_out.shape)  # [B, m2, vocab2]
    # torch.Size([2, 12, 70])

mask

最后聊聊mask。

mask是transformer的精髓之一,不同领域使用这个模型会用不一样的mask,这里的代码也是为mask提供了接口——三个位置的mask。

在NLP中,mask讲解可以参考这个手撕Transformer(二)| Transformer掩码机制的两个功能,三个位置的解析及其代码_transformer 掩码-CSDN博客

代码可以参考这个【深度学习】Transformer中的mask机制超详细讲解_transformer mask-CSDN博客



http://www.kler.cn/a/373057.html

相关文章:

  • Django 5 增删改查 小练习
  • 【论文阅读笔记】VLP: A Survey on Vision-language Pre-training
  • 提升网站速度与性能优化的有效策略与实践
  • 3D Gaussian Splatting代码详解(二):模型构建
  • 分别用webpack和vite注册全局组件
  • python之数据结构与算法(数据结构篇)-- 集合
  • 2024第二次随堂测验参考答案
  • 【C++】——高效构建与优化二叉搜索树
  • docker容器和宿主机端口映射
  • Linux 命令行学习:数据流控制、文本处理、文件管理与自动化脚本 (第二天)
  • Python Requests 的高级使用技巧:应对复杂 HTTP 请求场景
  • 《达梦》达梦数据库安装步骤(VMware16+麒麟 10+DM8)
  • 中小企业设备维护新策略:Spring Boot系统设计与实现
  • Tauri(一)——更适合 Web 开发人员的桌面应用开发解决方案 ✅
  • D365 FO开发参考
  • 应对市场变化与竞争对手挑战的策略
  • 分类算法——XGBoost 详解
  • Git 创建新的分支但清空提交记录
  • Linux 中,flock 对文件加锁
  • 智能听诊器:宠物健康监测的新纪元
  • [0260].第25节:锁的不同角度分类
  • 【简道云 -注册/登录安全分析报告】
  • 【STM32 Blue Pill编程实例】-I2C主从机通信(中断、DMA)
  • 1.STM32之定时器TIM---第一部分(基本定时器)(功能最强大结构最复杂的一个外设)(实验基本定时功能)-----定时器定时中断(利用内部时钟72M)
  • OpenCV视觉分析之目标跟踪(7)目标跟踪器类TrackerVit的使用
  • VueRouter引入步骤