程序员的自我修养
Home » Apache Spark, 机器学习 » Spark MLlib之决策树(上)

Spark MLlib之决策树(上)

8条评论16,968次浏览

决策树

决策树是常用的分类算法之一,其对于探索式的知识发现往往有较好的表现。决策树原理十分简单,可处理大维度的数据,不用预先对模型的特征有所了解,这些特性使得决策树被广泛使用。决策树采用贪心算法,其建立过程同样需要训练数据。决策树算法有ID3、在ID3基础上发展起来的C4.5,以及C4.5的商业化版本C5.0,C5.0核心与C4.5相同,只是在执行效率和内存使用方面有所改进。

决策树的核心问题是决策树分支准则的确定,以及分裂点的确定。为了直观起见,推荐大家玩一个游戏:通过20个问题来猜出你心中所想的那个人

初次接触这个游戏的你是否觉得十分神奇,在20个不到的问题里真的就能猜出你心中所想的那个人,不论是你的女朋友、父母或者动漫人物、歌手、演员甚至是政界人物。其实仔细想想,一个20层的二叉树最后的叶子节点有多少个?1024*1024个,而我们能想到的人绝对是超不出这个数量的。这个网站的具体算法就是采用的类似决策树的算法,通过一个个问题来减少候选的数据,直至找出你所想的那个人。

多玩几次你就会发现,一般第一个或前几个问题就会问你:你描述的对象是男(女)性吗?这意味着什么,意味着第一个问题就能将候选数据减少一半左右。因为你想的那个人,除了男人就是女人了。这就是前面所说的决策树分支准则的确定。若将这个问题放在最后几个问题中,毫无疑问是个吃力不讨好的事情。那么如何才能将这些众多属性(如:性别、高矮、胖瘦、头发长短、是否是歌手、是否有money等)按照其重要程度来排个顺序,这就是ID3和C4.5算法所做的事情了。

预备知识

信息熵是信息论中的基本概念。信息论是C.E.Shannon于1948年提出并由此发展起来的,主要用于解决信息传递过程中的问题,也称为统计通信理论。信息论认为:信息是用来消除随机不确定性的,信息量的大小可由所消除的不确定大小来计量。详细了解

信息量的数学定义为:

\(I(u_i)=-log_2P(u_i)\)

其中 \(P(u_i)\) 为信息 \(u_i\) 发生的概率。信息熵是信息量的数学期望,是信源发出信息前的平均不确定性,也成为先验熵,信息熵的数学定义为:

\(Ent(U)=-\sum_iP(u_i)log_2P(u_i)\)

当已知信号 \(U\) 的概率分布 \(P(U)\) 且收到信号 \(V=v_i\) 后,发出信号的概率分布变为 \(P(U|v_j)\) ,于是信源的平均不确定性变为(也称为条件熵):

\(Ent(U|v_i)=-\sum_iP(u_i|v_i)log_2P(u_i|v_i)\)

一般来说, \(Ent(U|v_i) < Ent(U)\) ,于是定义信息增益为:

\(Gains(U,V)=Ent(U)-Ent(U|V)\)

ID3

ID3算法的主要思想就是每次计算出各个属性的信息增益,选择最大者为分裂属性。下面举例说明,为简单起见,随机杜撰了10条数据,分为2个维度:

性别(T1) 1 0 1 0 1 0 1 0 1 1
套餐类别(T2) A B A A C C B A A C
是否购买 true false true false true true false true true true

信息熵

根据公式,信息熵计算方式如下:

\(Ent(U)=-\sum_iP(u_i)log_2P(u_i)=-{7\over 10}log_2({7\over 10})-{3\over 10}log_2({3\over 10})=0.881\)

条件熵

\(Ent(U|T_1)={6\over 10}(-{5\over 6}log_2({5\over 6})-{1\over 6}log_2({1\over 6}))+{4\over 10}(-{2\over 4}log_2({2\over 4})-{2\over 4}log_2({2\over 4}))=0.790\)

