接口介绍

决策树算法的理论请参考本博客的这篇文章。 本文使用rpart包来完成决策树建模,主要内容整理自谢益辉的一篇中文论文。

rpart包主要有两个函数组成,分别介绍如下:

rpart(formula, data, weight s, subset, na. action = na. rpart, method, model= FALSE, x= FALSE,y= TRUE, parms, control, cost, . . . )
  • fomula 回归方程形式: 例如 y~ x 1+ x2+ x3。
  • data 数据: 包含前面方程中变量的数据框( data frame) 。
  • na.action 缺失数据的处理办法: 默认办法是删除因变量缺失的观测而保留自变量缺失的观测。
  • method 根据树末端的数据类型选择相应变量分割方法,本参数有四种取值: 连续型>anova; 离散型>class; 计数型( 泊松过程)>poisson; 生存分析型>exp。程序会根据因变量的类型自动选择方法, 但一般情况下最好还是指明本参数, 以便让程序清楚做哪一种树模型。
  • parms 用来设置三个参数:先验概率、损失矩阵、分类纯度的度量方法。anova没有参数;poisson分割有一个参数,先验分布变异系数的比率,默认为1;生存分布的参数和poisson一致;对离散型,可以设置先验分布的分布的概率(prior),损失矩阵(loss),分类纯度(split);priors必须为正值且和为1,loss必须对角为0且非对角为正数,split可以是gini(基尼系数)或者information(信息增益);
  • control 控制每个节点上的最小样本量、交叉验证的次数、复杂性参量: 即cp: complexity pamemeter, 这个参数意味着对每一步拆分, 模型的拟合优度必须提高的程度, 等等。

PS:交叉验证指,比如xval是10折交叉验证:将数据集分为10组,进行10次拟合,第i次拟合用除了第i组以外的数据训练,用第i组进行预测;目的是减少misclaassification rate。

prune(tree, cp, . . . )
  • tree 一个回归树对象, 常是rpart()的结果对象。
  • cp 复杂性参量, 指定剪枝采用的阈值。

方法论

通常分为两步建立回归树,先使用前剪枝,生成一颗较大的树;然后通过后剪枝,根据统计估量删除底部的一些节点来对树进行修剪。

使用rpart()构建树的过程中,当给定条件满足时构建过程就停止。偏差的减少小于某一个给定界限值、节点中的样本数量小于某个给定界限、树的深度大于一个给定的界限,上面三个界限分别由rpart()的三个参数(cp、minsplit、maxdepth)确定,默认值是0.01、20和30。如果要避免树的过度拟合问题,就要经常检查这些默认值的有效性,这可以通过对得到的树采取事后修剪的过程来实现。

选择树的方法一般有两种,一种是最小化交叉验证的相对方差(xerror)。另外一种是在剪枝理论中,比较著名的规则就是1-SE(1标准差) 规则, 其意思是: 首先要保证预测误差( 通过交叉验证获得, 在程序中表示为xerror) 尽量小,但不一定要取最小值, 而是允许它在“最小的误差”一个相应标准差的范围内, 然后在此范围内选取尽量小的复杂性参量,进而以它为依据进行剪枝。这个规则体现了兼顾树的规模( 复杂性) 和误差大小的思想, 因为一般说来, 随着拆分的增多, 复杂性参量会单调下降(纯度越来越高) , 但是预测误差则会先降后升, 这样, 就无法使复杂性和误差同时降到最低,因此允许误差可以在一个标准差内波动。

示例

rpart包自带数据集stagec,包含了146位患了stage c前列腺(prostate)癌的病人。变量介绍如下:

  • pgtime: 出现症状或复发时间,单位年;
  • pgstat:状态变量,1为复发,0为删减;
  • age:年龄;
  • eet:是否内分泌治疗,1为no,2为yes;
  • g2:g2阶段肿瘤细胞百分比;
  • grade:肿瘤等级,farrow体系;
  • gleason:肿瘤等级,gleason体系;
  • ploidy:肿瘤的倍体状态。

执行以下代码对数据集建模:

library(rpart)
## rpart.control对树进行一些设置  
## xval是10折交叉验证:将数据集分为10组,进行10次拟合,第i次拟合用除了第i组以外的数据训练,用第i组进行预测;目的是减少misclaassification rate  
## minsplit是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止  
## cp全称为complexity parameter,指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度  
ct <- rpart.control(xval=10, minsplit=20, cp=0.01)  
     
