当前位置: 首页 > article >正文

机器学习(6):K 近邻算法

1 介绍

1.1 基本概念

        k近邻算法是一种基本分类和回归方法。K近邻算法(KNN),即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类,就把该输入实例分类到这个类中。(这就类似于现实生活中少数服从多数的思想):

        如上图所示,有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。这也就是我们的目的,来了一个新的数据点,我要得到它的类别是什么?好的,下面我们根据k近邻的思想来给绿色圆点进行分类。

  • 如果K=3,绿色圆点的最邻近的3个点是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
  • 如果K=5,绿色圆点的最邻近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。

        从上面例子我们可以看出,k近邻的算法思想非常的简单,也非常的容易理解,那么我们是不是就到此结束了,该算法的原理我们也已经懂了,也知道怎么给新来的点如何进行归类,只要找到离它最近的k个实例,哪个类别最多即可。

1.2 k的选取以及特征归一化

1.2.1 选取k值

        k近邻的k值我们应该怎么选取呢?如果我们选取较小的k值,那么就会意味着我们的整体模型会变得复杂,容易发生过拟合。假设我们有训练数据和待分类点如下图:

        上图中有俩类,一个是黑色的圆点,一个是蓝色的长方形,现在我们的待分类点是红色的五边形。根据我们的k近邻算法步骤来决定待分类点应该归为哪一类。我们由图中可以得到,很容易我们能够看出来五边形离黑色的圆点最近,k又等于1,那太好了,我们最终判定待分类点是黑色的圆点。

        由这个处理过程我们很容易能够感觉出问题了,如果k太小了,比如等于1,那么模型就太复杂了,我们很容易学习到噪声,也就非常容易判定为噪声类别,而在上图,如果,k大一点,k等于8,把长方形都包括进来,我们很容易得到我们正确的分类应该是蓝色的长方形!如下图:

        所谓的过拟合就是在训练集上准确率非常高,而在测试集上准确率低,经过上例,我们可以得到k太小会导致过拟合,很容易将一些噪声(如上图离五边形很近的黑色圆点)学习到模型中,而忽略了数据真实的分布!

        如果我们选取较大的k值,就相当于用较大邻域中的训练数据进行预测,这时与输入实例较远的(不相似)训练实例也会对预测起作用,使预测发生错误,k值的增大意味着整体模型变得简单。

        我们想,如果k=N(N为训练样本的个数),那么无论输入实例是什么,都将简单地预测它属于在训练实例中最多的类。直接拿训练数据统计了一下各个数据的类别,找最大的而已!这好像下图所示:

        这个时候,模型过于简单,完全忽略训练数据实例中的大量有用信息,是不可取的。k值既不能过大,也不能过小,我们k值的选择,在下图红色圆边界之间这个范围是最好的,如下图:

        这里只是为了更好让大家理解,真实例子中不可能只有俩维特征,但是原理是一样的,我们就是想找到较好的k值大小。那么我们一般怎么选取呢?我们一般选取一个较小的数值,通常采取交叉验证法来选取最优的k值。

1.2.2 距离的度量

        k近邻算法是在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类,我们就说预测点属于哪个类。定义中所说的最邻近是如何度量呢?我们怎么知道谁跟测试点最邻近。这里就会引出我们几种度量俩个点之间距离的标准。我们可以有以下几种度量方式:

        其中当p=2的时候,就是我们最常见的欧式距离,我们也一般都用欧式距离来衡量我们高维空间中俩点的距离。在实际应用中。距离函数的选择应该根据数据的特性和分析的需要而定,一般选取p=2欧式距离表示。

1.2.3 特征归一化

        首先举例如下,我用一个人身高(cm)与脚码(尺码)大小来作为特征值,类别为男性或者女性。我们现在如果有5个训练样本,分布如下:

  • A [(179,42),男]
  • B [(178,43),男]
  • C [(165,36)女]
  • D [(177,42),男]
  • E [(160,35),女]

        通过上述训练样本,很容易看到第一维身高特征是第二维脚码特征的4倍左右,那么在进行距离度量的时候,我们就会偏向于第一维特征。这样造成俩个特征并不是等价重要的,最终可能会导致距离计算错误,从而导致预测错误。例如:

        现在我来了一个测试样本 F(167,43),让我们来预测他是男性还是女性,我们采取k=3来预测。下面我们用欧式距离分别算出F离训练样本的欧式距离,然后选取最近的3个,多数类别就是我们最终的结果,计算如下:

        由计算可以得到,最近的前三个分别是C,D,E三个样本,那么由C,E为女性,D为男性,女性多于男性得到我们要预测的结果为女性。

        这样问题就来了,一个女性的脚43码的可能性,远远小于男性脚43码的可能性,那么为什么算法还是会预测F为女性呢?那是因为由于各个特征量纲的不同,在这里导致了身高的重要性已经远远大于脚码了,这是不客观的。所以我们应该让每个特征都是同等重要的!这也是我们要归一化的原因!归一化公式如下:

