当前位置: 首页 > article >正文

【Pytorch框架】无中生有,从0到1使用Dataset类处理MNIST数据集

文章目录

  • 一、Pytorch下载
  • 二、MNIST数据集下载
  • 三、自定义Dataset类处理MNIST数据集

一、Pytorch下载

Pytorch框架以包的类型存在,但是又不同于其他包。


这里只介绍通过anaconda安装pytorch,因为安装并不是这篇博文的重点,详细的安装介绍可以参考 pytorch安装介绍。

🔔目前pytorch框架只支持 CPU版CUDA版,而CUDA目前只有NVIDIA显卡支持,所以没有NVIDIA显卡支持的请安装CPU版。

1、首先进入 pytorch官网,往下翻找到:
在这里插入图片描述
2、选择稳定版,操作系统根据自己的来,Linux系统选择Linux,Windows系统选择Windows,package这里我们使用的是anaconda安装,所以选择conda,语言不多说哈,版本根据上面说的,有NVIDIA显卡的选择CUDA版,没有的选择CPU版。
在这里插入图片描述
3、复制下面的安装语句,比如我这里是下面这个,粘贴到anaconda的命令行,至于anaconda的哪个环境可以自己选择,回车就可以自动下载啦😆

conda install pytorch torchvision torchaudio cpuonly -c pytorch

在这里插入图片描述

二、MNIST数据集下载

💥MNIST数据集是一个经典的机器学习和计算机视觉数据集,用于手写数字识别的训练和测试,内容包含70000张手写数字的灰度图像,其中60000张用于训练,10000张用于测试。每张图像的大小为28x28像素,表示手写数字0-9。

👇下载方式一:通过pytorch下载

pytorch中内置有MNIST数据集,下载非常方便。

import torch
from torchvision import datasets, transforms

# 定义数据转换:将图像转换为张量
transform = transforms.Compose([
    transforms.ToTensor(),
])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)

🔔 datasets.MNIST中root代表数据集路径,train为True表示数据集为训练数据集,否则为测试数据集,download为True表示当路径中没有MNIST数据集时,自动下载,并将数据集保存在root路径中,transform表示数据转换器。

👇下载方式二:通过百度网盘下载

😙资源来自 详解MNIST数据集下载、解析及显示的Python实现

  • 原始格式数据下载
链接:https://pan.baidu.com/s/1jAPlVKLYamJn6I63GD6HDg?pwd=azq2 
提取码:azq2 

在这里插入图片描述

  • JPEG格式数据下载
链接:https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n 
提取码:xl8n 

在这里插入图片描述

🌟上述两种格式下载后均是一个文件,这里我将 training类型test类型 分成了两个文件夹,图片格式为JEPG格式。

链接: https://pan.baidu.com/s/1i-hXHMBq1-dWKvZXhUYoAQ?pwd=6666 
提取码: 6666 

三、自定义Dataset类处理MNIST数据集

☀️自定义Dataset类需要继承torch.utils.data中的Dataset类,并重写__getitem__方法,使用PIL包下的Image类处理图片,os包读取数据集路径

from torch.utils.data import Dataset
from PIL import Image
import os

class MyDataset(Dataset):

在MyDataset类中有两个方法:__init__方法和__getitem__方法。

  • 在__init__方法中,传入数据集路径参数。对MyDataset类进行初始化,通过os.listdir方法将 数据集路径 对应的文件处理为列表,列表中存储每一张图片的完整名称(例如:test_0_7.jpg,其中test代表属于测试集数据,若为training则代表训练集数据;0代表图片的索引;7为label,表示图片所描述的数字)。
  • 在__getitem__方法中,传入索引参数,返回索引对应图片的JpegImageFile对象及对应的label标签。使用Image.open方法可将图片路径转换为JpegImageFile格式。
# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
    def __init__(self, root_dir):
        #数据集路径
        self.root_dir = root_dir
        #通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
        self.img_paths = os.listdir(self.root_dir)

    def __getitem__(self, idx):
        #获取下标为idx的图片路径
        img_path = self.img_paths[idx]
        #os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
        img = Image.open(os.path.join(self.root_dir,img_path))
        #去除名称中的后缀,获得图片的label
        img_name = img_path.split('.')[0]

        label = img_name.split('_')[-1]
        #将图片路径和label返回
        return img, label

    def __len__(self):
        return len(self.img_paths)

🌟 其中的__len__方法可以返回img_paths列表的长度。


🍁测试完整代码

from torch.utils.data import Dataset
from PIL import Image
import os

# 定义MyDataset数据集处理类,继承于Dataset,重写__getitem__方法
class MyDataset(Dataset):
    def __init__(self, root_dir):
        #数据集路径
        self.root_dir = root_dir
        #通过listdir函数将数据集中的图片转化成列表,列表中存储图片的完整名称,例如 test_0_7.jpg
        self.img_paths = os.listdir(self.root_dir)

    def __getitem__(self, idx):
        #获取下标为idx的图片路径
        img_path = self.img_paths[idx]
        #os.path.join函数可将两个路径拼接起来,Image.open函数将路径对应的图片打开为JpegImageFile格式
        img = Image.open(os.path.join(self.root_dir,img_path))
        #去除名称中的后缀,获得图片的label
        img_name = img_path.split('.')[0]

        label = img_name.split('_')[-1]
        #将图片路径和label返回
        return img, label

    def __len__(self):
        return len(self.img_paths)


if __name__ == '__main__':
    #定义数据集路径
    train_dataset_path = "./mnist-20/training"
    test_dataset_path = "./mnist-20/test"

    #创建MyDataset类的对象
    train_dataset = MyDataset(train_dataset_path)
    test_dataset = MyDataset(test_dataset_path)

	#得到__getitem__函数返回的变量
    img,label = train_dataset[0] 
    #展示图片
    img.show()

在这里插入图片描述


http://www.kler.cn/a/412364.html

相关文章:

  • PIMPL模式和D指针
  • HTML详解(1)
  • qt 读写文本、xml文件
  • burp功能介绍
  • 4.4 JMeter 请求参数类型详解
  • Springboot启动报错’javax.management.MBeanServer’ that could not be found.
  • 多线程1:基础概念、接口介绍、锁
  • 通俗理解人工智能、机器学习和深度学习的关系
  • 【carla生成车辆时遇到的问题】carla显示的坐标和carlaworld中提取的坐标y值相反
  • 前后端中Json数据的简单处理
  • Javaweb 前端 HTML css 案例 总结
  • 开发一个基于MACOS M1/2芯片的Android 12的模拟器
  • 基于STM32的智能风扇控制系统
  • digit_eye开发记录(2): Python读取MNIST数据集
  • 渗透测试笔记—window基础
  • 蓝桥杯每日真题 - 第24天
  • 27加餐篇:gRPC框架的优势与不足之处
  • Apache Zeppelin:一个基于Web的大数据可视化分析平台
  • 前端 设置 div 标签内子多个子 div 内容,在一行展示,并且可以字段自动换行
  • Flink 实现超速监控:从 Kafka 读取卡口数据写入 MySQL
  • 浏览器开发工具
  • java——SpringBoot中常用注解及其底层原理
  • SSM之AOP与事务
  • 缓存雪崩、击穿、穿透深度解析与实战应对
  • 使用OpenCV实现视频背景减除与目标检测
  • 【QT】背景,安装和介绍