当前位置: 首页 > article >正文

使用 Numpy 自定义数据集,使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

1. 导入必要的库

首先,导入我们需要的库:Numpy、Pytorch 和相关工具包。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, recall_score, f1_score
2. 自定义数据集

使用 Numpy 创建一个简单的线性可分数据集,并将其转换为 Pytorch 张量。

# 创建数据集
X = np.random.rand(100, 2)  # 100 个样本,2 个特征
y = (X[:, 0] + X[:, 1] > 1).astype(int)  # 标签,若特征之和大于1则为 1,否则为 0

# 转换为 PyTorch 张量
X_train = torch.tensor(X, dtype=torch.float32)
y_train = torch.tensor(y, dtype=torch.long)
3. 定义逻辑回归模型

在 Pytorch 中定义一个简单的逻辑回归模型。

class LogisticRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, 2)  # 二分类问题

    def forward(self, x):
        return self.linear(x)
4. 初始化模型、损失函数和优化器
# 初始化模型
model = LogisticRegressionModel(input_dim=2)

# 损失函数与优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)
5. 训练模型

训练模型并保存训练好的权重。

epochs = 100
for epoch in range(epochs):
    # 前向传播
    outputs = model(X_train)
    loss = criterion(outputs, y_train)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 20 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# 保存模型
torch.save(model.state_dict(), 'logistic_regression.pth')
6. 加载模型并进行预测

加载保存的模型并进行预测。

# 加载模型
model = LogisticRegressionModel(input_dim=2)
model.load_state_dict(torch.load('logistic_regression.pth'))
model.eval()  # 设为评估模式

# 预测
with torch.no_grad():
    y_pred = model(X_train)
    _, predicted = torch.max(y_pred, 1)
7. 计算精确度、召回率和 F1 分数

使用 sklearn 中的评估函数计算精确度、召回率和 F1 分数。

accuracy = accuracy_score(y_train, predicted)
recall = recall_score(y_train, predicted)
f1 = f1_score(y_train, predicted)

print(f"Accuracy: {accuracy:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
8. 总结

这篇博客展示了如何使用 Numpy 自定义数据集,利用 Pytorch 框架实现逻辑回归模型,并进行训练。训练后的模型被保存,并在加载后进行预测,最后计算了精确度、召回率和 F1 分数。


http://www.kler.cn/a/529585.html

相关文章:

  • model calibration
  • 如何使用SliverList组件
  • p1044 栈
  • 【13】WLC HA介绍和配置
  • Python的那些事第六篇:从定义到应用,Python函数的奥秘
  • 推荐一款好用的翻译类浏览器扩展插件
  • C29.【C++ Cont】STL库:动态顺序表(vector容器)
  • LeetCode //C - 567. Permutation in String
  • IM 即时通讯系统-42-基于netty实现的IM服务端,提供客户端jar包,可集成自己的登录系统
  • 【Redis】Redis 经典面试题解析:深入理解 Redis 的核心概念与应用
  • java基础概念63-多线程
  • 【xdoj-离散线上练习】T251(C++)
  • AI技术路线(marked)
  • LeetCode 344: 反转字符串
  • Zabbix 推送告警 消息模板 美化(钉钉Webhook机器人、邮件)
  • 无人机飞手光伏吊运、电力巡检、农林植保技术详解
  • kamailio的kamctl的使用
  • [c语言日寄]C语言类型转换规则详解
  • ZYNQ-AXI DMA+AXI-S FIFO回环学习
  • DirectShow过滤器开发-读视频文件过滤器(再写)
  • 本地缓存~
  • 功防世界 Web_php_include
  • 理解红黑树
  • word2vec 实战应用介绍
  • Kotlin 协程 与 Java 虚拟线程对比测试(娱乐性质,请勿严谨看待本次测试)
  • VSCode设置内容字体大小