当前位置: 首页 > article >正文

Python深度学习实战-基于Sequential方法搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能

  1. 第一步:导入模块:import tensorflow as tf

  2. 第二步:制定输入网络的训练集和测试集

  3. 第三步:搭建网络结构:tf.keras.models.Sequential()

  4. 第四步:配置训练方法:model.compile():

  5. 第五步:执行训练过程:model.fit():

  6. 第六步:打印网络结构:model.summary()

  7. 第七步:执行验证过程:model.evaluate()

实现代码

import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 数据预处理
scaler = StandardScaler()
X = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(X.shape[1],)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(len(set(y)), activation='softmax')
])

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(X_train, y_train, epochs=10, batch_size=32)
model.summary()
# 评估模型
test_loss, test_accuracy = model.evaluate(X_test, y_test)

实现效果

本人读研期间发表5篇SCI数据挖掘相关论文,现在某研究院从事数据挖掘相关科研工作,对数据挖掘有一定认知和理解,会结合自身科研实践经历不定期分享关于python、机器学习、深度学习基础知识与案例。

致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

邀请三个朋友关注本订阅号V:数据杂坛,即可在后台联系我获取相关数据集和源码,送有关数据分析、数据挖掘、机器学习、深度学习相关的电子书籍。


http://www.kler.cn/a/107344.html

相关文章:

  • SHA-256哈希函数
  • 高效稳定!新加坡服务器托管方案助力企业全球化布局
  • Android 10 默认授权安装app运行时权限(去掉运行时所有权限授权弹窗)
  • 大数据新视界 -- 大数据大厂之 Impala 性能飞跃:动态分区调整的策略与方法(上)(21 / 30)
  • 使用etl工具kettle的日常踩坑梳理之二、从Hadoop中导出数据
  • 继承和多态(上)
  • 基于GPIO子系统编写LED驱动,编写应用程序进行测试设置定时器,5秒钟打印一次hello world
  • 软考 系统架构设计师系列知识点之设计模式(4)
  • GoLong的学习之路(十四)语法之标准库 time(时间包)的使用
  • MySQL语言分类
  • 论文阅读 - Hidden messages: mapping nations’ media campaigns
  • Android原生项目集成uniMPSDK(Uniapp)遇到的报错总结
  • 在 macOS 上的多个 PHP 版本之间切换
  • 李沐——论文阅读——VIT(VIsionTransformer)
  • 使用Gateway解决跨域问题时配置文件不生效的情况之一
  • CTF-php特性绕过
  • 一次不接受官方建议导致的事故
  • 软考高项-计算题(3)
  • 【LeetCode】5. 最长回文子串
  • 10月28日,每日信息差
  • HarmonyOS开发:探索组件化模式开发
  • Flink CDC 2.0 主要是借鉴 DBLog 算法
  • PostgreSQL basebackup备份和恢复
  • 闲聊一下写技术博客的一些感想
  • Dijkstra算法基础详解,附有练习题
  • OpenAI大模型项目计划表(InsCode AI 创作助手)