当前位置: 首页 > article >正文

【机器学习】K近邻算法

目录

算法引入:

KNN算法的核心思想

KNN算法的步骤

KNN常用的距离度量方法

KNN算法的优缺点

优点:

缺点:

K值的选择

KNN的C++实现

复杂度分析:


K近邻算法(K-Nearest Neighbors, KNN)是一种简单但非常实用的监督学习算法,主要用于分类和回归问题。KNN 基于相似性度量(如欧几里得距离)来进行预测,核心思想是给定一个样本,找到与其最接近的 K 个邻居,根据这些邻居的类别或特征对该样本进行分类或预测。

算法引入:

我们假设平面上有两类点的集合,一类属于A类,一类属于B类,A类有三个点,B类有三个点。如果我们要加入一个橙色的点,那么它是属于A类还是B类? 

我们不去考虑算法的话,如果去归类A类还是B类,那么我们肯定想到的就是这个点它离哪个点最近,就属于哪一类。到这里就与我们的K近邻算法大概相同了,只不过我们要选取一个范围,在这个范围里找点进行判断,比如,我们选择它的邻域三个点,即K=3,那么这个区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。

K近邻算法有三个要素,如下图所示,第一个就是距离度量,这个距离有很多种距离,比如:欧几里得距离、曼哈顿距离、闵可夫斯基距离等。上面的例子中我们选择的欧几里得距离。第二个是K值,就是这个一个范围的的点个数,上面例子中,K=3。第三个是少数服从多数规则,上面例子中区域里面有一个点属于A类,两个点属于B类,根据少数服从多数,我们就可以把它归于B类点。


KNN算法的核心思想

1. 分类问题:
   - 给定一个未标记的数据点,通过计算该数据点与已标记的训练数据集中的每一个数据点的距离,选择距离最近的 K 个邻居。
   - 根据这 K 个邻居的类别,采用“多数投票”的方式来决定未标记数据点的类别。

2. 回归问题:
   - 计算未标记数据点的 K 个最近邻居的值,然后取这些邻居的平均值或加权平均值作为该点的预测值。

KNN算法的步骤

1. 数据准备:准备好训练数据集,包括特征和标签(分类问题中为类别,回归问题中为数值)。
2. 选择K值:选择邻居的数量K,一般是正整数。
3. 计算距离:对每个未标记数据点,计算它与训练集中每一个数据点的距离(常见的距离度量方法有欧几里得距离、曼哈顿距离等)。
4. 选择K个最近邻居:根据距离从小到大排序,选择距离最近的K个邻居。
5. 投票或平均:分类问题中,根据K个邻居的类别进行投票选择类别;回归问题中,计算邻居的平均值作为预测结果。


KNN常用的距离度量方法

1. 欧几里得距离:

   欧几里得距离是最常用的距离度量方法,适用于连续变量的情况。

2. 曼哈顿距离:
   
   曼哈顿距离适用于某些特定场景,尤其是当特征值的变化范围不均匀时。

3. 闵可夫斯基距离:
  
   闵可夫斯基距离是欧几里得距离和曼哈顿距离的推广形式,其中 p 是一个参数,当 p=2 时,便是欧几里得距离,当 p=1 时便是曼哈顿距离。


KNN算法的优缺点

优点:

1. 简单易理解:KNN算法实现简单,易于理解和解释。
2. 无参数模型:KNN不需要训练过程,可以直接使用数据进行预测。
3. 适用性广泛:KNN可以用于分类和回归问题,并且对非线性数据有较好的适应性。

缺点:

1. 计算复杂度高:KNN算法需要对每一个测试样本都计算与所有训练样本的距离,因此在大数据集下计算开销较大。
2. 内存开销大:KNN需要存储整个训练数据集,占用较大的存储空间。
3. 对噪声敏感:KNN对噪声数据较为敏感,特别是在K值较小的情况下,少量噪声数据可能会对结果产生很大影响。


K值的选择

- 小K值:K值较小时,模型会更加复杂,可能会过拟合。即使有少量噪声数据,也会对分类结果产生较大的影响。
- 大K值:K值较大时,模型会更加平滑,可能会欠拟合。K值过大会忽略数据的局部结构。

通常,K值通过交叉验证等方法来选择合适的值。


KNN的C++实现

下面是一个简单的KNN算法的C++实现,用于分类问题,采用欧几里得距离来计算邻居之间的距离。

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>
using namespace std;

