当前位置: 首页 > article >正文

在 CIFAR10 数据集上训练 Vision Transformer (ViT)

点击下方卡片,关注“小白玩转Python”公众号

132651974e52cbd68faee572e49704f8.png在这篇简短的文章中,我将构建一个简单的 ViT 并将其训练在 CIFAR 数据集上。

训练循环

我们从训练 CIFAR 数据集上的模型的样板代码开始。我们选择批量大小为64,以在性能和 GPU 资源之间取得平衡。我们将使用 Adam 优化器,并将学习率设置为0.001。与 CNN 相比,ViT 收敛得更慢,所以我们可能需要更多的训练周期。此外,根据我的经验,ViT 对超参数很敏感。一些超参数会使模型崩溃并迅速达到零梯度,模型的参数将不再更新。因此,您必须测试与模型大小和形状本身以及训练超参数相关的不同超参数。

transform_train = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)


train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)
n_epochs = 100
lr = 0.0001


optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()


for epoch in range(n_epochs):
    train_loss = 0.0
    for i,batch in enumerate(train_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat, _ = model(x)
        loss = criterion(y_hat, y)


        batch_loss = loss.detach().cpu().item()
        train_loss += batch_loss / len(train_loader)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        if i%100==0:
          print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")
    
    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")

构建 ViT

如果您熟悉注意力和transforms块,ViT 架构就很容易理解。简而言之,我们将使用 Pytorch 提供的多头注意力,视觉transforms的第一部分是将图像分割成相同大小的块。如您所知,transforms作用于标记,而不是像在 CNN 中那样卷积特征。在我们的例子中,图像块充当标记。

有很多方法可以对图像进行分块。有些人手动进行,这不符合 Python 的风格。其他人使用卷积。还有些人使用 Pytorch 提供的张量操作工具。我们将使用 Pytorch nn 模块提供的 `unfold` 层作为我们 `Patcher` 模块的核心。

该模块作用于形状为 (N, 3, 32, 32) 的张量。其中 N 是每批图像的数量。3 是通道数,因为我们处理的是 RGB 图像。32 是图像的大小,因为我们处理的是 CIFAR10 数据集。我们可以测试我们的模块,以确保它将上述形状转换为分块张量。新张量的形状取决于补丁大小。如果我们选择补丁大小为4,输出形状将是 (N, 64, 3, 4, 4),其中 64 是每张图像的补丁数量。

class Patcher(nn.Module):
  def __init__(self, patch_size):
    super(Patcher, self).__init__()


    self.patch_size=patch_size


    self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)


  def forward(self, images):
    batch_size, channels, height, width = images.shape


    patch_height, patch_width = [self.patch_size, self.patch_size]
    assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."


    patches = self.unfold(images) #bs (cxpxp) N
    patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P


    return patches
x = torch.rand((10, 3, 32, 32))
x = Patcher(patch_size=4)(x)
x.shape
# torch.Size([10, 64, 3, 4, 4])

在语言处理中,标记通过词嵌入投影到 `d` 维向量中。这个超参数 `d` 是transforms模型的特征,选择合适的维度大小对于模型的转换很重要。太大,模型会崩溃。太小,模型将无法很好地训练。因此,到目前为止,我们的 ViT 模块形状将如下所示:

class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100):
    super().__init__()


    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim


    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)


    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)


  def forward(self, x):


    x = self.patcher(x)


    x = x.flatten(start_dim=2)


    x = self.linear_projector(x)


    return x

我们将图像 (N, 3, 32, 32) 分割成大小为4的补丁 (N, 64, 3, 4, 4),然后我们将它们展平为 (N, 64, 3*4*4=48)。之后,我们使用 Pytorch 的 Linear 模块将它们投影到大小为 (N, 64, 100)。

即使在将输入喂入transforms块之后,整个模块的输出大小也将是 (N, n_patches, model_dim)。现在我们有很多投影和关注的补丁,应该使用哪个补丁进行预测?一种常见的方法是计算所有补丁的平均值,然后使用平均向量进行预测。然而,对于transforms,现在正在广泛使用另一种技巧。那就是添加一个 [cls] 一个新的标记到输入中。辅助标记最终将用于预测。它将作用于模型对整个图像的理解。该标记只是一个大小为 (1, model_dim) 的参数向量。现在,整个模块的输出将是 (N, n_patches+1, model_dim)。

class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim


    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)


    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)


    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, self.model_dim))




  def forward(self, x):


    x = self.patcher(x)


    x = x.flatten(start_dim=2)


    x = self.linear_projector(x)


    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)


    return x

