当前位置: 首页 > article >正文

Vision Transformer(vit)的主干

图解:

代码:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
#输入图像的大小,通常是 224 或其他标准尺寸
            patch_size (int, tuple): patch size
#每个块(patch)的大小,例如 16x16
            in_c (int): number of input channels
#输入图像的通道数,RGB 图像是 3
            num_classes (int): number of classes for classification head
#最终分类的类别数,默认 1000 类
            embed_dim (int): embedding dimension
#嵌入维度,即每个 patch 被映射到的向量的维度,默认是 768
            depth (int): depth of transformer
#Transformer 的深度,即堆叠的块(Block)数量。
            num_heads (int): number of attention heads
#注意力头的数量,默认设为 12
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
# MLP 隐藏层的维度与嵌入维度的比例。
            qkv_bias (bool): enable bias for qkv if True
#是否为 QKV(查询、键、值)矩阵添加偏置
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
#如果设定,将会覆盖默认的 qk 缩放因子
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
#如果设置了这个值,将会有一个表示层(pre-logits)
            distilled (bool): model includes a distillation token and head as in DeiT models
#vit中可以不管这个参数
            drop_ratio (float): dropout rate
# Dropout 的比例
            attn_drop_ratio (float): attention dropout rate
#注意力层的 Dropout 比例
            drop_path_ratio (float): stochastic depth rate
#droppath比例
            embed_layer (nn.Module): patch embedding layer
#用于嵌入图像的层,默认使用 PatchEmbed
            norm_layer: (nn.Module): normalization layer
#正则化层,通常是 LayerNorm
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
# 与 embed_dim 保持一致,表示嵌入的维度。
        self.num_tokens = 2 if distilled else 1
#不管distilled所以distilled=1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
#使用 LayerNorm作为默认的规范化层
        act_layer = act_layer or nn.GELU
#默认使用 GELU 作为激活函数

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
#Embedding层结构
        num_patches = self.patch_embed.num_patches
#patches的个数

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#这是用于分类的分类标记(Class Token),它是一个可学习的参数,初始值为零
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
#不管distilled所以self.dist_token=None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
#位置编码(Position Embedding)
        self.pos_drop = nn.Dropout(p=drop_ratio)
#位置编码后的 Dropout 操作

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
#用于控制每个 Block 的 DropPath 比例
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
#使用 Block 类构建了Transformer的主体部分,包括注意力和MLP层,并使用残差连接和 DropPath 
        self.norm = norm_layer(embed_dim)
#最后的归一化层,用于 Transformer 输出的处理

        # Representation layer
        if representation_size and not distilled:
#设置了 representation_size则会增加一个表示层 pre_logits,not distilled=true
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
#pre_logits层结构一个全连接和tanh激活函数
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
#distilled为none不用管

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)
#权重初始化

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

        x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout
        x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
        x = self.norm(x)
#进行归一化
#vit中self.dist_token is None所以模型只有分类标记 (class token)。
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
#首先获取 Transformer 的特征输出
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
            x = self.head(x)
        return x

 操作:

代码:

        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层

操作:

代码:

    # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

操作:

代码:

x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout

操作:

代码:

        x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
        x = self.norm(x)
#进行归一化

操作:

代码:

#vit中self.dist_token is None所以模型只有分类标记 (class token)。
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
        else:
            return x[:, 0], x[:, 1]

操作:

代码:

#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
            x = self.head(x)

分类标记 (Class Token):

是一种特殊的 输入 token,在 Transformer 模型中被用来聚合全局特征。

它在模型中起到了类似于 CNN 中全局池化 (Global Pooling) 的作用,负责从所有 patch 的信息中提取一个全局表示。

这个 token 的输出向量被用作分类任务的特征输入,之后会被送入分类头 (classifier head) 进行最终的类别预测。

embedding层:

Vision Transformer(vit)的Embedding层结构-CSDN博客

Multi-Head Self-Attention:

Vision Transformer(vit)的Multi-Head Self-Attention(多头注意力机制)结构-CSDN博客

MLP模块:

Vision Transformer(vit)的MLP模块-CSDN博客

Encoder block:

Vision Transformer(vit)的Encoder层结构-CSDN博客

详解:Vision Transformer详解-CSDN博客


http://www.kler.cn/a/415457.html

相关文章:

  • openGauss你计算的表大小,有包含toast表么?
  • 【前端开发】小程序无感登录验证
  • 单片机知识总结(完整)
  • 什么是串联谐振
  • 多线服务器和BGP服务器有什么区别
  • 26页PDF | 数据中台能力框架及评估体系解读(限免下载)
  • CSS学习记录02
  • AI开发:逻辑回归 - 实战演练- 垃圾邮件的识别(一)
  • SpringMVC跨域问题解决方案
  • BUUCTF—Reverse—GXYCTF2019-luck_guy(9)
  • 003 MATLAB基础计算
  • Cesium 当前位置矩阵的获取
  • 深入探索 Java 中的 Spring 框架
  • ORACLE之DBA常用数据库查询
  • openGauss你计算的表大小,有包含toast表么?
  • ArcGIS pro中的回归分析浅析(加更)关于广义线性回归工具的补充内容
  • 2.安装docker、docker compose
  • 使用Native AOT发布C# dll 提供给C++调用
  • c++趣味编程玩转物联网:树莓派Pico控制 LED点阵屏
  • 11.25.2024刷华为OD
  • 【动态规划】完全背包问题应用
  • 淘宝Vision Pro:革新购物体验的沉浸式未来
  • QML 之 画布元素学习
  • 51单片机从入门到精通:理论与实践指南常用资源篇(五)
  • 提升数据分析效率:Excel Power Query和Power Pivot的妙用
  • 获取字 short WORD 上指定的位是否有效