Press "Enter" to skip to content

通俗易懂的机器学习——python手动实现DBSCAN聚类算法(不依赖已有框架)

手动实现DBSCAN算法

 

DBSCAN的工作原理

 

1.对于每个实例(即每个数据)都会计算在距离它一段距离的邻域中的实例数

 

2.如果在邻域中的实例数超过了最小样本数规定的阈值,则该实例被视为核心实例

 

3.核心实例邻域内的实例都视为同一个集群,即视为他们的类别相同

 

4.任何不是核心实例,并且在他的邻域中实例数没有超过最小样本数的实例被视为异常数据

 

通过工作原理我们可以看到,算法进行聚类的方式大致方式是进行广度优先搜索并对搜到的数据染色,下面我们开始一个DBSCAN算法的简单手动实现

 

DBSCAN手动实现

 

由于算法是运用广度优先搜索,所以需要用到队列,我们在这里使用numpy仅使用了其批量处理数据的功能,最后利用matplotlib进行画图。

 

from queue import Queue
import numpy as np
import matplotlib.pyplot as plt

 

手写DBSCAN类

 

class DBSCAN:
    def __init__(self, min_samples=10, r=0.15):
        self.min_samples=min_samples
        self.r=r
        self.X = None
        self.label = None
        self.n_class = 0
        
    def fit(self, X):
        self.X = X
        self.label = np.zeros(X.shape[0])
        q = Queue()
        for i in range(len(self.X)):
            if self.label[i] == 0:
                q.put(self.X[i])
                if self.X[(np.sqrt(np.sum((self.X - self.X[i]) ** 2, axis=1)) <= self.r) & (self.label==0)].shape[0] >= self.min_samples:
                    self.n_class += 1
                while not q.empty():
                    p = q.get()
                    neighbors = self.X[(np.sqrt(np.sum((self.X - p) ** 2, axis=1)) <= self.r) & (self.label==0)]
                    if neighbors.shape[0] >= self.min_samples:
                        mark = (np.sqrt(np.sum((self.X - p) ** 2, axis=1)) <= self.r)
                        self.label[mark] = np.ones(self.label[mark].shape) * self.n_class
                        # print(self.label)
                        for x in neighbors:
                            q.put(x)
                            
    def plot_dbscan_2D(self):
        plt.rcParams['font.sans-serif'] = ["SimHei"]
        plt.rcParams['axes.unicode_minus']=False
        for i in range(self.n_class+1):
            if i == 0:
                label = '异常数据'
            else:
                label = '第'+str(i) + '类数据'
            plt.scatter(self.X[self.label==i,0], self.X[self.label==i,1],label=label)
        plt.legend()
        plt.show()

 

可以看到这个类特别的简单,稍微了解过广度优先搜索的就可以很快理解

 

代码参数分析

 

min_samples:在邻域内要求的最小实例数
r:集群内一个实例到另一个实例的最大距离
n_class:表示通过聚类获得的集群的个数

 

代码测试

 

下面我们使用sklearn中的make_circles和make_moons生成数据集来测试聚类结果

 

from sklearn.datasets import make_moons, make_circles
X,_=make_circles(n_samples=1000,factor=0.5,noise=0.1)
db = DBSCAN(4, 0.15)
db.fit(X)
db.plot_dbscan_2D()
X,_ = make_moons(n_samples=1000, noise=0.05)
db = DBSCAN(10, 0.15)
db.fit(X)
db.plot_dbscan_2D()

 

运行结果:

从运行结果可以看出效果还是相当不错的,不过在使用make_circles产生的数据集上类别多分了两个,这种情况其实是由于异常数据和超参数产生的,可以通过调整超参数或者在搜索结束后之选用较大的集群来解决这个问题。我们这里仅讨论思维与实现方式,具体细节有兴趣的可以再自行修改

 

DBSCAN的特点以及应用场景

 

从代码测试可以看得出来,DBSCAN算法可以适用于任何数据集,不受数据分布情况的影响,同时时间复杂度为O(nlogn),所以他是一种快速并且适应性强的算法。

 

我们可以运用DBSCAN算法剔除异常数据,同时聚类的结果相当于给无标签数据打上了标签,可以用于无标签数据或标签不全的数据的预处理。

Be First to Comment

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注