当前位置: 首页 > article >正文

机器学习:k近邻

所有代码和文档均在golitter/Decoding-ML-Top10: 使用 Python 优雅地实现机器学习十大经典算法。 (github.com),欢迎查看。

K 邻近算法(K-Nearest Neighbors,简称 KNN)是一种经典的机器学习算法,主要用于分类和回归任务。它的核心思想是:给定一个新的数据点,通过查找训练数据中最接近的 K 个邻居,并根据这些邻居的标签来预测新数据点的标签。

KNN 是一种 基于实例的学习(Instance-based learning)算法。在训练阶段,它并不构建显式的模型,而是将训练数据存储起来,在预测阶段计算待预测点与训练集中所有点的距离,然后选择 K 个最近的邻居,根据邻居的标签进行投票或平均来做出预测。

KNN 的优点在于其简单易懂、无需训练过程,并且适用于大多数任务。它能够处理复杂的非线性问题,不依赖数据分布假设,能够很好地适应复杂的决策边界。

然而,KNN 的缺点也很明显。它的计算开销大,因为每次预测都需要计算所有训练数据的距离,导致在大数据集上表现不佳。此外,KNN 需要存储所有训练数据,占用较大的内存空间,并且对异常值敏感,可能会影响预测结果的准确性。

KNN算法步骤:

  1. 选择 K 个邻居的数量,K 值通常是一个奇数,以避免平票的情况。
  2. 计算待预测数据点与训练数据集中每个点的距离。
  3. 根据计算出的距离选择 K 个最接近的点。
  4. 对于分类任务,返回 K 个邻居中最多的类别;对于回归任务,返回 K 个邻居标签的均值。

代码实现

数据处理:使用iris.data数据集,用PCA进行降维。

import numpy as np
import pandas as pd


def pca(X: np.array, n_components: int) -> np.array:
	"""
	PCA 进行降维。
	"""
	# 1. 数据标准化(去均值)
	X_mean = np.mean(X, axis=0)
	X_centered = X - X_mean

	# 2. 计算协方差矩阵
	covariance_matrix = np.cov(X_centered, rowvar=False)

	# 3. 计算特征值和特征向量
	eigenvalues, eigenvectors = np.linalg.eig(covariance_matrix)

	# 4. 按特征值降序排序
	sorted_indices = np.argsort(eigenvalues)[::-1]
	top_eigenvectors = eigenvectors[:, sorted_indices[:n_components]]

	# 5. 投影到新空间
	X_pca = np.dot(X_centered, top_eigenvectors)

	return X_pca


def get_data():
	data = pd.read_csv('iris.csv', header=None)
	# print(data.dtypes)
	unq = data.iloc[:, -1].unique()
	for i, u in enumerate(unq):
		data.iloc[:, -1] = data.iloc[:, -1].apply(lambda x: i if x == u else x)

	# print(data.sample(5))
	xuanze = np.random.choice([True, False], len(data), replace=True, p=[0.8, 0.2])
	train_data = data[xuanze]
	test_data = data[~xuanze]
	train_data = np.array(
		train_data,
		dtype=np.float32,
	)
	test_data = np.array(test_data, dtype=np.float32)
	# 归一化
	train_data[:, :-1] = (train_data[:, :-1] - train_data[:, :-1].mean(axis=0)) / train_data[:, :-1].std(axis=0)
	test_data[:, :-1] = (test_data[:, :-1] - test_data[:, :-1].mean(axis=0)) / test_data[:, :-1].std(axis=0)
	return (
		pca(train_data[:, :-1], 2),
		train_data[:, -1].astype(np.int32),
		pca(test_data[:, :-1], 2),
		test_data[:, -1].astype(np.int32),
	)


if __name__ == '__main__':
	x_train, y_train, x_test, y_test = get_data()
	print(y_train.dtype)
	print(x_test, y_test)
	print(x_train.shape, y_train.shape)

knn过程:

from data_processing import get_data
import numpy as np
import matplotlib.pyplot as plt


def euclidean_distance(x_train: np.array, x_test: np.array) -> np.array:
	"""
	计算欧拉距离
	"""
	return np.sqrt(np.sum((x_train - x_test) ** 2, axis=1))


def knn(k: int, x_train: np.array, y_train: np.array, x_test: np.array) -> np.array:
	"""
	k近邻算法
	"""
	predictions = []
	for test in x_test:
		distances = euclidean_distance(x_train, test)
		nearest_indices = np.argsort(distances)[:k]  # 返回最近的k个点的索引
		nearest_labels = y_train[nearest_indices]  # 返回最近的k个点的标签
		prediction = np.argmax(np.bincount(nearest_labels))  # 返回最近的k个点中出现次数最多的标签
		predictions.append(prediction)
	return np.array(predictions)


def accuracy(predictions: np.array, y_test: np.array) -> float:
	"""
	计算准确率
	"""
	return np.sum(predictions == y_test) / len(y_test)


if __name__ == '__main__':
	k = 5
	x_train, y_train, x_test, y_test = get_data()
	predictions = knn(k, x_train, y_train, x_test)
	acc = accuracy(predictions, y_test)
	print(f'准确率为: {acc * 100:.2f}')

	# 绘制训练数据
	plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap='viridis', marker='o', label='Train Data', alpha=0.7)

	# 绘制测试数据
	plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test, cmap='coolwarm', marker='x', label='Test Data', alpha=0.7)

	# 绘制预测结果
	plt.scatter(
		x_test[:, 0],
		x_test[:, 1],
		c=predictions,
		cmap='coolwarm',
		marker='.',
		edgecolor='black',
		alpha=0.7,
		label='Predictions',
	)

	# 添加标题和标签
	plt.title('KNN Classification Results')
	plt.xlabel('Feature 1')
	plt.ylabel('Feature 2')
	plt.legend()

	# 显示图形
	plt.show()

在这里插入图片描述


http://www.kler.cn/a/552627.html

相关文章:

  • 巧用 PasteMate,联合 DeepSeek 与 LaTeX 高效生成 PDF 文档
  • C#中的图形渲染模式
  • 后端生成二维码,前端请求接口生成二维码并展示,且多个参数后边的参数没有正常传输问题处理
  • 个人shell脚本分享
  • 记一次 Git Fetch 后切换分支为空的情况
  • 【C++笔记】C++11的深度剖析(二)
  • GIT提错分支,回滚提交
  • SOME/IP--协议英文原文讲解7
  • 蓝桥杯 Java B 组之日期与时间计算(闰年、星期计算)
  • 使用API有效率地管理Dynadot域名,参与过期域名竞价
  • 系统学习算法:专题十一 floodfill算法
  • 无人机避障——配置新NX
  • 出现 [ app.json 文件内容错误] app.json: 在项目根目录未找到 app.json (env: Windows,mp 解决方法
  • C#程序中进行打印输出文本
  • opencascade 源码学习找到edge对应的face BRepBuilderAPI-BRepBuilderAPI_FindPlane
  • 架构师面试(二):计算机编程基础
  • 极限网关核心架构解析:从 Nginx 到 INFINI Gateway 的演进
  • ABB机器人的二次开发
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_palloc_block函数
  • golang面试题:两个interface{} 能不能比较?