当前位置: 首页 > article >正文

k均值聚类将数据分成多个簇

K-Means 聚类并将数据分成多个簇,可以使用以下方法:

实现思路

  1. 随机初始化 K 个聚类中心
  2. 计算每个点到聚类中心的距离
  3. 将点分配到最近的簇
  4. 更新聚类中心
  5. 重复上述过程直到收敛

完整代码:

import torch
import matplotlib.pyplot as plt

def kmeans(X, k, max_iters=100, tol=1e-4):
    """
    使用 PyTorch 实现 K-Means 聚类,并返回聚类结果
    :param X: (n, d) 输入数据
    :param k: 簇的个数
    :param max_iters: 最大迭代次数
    :param tol: 收敛阈值
    :return: (最终聚类中心, 每个样本的簇索引)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    X = X.to(device)

    n, d = X.shape
    indices = torch.randperm(n)[:k]  # 随机选择 k 个数据点作为初始聚类中心
    centroids = X[indices].clone()

    for i in range(max_iters):
        distances = torch.cdist(X, centroids)  # 计算所有点到聚类中心的欧式距离
        cluster_assignments = torch.argmin(distances, dim=1)  # 分配每个点到最近的簇

        new_centroids = torch.stack([
            X[cluster_assignments == j].mean(dim=0) if (cluster_assignments == j).sum() > 0
            else centroids[j]  # 避免空簇
            for j in range(k)
        ])

        shift = torch.norm(new_centroids - centroids, p=2)  # 计算变化量
        if shift < tol:
            print(f'K-Means 提前收敛于第 {i+1} 轮')
            break

        centroids = new_centroids

    return centroids.cpu(), cluster_assignments.cpu()

# 生成数据
torch.manual_seed(42)
X = torch.randn(200, 2)  # 200 个 2D 点
k = 3

# 运行 K-Means
centroids, labels = kmeans(X, k)

# 输出最终结果
print("最终聚类中心:")
print(centroids)

# 统计每个簇的样本数量
for i in range(k):
    count = (labels == i).sum().item()
    print(f"簇 {i} 的数据点数量: {count}")

# 可视化聚类结果
def plot_kmeans(X, labels, centroids, k):
    """
    可视化 K-Means 聚类结果
    :param X: 数据点
    :param labels: 聚类标签
    :param centroids: 聚类中心
    :param k: 簇的个数
    """
    X = X.numpy()
    labels = labels.numpy()
    centroids = centroids.numpy()

    plt.figure(figsize=(8, 6))

    # 画出每个簇的点
    colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
    for i in range(k):
        plt.scatter(X[labels == i, 0], X[labels == i, 1],
                    c=colors[i % len(colors)], label=f'Cluster {i}', alpha=0.6)

    # 画出聚类中心
    plt.scatter(centroids[:, 0], centroids[:, 1],
                c='black', marker='X', s=200, label='Centroids')

    plt.legend()
    plt.title("K-Means Clustering using PyTorch")
    plt.xlabel("Feature 1")
    plt.ylabel("Feature 2")
    plt.grid()
    plt.show()

# 绘制聚类结果
plot_kmeans(X, labels, centroids, k)

备注:

  • 初始化
    • 采用 torch.randperm(n)[:k] 选择 k 个数据点作为初始聚类中心。
  • 计算距离
    • torch.cdist(X, centroids) 计算所有点到各个聚类中心的欧式距离。
  • 分配簇
    • torch.argmin(distances, dim=1) 选择最近的聚类中心。
  • 更新中心
    • X[cluster_assignments == j].mean(dim=0) 计算每个簇的新中心。
    • 如果某个簇为空,保持原来的中心不变,避免空簇问题。
  • 判断收敛
    • torch.norm(new_centroids - centroids, p=2) 计算中心点的移动量,若小于阈值 tol,则提前终止。
  • 按簇分类数据
    • clusters = [X[labels == i] for i in range(k)] 将数据划分到不同簇。

http://www.kler.cn/a/525308.html

相关文章:

  • C++中常用的排序方法之——冒泡排序
  • DeepSeek-R1本地部署笔记
  • c++:vector
  • 消息队列篇--通信协议篇--应用层协议和传输层协议理解
  • csapp2.4节——浮点数
  • Fullcalendar @fullcalendar/react 样式错乱丢失问题和导致页面卡顿崩溃问题
  • 高级编码参数
  • 【Attention】KV Cache
  • TypeScript 学习 -类型 - 10
  • 快速提升网站收录:内容创作的艺术
  • 工具的应用——安装copilot
  • 高速PCB设计指南3——PCB 传输线和受控阻抗
  • 供应链系统设计-供应链中台系统设计(十)- 清结算中心概念片篇
  • Python3 【内置函数】:使用示例参考手册
  • JVM--类加载器
  • 超越传统图结构:记忆模拟新突破
  • C语言从入门到进阶
  • 【deepseek】本地部署DeepSeek R1模型:使用Ollama打造个人AI助手
  • 并发编程 - 线程同步(二)
  • 【2024年华为OD机试】 (A卷,200分)- 服务中心选址(JavaScriptJava PythonC/C++)
  • Python异步编程核武器:asyncio.gather() 的终极使用手册
  • 使用scikit-learn中的KNN包实现对鸢尾花数据集或者自定义数据集的的预测。
  • SpringBoot+Vue的理解(含axios/ajax)-前后端交互前端篇
  • 【开源免费】基于SpringBoot+Vue.JS社区智慧养老监护管理平台(JAVA毕业设计)
  • gif动画图像优化,相同的图在第2,4,6帧中重复出现,会增加图像体积吗?
  • 迭代推理机制提升AI精准性