决策树是机器学习中一种简单明了的分类算法,用程序语言描述就是if...elif...else...,关键问题则是如何选择合适的特征对数据集进行切割,常见算法有: ID3、C4.5、CART等。

今天主要记录一下ID3这个算法,想使用这个算法首先要了解信息增益,想了解信息增益则要先明白什么是”熵”。熵描述了一个系统的混乱复杂程度,有一个理论叫做”熵增加”,含义就是一个没有外力干涉的系统混乱程度总是增加的,比如一个房间如果没人打扫的话只会越来越混乱,而不会自己变得整洁。

计算熵的公式如下:

$$H=-\sum_{i=i}^{n}P(x_i)log_2P(x_i)$$

其中\(P(x_i)\) 表示P发生的概率。

举个栗子,比如我们有下面这些数据:

饮料肉类水果闹肚子
牛奶牛肉香蕉
可乐鸭肉苹果
可乐鸡肉香蕉
牛奶猪肉苹果
咖啡鱼肉香蕉

上面记录了食物和是否闹肚子之间的关系,那么闹肚子的概率2/5不闹肚子的概率是3/5,所以整个样本的熵就是:

$$-{2 \over 5}log_2{2 \over 5} - {3 \over 5}log_2{3 \over 5}\approx0.971$$

当水果吃香蕉的时候,闹肚子的概率为2/3不闹肚子的概率是1/3,所以当吃香蕉时候条件熵为: $$-{2 \over 3}log_2{2 \over 3} - {1 \over 3}log_2{1 \over 3}\approx0.918$$

同理,吃苹果时候从来不闹肚子,所以条件熵为0。

那么在特征”水果”上的信息增益就是:

0.971 - ((3/5)*0.918+(2/5)*0) = 0.42

注意计算信息增益时候需要将 条件熵乘以这个情形发生的概率

同理可以求得特征”饮料”的信息增益:0.571和特征”肉类”的信息增益:0.971。

选择信息增益最大的特征来分割数据,所以选肉类。

计算熵的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from math import log

def create_dataset():
"""构造数据集"""
return [
['milk','beef','banana','N'],
['coca','fish','apple','N'],
['coca','beef','banana','Y'],
['milk','pork','apple','N'],
['coffee','fish','banana','Y'],
],['drink','meat','fruit']


def split_dataset(dataset,col,val):
"""切分col列值为val的数据集"""
res_dataset = []
for each in dataset:
if each[col] == val:
temp = each[:]
temp.remove(val)
res_dataset.append(temp)
return res_dataset


def calc_shannon_ent(dataset):
"""计算熵"""
count = len(dataset)
labels = {}
for each in dataset:
label = each[-1]
if label not in labels.keys():
labels[label] = 0
labels[label] += 1
ent = 0
for one in labels.keys():
prod = labels[one] / count
temp = prod * log(prod, 2)
ent += temp
return -ent


dataset,labels = create_dataset()
print(calc_shannon_ent(dataset))
sub_data = split_dataset(dataset,2,'banana')
print(sub_data)
print(calc_shannon_ent(sub_data))

输出如下:

1
2
3
0.9709505944546686
[['milk', 'beef', 'N'], ['coca', 'chicken', 'Y'], ['coffee', 'fish', 'Y']]
0.9182958340544896

接下来就是计算信息增益:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def choose_feature2split(dataset):
"""返回信息增益最高的特征是第几列"""
best_infogain = 0
best_feature_col = -1
all_count = len(dataset)
base_ent = calc_shannon_ent(dataset) # 整个样本的熵
feature_num = len(dataset[0]) - 1 # 这里有2个前提:数据集中列数相同,且最后一列为标签
for i in range(feature_num):
feat_list = set([each[i] for each in dataset]) # 得到某特征有多少不同的值
ent_temp = 0
for feat in feat_list:
sub_dataset = split_dataset(dataset,i,feat)
prob = len(sub_dataset) / all_count # 计算这种条件发生的概率
ent_temp += prob * calc_shannon_ent(sub_dataset)
tmp_infogain = base_ent - ent_temp
print("col:%s,gain:%s" % (i,tmp_infogain))
if tmp_infogain > best_infogain:
best_infogain = tmp_infogain
best_feature_col = i
return best_feature_col

输出如下:

1
2
3
col:0,gain:0.5709505944546686
col:1,gain:0.9709505944546686
col:2,gain:0.4199730940219749

ID3算法优点就是理解起来比较容易,缺点则是容易造成 过拟合 问题,另外在某些极端情况下,比如某个特征每一行值都独一无二(比如例子中的肉类),这个算法倾向于优先根据此特征划分,效率极差。而且这个算法没法处理连续型数据,比较适合 类别较少的离散数据

也因为如此,所以C4.5算法中使用 信息增益率 替换了信息增益判断,具体细节以后再写。接下来构造树:

1
2
3
4
5
6
7
8
9
10
11
12
13
def create_tree(dataset,labels):
cls_list = [each[-1] for each in dataset] # 前提是最后一列为标签
if len(set(cls_list)) == 1:
# 当标签只剩一种时候返回
return cls_list[0]
best_feat_col = choose_feature2split(dataset)
best_feat_label = labels[best_feat_col]
tree = {best_feat_label:{}}
for val in set([each[best_feat_col] for each in dataset]):
sub_labels = labels[:]
del(sub_labels[best_feat_col])
tree[best_feat_label][val] = create_tree(split_dataset(dataset,best_feat_col,val),sub_labels)
return tree

输出如下:

1
2
3
4
dataset,labels = create_dataset()
tree = create_tree(dataset,labels)

{'meat': {'beef': 'N', 'chicken': 'Y', 'duck': 'N', 'fish': 'Y', 'pork': 'N'}}

可以看出,根据上面示例构造出的决策树仅仅根据meat来决定。

将原始数据集修改一下:

1
2
3
4
5
6
7
8
def create_dataset():
return [
['milk', 'beef', 'banana', 'N'],
['coca', 'fish', 'apple', 'N'],
['coca', 'beef', 'banana', 'Y'],
['milk', 'pork', 'apple', 'Y'],
['coffee', 'fish', 'banana', 'Y'],
],['drink','meat','fruit']

可以得到下面这中决策树:

1
2
3
{'drink': {'coca': {'meat': {'beef': 'Y', 'fish': 'N'}},
'coffee': 'Y',
'milk': {'meat': {'beef': 'N', 'pork': 'Y'}}}}

接下来进行分类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def classify(tree, labels, test_vec):
root_label = list(tree.keys())[0] # py3中keys是一个生成器
sub_tree = tree[root_label]
feat_col = labels.index(root_label) # 找到特征是第几列
for key in sub_tree.keys():
if test_vec[feat_col] == key:
if isinstance(sub_tree[key], dict):
cls_label = classify(sub_tree[key], labels, test_vec) # 判断节点
else:
cls_label = sub_tree[key] # 叶子节点
return cls_label

print(classify(tree,labels,['milk','beef','apple']))
N

输出了N,但如果我们把beef改成fish这种不在决策树中的情况,则会报错。这点也说明了这个算法的过拟合缺点。