当前位置: 首页 > article >正文

0基础学习PyTorch——最小Demo

大纲

  • 环境准备
    • 安装依赖
  • 训练和推理
    • 训练
      • 生成数据
      • 加载数据
        • TensorDataset
        • DataLoader
      • 定义神经网络
      • 定义损失函数和优化器
      • 训练模型
    • 推理
  • 参考代码

PyTorch以其简洁直观的API、动态计算图和强大的社区支持,在学术界和工业界都享有极高的声誉,成为许多深度学习爱好者和专业人士的首选工具。本系列将更多从工程实践角度探索PyTorch的使用,而不是算法公式的讨论。

环境准备

使用《管理Python虚拟环境的脚本》中的脚本初始化一个虚拟环境。

source env.sh init

在这里插入图片描述

然后进入虚拟环境

source env.sh enter

在这里插入图片描述

安装依赖

source env.sh install pyyaml

在这里插入图片描述

source env.sh install torch

这个过程比较漫长,需要下载一个多G文件。
在这里插入图片描述

source env.sh install numpy

在这里插入图片描述

训练和推理

训练就是模型训练。我们可以认为知道系统的输入和输出(目标),猜测系统中的算法。在这里插入图片描述

推理则是使用模型,计算出对应的输出。

在这里插入图片描述
举一个例子,也是我们后面代码的例子。假如我们使用f(x)=2x+1计算一批随机数(输入)得到一批计算结果(目标),然后我们将这些数据交给一个模型训练器,可以得到一个模型。这个模型的计算结果(输出)应该非常近似于f(x)=2x+1。

训练

生成数据

数据生成不是必须的,因为我们从其他地方获取数据。为了让这个例子没有太多依赖项,我们就自己生成数据。
input_data 是100个随机数数组,它是模型训练的“输入”数据;target_data是对input_data使用f(x)=2x+1得到的一个数组,它是模型训练的“目标”数据。即我们需要模型要将“输入”尽量转换成接近的“目标”。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 生成一些随机数据
torch.manual_seed(0)
input_data = torch.randn(100, 1)  # 100个样本,每个样本有2个特征
target_data = 2 * input_data + 1  # 简单的线性关系

加载数据

# 创建数据加载器
dataset = TensorDataset(input_data, target_data)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
TensorDataset

TensorDataset 是一个简单的数据集封装类,用于将多个张量(tensors)包装在一起。它的主要作用是将特征和标签数据对齐,以便于后续的数据加载和处理。

主要功能:

  • 将多个张量封装成一个数据集。
  • 使得数据集可以通过索引访问。
DataLoader

DataLoader 是一个数据加载器类,用于将数据集分批次加载到模型中进行训练或评估。它的主要作用是提供一个迭代器,能够高效地加载数据,并支持多线程并行加载。

主要功能:

  • 将数据集分批次加载。
  • 支持多线程并行加载数据。
  • 支持数据的随机打乱(shuffle)。
  • 提供一个迭代器,方便在训练循环中使用。

定义神经网络

定义神经网络是深度学习模型开发的核心步骤之一。一个良好定义的神经网络可以有效地学习和泛化数据,从而在各种任务中取得优异的表现。
本文不过度讨论神经网络,只是抛砖引玉,让大家知道结构长什么样子。

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入和输出都是1维

    def forward(self, x):
        return self.linear(x)

定义损失函数和优化器

# 实例化模型、损失函数和优化器
model = SimpleNet()
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

损失函数用于衡量模型预测值与真实值之间的差异。它是模型优化的目标,模型训练的目的是最小化损失函数的值

优化器用于更新模型参数,以最小化损失函数。

训练模型

对于有限的数据,我们可以通过增加训练次数来优化模型。所以下面代码,我们对一个数据集进行了20次训练。

num_epochs = 20
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        optimizer.zero_grad()  # 清空梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, targets)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

后面会单独开一篇文章来将训练过程。

推理

我们再生成一个输入数组,并计算期望的值。

# 生成一组测试输入数据
test_input = torch.randn(10, 1)  # 生成10个随机样本,每个样本1个特征
test_output = 2 * test_input + 1  # 期望的输出

然后用模型算出结果,并进行比较

# 推理
model.eval()
with torch.no_grad():
    output = model(test_input)
    for i in range(len(test_input)):
        print(f'''Test Input: {test_input[i].item()}, 
    Test Output: {output[i].item()}, 
    Actual Output: {test_output[i].item()}, 
    Diff: {output[i].item() - test_output[i].item()}, 
    Loss: {abs(output[i].item() - test_output[i].item()) / abs(test_output[i].item())* 100:.2f}%\n''')

在这里插入图片描述
在这里插入图片描述
我们看到,模型最后推理出的结果和我们的期望值误差在2%以内。

参考代码

https://github.com/f304646673/deeplearning/tree/main/mvp


http://www.kler.cn/news/315693.html

相关文章:

  • AI教你学Python 第17天 :小项目联系人管理系统
  • 小程序-模板与配置
  • 乱弹篇(53)丹桂未飘香
  • Excel常用函数大全
  • DAPP智能合约系统开发
  • 【计算机网络 - 基础问题】每日 3 题(十八)
  • SecureCRT下载
  • 如何在 MySQL Workbench 中修改表数据并保存??
  • 记一次Meilisearch轻量级搜索引擎使用
  • 蓝桥杯1.小蓝的漆房
  • 网络安全等保培训 ppt
  • 循环单链表来表示队列
  • Qt Debugging帮助文档
  • Packet Tracer - 配置编号的标准 IPv4 ACL(两篇)
  • jQuery 入口函数 详解
  • IT行业中的工作生活平衡探讨
  • Android 内核开发之—— repo 使用教程
  • Kotlin 函数和变量(四)
  • 双向链表:实现、操作与分析【算法 17】
  • 【Fastapi】参数获取,json和query
  • 深度学习02-pytorch-05-张量的索引操作
  • 2024 年最新前端ES-Module模块化、webpack打包工具详细教程(更新中)
  • Android 车载应用开发指南 - CarService 详解(下)
  • 在Spring Boot中实现多环境配置
  • 汽车总线之----FlexRay总线
  • LeetCode_sql_day31(1384.按年度列出销售总额)
  • C# 委托与事件 观察者模式
  • Java项目实战II基于Java+Spring Boot+MySQL的植物健康系统(开发文档+源码+数据库)
  • 设计模式之复合模式
  • 高级java每日一道面试题-2024年9月16日-框架篇-Spring MVC和Struts的区别是什么?