当前位置: 首页 > article >正文

KNN算法与实战案例详解

目录

  • 一、KNN算法原理
    • 1.样本距离公式
    • 2.特征标准化
  • 二、 实战:使用KNN完成鸢尾花分类
    • 1.数据加载与预处理
    • 2.KNN模型构建与训练
    • 3.模型评估
  • 三、交叉验证与K折交叉验证
    • 1.什么是交叉验证?
    • 2.K折交叉验证
  • 四、实战:手写数字图片数据集分类与调参
    • 1.加载数据与可视化
    • 2.交叉验证调参
    • 3.使用最优超参数进行训练与测试
  • 五、网格搜索优化超参数
    • 1.什么是网格搜索?
    • 2.使用网格搜索调参
  • 六、总结与未来展望

KNN(K-Nearest Neighbors, K近邻算法)是机器学习中一种经典的监督学习算法,常用于分类和回归问题。其基本思想可以通过一句俗语概括——“近朱者赤,近墨者黑”,即根据目标数据点附近的样本来决定其类别或值。KNN以其直观性和实现简单而受到广泛使用,尤其在分类问题中表现出色。

本文将对KNN算法的基础原理进行详细介绍,并通过实际案例展示如何使用该算法解决鸢尾花分类问题和手写数字识别问题。同时,还会讨论如何利用交叉验证和网格搜索来优化KNN模型的超参数。

一、KNN算法原理

KNN算法的核心思想是,给定一个样本点,寻找其在特征空间中最接近的K个邻居,根据这些邻居的类别来对样本点进行分类。如果是分类任务,则通过邻居投票决定样本的类别;如果是回归任务,则通常通过计算邻居的均值来预测目标值。

1.样本距离公式

在KNN算法中,样本之间的距离是至关重要的。常见的距离度量方式包括:

  • 欧几里得距离:计算两个点之间的直线距离,是最常用的距离度量方式。

    公式:
    在这里插入图片描述

  • 曼哈顿距离:计算两个点在各坐标轴方向上的距离之和。

    公式:
    在这里插入图片描述

  • 明可夫斯基距离:是一种更广泛的距离计算方式,其中,p是可调节的超参数。

    公式:
    在这里插入图片描述
    当 p=2 时,它是欧几里得距离;当 p=1 时,它是曼哈顿距离。

2.特征标准化

在计算样本距离时,如果不同特征的量纲不一致(如一个特征是毫米,另一个是千米),某些特征可能会主导距离计算。因此,在使用KNN时,我们通常需要对特征进行标准化。
Z-score标准化 是常用的方法,其公式为:在这里插入图片描述
其中 μ 是特征的均值,σ 是特征的标准差。通过标准化,所有特征将转换为均值为0,标准差为1的标准正态分布。
在sklearn中,可以使用 StandardScaler 实现Z-score标准化。

from sklearn.preprocessing import StandardScaler

std = StandardScaler()
X_train_standard = std.fit_transform(X_train)  # 对训练数据进行标准化
X_test_standard = std.transform(X_test)  # 对测试数据使用相同的标准化

二、 实战:使用KNN完成鸢尾花分类

鸢尾花数据集是机器学习中的经典数据集,包含150个样本,每个样本有4个特征,分为3类。我们将使用KNN算法对该数据集进行分类。

1.数据加载与预处理

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data  # 样本特征
y = iris.target  # 样本标签

# 数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)

# 对特征进行标准化
std = StandardScaler()
X_train_standard = std.fit_transform(X_train)
X_test_standard <

http://www.kler.cn/news/310784.html

相关文章:

  • 基于51单片机的自动清洗系统(自动洗衣机)
  • 【QT】系统-上
  • ​补​充​元​象​二​面​
  • Hexo框架学习——从安装到配置
  • SpringBoot:解析excel
  • PowerBI 关于FILTERS函数和VALUES函数
  • Spring模块详解Ⅳ(Spring ORM和Spring Transaction)
  • RedisTemplate混用带来的序列化问题
  • json.dumps 中的参数
  • 预警提醒并生成日志,便于后期追溯的智慧地产开源了
  • 让IT部门弄一个炫酷的数字驾驶舱就是数字化转型成功?
  • Vue 3 中动态赋值 ref 的应用
  • windows下使用 vscode 远程X11服务GUI显示的三种方法
  • 从种草到销售:家居品牌构建O2O私域运营的完整闭环
  • 考研数学精解【3】
  • 四、(JS)JS中常见的加载事件
  • 软考(中级-软件设计师)(0919)
  • 百度Android IM SDK组件能力建设及应用
  • Golang、Python、C语言、Java的圆桌会议
  • https和http区别
  • 【网络】TCP/IP 五层网络模型:网络层
  • 计算机专业毕设-校园新闻网站
  • vue实现二维码生成器应用
  • 【ARM】Cache深度解读
  • redis 在企业开发实践中注意事项
  • MATLAB中的无线通信系统部署和优化工具有哪些
  • 【Leetcode152】分割回文串(回溯 | 递归)
  • python 实现double factorial recursive双阶乘递归算法
  • 运行npm install 时,卡在sill idealTree buildDeps没有反应
  • 固件升级之Bootloader(三)