当前位置: 首页 > article >正文

【论文复现】ViT:对图片进行分类

在这里插入图片描述

📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹

在这里插入图片描述
在这里插入图片描述

❀ ViT:对图片进行分类

  • 概述
  • 模型结构
    • 模型总体框架
    • Patch_embed
    • Transformer Encoder
    • MLP Head
  • 演示效果
  • 核心逻辑
  • 部署方式
  • 参考文献

概述


Transformer架构虽然已经成为自然语言处理任务的标准,但是它在计算机视觉的应用仍然有限,先前的视觉任务中,注意力大多与卷积结合使用。ViT模型的出现,证明了对CNN的依赖是不必要的,直接应用于图像补丁序列的纯Transformer架构可以在图像分类任务中表现良好。

本文所涉及的所有资源的获取方式:这里

模型结构


模型总体框架


在这里插入图片描述
上述是ViT模型的基本框架,可以大致分为三个主要部分

  • Patch_embed(将图片分成一系列的patches)
  • Transformer Encoder(建模不同序列之间的相关性)
  • MLP Head(用于最终的分类结构)

Patch_embed


在标准的Transformer模块中,输入的格式为二维矩阵 [num_token,token_dim] ,但对于图像数据而言,其输入数据的格式为[H,W,C] 的三维矩阵,明显不是Transformer架构需要的。所以需要Patch_embed结构将其转换为Transformer架构的输入。

针对于ViT-B/16而言,将输入图片(224x224)按照大小为(16x16) 的Patch进行划分,生成196个Patch。此时通过线性映射将每个Patch映射到一个长度为768 (16x16x3) 一维向量中。这一步可以通过卷积核大小为16x16,步距为 16 的卷积来实现。最后将长宽进行展平,则得到Transformer需要的输入格式。具体的维度变换如下所示:
[224,224,3] -> [14,14,768] -> [196,768]

在输入到Transformer Encoder之前还需要加上 [class]token z 0 0 z_0^0 z00 = x c l a s s x_{class} xclass,它在Transformer 编码器 z L 0 z_L^0 zL0,输出处的状态用作图像表示 y , 在预训练和微调过程中, z L 0 z_L^0 zL0处都具有一个分类头。

同时需要将Position Embeddin[197,768]叠加(add)到上述的token上.

在这里插入图片描述
如上图所示,第一行第一列的位置编码上与其自身的余弦相似度最高,其次是与第一行和第一列的余弦相似度更高,这符合常理。

Transformer Encoder


Transformer Encoder 本身是堆叠Encoder Block L 次,ViT-B/16是12次。主要有以下几部分组成:

  • Layer Norm: 针对NLP领域提出,因为在RNN这类时序网络中,时序的长度并不一定是一个定值,Layer Norm在每个样本的每个特征维度上进行归一化,使得每个特征的均值为0,方差为1,从而有助于提高模型的训练效果和泛化能力。
  • Multi-head Attention: 使用多头注意力机制能够联合来自不同head部分学习到的信息。
  • MLP Block:由全连接+GELU激活函数+Dropout组成,在ViT-B/16的模型结构中,第一个全连接层将输入节点的个数翻4倍,第二个全连接层键还原节点的个数。

MLP Head


通过Transfomer Encoder后输入的shape和输出的shape保持不变,由于我们只需要分类信息,因此只需要提取[class]token 的结果 z L 0 z_L^0 zL0 ,之后通过MLP Head得到最后的分类结果。

模型的公式如下,其中E表示token的个数
在这里插入图片描述

演示效果


可视化输入图片的形式
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可视化模型运行结果
在这里插入图片描述

核心逻辑


对输入图片进行分块处理

class PatchEmbed(nn.Module):
    def __init__(self,img_size=224,patch_size=16,in_c=3,embed_dim=768,norm_layer=None):
        super(PatchEmbed,self).__init__()
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])
        self.proj = nn.Conv2d(in_c,embed_dim,kernel_size=patch_size,
                               stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim) if norm_layer else nn.Identity()
    
    def forward(self,x):
        # 首先需要判断输入图片的大小符合我们的预期
        B,C,H,W=x.shape
        assert H==self.img_size[0] and W==self.img_size[1],\
            f"input image{H}x{W} does not model {self.img_size[0]}x{self.img_size[1]}"
        # [N,in_c,H,W]->[N,embed_dim,H//16,W//16]->[N,embed_dim,H//16*W//16]
        x = self.proj(x).flatten(2).transpose(1,2)
        
        x = self.norm(x)
        
        return x

多头注意力机制

class Attention(nn.Module):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0,
                 proj_drop_ratio=0):
        super(Attention,self).__init__()
        self.head_dim = dim//num_heads
        self.num_heads = num_heads
        self.dim = dim
        self.scale = qk_scale or self.head_dim**(0.5)
        
        
        self.qkv = nn.Linear(dim,dim*3,bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)
        
    def forward(self,x):
        # [batch_size,num_patches+class_token,channel:HxW]
        B,N,C = x.shape
        
        # 将其进行投影,也就是多头自注意力机制所说的矩阵相乘
        # reshape [B,N,C]->[B,N,3,heads,head_dim]->[3,B,heads,N,head_dim]
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        # [B,heads,N,head_dim]
        q,k,v=qkv[0],qkv[1],qkv[2]
        # [B,heads,N,N]
        attn = (q@k.transpose(-2,-1))*self.scale
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # [B,heads,N,head_dim]->[B,N,heads,head_dim]->[B,N,heads*head_dim]
        # x = attn@v.permute(0,2,1,3).flatten(2)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        
        x = self.proj_drop(self.proj(x))
        return x

