当前位置: 首页 > article >正文

4.sklearn-K近邻算法、模型选择与调优

文章目录

  • 环境配置(必看)
  • 头文件引用
    • 1.sklearn转换器和估计器
        • 1.1 转换器 - 特征工程的父类
        • 1.2 估计器(sklearn机器学习算法的实现)
    • 2.K-近邻算法
      • 2.1 简介:
      • 2.2 K-近邻算法API
      • 2.3 K-近邻算法代码
      • 2.4 运行结果
      • 2.5 K-近邻算法优缺点
    • 3.模型选择与调优
      • 3.1 交叉验证(cross validation)
      • 3.2 网格搜索(Grid Search)
      • 3.3 交叉验证,网格搜索(模型选择与调优)API:
      • 3.4 代码
      • 3.5 运行结果
  • 本章学习资源

环境配置(必看)

Anaconda-创建虚拟环境的手把手教程相关环境配置看此篇文章,本专栏深度学习相关的版本和配置,均按照此篇文章进行安装。

头文件引用

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier

1.sklearn转换器和估计器

1.1 转换器 - 特征工程的父类
1 实例化 (实例化的是一个转换器类(Transformer))
2 调用fit_transform(对于文档建立分类词频矩阵,不能同时调用)
  标准化:
  		(x - mean) / std	(特征 - 均值)/ 标准差
  		fit_transform()
  		fit()           	计算 每一列的平均值、标准差
  		transform()     	(x - mean) / std进行最终的转换
1.2 估计器(sklearn机器学习算法的实现)
1 实例化一个estimator
2 estimator.fit(x_train, y_train) 计算
  —— 调用完毕,模型生成
3 模型评估:
	1)直接比对真实值和预测值
   		y_predict = estimator.predict(x_test)
        y_test == y_predict
    2)计算准确率
        accuracy = estimator.score(x_test, y_test)

2.K-近邻算法

2.1 简介:

KNN核心思想:
   你的“邻居”来推断出你的类别
    1 K-近邻算法(KNN)原理
      k = 1
      容易受到异常点的影响
   如何确定谁是邻居?
       计算距离:
        	距离公式
        	欧氏距离  --  算法默认的是使用欧式距离
        	曼哈顿距离 绝对值距离
        	明可夫斯基距离

	如果取的k值不一样?会是什么结果?
        k 值取得过小,容易受到异常点的影响
        k 值取得过大,样本不均衡的影响

2.2 K-近邻算法API

sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')

API注释:

n_neighbors:
	int,可选(默认= 5),k_neighbors查询默认使用的邻居数
algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’}
	快速k近邻搜索算法,默认参数为auto,可以理解为算法自己决定合适的搜索算法。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索,
brute:
	是蛮力搜索,也就是线性扫描,当训练集很大时,计算非常耗时。
kd_tree:
	构造kd树存储数据以便对其进行快速检索的树形数据结构,kd树也就是数据结构中的二叉树。以中值切分构造的树,每个结点是一个超矩形,在维数小于20时效率高。
ball tree:
	是为了克服kd树高维失效而发明的,其构造过程是以质心C和半径r分割样本空间,每个节点是一个超球体

2.3 K-近邻算法代码

分析:

  1. x_test = transfer.transform(x_test),测试集只是使用transform进行标准化,是因为要和训练集x_train 做一样的处理,训练集调用transfer.fit_transform()计算出的均值,标准差的值均在模型中,x_test = transfer.transform(x_test)就是直接使用测试集的参数进行计算。
def knn_iris():
    """
    用KNN算法对鸢尾花进行分类
    :return:
    """

    # 1.获取数据
    iris = load_iris()
    # 2.划分数据集  参数:特征值,目标值,随机数种子
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=22)
    # 3.特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)                                 
    # 4.KNN算法预估器  n_neighbors=3就是K值等于3
    estimator = KNeighborsClassifier(n_neighbors=3)
    estimator.fit(x_train, y_train)
    # 5.模型评估
    # 方法1: 直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print(f"y_predict:\n{y_predict}")
    print(f"直接比对真实值和预测值: {y_test == y_predict}")
    # 方法2: 计算准确率
    score = estimator.score(x_test, y_test)
    print(f"准确率为: {score}")

    return None

2.4 运行结果

在这里插入图片描述

2.5 K-近邻算法优缺点

优点:简单,易于理解,易于实现,无需训练
缺点:
    1)必须指定K值,K值选择不当则分类精度不能保证
    2)懒惰算法,对测试样本分类时的计算量大,内存开销大
    使用场景:小数据场景,几千~几万样本,具体场景具体业务去测试

