将图像增广应用于Mnist数据集
将图像增广应用于Mnist数据集
不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用
首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片
import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d2l.set_figsize()
img = Image.open('cat.png')
d2l.plt.imshow(img)
之后呢,我们先定义几个函数,以后方便调用,第一个函数show_images,他是用来展示多张图片的
def show_images(imgs, num_rows, num_cols, scale=2):
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize = figsize)
for i in range(num_rows):
for j in range(num_cols):
axes[i][j].imshow(imgs[i * num_cols + j])
axes[i][j].axes.get_xaxis().set_visible(False)
axes[i][j].axes.get_yaxis().set_visible(False)
return axes
然后将图像展示函数和图像增广函数结合起来展示,也用一个函数来集成
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
Y = [aug(img) for _ in range(num_rows * num_cols)]
show_images(Y, num_rows, num_cols, scale)
接下来,就可以开始我们的图像增广之路啦
左右翻转
torchvision.transforms.RandomHorizontalFlip()这个函数有百分之五十的概率实现左右翻转
apply(img, torchvision.transforms.RandomHorizontalFlip()) # torchvision.transforms.RandomHorizontalFlip() 百分之50的概率左右翻转
上下翻转
torchvision.transforms.RandomVerticalFlip() 百分之50的概率上下翻转
随机裁剪
随机裁剪出一块面积为原面积10%100%的区域,且该区域的宽和高之比随机取自0.52,然后将该区域的宽高缩放到200像素
shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)
自然,我们也可以变换颜色,有亮度(brightness),对比度(contrast),饱和度(saturation),色调(hue)
我就直接一起写了,也可以只变单个
0.5的意思是比如对于亮度来说,他会在50%的范围内随机选择,即亮度为原来的0.5~1.5
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, saturation=0.5, contrast=0.5)
apply(img, color_aug)
那么当然,我们也可以把上述的那些进行叠加
用到torchvision.transforms.Compose
augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])
apply(img, augs)
之后呢,就可以用增广后的图像进行训练啦,这里给大家一个例子用Resnet18进行训练Mnist数据集,Resnet18就不带着大家写了,直接调用别人写好的函数,写网络并不是本节的重点,如果以后有时间或者大家有需要我可以再来写~
(为什么是Mnist数据集,其实他在Mnist数据集上的效果并没有很明显,比较比较简单,最好是在cifar上,但是cifar要下太久了,懒,大家可以在cifar上测一下)
先写两个augs,训练集我就将他随机翻转,测试集就不动了
flip_aug = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor() # 记得转换成tensor 以便训练
])
no_aug = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
然后是load_mnist,就加一个transform就好啦,集成到一个函数里
def load_mnist(is_train, augs, batch_size, root="~/Datasets/MNIST"):
dataset = torchvision.datasets.MNIST(train=is_train, root=root, transform=augs, download=True)
return DataLoader(dataset, batch_size = batch_size, shuffle=is_train)
再之后就是模型的训练了,这个大家应该都写腻了,我也不多说什么了,反正就是模型前向传播+反向传播,然后再记录点值
def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
net = net.to(device)
print("training on ", device)
batch_count = 0
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
for X, y in train_iter:
X = X.to(device)
y = y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
optimizer.zero_grad()
l.backward()
optimizer.step()
train_l_sum += l.cpu().item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
n += y.shape[0]
batch_count += 1
test_acc = d2l.evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
再最后,就定义一个函数,把前面的都用上啦!
def train_with_data_aug(train_augs, test_augs, lr=0.001):
batch_size, net = 256, d2l.resnet18(output=10, in_channels=1)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = torch.nn.CrossEntropyLoss()
train_iter = load_mnist(True, train_augs, batch_size)
test_iter = load_mnist(False, test_augs, batch_size)
train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)
值得注意的是,这边调用别人的d2l.resnet18,要注意in_channels=1记得写,他默认是3通道的,改成1通道对于我们的mnist,如果你要是cifar-10就不用变了,把in_channel=1给删掉就好~,至此,调用我们的函数就行
训练还是很快滴