使用PyTorch实现逻辑回归:从训练到模型保存与加载
1. 引入必要的库
首先,需要引入必要的库。PyTorch用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
2. 加载自定义数据集
有一个CSV文件custom_dataset.csv
,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。
# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')
# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32) # 特征
y = data.iloc[:, -1].values.astype(np.float32) # 标签
# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)
3. 创建数据集和数据加载器
使用PyTorch的TensorDataset
和DataLoader
来创建数据集和数据加载器。
# 创建数据集和数据加载器
dataset = TensorDataset(torch.tensor(X), torch.tensor(y))
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
4. 定义逻辑回归模型
使用PyTorch的nn.Module
来定义逻辑回归模型。
class LogisticRegression(nn.Module):
def __init__(self, input_dim):
super(LogisticRegression, self).__init__()
self.linear = nn.Linear(input_dim, 1)
def forward(self, x):
outputs = torch.sigmoid(self.linear(x))
return outputs
# 初始化模型
input_dim = X.shape[1]
model = LogisticRegression(input_dim)
5. 训练模型
定义损失函数和优化器,然后训练模型。
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
for inputs, labels in train_loader:
# 前向传播
outputs = model(inputs)
loss = criterion(outputs.flatten(), labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
6. 保存模型
训练完成后,可以使用PyTorch的torch.save
函数来保存模型。
# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')
7. 加载模型并进行预测
在需要时,可以使用torch.load
函数加载模型,并进行预测。
# 加载模型
model = LogisticRegression(input_dim)
model.load_state_dict(torch.load('logistic_regression_model.pth'))
model.eval()
# 进行预测
with torch.no_grad():
sample_inputs = torch.tensor(X[:5]).float() # 示例输入
predictions = model(sample_inputs)
predicted_labels = (predictions.flatten() > 0.5).int()
print("Predicted Labels:", predicted_labels.numpy())