在 CIFAR10 数据集上训练 Vision Transformer (ViT)
点击下方卡片,关注“小白玩转Python”公众号
在这篇简短的文章中,我将构建一个简单的 ViT 并将其训练在 CIFAR 数据集上。
训练循环
我们从训练 CIFAR 数据集上的模型的样板代码开始。我们选择批量大小为64,以在性能和 GPU 资源之间取得平衡。我们将使用 Adam 优化器,并将学习率设置为0.001。与 CNN 相比,ViT 收敛得更慢,所以我们可能需要更多的训练周期。此外,根据我的经验,ViT 对超参数很敏感。一些超参数会使模型崩溃并迅速达到零梯度,模型的参数将不再更新。因此,您必须测试与模型大小和形状本身以及训练超参数相关的不同超参数。
transform_train = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)
n_epochs = 100
lr = 0.0001
optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()
for epoch in range(n_epochs):
train_loss = 0.0
for i,batch in enumerate(train_loader):
x, y = batch
x, y = x.to(device), y.to(device)
y_hat, _ = model(x)
loss = criterion(y_hat, y)
batch_loss = loss.detach().cpu().item()
train_loss += batch_loss / len(train_loader)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%100==0:
print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")
print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")
构建 ViT
如果您熟悉注意力和transforms块,ViT 架构就很容易理解。简而言之,我们将使用 Pytorch 提供的多头注意力,视觉transforms的第一部分是将图像分割成相同大小的块。如您所知,transforms作用于标记,而不是像在 CNN 中那样卷积特征。在我们的例子中,图像块充当标记。
有很多方法可以对图像进行分块。有些人手动进行,这不符合 Python 的风格。其他人使用卷积。还有些人使用 Pytorch 提供的张量操作工具。我们将使用 Pytorch nn 模块提供的 `unfold` 层作为我们 `Patcher` 模块的核心。
该模块作用于形状为 (N, 3, 32, 32) 的张量。其中 N 是每批图像的数量。3 是通道数,因为我们处理的是 RGB 图像。32 是图像的大小,因为我们处理的是 CIFAR10 数据集。我们可以测试我们的模块,以确保它将上述形状转换为分块张量。新张量的形状取决于补丁大小。如果我们选择补丁大小为4,输出形状将是 (N, 64, 3, 4, 4),其中 64 是每张图像的补丁数量。
class Patcher(nn.Module):
def __init__(self, patch_size):
super(Patcher, self).__init__()
self.patch_size=patch_size
self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
def forward(self, images):
batch_size, channels, height, width = images.shape
patch_height, patch_width = [self.patch_size, self.patch_size]
assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."
patches = self.unfold(images) #bs (cxpxp) N
patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P
return patches
x = torch.rand((10, 3, 32, 32))
x = Patcher(patch_size=4)(x)
x.shape
# torch.Size([10, 64, 3, 4, 4])
在语言处理中,标记通过词嵌入投影到 `d` 维向量中。这个超参数 `d` 是transforms模型的特征,选择合适的维度大小对于模型的转换很重要。太大,模型会崩溃。太小,模型将无法很好地训练。因此,到目前为止,我们的 ViT 模块形状将如下所示:
class ViT_RGB(nn.Module):
def __init__(self, img_size, patch_size, model_dim= 100):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (self.img_size // self.patch_size) ** 2
self.model_dim = model_dim
# 1) Patching
self.patcher = Patcher(patch_size=self.patch_size)
# 2) Linear Prjection
self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
def forward(self, x):
x = self.patcher(x)
x = x.flatten(start_dim=2)
x = self.linear_projector(x)
return x
我们将图像 (N, 3, 32, 32) 分割成大小为4的补丁 (N, 64, 3, 4, 4),然后我们将它们展平为 (N, 64, 3*4*4=48)。之后,我们使用 Pytorch 的 Linear 模块将它们投影到大小为 (N, 64, 100)。
即使在将输入喂入transforms块之后,整个模块的输出大小也将是 (N, n_patches, model_dim)。现在我们有很多投影和关注的补丁,应该使用哪个补丁进行预测?一种常见的方法是计算所有补丁的平均值,然后使用平均向量进行预测。然而,对于transforms,现在正在广泛使用另一种技巧。那就是添加一个 [cls] 一个新的标记到输入中。辅助标记最终将用于预测。它将作用于模型对整个图像的理解。该标记只是一个大小为 (1, model_dim) 的参数向量。现在,整个模块的输出将是 (N, n_patches+1, model_dim)。
class ViT_RGB(nn.Module):
def __init__(self, img_size, patch_size, model_dim= 100):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (self.img_size // self.patch_size) ** 2
self.model_dim = model_dim
# 1) Patching
self.patcher = Patcher(patch_size=self.patch_size)
# 2) Linear Prjection
self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
# 3) Class Token
self.class_token = nn.Parameter(torch.rand(1, self.model_dim))
def forward(self, x):
x = self.patcher(x)
x = x.flatten(start_dim=2)
x = self.linear_projector(x)
batch_size = x.shape[0]
class_token = self.class_token.expand(batch_size, -1, -1)
x = torch.cat((class_token, x), dim=1)
return x
在添加了类标记之后,我们仍然需要添加位置编码部分。transforms操作在一系列标记上,它们对序列顺序视而不见。为了确保在训练中加入顺序,我们手动添加位置编码。因为我们处理的是大小为 model_dim 的向量,我们不能简单地添加顺序 [0, 1, 2, …],位置应该是模型固有的,这就是为什么我们使用所谓的位置编码。这个向量可以手动设置或训练。在我们的例子中,我们将简单地训练一个位置嵌入,它只是一个大小为 (1, n_patches+1, model_dim) 的向量。我们将这个向量添加到完整的补丁序列中,以及类标记。如前所述,为了计算模型的输出,我们简单地对嵌入的第一个标记(类标记)应用一个带有 SoftMax 层的 MLP,以获得类别的对数几率。
class ViT_RGB(nn.Module):
def __init__(self, img_size, patch_size, model_dim= 100,n_classes=10):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (self.img_size // self.patch_size) ** 2
self.model_dim = model_dim
# 1) Patching
self.patcher = Patcher(patch_size=self.patch_size)
# 2) Linear Prjection
self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
# 3) Class Token
self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))
# 4) Positional Embedding
self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))
# 6) Classification MLP
self.mlp = nn.Sequential(
nn.Linear(self.model_dim, self.n_classes),
nn.Softmax(dim=-1)
)
def forward(self, x):
x = self.patcher(x)
x = x.flatten(start_dim=2)
x = self.linear_projector(x)
batch_size = x.shape[0]
class_token = self.class_token.expand(batch_size, -1, -1)
x = torch.cat((class_token, x), dim=1)
x = x + self.positional_embedding
latent = x[:, 0]
logits = self.mlp(latent)
return logits
transforms块
之前的代码没有包括非常重要的transforms块。transforms块是大小保持块,它们通过交叉组成序列的标记本身来丰富信息序列。transforms块的核心模块是注意力模块(同样,您可以查看我关于注意力的帖子)。为了使模型更丰富地处理信息,我们通常使用多头注意力。为了使模型吸收越来越抽象的信息,我们应用了几个transforms块。使用的头数和transforms块的数量是transforms模型的特征。我们称使用的transforms块数量为模型的 `depth`。
class TransformerBlock(nn.Module):
def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super(TransformerBlock, self).__init__()
self.norm1 = nn.LayerNorm(model_dim)
self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
self.norm2 = nn.LayerNorm(model_dim)
# Feedforward network
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim),
nn.Dropout(dropout),
)
def forward(self, x):
# Self-attention
x = self.norm1(x)
attn_out, _ = self.attn(x, x, x)
x = x + attn_out
# Feedforward network
x = self.norm2(x)
mlp_out = self.mlp(x)
x = x + mlp_out
return x
class ViT_RGB(nn.Module):
def __init__(self, img_size, patch_size, model_dim= 100, num_heads=3, num_layers=2, n_classes=10):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (self.img_size // self.patch_size) ** 2
self.model_dim = model_dim
self.num_layers = num_layers
self.num_heads= num_heads
self.n_classes = n_classes
# 1) Patching
self.patcher = Patcher(patch_size=self.patch_size)
# 2) Linear Prjection
self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
# 3) Class Token
self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))
# 4) Positional Embedding
self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))
# 5) Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock( self.model_dim, self.num_heads) for _ in range(num_layers)
])
# 6) Classification MLPk
self.mlp = nn.Sequential(
nn.Linear(self.model_dim, self.n_classes),
nn.Softmax(dim=-1)
)
def forward(self, x):
x = self.patcher(x)
x = x.flatten(start_dim=2)
x = self.linear_projector(x)
batch_size = x.shape[0]
class_token = self.class_token.expand(batch_size, -1, -1)
x = torch.cat((class_token, x), dim=1)
x = x + self.positional_embedding
for block in self.blocks:
x = block(x)
latent = x[:, 0]
logits = self.mlp(latent)
return logits
最后,我们为训练和测试准备好了模型,并放置了所有必要的组件。然而,在实践中,我无法通过在类标记上应用 MLP 层使模型收敛。我不确定为什么——如果你知道,请告诉我。相反,我在整个图像补丁的平均向量上应用了 MLP。
· END ·
🌟 想要变身计算机视觉小能手?快来「小白玩转Python」公众号!
回复“Python视觉实战项目”,解锁31个超有趣的视觉项目大礼包!🎁
本文仅供学习交流使用,如有侵权请联系作者删除