\(Ent(U|T_2)={5\over 10}(-{4\over 5}log_2({4\over 5})-{1\over 5}log_2({1\over 5}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{3\over 10}(-{3\over 3}log_2({3\over 3})-{0\over 3}log_2({0\over 3}))=0.361\)

信息增益

\(Gains(U,T_1)=Ent(U)-Ent(U|T_1)=0.091\)

\(Gains(U,T_2)=Ent(U)-Ent(U|T_2)=0.520\)

根据ID3的算法,目前来说这种情况下将会选择T2作为最佳分组变量,因为它消除信宿对信源的平均不确定性的能力最强。

C4.5

C4.5算法主要为了解决ID3算法中的一些问题。例如当类别值多的输入变量比类别值少的输入变量有更多的机会成为当前最佳分组变量。将上述数据中的T2维度中的A类别划分为2个子类别A1和A2,输入如下:

性别(T1) 1 0 1 0 1 0 1 0 1 1
套餐类别(T2) A1 B A2 A1 C C B A2 A1 C
是否购买 true false true false true true false true true true

再来计算T2的条件熵:

\(Ent(U|T_2)={3\over 10}(-{2\over 3}log_2({2\over 3})-{1\over 3}log_2({1\over 3}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{3\over 10}(-{3\over 3}log_2({3\over 3})-{0\over 3}log_2({0\over 3}))=0.275\)

新的信息增益:

\(Gains(U,T_2)=Ent(U)-Ent(U|T_2)=0.606\)

可以看到比调整T2类别之前的信息增益要大了。为了消除这种不公平的现象,C4.5采用新的方式来定义信息增益:

\(GainsR(U,V)={Gains(U,V)/Ent(V)}\)

使用该公式来计算改变前和改变后T2的信息增益率:

\(GainsR(U,T_2)={Gains(U,T_2)/Ent(T_2)}=0.520\div (-{5\over 10}log_2({5\over 10})-{2\over 10}log_2({2\over 10})-{3\over 10}log_2({3\over 10}))=0.350\)

\(GainsR(U,T_2')={Gains(U,T_2')/Ent(T_2')}=0.606\div (-{3\over 10}log_2({3\over 10})-{2\over 10}log_2({2\over 10})-{2\over 10}log_2({2\over 10})-{3\over 10}log_2({3\over 10}))=0.307\)

可以看到T2改变后的信息增益率小于改变前的。

Spark实现决策树

训练数据

附件下载:sample_tree_data

数据格式:第一列为标签,表示类别;后面的列为维度。

代码实现

最后

目前Spark还没有Java版本的决策树,后面版本中可能会添加进来。Scala版本的决策树我这边运行起来报了一个这样的错:java.lang.NoSuchMethodError: scala.collection.Iterator.aggregate。目前对Scala还只是初步了解,不清楚出错的具体原因,看信息貌似是scala的iterator没有aggregate方法,查看scala的源码也没有找到该方法。最后尝试从2.8-2.11.2版本的Scala,发现都不能正确输出。

关于如何寻找分裂点不属于ID3和C4.5算法的范畴,还有决策树的裁剪也是一个比较重要的部分,希望以后有机会补充。

参考文献

(转载本站文章请注明作者和出处 程序员的自我修养 – SelfUp.cn ,请勿用于任何商业用途)
8条评论
  1. soso说道:

    尝试一下用2.10.X的scala跑一下

  2. 匿名说道:

    将numberClasses的值设置为大于实际的类别的数就可以了

  3. 匿名说道:

    你好,运行你的程序一直报空指针(数据也是用的spark中的数据),报错行为:
    val prediction = model.predict(point.features)
    推测为model为空指针,但是如果
    val labelAndPreds = parsedData.collect().map { point =>
    val feature = point.features
    val prediction = model.predict(feature)
    (point.label, prediction)
    }
    这样就能正常运行

  4. 匿名说道:

    赞一个

  5. EVIL说道:

    我在运行完C4.5的代码后,显示
    defined object DecisionTreeTest
    是什么意思?这是有错误吗?运行结果在哪里看?

发表评论


profile
  • 文章总数:79篇
  • 评论总数:254条
  • 分类总数:31个
  • 标签总数:44个
  • 运行时间:1192天

大家好,欢迎来到selfup.cn。

这不是一个只谈技术的博客,这里记录我成长的点点滴滴,coding、riding and everthing!

最新评论
  • Anonymous: :arrow: :neutral: :cry:
  • Anonymous: java.io.NotSerializableExcepti on: DStream checkpointing has been enabled but the DStreams with their...
  • wick: HI,请问一下,U,S,V得到后,怎么得到近似矩阵呢(用sp ark java),谢谢。
  • Michael Whitaker: Thank you for this blog, it was very helpful in troubleshooting my own issues. It seems that no...
  • Anonymous: :mad:
  • Anonymous: :???:
  • Anonymous: :mad: :mad: :mad:
  • 洋流: 哥们,我问个问题,你把testOnborrow去掉了。。如果 得到的jedis资源是个不可用的,服务从来都不出问题么?
  • 洋流: 哥们,我问个问题,你把testOnborrow去掉了。。如果 得到的jedis资源是个不可用的,服务从来都不出问题么?
  • Anonymous: :razz: :evil: :grin:
  • 张瑞昌: 有很多,比较常见的是Jacob迭代法,一次迭代O(n^3), 迭代次数不清楚。 如果是手动算的话按照定义求就可以了
  • Anonymous: :mrgreen:
  • lc277: 你好 我想问下一般删除节点要多久,要删除的datanode大概用了 1t,解除授权已经30多小时还没完成,请问是出现什么问题了吗 麻烦告诉下谢谢 qq1844554123
  • Anonymous: 你好 我想问下一般删除节点要多久,要删除的datanode大概用了 1t,解除授权已经30多小时还没完成,请问是出现什么问题了吗
  • Anonymous: :smile: :grin: :eek:
  • 李雪璇: 想要完整代码,可以帮忙发给我吗
  • Anonymous: 请问一下,那个 user的推荐结果楼主查看了么? 为什么输入数据 最高是五分,输出结果都是7分8分啥的?怎么设置输出的分数的最 大值?
  • Anonymous: 那个 user的推荐结果楼主查看了么? 为什么输入数据 最高是五分,输出结果都是7分8分啥的?
  • Anonymous: stopGracefullyOnShutdown在yarn- client模式下我测试的无效,你的呢
  • Anonymous: 另外,import的lib包能否发个列表.