当前位置: 首页 > article >正文

使用KNN实现对鸢尾花数据集或者自定义数据集的的预测

创建自定义数据集:

point1=[[7.7,6.1],[3.1,5.9],[8.6,8.8],[9.5,7.3],[3.9,7.4],[5.0,5.3],[1.0,7.3]]
point2=[[0.2,2.2],[4.5,4.1],[0.5,1.1],[2.7,3.0],[4.7,0.2],[2.9,3.3],[7.3,7.9]]
point3=[[9.2,0.7],[9.2,2.1],[7.3,4.5],[8.9,2.9],[9.5,3.7],[7.7,3.7],[9.4,2.4]]
point_concat = np.concatenate((point1, point2, point3), axis=0)
point_concat_label = np.concatenate((np.zeros(len(point1)), np.ones(len(point2)), np.ones(len(point2)) + 1), axis=0)
print(point_concat_label)

并对以上数据集进行预测

完整代码:

from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import matplotlib.pyplot as plt

point1=[[7.7,6.1],[3.1,5.9],[8.6,8.8],[9.5,7.3],[3.9,7.4],[5.0,5.3],[1.0,7.3]]
point2=[[0.2,2.2],[4.5,4.1],[0.5,1.1],[2.7,3.0],[4.7,0.2],[2.9,3.3],[7.3,7.9]]
point3=[[9.2,0.7],[9.2,2.1],[7.3,4.5],[8.9,2.9],[9.5,3.7],[7.7,3.7],[9.4,2.4]]
point_concat = np.concatenate((point1, point2, point3), axis=0)
point_concat_label = np.concatenate((np.zeros(len(point1)), np.ones(len(point2)), np.ones(len(point2)) + 1), axis=0)
print(point_concat_label)

n_neighbors = 3
knn = KNeighborsClassifier(n_neighbors=n_neighbors, algorithm='kd_tree', p=2)

knn.fit(point_concat, point_concat_label)

x1 = np.linspace(0, 10, 100)
y1 = np.linspace(0, 10, 100)
x_axis, y_axis = np.meshgrid(x1, y1)
print('s')

xy_axis=np.c_[x_axis.ravel(),y_axis.ravel()]
knn_predict_result=knn.predict(xy_axis)

fig=plt.figure(figsize=(5,5))
ax=fig.add_subplot(111)
ax.contour(x_axis,y_axis,knn.predict(xy_axis).reshape(x_axis.shape))

ax.scatter(point_concat[point_concat_label == 0, 0], point_concat[point_concat_label == 0, 1],color='r', marker='^')
ax.scatter(point_concat[point_concat_label == 1, 0], point_concat[point_concat_label == 1, 1],color='g', marker='*')
ax.scatter(point_concat[point_concat_label == 2, 0], point_concat[point_concat_label == 2, 1],color='b', marker='s')
plt.show()

输出结果: 

对鸢尾花数据集:

完整代码:

from sklearn.datasets import load_iris
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier


iris = load_iris()
iris_data1 = pd.DataFrame(data=iris['data'], columns = ['Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width'])
print('')
iris_data1['target']=iris['target']
def plot_iris(data,col1,col2):
    sns.lmplot(x=col1,y=col2,data=data,hue='target',fit_reg=False)
    plt.title('data show')
    plt.xlabel(col1)
    plt.ylabel(col2)
    plt.show()

plot_iris(iris_data1,'Sepal_Length','Petal_Width')
x_train,x_test,y_train,y_test=train_test_split(iris['data'],iris['target'],test_size=0.3,random_state=42)
print("训练集的特征值是 : \n", x_train)
print("测试集的特征值是 : \n", x_test)
print("训练集的目标值是 : \n", y_train)
print("测试集的目标值是 : \n", y_test)

print("训练集的特征值形状 : \n", x_train.shape)
print("测试集的特征值形状 : \n", x_test.shape)
print("训练集的目标值形状 : \n", y_train.shape)
print("测试集的目标值形状 : \n", y_test.shape)

transfer=MinMaxScaler(feature_range=(0,1))

transfer1=StandardScaler()
ret_train_data=transfer1.fit_transform(x_train)
ret_test_data=transfer1.fit_transform(x_test)

n_neighbors = 5
knn = KNeighborsClassifier(n_neighbors=n_neighbors)
knn.fit(ret_train_data, y_train)

y_pre=knn.predict(ret_test_data)
print('预测值是\n',y_pre)
print("预测值和真实值的对比是:\n",y_pre==y_test)
score=knn.score(ret_test_data,y_test)
print(f'准确率是:{score}')

结果:

