当前位置: 首页 > article >正文

机器学习·最近邻方法(k-NN)

前言

上一篇简单介绍了决策树,而本篇讲解与决策树相近的 最近邻方法k-NN

机器学习·决策树-CSDN博客


一、算法原理对比
特性决策树最近邻方法(k-NN)
核心思想通过特征分割构建树结构,递归划分数据基于距离度量,用最近的k个样本投票预测
训练方式显式构建模型(预训练)惰性学习(无显式训练,预测时计算)
关键参数max_depthmin_samples_leafn_neighborsmetric(距离度量)
分割标准信息增益、基尼系数欧氏距离、曼哈顿距离、余弦相似度等
输出类型分类树(类别标签)、回归树(连续值)分类(多数投票)、回归(均值/中位数)

二、概念
  1. 决策树

    • 信息增益(Entropy)
      \( S = -\sum_{i=1}^N p_i \log_2 p_i \)
      选择分割时最大化信息增益,减少不确定性。

    • 基尼系数(Gini Index)
      \( G = 1 - \sum_{k} (p_k)^2 \)
      衡量数据不纯度,值越小分割越优。

    • 剪枝策略

      • 预剪枝:限制树深度(max_depth)、叶节点最小样本数(min_samples_leaf)。

      • 后剪枝:构建完整树后合并冗余节点。

  2. k-NN

    • 距离度量

      • 欧氏距离(默认):\( d(x,y) = \sqrt{\sum (x_i - y_i)^2} \)

      • 曼哈顿距离:\( d(x,y) = \sum |x_i - y_i| \)

      • 余弦相似度:衡量向量方向相似性。

    • 参数调优

      • n_neighbors:邻居数,小值易过拟合,大值易欠拟合。

      • weights:邻居权重(uniform均等权重,distance按距离反比加权)。


三、交叉验证与调优
  1. 交叉验证方法

    • k折交叉验证:数据分为k个子集,轮流用k-1个子集训练,1个子集验证,取平均性能。

    • 留出法:按比例(如70%-30%)划分训练集和验证集。

  2. GridSearchCV 参数调优

    from sklearn.model_selection import GridSearchCV
    
    # 决策树参数网格
    tree_params = {'max_depth': [3, 5, 7], 'max_features': [10, 20, 30]}
    tree_grid = GridSearchCV(DecisionTreeClassifier(), tree_params, cv=5)
    tree_grid.fit(X_train, y_train)
    
    # k-NN参数网格(需标准化)
    knn_pipe = Pipeline([('scaler', StandardScaler()), 
                         ('knn', KNeighborsClassifier())])
    knn_params = {'knn__n_neighbors': range(1, 10)}
    knn_grid = GridSearchCV(knn_pipe, knn_params, cv=5)
    knn_grid.fit(X_train, y_train)
     

四、实际应用与性能对比
  1. 客户流失预测任务

    • 数据集:电信客户流失数据(特征包括国际套餐、语音邮箱等)。

    • 结果对比

      模型留置集准确率交叉验证最佳准确率
      决策树(调优)94.6%94.0%
      k-NN(调优)89.0%88.5%
      随机森林95.3%93.5%
  2. MNIST手写数字识别

    • 数据集:8x8像素手写数字图片。

    • 结果对比

      模型留置集准确率交叉验证最佳准确率
      决策树(调优)84.4%66.6%
      k-NN(调优)98.7%97.6%
      随机森林93.4%-

五、优缺点总结
算法优点缺点
决策树1. 可解释性强,规则可视化
2. 支持类别/数值特征
3. 训练速度快
1. 对噪声敏感,易过拟合
2. 边界为轴平行,灵活性差
3. 无法外推
k-NN1. 简单易实现
2. 无需显式训练
3. 适应复杂边界(小k值)
1. 预测速度慢(大数据集)
2. 高维数据效果差(维度灾难)
3. 依赖距离度量

六、应用场景
  1. 选择决策树

    • 需要可解释性强的模型(如金融风控、医疗诊断)。

    • 数据特征存在明显分层逻辑(如年龄分段、阈值判断)。

    • 实时预测需求(快速推理)。

  2. 选择k-NN

    • 数据维度较低且分布复杂(如小规模图像分类)。

    • 需要快速原型验证(基线模型)。

    • 数据特征尺度一致(需标准化)。


