剪枝有预剪枝和后剪枝,预剪枝就是在树生成的过程中,加上一些限制条件使得树不会过度分裂,在上一节代码中,已经加上了预剪枝。
下面重点讲后剪枝。
后剪枝算法:
输入:已经生成的树
输出:剪枝后的树
步骤:
(1)如果存在任一子集是一棵树,则在该子集递归剪枝过程
(2)计算将当前两个叶节点合并后的误差
(3)计算不合并的误差
(4)如果合并会降低误差,则将叶节点合并
def isTree(obj): return (type(obj).__name__=='dict') def getMean(tree): if isTree(tree['right']):tree['right']=getMean(tree['right']) if isTree(tree['left']):tree['left']=getMean(tree['left']) return (tree['left']+tree['right'])/2.0 def prune(tree,testData): if shape(testData)[0]==0:return getMean(tree) if (isTree(tree['right']) or isTree(tree['left'])): lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) if isTree(tree['left']):tree['left']=prune(tree['left'],lSet) if isTree(tree['right']):tree['right']=prune(tree['right'],rSet) if not isTree(tree['left']) and not isTree(tree['right']): lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2)) treeMean=(tree['left']+tree['right'])/2.0 errorMerge=sum(power(testData[:,-1]-treeMean,2)) if errorMerge<errorNoMerge: print "merging" return treeMean else: return tree else: return tree
