当前位置: 首页 > article >正文

使用TensorFlow实现逻辑回归:从训练到模型保存与加载

1. 引入必要的库

首先,需要引入必要的库。TensorFlow用于构建和训练模型,pandas和numpy用于数据处理,matplotlib用于结果的可视化。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import SGD
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


2. 加载自定义数据集

假设有一个CSV文件custom_dataset.csv,其中包含特征(自变量)和标签(因变量)。使用pandas来加载数据,并进行预处理。

# 加载自定义数据集
data = pd.read_csv('custom_dataset.csv')

# 假设数据集中有多列特征和一个二分类标签
X = data.iloc[:, :-1].values.astype(np.float32)  # 特征
y = data.iloc[:, -1].values.astype(np.float32)   # 标签

# 将标签转换为0和1
y = np.where(y == 'positive', 1, 0)

3. 构建逻辑回归模型

使用TensorFlow的Keras接口来构建逻辑回归模型。

# 构建逻辑回归模型
model = Sequential([
    Dense(1, activation='sigmoid', input_shape=(X.shape[1],))
])

# 编译模型
model.compile(optimizer=SGD(learning_rate=0.01), loss='binary_crossentropy', metrics=['accuracy'])


4. 训练模型

使用自定义数据集训练模型。

# 训练模型
history = model.fit(X, y, epochs=100, batch_size=32, verbose=1)


5. 保存模型

训练完成后,可以使用TensorFlow的save方法保存模型。

# 保存模型
model.save('logistic_regression_model.h5')


6. 加载模型并进行预测

在需要时,可以使用TensorFlow的load_model方法加载模型,并进行预测。

# 加载模型
from tensorflow.keras.models import load_model

loaded_model = load_model('logistic_regression_model.h5')

# 进行预测
predictions = loaded_model.predict(X[:5])
predicted_labels = (predictions > 0.5).astype(int)

print("Predicted Labels:", predicted_labels.flatten())


7. 结果可视化

可以绘制训练过程中的损失和准确率变化曲线,以帮助理解模型的性能。

# 绘制训练和验证的损失曲线
plt.plot(history.history['loss'], label='Loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和验证的准确率曲线
plt.plot(history.history['accuracy'], label='Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


http://www.kler.cn/a/522004.html

相关文章:

  • [蓝桥杯 2014 省 AB] 蚂蚁感冒
  • Linux C++
  • 算法12(力扣739)-每日温度
  • 读书笔记:《华为突围ERP封锁全纪实》
  • 4.flask-SQLAlchemy,表Model定义、增删查改操作
  • 无耳科技 Solon v3.0.7 发布(2025农历新年版)
  • 信息学奥赛一本通 2110:【例5.1】素数环
  • 2025数学建模美赛|A题成品论文
  • 神经网络|(六)概率论基础知识-全概率公式
  • 爱快 IK-X9 吸顶AP 简单开箱评测和拆解,三频WiFi7,BE5000,2.5G网口
  • Continuous Batching 连续批处理
  • 基于ESP8266的多功能环境监测与反馈系统开发指南
  • 嵌入式C语言:结构体
  • KF-GINS 和 OB-GINS 的 Earth类 和 Rotation 类
  • 安卓日常问题杂谈(一)
  • Java-数据结构-二叉树习题(3)
  • 落地 基于特征的对象检测
  • leetcode 面试经典 150 题:简化路径
  • 鲁滨逊漂流记读后感
  • 【PySide6快速入门】QGridLayout 网格布局
  • 如何使用 DeepSeek API 结合 VSCode 提升开发效率
  • 深度学习笔记13-CIFAR彩色图片识别(Pytorch)
  • 供应链管理中的BOM 和 MRP 是什么,如何计算
  • 探索前端可观察性:如何使用Telemetry提高用户体验
  • 基于Java+Springboot+MySQL校园在线考试网站系统设计与实现
  • zyNo.19