七、结论
  1. 模型选择优先级

    • 优先尝试简单模型(如决策树、k-NN),再过渡到复杂模型(随机森林、神经网络)。

    • 决策树在结构化数据中表现优异,k-NN适合小规模非结构化数据。

  2. 调优核心

    • 决策树:控制深度(max_depth)和叶节点样本数(min_samples_leaf)。

    • k-NN:选择合适的邻居数(n_neighbors)和距离度量(metric)。

  3. 交叉验证必要性

    • 避免过拟合,确保模型泛化性,尤其在参数调优时不可或缺。


八、完整代码

1.客户流失预测任务

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_digits
from sklearn.tree import export_graphviz
import pydotplus
from io import StringIO
import matplotlib.pyplot as plt
from IPython.display import Image


# 客户离网率预测任务
# 数据预处理
df = pd.read_csv('https://labfile.oss.aliyuncs.com/courses/1283/telecom_churn.csv')
df['International plan'] = pd.factorize(df['International plan'])[0]
df['Voice mail plan'] = pd.factorize(df['Voice mail plan'])[0]
df['Churn'] = df['Churn'].astype('int')
states = df['State']
y = df['Churn']
df.drop(['State', 'Churn'], axis=1, inplace=True)

# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(df.values, y, test_size=0.3, random_state=17)

# 训练决策树和K近邻模型(随机参数)
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn = KNeighborsClassifier(n_neighbors=10)
tree.fit(X_train, y_train)
knn.fit(X_train, y_train)

# 模型评估
tree_pred = tree.predict(X_holdout)
print("决策树准确率(随机参数):", accuracy_score(y_holdout, tree_pred))
knn_pred = knn.predict(X_holdout)
print("K近邻准确率(随机参数):", accuracy_score(y_holdout, knn_pred))

# 决策树交叉验证调优
tree_params = {'max_depth': range(5, 7),'max_features': range(16, 18)}
tree_grid = GridSearchCV(tree, tree_params, cv=5, n_jobs=-1, verbose=True)
tree_grid.fit(X_train, y_train)
print("决策树最佳参数:", tree_grid.best_params_)
print("决策树最佳分数:", tree_grid.best_score_)
print("决策树调优后准确率:", accuracy_score(y_holdout, tree_grid.predict(X_holdout)))

# 绘制决策树
dot_data = StringIO()
export_graphviz(tree_grid.best_estimator_, feature_names=df.columns, out_file=dot_data, filled=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(value=graph.create_png())

# K近邻交叉验证调优
knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_jobs=-1))])
knn_params = {'knn__n_neighbors': range(6, 8)}
knn_grid = GridSearchCV(knn_pipe, knn_params, cv=5, n_jobs=-1, verbose=True)
knn_grid.fit(X_train, y_train)
print("K近邻最佳参数:", knn_grid.best_params_)
print("K近邻最佳分数:", knn_grid.best_score_)
print("K近邻调优后准确率:", accuracy_score(y_holdout, knn_grid.predict(X_holdout)))

# 训练随机森林模型
forest = RandomForestClassifier(n_estimators=100, n_jobs=-1, random_state=17)
print("随机森林交叉验证分数:", np.mean(cross_val_score(forest, X_train, y_train, cv=5)))
forest_params = {'max_depth': range(8, 10),'max_features': range(5, 7)}
forest_grid = GridSearchCV(forest, forest_params, cv=5, n_jobs=-1, verbose=True)
forest_grid.fit(X_train, y_train)
print("随机森林最佳参数:", forest_grid.best_params_)
print("随机森林最佳分数:", forest_grid.best_score_)
print("随机森林准确率:", accuracy_score(y_holdout, forest_grid.predict(X_holdout)))


# 简单分类任务
# 生成数据
def form_linearly_separable_data(n=500, x1_min=0, x1_max=30, x2_min=0, x2_max=30):
    data, target = [], []
    for i in range(n):
        x1 = np.random.randint(x1_min, x1_max)
        x2 = np.random.randint(x2_min, x2_max)
        if np.abs(x1 - x2) > 0.5:
            data.append([x1, x2])
            target.append(np.sign(x1 - x2))
    return np.array(data), np.array(target)


