KNN算法与实战案例详解
目录
- 一、KNN算法原理
-
- 1.样本距离公式
- 2.特征标准化
- 二、 实战:使用KNN完成鸢尾花分类
-
- 1.数据加载与预处理
- 2.KNN模型构建与训练
- 3.模型评估
- 三、交叉验证与K折交叉验证
-
- 1.什么是交叉验证?
- 2.K折交叉验证
- 四、实战:手写数字图片数据集分类与调参
-
- 1.加载数据与可视化
- 2.交叉验证调参
- 3.使用最优超参数进行训练与测试
- 五、网格搜索优化超参数
-
- 1.什么是网格搜索?
- 2.使用网格搜索调参
- 六、总结与未来展望
KNN(K-Nearest Neighbors, K近邻算法)是机器学习中一种经典的监督学习算法,常用于分类和回归问题。其基本思想可以通过一句俗语概括——“近朱者赤,近墨者黑”,即根据目标数据点附近的样本来决定其类别或值。KNN以其直观性和实现简单而受到广泛使用,尤其在分类问题中表现出色。
本文将对KNN算法的基础原理进行详细介绍,并通过实际案例展示如何使用该算法解决鸢尾花分类问题和手写数字识别问题。同时,还会讨论如何利用交叉验证和网格搜索来优化KNN模型的超参数。
一、KNN算法原理
KNN算法的核心思想是,给定一个样本点,寻找其在特征空间中最接近的K个邻居,根据这些邻居的类别来对样本点进行分类。如果是分类任务,则通过邻居投票决定样本的类别;如果是回归任务,则通常通过计算邻居的均值来预测目标值。
1.样本距离公式
在KNN算法中,样本之间的距离是至关重要的。常见的距离度量方式包括:
-
欧几里得距离:计算两个点之间的直线距离,是最常用的距离度量方式。
公式:
-
曼哈顿距离:计算两个点在各坐标轴方向上的距离之和。
公式:
-
明可夫斯基距离:是一种更广泛的距离计算方式,其中,p是可调节的超参数。
公式:
当 p=2 时,它是欧几里得距离;当 p=1 时,它是曼哈顿距离。
2.特征标准化
在计算样本距离时,如果不同特征的量纲不一致(如一个特征是毫米,另一个是千米),某些特征可能会主导距离计算。因此,在使用KNN时,我们通常需要对特征进行标准化。
Z-score标准化 是常用的方法,其公式为:
其中 μ 是特征的均值,σ 是特征的标准差。通过标准化,所有特征将转换为均值为0,标准差为1的标准正态分布。
在sklearn中,可以使用 StandardScaler 实现Z-score标准化。
from sklearn.preprocessing import StandardScaler
std = StandardScaler()
X_train_standard = std.fit_transform(X_train) # 对训练数据进行标准化
X_test_standard = std.transform(X_test) # 对测试数据使用相同的标准化
二、 实战:使用KNN完成鸢尾花分类
鸢尾花数据集是机器学习中的经典数据集,包含150个样本,每个样本有4个特征,分为3类。我们将使用KNN算法对该数据集进行分类。
1.数据加载与预处理
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data # 样本特征
y = iris.target # 样本标签
# 数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)
# 对特征进行标准化
std = StandardScaler()
X_train_standard = std.fit_transform(X_train)
X_test_standard <