1.3 KNN 算法的优缺点

        优点

  • 简单易用:KNN 算法的原理简单,易于理解和实现。
  • 无需训练:KNN 不需要显式的训练过程,所有的计算都在预测时进行。
  • 适用于多分类问题:KNN 可以轻松处理多分类问题。

        缺点

  • 计算复杂度高:KNN 需要在预测时计算所有样本的距离,当数据集较大时,计算复杂度会很高。
  • 对噪声敏感:KNN 对噪声数据较为敏感,噪声数据可能会影响预测结果。
  • 需要选择合适的 K 值:K 值的选择对模型的性能有很大影响,选择合适的 K 值是一个挑战。

2 KNN 算法的实现步骤

2.1 导入必要的库

        首先,我们需要导入一些常用的 Python 库,如 numpy 用于数值计算,matplotlib 用于绘图,sklearn 用于加载数据集和评估模型。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

2.2 加载数据集

        我们使用 sklearn 中的 load_iris 函数加载经典的鸢尾花数据集。这个数据集包含 150 个样本,每个样本有 4 个特征,目标是将样本分为 3 类。

# 加载Iris数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 只取前两个特征,便于可视化
y = iris.target

2.3 数据预处理

        在应用 KNN 算法之前,通常需要对数据进行标准化处理,以确保每个特征对距离计算的贡献是相同的。

# 将数据集拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

2.4 训练 KNN 模型

        接下来,我们使用 sklearn 中的 KNeighborsClassifier 来训练 KNN 模型。这里我们选择 K=3,即选择 3 个最近邻。

# 创建KNN模型,设置K值为3
knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型
knn.fit(X_train, y_train)

2.5 预测与评估

        使用训练好的模型对测试集进行预测,并计算模型的准确率。

# 在测试集上进行预测
y_pred = knn.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"KNN模型的准确率: {accuracy:.4f}")

2.6 可视化 KNN 分类结果

# 绘制决策边界和数据点
h = .02  # 网格步长
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1

# 创建一个二维网格,表示不同的样本空间
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                     np.arange(y_min, y_max, h))

# 使用KNN模型预测网格中的每个点的类别
Z = knn.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 绘制决策边界
plt.contourf(xx, yy, Z, alpha=0.8)

# 绘制训练数据点
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o', s=50)
plt.title("KNN Demo")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

2.7 调整 K 值

        K 值的选择对模型的性能有重要影响。我们通过交叉验证或可视化方法选择最佳的 K 值。

# 尝试不同的K值并绘制准确率变化
k_range = range(1, 21)
accuracies = []

for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train)
    y_pred = knn.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    accuracies.append(accuracy)

# 绘制K值与准确率的关系
plt.rcParams["font.sans-serif"] = ["SimHei"]  # 设置字体
plt.plot(k_range, accuracies, marker='o')
plt.title("K值与准确率的关系")
plt.xlabel("K值")
plt.ylabel("准确率")
plt.show()

2.8 使用 KNN 进行回归任务

        KNN 同样可以用于回归任务(KNN Regression)。在回归任务中,KNN 根据 K 个最近邻的目标值进行平均来预测输出。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor

# 生成示例数据
X = np.random.rand(100, 1) * 10
y = np.sin(X).ravel() + 0.1 * np.random.randn(100)

# 拆分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建KNN回归模型
knn_reg = KNeighborsRegressor(n_neighbors=5)

# 训练模型
knn_reg.fit(X_train, y_train)

# 在测试集上进行预测
y_pred = knn_reg.predict(X_test)

# 可视化回归结果
plt.scatter(X_test, y_test, color='red', label='True Values')
plt.scatter(X_test, y_pred, color='blue', label='Predicted Values')
plt.title("KNN Regression")
plt.xlabel("Feature")
plt.ylabel("Target")
plt.legend()
plt.show()


http://www.kler.cn/a/514891.html

相关文章:

  • VirtualBox can‘t enable the AMD-V extension
  • 扬帆数据结构算法之雅舟航程,漫步C++幽谷——LeetCode刷题之移除链表元素、反转链表、找中间节点、合并有序链表、链表的回文结构
  • 剑指Offer|LCR 040.最大矩形
  • Solidity06 Solidity变量数据存储和作用域
  • 安装centos7之后问题解决
  • 根除埃博拉病毒(2015MCM美赛A)
  • 嵌入式入门(一)-STM32CubeMX
  • c++中的链表list
  • 【Android】创建基类BaseActivity和BaseFragment
  • Spring注解篇:@RestController详解
  • AI大模型-提示工程学习笔记11-思维树
  • 【线性代数】列主元法求矩阵的逆
  • 云原生架构下的AI智能编排:ScriptEcho赋能前端开发
  • 2025_1_22_进程替换
  • Simula语言的云计算
  • C语言进阶习题【1】指针和数组(4)——指针笔试题3
  • RabbitMQ的消息可靠性保证
  • 网络(一)
  • C语言程序环境与预处理—从源文件到执行程序,这里面有怎么的工序?绝对0基础!
  • 【 MySQL 学习4】排序