X, y = form_linearly_separable_data()
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='autumn', edgecolors='black')

# 训练决策树并绘制分类边界
tree = DecisionTreeClassifier(random_state=17).fit(X, y)


def get_grid(X):
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1))
    return xx, yy


xx, yy = get_grid(X)
predicted = tree.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.pcolormesh(xx, yy, predicted, cmap='autumn')
plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='autumn', edgecolors='black', linewidth=1.5)
plt.title('Easy task. Decision tree compexifies everything')

# 可视化决策树
dot_data = StringIO()
export_graphviz(tree, feature_names=['x1', 'x2'], out_file=dot_data, filled=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(value=graph.create_png())

# 训练K近邻模型
knn = KNeighborsClassifier(n_neighbors=1).fit(X, y)
xx, yy = get_grid(X)
predicted = knn.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
plt.pcolormesh(xx, yy, predicted, cmap='autumn')
plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='autumn', edgecolors='black', linewidth=1.5)
plt.title('Easy task, kNN. Not bad')


# MNIST手写数字识别任务
# 加载数据
data = load_digits()
X, y = data.data, data.target

# 绘制MNIST手写数字
f, axes = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
for i in range(4):
    axes[i].imshow(X[i, :].reshape([8, 8]), cmap='Greys')

# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.3, random_state=17)

# 训练决策树和K近邻模型(随机参数)
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_neighbors=10))])
tree.fit(X_train, y_train)
knn_pipe.fit(X_train, y_train)

# 模型预测与评估
tree_pred = tree.predict(X_holdout)
knn_pred = knn_pipe.predict(X_holdout)
print("MNIST任务中决策树准确率(随机参数):", accuracy_score(y_holdout, tree_pred))
print("MNIST任务中K近邻准确率(随机参数):", accuracy_score(y_holdout, knn_pred))

# 决策树交叉验证调优
tree_params = {'max_depth': [10, 20, 30],'max_features': [30, 50, 64]}
tree_grid = GridSearchCV(tree, tree_params, cv=5, n_jobs=-1, verbose=True)
tree_grid.fit(X_train, y_train)
print("MNIST任务中决策树最佳参数:", tree_grid.best_params_)
print("MNIST任务中决策树最佳分数:", tree_grid.best_score_)

# K近邻交叉验证调优
print("MNIST任务中K近邻交叉验证分数:", np.mean(cross_val_score(KNeighborsClassifier(n_neighbors=1), X_train, y_train, cv=5)))

# 训练随机森林模型
print("MNIST任务中随机森林交叉验证分数:", np.mean(cross_val_score(RandomForestClassifier(random_state=17), X_train, y_train, cv=5)))


# 最近邻方法复杂情形
# 生成数据
def form_noisy_data(n_obj=1000, n_feat=100, random_seed=17):
    np.seed = random_seed
    y = np.random.choice([-1, 1], size=n_obj)
    x1 = 0.3 * y
    x_other = np.random.random(size=[n_obj, n_feat - 1])
    return np.hstack([x1.reshape([n_obj, 1]), x_other]), y


X, y = form_noisy_data()

# 划分数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(X, y, test_size=0.3, random_state=17)

# 训练K近邻模型并绘制验证曲线
cv_scores, holdout_scores = [], []
n_neighb = [1, 2, 3, 5] + list(range(50, 550, 50))

for k in n_neighb:
    knn_pipe = Pipeline([('scaler', StandardScaler()), ('knn', KNeighborsClassifier(n_neighbors=k))])
    cv_scores.append(np.mean(cross_val_score(knn_pipe, X_train, y_train, cv=5)))
    knn_pipe.fit(X_train, y_train)
    holdout_scores.append(accuracy_score(y_holdout, knn_pipe.predict(X_holdout)))

plt.plot(n_neighb, cv_scores, label='CV')
plt.plot(n_neighb, holdout_scores, label='holdout')
plt.title('Easy task. kNN fails')
plt.legend()

