机器学习笔记——决策树(CART方法)

编程入门 行业动态 更新时间:2024-10-25 00:30:23

机器<a href=https://www.elefans.com/category/jswz/34/1770117.html style=学习笔记——决策树(CART方法)"/>

机器学习笔记——决策树(CART方法)

# -*- coding: utf-8 -*-
"""
Created on Sat Aug 24 11:14:58 2019@author:wangtao_zuelE-mail:wangtao_zuel@126决策树CART方法"""import numpy as np
import pandas as pddef loadData(filepath,fileType):"""返回矩阵形式的数据,当样本类别为str类型时,应当相应修改样本读取方式"""if fileType == 'xlsx':data = pd.read_excel(filepath)elif fileType == 'csv':data = pd.read_csv(filepath)else:data = pd.read_csv(filepath,sep='\t',header=None)data = np.mat(data)return datadef binSplitDataSet(dataSet,featInd,featVal):"""按照特征(序号)、特征值将样本二分,这里统一将小于的部分放在左边(matL)"""matL = dataSet[np.nonzero(dataSet[:,featInd] <= featVal)[0],:]matR = dataSet[np.nonzero(dataSet[:,featInd] > featVal)[0],:]return matL,matRdef regLeaf(dataSet):"""叶节点创建这里返回分支下的分类平均值,适用于回归情况"""return np.mean(dataSet[:,-1])def maxLeaf(dataSet):"""叶节点创建这类返回最多的分类"""results = uniqueCount(dataSet)return max(results,key=results.get)def uniqueCount(dataMat):"""统计各类别样本个数注意这里使用的是矩阵类数据,若使用其他类型数据需修改遍历循环部分“dataSet[:,-1].T.tolist()[0]”"""results = {}for sample in dataMat[:,-1].T.tolist()[0]:if sample not in results:results[sample] = 0results[sample] += 1return resultsdef regErr(dataSet):"""误差计算这里使用的是平方误差,适合回归情况"""var = np.var(dataSet[:,-1])m = dataSet.shape[0]err = m*varreturn errdef entErr(dataSet):"""香农熵计算误差(混乱程度)"""results = uniqueCount(dataSet)sampleNum = dataSet.shape[0]shannonEnt = 0.0for key in results:prob = float(results[key])/sampleNumshannonEnt -= prob*np.log2(prob)return shannonEntdef giniErr(dataSet):"""基尼不纯度计算误差(混乱程度)"""sampleNum = dataSet.shape[0]results = uniqueCount(dataSet)imp = 0.0for k1 in results:p1 = float(results[k1])/sampleNumfor k2 in results:if k1 == k2:continuep2 = float(results[k2])/sampleNumimp += p1*p2return impdef chooseBestSplit(dataSet,leafType,errType,ops):"""筛选最优分类特征、特征值"""# 预剪枝参数,当优化(误差减小)过小或者分类太细(分支下样本数量太少),选择忽略minErr = ops[0]minNum = ops[1]# 若某分支下样本均为同一类,则返回建立叶节点if len(set(dataSet[:,-1].T.tolist()[0])) == 1:return None,leafType(dataSet)m,n = dataSet.shape# 不分类误差basicErr = errType(dataSet)bestErr = np.infbestInd = 0bestVal = 0# 获取最小误差for featInd in range(n-1):for featVal in set(dataSet[:,featInd].T.tolist()[0]):matL,matR = binSplitDataSet(dataSet,featInd,featVal)# 判断分支下样本数目是否过小,预剪枝的一部分if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):continuenewErr = errType(matL) + errType(matR)if newErr < basicErr:bestInd = featIndbestVal = featValbestErr = newErr# 若优化太小,分类和不分类相差不大,则忽略优化,其实这部分也是预剪枝的一部分,if (basicErr - bestErr) < minErr:return None,leafType(dataSet)# 二次判断,和前面的部分并未冲突,这部分用于处理没有最优分类特征、特征值的情况matL,matR = binSplitDataSet(dataSet,bestInd,bestVal)if (matL.shape[0] < minNum) or (matR.shape[0] < minNum):return None,leafType(dataSet)return bestInd,bestValdef creatTree(dataSet,leafType,errType,ops):"""递归创建树"""# 选择最优的分类特征、特征值spInd,spVal = chooseBestSplit(dataSet,leafType,errType,ops)# 创建叶节点情况if spInd == None:return spVal# 创建子树tree = {}tree['spInd'] = spIndtree['spVal'] = spVal# 递归得到子分支树matL,matR = binSplitDataSet(dataSet,spInd,spVal)tree['left'] = creatTree(matL,leafType,errType,ops)tree['right'] = creatTree(matR,leafType,errType,ops)return tree"""
# 后剪枝操作
"""def isTree(obj):"""判断分支下是否为子树,是则返回True"""return (type(obj).__name__=='dict')def getMean(tree):"""塌陷处理,返回左右分支的平均值作为上一节点的值"""if isTree(tree['left']):return getMean(tree['left'])if isTree(tree['right']):return getMean(tree['right'])return (tree['left']+tree['right'])/2def regPrune(tree,testData):"""递归后剪枝,需要一定数量的测试集,最好数量和样本集相同注意这种剪枝方法适合用于结果是连续型数据(按平均值塌陷不太适合分类,因为类别是固定的)"""# 若无测试集,则做塌陷处理if testData.shape[0] == 0:return getMean(tree)# 判断节点下是否为子树,若为子树则进一步细分处理,直至节点下均为叶节点if (isTree(tree['left'])) or (isTree(tree['right'])):lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])if isTree(tree['left']):tree['left'] = regPrune(tree['left'],lSet)if isTree(tree['right']):tree['right'] = regPrune(tree['right'],rSet)# 当节点下都为叶节点时,判断是否进行合并处理if (not isTree(tree['left'])) and (not isTree(tree['right'])):lSet,rSet = binSplitDataSet(testData,tree['spInd'],tree['spVal'])# 计算未合并时的误差(混乱程度)notMergeErr = sum(np.power(lSet[:,-1]-tree['left'],2)) + sum(np.power(rSet[:,-1]-tree['right'],2))treeMerge = (tree['left']+tree['right'])/2mergeErr = sum(np.power(testData[:,-1]-treeMerge,2))if mergeErr < notMergeErr:print("Merging!")return treeMergeelse:return tree# 若节点下不全为叶节点,则不执行合并剪枝操作else:return treedef outJudge(dataSet,tree):"""遍历判断样本外数据类型"""outputData = pd.DataFrame(dataSet)classResults = []for ii in range(dataSet.shape[0]):result = judgeType(dataSet[ii,:].A[0],tree)classResults.append(result)outputData['classResults'] = classResultsoutputData.to_excel('./data/machine_learning/mytree.xlsx',index=False,encoding='utf-8-sig')print("样本外数据分类(判断)完成!")def judgeType(data,tree):"""递归判断分类"""spInd = tree['spInd']spVal = tree['spVal']if data[spInd] <= spVal:# 若节点下为子树则递归,否则返回叶节点的值if isTree(tree['left']):return judgeType(data,tree['left'])return tree['left']else:if isTree(tree['right']):return judgeType(data,tree['right'])return tree['right']def treeCart(trainDataPath,outDataPath='',testDataPath='',leafType=regLeaf,errType=regErr,ops=(1,4),prune=False,fileType='txt'):"""主函数,参数含义:trainDataPath:训练集数据路径outDataPath:样本外数据路径testDataPath:测试集数据路径,当需要后剪枝操作时需输入leafType:创建叶节点方式errType:误差(混乱程度)计算方式ops:预剪枝参数,第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支prune:是否进行后剪枝操作fileType:训练集、测试集数据类型(xlsx、txt、csv),txt文件需以制表符\t为分割"""dataMat = loadData(trainDataPath,fileType)try:myTree = creatTree(dataMat,leafType,errType,ops)if prune:testData = loadData(testDataPath,fileType)myTree = regPrune(myTree,testData)print('决策树构建完成!')print(myTree)else:print('决策树构建完成!')print(myTree)# 预测(分类操作)if outDataPath != '':outData = loadData(outDataPath,fileType)outJudge(outData,myTree)except:print("检查是否正确输入参数!")print('请在函数treeCart中输入叶节点创建方式参数:\n\t1、按平均值创建:leafType=regLeaf\n\t2、按最多样本创建:leafType=maxLeaf')print('请在treeCart中输入误差计算方式参数:\n\t1、香农熵:errType=entErr\n\t2、基尼不纯度:errType=giniErr\n\t3、平方误差:regErr')print('请在treeCart中输入预剪枝参数ops:\n\t其中第一个元素表示能忽略的最小误差,第二个元素表示当某分支下样本数小于该元素时,不考虑建立该分支')print('示例:treeCart(trainDataPath,leafType=regLeaf,errType=regErr,ops=(1,4))')

更多推荐

机器学习笔记——决策树(CART方法)

本文发布于:2024-03-12 03:29:56,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1730593.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:学习笔记   机器   方法   决策树   CART

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!