当前位置: 首页 > article >正文

第一个机器学习应用:鸢尾花分类

目录

1. 特征数据与标签数据

2. 训练数据与测试数据

3. 构建模型机器学习模型

4. 预测与评估

4.1 预测

4.2 评估 

5. 学习小结


准备:采用Scikit- learn中鸢尾花数据集,完成一个简单的机器学习应用,构建第一个机器学习模型。

已知:这些花已经被植物学专家鉴定为三个类别:setosa、versicolor、virginica

问题:要在多个选项(3个鸢尾花类别)中预测其中一个鸢尾花的品种,这是一个分类问题。可能的输出叫作类别。数据集中每朵鸢尾花都属于三个类别之一,所以这是一个三分类问题。单个数据点(一朵鸢尾花)的预期输出是这朵花的品种。对于一个数据点来说,它的品种叫作标签。

关键函数:load_iris() 载入iris数据集

1. 特征数据与标签数据

#鸢尾花⚜️分类
from sklearn.datasets import load_iris
iris_dataset = load_iris()
# 返回一个bunch对象,它直接继承自dict类,与字典类似,由键值对组成。
# 同样可以使用bunch.keys(),bunch.values(),bunch.items()等方法
print(iris_dataset.keys())
# out1:dict_keys(['data', 'target', 'frame', 'target_names', 'DESCR', 
# 'feature_names', 'filename', 'data_module'])

#鸢尾花的data属性:iris鸢尾花数据集内包含3类分别为setosa、versicolor、virginica
#共150条记录,每类各50个数据,每条记录都有4项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度
#可以通过这4个特征预测鸢尾花卉属于哪一品种


print(type(iris_dataset['data']))
# out2:<class 'numpy.ndarray'>
# 输出表示iris数据集的类型为ndarray类型

print(iris_dataset['data'].shape)
# out3:(150, 4)
# 查看了数据集的形状,表示iris数据集有150朵花,每朵花有4列,即4个特征,
# 机器学习的个体叫作样本,其属性叫作特征,data数组的形状是样本数乘以特征数。

print(iris_dataset['target'].shape)
# out4:(150,)
# 数据集'target'表示标签,即鸢尾花的类别,其数值为150,表明有150个标签。

print(iris_dataset['target_names'])
# out5:['setosa' 'versicolor' 'virginica'] 输出鸢尾花标签的名称
# target_names为'setosa' 'versicolor' 'virginica'

print(iris_dataset['target'])
#out6:[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
#0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
#1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
#2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
#2 2]
#数字的代表含义由iris['target_names']数组给出:0代表setosa,1代表versicolor,
# 2

print(iris_dataset['data'][:6])

# out6:[5.1 3.5 1.4 0.2]
# [4.9 3.  1.4 0.2]
# [4.7 3.2 1.3 0.2]
# [4.6 3.1 1.5 0.2]
# [5.  3.6 1.4 0.2]
# [5.4 3.9 1.7 0.4]]

#显示前6行的特征数据,每一列为鸢尾花的1个特征数据,4列分别为鸢尾花的花瓣的长度、宽度,花萼的长度、宽度。

2. 训练数据与测试数据

鸢尾花分为两个子集:

训练集:用于训练模型。

测试集:用于测试训练后模型的性能。

训练集数据用于算法的学习,构建模型。

机器学习将训练好的模型应用于新的数据,判断这个训练的模型是否可用,需要有评估模型性能的方法,故将测试集数据用于评估模型的性能。

# 利用scikit- learn中的train_test_split函数可以将数据划分为训练集和测试集。
from sklearn.model_selection import train_test_split #train_test_split可以返回四个数据集

X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)
# 一般会随机选取一个random_state的值作为参数
print(X_train[:5])

print("X_train:{}".format(X_train[:10]))
print("y_train:{}".format(y_train[:10]))

