当前位置: 首页 > article >正文

dl学习笔记(8):fashion-mnist

过完年懒羊羊也要复工了,这一节的内容不多,我们接着上次的fashion-mnist数据集。

首先第一步就是导入数据集,由于这个数据集很有名,是深度学习的常见入门数据集,所以可以在库里面导入。由于是图像数据集所以,被存放在视觉模块里面。

import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.FashionMNIST(
    root=r'E:\桌面\深度学习课件\lesson 11\MINST-FASHION'
   , train=True
   , download=False
   , transform=transforms.ToTensor())

下面我们来解释一下这几个参数:

1)root指定数据集存储的本地路径,如果路径不存在,且 download=True,PyTorch会自动创建该路径并下载数据。如果路径已存在且包含数据集文件,则直接加载本地数据。

2)train决定加载的是训练集还是测试集。

  • train=True:加载训练集(60,000张图片)

  • train=False:加载测试集(10,000张图片)

3)download:控制是否从网络下载数据集。

  • download=True:如果本地路径 root 中不存在数据集,则自动下载。

  • download=False:不下载,直接加载本地数据(需确保本地路径已存在数据集)。

4)transform:定义数据预处理操作。

  • ToTensor() 将PIL图像或NumPy数组转换为PyTorch张量(Tensor),并自动进行以下操作:

    将图像数据范围从 [0, 255] 缩放到 [0, 1]。调整张量维度为 [C, H, W](通道、高度、宽度),例如FashionMNIST是灰度图,因此 C=1
  • 如果需要对数据做进一步处理(如归一化),可以组合多个变换

运行结果如上,下一步可以查看属性信息。

这里的size含义就是有六万张图片,每张都是28*28的像素,需要注意的是这里省略了颜色通道,由于该数据集是灰度图片所以这里默认是1。

我们可以通过targets来查看标签,再通过unique来获得标签的唯一值,可以看到是一个多分类任务,总共十个类别。我们还可以通过classes来查看每个数字对应的具体衣服的类别是什么。

下一步我们通过索引来具体看看里面存储的是什么:

图片有点长,如果我们仔细看的话,前面全是图片像素点的张量,最后有一个不起眼的9就是这张图片的标签,所以我们可以通过[0][0]来索引张量,下面我们来展示出来这张图片。

我们将像素部分的张量传入,由于这里是tensor结构,所以我们需要最后转化成numpy才行。

再展示一张:

由于前面已经看过标签和样本已经打包在一起了,所以这里我们不需要使用之前学的dataset的打包功能了,只需要dataloader的分批次。

最后我们开始完整的建模之前我们先复习一下上次说过的完整流程:
1)设置步长 ,动量值 ,迭代次数 ,batch_size等信息,(如果需要)设置初始权重
2)导入数据,将数据切分成batches
3)定义神经网络架构
4)定义损失函数 ,如果需要的话,将损失函数调整成凸函数,以便求解最小值
5)定义所使用的优化算法
6)开始在epoches和batch上循环,执行优化算法:
6.1)调整数据结构,确定数据能够在神经网络、损失函数和优化算法中顺利运行
6.2)完成向前传播,计算初始损失
6.3)利用反向传播,在损失函数上求偏导数
6.4)迭代当前权重
6.5)清空本轮梯度
6.6)完成模型进度与效果监控
7)输出结果

按照惯例首先还是先导入库,下面是所有用到的库

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

1)确定超参数

lr = 0.1
gamma = 0.7
epochs = 5
bs = 128

2)导入数据,将数据切分成batches

batcheddata = DataLoader(mnist,batch_size = bs,shuffle = True)

我们可以通过查看shape属性来看结果是否符合要求:

3)定义神经网络架构

先定义输入输出神经元个数:

input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())

定义架构:

def fit(net, batchdata, lr=0.01, epochs=5, gamma=0):
    criterion = nn.NLLLoss()  # 定义损失函数
    opt = optim.SGD(net.parameters(), lr=lr, momentum=gamma)  # 定义优化算法
    
    for epoch in range(epochs):
        net.train()  # 设置模型为训练模式
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (x, y) in enumerate(batchdata):
            y = y.view(x.shape[0])  # 确保y是一个一维的张量
            opt.zero_grad()  # 清除之前的梯度
            
            sigma = net(x)  # 前向传播
            loss = criterion(sigma, y)  # 计算损失
            
            loss.backward()  # 反向传播
            opt.step()  # 更新参数
            
            # 计算损失
            running_loss += loss.item()
            
            # 计算准确率
            _, predicted = torch.max(sigma, 1)  # 获取模型的预测
            total += y.size(0)
            correct += (predicted == y).sum().item()
        
        # 输出每个epoch的平均损失和准确率
        avg_loss = running_loss / len(batchdata)
        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