在添加了类标记之后,我们仍然需要添加位置编码部分。transforms操作在一系列标记上,它们对序列顺序视而不见。为了确保在训练中加入顺序,我们手动添加位置编码。因为我们处理的是大小为 model_dim 的向量,我们不能简单地添加顺序 [0, 1, 2, …],位置应该是模型固有的,这就是为什么我们使用所谓的位置编码。这个向量可以手动设置或训练。在我们的例子中,我们将简单地训练一个位置嵌入,它只是一个大小为 (1, n_patches+1, model_dim) 的向量。我们将这个向量添加到完整的补丁序列中,以及类标记。如前所述,为了计算模型的输出,我们简单地对嵌入的第一个标记(类标记)应用一个带有 SoftMax 层的 MLP,以获得类别的对数几率。

class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100,n_classes=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim


    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)


    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)


    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))


    # 4) Positional Embedding
    self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))


    # 6) Classification MLP
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.n_classes),
            nn.Softmax(dim=-1)
      )


  def forward(self, x):


    x = self.patcher(x)


    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)


    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)


    x = x + self.positional_embedding


    latent = x[:, 0]
    logits = self.mlp(latent)


    return logits

transforms块

之前的代码没有包括非常重要的transforms块。transforms块是大小保持块,它们通过交叉组成序列的标记本身来丰富信息序列。transforms块的核心模块是注意力模块(同样,您可以查看我关于注意力的帖子)。为了使模型更丰富地处理信息,我们通常使用多头注意力。为了使模型吸收越来越抽象的信息,我们应用了几个transforms块。使用的头数和transforms块的数量是transforms模型的特征。我们称使用的transforms块数量为模型的 `depth`。

class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(model_dim)
        self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(model_dim)


        # Feedforward network
        self.mlp = nn.Sequential(
            nn.Linear(model_dim, int(model_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(model_dim * mlp_ratio), model_dim),
            nn.Dropout(dropout),
        )


    def forward(self, x):
        # Self-attention
        x = self.norm1(x)
        attn_out, _ = self.attn(x, x, x)
        x = x + attn_out


        # Feedforward network
        x = self.norm2(x)
        mlp_out = self.mlp(x)
        x = x + mlp_out


        return x
class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100, num_heads=3, num_layers=2, n_classes=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    self.num_layers = num_layers
    self.num_heads= num_heads
    self.n_classes = n_classes


    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)


    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)


    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))


    # 4) Positional Embedding
    self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))


    # 5) Transformer blocks
    self.blocks = nn.ModuleList([
        TransformerBlock( self.model_dim,  self.num_heads) for _ in range(num_layers)
    ])


    # 6) Classification MLPk
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.n_classes),
            nn.Softmax(dim=-1)
        )


  def forward(self, x):


    x = self.patcher(x)


    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)


    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)


    x = x + self.positional_embedding


    for block in self.blocks:
      x = block(x)


    latent = x[:, 0]
    logits = self.mlp(latent)


    return logits

最后,我们为训练和测试准备好了模型,并放置了所有必要的组件。然而,在实践中,我无法通过在类标记上应用 MLP 层使模型收敛。我不确定为什么——如果你知道,请告诉我。相反,我在整个图像补丁的平均向量上应用了 MLP。

·  END  ·

🌟 想要变身计算机视觉小能手?快来「小白玩转Python」公众号!

回复Python视觉实战项目,解锁31个超有趣的视觉项目大礼包!🎁

87ee6e4d02529b4ce532d35024df7f2f.png

本文仅供学习交流使用,如有侵权请联系作者删除


http://www.kler.cn/a/387474.html

相关文章:

  • 接口测试Day09-数据库工具类封装
  • 使用葡萄城+vue实现Excel
  • 使用WebdriverIO和Appium测试App
  • 《机器学习》——贝叶斯算法
  • 学英语学Elasticsearch:04 Elastic integrations 工具箱实现对第三方数据源的采集、存储、可视化,开箱即用
  • 如何独立SDK模块到源码目录?
  • 解释一下Java中的异常处理机制
  • IDM扩展添加到Edge浏览器
  • 怎么给llama3.2-vision:90b模型进行量化剪枝蒸馏
  • 类加载的生命周期?
  • opencv实时弯道检测
  • 1.6K+ Star!Ichigo:一个开源的实时语音AI项目
  • 华为机试HJ29 字符串加解密
  • SDL打开YUV视频
  • AI和大模型技术在网络脆弱性扫描领域的最新进展与未来发展趋势
  • [C++ 核心编程]笔记 4.4.3 成员函数做友元
  • <<零基础C++第一期, C++入门基础>>
  • 打造完整 Transformer 编码器:逐步实现高效深度学习模块
  • 深度学习在大数据处理中的应用
  • 电子电气架构 --- 车载以太网架构安全性要求
  • Qt使用属性树(QtProPertyBrowser)时,引用报错#include “QtTreePropertyBrowser“解决方案
  • HDR视频技术之二:光电转换与 HDR 图像显示
  • python批量合并excel文件
  • 经典的ORACLE 11/12/19闪回操作
  • 前端vue3若依框架pnpm run dev启动报错
  • AI时代来临,什么是真正的大模型?【大模型扫盲系列】