当前位置: 首页 > article >正文

gMLP:Pay Attention to MLPs--模型代码讲解

gMLP模型代码讲解

  • Introduction
  • gMLP网络结构
    • Spatial Gating Unit (SGU)
  • code
    • gMLPBlock
    • Spatial Gating Unit

基于MLP-Mixer 的改进…

Introduction

总的来说,gMLP 在视觉和NLP领域的惊人有效性表明,自我注意并不是扩大机器学习模型的必要因素,尽管它根据任务的不同可以是一个有用的补充随着数据和计算量的增加。具有gMLP等更简单的空间交互机制的模型可以像变压器一样强大,分配给自我注意的能力可以被删除或大幅减少。

gMLP网络结构

gMLP 的输入仍为若干图像块(即将一张图像切割成若干图像块),输出为若干个向量(token)堆叠组成的矩阵,例如token的维度为L,个数为N,则输出为N ∗ L 的矩阵,通过池化等操作转换为最终的特征向量。
由若干个基本构成单元堆叠而成
在这里插入图片描述

设输入矩阵(即图中的Input Embeddedings)为 n ∗ d n∗d nd 的矩阵X , n为序列长度, d为特征维度,则gMLP的unit结构可以简化为 Z = δ ( X U ) Z ~ = s ( Z ) Y = δ ( Z ~ V ) + X Z=\delta (XU)\\ \tilde{Z} = s(Z)\\ Y=\delta(\tilde{Z}V)+X Z=δ(XU)Z~=s(Z)Y=δ(Z~V)+X
U , V U,V U,V为可学习的矩阵, δ \delta δ 为激活函数, s ( z ) s(z) s(z) 为图中的Spatial Gating Unit.

Spatial Gating Unit (SGU)

为了能有跨token的交互, s ( ⋅ ) s(\cdot) s() 操作须在空间维度。可以简单的使用线性映射表示: f W , b ( Z ) = W Z + b s ( Z ) = Z ⊙ f W , b ( Z ) f_{W,b}(Z)=WZ+b\\ s(Z)=Z⊙f_{W,b}(Z) fW,b(Z)=WZ+bs(Z)=ZfW,b(Z) Z Z Z n ∗ d n∗d nd 的矩阵,则 W W W n ∗ n n∗n nn 的矩阵,表示空间交互的映射参数,b 为n 维向量(WZ+b表示WZ的第一行元素与b的第一维元素相加),为了保证训练的稳定性,W 初始化值接近于0(貌似用[-1,1]的均匀分布初始化),b 的初始值为1,此时 f W , b ( Z ) ≈ 1 , s ( Z ) ≈ Z f_{W,b}(Z)\approx1,s(Z)\approx Z fW,b(Z)1,s(Z)Z,这种初始化确保了每个gMLP块在训练的早期阶段像一个常规的FFN,其中每个token 都被独立处理,并且只在学习过程中逐步跨token注入空间信息。

更进一步的作者发现将Z 沿着channel维度切割成 Z 1 , Z 2 Z_1,Z_2 Z1,Z2 ( Z 1 , Z 2 Z_1,Z_2 Z1,Z2的维度分别为 n ∗ d 1 , n ∗ d 2 , d 1 + d 2 = n n*d_1,n*d_2,d_1+d_2=n nd1,nd2,d1+d2=n)两个部分更为有效,此时s(Z)操作变为
s ( Z ) = Z 1 ⊙ f W , b ( Z 2 ) s(Z)=Z_1\odot f_{W,b}(Z_2) s(Z)=Z1fW,b(Z2)

code

先看整体结构,在整个gMLP结构中,gmlp代替self-attention设计了框架结构。每一个层级使用gMLPBlock作为一个block阶段。整个残差形式为gmlp(norm(x))+x.

class gMLP(nn.Module):
    def __init__(
            self,
            *,
            ...
    ):
        super().__init__()
        dim_ff = dim * ff_mult
        self.seq_len = seq_len
        self.prob_survival = prob_survival

        self.to_embed = nn.Embedding(num_tokens, dim) if exists(num_tokens) else nn.Identity()

        self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, attn_dim = attn_dim, causal = causal, act = act))) for i in range(depth)])
        #  gmlp(norm(x))+x

        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_tokens)
        ) if exists(num_tokens) else nn.Identity()

    def forward(self, x):
        x = self.to_embed(x)
        layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
        out = nn.Sequential(*layers)(x)
        return self.to_logits(out)
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)

gMLPBlock

class gMLPBlock(nn.Module):
    def __init__(
            self,
            *,
            dim,
            dim_ff,
            seq_len,
            attn_dim = None,
            causal = False,
            act = nn.Identity()
    ):
        super().__init__()
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_ff),
            nn.GELU()
        )
		# dim_ff = dim * ff_mult(4)
		# dim -> dim*4
        self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if exists(attn_dim) else None

        self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act)
        self.proj_out = nn.Linear(dim_ff // 2, dim)

    def forward(self, x):
        gate_res = self.attn(x) if exists(self.attn) else None
		# 默认的attn是None,即不进行该操作
        x = self.proj_in(x)
        x = self.sgu(x, gate_res = gate_res)
        x = self.proj_out(x)
        return x

Spatial Gating Unit

class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, causal = False, act = nn.Identity(), init_eps = 1e-3):
        super().__init__()
        dim_out = dim // 2
        self.causal = causal

        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1)

        self.act = act

        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)

    def forward(self, x, gate_res = None):
        device, n = x.device, x.shape[1]

        res, gate = x.chunk(2, dim = -1)
        # self-atten 用的dim
        # sgu用的dim_ff = dim * ff_mult(4),即4倍
        # chunk之后,每个为2倍,用两倍的值进行attention
        gate = self.norm(gate)

        weight, bias = self.proj.weight, self.proj.bias
        if self.causal:
            ...

        gate = F.conv1d(gate, weight, bias)
		# 1d卷积混合w*h维度的信息,patch通道的混合
        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res

http://www.kler.cn/a/323359.html

相关文章:

  • PostgreSQL TRUNCATE TABLE
  • 8.C++面向对象5(实现一个较为完善的日期类)
  • Chrome 浏览器开启打印模式
  • 动态规划-完全背包问题——518.零钱兑换II
  • Vuex vs Pinia:新一代Vue状态管理方案对比
  • FRP 实现内网穿透
  • 数字通云平台智慧政务 login 存在登录绕过
  • Java | Leetcode Java题解之第435题无重叠区间
  • E9OA解决文档附件没有关联文档正文问题
  • 54K55LyB5p2l5a6i5pyN57O757uf token硬编码漏洞
  • Spring源码学习:SpringMVC(2)DispatcherServlet初始化【子容器9大组件】
  • 对于 Vue CLI 项目如何引入Echarts以及动态获取数据
  • 机器学习-SVM
  • xxl-job 适配达梦数据库
  • StarRocks Elasticsearch Catalog原理简析
  • 【机器学习】目标分类算法概述
  • UI设计师面试整理-作品集展示
  • 基于Hive和Hadoop的招聘分析系统
  • GUI-窗口,模态窗口,拖动窗口
  • centos离线安装nvm
  • 2024新版IDEA创建JSP项目
  • 查看和升级pytorch到指定版本
  • 如何让 Android 的前端页面像 iOS 一样“优雅”?
  • 从 ES5 到 ES14:深入解析 JavaScript 的演进与特性
  • 828华为云征文|部署去中心化网络的 AI 照片管理应用 PhotoPrism
  • 【教程】最新可用! 移动云手机开启Root权限方法