4)实例化

torch.manual_seed(250)
net = model(in_features=input_, out_features=output_)
fit(net,batcheddata,lr=lr,epochs=epochs,gamma=gamma)

由于上面的代码都是前面的章节中已经提及过的,这里就不再重复了。

完整代码:

#完整代码
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

lr = 0.1
gamma = 0.7
epochs = 5
bs = 128

mnist = torchvision.datasets.FashionMNIST(
    root=r'E:\桌面\深度学习课件\lesson 11\MINST-FASHION'
   , train=True
   , download=False
   , transform=transforms.ToTensor())
batcheddata = DataLoader(mnist,batch_size = bs,shuffle = True)
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())
class model(nn.Module):
    def __init__(self,in_features=1,out_features=2):
        super().__init__()
        self.linear1 = nn.Linear(in_features,128,bias=False)
        self.output = nn.Linear(128,out_features,bias=False)
    def forward(self,x):
        x = x.view(-1,28*28)
        sigma1 = torch.relu(self.linear1(x))
        z2 = self.output(sigma1)
        sigma2 = F.log_softmax(z2,dim=1)
        return sigma2

def fit(net, batchdata, lr=0.01, epochs=5, gamma=0):
    criterion = nn.NLLLoss()  # 定义损失函数
    opt = optim.SGD(net.parameters(), lr=lr, momentum=gamma)  # 定义优化算法
    
    for epoch in range(epochs):
        net.train()  # 设置模型为训练模式
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (x, y) in enumerate(batchdata):
            y = y.view(x.shape[0])  # 确保y是一个一维的张量
            opt.zero_grad()  # 清除之前的梯度
            
            sigma = net(x)  # 前向传播
            loss = criterion(sigma, y)  # 计算损失
            
            loss.backward()  # 反向传播
            opt.step()  # 更新参数
            
            # 计算损失
            running_loss += loss.item()
            
            # 计算准确率
            _, predicted = torch.max(sigma, 1)  # 获取模型的预测
            total += y.size(0)
            correct += (predicted == y).sum().item()
        
        # 输出每个epoch的平均损失和准确率
        avg_loss = running_loss / len(batchdata)
        accuracy = 100 * correct / total
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
torch.manual_seed(250)
net = model(in_features=input_, out_features=output_)
fit(net,batcheddata,lr=lr,epochs=epochs,gamma=gamma)


http://www.kler.cn/a/533890.html

相关文章:

  • Java进阶(JVM调优)——阿里云的Arthas的使用 安装和使用 死锁查找案例,重新加载案例,慢调用分析
  • 宾馆民宿酒店住宿管理系统+小程序项目需求分析文档
  • CompletableFuture
  • 物业管理系统源码提升社区智能化管理效率与用户体验
  • 【LeetCode 刷题】回溯算法(4)-排列问题
  • JavaScript系列(54)--性能优化技术详解
  • maven本地打包依赖无法引用
  • 基于微信小程序的培训机构客户管理系统设计与实现(LW+源码+讲解)
  • 动态规划——斐波那契数列模型问题
  • Java 进阶 01 —— 5 分钟回顾一下 Java 基础知识
  • 【华为OD-E卷 - 107 连续出牌数量 100分(python、java、c++、js、c)】
  • 使用 Deepseek AI 制作视频的完整教程
  • 神经网络常见激活函数 2-tanh函数(双曲正切)
  • 63.网页请求与按钮禁用 C#例子 WPF例子
  • 低代码系统-产品架构案例介绍、蓝凌(十三)
  • 4.PPT:日月潭景点介绍【18】
  • Python爬虫实战:一键采集电商数据,掌握市场动态!
  • MySQL 索引原理
  • 昆工昆明理工大学材料25调剂名额
  • [CMake]报错: Qt requires a C++17 compiler
  • 【starrocks学习】之将starrocks表同步到hive
  • 机器学习8-卷积和卷积核
  • Java进阶,集合,Colllection,常见数据结构
  • Spring Boot Actuator与JMX集成实战
  • Java 面试合集(2024版)
  • java后端开发面试常问