当前位置: 首页 > article >正文

24/11/5 算法笔记 DBSCAN聚类算法

DBSCAN(Density-Based Spatial Clustering of Applications with Noise)是一种基于密度的聚类算法,它能够在具有噪声的空间数据库中发现任意形状的簇。

DBSCAN密度聚类思想
DBSCAN的聚类定义很简单:由密度可达关系导出的最大密度相连的样本集合,即为我们最终聚类的一个类别,或者说一个簇。
这个DBSCAN的簇里面可以有一个或者多个核心对象。如果只有一个核心对象,则簇里其他的非核心对象样本都在这个核心对象的ϵ \epsilonϵ-邻域里;如果有多个核心对象,则簇里的任意一个核心对象的ϵ \epsilonϵ-邻域中一定有一个其他的核心对象,否则这两个核心对象无法密度可达。这些核心对象的ϵ \epsilonϵ-邻域里所有的样本的集合组成的一个D B S C A N DBSCANDBSCAN聚类簇。
那么怎么才能找到这样的簇样本集合呢?DBSCAN使用的方法很简单,它任意选择一个没有类别的核心对象作为种子,然后找到所有这个核心对象能够密度可达的样本集合,即为一个聚类簇。接着继续选择另一个没有类别的核心对象去寻找密度可达的样本集合,这样就得到另一个聚类簇。一直运行到所有核心对象都有类别为止。

下面是一个纯用python实现的DBSCAN算法:

import numpy as np

class DBSCAN:
    def __init__(self,eps,min_samples):
        self.eps = eps
        self.min_samples = min_samples

    def fit(self,x):
        self.labels_ = np.full(shape = x.shape[0],fill_value = -1,dtype = int)
#-1表示噪声点
        self.cluster_id_ = 0  #簇标志
        self.x_ = x
        
        for i in range(x.shape[0]):
            if self.labels_[i] != -1:
                continue
            self._expand_cluster(i)
    def _expand_cluster(self,point_idx):
        neighbours = self._region_query(self.x_[point_idx])
        if len(neighbors) < self.min_samples:
            self.labels_[point_idx] = -1  # 标记为噪声点
            return
        
        self.cluster_id_ += 1
        self.labels_[point_idx] = self.cluster_id_
        neighbors_queue = list(neighbors)

        while len(neighbors_queue) > 0:
            neighbor = neighbors_queue.pop(0) #取出邻居点
            if self.labels_[neighbor] == -1:   #检查邻居点的标签
                self.labels_[neighbor] = self.cluster_id_
            if len(self._region_query(self.X_[neighbor])) >= self.min_samples: #区域查询
                for next_neighbor in self._region_query(self.X_[neighbor]): #扩展邻居点
                    if self.labels_[next_neighbor] == -1:
                        self.labels_[next_neighbor] = self.cluster_id_
                        neighbors_queue.append(next_neighbor)


    def _region_query(self,point):#用于查找给定点 point 在指定半径 eps 内的邻居点的核心操作
        return np.where((self.x_ - point) @ (self.x_point)<(self.eps**2)[0]

    
    def get_labels(self):
        return self.labels_

    if __name__ == "__main__":
    np.random.seed(0)
    X = np.random.rand(100, 2)

    dbscan = DBSCAN(eps=0.3, min_samples=5)
    dbscan.fit(X)

    labels = dbscan.get_labels()

    import matplotlib.pyplot as plt

    plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis', marker='o', s=50)
    plt.title('DBSCAN Clustering')
    plt.show()



下面解释每段代码:

fit 方法

def fit(self, X):
    self.labels_ = np.full(shape=X.shape[0], fill_value=-1, dtype=int)  # -1 表示噪声点
    self.cluster_id_ = 0
    self.X_ = X

    for i in range(X.shape[0]):
        if self.labels_[i] != -1:
            continue
        self._expand_cluster(i)

fit 方法是 DBSCAN 算法的主要入口点,它初始化聚类标签数组 labels_,所有点最初都被标记为 -1(表示噪声点)。然后,它遍历每个点,如果点还没有被访问过(即标签仍然是 -1),则调用 _expand_cluster 方法来扩展以该点为中心的聚类。

_expand_cluster 方法

def _expand_cluster(self, point_idx):
    neighbors = self._region_query(self.X_[point_idx])
    if len(neighbors) < self.min_samples:
        self.labels_[point_idx] = -1  # 标记为噪声点
        return

    self.cluster_id_ += 1
    self.labels_[point_idx] = self.cluster_id_
    neighbors_queue = list(neighbors)

    while len(neighbors_queue) > 0:
        neighbor = neighbors_queue.pop(0)
        if self.labels_[neighbor] == -1:
            self.labels_[neighbor] = self.cluster_id_
        if len(self._region_query(self.X_[neighbor])) >= self.min_samples:
            for next_neighbor in self._region_query(self.X_[neighbor]):
                if self.labels_[next_neighbor] == -1:
                    self.labels_[next_neighbor] = self.cluster_id_
                    neighbors_queue.append(next_neighbor)

_expand_cluster 方法用于扩展聚类。它首先查询给定点的邻居,如果邻居数量少于 min_samples,则将该点标记为噪声点。否则,它将点标记为一个新的聚类,并将其邻居添加到队列中。然后,它迭代地处理邻居的邻居,直到队列为空。

_region_query 方法,(核心)

def _region_query(self, point):
    return np.where((self.X_ - point) @ (self.X_ - point) < self.eps**2)[0]

用于找到给定点在 eps 距离内的邻居点。它通过计算点与所有其他点之间的欧氏距离,并返回距离小于 eps 的点的索引。

get_labels 方法

def get_labels(self):
    return self.labels_

get_labels 方法返回聚类结果的标签数组


http://www.kler.cn/a/387975.html

相关文章:

  • 如何在 Ubuntu 22.04 上安装 Caddy Web 服务器教程
  • 代码随想录 哈希 test 8
  • 计算机网络之---数据链路层的功能与作用
  • MySQL 如何赶上 PostgreSQL 的势头?
  • mybatisX插件的使用,以及打包成配置
  • 68.基于SpringBoot + Vue实现的前后端分离-心灵治愈交流平台系统(项目 + 论文PPT)
  • 高中诊断考如何影响高考?答案都在这 5 个方面
  • PySimpleGUI和Pymysql
  • 安全、高效、有序的隧道照明能源管理解决方案
  • uniapp配置消息推送unipush 厂商推送设置配置 FCM 教程
  • 了解云计算工作负载保护的重要性及必要性
  • 东胜物流软件 AttributeAdapter.aspx SQL 注入漏洞复现
  • 前端根据后端返回的文本流逐个展示文本内容
  • Java基础——类和对象的定义链表的创建,输出
  • 通过 ssh config 快速免密连接服务器
  • 【dvwa靶场:XSS系列】XSS (Reflected)低-中-高级别,通关啦
  • 【开发】Java的内存溢出
  • uni-app打包后报错云服务空间未关联
  • unity关于自定义渲染、内存管理、性能调优、复杂物理模拟、并行计算以及插件开发
  • [python] 如何debug python脚本中C++后端的core dump
  • 【嵌入式开发——ARM】1ARM架构
  • 牛客周赛 Round 67
  • Android 实现一个系统级的悬浮秒表
  • 基于 STM32 的天气时钟项目中添加天气数据的网络获取功能
  • Edge浏览器打开PDF无法显示电子签章
  • mac 本地docker-mysql主从复制部署