当前位置: 首页 > article >正文

论文阅读(二十五):PVTv2: Improved Baselines with Pyramid Vision Transformer

文章目录

  • 1.回顾PVT
  • 2.重叠块嵌入
  • 3.移除固定位置编码
  • 4.线性SRA
  • 5.版本结构


  论文:PVTv2: Improved Baselines with Pyramid Vision Transformer
  代码:Github

1.回顾PVT

  在提出PVT之前,Vision Transformer领域输出的特征图和输入大小基本保持一致,不发生尺度的变化。在PVT中提出将多个Transformer模块进行叠加,同时在每个模块内部 Attention 机构进行特征提取的大小变化。整体结构如下:
在这里插入图片描述
共采用了四个Transformer编码器进行叠加,每个阶段只有参数不同结构都一样。并且为了不盲目叠加编码器、适应尺度变化,作者提出可更改大小的注意力模块(SRA:Spatial-Reduction Attention,空间缩减注意力机制):
在这里插入图片描述
在这里插入图片描述
可理解为将重塑K、V的形状,由原先的 ( H W ) × C (HW)×C (HW)×C缩减为 ( H W ) R 2 × ( R 2 C ) \frac{(HW)}{R^2}×(R^2C) R2(HW)×(R2C)。经过此操作后特征图的大小不断变化,形成多尺度 PVT。各PVT版本的结构如下:
在这里插入图片描述
  但PVT也存在如下问题:

  • (1)与ViT类似,在处理高分辨率输入时,计算复杂度高。
  • (2)PVT将图像视为不重叠的小块序列,在一定程度上失去了图像的局部连续性。
  • (3)PVT中位置编码大小固定,这对任意大小的图像处理不灵活。

为解决这些问题而提出了PVT-v2。

2.重叠块嵌入

在这里插入图片描述
  在原始PVT及ViT中均采用刚好切分的方式,这使得图像边界部分的信息无法得到完整解读。同时,也丧失了这些分块的局部连续性。因此,PVT2中将原始图像进行零填充,并使用 s t r i d e stride stride小于 k e r n e l _ s i z e kernel\_size kernel_size的重叠分块操作完成嵌入(上图中红色边框内的虚线即为重叠部分)。代码如下:

#原始PVT直接进行分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=4, stride=4)

#PVTv2包含零填充+重叠分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=7,stride=4,padding=(3, 3))

此时PVT2与PVT在 p a t c h _ e m b e d i n g patch\_embeding patch_embeding部分的输入、输出大小相同,即,输入尺寸为 ( 1 , 3 , 224 , 2240 ) (1,3,224,2240) (1,3,224,2240),输出尺寸仍为 ( b a t c h _ s i z e , c h a n n a l , 56 , 56 ) (batch\_size,channal,56,56) (batch_size,channal,56,56)。但PVT2的嵌入包含了每个图像patch和周围相邻图像patch的相关信息。实际代码:

class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        assert max(patch_size) > stride, "Set larger patch_size than stride"
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_chans, embed_dim, patch_size,
            stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)
 
    def forward(self, x):  # (1,3,224,224)
        x = self.proj(x)  # (1,64,56,56)
        x = x.permute(0, 2, 3, 1)  # (1,56,56,64)
        x = self.norm(x)
        return x

其中, p a t c h _ s i z e = k e r n e l _ s i z e = 2 ∗ s t r i d e − 1 , p a d d i n g _ s i z e = s t r i d e − 1 patch\_size=kernel\_size=2*stride-1,padding\_size=stride-1 patch_size=kernel_size=2stride1,padding_size=stride1

3.移除固定位置编码

在这里插入图片描述
  在MLP层中两个FC之间加入了3×3的卷积以移除固定大小的位置编码,意思是,使用零填充给特征图外面加一圈padding,而卷积层可以根据特征图外圈的0学习到特征图的轮廓信息,换句话说可以学习到一些绝对位置信息,因此可以用DW卷积来建模位置信息并且去掉位置编码以减少计算量。实现代码:

class DWConv(nn.Module):
    def __init__(self, dim=768):
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class Mlp(nn.Module):
    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W) #这里这里
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

4.线性SRA

  PVT中使用SRA结构降低分辨率,而PVTv2中使用池化+卷积操作实现,以降低计算量。
在这里插入图片描述

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
        if not linear:
            if sr_ratio > 1:
                self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)
        else: #使用线性 SRA
            self.pool = nn.AdaptiveAvgPool2d(7) #加了一层池化
            self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) #卷积核为 1
            self.norm = nn.LayerNorm(dim)
            self.act = nn.GELU() #激活函数
        self.apply(self._init_weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if not self.linear: #这里是原版 SRA
            if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
                x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            else:
                kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        else: #是线性 SRA
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            x_ = self.act(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

5.版本结构

  PVTv2各版本结构如下:
在这里插入图片描述


http://www.kler.cn/news/368716.html

相关文章:

  • ReactNative JSI(4)
  • 亿家旺生鲜云订单零售系统的设计与实现小程序ssm+论文源码调试讲解
  • 什么是AI神经网络?
  • 重学SpringBoot3-Spring WebFlux之SSE服务器发送事件
  • 循序渐进丨openGauss / MogDB 数据库内存占用相关SQL
  • android——渐变色
  • SASS转换成CSS步骤
  • 宝塔如何部署Django项目(前后端分离篇)
  • Three.js 使用着色器 实现跳动的心
  • WebView渲染异常导致闪退解决方案
  • 若依学习 后端传过来的数据在控制台打印为空
  • iPhone当U盘使用的方法 - iTunes共享文件夹无法复制到电脑怎么办 - 如何100%写入读出
  • 解决pycharm无法添加conda环境的问题【Conda Environment下没有Existing environment】
  • 机器学习在智能水泥基复合材料中的应用与实践
  • 部署 Traefik 实现 dashboard 与 原生Ingress使用 CRD IngressRoute使用
  • 大语言模型参数传递、model 构建与tokenizer构建(基于llama3模型)
  • 关于洛谷中XJS-SINGA科技站点 系统讨论团队的一些介绍
  • 【网络】:网络基础
  • 地球Online生存天数计算器(java小案例)
  • GPU的使用寿命可能只有1~3年
  • 基于去哪儿旅游出行服务平台旅游推荐网站【源码+安装+讲解+售后+文档】
  • Linux 重启命令全解析:深入理解与应用指南
  • 51单片机完全学习——红外遥控
  • LeetCode——最小差值
  • RTMP视频推流EasyDSS平台重装服务器系统后无法启动是什么原因?
  • [LeetCode] 47. 全排列Ⅱ