当前位置: 首页 > article >正文

AI人工智能机器学习之神经网络

1、概要

  本篇学习AI人工智能机器学习之神经网络,以MLPClassifier和MLPRegressor为例,从代码层面讲述最常用的神经网络模型MLP。

2、神经网络 - 简介

在 Scikit-learn 中,神经网络是通过 sklearn.neural_network 模块提供的。最常用的神经网络模型是多层感知器(MLP,Multi-layer Perceptron),它可以用于分类和回归任务。

一些基本的概念

  • 多层感知器(MLP):一种前馈神经网络,由输入层、隐藏层和输出层组成。每一层的节点与下一层的节点是全连接的。
  • 激活函数:每个神经元通常会有一个激活函数,如 ReLU、Sigmoid 和 Tanh,用于引入非线性。
  • 损失函数:在训练过程中用来评估模型性能的函数,MLP 允许使用不同的损失函数,具体取决于任务(如分类或回归)。
  • 优化算法:用于更新网络权重的算法,最常用的是随机梯度下降(SGD)和 Adam。

本篇,以两个示例讲述神经网络MLP的使用方法:

  • 示例1:MLPClassifier对数据集进行分类
  • 示例2:MLPRegressor对数据进行回归

本篇相关资料代码参见:AI人工智能机器学习相关知识资源及使用的示例代码

3、神经网络

3.1、安装依赖

python安装机器学习库: pip install scikit-learn

3.2、示例1: MLPClassifier对数据集进行分类
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# 生成模拟数据集
X, y = make_classification(n_samples=100, n_features=3, n_informative=2, n_redundant=0, n_classes=2, random_state=42)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


def test_MLPClassifier():
    # 创建神经网络分类器实例
    # hidden_layer_sizes参数,一个元组,定义隐藏层的结构,例如 (100, 50) 表示有两个隐藏层,第一层100神经元,第二层 50 个神经元。
    # max_iter参数,最大迭代次数
    # random_state:随机种子,使结果可重复
    model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=42)

    # 训练模型
    model.fit(X_train, y_train)

    # 进行预测
    y_pred = model.predict(X_test)

    # 计算准确率
    accuracy = accuracy_score(y_test, y_pred)
    print(f"准确率为: {accuracy:.2f}")
    # 计算混淆矩阵
    print(confusion_matrix(y_test, y_pred))
    # 报告
    print(classification_report(y_test, y_pred))

test_MLPClassifier()

运行上述代码的输出:

准确率为: 0.93
[[10  1]
 [ 1 18]]
              precision    recall  f1-score   support

           0       0.91      0.91      0.91        11
           1       0.95      0.95      0.95        19

    accuracy                           0.93        30
   macro avg       0.93      0.93      0.93        30
weighted avg       0.93      0.93      0.93        30
3.3、示例2:MLPRegressor对数据进行回归
from sklearn.neural_network import MLPRegressor
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

# 创建回归数据集
X, y = make_regression(n_samples=100, n_features=1, noise=0.1, random_state=42)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

def test_MLPRegressor():
    # 创建和训练 MLPRegressor
    mlp_regressor = MLPRegressor(hidden_layer_sizes=(10,), max_iter=1000, random_state=42)
    mlp_regressor.fit(X_train, y_train)
    
    # 进行预测
    y_pred = mlp_regressor.predict(X_test)
    
    # 打印预测结果
    print("Predicted values:", y_pred)
    print("True values:", y_test)  

test_MLPRegressor()

运行上述代码的输出:

Predicted values: [-15.9535171   11.18083111   6.66419755  -6.93420128  -4.88154814
  -5.62929653  -7.89092833 -19.22598705   6.72784808   7.47032208
   8.14723448   2.80206811 -15.14571795  -8.72301651 -14.61559911
  -8.06564202   7.77080054   1.30566751   6.16147072   3.04358961]
True values: [-55.37503843  61.96236579  34.0206566  -16.26246864  -9.75232562
 -12.0363855  -19.53933098 -73.53859117  34.32170107  38.9917296
  42.89105035  14.96006767 -50.87199832 -22.1085758  -48.12392116
 -19.9786311   40.84203409  10.11000622  30.8780412   15.82045024]

4、 总结

本篇以MLPClassifier和MLPRegressor为例,从代码层面讲述最常用的神经网络模型MLP。虽然sklearn提供了接口来构建和训练神经网络模型,但是对于复杂的复杂的神经网络模型,推荐使用 TensorFlow 或 PyTorch 等库。


http://www.kler.cn/a/568042.html

相关文章:

  • springBoot连接远程Redis连接失败(已解决)
  • 最新Git入门到精通完整教程
  • Python办公自动化教程(008):设置excel单元格边框和背景颜色
  • Windows 11 下正确安装 Docker Desktop 到 D 盘的完整教程
  • EasyRTC嵌入式WebRTC技术与AI大模型结合:从ICE框架优化到AI推理
  • 基于 SSM+Vue的 车辆管理系统 系统的设计与实现
  • Brave 132 编译指南 Android 篇 - 配置编译环境 (五)
  • 从JSON过滤到编程范式:深入理解JavaScript数据操作
  • MySQL在线、离线安装
  • 蓝桥杯备考:DFS剪枝之数的划分
  • 机器学习数学基础:33.分半信度
  • 区块链的原理、技术与应用场景
  • 金融项目管理:合规性与风险管理的实战指南
  • C#上位机--关键字
  • 松灵机器人地盘 安装 ros 驱动 并且 发布ros 指令进行控制
  • [Windows] 批量为视频或者音频生成字幕 video subtitle master 1.5.2
  • 网络安全深度剖析
  • Tomcat 8 安装包下载
  • 2025影视站群程序实战:search聚合版/无缓存泛页面刷新不变
  • github上传代码(自用)