当前位置: 首页 > article >正文

六。自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

import torch
import torch.nn as nn
from torch.utils.data import Dataset
import numpy as np

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = torch.from_numpy(x_data).float()
        self.y_data = torch.from_numpy(y_data).float()

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

# 逻辑回归模型
class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, 1)  # 输出维度为1,因为这是二分类问题

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

# 创建数据集
x_data = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y_data = np.array([[0], [0], [1], [1], [1]], dtype=np.float32)
dataset = CustomDataset(x_data, y_data)

# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

# 创建模型、损失函数和优化器
model = LogisticRegressionModel(input_dim=1)
criterion = nn.BCELoss()  # 二分类交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 100
for epoch in range(epochs):
    for x_batch, y_batch in dataloader:
        # 前向传播
        y_pred = model(x_batch)
        loss = criterion(y_pred, y_batch.view(-1, 1))
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'logistic_regression_model.pth')

# 加载模型
model = LogisticRegressionModel(input_dim=1)
model.load_state_dict(torch.load('logistic_regression_model.pth'))

# 进行预测
with torch.no_grad():
    x_test = torch.tensor([[6], [7], [8]], dtype=torch.float32)
    y_pred = model(x_test)
    print('预测值:', y_pred)
 


http://www.kler.cn/a/535682.html

相关文章:

  • Vue el-input密码输入框 按住显示密码,松开显示*;阻止浏览器密码回填,自写密码输入框;校验输入非汉字内容;文本框聚焦到内容末尾;
  • 第二次连接k8s平台注意事项
  • go流程控制
  • 114,【6】攻防世界 web wzsc_文件上传
  • LPJ-GUESS模型入门(一)
  • PostgreSQL函数自动Commit/Rollback所带来的问题
  • 虚拟机报错:处理器未启用 NX/XD,而 VMware Workstation 要求启用 NX/XD.没有可打开电源的虚拟机
  • k8sollama部署deepseek-R1模型,内网无坑
  • 【Leetcode 热题 100】72. 编辑距离
  • 【2025最新计算机毕业设计】基于springboot智能教师评价系统【提供源码+答辩PPT+文档+项目部署】(高质量源码,可定制,提供文档,免费部署到本地)
  • 深入理解C++的new和delete
  • linux 使用docker安装 postgres 教程,踩坑实践
  • Unity3D 切线空间及其应用详解
  • springboot011-G县乡村生活垃圾治理问题中运输地图系统
  • 网络安全配置
  • 从源码到上线:AI在线教育系统开发全流程与网校APP构建指南
  • 使用 TensorRT 和 Python 实现高性能图像推理服务器
  • LeetCode 415
  • 无人机使用数传时的注意事项!
  • 尝试在Excel里调用硅基流动上的免费大语言模型
  • FFmpeg 头文件完美翻译之 libavdevice 模块
  • gc buffer busy acquire导致的重大数据库性能故障
  • Elasticsearch集群模式保姆级教程
  • MQTT:物联网时代的数据桥梁
  • 【Redisson分布式锁】基于redisson的分布式锁
  • 在 Vue Router 中,params和query的区别?