1 简述
1.1 id3是一种基于决策树的分类算法,由J.Ross Quinlan 在1986年开发。id3根据信息增益,运用自顶向下的贪心策略 建立决策树。信息增益用于度量某个属性对样本集合分类的好坏程度。 由于采用了信息增益,id3算法建立的决策树规模比较小, 查询速度快。id3算法的改进是C4.5算法,C4.5算法可以 处理连续数据,采用信息增益率,而不是信息增益。 理解信息增益,需要先看一下信息熵。 1.2 信息熵 信息熵是随机变量的期望。度量信息的不确定程度。 信息的熵越大,信息就越不容易搞清楚。处理信息就是 为了把信息搞清楚,就是熵减少的过程。 Entropy(X) = -Sum(p(xi) * log(p(xi))) {i: 0 <= i <= n} p(x)是概率密度函数;对数是以2为底; 1.3 信息增益 用于度量属性A降低样本集合X熵的贡献大小。信息增益 越大,越适于对X分类。 Gain(A, X) = Entropy(X) - Sum(|Xv| / |X| * Entropy(Xv)) {v: A的所有可能值} Xv表示A中所有为v的值;|Xv|表示A中所有为v的值的数量; 2 id3算法流程 输入:样本集合S,属性集合A 输出:id3决策树。 1) 若所有种类的属性都处理完毕,返回;否则执行2) 2)计算出信息增益最大属性a,把该属性作为一个节点。 如果仅凭属性a就可以对样本分类,则返回;否则执行3) 3)对属性a的每个可能的取值v,执行一下操作: i. 将所有属性a的值是v的样本作为S的一个子集Sv; ii. 生成属性集合AT=A-{a}; iii.以样本集合Sv和属性集合AT为输入,递归执行id3算法; 3 一个的例子 3.1 这个例子来源于Quinlan的论文。 假设,有种户外活动。该活动能否正常进行与各种天气因素有关。 不同的天气因素组合会产生两种后果,也就是分成2类:能进行活动或不能。 我们用P表示该活动可以进行,N表示该活动无法进行。 下表描述样本集合是不同天气因素对该活动的影响。 Attribute class outlook temperature humidity windy --------------------------------------------------------- sunny hot high false N sunny hot high true N overcast hot high false P rain mild high false P rain cool normal false P rain cool normal true N overcast cool normal true P sunn y mild high false N sunny cool normal false P rain mild normal false P sunny mild normal true P overcast mild high true P overcast hot normal false P rain mild high true N 3.2 该活动无法进行的概率是:5/14 该活动可以进行的概率是:9/14 因此样本集合的信息熵是:-5/14log(5/14) - 9/14log(9/14) = 0.940 3.3 接下来我们再看属性outlook信息熵的计算: outlook为sunny时, 该活动无法进行的概率是:3/5 该活动可以进行的概率是:2/5 因此sunny的信息熵是:-3/5log(3/5) - 2/5log(2/5) = 0.971 同理可以计算outlook属性取其他值时候的信息熵: outlook为overcast时的信息熵:0 outlook为rain时的信息熵:0.971 属性outlook的信息增益:gain(outlook) = 0.940 - (5/14*0.971 + 4/14*0 + 5/14*0.971) = 0.246 相似的方法可以计算其他属性的信息增益: gain(temperature) = 0.029 gain(humidity) = 0.151 gain(windy) = 0.048 信息增益最大的属性是outlook。 3.4 根据outlook把样本分成3个子集,然后把这3个子集和余下的属性 作为输入递归执行算法。 4 代码演示 4.1 代码说明: 代码只是演示上一节的例子,写的比较仓促,没有经过仔细的设计和编码, 只是在fedora 16上做了初步的测试,所以有一些错误和不适当的地方。 4.2 编译: g++ -g -W -Wall -Wextra -o mytest main.cpp id3.cpp 4.3 执行: ./mytest 4.4 id3.h: ================================================ // 2012年 07月 12日 星期四 15:07:10 CST // author: 李小丹(Li Shao Dan) 字 殊恒(shuheng) // K.I.S.S // S.P.O.T #ifndef ID3_H #define ID3_H #include <list> #include <map> #include <utility> // value and index: >= 0, and index 0 is classification // value and index: not decision is -1 class id3_classify { public: id3_classify(int); ~id3_classify(); public: int push_sample(const int *, int); int classify(); int match(const int *); void print_tree(); private: typedef std::list<std::list<std::pair<int, int> > > sample_space_t; struct tree_node { int index; int classification; std::map<int, struct tree_node *> next; sample_space_t unclassified; }; private: void clear(struct tree_node *); int recur_classify(struct tree_node *, int); int recur_match(const int *, struct tree_node *); int max_gain(struct tree_node *); double cal_entropy(const std::map<int, int> &, double); int cal_max_gain(const sample_space_t &); int cal_split(struct tree_node *, int); void att_statistics(const sample_space_t &, std::map<int, std::map<int, int> > &, std::map<int, std::map<int, std::map<int, int> > > &, std::map<int, int> &); double cal_gain(std::map<int, int> &, std::map<int, std::map<int, int> > &, double, double); int is_classfied(const sample_space_t &); void dump_tree(struct tree_node *); private: sample_space_t unclassfied; struct tree_node *root; std::map<int, int> *attribute_values; int dimension; }; #endif =================================================== id3.cpp: ================================================== // 2012年 07月 16日 星期一 10:07:43 CST // author: 李小丹(Li Shao Dan) 字 殊恒(shuheng) // K.I.S.S // S.P.O.T #include <iostream> #include <cmath> #include <cassert> #include "id3.h" using namespace std; id3_classify::id3_classify(int d) :root(new struct tree_node), dimension(d) { root->index = -1; root->classification = -1; } id3_classify::~id3_classify() { clear(root); } int id3_classify::push_sample(const int *vec, int c) { list<pair<int, int> > v; for(int i = 0; i < dimension; ++i) v.push_back(make_pair(i + 1, vec[i])); v.push_front(make_pair(0, c)); root->unclassified.push_back(v); return 0; } int id3_classify::classify() { return recur_classify(root, dimension); } int id3_classify::match(const int *v) { return recur_match(v, root); } void id3_classify::clear(struct tree_node *node) { unclassfied.clear(); std::map<int, struct tree_node *> &next = node->next; for(std::map<int, struct tree_node *>::iterator pos = next.begin(); pos != next.end(); ++pos) clear(pos->second); next.clear(); delete node; } int id3_classify::recur_classify(struct tree_node *node, int dim) { sample_space_t &unclassified = node->unclassified; int cls; if((cls = is_classfied(unclassified)) >= 0) { node->index = -1; node->classification = cls; return 0; } int ret = max_gain(node); unclassified.clear(); if(ret < 0) return 0; map<int, struct tree_node *> &next = node->next; for(map<int, struct tree_node *>::iterator pos = next.begin(); pos != next.end(); ++pos) recur_classify(pos->second, dim - 1); return 0; } int id3_classify::is_classfied(const sample_space_t &ss) { const list<pair<int, int> > &f = ss.front(); if(f.size() == 1) return f.front().second; int cls; for(list<pair<int, int> >::const_iterator p = f.begin(); p != f.end(); ++p) { if(!p->first) { cls = p->second; break; } } for(sample_space_t::const_iterator s = ss.begin(); s != ss.end(); ++s) { const list<pair<int, int> > &v = *s; for(list<pair<int, int> >::const_iterator vp = v.begin(); vp != v.end(); ++vp) { if(!vp->first) { if(cls != vp->second) return -1; else break; } } } return cls; } int id3_classify::max_gain(struct tree_node *node) { // index of max attribute gain int mai = cal_max_gain(node->unclassified); assert(mai >= 0); node->index = mai; cal_split(node, mai); return 0; } int id3_classify::cal_max_gain(const sample_space_t &ss) { map<int, map<int, int> >att_val; map<int, map<int, map<int, int> > >val_cls; map<int, int> cls; att_statistics(ss, att_val, val_cls, cls); double s = (double)ss.size(); double entropy = cal_entropy(cls, s); double mag = -1; // max information gain int mai = -1; // index of max information gain for(map<int, map<int, int> >::iterator p = att_val.begin(); p != att_val.end(); ++p) { double g; if((g = cal_gain(p->second, val_cls[p->first], s, entropy)) > mag) { mag = g; mai = p->first; } } if(!att_val.size() && !val_cls.size() && cls.size()) return 0; return mai; } void id3_classify::att_statistics(const sample_space_t &ss, map<int, map<int, int> > &att_val, map<int, map<int, map<int, int> > > &val_cls, map<int, int> &cls) { for(sample_space_t::const_iterator spl = ss.begin(); spl != ss.end(); ++spl) { const list<pair<int, int> > &v = *spl; int c; for(list<pair<int, int> >::const_iterator vp = v.begin(); vp != v.end(); ++vp) { if(!vp->first) { c = vp->second; break; } } ++cls[c]; for(list<pair<int, int> >::const_iterator vp = v.begin(); vp != v.end(); ++vp) { if(vp->first) { ++att_val[vp->first][vp->second]; ++val_cls[vp->first][vp->second][c]; } } } } double id3_classify::cal_entropy(const map<int, int> &att, double s) { double entropy = 0; for(map<int, int>::const_iterator pos = att.begin(); pos != att.end(); ++pos) { double tmp = pos->second / s; entropy += tmp * log2(tmp); } return -entropy; } double id3_classify::cal_gain(map<int, int> &att_val, map<int, map<int, int> > &val_cls, double s, double entropy) { double gain = entropy; for(map<int, int>::const_iterator att = att_val.begin(); att != att_val.end(); ++att) { double r = att->second / s; double e = cal_entropy(val_cls[att->first], att->second); gain -= r * e; } return gain; } int id3_classify::cal_split(struct tree_node *node, int idx) { map<int, struct tree_node *> &next = node->next; sample_space_t &unclassified = node->unclassified; for(sample_space_t::iterator sp = unclassified.begin(); sp != unclassified.end(); ++sp) { list<pair<int, int> > &v = *sp; for(list<pair<int, int> >::iterator vp = v.begin(); vp != v.end(); ++vp) { if(vp->first == idx) { struct tree_node *tmp; if(!(tmp = next[vp->second])) { tmp = new struct tree_node; tmp->index = -1; tmp->classification = -1; next[vp->second] = tmp; } v.erase(vp); tmp->unclassified.push_back(v); break; } } } return 0; } int id3_classify::recur_match(const int *v, struct tree_node *node) { if(node->index < 0) return node->classification; map<int, struct tree_node *>::iterator p; map<int, struct tree_node *> &next = node->next; if((p = next.find(v[node->index-1])) == next.end()) return -1; return recur_match(v, p->second); } void id3_classify::print_tree() { return dump_tree(root); } void id3_classify::dump_tree(struct tree_node *node) { cout << "I: " << node->index << endl; cout << "C: " << node->classification << endl; cout << "N: " << node->next.size() << endl; cout << "+++++++++++++++++++++++\n"; map<int, struct tree_node *> &next = node->next; for(map<int, struct tree_node *>::iterator p = next.begin(); p != next.end(); ++p) { dump_tree(p->second); } } ==================================================== main.cpp: =================================================== // 2012年 07月 18日 星期三 13:59:10 CST // author: 李小丹(Li Shao Dan) 字 殊恒(shuheng) // K.I.S.S // S.P.O.T #include <iostream> #include "id3.h" using namespace std; int main() { enum outlook {SUNNY, OVERCAST, RAIN}; enum temp {HOT, MILD, COOL}; enum hum {HIGH, NORMAL}; enum windy {WEAK, STRONG}; int samples[14][4] = { {SUNNY , HOT , HIGH , WEAK }, {SUNNY , HOT , HIGH , STRONG}, {OVERCAST, HOT , HIGH , WEAK }, {RAIN , MILD, HIGH , WEAK }, {RAIN , COOL, NORMAL, WEAK }, {RAIN , COOL, NORMAL, STRONG}, {OVERCAST, COOL, NORMAL, STRONG}, {SUNNY , MILD, HIGH , WEAK }, {SUNNY , COOL, NORMAL, WEAK }, {RAIN , MILD, NORMAL, WEAK }, {SUNNY , MILD, NORMAL, STRONG}, {OVERCAST, MILD, HIGH , STRONG}, {OVERCAST, HOT , NORMAL, WEAK }, {RAIN , MILD, HIGH , STRONG}}; id3_classify cls(4); cls.push_sample((int *)&samples[0], 0); cls.push_sample((int *)&samples[1], 0); cls.push_sample((int *)&samples[2], 1); cls.push_sample((int *)&samples[3], 1); cls.push_sample((int *)&samples[4], 1); cls.push_sample((int *)&samples[5], 0); cls.push_sample((int *)&samples[6], 1); cls.push_sample((int *)&samples[7], 0); cls.push_sample((int *)&samples[8], 1); cls.push_sample((int *)&samples[9], 1); cls.push_sample((int *)&samples[10], 1); cls.push_sample((int *)&samples[11], 1); cls.push_sample((int *)&samples[12], 1); cls.push_sample((int *)&samples[13], 0); cls.classify(); cls.print_tree(); cout << "===============================\n"; for(int i = 0; i < 14; ++i) cout << cls.match((int *)&samples[i]) << endl; return 0; } ================================================