// 定义一个点,包含特征和标签
struct Point {
    vector<double> features;
    int label;
};

// 计算欧几里得距离
double euclideanDistance(const vector<double>& a, const vector<double>& b) {
    double distance = 0.0;
    for (int i = 0; i < a.size(); i++) {
        distance += pow(a[i] - b[i], 2);
    }
    return sqrt(distance);
}

// KNN算法实现
int knn(const vector<Point>& train_data, const vector<double>& test_point, int k) {
    vector<pair<double, int>> distances; // 距离和标签的对
    
    // 计算每个训练数据点到测试点的距离
    for (const auto& point : train_data) {
        double distance = euclideanDistance(point.features, test_point);
        distances.push_back({distance, point.label});
    }

    // 按距离排序
    sort(distances.begin(), distances.end());

    // 统计前k个最近邻的类别
    vector<int> label_count(100, 0); // 假设标签在0-99之间
    for (int i = 0; i < k; i++) {
        label_count[distances[i].second]++;
    }

    // 返回出现次数最多的类别
    int max_count = 0;
    int predicted_label = -1;
    for (int i = 0; i < label_count.size(); i++) {
        if (label_count[i] > max_count) {
            max_count = label_count[i];
            predicted_label = i;
        }
    }

    return predicted_label;
}

int main() {
    int n, m, k;
    cin >> n;//训练数据的个数
    cin >> m;//测试数据的维度
    cin >> k;
    
    // 输入训练数据
    vector<Point> train_data(n);
    for (int i = 0; i < n; i++) {
        train_data[i].features.resize(m);
        for (int j = 0; j < m; j++) {
            cin >> train_data[i].features[j];
        }
        cin >> train_data[i].label;
    }

    // 输入测试点
    vector<double> test_point(m);
    for (int j = 0; j < m; j++) {
        cin >> test_point[j];
    }

    // 使用KNN进行分类
    int predicted_label = knn(train_data, test_point, k);
    cout << predicted_label << endl;

    return 0;
}

复杂度分析:

- 时间复杂度:对于每个测试点,KNN需要计算与所有训练点的距离,因此时间复杂度为 O(n * m),其中 n 是训练集大小,m 是特征维度。
- 空间复杂度:主要用于存储训练数据和距离结果,空间复杂度为 O(n)。


K近邻算法是一个简单直观的非参数分类算法,适用于低维、小数据集的情况。然而,由于它的计算复杂性较高,KNN在大数据集或高维数据上的表现不佳。因此,KNN算法通常被用作基准模型或在小规模数据集上使用。


http://www.kler.cn/a/394762.html

相关文章:

  • NVR录像机汇聚管理EasyNVR多品牌NVR管理工具/设备:大华IPC摄像头局域网访问异常解决办法
  • LeetCode题解:5.最长回文子串【Python题解超详细,中心拓展、动态规划、暴力解法】
  • 高级java每日一道面试题-2024年11月06日-JVM篇-什么是 Class 文件? Class 文件主要的信息结构有哪些?
  • 实验5:网络设备发现、管理和维护
  • 分享 pdf 转 word 的免费平台
  • Dolby TrueHD和Dolby Digital Plus (E-AC-3)编码介绍
  • C++——视频问题总结
  • 猎板PCB罗杰斯板材的应用案例
  • 【填鸭表单】TDuckX-v2.0发布!
  • 【深度学习】神经网络优化方法 正则化方法 价格分类案例
  • 力扣-Mysql-3322- 英超积分榜排名 III(中等)
  • PyTorch——从入门到精通:PyTorch简介与安装(最新版)【PyTorch系统学习】
  • golang分布式缓存项目 Day4 一致性哈希
  • 前端权限控制代码
  • 计算机毕业设计 | SpringBoot社区物业管理系统 小区管理(附源码)
  • 14.最长公共前缀-力扣(LeetCode)
  • CSS:怎么把网站都变成灰色
  • uniapp解析蓝牙设备响应数据bug
  • 3588 yolov8 onnx 量化转 rknn 并运行
  • spark的学习-06
  • k8s 1.28.2 集群部署 docker registry 接入 MinIO 存储
  • leveldb存储token的简单实现
  • 数据结构-布隆过滤器和可逆布隆过滤器
  • vue中 通过cropperjs 实现图片裁剪
  • 开源项目低代码表单设计器FcDesigner扩展右侧组件的配置规则
  • Spring Cloud Gateway(分发请求)