当前位置: 首页 > article >正文

自定义数据集,使用 PyTorch 框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

在本文中,我们将展示如何使用 NumPy 创建自定义数据集,利用 PyTorch 实现一个简单的逻辑回归模型,并在训练完成后保存该模型,最后加载模型并用它进行预测。

1. 创建自定义数据集

首先,我们使用 NumPy 创建一个简单的二分类数据集。假设我们的数据集包含两个特征。

import numpy as np

# 生成随机数据
np.random.seed(42)
X = np.random.randn(100, 2)  # 100个样本,2个特征
y = (X[:, 0] + X[:, 1] > 0).astype(int)  # 标签为1或0

# 打印数据集的前5个样本
print(X[:5], y[:5])

2. 构建逻辑回归模型

接下来,我们使用 PyTorch 来构建一个简单的逻辑回归模型。PyTorch 提供了 torch.nn.Module 类,能够轻松实现神经网络模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 转换 NumPy 数据为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)

# 创建数据集并加载
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 定义逻辑回归模型
class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(2, 1)  # 2个输入特征,1个输出

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# 初始化模型
model = LogisticRegressionModel()

3. 训练模型

接下来,我们定义损失函数和优化器,并训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 1000
for epoch in range(epochs):
    for inputs, labels in dataloader:
        optimizer.zero_grad()  # 清空梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

4. 保存模型

模型训练完成后,我们可以将模型保存到文件中。PyTorch 提供了 torch.save() 方法来保存模型的状态字典。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
print("模型已保存!")

5. 加载模型并进行预测

我们可以加载保存的模型,并对新数据进行预测。

# 加载模型
loaded_model = LogisticRegressionModel()
loaded_model.load_state_dict(torch.load('logistic_regression_model.pth'))
loaded_model.eval()  # 切换到评估模式

# 进行预测
with torch.no_grad():
    test_data = torch.tensor([[1.5, -0.5]], dtype=torch.float32)
    prediction = loaded_model(test_data)
    print(f'预测值: {prediction.item():.4f}')

6. 总结

在这篇博客中,我们展示了如何使用 NumPy 创建一个简单的自定义数据集,并使用 PyTorch 实现一个逻辑回归模型。我们还展示了如何保存训练好的模型,并加载模型进行预测。通过保存和加载模型,我们可以在不同的时间或环境中重复使用已经训练好的模型,而不需要重新训练它。


http://www.kler.cn/a/523473.html

相关文章:

  • QT设置应用程序图标
  • pytorch实现半监督学习
  • 实验七 带函数查询和综合查询(2)
  • pytorch线性回归模型预测房价例子
  • Python学习之旅:进阶阶段(五)数据结构-双端队列(collections.deque)
  • 快速提升网站收录:内容创作的艺术
  • 重回C语言之老兵重装上阵(十四)C语言头文件详解
  • LeetCode热题100(八)—— 438.找到字符串中所有字母异位词
  • 危机13小时:追踪一场GitHub投毒事件
  • 实验二---二阶系统阶跃响应---自动控制原理实验课
  • wx043基于springboot+vue+uniapp的智慧物流小程序
  • Direct2D 极速教程(2) —— 画淳平
  • 求解旅行商问题的三种精确性建模方法,性能差距巨大
  • android主题设置为..DarkActionBar.Bridge时自定义DatePicker选中日期颜色
  • woocommerce独立站与wordpress独立站的最大区别是什么
  • 每日OJ_牛客_小红的子串_滑动窗口+前缀和_C++_Java
  • 实验二 数据库的附加/分离、导入/导出与备份/还原
  • XCTF Illusion wp
  • 【C++动态规划 状态压缩】2741. 特别的排列|2020
  • 在FreeBSD下安装Ollama并体验DeepSeek r1大模型
  • 循序渐进kubernetes-RBAC(Role-Based Access Control)
  • Vue.js Vuex 持久化存储(持久化插件)
  • 【Linux探索学习】第二十七弹——信号(一):Linux 信号基础详解
  • 小阿卡纳牌
  • Python 数据分析 - 初识 Pandas
  • Git Bash 配置 zsh