## method:树的末端数据类型选择相应的变量分割方法:  
## 连续性method=“anova”,离散型method=“class”,计数型method=“poisson”,生存分析型method=“exp”  
## parms用来设置三个参数:先验概率、损失矩阵、分类纯度的度量方法(gini和information)
## parms的解释:For classification splitting, the list can contain any of: the vector of prior probabilities (component prior), the loss matrix (component loss) or the splitting index (component split). The priors must be positive and sum to 1. The loss matrix must have zeros on the diagonal and positive off-diagonal elements. The splitting index can be gini or information. The default priors are proportional to the data counts, the losses default to 1, and the split defaults to gini.
## cost:给每个变量一个成本,选择某个变量进行split时,split改进量/成本作为评价标准,默认都为1
str(stagec)
## 'data.frame':    146 obs. of  8 variables:
##  $ pgtime : num  6.1 9.4 5.2 3.2 1.9 4.8 5.8 7.3 3.7 15.9 ...
##  $ pgstat : int  0 0 1 1 1 0 0 0 1 0 ...
##  $ age    : int  64 62 59 62 64 69 75 71 73 64 ...
##  $ eet    : int  2 1 2 2 2 1 2 2 2 2 ...
##  $ g2     : num  10.26 NA 9.99 3.57 22.56 ...
##  $ grade  : int  2 3 3 2 4 3 2 3 3 3 ...
##  $ gleason: int  4 8 7 4 8 7 NA 7 6 7 ...
##  $ ploidy : Factor w/ 3 levels "diploid","tetraploid",..: 1 3 1 1 2 1 2 3 1 2 ...
progstat <- factor(stagec$pgstat, levels=0:1, labels=c("No", "Prog"))
cfit <- rpart(progstat~age + eet + g2 + grade + gleason + ploidy,
              data=stagec, method="class", control=ct,
              parms=list(split="gini")
              )
print(cfit)
## n= 146 
## 
## node), split, n, loss, yval, (yprob)
##       * denotes terminal node
## 
##  1) root 146 54 No (0.6301370 0.3698630)  
##    2) grade< 2.5 61  9 No (0.8524590 0.1475410) *
##    3) grade>=2.5 85 40 Prog (0.4705882 0.5294118)  
##      6) g2< 13.2 40 17 No (0.5750000 0.4250000)  
##       12) ploidy=diploid,tetraploid 31 11 No (0.6451613 0.3548387)  
##         24) g2>=11.845 7  1 No (0.8571429 0.1428571) *
##         25) g2< 11.845 24 10 No (0.5833333 0.4166667)  
##           50) g2< 11.005 17  5 No (0.7058824 0.2941176) *
##           51) g2>=11.005 7  2 Prog (0.2857143 0.7142857) *
##       13) ploidy=aneuploid 9  3 Prog (0.3333333 0.6666667) *
##      7) g2>=13.2 45 17 Prog (0.3777778 0.6222222)  
##       14) g2>=17.91 22  8 No (0.6363636 0.3636364)  
##         28) age>=62.5 15  4 No (0.7333333 0.2666667) *
##         29) age< 62.5 7  3 Prog (0.4285714 0.5714286) *
##       15) g2< 17.91 23  3 Prog (0.1304348 0.8695652) *
# 绘制决策树
library(rpart.plot)
rpart.plot(cfit, main="Raw Decision Tree")

## rpart包提供了复杂度损失修剪的修剪方法,printcp会告诉分裂到每一层,cp是多少,平均相对误差是多少  
## 交叉验证的估计误差(“xerror”列),以及标准误差(“xstd”列),平均相对误差=xerror±xstd  
printcp(cfit)
## 
## Classification tree:
## rpart(formula = progstat ~ age + eet + g2 + grade + gleason + 
##     ploidy, data = stagec, method = "class", parms = list(split = "gini"), 
##     control = ct)
## 
## Variables actually used in tree construction:
## [1] age    g2     grade  ploidy
## 
## Root node error: 54/146 = 0.36986
## 
## n= 146 
## 
##         CP nsplit rel error  xerror    xstd
## 1 0.104938      0   1.00000 1.00000 0.10802
## 2 0.055556      3   0.68519 1.12963 0.11036
## 3 0.027778      4   0.62963 0.92593 0.10618
## 4 0.018519      6   0.57407 0.92593 0.10618
## 5 0.010000      7   0.55556 0.96296 0.10715
## 使用1-SE法则选出最优cp值:找到xerror最小的行,得到误差阈值为该行的xerror+xstd
## 找到所有xerror小于这个阈值的行,取其中最大值的上限为prune的阈值
cfit2 <- prune(cfit, cp=0.03)
rpart.plot(cfit2, main="Pruned Decision Tree")