KNN算法介绍及代码实例
KNN算法工作原理
-
算法概述:
- 给定一个数据点,通过计算它与训练集中每个点的距离,找到距离最近的K个点。
- 通过这些邻居点的标签或值,决定该数据点的预测结果。
-
分类任务:
- 使用K个邻居中多数类别的标签作为新数据点的预测类别。
-
回归任务:
- 使用K个邻居的平均值作为新数据点的预测值。
-
距离度量:
- 常用欧几里得距离公式:
- 除欧几里得距离外,曼哈顿距离、切比雪夫距离等。
实例
利用sklearn数据库得鸢尾花数据集为例,进行检测模型训练和目标分类预测。
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# 1. 加载数据集
data = load_iris()
X, y = data.data, data.target # 特征和标签
# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 3. 定义KNN分类器
k = 5 # K值
knn = KNeighborsClassifier(n_neighbors=k)
# 4. 训练模型
knn.fit(X_train, y_train)
# 5. 进行预测
y_pred = knn.predict(X_test)
# 6. 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"KNN分类准确率: {accuracy:.2f}")
# 7. 示例单点预测
sample_point = [5.1, 3.5, 1.4, 0.2] # 一个新的数据点
predicted_class = knn.predict([sample_point])
print(f"新样本预测类别: {data.target_names[predicted_class[0]]}")