Skip to contents

Extracts the terminal leaf nodes of a decision tree with one or two numeric predictor variables. These leaf nodes are then converted into a data frame, where each row represents a partition (or leaf or terminal node) that can easily be plotted in coordinate space.

Usage

parttree(tree, keep_as_dt = FALSE, flipaxes = FALSE)

Arguments

tree

A tree object. Supported classes include rpart::rpart.object, or the compatible classes from from the parsnip, workflows, or mlr3 front-ends, or the constparty class inheriting from partykit::party().

keep_as_dt

Logical. The function relies on data.table for internal data manipulation. But it will coerce the final return object into a regular data frame (default behavior) unless the user specifies TRUE.

flipaxes

Logical. The function will automatically set the y-axis variable as the first split variable in the tree provided unless the user specifies TRUE.

Value

A data frame comprising seven columns: the leaf node, its path, a set of coordinates understandable to ggplot2 (i.e., xmin, xmax, ymin, ymax), and a final column corresponding to the predicted value for that leaf.

Details

This function can be used with a regression or classification tree containing one or (at most) two numeric predictors.

Examples

## rpart trees
library("rpart")
rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris)
parttree(rp)
#>   node    Species                                         path xmin xmax ymin
#> 1    2     setosa                          Petal.Length < 2.45 -Inf 2.45 -Inf
#> 2    6 versicolor  Petal.Length >= 2.45 --> Petal.Width < 1.75 2.45  Inf -Inf
#> 3    7  virginica Petal.Length >= 2.45 --> Petal.Width >= 1.75 2.45  Inf 1.75
#>   ymax
#> 1  Inf
#> 2 1.75
#> 3  Inf

## conditional inference trees
library("partykit")
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris)
parttree(ct)
#>   node    Species
#> 2    2     setosa
#> 5    5 versicolor
#> 6    6 versicolor
#> 7    7  virginica
#>                                                                path xmin xmax
#> 2                                               Petal.Length <= 1.9 -Inf  1.9
#> 5 Petal.Length > 1.9 --> Petal.Width <= 1.7 --> Petal.Length <= 4.8  1.9  4.8
#> 6  Petal.Length > 1.9 --> Petal.Width <= 1.7 --> Petal.Length > 4.8  4.8  Inf
#> 7                          Petal.Length > 1.9 --> Petal.Width > 1.7  1.9  Inf
#>   ymin ymax
#> 2 -Inf  Inf
#> 5 -Inf  1.7
#> 6 -Inf  1.7
#> 7  1.7  Inf

## rpart via partykit
rp2 = as.party(rp)
parttree(rp2)
#>   node    Species                                         path xmin xmax ymin
#> 2    2     setosa                          Petal.Length < 2.45 -Inf 2.45 -Inf
#> 4    4 versicolor  Petal.Length >= 2.45 --> Petal.Width < 1.75 2.45  Inf -Inf
#> 5    5  virginica Petal.Length >= 2.45 --> Petal.Width >= 1.75 2.45  Inf 1.75
#>   ymax
#> 2  Inf
#> 4 1.75
#> 5  Inf