# 决策树训练与评估
tree = DecisionTreeClassifier(random_state=17, max_depth=1)
tree_cv_score = np.mean(cross_val_score(tree, X_train, y_train, cv=5))
tree.fit(X_train, y_train)
tree_holdout_score = accuracy_score(y_holdout, tree.predict(X_holdout))
print('Decision tree. CV: {}, holdout: {}'.format(tree_cv_score, tree_holdout_score))

2.MNIST手写数字识别

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

# 载入 MNIST 手写数字数据集
data = load_digits()
X, y = data.data, data.target

# 查看第一个样本的 8x8 矩阵形式
print(X[0, :].reshape([8, 8]))

# 绘制一些 MNIST 手写数字
f, axes = plt.subplots(1, 4, sharey=True, figsize=(16, 6))
for i in range(4):
    axes[i].imshow(X[i, :].reshape([8, 8]), cmap='Greys')
plt.show()

# 分割数据集
X_train, X_holdout, y_train, y_holdout = train_test_split(
    X, y, test_size=0.3, random_state=17)

# 使用随机参数训练决策树和 k-NN
tree = DecisionTreeClassifier(max_depth=5, random_state=17)
knn_pipe = Pipeline([('scaler', StandardScaler()),
                     ('knn', KNeighborsClassifier(n_neighbors=10))])

tree.fit(X_train, y_train)
knn_pipe.fit(X_train, y_train)

# 在留置集上做出预测并评估
tree_pred = tree.predict(X_holdout)
knn_pred = knn_pipe.predict(X_holdout)
tree_accuracy = accuracy_score(y_holdout, tree_pred)
knn_accuracy = accuracy_score(y_holdout, knn_pred)
print(f"决策树(随机参数)在留置集上的准确率: {tree_accuracy}")
print(f"k-NN(随机参数)在留置集上的准确率: {knn_accuracy}")

# 使用交叉验证调优决策树模型
tree_params = {'max_depth': [10, 20, 30],
               'max_features': [30, 50, 64]}

tree_grid = GridSearchCV(tree, tree_params,
                         cv=5, n_jobs=-1, verbose=True)

tree_grid.fit(X_train, y_train)

# 查看交叉验证得到的最佳参数组合和相应的准确率
best_tree_params = tree_grid.best_params_
best_tree_score = tree_grid.best_score_
print(f"决策树最佳参数: {best_tree_params}")
print(f"决策树最佳交叉验证准确率: {best_tree_score}")

# 使用交叉验证调优 k-NN 模型
knn_cv_score = np.mean(cross_val_score(KNeighborsClassifier(
    n_neighbors=1), X_train, y_train, cv=5))
print(f"调优后 k-NN 的交叉验证准确率: {knn_cv_score}")

# 训练随机森林模型
forest_cv_score = np.mean(cross_val_score(RandomForestClassifier(
    random_state=17), X_train, y_train, cv=5))
print(f"随机森林的交叉验证准确率: {forest_cv_score}")

结果:


http://www.kler.cn/a/551845.html

相关文章:

  • 【算法】快排
  • harbor安装教程
  • PHP支付宝--转账到支付宝账户
  • 数据结构:最小生成树
  • 个人博客测试报告
  • SMU寒假训练第三周周报
  • DeepSeek 和 ChatGPT 在特定任务中的表现:逻辑推理与创意生成
  • 告别冷冰冰:如何训练AI写出温暖人心的广告文案
  • 基于flask+vue的租房信息可视化系统
  • Redis 启用自动内存碎片清理异常
  • 【MySQL安装】
  • 3.5 使用Tokenizer编解码文本:从原理到企业级实践
  • Redis实战-扩展Redis
  • Windows服务器搭建时间同步服务
  • C++ 设计模式-代理模式
  • IDEA——Mac版快捷键
  • 禁止WPS强制打开PDF文件
  • 数据倾斜定义以及在Spark中如何处理数据倾斜问题
  • kafka的Docker镜像使用说明:wurstmeister/kafka
  • 亚马逊企业购大客户业务拓展经理张越:跨境电商已然成为全球零售电商领域中熠熠生辉的强劲增长点