机器学习2——决策树

蚀窗接口
• 阅读 2104
决策树(Decision Tree)是一种十分常用的分类方法,作为一个预测模型,决策树表示对象属性与对象值之间的一种映射关系。

1. 信息熵和信息增益

1.1 信息熵

公式表示为:机器学习2——决策树其中S表示样本集,c表示样本集合中类别个数,Pi表示第i个类别的概率。

  • 信息熵的意思就是一个变量i(就是这里的类别)可能的变化越多(只和值的种类多少以及发生概率有关),它携带的信息量就越大(因为是相加累计),即类别变量i的信息熵越大。
  • 二分类问题中,当X的概率P(X)为0.5时,表示变量的不确定性最大,此时的熵达到最大值1。
  • 信息熵反映系统的确定程度:信息熵越低,系统越确定;信息熵越高,系统越不确定

1.2 条件熵

公式表示为:机器学习2——决策树其中ti表示属性T的取值。条件熵的直观理解:假设单独计算明天下雨的信息熵:H(Y)=2,而在已知今天阴天情况下计算明天下雨的条件熵:H(Y|X)=0.5(熵变小,确定性变大,明天下雨的概率变大,信息量减少),这样相减后为1.5,在获得阴天这个信息后,下雨信息不确定性减少了1.5,信息增益很大,所以今天是否时阴天这个特征信息X对明天下雨这个随机变量Y的来说是很重要的。

1.3 信息增益

公式表示为:机器学习2——决策树
信息增益考察某个特征对整个系统的贡献。

2. 算法实现

2.0 数据集描述

通过“不浮出水面能否生存 no surfacing” 和 “是否有脚蹼 flippers”来判断5种海洋生物是否属于鱼类。
机器学习2——决策树

2.1 计算信息熵

from math import log

def calcInforEnt(dataSet):
    num = len(dataSet)
    labelCount = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCount.keys():
            labelCount[currentLabel] = 0
        labelCount[currentLabel] += 1  # 统计类别数目,labelCount = {'yes': 2, 'no': 3} 
    inforEnt = 0.0
    for key in labelCount:
        prob = float(labelCount[key]) / num
        inforEnt -= prob * log(prob, 2)
    return inforEnt

测试:

dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
calcInforEnt(dataSet)  # 0.9709505944546686

2.2 划分数据集

按照给定特征值划分数据集

def splitDataSet(dateSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

测试:

dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
splitDataSet(dataSet, 0, 1)  # [[1, 'yes'], [1, 'yes'], [0, 'no']]
splitDataSet(dataSet, 0, 0)  # [[1, 'no'], [1, 'no']]

2.3 选择最好的特征划分数据集

遍历整个数据集,循环计算信息熵和splitDataSet()函数,找到最好的特征划分方式。

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 数据集的最后一列表示类标签
    baseEntropy = calcInforEnt(dataSet)
    bestInforGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 取出每个属性的所有值,组成一个数组
        uniqueVals = set(featList)  # 去重
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy +=  prob * calcInforEnt(subDataSet)
        inforGain = bestInforGain - newEntropy
        if inforGain > bestInforGain:
            bestInforGain = inforGain
            bestFeature = i
    return bestFeature

测试:

dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
chooseBestFeatureToSplit(dataSet)  # 0

2.4 递归构建决策树

工作原理:得到原始数据集,基于最佳的属性划分数据集,由于属性存在两个或以上属性值,因此存在两个或以上的数据分支。第一次划分结束后,数据向下传递到树分支中,每个分支按照条件继续分叉。
递归结束条件:程序遍历完所有划分数据集的属性,或者每一个分支下的实例属于相同分类

import operator

def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): 
            classCount[vote] = 0
        classCount[vote] += 1
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]

# 创建树
def ctrateTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0] == len(classList)):
        return classList[0]  # 类型相同,停止划分
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)  # 遍历结束,返回出现频率最高的特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = ctrateTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree 

测试:

dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
labels = ['no surfacing', 'flippers']
ctrateTree(dataSet, labels)  # {'no surfacing': {0: 'no', 1: {'flippers':{0: 'no', 1: 'yes'}}}}

2.5 使用决策树进行分类

比较测试数据与决策树上的值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型。

