当前位置: 首页 > article >正文

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

# 数据准备
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])
class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

x_train = np.concatenate((class1_points, class2_points), axis=0)
y_train = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))))

x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)

# 设置随机种子
seed = 42
torch.manual_seed(seed)

# 定义模型
class LogisticRegreModel(nn.Module):
    def __init__(self):
        super(LogisticRegreModel, self).__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x

model = LogisticRegreModel()

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)

# 训练模型
epochs = 1000
for epoch in range(1, epochs + 1):
    y_pred = model(x_train_tensor)
    loss = criterion(y_pred, y_train_tensor.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0 or epoch == 1:
        print(f"epoch: {epoch}, loss: {loss.item()}")

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 加载模型
model = LogisticRegreModel()
model.load_state_dict(torch.load('model.pth'))
# 设置模型为评估模式
model.eval()

# 进行预测
with torch.no_grad():
    y_pred = model(x_train_tensor)
    y_pred_class = (y_pred > 0.5).float().squeeze()

# 计算精确度、召回率和F1分数
precision = precision_score(y_train_tensor.numpy(), y_pred_class.numpy())
recall = recall_score(y_train_tensor.numpy(), y_pred_class.numpy())
f1 = f1_score(y_train_tensor.numpy(), y_pred_class.numpy())

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")


http://www.kler.cn/a/528702.html

相关文章:

  • jstat命令详解
  • DeepSeek 遭 DDoS 攻击背后:DDoS 攻击的 “千层套路” 与安全防御 “金钟罩”
  • 智慧园区如何利用智能化手段提升居民幸福感与环境可持续性
  • Autogen_core源码:_agent_instantiation.py
  • 【Java异步编程】CompletableFuture基础(1):创建不同线程的子任务、子任务链式调用与异常处理
  • 360嵌入式开发面试题及参考答案
  • Java小白入门教程:LinkedList
  • 车载以太网---数据链路层
  • Spark SQL读写Hive Table部署
  • SQL入门到精通 理论+实战 -- 在 MySQL 中学习SQL语言
  • 10:预处理
  • C++,vector:动态数组的原理、使用与极致优化
  • 回溯算法理论基础
  • 递归练习七(floodfill 算法)
  • C#属性和字段(访问修饰符)
  • 代码随想录-训练营-day17
  • 自制虚拟机(C/C++)(二、分析引导扇区,虚拟机读二进制文件img软盘)
  • 代码随想录算法训练营第四十二天-动态规划-股票-188.买卖股票的最佳时机IV
  • JVM运行时数据区域-附面试题
  • 笔记:同步电机调试时电角度校正方法说明
  • Python GIL(全局解释器锁)机制对多线程性能影响的深度分析
  • 《逆向工程核心原理》第三~五章知识整理
  • MATLAB实现多种群遗传算法
  • SQLAlchemy通用分页函数实现:支持搜索、排序和动态页码导航
  • 可视化相机pose colmap形式的相机内参外参
  • MySQL各种日志详解