CART(三)

    xiaoxiao2021-03-26  35

    四、剪枝

    剪枝有预剪枝和后剪枝,预剪枝就是在树生成的过程中,加上一些限制条件使得树不会过度分裂,在上一节代码中,已经加上了预剪枝。

    下面重点讲后剪枝。

    后剪枝算法:

     输入:已经生成的树

     输出:剪枝后的树

     步骤:

     (1)如果存在任一子集是一棵树,则在该子集递归剪枝过程

     (2)计算将当前两个叶节点合并后的误差

     (3)计算不合并的误差

     (4)如果合并会降低误差,则将叶节点合并

    def isTree(obj): return (type(obj).__name__=='dict') def getMean(tree): if isTree(tree['right']):tree['right']=getMean(tree['right']) if isTree(tree['left']):tree['left']=getMean(tree['left']) return (tree['left']+tree['right'])/2.0 def prune(tree,testData): if shape(testData)[0]==0:return getMean(tree) if (isTree(tree['right']) or isTree(tree['left'])): lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) if isTree(tree['left']):tree['left']=prune(tree['left'],lSet) if isTree(tree['right']):tree['right']=prune(tree['right'],rSet) if not isTree(tree['left']) and not isTree(tree['right']): lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal']) errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+sum(power(rSet[:,-1]-tree['right'],2)) treeMean=(tree['left']+tree['right'])/2.0 errorMerge=sum(power(testData[:,-1]-treeMean,2)) if errorMerge<errorNoMerge: print "merging" return treeMean else: return tree else: return tree

    转载请注明原文地址: https://ju.6miu.com/read-661841.html

    最新回复(0)