当前位置: 首页 > article >正文

【机器学习】 [代码篇] 30. KNN - sklearn 以及 自定义KNN 的实现

KNN - sklearn 以及 自定义KNN 的实现

  • 前言
  • Github 链接
  • 使用SKlearn 库完成KNN的训练以及预测
    • 1. 导入需要的库
    • 2. 加载数据
      • 2.1. 输出数据信息
    • 3. 分割训练集和测试集
    • 4. 可视化
    • 5. 创建模型并预测
  • 2. 自定义KNN模型并预测

前言

前面写完了理论篇,接下来补充代码。

机器学习使用sklearn会很简单,因此重点看下如何自定义实现。

KNN理论链接跳转

Github 链接

Github链接跳转

使用SKlearn 库完成KNN的训练以及预测

1. 导入需要的库

from IPython.display import set_matplotlib_formats, display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

2. 加载数据

from sklearn.datasets import load_iris
iris_dataset = load_iris()

2.1. 输出数据信息

print("Keys of iris_dataset:\n", iris_dataset.keys())
print(iris_dataset['DESCR'][:193] + "\n...")
print("Target names:", iris_dataset['target_names'])
print("Feature names:\n", iris_dataset['feature_names'])

3. 分割训练集和测试集

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset['data'], iris_dataset['target'], random_state=0)

4. 可视化

# label the columns using the strings in iris_dataset.feature_names
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# Create a scatter matrix from the dataframe, color by y_train
pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(16, 16),
                           marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8);

在这里插入图片描述

5. 创建模型并预测

from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
scaler = MinMaxScaler()#creating an object
scaler.fit(X_train)#calculate min and max value of the training data

X_train_norm = scaler.transform(X_train) #apply normalisation to the training set
X_test_norm = scaler.transform(X_test)

knn = KNeighborsClassifier(n_neighbors=40)
knn.fit(X_train_norm, y_train)
y_pred = knn.predict(X_test_norm) 
print("Accuracy on test set: {:.5f}".format(accuracy_score(y_pred, y_test)))

2. 自定义KNN模型并预测

import numpy as np
from collections import Counter

class KNN:
    def __init__(self, k=3, distance_metric='euclidean'):
        self.k = k
        self.distance_metric = distance_metric

    # define fit function
    def fit(self, X_train, y_train):
        self.X_train = np.array(X_train)
        self.y_train = np.array(y_train)

    # calculate distance
    def _compute_distance(self, x1, x2):
        if self.distance_metric == 'euclidean':
            return np.sqrt(np.sum((x1 - x2) ** 2))
        elif self.distance_metric == 'manhattan':
            return np.sum(np.abs(x1 - x2)) 
        else:
            raise ValueError("Unsupported distance metric")

    def predict(self, X_test):
        X_test = np.array(X_test)
        predictions = []

        for x in X_test:
            distances = [self._compute_distance(x, x_train) for x_train in self.X_train] # calculate all distance
            k_indices = np.argsort(distances)[:self.k]  # find the neaset  k points 
            k_nearest_labels = [self.y_train[i] for i in k_indices]  
            most_common = Counter(k_nearest_labels).most_common(1)[0][0]  # get the most common class
            predictions.append(most_common)

        return np.array(predictions)
    
    def score(self, X_test, y_test):
        y_pred = self.predict(X_test)
        return np.mean(y_pred == np.array(y_test))  # score

knn = KNN(k=40)
knn.fit(X_train_norm, y_train)
predictions = knn.predict(X_test_norm)
accuracy = knn.score(X_test_norm, y_test)

print(f"Predictions: {predictions}")
print(f"Accuracy: {accuracy}")

http://www.kler.cn/a/567792.html

相关文章:

  • WSL,Power shell 和CMD, Git bash的区别
  • FPGA:UART串口接收(高干扰情况)
  • 【PHP脚本语言详解】为什么直接访问PHP文件会显示空白?从错误示例到正确执行!
  • 2024 通用人工智能RAG大会实践资料(脱敏)PPT合集(22份)
  • matlab2023a下载和安装教程
  • GaussDB存储过程使用(一)
  • 第十五届蓝桥杯之宝石组合
  • 复习一下什么是restful风格
  • 蓝桥 发现环
  • React实现无缝滚动轮播图
  • C语言初始化结构体变量5种方式
  • 基于yolov8的农作物叶子病害检测系统python源码+onnx模型+评估指标曲线+精美GUI界面
  • Linux下网络运维命令总结
  • 嵌入式晶振细究
  • 【异地访问本地DeepSeek】Flask+内网穿透,轻松实现本地DeepSeek的远程访问
  • JavaWeb-ServletContext应用域接口
  • 在mingw64里面编译libdatachannel的步骤记录
  • idea中或pycharm中编写Markdown文件
  • ES 客户端 API 二次封装思想
  • 双机热备旁挂组网实验