【Pytorch框架】无中生有,从0到1使用Dataset类处理MNIST数据集
文章目录
- 一、Pytorch下载
- 二、MNIST数据集下载
- 三、自定义Dataset类处理MNIST数据集
一、Pytorch下载
Pytorch框架以包的类型存在,但是又不同于其他包。
这里只介绍通过anaconda安装pytorch,因为安装并不是这篇博文的重点,详细的安装介绍可以参考 pytorch安装介绍。
🔔目前pytorch框架只支持 CPU版 和 CUDA版,而CUDA目前只有NVIDIA显卡支持,所以没有NVIDIA显卡支持的请安装CPU版。
1、首先进入 pytorch官网,往下翻找到:
2、选择稳定版,操作系统根据自己的来,Linux系统选择Linux,Windows系统选择Windows,package这里我们使用的是anaconda安装,所以选择conda,语言不多说哈,版本根据上面说的,有NVIDIA显卡的选择CUDA版,没有的选择CPU版。
3、复制下面的安装语句,比如我这里是下面这个,粘贴到anaconda的命令行,至于anaconda的哪个环境可以自己选择,回车就可以自动下载啦😆
conda install pytorch torchvision torchaudio cpuonly -c pytorch
二、MNIST数据集下载
💥MNIST数据集是一个经典的机器学习和计算机视觉数据集,用于手写数字识别的训练和测试,内容包含70000张手写数字的灰度图像,其中60000张用于训练,10000张用于测试。每张图像的大小为28x28像素,表示手写数字0-9。
👇下载方式一:通过pytorch下载
pytorch中内置有MNIST数据集,下载非常方便。
import torch
from torchvision import datasets, transforms
# 定义数据转换:将图像转换为张量
transform = transforms.Compose([
transforms.ToTensor(),
])
# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)
🔔 datasets.MNIST中root代表数据集路径,train为True表示数据集为训练数据集,否则为测试数据集,download为True表示当路径中没有MNIST数据集时,自动下载,并将数据集保存在root路径中,transform表示数据转换器。
👇下载方式二:通过百度网盘下载
😙资源来自 详解MNIST数据集下载、解析及显示的Python实现
- 原始格式数据下载
链接:https://pan.baidu.com/s/1jAPlVKLYamJn6I63GD6HDg?pwd=azq2
提取码:azq2
- JPEG格式数据下载
链接:https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n
提取码:xl8n
🌟上述两种格式下载后均是一个文件,这里我将 training类型 和 test类型 分成了两个文件夹,图片格式为JEPG格式。
链接: https://pan.baidu.com/s/1i-hXHMBq1-dWKvZXhUYoAQ?pwd=6666
提取码: 6666
三、自定义Dataset类处理MNIST数据集
☀️自定义Dataset类需要继承torch.utils.data中的Dataset类,并重写__getitem__方法,使用PIL包下的Image类处理图片,os包读取数据集路径
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataset(Dataset):
在MyDataset类中有两个方法:__init__方法和__getitem__方法。
- 在__init__方法中,传入数据集路径参数。对MyDataset类进行初始化,通过os.listdir方法将 数据集路径 对应的文件处理为列表,列表中存储每一张图片的完整名称(例如:test_0_7.jpg,其中test代表属于测试集数据,若为training则代表训练集数据;0代表图片的索引;7为label,表示图片所描述的数字)。
- 在__getitem__方法中,传入索引参数,返回索引对应图片的JpegImageFile对象及对应的label标签。使用Image.open方法可将图片路径转换为JpegImageFile格式。
# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
def __init__(self, root_dir):
#数据集路径
self.root_dir = root_dir
#通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
self.img_paths = os.listdir(self.root_dir)
def __getitem__(self, idx):
#获取下标为idx的图片路径
img_path = self.img_paths[idx]
#os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
img = Image.open(os.path.join(self.root_dir,img_path))
#去除名称中的后缀,获得图片的label
img_name = img_path.split('.')[0]
label = img_name.split('_')[-1]
#将图片路径和label返回
return img, label
def __len__(self):
return len(self.img_paths)
🌟 其中的__len__方法可以返回img_paths列表的长度。
🍁测试完整代码:
from torch.utils.data import Dataset
from PIL import Image
import os
# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
def __init__(self, root_dir):
#数据集路径
self.root_dir = root_dir
#通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
self.img_paths = os.listdir(self.root_dir)
def __getitem__(self, idx):
#获取下标为idx的图片路径
img_path = self.img_paths[idx]
#os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
img = Image.open(os.path.join(self.root_dir,img_path))
#去除名称中的后缀,获得图片的label
img_name = img_path.split('.')[0]
label = img_name.split('_')[-1]
#将图片路径和label返回
return img, label
def __len__(self):
return len(self.img_paths)
if __name__ == '__main__':
#定义数据集路径
train_dataset_path = "./mnist-20/training"
test_dataset_path = "./mnist-20/test"
#创建MyDataset类的对象
train_dataset = MyDataset(train_dataset_path)
test_dataset = MyDataset(test_dataset_path)
#得到__getitem__函数返回的变量
img,label = train_dataset[0]
#展示图片
img.show()