Python编写决策树算法

编程入门 行业动态 更新时间:2024-10-27 22:19:31

Python编写决策树<a href=https://www.elefans.com/category/jswz/34/1770096.html style=算法"/>

Python编写决策树算法

main.py

import numpy as np
# pickle用于进行序列化与反序列化
# 序列化过程将文本信息转变为二进制数据流。这样就信息就容易存储在硬盘之中,
# 当需要读取文件的时候,从硬盘中读取数据,然后再将其反序列化便可以得到原始的数据。
import pickle
import os
import treePlotter# 创建训练数据
def CreateTrainingDataset():X = [[0, 2, 0, 0, 'N'],[0, 2, 0, 1, 'N'],[1, 2, 0, 0, 'Y'],[2, 1, 0, 0, 'Y'],[2, 0, 1, 0, 'Y'],[2, 0, 1, 1, 'N'],[1, 0, 1, 1, 'Y'],[0, 1, 0, 0, 'N'],[0, 0, 1, 0, 'Y'],[2, 1, 1, 0, 'Y'],[0, 1, 1, 1, 'Y'],[1, 1, 0, 1, 'Y'],[1, 2, 1, 0, 'Y'],[2, 1, 0, 1, 'N']]attributeList = ["age", "income", "student", "credit_rating"]return X, attributeList# 创建测试数据
def CreateTestDataset():X = [[0, 1, 0, 0],[0, 2, 1, 0],[2, 1, 1, 0],[0, 1, 1, 1],[1, 1, 0, 1],[1, 0, 1, 0],[2, 1, 0, 1]]attributeList = ["age", "income", "student", "credit_rating"]return X, attributeList# 计算类别的统计信息
def GetClassInfo(Dataset):    # 例如{'Y': 10, 'N':5}classInfo = {}for item in Dataset:if item[-1] not in classInfo.keys():classInfo[item[-1]] = 1else:classInfo[item[-1]] += 1classInfo = dict(sorted(classInfo.items(), key=lambda x: x[1], reverse=True))return classInfo# 计算最大占比类
def CalMostClass(classInfo):maxClass = list(classInfo.keys())[0]return maxClass# 计算数据集的信息熵
def ComputeEntropy(Dataset):ClassInfo = GetClassInfo(Dataset)entropy = 0amount = 0p = []  # p[]存放的是第k个类的数据个数for _, val in ClassInfo.items():p.append(val)amount += valfor pk in p:entropy -= (pk / amount) * np.log2(pk / amount)return entropy# 计算数据集在某个属性上的的信息增益Gain(attributeList)
# Gain(D, a)
def computeAttrGainNPartition(Dataset, attributeIndex):gain = ComputeEntropy(Dataset)  # Initialize:初始化等于数据集D的信息熵# 按属性的值划分数据集子集LEN_DATASET = len(Dataset)# attributePartition = {"attrVal1": [[], [] ,.., []], ..., "attrValn": [[], [] ,.., []]}attributePartition = {}for dataItem in Dataset:if dataItem[attributeIndex] not in attributePartition.keys():attributePartition[dataItem[attributeIndex]] = []attributePartition[dataItem[attributeIndex]].append(dataItem)else:attributePartition[dataItem[attributeIndex]].append(dataItem)amount = 0lenth = []Ent = []# 计算信息增益for key, valDataSet in attributePartition.items():Ent.append(ComputeEntropy(valDataSet))lenth.append(len(valDataSet))amount += len(valDataSet)for i in range(len(Ent)):gain -= (lenth[i] / LEN_DATASET) * Ent[i]return gain, attributePartition# 建决策树
def CreateDecisionTree(Dataset, attributeList):attrList = attributeListTree = {}classInfo = GetClassInfo(Dataset)LEN_DATASET = len(Dataset)# 建立叶子节点情况1:给定的属性集为空 ---- 不能划分if len(attributeList) == 0:return CalMostClass(classInfo)# 建立叶子节点情况2:给定的数据集所有label都相同 ---- 无需划分for key, valLen in classInfo.items():if valLen == LEN_DATASET:return keybreak# 建立叶子节点情况3:样本在属性集上取值都相等 ---- 无法划分temp = Dataset[0][:-1]sameCnt = 0for dataItem in Dataset:if temp == dataItem[:-1]:sameCnt += 1if sameCnt == LEN_DATASET:return CalMostClass(classInfo)# 选择最佳划分属性theBestAttrIndex = 0theBestAttrGain = 0theBestAttrPartition = {}for attributeIndex in range(len(attributeList)):gain, attributePartition = computeAttrGainNPartition(Dataset, attributeIndex)if gain > theBestAttrGain:theBestAttrGain = gaintheBestAttrIndex = attributeIndextheBestAttrPartition = attributePartitionattrName = attributeList[theBestAttrIndex]# python的list对象按索引删除对象,使用的是del()函数del (attributeList[theBestAttrIndex])# # 为了方便后面建子树,将此时的attr对应的那列去除for key, valList in theBestAttrPartition.items():for index in range(len(valList)):temp = valList[index][:theBestAttrIndex]temp.extend(valList[index][theBestAttrIndex + 1:])valList[index] = temp# 根据属性的值,建立分叉节点Tree[attrName] = {}for keyAttrVal, valDataset in theBestAttrPartition.items():# 因为python对iterable list对象的传参是按地址传参,会改变attributeList的值# 所以在传attributeList参数的时候,创建一个副本,就相当于按值传递了subLabels = attributeList[:]# valDataset是已去除attr的data,attributeList是已去除attr的attributeListTree[attrName][keyAttrVal] = CreateDecisionTree(valDataset, subLabels)return Tree# 测试做分类
def Predict(DataSet, testArrtList, decisionTree):predicted_label = []for dataItem in DataSet:cur_decisionTree = decisionTree# 如果root就是叶子结点leafif type(cur_decisionTree) == set:   # 例如:{'N'}node = list(cur_decisionTree)else:node = list(cur_decisionTree.keys())[0]# 只要temp处在attributeList,说明当前处在树枝结点(非叶子)上, 否则处在叶子结点while node in testArrtList:cur_index = testArrtList.index(node)  # 0 2cur_element = dataItem[cur_index]  # 0 0cur_decisionTree = cur_decisionTree[node][cur_element]  # {'student': {0: 'N', 1: 'Y'}} Nif type(cur_decisionTree) == dict:node = list(cur_decisionTree.keys())[0]  # studentelse:node = cur_decisionTreepredicted_label.append(node)return predicted_label# 将模型保存起来
def SaveModel(decisionTree, filename):# 由于pickle是将文本序列化成binary文件,故需用wbf = open(filename, 'wb')pickle.dump(decisionTree, f)# 读取模型
def LoadModel(filename):# 由于pickle读取的是binary文件,故需用rbf = open(filename, 'rb')return pickle.load(f)if __name__ == '__main__':base = os.path.dirname(os.path.abspath(__file__))trainingDataset, attributeList = CreateTrainingDataset()testDataset, testArrtList = CreateTestDataset()path = base + "/DecisionTreeModel.txt"print(path)# 建决策树decisionTree = CreateDecisionTree(trainingDataset, attributeList)# 保存模型SaveModel(decisionTree, path)# 读取模型model = LoadModel(path)print(model)# 对测试数据进行预测labelresult = Predict(testDataset, testArrtList ,model)print(result)treePlotter.createPlot(model)# 链接:   密码: uds1

treePlotter.py

  • 这个程序调用的是别人的程序,遗憾的是找不到出处了
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2019/1/28 下午 09:02
# @Author  : YuXin Chenimport matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \xytext=centerPt, textcoords='axes fraction', \va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = getTreeDepth(secondDict[key]) + 1else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalwplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalw = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalwplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()

更多推荐

Python编写决策树算法

本文发布于:2023-07-28 19:46:25,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1291685.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:算法   决策树   Python

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!