scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth

编程入门 行业动态 更新时间:2024-10-27 13:29:21

继续我的源码学习之旅,这次是mean-shift聚类算法里面的estimate_bandwidth函数。
estimate_bandwidth函数用作于mean-shift算法估计带宽,如果MeanShift函数没有传入bandwidth参数,MeanShift会自动运行estimate_bandwidth,源码地址

def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0,
                       n_jobs=1):
    """Estimate the bandwidth to use with the mean-shift algorithm.

    That this function takes time at least quadratic in n_samples. For large
    datasets, it's wise to set that parameter to a small value.

    Parameters
    ----------
    X : array-like, shape=[n_samples, n_features]
        Input points.

    quantile : float, default 0.3
        should be between [0, 1]
        0.5 means that the median of all pairwise distances is used.

    n_samples : int, optional
        The number of samples to use. If not given, all samples are used.

    random_state : int or RandomState
        Pseudo-random number generator state used for random sampling.

    n_jobs : int, optional (default = 1)
        The number of parallel jobs to run for neighbors search.
        If ``-1``, then the number of jobs is set to the number of CPU cores.

    Returns
    -------
    bandwidth : float
        The bandwidth parameter.
    """

    #根据random_state生成伪随机数生成器
    random_state = check_random_state(random_state)
    if n_samples is not None:
        #permutation将序列打乱 并取n_samples个数的样本
        idx = random_state.permutation(X.shape[0])[:n_samples]
        X = X[idx]
    #非监督方式进行近邻搜索
    #quantile的值表示进行近邻搜索时候的近邻占样本的比例
    nbrs = NearestNeighbors(n_neighbors=int(X.shape[0] * quantile),
                            n_jobs=n_jobs)
    nbrs.fit(X)

    bandwidth = 0.
    #gen_batches(n,batch_size) 根据batch_size的大小生成0~n的切片
    for batch in gen_batches(len(X), 500):
        #kneighbors返回batch里面每个点的n_sample个邻居的距离(不包括自己)
        #n_sample要是没有定义那就和NearestNeighbors里面的n_neighbors相等
        #还有个返回值是下标,不过用不到就拿_忽略了
        d, _ = nbrs.kneighbors(X[batch, :], return_distance=True)
        #将每个点的最近的n_neighbors个邻居中最远的距离加起来
        bandwidth += np.max(d, axis=1).sum()
    #本质上就是求平均最远k近邻距离
    return bandwidth / X.shape[0]

中文注释都是个人见解,如果有写的不到位的地方,欢迎大家评论区拍砖

更多推荐

scikit-learn源码学习之cluster.mean_shift.estimate_bandwidth

本文发布于:2023-06-14 05:53:00,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1444618.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:源码   learn   scikit   estimate_bandwidth   mean_shift

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!