训练集的特征值是 : 
 [[5.5 2.4 3.7 1. ]
 [6.3 2.8 5.1 1.5]
 [6.4 3.1 5.5 1.8]
 [6.6 3.  4.4 1.4]
 [7.2 3.6 6.1 2.5]
 [5.7 2.9 4.2 1.3]
 [7.6 3.  6.6 2.1]
 [5.6 3.  4.5 1.5]
 [5.1 3.5 1.4 0.2]
 [7.7 2.8 6.7 2. ]
 [5.8 2.7 4.1 1. ]
 [5.2 3.4 1.4 0.2]
 [5.  3.5 1.3 0.3]
 [5.1 3.8 1.9 0.4]
 [5.  2.  3.5 1. ]
 [6.3 2.7 4.9 1.8]
 [4.8 3.4 1.9 0.2]
 [5.  3.  1.6 0.2]
 [5.1 3.3 1.7 0.5]
 [5.6 2.7 4.2 1.3]
 [5.1 3.4 1.5 0.2]
 [5.7 3.  4.2 1.2]
 [7.7 3.8 6.7 2.2]
 [4.6 3.2 1.4 0.2]
 [6.2 2.9 4.3 1.3]
 [5.7 2.5 5.  2. ]
 [5.5 4.2 1.4 0.2]
 [6.  3.  4.8 1.8]
 [5.8 2.7 5.1 1.9]
 [6.  2.2 4.  1. ]
 [5.4 3.  4.5 1.5]
 [6.2 3.4 5.4 2.3]
 [5.5 2.3 4.  1.3]
 [5.4 3.9 1.7 0.4]
 [5.  2.3 3.3 1. ]
 [6.4 2.7 5.3 1.9]
 [5.  3.3 1.4 0.2]
 [5.  3.2 1.2 0.2]
 [5.5 2.4 3.8 1.1]
 [6.7 3.  5.  1.7]
 [4.9 3.1 1.5 0.2]
 [5.8 2.8 5.1 2.4]
 [5.  3.4 1.5 0.2]
 [5.  3.5 1.6 0.6]
 [5.9 3.2 4.8 1.8]
 [5.1 2.5 3.  1.1]
 [6.9 3.2 5.7 2.3]
 [6.  2.7 5.1 1.6]
 [6.1 2.6 5.6 1.4]
 [7.7 3.  6.1 2.3]
 [5.5 2.5 4.  1.3]
 [4.4 2.9 1.4 0.2]
 [4.3 3.  1.1 0.1]
 [6.  2.2 5.  1.5]
 [7.2 3.2 6.  1.8]
 [4.6 3.1 1.5 0.2]
 [5.1 3.5 1.4 0.3]
 [4.4 3.  1.3 0.2]
 [6.3 2.5 4.9 1.5]
 [6.3 3.4 5.6 2.4]
 [4.6 3.4 1.4 0.3]
 [6.8 3.  5.5 2.1]
 [6.3 3.3 6.  2.5]
 [4.7 3.2 1.3 0.2]
 [6.1 2.9 4.7 1.4]
 [6.5 2.8 4.6 1.5]
 [6.2 2.8 4.8 1.8]
 [7.  3.2 4.7 1.4]
 [6.4 3.2 5.3 2.3]
 [5.1 3.8 1.6 0.2]
 [6.9 3.1 5.4 2.1]
 [5.9 3.  4.2 1.5]
 [6.5 3.  5.2 2. ]
 [5.7 2.6 3.5 1. ]
 [5.2 2.7 3.9 1.4]
 [6.1 3.  4.6 1.4]
 [4.5 2.3 1.3 0.3]
 [6.6 2.9 4.6 1.3]
 [5.5 2.6 4.4 1.2]
 [5.3 3.7 1.5 0.2]
 [5.6 3.  4.1 1.3]
 [7.3 2.9 6.3 1.8]
 [6.7 3.3 5.7 2.1]
 [5.1 3.7 1.5 0.4]
 [4.9 2.4 3.3 1. ]
 [6.7 3.3 5.7 2.5]
 [7.2 3.  5.8 1.6]
 [4.9 3.6 1.4 0.1]
 [6.7 3.1 5.6 2.4]
 [4.9 3.  1.4 0.2]
 [6.9 3.1 4.9 1.5]
 [7.4 2.8 6.1 1.9]
 [6.3 2.9 5.6 1.8]
 [5.7 2.8 4.1 1.3]
 [6.5 3.  5.5 1.8]
 [6.3 2.3 4.4 1.3]
 [6.4 2.9 4.3 1.3]
 [5.6 2.8 4.9 2. ]
 [5.9 3.  5.1 1.8]
 [5.4 3.4 1.7 0.2]
 [6.1 2.8 4.  1.3]
 [4.9 2.5 4.5 1.7]
 [5.8 4.  1.2 0.2]
 [5.8 2.6 4.  1.2]
 [7.1 3.  5.9 2.1]]
