当前位置: 首页 > article >正文

将图像增广应用于Mnist数据集

将图像增广应用于Mnist数据集

不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用

首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d2l.set_figsize()
img = Image.open('cat.png')
d2l.plt.imshow(img)

在这里插入图片描述
之后呢,我们先定义几个函数,以后方便调用,第一个函数show_images,他是用来展示多张图片的

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize = figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

然后将图像展示函数和图像增广函数结合起来展示,也用一个函数来集成

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

接下来,就可以开始我们的图像增广之路啦

左右翻转

torchvision.transforms.RandomHorizontalFlip()这个函数有百分之五十的概率实现左右翻转

apply(img, torchvision.transforms.RandomHorizontalFlip()) # torchvision.transforms.RandomHorizontalFlip() 百分之50的概率左右翻转

在这里插入图片描述

上下翻转

torchvision.transforms.RandomVerticalFlip() 百分之50的概率上下翻转
在这里插入图片描述

随机裁剪

随机裁剪出一块面积为原面积10%100%的区域,且该区域的宽和高之比随机取自0.52,然后将该区域的宽高缩放到200像素

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述
自然,我们也可以变换颜色,有亮度(brightness),对比度(contrast),饱和度(saturation),色调(hue)
我就直接一起写了,也可以只变单个
0.5的意思是比如对于亮度来说,他会在50%的范围内随机选择,即亮度为原来的0.5~1.5

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, saturation=0.5, contrast=0.5) 
apply(img, color_aug)

在这里插入图片描述
那么当然,我们也可以把上述的那些进行叠加
用到torchvision.transforms.Compose

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])
apply(img, augs)

在这里插入图片描述
之后呢,就可以用增广后的图像进行训练啦,这里给大家一个例子用Resnet18进行训练Mnist数据集,Resnet18就不带着大家写了,直接调用别人写好的函数,写网络并不是本节的重点,如果以后有时间或者大家有需要我可以再来写~
(为什么是Mnist数据集,其实他在Mnist数据集上的效果并没有很明显,比较比较简单,最好是在cifar上,但是cifar要下太久了,懒,大家可以在cifar上测一下)
先写两个augs,训练集我就将他随机翻转,测试集就不动了

flip_aug = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()           # 记得转换成tensor 以便训练
])
no_aug = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

然后是load_mnist,就加一个transform就好啦,集成到一个函数里

def load_mnist(is_train, augs, batch_size, root="~/Datasets/MNIST"):
    dataset = torchvision.datasets.MNIST(train=is_train, root=root, transform=augs, download=True)
    return DataLoader(dataset, batch_size = batch_size, shuffle=is_train)

再之后就是模型的训练了,这个大家应该都写腻了,我也不多说什么了,反正就是模型前向传播+反向传播,然后再记录点值

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = d2l.evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

再最后,就定义一个函数,把前面的都用上啦!

def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(output=10, in_channels=1)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_mnist(True, train_augs, batch_size)
    test_iter = load_mnist(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

值得注意的是,这边调用别人的d2l.resnet18,要注意in_channels=1记得写,他默认是3通道的,改成1通道对于我们的mnist,如果你要是cifar-10就不用变了,把in_channel=1给删掉就好~,至此,调用我们的函数就行
在这里插入图片描述

训练还是很快滴


http://www.kler.cn/a/161211.html

相关文章:

  • Rust 图形界面开发——使用 GTK 创建跨平台 GUI
  • 【信息系统项目管理师】高分论文:论信息系统项目的沟通管理(银行绩效考核系统)
  • RT-Thread中堆和栈怎么跟单片机内存相联系
  • 【每日学点鸿蒙知识】ets匿名类、获取控件坐标、Web显示iframe标签、软键盘导致上移、改变Text的背景色
  • 马原复习笔记
  • Lombok是银弹?还是陷阱?
  • scp 指令详细介绍
  • activemq启动成功但web管理页面却无法访问
  • 多人聊天Java
  • 【前端架构】清洁前端架构
  • ubuntu22.04设置国内源
  • JAVA 企业面试题
  • inBuilder低代码平台新特性推荐-第十五期
  • Shopify 开源 WebAssembly 工具链 Ruvy
  • C++STL的string类(一)
  • mysql的几种索引
  • 在数字化转型大时代下,企业进行知识管理的重要性
  • 腾讯云轻量应用服务器怎么安装宝塔Linux面板?
  • js vue form表单层级过深,层级太深了,form检测不到form的变化
  • 关于FBPINN的讨论
  • 南京大学考研机试题DP
  • 【文末送书】Python OpenCV从入门到精通
  • Abaqus基础教程--胶合失效仿真
  • Leetcode—1038.从二叉搜索树到更大和树【中等】
  • MySQL 数据库如何实现 XA 规范?
  • 【重磅来袭!!!工程师必备初始化建工程软件】