173 lines
6.2 KiB
Plaintext
173 lines
6.2 KiB
Plaintext
|
|
R Under development (unstable) (2019-04-05 r76323) -- "Unsuffered Consequences"
|
|
Copyright (C) 2019 The R Foundation for Statistical Computing
|
|
Platform: x86_64-pc-linux-gnu (64-bit)
|
|
|
|
R is free software and comes with ABSOLUTELY NO WARRANTY.
|
|
You are welcome to redistribute it under certain conditions.
|
|
Type 'license()' or 'licence()' for distribution details.
|
|
|
|
R is a collaborative project with many contributors.
|
|
Type 'contributors()' for more information and
|
|
'citation()' on how to cite R or R packages in publications.
|
|
|
|
Type 'demo()' for some demos, 'help()' for on-line help, or
|
|
'help.start()' for an HTML browser interface to help.
|
|
Type 'q()' to quit R.
|
|
|
|
> # Any necessary setup
|
|
> library(rpart)
|
|
> options(na.action="na.omit")
|
|
> options(digits=4) # to match earlier output
|
|
> set.seed(1234)
|
|
>
|
|
> mystate <- data.frame(state.x77, region=factor(state.region))
|
|
> names(mystate) <- c("population","income" , "illiteracy","life" ,
|
|
+ "murder", "hs.grad", "frost", "area", "region")
|
|
> #
|
|
> # Test out the "user mode" functions, with an anova variant
|
|
> #
|
|
>
|
|
> # The 'evaluation' function. Called once per node.
|
|
> # Produce a label (1 or more elements long) for labeling each node,
|
|
> # and a deviance. The latter is
|
|
> # - of length 1
|
|
> # - equal to 0 if the node is "pure" in some sense (unsplittable)
|
|
> # - does not need to be a deviance: any measure that gets larger
|
|
> # as the node is less acceptable is fine.
|
|
> # - the measure underlies cost-complexity pruning, however
|
|
> temp1 <- function(y, wt, parms) {
|
|
+ wmean <- sum(y*wt)/sum(wt)
|
|
+ rss <- sum(wt*(y-wmean)^2)
|
|
+ list(label= wmean, deviance=rss)
|
|
+ }
|
|
>
|
|
> # The split function, where most of the work occurs.
|
|
> # Called once per split variable per node.
|
|
> # If continuous=T
|
|
> # The actual x variable is ordered
|
|
> # y is supplied in the sort order of x, with no missings,
|
|
> # return two vectors of length (n-1):
|
|
> # goodness = goodness of the split, larger numbers are better.
|
|
> # 0 = couldn't find any worthwhile split
|
|
> # the ith value of goodness evaluates splitting obs 1:i vs (i+1):n
|
|
> # direction= -1 = send "y< cutpoint" to the left side of the tree
|
|
> # 1 = send "y< cutpoint" to the right
|
|
> # this is not a big deal, but making larger "mean y's" move towards
|
|
> # the right of the tree, as we do here, seems to make it easier to
|
|
> # read
|
|
> # If continuos=F, x is a set of integers defining the groups for an
|
|
> # unordered predictor. In this case:
|
|
> # direction = a vector of length m= "# groups". It asserts that the
|
|
> # best split can be found by lining the groups up in this order
|
|
> # and going from left to right, so that only m-1 splits need to
|
|
> # be evaluated rather than 2^(m-1)
|
|
> # goodness = m-1 values, as before.
|
|
> #
|
|
> # The reason for returning a vector of goodness is that the C routine
|
|
> # enforces the "minbucket" constraint. It selects the best return value
|
|
> # that is not too close to an edge.
|
|
> temp2 <- function(y, wt, x, parms, continuous) {
|
|
+ # Center y
|
|
+ n <- length(y)
|
|
+ y <- y- sum(y*wt)/sum(wt)
|
|
+
|
|
+ if (continuous) {
|
|
+ # continuous x variable
|
|
+ temp <- cumsum(y*wt)[-n]
|
|
+
|
|
+ left.wt <- cumsum(wt)[-n]
|
|
+ right.wt <- sum(wt) - left.wt
|
|
+ lmean <- temp/left.wt
|
|
+ rmean <- -temp/right.wt
|
|
+ goodness <- (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2)
|
|
+ list(goodness= goodness, direction=sign(lmean))
|
|
+ }
|
|
+ else {
|
|
+ # Categorical X variable
|
|
+ ux <- sort(unique(x))
|
|
+ wtsum <- tapply(wt, x, sum)
|
|
+ ysum <- tapply(y*wt, x, sum)
|
|
+ means <- ysum/wtsum
|
|
+
|
|
+ # For anova splits, we can order the categories by their means
|
|
+ # then use the same code as for a non-categorical
|
|
+ ord <- order(means)
|
|
+ n <- length(ord)
|
|
+ temp <- cumsum(ysum[ord])[-n]
|
|
+ left.wt <- cumsum(wtsum[ord])[-n]
|
|
+ right.wt <- sum(wt) - left.wt
|
|
+ lmean <- temp/left.wt
|
|
+ rmean <- -temp/right.wt
|
|
+ list(goodness= (left.wt*lmean^2 + right.wt*rmean^2)/sum(wt*y^2),
|
|
+ direction = ux[ord])
|
|
+ }
|
|
+ }
|
|
>
|
|
> # The init function:
|
|
> # fix up y to deal with offsets
|
|
> # return a dummy parms list
|
|
> # numresp is the number of values produced by the eval routine's "label"
|
|
> # numy is the number of columns for y
|
|
> # summary is a function used to print one line in summary.rpart
|
|
> # In general, this function would also check for bad data, see rpart.poisson
|
|
> # for instace.
|
|
> temp3 <- function(y, offset, parms, wt) {
|
|
+ if (!is.null(offset)) y <- y-offset
|
|
+ list(y=y, parms=0, numresp=1, numy=1,
|
|
+ summary= function(yval, dev, wt, ylevel, digits ) {
|
|
+ paste(" mean=", format(signif(yval, digits)),
|
|
+ ", MSE=" , format(signif(dev/wt, digits)),
|
|
+ sep='')
|
|
+ })
|
|
+ }
|
|
>
|
|
>
|
|
> alist <- list(eval=temp1, split=temp2, init=temp3)
|
|
>
|
|
> fit1 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
|
|
+ mystate, control=rpart.control(minsplit=10, xval=0),
|
|
+ method=alist)
|
|
>
|
|
> fit2 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
|
|
+ mystate, control=rpart.control(minsplit=10, xval=0),
|
|
+ method='anova')
|
|
>
|
|
> # Other than their call statement, and a longer "functions" component in
|
|
> # fit1, fit1 and fit2 should be identical.
|
|
> all.equal(fit1$frame, fit2$frame)
|
|
[1] TRUE
|
|
> all.equal(fit1$splits, fit2$splits)
|
|
[1] TRUE
|
|
> all.equal(fit1$csplit, fit2$csplit)
|
|
[1] TRUE
|
|
> all.equal(fit1$where, fit2$where)
|
|
[1] TRUE
|
|
> all.equal(fit1$cptable, fit2$cptable)
|
|
[1] TRUE
|
|
>
|
|
> # Now try xpred on it
|
|
> xvtemp <- rep(1:5, length=50)
|
|
> xp1 <- xpred.rpart(fit1, xval=xvtemp)
|
|
> xp2 <- xpred.rpart(fit2, xval=xvtemp)
|
|
> aeq <- function(x,y) all.equal(as.vector(x), as.vector(y))
|
|
> aeq(xp1, xp2)
|
|
[1] TRUE
|
|
>
|
|
> fit3 <- rpart(income ~population +illiteracy + murder + hs.grad + region,
|
|
+ mystate, control=rpart.control(minsplit=10, xval=xvtemp),
|
|
+ method='anova')
|
|
> zz <- apply((mystate$income - xp1)^2,2, sum)
|
|
> aeq(zz/fit1$frame$dev[1], fit3$cptable[,4]) #reproduce xerror
|
|
[1] TRUE
|
|
>
|
|
> zz2 <- sweep((mystate$income-xp1)^2,2, zz/nrow(xp1))
|
|
> zz2 <- sqrt(apply(zz2^2, 2, sum))/ fit1$frame$dev[1]
|
|
> aeq(zz2, fit3$cptable[,5]) #reproduce se(xerror)
|
|
[1] TRUE
|
|
>
|
|
>
|
|
> proc.time()
|
|
user system elapsed
|
|
0.155 0.025 0.175
|