Skip to contents

Basic use

Let’s start by loading the parttree package alongside rpart, which comes bundled with the base R installation and is what we’ll use for fitting our decision trees (at least, to start with). For the basic examples that follow, I’ll use the well-known Palmer Penguins dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV here.

library(rpart)     # For fitting decisions trees
library(parttree)  # This package (will automatically load ggplot2 too)
#> Loading required package: ggplot2

theme_set(theme_linedraw())

# install.packages("palmerpenguins")
data("penguins", package = "palmerpenguins")
head(penguins)
#> # A tibble: 6 × 8
#>   species island    bill_length_mm bill_depth_mm flipper_l…¹ body_…² sex    year
#>   <fct>   <fct>              <dbl>         <dbl>       <int>   <int> <fct> <int>
#> 1 Adelie  Torgersen           39.1          18.7         181    3750 male   2007
#> 2 Adelie  Torgersen           39.5          17.4         186    3800 fema…  2007
#> 3 Adelie  Torgersen           40.3          18           195    3250 fema…  2007
#> 4 Adelie  Torgersen           NA            NA            NA      NA NA     2007
#> 5 Adelie  Torgersen           36.7          19.3         193    3450 fema…  2007
#> 6 Adelie  Torgersen           39.3          20.6         190    3650 male   2007
#> # … with abbreviated variable names ¹​flipper_length_mm, ²​body_mass_g

Categorical predictions

Say we are interested in predicting the penguins species as a function of 1) flipper length and 2) bill length. We can visualize these relationships as a simple scatter plot prior to doing any formal modeling.

p = 
  ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_point(aes(col = species))
p

Recasting in terms of a decision tree is easily done (e.g., with rpart). However, visualizing the resulting tree predictions against the raw data is hard to do out of the box and this where parttree enters the fray. The main function that users will interact with is geom_parttree(), which provides a new geom layer for ggplot2 objects.

## Fit a decision tree using the same variables as the above plot
tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins)

## Visualize the tree partitions by adding it to our plot with geom_parttree()
p +  
  geom_parttree(data = tree, aes(fill=species), alpha = 0.1) +
  labs(caption = "Note: Points denote observations. Shaded regions denote model predictions.")

Continuous predictions

Trees with continuous independent variables are also supported. However, I recommend adjusting the plot fill aesthetic since your model will likely partition the data into intervals that don’t match up exactly with the raw data. The easiest way to do this is by setting your colour and fill aesthetic together as part of the same scale_colour_* call.

tree2 = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data=penguins)

ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) +
  geom_parttree(data = tree2, aes(fill=body_mass_g), alpha = 0.3) +
  geom_point(aes(col = body_mass_g)) + 
  scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together

Supported model classes

Currently, the package works with decision trees created by the rpart and partykit packages. Moreover, it supports other front-end modes that call rpart::rpart() as the underlying engine; in particular the tidymodels (parsnip or workflows) and mlr3 packages. Here’s a quick example with parsnip.

set.seed(123) ## For consistent jitter

library(parsnip)
library(titanic) ## Just for a different data set

titanic_train$Survived = as.factor(titanic_train$Survived)

## Build our tree using parsnip (but with rpart as the model engine)
ti_tree =
  decision_tree() |>
  set_engine("rpart") |>
  set_mode("classification") |>
  fit(Survived ~ Pclass + Age, data = titanic_train)

## Plot the data and model partitions
titanic_train |>
  ggplot(aes(x=Pclass, y=Age)) +
  geom_parttree(data = ti_tree, aes(fill=Survived), alpha = 0.1) +
  geom_jitter(aes(col=Survived), alpha=0.7)

Plot orientation

Underneath the hood, geom_parttree() is calling the companion parttree() function, which coerces the rpart tree object into a data frame that is easily understood by ggplot2. For example, consider again our first “tree” model from earlier. Here’s the print output of the raw model.

tree
#> n=342 (2 observations deleted due to missingness)
#> 
#> node), split, n, loss, yval, (yprob)
#>       * denotes terminal node
#> 
#> 1) root 342 191 Adelie (0.441520468 0.198830409 0.359649123)  
#>   2) flipper_length_mm< 206.5 213  64 Adelie (0.699530516 0.295774648 0.004694836)  
#>     4) bill_length_mm< 43.35 150   5 Adelie (0.966666667 0.033333333 0.000000000) *
#>     5) bill_length_mm>=43.35 63   5 Chinstrap (0.063492063 0.920634921 0.015873016) *
#>   3) flipper_length_mm>=206.5 129   7 Gentoo (0.015503876 0.038759690 0.945736434) *

And here’s what we get after we feed it to parttree().

parttree(tree)
#>   node   species                                                  path  xmin
#> 1    3    Gentoo                            flipper_length_mm >= 206.5 206.5
#> 2    4    Adelie  flipper_length_mm < 206.5 --> bill_length_mm < 43.35  -Inf
#> 3    5 Chinstrap flipper_length_mm < 206.5 --> bill_length_mm >= 43.35  -Inf
#>    xmax  ymin  ymax
#> 1   Inf  -Inf   Inf
#> 2 206.5  -Inf 43.35
#> 3 206.5 43.35   Inf

Again, the resulting data frame is designed to be amenable to a ggplot2 geom layer, with columns like xmin, xmax, etc. specifying aesthetics that ggplot2 recognises. (Fun fact: geom_parttree() is really just a thin wrapper around geom_rect().) The goal of the package is to abstract away these kinds of details from the user, so we can just specify geom_parttree() — with a valid tree object as the data input — and be done with it. However, while this generally works well, it can sometimes lead to unexpected behaviour in terms of plot orientation. That’s because it’s hard to guess ahead of time what the user will specify as the x and y variables (i.e. axes) in their other plot layers. To see what I mean, let’s redo our penguin plot from earlier, but this time switch the axes in the main ggplot() call.

## First, redo our first plot but this time switch the x and y variables
p3 = 
  ggplot(
    data = penguins, 
    aes(x = bill_length_mm, y = flipper_length_mm) ## Switched!
    ) +
  geom_point(aes(col = species))

## Add on our tree (and some preemptive titling..)
p3 +
  geom_parttree(data = tree, aes(fill = species), alpha = 0.1) +
  labs(
    title = "Oops!",
    subtitle = "Looks like a mismatch between our x and y axes..."
    )

As was the case here, this kind of orientation mismatch is normally (hopefully) pretty easy to recognize. To fix, we can use the flipaxes = TRUE argument to flip the orientation of the geom_parttree layer.

p3 +
  geom_parttree(
    data = tree, aes(fill = species), alpha = 0.1,
    flipaxes = TRUE  ## Flip the orientation
    ) +
  labs(title = "That's better")

Base graphics

While the package has been primarily designed to work with ggplot2, the parttree() infrastructure can also be used to generate plots with base graphics. Here, the ctree() function from partykit is used for fitting the tree.

library(partykit)
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm

## CTree and corresponding partition
ct = ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins)
pt = parttree(ct)

## Color palette
pal = palette.colors(4, "R4")[-1]

## Maximum/minimum for plotting range as rect() does not handle Inf well
m = 1000

## scatter plot() with added rect()
plot(
  bill_length_mm ~ flipper_length_mm, 
  data = penguins, col = pal[species], pch = 19
  )
rect(
  pmax(-m, pt$xmin), pmax(-m, pt$ymin), pmin(m, pt$xmax), pmin(m, pt$ymax),
  col = adjustcolor(pal, alpha.f = 0.1)[pt$species]
  )