Join the Shiny Community every month at Shiny Gatherings

R Decision Trees New Thumbnail

Machine Learning with R: A Complete Guide to Decision Trees


Updated: August 22, 2022.

R Decision Trees

R Decision Trees are among the most fundamental algorithms in supervised machine learning, used to handle both regression and classification tasks. In a nutshell, you can think of it as a glorified collection of if-else statements. What makes these if-else statements different from traditional programming is that the logical conditions are “generated” by the machine learning algorithm, but more on that later.

Interested in more basic machine learning guides? Check our detailed guide on Logistic Regression with R.

Today you’ll learn the basic theory behind the decision trees algorithm and also how to implement the algorithm in R.

Table of contents:


Introduction to R Decision Trees

Decision trees are intuitive. All they do is ask questions like is the gender male or is the value of a particular variable higher than some threshold. Based on the answers, either more questions are asked, or the classification is made. Simple!

To predict class labels, the decision tree starts from the root (root node). Calculating which attribute should represent the root node is straightforward and boils down to figuring out which attribute best separates the training records. The calculation is done with the gini impurity formula. It’s simple math but can get tedious to do manually if you have many attributes.

After determining the root node, the tree “branches out” to better classify all of the impurities found in the root node.

That’s why it’s common to hear decision tree = multiple if-else statements analogy. The analogy makes sense to a degree, but the conditional statements are calculated automatically. In simple words, the machine learns the best conditions for your data.

Let’s take a look at the following decision tree representation to drive these points further home:

Image 1 - Example decision tree

Image 1 – Example decision tree (source)

As you can see, variables Outlook?Humidity?, and Windy? are used to predict the dependent variable – Play.

You now know the basic theory behind the algorithm, and you’ll learn how to implement it in R next.

Dataset Loading and Preparation

There’s no machine learning without data, and there’s no working with data without libraries. You’ll need these ones to follow along:

library(caTools)
library(rpart)
library(rpart.plot)
library(caret)
library(Boruta)
library(cvms)
library(dplyr)

head(iris)

As you can see, we’ll use the Iris dataset to build our decision tree classifier. This is how the first couple of lines look like (output from the head() function call):

Image 2 - Iris dataset head

Image 2 – Iris dataset head

The dataset is pretty much familiar to anyone with a week of experience in data science and machine learning, so it doesn’t require a further introduction. Also, the dataset is as clean as they come, which will save us a lot of time in this section.

The only thing we have to do before continuing to predictive modeling is to split this dataset randomly into training and testing subsets. You can use the following code snippet to do a split in a 75:25 ratio:

set.seed(42)
sample_split <- sample.split(Y = iris$Species, SplitRatio = 0.75)
train_set <- subset(x = iris, sample_split == TRUE)
test_set <- subset(x = iris, sample_split == FALSE)

And that’s it! Let’s start with modeling next.

Predictive Modeling with R Decision Trees

We’re using the rpart library to build the model. The syntax for building models is identical to linear and logistic regression. You’ll need to put the target variable on the left and features on the right, separated with the ~ sign. If you want to use all features, put a dot (.) instead of feature names.

Also, don’t forget to specify method = "class" since we’re dealing with a classification dataset here.

Here’s how to train the model:

model <- rpart(Species ~ ., data = train_set, method = "class")
model

The output of calling model is shown in the following image:

Image 3 - Decision tree classifier model

Image 3 – Decision tree classifier model

From this image alone, you can see the “rules” decision tree model used to make classifications. If you’d like a more visual representation, you can use the rpart.plot package to visualize the tree:

rpart.plot(model)
Image 4 - Visual representation of the decision tree

Image 4 – Visual representation of the decision tree

You can see how many classifications were correct (in the train set) by examining the bottom nodes. The setosa was correctly classified every time, the versicolor was misclassified for virginica 5% of the time, and virginica was misclassified for versicolor 3% of the time. It’s a simple graph, but you can read everything from it.

Decision trees are also useful for examining feature importance, ergo, how much predictive power lies in each feature. You can use the varImp() function to find out. The following snippet calculates the importance and sorts them descendingly:

importances <- varImp(model) importances %>%
  arrange(desc(Overall))

The results are shown in the image below:

Image 5 - Feature importances

Image 5 – Feature importances

If the varImp() doesn’t do it for you and you’re looking for something more advanced, look no further than Boruta.

Feature Importances with Boruta

Boruta is a feature ranking and selection algorithm based on the Random Forests algorithm. It will tell you if features in your dataset are relevant for making predictions. There are ways to adjust this “relevancy”, such as tweaking the P-value and other parameters, but that’s not something we’ll go over today.

A call to boruta() function is identical to part(), with the additional doTrace parameter for limiting the console output. The code snippet below shows you how to find the importance, and how to print them sorted in descending order:

library(Boruta)

boruta_output <- Boruta(Species ~ ., data = train_set, doTrace = 0)
rough_fix_mod <- TentativeRoughFix(boruta_output)
boruta_signif <- getSelectedAttributes(rough_fix_mod)
importances <- attStats(rough_fix_mod)
importances <- importances[importances$decision != "Rejected", c("meanImp", "decision")]
importances[order(-importances$meanImp), ]
Image 6 - Boruta importances

Image 6 – Boruta importances

In case you want to present these results visually, the package has you covered:

plot(boruta_output, ces.axis = 0.7, las = 2, xlab = "", main = "Feature importance")
Image 7 - Boruta plot

Image 7 – Boruta plot

Look only for the green color – it means the feature is important. The red color would indicate the feature isn’t important, and blue represents the variable used by Boruta to determine importance, so these can be discarded. The higher the box plot on the Y-axis is, the more important the feature. It’s that easy!

You’ve built and explored the model so far, but there’s no use in it yet. The next section shows you how to make predictions on previously unseen data and evaluate the model.

Generating Predictions

Predicting new instances is now a trivial task. All you have to do is use the predict() function and pass in the testing subset. Also, make sure to specify type = "class" for everything to work correctly. Here’s an example:

preds <- predict(model, newdata = test_set, type = "class")
preds

The results are shown in the following image: 

Image 8 - Decision tree predictions

Image 8 – Decision tree predictions

But how good are these predictions? Let’s evaluate. The confusion matrix is one of the most commonly used metrics to evaluate classification models. In R, it also outputs values for other metrics, such as sensitivity, specificity, and others.

Here’s how you can print the confusion matrix:

confusionMatrix(test_set$Species, preds)

And here are the results:

Image 9 - Confusion matrix on the test set

Image 9 – Confusion matrix on the test set

As you can see, there are some misclassifications in versicolor and virginica classes, similar to what we’ve seen in the training set. Overall, the model is just short of 90% accuracy, which is more than acceptable for a simple decision tree classifier.

But let’s be honest – the amount of details in the previous image is overwhelming. What if you want to display the confusion matrix only, and display it visually as a heatmap? That’s where the cvms package comes in. It allows you to visually represent a tibble, which is just what we need.
Keep in mind the parameters in plot_confusion_matrix() function – all are intuitive to understand, and the values are fetched from cfm. Your’s might be different:

library(cvms)

cm <- confusionMatrix(test_set$Species, preds)
cfm <- as_tibble(cm$table)
plot_confusion_matrix(cfm, target_col = "Reference", prediction_col = "Prediction", counts_col = "n")
Image 10 - Confusion matrix plot

Image 10 – Confusion matrix plot

Much better, isn’t it? Now you have something to present. Let’s wrap things up in the following section.


Summary of R Decision Trees

Decision trees are an excellent introductory algorithm to the whole family of tree-based algorithms. It’s commonly used as a baseline model, which more sophisticated tree-based algorithms (such as random forests and gradient boosting) need to outperform.

Today you’ve learned basic logic and intuition behind decision trees, and how to implement and evaluate the algorithm in R. You can expect the whole suite of tree-based algorithms covered soon, so stay tuned to the Appsilon blog if you want to learn more.

If you want to implement machine learning in your organization, you can always reach out to Appsilon for help.

Learn More