决策树
决策树是常用的分类算法之一,其对于探索式的知识发现往往有较好的表现。决策树原理十分简单,可处理大维度的数据,不用预先对模型的特征有所了解,这些特性使得决策树被广泛使用。决策树采用贪心算法,其建立过程同样需要训练数据。决策树算法有ID3、在ID3基础上发展起来的C4.5,以及C4.5的商业化版本C5.0,C5.0核心与C4.5相同,只是在执行效率和内存使用方面有所改进。
决策树的核心问题是决策树分支准则的确定,以及分裂点的确定。为了直观起见,推荐大家玩一个游戏:通过20个问题来猜出你心中所想的那个人。
初次接触这个游戏的你是否觉得十分神奇,在20个不到的问题里真的就能猜出你心中所想的那个人,不论是你的女朋友、父母或者动漫人物、歌手、演员甚至是政界人物。其实仔细想想,一个20层的二叉树最后的叶子节点有多少个?1024*1024个,而我们能想到的人绝对是超不出这个数量的。这个网站的具体算法就是采用的类似决策树的算法,通过一个个问题来减少候选的数据,直至找出你所想的那个人。
多玩几次你就会发现,一般第一个或前几个问题就会问你:你描述的对象是男(女)性吗?这意味着什么,意味着第一个问题就能将候选数据减少一半左右。因为你想的那个人,除了男人就是女人了。这就是前面所说的决策树分支准则的确定。若将这个问题放在最后几个问题中,毫无疑问是个吃力不讨好的事情。那么如何才能将这些众多属性(如:性别、高矮、胖瘦、头发长短、是否是歌手、是否有money等)按照其重要程度来排个顺序,这就是ID3和C4.5算法所做的事情了。
预备知识
信息熵是信息论中的基本概念。信息论是C.E.Shannon于1948年提出并由此发展起来的,主要用于解决信息传递过程中的问题,也称为统计通信理论。信息论认为:信息是用来消除随机不确定性的,信息量的大小可由所消除的不确定大小来计量。详细了解。
信息量的数学定义为:
\(I(u_i)=-log_2P(u_i)\)
其中 \(P(u_i)\) 为信息 \(u_i\) 发生的概率。信息熵是信息量的数学期望,是信源发出信息前的平均不确定性,也成为先验熵,信息熵的数学定义为:
\(Ent(U)=-\sum_iP(u_i)log_2P(u_i)\)
当已知信号 \(U\) 的概率分布 \(P(U)\) 且收到信号 \(V=v_i\) 后,发出信号的概率分布变为 \(P(U|v_j)\) ,于是信源的平均不确定性变为(也称为条件熵):
\(Ent(U|v_i)=-\sum_iP(u_i|v_i)log_2P(u_i|v_i)\)
一般来说, \(Ent(U|v_i) < Ent(U)\) ,于是定义信息增益为:
\(Gains(U,V)=Ent(U)-Ent(U|V)\)
ID3
ID3算法的主要思想就是每次计算出各个属性的信息增益,选择最大者为分裂属性。下面举例说明,为简单起见,随机杜撰了10条数据,分为2个维度:
性别(T1) | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 1 |
套餐类别(T2) | A | B | A | A | C | C | B | A | A | C |
是否购买 | true | false | true | false | true | true | false | true | true | true |
信息熵
根据公式,信息熵计算方式如下:
\(Ent(U)=-\sum_iP(u_i)log_2P(u_i)=-{7\over 10}log_2({7\over 10})-{3\over 10}log_2({3\over 10})=0.881\)
条件熵
\(Ent(U|T_1)={6\over 10}(-{5\over 6}log_2({5\over 6})-{1\over 6}log_2({1\over 6}))+{4\over 10}(-{2\over 4}log_2({2\over 4})-{2\over 4}log_2({2\over 4}))=0.790\)
\(Ent(U|T_2)={5\over 10}(-{4\over 5}log_2({4\over 5})-{1\over 5}log_2({1\over 5}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{3\over 10}(-{3\over 3}log_2({3\over 3})-{0\over 3}log_2({0\over 3}))=0.361\)
信息增益
\(Gains(U,T_1)=Ent(U)-Ent(U|T_1)=0.091\)
\(Gains(U,T_2)=Ent(U)-Ent(U|T_2)=0.520\)
根据ID3的算法,目前来说这种情况下将会选择T2作为最佳分组变量,因为它消除信宿对信源的平均不确定性的能力最强。
C4.5
C4.5算法主要为了解决ID3算法中的一些问题。例如当类别值多的输入变量比类别值少的输入变量有更多的机会成为当前最佳分组变量。将上述数据中的T2维度中的A类别划分为2个子类别A1和A2,输入如下:
性别(T1) | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 1 |
套餐类别(T2) | A1 | B | A2 | A1 | C | C | B | A2 | A1 | C |
是否购买 | true | false | true | false | true | true | false | true | true | true |
再来计算T2的条件熵:
\(Ent(U|T_2)={3\over 10}(-{2\over 3}log_2({2\over 3})-{1\over 3}log_2({1\over 3}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{2\over 10}(-{2\over 2}log_2({2\over 2})-{0\over 2}log_2({0\over 2}))+{3\over 10}(-{3\over 3}log_2({3\over 3})-{0\over 3}log_2({0\over 3}))=0.275\)
新的信息增益:
\(Gains(U,T_2)=Ent(U)-Ent(U|T_2)=0.606\)
可以看到比调整T2类别之前的信息增益要大了。为了消除这种不公平的现象,C4.5采用新的方式来定义信息增益:
\(GainsR(U,V)={Gains(U,V)/Ent(V)}\)
使用该公式来计算改变前和改变后T2的信息增益率:
\(GainsR(U,T_2)={Gains(U,T_2)/Ent(T_2)}=0.520\div (-{5\over 10}log_2({5\over 10})-{2\over 10}log_2({2\over 10})-{3\over 10}log_2({3\over 10}))=0.350\)
\(GainsR(U,T_2')={Gains(U,T_2')/Ent(T_2')}=0.606\div (-{3\over 10}log_2({3\over 10})-{2\over 10}log_2({2\over 10})-{2\over 10}log_2({2\over 10})-{3\over 10}log_2({3\over 10}))=0.307\)
可以看到T2改变后的信息增益率小于改变前的。
Spark实现决策树
训练数据
1 2 3 4 5 |
1 17.99 10.38 122.8 1001 0.1184 0.2776 0.3001 0.1471 0.2419 0.07871 1.095 0.9053 8.589 153.4 0.006399 0.04904 0.05373 0.01587 0.03003 0.006193 25.38 17.33 184.6 2019 0.1622 0.6656 0.7119 0.2654 0.4601 1 20.57 17.77 132.9 1326 0.08474 0.07864 0.0869 0.07017 0.1812 0.05667 0.5435 0.7339 3.398 74.08 0.005225 0.01308 0.0186 0.0134 0.01389 0.003532 24.99 23.41 158.8 1956 0.1238 0.1866 0.2416 0.186 0.275 1 19.69 21.25 130 1203 0.1096 0.1599 0.1974 0.1279 0.2069 0.05999 0.7456 0.7869 4.585 94.03 0.00615 0.04006 0.03832 0.02058 0.0225 0.004571 23.57 25.53 152.5 1709 0.1444 0.4245 0.4504 0.243 0.3613 1 11.42 20.38 77.58 386.1 0.1425 0.2839 0.2414 0.1052 0.2597 0.09744 0.4956 1.156 3.445 27.23 0.00911 0.07458 0.05661 0.01867 0.05963 0.009208 14.91 26.5 98.87 567.7 0.2098 0.8663 0.6869 0.2575 0.6638 ... |
附件下载:sample_tree_data
数据格式:第一列为标签,表示类别;后面的列为维度。
代码实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Gini /** * Created with IntelliJ IDEA. * User: He Qi * Date: 14-8-28 * Time: 10:14 */ object DecisionTreeTest extends App { val sparkConf = new SparkConf().setAppName("DecisionTree").setMaster("local[2]") val sc = new SparkContext(sparkConf) val data = sc.textFile("/home/yurnom/data/sample_tree_data.csv") val parsedData = data.map { line => val parts = line.split('\t').map(_.toDouble) LabeledPoint(parts(0), Vectors.dense(parts.tail)) } val maxDepth = 5 val model = DecisionTree.train(parsedData, Classification, Gini, maxDepth) val labelAndPreds = parsedData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val trainErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / parsedData.count println("Training Error = " + trainErr) } |
最后
目前Spark还没有Java版本的决策树,后面版本中可能会添加进来。Scala版本的决策树我这边运行起来报了一个这样的错:
关于如何寻找分裂点不属于ID3和C4.5算法的范畴,还有决策树的裁剪也是一个比较重要的部分,希望以后有机会补充。
asd
尝试一下用2.10.X的scala跑一下
将numberClasses的值设置为大于实际的类别的数就可以了
你好,运行你的程序一直报空指针(数据也是用的spark中的数据),报错行为:
val prediction = model.predict(point.features)
推测为model为空指针,但是如果
val labelAndPreds = parsedData.collect().map { point =>
val feature = point.features
val prediction = model.predict(feature)
(point.label, prediction)
}
这样就能正常运行
赞一个
我在运行完C4.5的代码后,显示
defined object DecisionTreeTest
是什么意思?这是有错误吗?运行结果在哪里看?