def classify(inputTree, featLabels, testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key
            if type(secondDict[key]).__name__ == 'dict':  # 判断类型
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

测试:

inputTree = {'no surfacing': {0: 'no', 1: {'flippers':{0: 'no', 1: 'yes'}}}}
featLabels = ['no surfacing', 'flippers']
classify(inputTree, featLabels, [1, 0])  # 'no'
classify(inputTree, featLabels, [1, 1])  # 'yes'

3. 评价

以上决策树的ID3的实现方式,没有剪枝的步骤,容易发生过度拟合,导致决策树过高。C4.5决策树的改进策略:

  • 用信息增益率来选择属性,克服了用信息增益选择属性偏向选择多值属性的不足
  • 在构造树的过程中进行剪枝,参考剪枝算法
  • 对连续属性进行离散化
  • 能够对不完整的数据进行处理

4. 参考

点赞
收藏
评论区
推荐文章
HelloWorld官方 HelloWorld官方
4年前
C++ 基本语法
C程序可以定义为对象的集合,这些对象通过调用彼此的方法进行交互。现在让我们简要地看一下什么是类、对象,方法、即时变量。对象对象具有状态和行为。例如:一只狗的状态颜色、名称、品种,行为摇动、叫唤、吃。对象是类的实例。类类可以定义为描述对象行为/状态的模板/蓝图。方法从基本上说,一个方法表示一种行为。一个类可以包含多个
Stella981 Stella981
3年前
JavaScript原型深入浅出
不学会怎么处理对象,你在JavaScript道路就就走不了多远。它们几乎是JavaScript编程语言每个方面的基础。事实上,学习如何创建对象可能是你刚开始学习的第一件事。对象是键/值对。创建对象的最常用方法是使用花括号{},并使用点表示法向对象添加属性和方法。letanimal{}animal.name
Stella981 Stella981
3年前
Python 常用的ORM框架简介
ORM概念ORM(ObjectRalationalMapping,对象关系映射)用来把对象模型表示的对象映射到基于SQL的关系模型数据库结构中去。这样,我们在具体的操作实体对象的时候,就不需要再去和复杂的SQL语句打交道,只需简单的操作实体对象的属性和方法。ORM技术是在对象和关系之间提供了一条桥梁,前台的对象型数据和数据
Wesley13 Wesley13
3年前
Java面试题
91,什么是ORM?        对象关系映射(ObjectRelationalMapping,简称ORM)是一种为了解决程序的面向对象模型与数据库的关系模型互不匹配问题的技术;        简单的说,ORM是通过使用描述对象和数据库之间映射的元数据(在Java中可以用XML或者是注解),将程序中的对象自动持久化到关系数据库中或者将
Stella981 Stella981
3年前
JPA、Hibernate、Spring data jpa之间的关系,终于明白了
什么么是JPA?全称JavaPersistenceAPI,可以通过注解或者XML描述【对象关系表】之间的映射关系,并将实体对象持久化到数据库中。为我们提供了:1)ORM映射元数据:JPA支持XML和注解两种元数据的形式,元数据描述对象和表之间的映射关系,框架据此将实体对象持久化到数据库表中;如:@Entity、@Table、@C
Wesley13 Wesley13
3年前
Java中的Map集合
Map接口简介Map接口是一种双列集合,它的每个元素都包含一个键对象Key和值对象Value,键和值对象之间存在一种对应关系,称为映射。从Map集合中访问元素时,只要指定了Key,就能找到对应的Value,Map中的键必须是唯一的,不能重复,如果存储了相同的键,后存储的值会覆盖原有的值,简而言之就是键相同,值覆盖。Map常用
Stella981 Stella981
3年前
LightGBM 算法原理
LightGBM的动机GBDT(GradientBoostingDecisionTree)是机器学习中一个长盛不衰的模型,其主要思想是利用弱分类器(决策树)迭代训练以得到最优模型,该模型具有训练效果好、不易过拟合等优点。GBDT在工业界应用广泛,通常被用于点击率预测,搜索排序等任务而GBDT在每一次迭代的时
Wesley13 Wesley13
3年前
KNN分类算法原理分析及代码实现
1、分类与聚类的概念与区别分类:是从一组已知的训练样本中发现分类模型,并且使用这个分类模型来预测待分类样本。目前常用的分类算法主要有:朴素贝叶斯分类算法(NaïveBayes)、支持向量机分类算法(SupportVectorMachines)、KNN最近邻算法(kNearestNeighbors)、神经网络算法(NNet)以及决策树(De
Stella981 Stella981
3年前
GPU上的随机森林:比Apache Spark快2000倍
作者|AaronRichter编译|VK来源|TowardsDataScience随机森林是一种机器学习算法,以其鲁棒性、准确性和可扩展性而受到许多数据科学家的信赖。该算法通过bootstrap聚合训练出多棵决策树,然后通过集成对输出进行预测。由于其集成特征的特点,随机森林是一种可以在分布式计算环境中实现的算法。树可以在集群中跨进程和机器并
大数据——决策树(decision tree)
大数据————决策树(decisiontree)决策树(decisiontree):是一种基本的分类与回归方法,主要讨论分类的决策树。在分类问题中,表示基于特征对实例进行分类的过程,可以认为是ifthen的集合,也可以认为是定义在特征空间