当前位置: 首页 > article >正文

数据结构与算法Bonus-KNN问题的代码求解过程

一、问题提出

(一)要求

1.随机生成>=10万个三维点的点云,并以适当方式存储

2.自行实现一个KNN算法,对任意Query点,返回最邻近的K个点

3.不允许使用第三方库(e.g.flann,PCL,opencv)!

4.语言任选(推荐C++或者Python)

(二)规则

1.正确实现(3')

2.优于Flann、PCL在相同输入下的KNN求解函数中的一种(2')

3.优于Flann、PCL在相同输入下的KNN求解函数中的两种(2')

4.创新性评估(3')

二、KNN算法概述

KNN(K-Nearest Neighbor)算法,也称为K最邻近法,是一种基本的机器学习算法,属于有监督学习中的分类算法。该算法最初由Cover和Hart于1968年提出,具有简单直观的特点。

KNN算法的思路是:如果一个样本在特征空间中的K个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。这里的K通常是一个用户指定的整数,通常选取奇数以避免出现平局的情况。

KNN算法的距离度量通常是欧氏距离,但也可以使用其他距离度量方法。在选择K值时,较小的K值可能会使算法对噪声更加敏感,而较大的K值可能会使算法分类边界变得模糊。因此,选择合适的K值对于KNN算法的性能至关重要。

三、算法描述

(一)语言选择

所选语言为MATLAB,软件版本为MATLAB R2022a

(二)算法原理

K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的多数属于某个类,就把该输入实例分类到这个类中。(这就类似于现实生活中少数服从多数的思想),也就是在训练数据集中寻找与待预测样本A距离最近的K个样本,如果K个样本中大多数属于类别甲,少数属于类别乙,那个就可以认为样本A属于类别甲。

(三)距离度量

一般计算样本在多维空间的距离有两种方式:欧式距离和曼哈顿距离。

在实际KNN问题应用中,距离函数的选择应该根据数据的特性和分析的需要而定,选择欧式距离表示。

(四)算法流程

1.设置样本点数量,此处定义N为样本点数目的一半。

2.设置矩阵label1、label2存储样本点所属类别,label1为第一类,label2位第二类。

3.生成随机数矩阵data1、data2。为了使样本点更具分散性,此处选择直接使用rand函数,而不是正态随机数randn和mvnrnd函数。

4.利用三维绘图函数scatter3绘制两类样本点,如下图所示(此处使用N=100举例,增加可视性):

其中,红色代表第一类样本点,蓝色代表第二类样本点。

  1. 设置K值为11(K必为奇数),即周围11个样本点。
  2. 遍历从(3,3,3)至(7,7,7)范围内的125个点,间隔为1个单位。

7.计算待预测样本与训练数据集中样本特征之间的欧式距离dis。

8.按照距离递增的顺序,使用sort函数排序,返回排序后的矩阵B及其索引值矩阵index。

9.选取距离最近的K个样本以及所属类别的次数,输出最近的K个样本坐标(见附件:最邻近K点输出结果.xlsx)。

10.分别用变量c1、c2存储出现的类别次数,返回前k个点所出现频率最高的类别作为预测分类结果。

11.数据可视化处理,如下图所示(N为100时):

当n=50000时(即分析10万个样本点时),图像如下图所示:

四、代码实现 

% KNN算法
clear all;
clc;
%总体样本点数量为2N
N=50000;
% 每一个数据有两个特征
label1 = ones(N,1);%第一类点云序号,记为1
label2 = 1+ones(N,1);%第二类点云序号,记为2
%生成第一类数据
data1 =  10*rand(N,3);%坐标范围为[0,10]
data1(data1<0)=0;
%生成第二类数据
data2 =  10*rand(N,3);
data2(data1<0)=0;
scatter3(data1(:,1),data1(:,2),data1(:,3),'ro')%红色圆圈代表第一类数据
hold on;
scatter3(data2(:,1),data2(:,2),data2(:,3),'b^')%蓝色三角代表第二类数据
hold on;
data = [data1;data2];%两类数据整合,放在一个矩阵里
label = [label1;label2];%两类数据类别序号整合,也放在一个矩阵里
K= 11;%K值为11,表示周围最近的11个点。K为奇数
for i1 = 3:7
    for i2 = 3:7
        for i3=3:7
            testdata = [i1 i2 i3];
            distance=zeros(2*N,1);
            dis = sum((data-testdata).^2,2);%返回包含每一行总和的列向量
            [B index]= sort(dis);%返回索引值index
            for j=1:K
                disp("与点("+num2str(i1)+","+num2str(i2)+","+num2str(i3)+")相邻的第"+num2str(j)+"个点的坐标为:("+num2str(data(index(j,1),1))+","+num2str(data(index(j,1),2))+","+num2str(data(index(j,1),3))+")");
            end
            disp(' ');%换行
            newLabel = label(index(1:K));
            c1 = 0;
            c2 = 0;
            for ii = 1:K
                if newLabel(ii)==1
                    c1 = c1+1;%第一类的点数量加一
                else
                    c2 = c2+1;%第二类的点数量加一
                end
            end
            if c1>c2%第一类的数量的点更多
                scatter3(testdata(1),testdata(2),testdata(3),50,'ro','filled')
            else%第二类的数量的点更多
                scatter3(testdata(1),testdata(2),testdata(3),50,'bo','filled')
            end
        end
    end
end
legend('第一类','第二类')

整体过程偏向于暴力,仅供参考 


http://www.kler.cn/a/273254.html

相关文章:

  • Vue2+OpenLayers调用WMTS服务初始化天地图示例(提供Gitee源码)
  • Docker Desktop 构建java8基础镜像jdk安装配置失效解决
  • 解决:ubuntu22.04中IsaacGymEnv保存视频报错的问题
  • 大语言模型的稀疏性:提升效率与性能的新方向
  • HTML和CSS相关的问题,为什么页面加载速度慢?
  • Bash语言的语法糖
  • java15~17 密封类
  • 【JS逆向学习】猿人学第六题 js混淆 回溯
  • 数目之差
  • 【Paper Reading】6.RLHF-V 提出用RLHF的1.4k的数据微调显著降低MLLM的虚幻问题
  • upload-labs 0.1 靶机详解
  • 【Spring MVC】Spring MVC拦截器(Interceptor)
  • 《我的AUTOSAR之路》ECUM(二) 唤醒处理
  • 【Java】高级篇1:异常处理
  • 小迪安全42WEB攻防-通用漏洞文件包含LFIRFI伪协议
  • python 深度学习的学习路径
  • 机器学习----特征缩放
  • Mysql增删改查(详解)
  • 摄影第一课
  • java 程序连接 redis 集群 的时候报错 MUTLI is currently not supported in cluster mode
  • 阳光保险MySQL数据库平稳迁移OceanBase,稳定运营超700天
  • XDAG节点版本更新(0.6.5升级到0.7.0)
  • SpringBoot异常:类文件具有错误的版本 61.0, 应为 52.0的解决办法
  • 辐射全国、面向世界、聚焦未来——华为(深圳)全球具身智能产业创新中心正式成立
  • LeetCode刷题记录:(11)组合(初识回溯算法)
  • 前端路由跳转bug