当前位置: 首页 > article >正文

深度学习 模型和代码

提供一个简单的深度学习模型(类似 DeepSeek 工作原理的简单示例,比如一个简单的神经网络实现手写数字识别,使用 PyTorch 框架)示例代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True,
                               download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False,
                              download=True, transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x


# 初始化模型、损失函数和优化器
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

# 在测试集上评估模型
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        test_loss += criterion(output, target).item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

以上代码构建了一个简单的卷积神经网络用于 MNIST 手写数字识别,包含了数据加载、模型定义、训练和测试等流程。


http://www.kler.cn/a/584464.html

相关文章:

  • mysql进阶——数据类型一篇详解
  • 在 Linux 64 位系统上安装 Oracle 11g R2 数据库的完整指南
  • 2025-3-13 leetcode刷题情况(贪心算法--区间问题)
  • Prompt优化 COT/COD
  • 时间有限,如何精确设计测试用例?5种关键方法
  • pop_dialog_state(state: State)弹出对话栈并返回到主助手,让整个对话流程图可以明确追踪对话流,并将控制权委派给特定的子对话图。
  • 使用conda将python环境打包,移植到另一个linux服务器项目中
  • Matplotlib高阶技术全景解析
  • 【数据挖掘】知识蒸馏(Knowledge Distillation, KD)
  • kali linux 漏洞扫描
  • (每日一题) 力扣 179 最大数
  • 前端面试:如何实现预览 PDF 文件?
  • 基于深度学习的肺炎X光影像自动诊断系统实现,真实操作案例分享,值得学习!
  • 【文献阅读】SPRec:用自我博弈打破大语言模型推荐的“同质化”困境
  • 电子电气架构 --- 智能电动汽车概述
  • 塔能IVO-SCY智能机箱:点亮智慧城市的电力“智慧核芯”
  • python语言写的一款pdf转word、word转pdf的免费工具
  • 微店关键词搜索接口(micro.item_search)返回数据测试指南
  • Spring 注解解析
  • java: system类