# X_train:[[5.9 3.  4.2 1.5]
# [5.8 2.6 4.  1.2]
# [6.8 3.  5.5 2.1]
# [4.7 3.2 1.3 0.2]
# [6.9 3.1 5.1 2.3]
# [5.  3.5 1.6 0.6]
# [5.4 3.7 1.5 0.2]
# [5.  2.  3.5 1. ]
# [6.5 3.  5.5 1.8]
# [6.7 3.3 5.7 2.5]]
# y_train:[1 1 2 0 2 0 0 1 2 2]

3. 构建模型机器学习模型

K近邻算法的核心思想是:依据统计学的理论看它所处的位置特征,衡量它周围邻居的数量,而把它归为(或分配)到数量更多的那一类。

在Scikit- learn中,K近邻分类算法是在neighbors模块的KNeighborClassifier类中实现的。需要将这个类实例化为一个对象,才能使用这个模型。注意使用时需要设置模型的参数。

from sklearn.neighbors import KNeighborsClassifier
# 实例化对象,“n_neighbors=1”表示选取一个邻居,knn.fit()表示训练模型
knn=KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train,y_train)

4. 预测与评估

4.1 预测

问题:发现一朵鸢尾花,花萼长5cm、宽2.9cm,花瓣长1cm、宽0.2cm,这朵花属于哪个品种?

# 将输入数据放在一个numpy数组中,使用predict方法进行预测

import numpy as np
Xnew = np.array([[5, 2.9, 1, 0.2]])
prediction = knn.predict(Xnew)
print("Prediction:{}".format(prediction))
print("Predictiontargetname:{}".format(iris_dataset['target_names'][prediction]))

# out8:
# Prediction:[0]
# Predictiontargetname:['setosa']

ValueError: Expected 2D array, got 1D array instead: array=[5. 2.9 1. 0.2]. Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

包错原因:Xnew = np.array([5, 2.9, 1, 0.2])

订正:Xnew = np.array([[5, 2.9, 1, 0.2]])

注意:

在建立一个数组的时候要将其转换为二维数组的一行,因为该函数的输入数据必须满足之前提到的约定:二维数组。

4.2 评估 

用到之前创建的测试集。

通过计算精度来衡量模型的性能,精度就是品种预测正确的花所占的比例。

print("Test set score:{}".format(knn.score(X_test,y_test)))

# out9: Test set score:0.9736842105263158

5. 学习小结

习题见上一篇博客!


(2023年4月19日 23:40首次发布) 


http://www.kler.cn/a/13822.html

相关文章:

  • 哪款无线洗地机最好用?好用的无线洗地机分享
  • 无线洗地机哪款性价比高?高性价比的洗地机分享
  • OSCP-Exfiltrated(Subrion、exiftool提权)
  • 功能安全ISO26262 道路车辆 功能安全审核及评估方法第3部分:软件层面
  • springcloudfeign原理和流程
  • OpenAI-ChatGPT最新官方接口《从0到1生产最佳实例》全网最详细中英文实用指南和教程,助你零基础快速轻松掌握全新技术(十一)(附源码)
  • 设备树常用of操作函数
  • UE4: Niagara系统实现雨天效果,并跟随人物移动
  • hadoop之MapReduce框架原理
  • Java基于POI动态合并单元格
  • 大语言模型-中文Langchain
  • ElasticSearch索引文档写入和近实时搜索
  • 86页2023年新型智慧城市顶层设计规划解决方案(ppt可编辑)
  • 火车站闸机web3d数字展示平台全方位动态呈现设备细节
  • MIT6.824 Lecture18 Fork Consistency
  • 赛题解析 | kaggle百万奖金新赛--图书墨水检测大赛
  • Zimbra 远程代码执行漏洞(CVE-2019-9670)漏洞分析
  • 非计算机专业如何转行成为程序员?我用亲身经历教你用这三种方法
  • 【LeetCode】数据结构题解(3)[查找链表中倒数第k个节点]
  • Mongo集群化部署+高可用架构