Extracts the terminal leaf nodes of a decision tree with one or two numeric predictor variables. These leaf nodes are then converted into a data frame, where each row represents a partition (or leaf or terminal node) that can easily be plotted in coordinate space.
Arguments
- tree
A tree object. Supported classes include rpart::rpart.object, or the compatible classes from from the
parsnip
,workflows
, ormlr3
front-ends, or theconstparty
class inheriting frompartykit::party()
.- keep_as_dt
Logical. The function relies on
data.table
for internal data manipulation. But it will coerce the final return object into a regular data frame (default behavior) unless the user specifiesTRUE
.- flipaxes
Logical. The function will automatically set the y-axis variable as the first split variable in the tree provided unless the user specifies
TRUE
.
Value
A data frame comprising seven columns: the leaf node, its path, a set
of coordinates understandable to ggplot2
(i.e., xmin, xmax, ymin, ymax),
and a final column corresponding to the predicted value for that leaf.
Details
This function can be used with a regression or classification tree containing one or (at most) two numeric predictors.
Examples
## rpart trees
library("rpart")
rp = rpart(Species ~ Petal.Length + Petal.Width, data = iris)
parttree(rp)
#> node Species path xmin xmax ymin
#> 1 2 setosa Petal.Length < 2.45 -Inf 2.45 -Inf
#> 2 6 versicolor Petal.Length >= 2.45 --> Petal.Width < 1.75 2.45 Inf -Inf
#> 3 7 virginica Petal.Length >= 2.45 --> Petal.Width >= 1.75 2.45 Inf 1.75
#> ymax
#> 1 Inf
#> 2 1.75
#> 3 Inf
## conditional inference trees
library("partykit")
#> Loading required package: grid
#> Loading required package: libcoin
#> Loading required package: mvtnorm
ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris)
parttree(ct)
#> node Species
#> 2 2 setosa
#> 5 5 versicolor
#> 6 6 versicolor
#> 7 7 virginica
#> path xmin xmax
#> 2 Petal.Length <= 1.9 -Inf 1.9
#> 5 Petal.Length > 1.9 --> Petal.Width <= 1.7 --> Petal.Length <= 4.8 1.9 4.8
#> 6 Petal.Length > 1.9 --> Petal.Width <= 1.7 --> Petal.Length > 4.8 4.8 Inf
#> 7 Petal.Length > 1.9 --> Petal.Width > 1.7 1.9 Inf
#> ymin ymax
#> 2 -Inf Inf
#> 5 -Inf 1.7
#> 6 -Inf 1.7
#> 7 1.7 Inf
## rpart via partykit
rp2 = as.party(rp)
parttree(rp2)
#> node Species path xmin xmax ymin
#> 2 2 setosa Petal.Length < 2.45 -Inf 2.45 -Inf
#> 4 4 versicolor Petal.Length >= 2.45 --> Petal.Width < 1.75 2.45 Inf -Inf
#> 5 5 virginica Petal.Length >= 2.45 --> Petal.Width >= 1.75 2.45 Inf 1.75
#> ymax
#> 2 Inf
#> 4 1.75
#> 5 Inf