9.4 树剪枝
一棵树如果节点过多,表明该模型可能对数据进行了“过拟合”。那么,如何判断是否发生了过拟合?前面章节中使用了测试集上某种交叉验证技术来发现过拟合,决策树亦是如此。本节将对此进行讨论,并分析如何避免过拟合。
通过降低决策树的复杂度来避免过拟合的过程称为剪枝(pruning)。其实本章前面已经进行过剪枝处理。在函数chooseBestSplit()中的提前终止条件,实际上是在进行一种所谓的预剪枝(prepruning)
操作。另一种形式的剪枝需要使用测试集和训练集,称作后剪枝(postpruning)。本节将分析后剪枝的有效性,但首先来看一下预剪枝的不足之处。
9.4.1 预剪枝
上节两个简单实验的结果还是令人满意的,但背后存在一些问题。树构建算法其实对输入的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。为了说明这一点,在Python提示符下输入如下命令:
>>> regTrees.createTree(myMat,ops=(0,1))
与上节中只包含两个节点的树相比,这里构建的树过于臃肿,它甚至为数据集中每个样本都分配了一个叶节点。
图9-3中的散点图,看上去与图9-1非常相似。但如果仔细地观察y轴就会发现,前者的数量级是后者的100倍。这将不是问题,对吧?现在用该数据来构建一棵新的树(数据存放在ex2.txt中),在Python提示符下输入以下命令:
>>> myDat2=regTrees.loadDataSet('ex2.txt')
>>> myMat2=mat(myDat2)
>>> regTrees.createTree(myMat2)
{'spInd': 0, 'spVal': matrix([[ 0.499171]]), 'right': {'spInd': 0,
'spVal': matrix([[ 0.457563]]), 'right': -3.6244789069767438,
'left': 7.9699461249999999}, 'l
.
.
0, 'spVal': matrix([[ 0.958512]]), 'right': 112.42895575000001,
'left': 105.248
2350000001}}}}
不知你注意到没有,从图9-1数据集构建出来的树只有两个叶节点,而这里构建的新树则有很多叶节点。
产生这个现象的原因在于,停止条件tolS对误差的数量级十分敏感。如果在选项中花费时间并对上述误差容忍度取平方值,或许也能得到仅有两个叶节点组成的树:
>>> regTrees.createTree(myMat2,ops=(10000,4))
{'spInd': 0, 'spVal': matrix([[ 0.499171]]), 'right': -2.6377193297872341,
'left': 101.35815937735855}
然而,通过不断修改停止条件来得到合理结果并不是很好的办法。事实上,我们常常甚至不确定到底需要寻找什么样的结果。这正是机器学习所关注的内容,计算机应该可以给出总体的概貌。
下节将讨论后剪枝,即利用测试集来对树进行剪枝。由于不需要用户指定参数,后剪枝是一个更理想化的剪枝方法。
9.4.2 后剪枝
使用后剪枝方法需要将数据集分成测试集和训练集。首先指定参数,使得构建出的树足够大、足够复杂,便于剪枝。接下来从上而下找到叶节点,用测试集来判断将这些叶节点合并是否能降低测试误差。如果是的话就合并。
函数prune()的伪代码如下:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并为了解实际效果,打开regTrees.py并输入程序清单9-3的代码。
程序清单9-3 回归树剪枝函数
def isTree(obj):
return (type(obj).__name__=='dict')
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree, testData):
#❶ 没有测试数据则对树进行塌陷处理
if shape(testData)[0] == 0: return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'],tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +sum(power(rSet[:,-1] - tree['right'],
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print "merging"
return treeMean
else: return tree
else: return tree
程序清单9-3中包含三个函数:isTree()、getMean()和prune()。其中isTree()用于测试输入变量是否是一棵树,返回布尔类型的结果。换句话说,该函数用于判断当前处理的节点是否是叶节点。
函数getMean()是一个递归函数,它从上往下遍历树直到叶节点为止。如果找到两个叶节点则计算它们的平均值。该函数对树进行塌陷处理(即返回树平均值),在prune()函数中调用该函数时应明确这一点。
程序清单9-3的主函数是prune(),它有两个参数:待剪枝的树与剪枝所需的测试数据testData。
prune()函数首先需要确认测试集是否为空❶。一旦非空,则反复递归调用函数prune()对测试数据进行切分。因为树是由其他数据集(训练集)生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。一旦出现这种情况应当怎么办?数据发生过拟合应该进行剪枝吗?或者模型正确不需要任何剪枝?
这里假设发生了过拟合,从而对树进行剪枝。
接下来要检查某个分支到底是子树还是节点。如果是子树,就调用函数prune()来对该子树进行剪枝。在对左右两个分支完成剪枝之后,还需要检查它们是否仍然还是子树。如果两个分支已经不再是子树,那么就可以进行合并。具体做法是对合并前后的误差进行比较。如果合并后的误差比不合并的误差小就进行合并操作,反之则不合并直接返回。
接下来看看实际效果,将程序清单9-3的代码添加到regTrees.py文件并保存,在Python提示符下输入下面的命令:
>>> reload(regTrees)
为了创建所有可能中最大的树,输入如下命令:
>>> myTree=regTrees.createTree(myMat2, ops=(0,1))
输入以下命令导入测试数据:
>>> myDatTest=regTrees.loadDataSet('ex2test.txt')
>>> myMat2Test=mat(myDatTest)
输入以下命令,执行剪枝过程:
>>> regTrees.prune(myTree, myMat2Test)
merging
merging
merging
.
.
merging
{'spInd': 0, 'spVal': matrix([[ 0.499171]]), 'right': {'spInd': 0, 'spVal':
.
.
01, 'left': {'spInd': 0, 'spVal': matrix([[ 0.960398]]), 'right': 123.559747,
'left': 112.386764}}}, 'left': 92.523991499999994}}}}
可以看到,大量的节点已经被剪枝掉了,但没有像预期的那样剪枝成两部分,这说明后剪枝可能不如预剪枝有效。一般地,为了寻求最佳模型可以同时使用两种剪枝技术。
下节将重用部分已有的树构建代码来创建一种新的树。该树仍采用二元切分,但叶节点不再是简单的数值,取而代之的是一些线性模型。
本书评论