继续我的源码学习之旅,这次是mean-shift聚类算法里面的estimate_bandwidth函数。
estimate_bandwidth函数用作于mean-shift算法估计带宽,如果MeanShift函数没有传入bandwidth参数,MeanShift会自动运行estimate_bandwidth,源码地址
def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0,
n_jobs=1):
"""Estimate the bandwidth to use with the mean-shift algorithm.
That this function takes time at least quadratic in n_samples. For large
datasets, it's wise to set that parameter to a small value.
Parameters
----------
X : array-like, shape=[n_samples, n_features]
Input points.
quantile : float, default 0.3
should be between [0, 1]
0.5 means that the median of all pairwise distances is used.
n_samples : int, optional
The number of samples to use. If not given, all samples are used.
random_state : int or RandomState
Pseudo-random number generator state used for random sampling.
n_jobs : int, optional (default = 1)
The number of parallel jobs to run for neighbors search.
If ``-1``, then the number of jobs is set to the number of CPU cores.
Returns
-------
bandwidth : float
The bandwidth parameter.
"""
#根据random_state生成伪随机数生成器
random_state = check_random_state(random_state)
if n_samples is not None:
#permutation将序列打乱 并取n_samples个数的样本
idx = random_state.permutation(X.shape[0])[:n_samples]
X = X[idx]
#非监督方式进行近邻搜索
#quantile的值表示进行近邻搜索时候的近邻占样本的比例
nbrs = NearestNeighbors(n_neighbors=int(X.shape[0] * quantile),
n_jobs=n_jobs)
nbrs.fit(X)
bandwidth = 0.
#gen_batches(n,batch_size) 根据batch_size的大小生成0~n的切片
for batch in gen_batches(len(X), 500):
#kneighbors返回batch里面每个点的n_sample个邻居的距离(不包括自己)
#n_sample要是没有定义那就和NearestNeighbors里面的n_neighbors相等
#还有个返回值是下标,不过用不到就拿_忽略了
d, _ = nbrs.kneighbors(X[batch, :], return_distance=True)
#将每个点的最近的n_neighbors个邻居中最远的距离加起来
bandwidth += np.max(d, axis=1).sum()
#本质上就是求平均最远k近邻距离
return bandwidth / X.shape[0]
中文注释都是个人见解,如果有写的不到位的地方,欢迎大家评论区拍砖
更多推荐
scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth
发布评论