3.模型选择与调优

3.1 交叉验证(cross validation)

交叉验证:将拿到的训练数据,分为训练和验证集。以下图为例:将数据分成4份,其中一份作为验证集。然后经过4()的测试,每次都更换
不同的验证集。即得到4组模型的结果,取平均值作为最终结果。又称4折交叉验证。

在这里插入图片描述

3.2 网格搜索(Grid Search)

通常情况下,有很多参数是需要手动指定的(如k-近邻算法中的K值),这种叫超参数。但是手动过程繁杂,所以需要对模型预设几种超参数组合。
每组超参数都采用交叉验证来进行评估。最后选出最优参数组合建立模型。

在这里插入图片描述

3.3 交叉验证,网格搜索(模型选择与调优)API:

sklearn.model_selection.GridSearchCV(estimator, param_grid=None,cv=None)
对估计器的指定参数值进行详尽搜索
	estimator:估计器对象
	param_grid:估计器参数(dict){“n_neighbors”:[1,3,5]}
	cv:指定几折交叉验证
	fit:输入训练数据
	score:准确率
结果分析:
	bestscore__:在交叉验证中验证的最好结果
	bestestimator:最好的参数模型
	cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果

3.4 代码

def knn_iris_gscv():
    """
    用KNN算法对鸢尾花进行分类,添加网格搜索和交叉验证
    :return:
    """

    # 1.获取数据
    iris = load_iris()

    # 2.划分数据集
    x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=20)

    # 3.特征工程:标准化
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)

    # 4.KNN算法预估器
    estimator = KNeighborsClassifier()

    # 加入网格搜索和交叉验证
    # 参数准备
    param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11]}   # 网格搜索
    # cv=10 代表10折运算(交叉验证)
    estimator = GridSearchCV(estimator, param_grid=param_dict, cv=10)
    estimator.fit(x_train, y_train)

    # 5.模型评估
    # 方法1: 直接比对真实值和预测值
    y_predict = estimator.predict(x_test)
    print(f"y_predict:\n{y_predict}")
    print(f"直接比对真实值和预测值: {y_test == y_predict}")
    # 方法2: 计算准确率
    score = estimator.score(x_test, y_test)
    print(f"准确率为: {score}")

    # 最佳参数:
    print("最佳参数: \n", estimator.best_params_)
    # 最佳结果:
    print("最佳结果: \n", estimator.best_score_)
    # 最佳参数:
    print("最佳估计器: \n", estimator.best_estimator_)
    # 交叉验证结果:
    print("交叉验证结果: \n", estimator.cv_results_)
    return None

3.5 运行结果

在这里插入图片描述

本章学习资源

黑马程序员3天快速入门python机器学习
我是跟着视频进行的学习,欢迎大家一起来学习!


http://www.kler.cn/a/283004.html

相关文章:

  • Matplotlib库中show()函数的用法
  • 「Mac玩转仓颉内测版12」PTA刷题篇3 - L1-003 个位数统计
  • 【Excel】身份证号最后一位“X”怎么计算
  • JAVA:探索 EasyExcel 的技术指南
  • Cyberchef配合Wireshark提取并解析HTTP/TLS流量数据包中的文件
  • 三维测量与建模笔记 - 特征提取与匹配 - 4.2 梯度算子、Canny边缘检测、霍夫变换直线检测
  • MySQL集群技术1——编译部署mysql
  • “重启就能解决一切问题”,iPhone重启方法大揭秘
  • 解决:无法从域控制器读取配置信息
  • 2024.8.29 C++
  • C#面:ASP.NET MVC 中还有哪些注释属性用来验证?
  • RKNPU2从入门到实践 ---- 【8】借助 RKNN Toolkit lite2 在RK3588开发板上部署RKNN模型
  • 设计模式--装饰器模式
  • 理解torch.argmax() ,我是错误的
  • 融资和融券分别是什么意思,融资融券开通后能融到多少资金?
  • Datawhale X 李宏毅苹果书 AI夏令营_深度学习基础学习心得Task2.2
  • Java 入门指南:Java NIO —— Selector(选择器)
  • 【hot100篇-python刷题记录】【搜索二维矩阵】
  • 分布式锁的实现:ZooKeeper 的解决方案
  • hive数据迁移
  • 低代码革命:JNPF平台如何简化企业应用开发
  • Linux 中的中断响应机制
  • TCP keepalive和HTTP keepalive区别
  • SCP拷贝失败解决办法
  • 基于单片机的指纹识别考勤系统设计
  • Web应用服务器Tomcat