测试集的特征值是 : 
 [[6.1 2.8 4.7 1.2]
 [5.7 3.8 1.7 0.3]
 [7.7 2.6 6.9 2.3]
 [6.  2.9 4.5 1.5]
 [6.8 2.8 4.8 1.4]
 [5.4 3.4 1.5 0.4]
 [5.6 2.9 3.6 1.3]
 [6.9 3.1 5.1 2.3]
 [6.2 2.2 4.5 1.5]
 [5.8 2.7 3.9 1.2]
 [6.5 3.2 5.1 2. ]
 [4.8 3.  1.4 0.1]
 [5.5 3.5 1.3 0.2]
 [4.9 3.1 1.5 0.1]
 [5.1 3.8 1.5 0.3]
 [6.3 3.3 4.7 1.6]
 [6.5 3.  5.8 2.2]
 [5.6 2.5 3.9 1.1]
 [5.7 2.8 4.5 1.3]
 [6.4 2.8 5.6 2.2]
 [4.7 3.2 1.6 0.2]
 [6.1 3.  4.9 1.8]
 [5.  3.4 1.6 0.4]
 [6.4 2.8 5.6 2.1]
 [7.9 3.8 6.4 2. ]
 [6.7 3.  5.2 2.3]
 [6.7 2.5 5.8 1.8]
 [6.8 3.2 5.9 2.3]
 [4.8 3.  1.4 0.3]
 [4.8 3.1 1.6 0.2]
 [4.6 3.6 1.  0.2]
 [5.7 4.4 1.5 0.4]
 [6.7 3.1 4.4 1.4]
 [4.8 3.4 1.6 0.2]
 [4.4 3.2 1.3 0.2]
 [6.3 2.5 5.  1.9]
 [6.4 3.2 4.5 1.5]
 [5.2 3.5 1.5 0.2]
 [5.  3.6 1.4 0.2]
 [5.2 4.1 1.5 0.1]
 [5.8 2.7 5.1 1.9]
 [6.  3.4 4.5 1.6]
 [6.7 3.1 4.7 1.5]
 [5.4 3.9 1.3 0.4]
 [5.4 3.7 1.5 0.2]]
训练集的目标值是 : 
 [1 2 2 1 2 1 2 1 0 2 1 0 0 0 1 2 0 0 0 1 0 1 2 0 1 2 0 2 2 1 1 2 1 0 1 2 0
 0 1 1 0 2 0 0 1 1 2 1 2 2 1 0 0 2 2 0 0 0 1 2 0 2 2 0 1 1 2 1 2 0 2 1 2 1
 1 1 0 1 1 0 1 2 2 0 1 2 2 0 2 0 1 2 2 1 2 1 1 2 2 0 1 2 0 1 2]
测试集的目标值是 : 
 [1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0 1 0 0 2 1
 0 0 0 2 1 1 0 0]
训练集的特征值形状 : 
 (105, 4)
测试集的特征值形状 : 
 (45, 4)
训练集的目标值形状 : 
 (105,)
测试集的目标值形状 : 
 (45,)
预测值是
 [1 0 2 2 2 0 1 2 1 1 2 0 0 0 0 2 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0 1 0 0 2 1
 0 0 0 2 1 1 0 0]
预测值和真实值的对比是:
 [ True  True  True False False  True  True  True  True  True  True  True
  True  True  True False  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True]
准确率是:0.9333333333333333


http://www.kler.cn/a/516034.html

相关文章:

  • Docker Desktop 在Windows 环境中开发、测试和运行容器化的应用程序
  • AIGC视频生成模型:ByteDance的PixelDance模型
  • Spring Boot整合Thymeleaf、JDBC Template与MyBatis配置详解
  • 第14章:Python TDD应对货币类开发变化(一)
  • C#,入门教程(06)——解决方案资源管理器,代码文件与文件夹的管理工具
  • MyBatis最佳实践:提升数据库交互效率的秘密武器
  • 基于JAVA的微信点餐小程序设计与实现(LW+源码+讲解)
  • FCA-FineReport试卷
  • 数据挖掘常用算法模型简介
  • 有关Android Studio的安装与配置并实现helloworld(有jdk的安装与配置)(保姆级教程)
  • 云计算和服务器
  • 软件工程的本质特征
  • 无人机高速无刷动力电机核心设计技术
  • Python 之 Excel 表格常用操作
  • 考研机试:学分绩点
  • linux 扩容
  • MySQL 中开启二进制日志(Binlog)
  • 0164__【GNU】gcc -O编译选项 -Og -O0 -O1 -O2 -O3 -Os
  • three.js+WebGL踩坑经验合集(1):THREE.Line无故消失的元凶
  • c++-------------------------继承
  • 神经网络梯度爆炸的原因及解决方案
  • 10个非常基础的 Javascript 问题
  • Seata进阶全文详解(集成Nacos及SpringCloud配置)
  • web服务器 网站部署的架构
  • 三格电子新品丨三菱Q系列PLC转网口
  • Kotlin基础知识学习(四)