当前位置: 首页 > article >正文

使用PyTorch实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库

首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


2. 加载自定义数据集

有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)


3. 创建数据集和数据加载器

使用PyTorch的TensorDatasetDataLoader来创建数据集和数据加载器。

# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)


4. 定义逻辑回归模型

使用PyTorch的nn.Module来定义逻辑回归模型。

class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        outputs = torch.sigmoid(self.linear(x))
        return outputs

# 初始化模型
input_dim = X.shape[1]
model = LogisticRegression(input_dim)

5. 训练模型

定义损失函数和优化器,然后训练模型。

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs.flatten(), labels)
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

6. 保存模型

训练完成后,可以使用PyTorch的torch.save函数来保存模型。

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')


7. 加载模型并进行预测

在需要时,可以使用torch.load函数加载模型,并进行预测。

# 加载模型
model = LogisticRegression(input_dim)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()

# 进行预测
with torch.no_grad():
    sample_inputs = torch.tensor(X[:5]).float()  # 示例输入
    predictions = model(sample_inputs)
    predicted_labels = (predictions.flatten() > 0.5).int()

print("Predicted Labels:", predicted_labels.numpy())


http://www.kler.cn/a/521110.html

相关文章:

  • TikTok 推出了一款 IDE,用于快速构建 AI 应用
  • Hystrix熔断器与异常处理的艺术
  • 【Linux】列出所有连接的 WiFi 网络的密码
  • 通过 NAudio 控制电脑操作系统音量
  • 音频 PCM 格式 - raw data
  • C++小病毒-1.0勒索(更新次数:2)
  • MySQL 8 不开通 CLONE 插件,建立主从关系
  • mybatis(78/134)
  • 使用QSqlQueryModel创建交替背景色的表格模型
  • 技术速递|.NET 9 中的 OpenAPI 文档生成
  • 【大数据】数据治理浅析
  • 想品客老师的第七天:闭包和作用域
  • 代码随想录算法训练营day30(补0123)
  • 基于 Ansible 的 Linux 服务器自动化运维实战
  • Java Web-Cookie与Session
  • 前端性能优化指标 - DCL(触发时机、脚本对 DCL 的影响、CSS 对 DCL 的影响)
  • RAG:实现基于本地知识库结合大模型生成(LangChain4j快速入门#1)
  • 【HarmonyOS之旅】基于ArkTS开发(三) -> 兼容JS的类Web开发(二)
  • ollama使用详解
  • JavaScript 的 Promise 对象和 Promise.all 方法的使用
  • 验证二叉搜索树(力扣98)
  • Pandas基础03(数据的合并操作/concat()/append()/merge())
  • 第五节 MATLAB命令
  • 【大厂面试AI算法题中的知识点】方向涉及:ML/DL/CV/NLP/大数据...本篇介绍Transformer相较于CNN的优缺点?
  • WPF基础 | WPF 常用控件实战:Button、TextBox 等的基础应用
  • 对比OpenAI的AI智能体Operator和智谱的GLM-PC,它们有哪些不同?