CH3

编程入门 行业动态 更新时间:2024-10-11 05:30:08

CH3

CH3

KNN简介

KNN(k-Nearest Neighbors)又称作k-近邻。k-nn就是把未标记分类的案列归为与它们最相似的带有分类标记的案例所在的类。

KNN的特点

优点缺点
简单且有效不产生模型
训练阶段很快分类过程比较慢
对数据分布无要求模型解释性较差
适合稀疏时间和多分类问题名义变量和缺失数据需要额外处理

KNN模型

K近邻模型有三个基本要素:距离度量、K值的选择、分类决策规则

实现步骤
  1. 计算距离:计算待测案例与训练样本之间的距离 。
  2. 选择一个合适的k:确定用于KNN算法的邻居数量,一般用交叉验证或仅凭经验选择一个合适的k值,待测案例与训练样本之间距离最小的k个样本组成一个案例池。
  3. 类别判定:根据案例池的数据采用投票法或者加权投票法等方法来决定待测案例所属的类别。
KD-Tree

kd-tree是一种分割k维数据空间的数据结构。主要应用于多维空间数据的搜索,经常使用在SIFT、KNN等多维数据搜索的场景中,以KNN(K近邻)为例,使用线性搜索的方式效率低下,k-d树本质是对多维空间的划分,其每个节点都为k维点的二叉树kd-tree,因此可以大大提高搜索效率。详细的构造方法和kd树的最近邻搜索方法可以参考李航老师的《统计学习方法》。

1.定义Kd树类及其方法
package CH3_KNearestNeibor/*** Created by WZZC on 2019/11/29**/
/**** @param label 分类指标*  @param value 节点数据*  @param dim   当前切分维度*  @param left  左子节点*  @param right 右子节点*/
case class TreeNode(label: String,value: Seq[Double],dim: Int,var left: TreeNode,var right: TreeNode)extends Serializable {}object TreeNode {import statisticslearn.DataUtils.distanceUtils._/***创建KD 树** @param value* @param dim* @param shape* @return*/def creatKdTree(value: Seq[(String, Seq[Double])],dim: Int,shape: Int): TreeNode = {// 数据按照当前划分的维度排序val sorted: Seq[(String, Seq[Double])] = value.sortBy(tp2 => tp2._2(dim))//中间位置的索引val midIndex: Int = value.length / 2sorted match {// 当节点为空时,返回nullcase Nil => null//节点不为空时,递归调用方法case _ =>val left: Seq[(String, Seq[Double])] = sorted.slice(0, midIndex)val right: Seq[(String, Seq[Double])] =sorted.slice(midIndex + 1, value.length)val leftNode = creatKdTree(left, (dim + 1) % shape, shape) //左子节点递归创建树val rightNode = creatKdTree(right, (dim + 1) % shape, shape) //右子节点递归创建树TreeNode(sorted(midIndex)._1,sorted(midIndex)._2,dim,leftNode,rightNode)}}/*** 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。* 如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。* 然后通过stack回溯:* 如果当前点的距离比最近邻点距离近,更新最近邻节点.* 然后检查以最近距离为半径的圆是否和父节点的超平面相交.* 如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。* 如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中.* 当搜索回到root节点时,搜索完成,得到最近邻节点。** @param treeNode* @param data* @param k* @return*/def knn(treeNode: TreeNode, data: Seq[Double], k: Int = 1) = {//    implicit def vec2Seq(a:DenseVector[Double])=a.toArray.toSeqvar resArr = new Array[(Double, TreeNode)](k).map(_ => (Double.MaxValue, null)).asInstanceOf[Array[(Double, TreeNode)]]def finder(treeNode: TreeNode): TreeNode = {if (treeNode != null) {val dimr = data(treeNode.dim) - treeNode.value(treeNode.dim)if (dimr > 0) finder(treeNode.right) else finder(treeNode.left)val distc: Double = euclidean(treeNode.value, data)if (distc < resArr.last._1) {resArr.update(k - 1, (distc, treeNode))resArr = resArr.sortBy(_._1)}if (math.abs(dimr) < resArr.last._1)if (dimr > 0) finder(treeNode.left) else finder(treeNode.right)}resArr.last._2}finder(treeNode)resArr}}
2.Spark实现 Knn模型
package CH3_KNearestNeiborimport org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._/*** Created by WZZC on 2019/11/29**/
case class KnnModel(data: DataFrame, labelName: String) extends Serializable {private val spark = data.sparkSession//  import spark.implicits._// 使用.rdd的时候不能使用 col
//  private val sfadsfaggaggsagafasavsa: String = UUID.randomUUID().toStringprivate val ftsName: String = Identifiable.randomUID("KnnModel")// 数据特征名称private val fts: Array[String] = data.columns.filterNot(_ == labelName)val shapes: Int = fts.lengthdef vec2Seq = udf((vec: DenseVector) => vec.toArray.toSeq)/**** @param dataFrame* @return*/def dataTransForm(dataFrame: DataFrame) = {new VectorAssembler().setInputCols(fts).setOutputCol(ftsName).transform(dataFrame)}private val kdtrees: Array[TreeNode] = dataTransForm(data).withColumn(ftsName, vec2Seq(col(ftsName))).select(labelName, ftsName).withColumn("partitionIn", spark_partition_id()).rdd //在大数据情况下,分区构建kdtree.map(row => {val partitionIn = row.getInt(2)val label = row.getString(0)val features = row.getAs[Seq[Double]](1)(partitionIn, label, features)}).groupBy(_._1).mapValues(_.toSeq.map(tp3 => (tp3._2, tp3._3))).mapValues(nn => TreeNode.creatKdTree(nn, 0, shapes)).values.collect()/**** @param predictDf* @param k* @return*/def predict(predictDf: DataFrame, k: Int): DataFrame = {// 此处方法重载需要注意:overloaded method needs result typedef nsearchUdf = udf((seq: Seq[Double]) => predict(seq, k))dataTransForm(predictDf).withColumn(ftsName, vec2Seq(col(ftsName))).withColumn(labelName, nsearchUdf(col(ftsName))).drop(ftsName)}/**** @param predictData* @param k* @return*/def predict(predictData: Seq[Double], k: Int): String = {// 查询的时候遍历每个kdtree,然后取结果集再排序val res: Array[(Double, Seq[Double], String)] = kdtrees.map(node => {TreeNode.knn(node, predictData, k).map(tp2 => (tp2._1, tp2._2.value, tp2._2.label))}).flatMap(_.toSeq).sortBy(_._1).take(k)// 按照投票选举的方法选择分类结果val cl = res.map(tp3 => (tp3._3, 1)).groupBy(_._1).mapValues(_.map(_._2).sum).maxBy(_._2)._1cl}}
3.算法测试
package CH3_KNearestNeiborimport org.apache.spark.sql.SparkSession/*** Created by WZZC on 2019/11/29**/
object KNNRunner {def main(args: Array[String]): Unit = {val spark = SparkSession.builder().appName(s"${this.getClass.getSimpleName}").master("local[*]").getOrCreate()val iris = spark.read.option("inferSchema", true).option("header", true).csv("/data/iris.csv")val model: KnnModel = KnnModel(iris, "class")model.predict(iris, 3).show()spark.stop()}}

更多推荐

CH3

本文发布于:2024-03-24 00:05:39,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1744444.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!