当前位置: 首页 > article >正文

聚类_K均值

import numpy as np
import matplotlib.pyplot as plt 
from sklearn.datasets import make_blobs

1.数据预处理

#创建基于高斯分布的样本点, x是点的坐标,y是所属聚类值
x, y = make_blobs(n_samples=100, centers=6, random_state=100, cluster_std=0.6)
# 设置图形尺寸,单位英寸
plt.figure(figsize=(6,6))
plt.scatter(x[:,0], x[:,1],c = y)
plt.show
<function matplotlib.pyplot.show(close=None, block=None)>

在这里插入图片描述

2.模型实现

from scipy.spatial.distance import cdist

class KMeansModel():
    #参数k聚类数, 最大迭代次数,初始质心
    def __init__(self, k_cluster=6, max_iteration=100, centroids=[]):
        self.k_cluster = k_cluster
        self.max_iteration = max_iteration
        self.centroids = np.array(centroids, dtype = np.float32)
        
        
        
    def fit(self, points):
        # 随机选取初始质心点
        if(self.centroids.shape==(0,)):
            self.centroids = points[np.random.randint(0, points.shape[0], self.k_cluster), :]
        for i in range(self.max_iteration):
            #计算所有测试点和所有质心的距离,返回100*6的矩阵
            distances = cdist(points, self.centroids)
            #选取行方向最小的书作为测试点的质心
            c_index = np.argmin(distances, axis=1)
            if(i == 0):
                print("c shape", c_index.shape,c_index[0])
            #计算每类数据的均值作为新的质心
            for i in range(self.k_cluster):
                if i in c_index:
                    self.centroids[i] = np.mean(points[c_index == i], axis=0)
    
    def predict(self, points):
        distances = cdist(points, self.centroids)
        
        #选取距离最近的质心作为分类
        c_index = np.argmin(distances, axis=1)
        return c_index

3.测试

def plot_kmeans(x, y, centroids, subplot):
    plt.subplot(subplot)
    plt.scatter(x[:,0], x[:,1], c=y)
    plt.scatter(centroids[:,0], centroids[:,1],s=100,c='r')

# 训练
kmean_model = KMeansModel(centroids=np.array([[1,1],[2,2],[3,3],[4,4],[5,5],[6,6]]))
plt.figure(figsize=(18,8))
plot_kmeans(x, y, kmean_model.centroids, 121)
           
kmean_model.fit(x)
print(kmean_model.centroids)
plot_kmeans(x, y, kmean_model.centroids, 122)

#预测
x_new = np.array([[10,7],[0,0]])
y_predict = kmean_model.predict(x_new)
print("predict y ", y_predict)
plt.scatter(x_new[:,0],x_new[:,1],s=100, c= "black")
c shape (100,) 0
[[ 4.343336  -5.112518 ]
 [-1.6609049  6.7436223]
 [-8.57988   -3.3460388]
 [ 2.7469435  6.05025  ]
 [ 2.490612   7.7450833]
 [ 4.1287684  6.6914167]]
predict y  [5 3]





<matplotlib.collections.PathCollection at 0x1576e5a9850>

在这里插入图片描述


http://www.kler.cn/a/315444.html

相关文章:

  • Mysql--运维篇--备份和恢复(逻辑备份,mysqldump,物理备份,热备份,温备份,冷备份,二进制文件备份和恢复等)
  • Hadoop3.x 万字解析,从入门到剖析源码
  • MySQL主从:如何处理“Got Fatal Error 1236”或 MY-013114 错误(percona译文)
  • 2025宝塔API一键建站系统PHP源码
  • IDEA编译器集成Maven环境以及项目的创建(2)
  • TiDB常见操作指南:从入门到进阶
  • 基于 Web 的工业设备监测系统:非功能性需求与标准化数据访问机制的架构设计
  • git重置本地提交与远程保持一致
  • 阅读笔记——《围城》
  • git 版本管理的常用命令
  • c++249多态
  • 【计算机网络篇】计算机网络概述
  • 安全第一:API 接口接入前的防护性注意要点
  • Java21 中的虚拟线程
  • 校园美食猎人:Spring Boot技术的美食探索应用
  • xxl-job适配sqlite本地数据库及mysql数据库。可根据配置指定使用哪种数据库。
  • 鸿蒙OS 线程间通信
  • 【VLM小白指北 (1) 】An Introduction to Vision-Language Modeling
  • CTFShow-反序列化
  • 聚焦API安全未来,F5打造无缝集成的解决方案
  • 2024年中国研究生数学建模竞赛D题大数据驱动的地理综合问题
  • harbor集成trivy镜像扫描工具
  • 模仿抖音用户ID加密ID的算法MB4E,提高自己平台ID安全性
  • C# Winform调用控制台程序(通过Process类)
  • Java设计模式(单例模式)——单例模式存在的问题(完整详解,附有代码+案例)
  • svn 1.14.5