MLP 模块

class MLP(nn.Module):
    def __init__(self,in_features,hidden_features=None,out_features=None,act_layer=nn.GELU,drop=0.):
        super(MLP,self).__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features,hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features,out_features)
        self.drop = nn.Dropout(drop)
    
    # 根据流程图确定其中的结构,注意是先激活函数之后才是dropout操作   
    def forward(self,x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

Block 结构

class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 drop_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm
                 ):
        super(Block,self).__init__()
        
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim=dim,num_heads=num_heads,qkv_bias=qkv_bias,qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio,proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        # Linear都需要是int的类型数据
        self.mlp = MLP(dim,int(dim*mlp_ratio),dim,act_layer,drop_ratio)
        
        self.norm2 = norm_layer(dim)
        
    def forward(self,x):
        x = x+self.drop_path(self.attn(self.norm1(x)))
        x = x+self.drop_path(self.mlp(self.norm2(x)))
        
        return x

ViT 模块

class VisionTransformer(nn.Module):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_c=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=12,
                 num_heads=12,
                 mlp_ratio=4.0,
                 qkv_bias=True,
                 qk_scale=None,
                 representation_size=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 embed_layer=PatchEmbed,
                 norm_layer=None,
                 act_layer=None,
                 ):
        super(VisionTransformer,self).__init__()
        # 首先需要进行初始化操作,还可以对权重进行初始化操作
        self.num_classes = num_classes
        self.embed_dim = self.num_features = embed_dim
        self.num_tokens = 1 
        act_layer = act_layer or nn.GELU
        norm_layer = norm_layer or partial(nn.LayerNorm,eps=1e-6)
        
        self.patch_embed = embed_layer(img_size=img_size,patch_size=patch_size,in_c=in_c,embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # cls_token是针对于每个embed_dim确定一个class
        # pos_embed除了channel 还要针对于每一个patch确定结果
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,num_patches+self.num_tokens,embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)
        
        dpr = [x.item() for x in torch.linspace(0,drop_path_ratio,depth)]
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim,num_heads=num_heads,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,
                  drop_ratio=drop_ratio,attn_drop_ratio=attn_drop_ratio,drop_path_ratio=dpr[i],
                  norm_layer=norm_layer,act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)
        
        # Pre_logits layer 相当于多添加了一个全连接层
        if representation_size:
            self.has_logits=True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ('fc',nn.Linear(embed_dim,representation_size))
                ('out',nn.Tanh())]
            ))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()
            
        self.head = nn.Linear(self.num_features,self.num_classes) if num_classes>0 else nn.Identity()
        
        # 开始对所有的权重进行初始化操作
        nn.init.trunc_normal_(self.pos_embed,std=0.02)
        nn.init.trunc_normal_(self.cls_token,std=0.02)
        self.apply(_init_vit_weights)
        
    def forward(self,x):
        B,C,H,W = x.shape
        #[B,C,H,W]->[B,N,H*W]
        x = self.patch_embed(x)
        # 每次都需要进行操作,所以不能对其本身进行expand操作
        cls_token = self.cls_token.expand(B,-1,-1)
        
        # 注意到后续一个是cat操作一个是add操作,且位置的先后关系
        x = torch.cat((cls_token,x),dim=1)
        
        # self.pos_embed中针对于一个batch值共享
        x = self.pos_drop(x+self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        
        
        x = self.pre_logits(x[:,0])
        x = self.head(x)
        return x  

部署方式


python 3.7.16

torch == 1.13.1
torchvision == 0.14.1
tqdm == 4.66.2
pillow == 9.5.0
matplotlib == 3.5.3

参考文献


论文下载地址
源码参考地址
参考博客地址


编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!

更多内容详见:这里


http://www.kler.cn/a/415211.html

相关文章:

  • 【案例学习】如何使用Minitab实现包装过程的自动化和改进
  • 单片机知识总结(完整)
  • HarmonyOS(60)性能优化之状态管理最佳实践
  • IDEA报错: java: JPS incremental annotation processing is disabled 解决
  • AI智算-正式上架GPU资源监控概览 Grafana Dashboard
  • C语言——海龟作图(对之前所有内容复习)
  • RHCE NFS
  • 网络连接设备与技术
  • VSCode修改资源管理器文件目录树缩进(VSCode目录结构、目录缩进、文件目录外观)workbench.tree.indent
  • AI开发:生成式对抗网络入门 模型训练和图像生成 -Python 机器学习
  • 《Python基础》之OS模块
  • 第04章_运算符(基础)
  • C# 解决【托管调试助手 “ContextSwitchDeadlock“:……】问题
  • 《代码随想录》刷题笔记——栈与队列篇【java实现】
  • 【力扣】389.找不同
  • SLAM算法融合处理多源信息实现定位和姿态估计,并最终完成路径规划、运动控制和避障与动态环境应对
  • 支持多种快充协议的取电芯片,支持最大功率140W
  • Python学习入门教程
  • 路径规划之启发式算法之一:A-Star(A*)算法
  • 第一周周总结
  • 大数据-237 离线数仓 - 广告业务 需求分析 ODS DWD UDF JSON 串解析
  • 深入探索Flax:一个用于构建神经网络的灵活和高效库
  • RBF神经网络预测结合NSGAII多目标优化
  • HTTP(网络)
  • 【LeetCode面试150】——141环形列表
  • milvus 通俗易懂原理