当前位置: 首页 > article >正文

卷积神经网络 (CNN)

代码功能

网络结构:
卷积层:
两个卷积层,每个卷积层后接 ReLU 激活函数。
最大池化层用于降低维度。
全连接层:
使用一个隐藏层(128 个神经元)和一个输出层(10 类分类任务)。
数据集:
使用 PyTorch 内置的 MNIST 数据集,其中包含手写数字的灰度图像。
训练过程:
使用交叉熵损失函数 (CrossEntropyLoss)。
优化器为 Adam,学习率设为 0.001。
每轮训练输出损失。
测试与可视化:
测试模型在测试集上的准确率。
可视化 6 张测试样本的预测结果与真实标签。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 1. 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 卷积层 1
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 最大池化层
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 卷积层 2
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 最大池化层
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),  # 全连接层 1
            nn.ReLU(),
            nn.Linear(128, 10)  # 全连接层 2 (10 类分类)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# 2. 加载 MNIST 数据集
def load_data(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 标准化
    ])
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

# 3. 训练 CNN
def train_cnn(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU(如适用)

            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)

            # 反向传播与优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")

# 4. 测试 CNN
def test_cnn(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU(如适用)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# 5. 可视化测试结果
def visualize_predictions(model, test_loader):
    model.eval()
    images, labels = next(iter(test_loader))
    images, labels = images.cuda(), labels.cuda()
    outputs = model(images)
    _, predicted = torch.max(outputs, 1)

    # 绘制图像与预测结果
    images, labels, predicted = images.cpu(), labels.cpu(), predicted.cpu()
    plt.figure(figsize=(12, 8))
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title(f"True: {labels[i]}, Pred: {predicted[i]}")
        plt.axis('off')
    plt.show()

# 主程序
if __name__ == "__main__":
    # 加载数据
    train_loader, test_loader = load_data()

    # 初始化网络、损失函数和优化器
    model = CNN().cuda()  # 将模型移动到 GPU(如适用)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 训练和测试模型
    train_cnn(model, train_loader, criterion, optimizer, epochs=5)
    test_cnn(model, test_loader)

    # 可视化部分测试结果
    visualize_predictions(model, test_loader)


http://www.kler.cn/a/400972.html

相关文章:

  • React解决保存less文件后会自动生成css文件的方法
  • 3D Streaming 在线互动展示系统:NVIDIA RTX 4090 加速实时渲染行业数字化转型
  • Figma中文网:UI设计师的新资源宝库
  • Springboot 整合 Java DL4J 打造金融风险评估系统
  • OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
  • leetcode 面试150之 Z 字形变换
  • java设计模式之 - 适配器模式
  • AVL树(c++版)
  • 农业银行手机银行系统介绍
  • 苍穹外卖 软件开发流程
  • Elasticsearch实战应用-dsl语句
  • fpga spi回环
  • 【每日题解】3239. 最少翻转次数使二进制矩阵回文 I
  • Dolby TrueHD和Dolby Digital Plus (E-AC-3)编码介绍
  • MaaS模型即服务的优势与发展前景
  • 百度世界大会2024,展现科技改变生活的力量
  • Vue 生成二维码
  • 革命性AI搜索引擎!ChatGPT最新功能发布,无广告更智能!
  • Leetcode 回文数
  • 游戏引擎学习第17天
  • 使用合适的Prompt充分利用ChatGPT的能力
  • 如何用WordPress和Shopify提升SEO表现?
  • labview中连接sql server数据库查询语句
  • 一次需升级系统的wxpython安装(macOS M1)
  • Macmini中普通鼠标与TrackPad联动问题解决
  • openGauss 6.0.0单机部署(企业版)