当前位置: 首页 > article >正文

七。自定义数据集 使用tensorflow框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

import tensorflow as tf
import numpy as np

# 自定义数据集类
class CustomDataset(tf.data.Dataset):
    def __init__(self, x_data, y_data):
        self.x_data = tf.convert_to_tensor(x_data, dtype=tf.float32)
        self.y_data = tf.convert_to_tensor(y_data, dtype=tf.float32)

    def __iter__(self):
        for i in range(len(self.x_data)):
            yield (self.x_data[i], self.y_data[i])

# 逻辑回归模型
class LogisticRegressionModel(tf.keras.Model):
    def __init__(self, input_dim):
        super(LogisticRegressionModel, self).__init__()
        self.linear = tf.keras.layers.Dense(1, input_shape=(input_dim,), activation='sigmoid')

    def call(self, x):
        return self.linear(x)

# 创建数据集
x_data = np.array([[1], [2], [3], [4], [5]], dtype=np.float32)
y_data = np.array([[0], [0], [1], [1], [1]], dtype=np.float32)
dataset = CustomDataset(x_data, y_data)

# 创建数据加载器
dataloader = dataset.batch(2).shuffle(100).repeat()

# 创建模型、损失函数和优化器
model = LogisticRegressionModel(input_dim=1)
loss_object = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# 训练模型
epochs = 100
for epoch in range(epochs):
    for x_batch, y_batch in dataloader:
        with tf.GradientTape() as tape:
            predictions = model(x_batch)
            loss = loss_object(y_batch, predictions)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.numpy():.4f}')

# 保存模型
model.save('logistic_regression_model.h5')

# 加载模型
model = tf.keras.models.load_model('logistic_regression_model.h5')

# 进行预测
x_test = np.array([[6], [7], [8]], dtype=np.float32)
y_pred = model.predict(x_test)
print('预测值:', y_pred)
 


http://www.kler.cn/a/536020.html

相关文章:

  • 【分布式理论六】分布式调用(4):服务间的远程调用(RPC)
  • Google地图瓦片爬虫——进阶版
  • 【文件上传、秒传、分片上传、断点续传、重传】
  • 【Unity3D小功能】Unity3D中实现超炫按钮悬停效果
  • React组件开发技巧:如何优雅地传递Props?
  • Kafka 使用说明(kafka官方文档中文)
  • Mac: docker安装以后报错Command not found: docker
  • ctf网络安全大赛python ctf网络安全大赛
  • 本文主要详细讲解ArcGIS中的线、多线段和多边形的结构关系。
  • Kafka 可靠性探究—副本刨析
  • 关于maven的java面试题汇总
  • 1 Java 基础面试题(上)
  • 物联网实训室解决方案(2025年最新版)
  • BUU26 [极客大挑战 2019]HardSQL1
  • Electron学习笔记,用node程序备份数据库(2)
  • Github 2025-02-07Java开源项目日报 Top9
  • 二叉树实现(学习记录)
  • 神经辐射场(NeRF):从2D图像到3D场景的革命性重建
  • Java面试题——事务
  • 【论文翻译】DeepSeek-V3论文翻译——DeepSeek-V3 Technical Report——第一部分:引言与模型架构
  • windows10环境下的Deepseek本地部署及接口调用
  • 网络安全威胁框架与入侵分析模型概述
  • 【PostgreSQL内核学习 —— (WindowAgg(三))】
  • golang命令大全12--命令速查表
  • Vue学习综合案例(四)
  • Spring的三级缓存如何解决循环依赖问题