Decision Trees With R

"Decision trees can give a clear picture of the underlying structure in data and relationships between variables. They are an excellent tool for data inspection and to understand the interactions between variables."

The methods described below shows how to quickly implement decision trees with functions in tree, party and rpart packages.

Data Preparation
Lets use the ‘census income‘ dataset and apply various decision tree methods to predict whether a person’s income will exceed $50K/yr. The dataset used is also called ‘adults‘ data. Some of the attributes available to predict the income are age, employment type, education, marital status, work hours per week etc. Below the data is split as training and test data, which will be used for building the model and predictions.
fullData <- read.csv("", header=F) #import
names(fullData) <- c("age", "workclass", "fnlwgt", "education", "educationnum", "maritalstatus", "occupation", "relationship", "race", "sex", "capitalgain", "capitalloss", "hoursperweek", "nativecountry", "response")
fullData <- fullData[, c(15, 1:13)] # remove a factor with more than 31 levels.
train <- sample (1:nrow(fullData), .8*nrow(fullData)) # training row indices
inputData <- fullData[train, ] # training data
testData <- fullData[-train, ] # test data

Census income - snapshot
Census income – snapshot

Using the tree package

Step 1: Build the tree

Fit a ‘tree’ model on training data and calculate mis-classification error. There could be a possible over-fitting (rules becoming too specific). Pruning the size of the tree could improve the prediction accuracy to an extent. It is worthwhile to note that any factor variables in predictors can have a maximum of 32 levels, so consider regrouping if your have more than 32 levels.

treeMod <- tree(response ~ ., data = inputData)  # model the tree, including all the variables
plot(treeMod)  # Plot the tree model
text(treeMod, pretty=0)  # Add text to the plot
out <- predict(treeMod) # Predict the training data
input.response <- as.character(inputData$response) # actuals
pred.response <- colnames(out)[max.col(out, ties.method = c("first"))] # predicted
mean (input.response != pred.response) # misclassification %

Full Decision Tree With R
Full Decision Tree With R

Step 2: Prune the tree
Your tree may need ‘pruning’ to avoid over-fitting on test data. Some of the rules that are more specific can be relaxed when a higher level rule is good enough to predict the outcome. It is also possible that you may desire more rules when there is a large number of predictors and data is in large volume. In such cases, it is possible that your predictors are not ‘good enough’ at explaining the response or you need to check the integrity of data .

As a thumb rule, pick smaller value for rule size so that the rules are less specific (using ‘best’ parameter) without compromising prediction accuracy.
cvTree <- cv.tree(treeMod, FUN = prune.misclass)  # run the cross validation
plot(cvTree)  # plot the CV
treePrunedMod <- prune.misclass(treeMod, best = 9) # set size corresponding to lowest value in below plot. try 4 or 16.
text(treePrunedMod, pretty = 0)

Decision Tree CV Plot in R
Decision Tree CV Plot in R

In the above plot, the lower X axis is the number of terminal nodes and the upper X axis is the number of folds (# of pieces the data is split) in the cross validation. It shows how the misclassification error varies against these. So, this plot is very useful in determining the optimal number of terminal nodes at which the decision tree should be pruned. In the above plot, the two red lines mark the two options (# terminal nodes) at which you want to prune the data. Ideally, it is best keep the tree as simple as possible (lesser number of nodes) and the misclassification error as low as possible. Given a choice of number of terminal nodes between 4 – 9, all of which giving the same misclassification error, 4 terminal nodes should be the first choice.

Step 3: Re-calculate the mis-classification error with pruned tree

Pruning the tree can help improve the accuracy because the rules are now generic enough to fit larger subgroups.
out <- predict(treePrunedMod) # fit the pruned tree
pred.response <- colnames(out)[max.col(out, ties.method = c("random"))] # predicted
mean(inputData$response != pred.response) # Calculate Mis-classification error.

Pruned Decision Tree With R
Pruned Decision Tree With R

Step 4: Predict
out <- predict(treePrunedMod, testData)  # Predict testData with Pruned tree

Using the party package

The ctree() function in party package can be used to model binary, nominal, ordinal and numeric variables. The nature of the tree depends on the type of response variable. Pruning the tree is not required with this approach.

Step 1: Build the model tree
library (party)
fit <- ctree (response ~ pred1 + pred2 + pred3, data = inputData)  # build the tree model
plot (fit, main="Conditional Inference Tree")  # the ctree

cTree Plot in R
cTree Plot

Step 2: Predict On New or Test Data
pred.response <- as.character (predict(fit), testData) # predict on test data
input.response <- as.character (testData$response) # actuals
mean (input.response != pred.response) # misclassification %

Using the rpart package

The ‘rpart‘ package can be used to model categorical, numeric and survival object.

Step 1: Build the tree

Fit rpart() on training data and calc mis-classification error

library (rpart)
rpartMod <- rpart(response ~ ., data = inputData, method = "class")  # build the model
printcp(rpartMod)  # print the cptable

Classification tree:
rpart(formula = response ~ ., data = inputData, method = "class")
Variables actually used in tree construction:
[1] age           capitalgain   education     fnlwgt        maritalstatus
[6] occupation    workclass    

Root node error: 81/376 = 0.21543
n= 376 
        CP nsplit rel error  xerror     xstd
1 0.086420      0   1.00000 1.00000 0.098418
2 0.074074      3   0.71605 0.91358 0.095179
3 0.049383      4   0.64198 0.88889 0.094194
4 0.028807      5   0.59259 0.85185 0.092665
5 0.012346      8   0.50617 0.88889 0.094194
6 0.010000     10   0.48148 0.86420 0.093182

Lets predict on fitted data and calculate misclassification percentage.
out <- predict(rpartMod) # predict probabilities
pred.response <- colnames(out)[max.col(out, ties.method = c("random"))] # predict response
mean(inputData$response != pred.response) # % misclassification error

Step 2: Predict the Test Data
out <- predict(rpartMod, testData)

If you like us, please tell your friends.Share on LinkedInShare on Google+Share on RedditTweet about this on TwitterShare on Facebook
  • Rupam Paul

    What are pred1, pred2 and pred3 in party package?

    • Sameer Kumar Panda

      These are independent variables.