当前位置: 首页 > article >正文

PyTorch实战-手写数字识别-MLP模型

1 需求

10分钟入门神经网络 PyTorch 手写数字识别_哔哩哔哩_bilibili

pytorch tutorial: PyTorch 手写数字识别 教程代码

从零设计并训练一个神经网络,你就能真正理解它了_哔哩哔哩_bilibili

https://github.com/xhh890921/mnist_network


2 接口


3 豆包生成代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# 定义超参数
batch_size = 128
learning_rate = 0.001
num_epochs = 10

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 定义 MLP 模型
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 实例化模型
model = MLP()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # 前向传播
        outputs = model(data)
        loss = criterion(outputs, targets)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item()}')

# 在测试集上评估模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for data, targets in test_loader:
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')

3  

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)


def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28 * 28))
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total


def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()

    print("initial accuracy:", evaluate(test_data, net))
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(2):
        for (x, y) in train_data:
            net.zero_grad()
            output = net.forward(x.view(-1, 28 * 28))
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()
            optimizer.step()
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()

4 参考资料

PyTorch——手写数字识别_pytorch 手写数字-CSDN博客

Python :MNIST手写数据集识别 + 手写板程序 最详细,直接放心,大胆地抄!跑不通找我,我包教!_手写数字数据集-CSDN博客

Python人工智能--实现手写数字识别-CSDN博客


http://www.kler.cn/a/379526.html

相关文章:

  • selinux和防火墙
  • 相机硬触发
  • VisionPro —— CogPatInspectTool对比工具
  • RuoYi 样例框架运行步骤(测试项目自用,同学可自取)
  • 《Keras3 深度学习初探:开启Keras3 深度学习之旅》
  • Linux——Ubuntu环境C编程
  • Redis高级篇之缓存一致性详细教程
  • OpenEuler 使用ffmpeg x11grab捕获屏幕流,rtsp推流,并用vlc播放
  • 深入理解 Spring AOP:面向切面编程的原理与应用
  • LeetCode 0633.平方数之和:模拟
  • 【系统架构设计师】预测试卷一:综合知识(75道选择题)
  • Android Studio 安装过程
  • 虚拟化环境中的精简版 Android 操作系统 Microdroid
  • 【MATLAB源码-第286期】基于MATLAB的根升余弦脉冲整形对 BPSK 和 QPSK 调制的影响的对比仿真,输出功率谱,误码率曲线,星座图,眼图等.
  • 【初阶数据结构篇】链式结构二叉树(二叉链)的实现(感受递归暴力美学)
  • 金蝶云数据集成至MySQL的高效解决方案
  • 除了Vue CLI,还有哪些方式可以创建 Vue 项目?
  • Spring Boot 集成 Kafka
  • BERT的新闻标题生成
  • pip install -r requirements.txt下载速度慢
  • 跨越科技与文化的桥梁——ROSCon China 2024 即将盛大开幕
  • openstack之guardian介绍与实例创建过程
  • C语言实现力扣第31题:下一个排列
  • 重大917该如何复习?难度大不大?重点是啥?
  • Bacnet+springboot部署到linux后,无法检测到网络中的其他设备
  • 项目解决方案:跨不同的物理网络实现视频监控多画面的实时视频的顺畅访问