Skip to content

Latest commit

 

History

History
290 lines (240 loc) · 12.8 KB

README.md

File metadata and controls

290 lines (240 loc) · 12.8 KB

tidytrees

Regression and classification trees (e.g. from packages like partykit or rpart) are a very powerful set of statistical learning algorithms.

Nevertheless each tree package has its own way of representing and storing the trees, usually as a nested recursive list with attributes. This makes it very hard to interact with them.

This package provides an interface to convert tree objects from various packages into a “tidy” data frame, with a row for each node showing its defining set of rules and its characteristics.

Installation

You can install the last version of tidytrees with

# install.packages("devtools")
devtools::install_github("bakaburg1/tidytrees")

Simple use

tidytrees exposes the generic function tidy_tree which has a method for various tree objects (see ?tidy_tree for the list supported methods). The output is a tibble with a row of each tree node. For each node the relative rules are reported, plus other information like the node id, the number of observations related to the node in the data from which the model is derived, the depth of the node in the tree.

library(tidytrees)
library(partykit)
library(rpart)

# The function works with partykit...
model <- ctree(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model)
#> # A tibble: 12 x 6
#>    rule                                         id n.obs terminal depth estimate
#>    <chr>                                     <int> <int> <lgl>    <dbl>    <dbl>
#>  1 Species in setosa                             2    50 FALSE        1     3.43
#>  2 Sepal.Length <= 5 & Species in setosa         3    28 FALSE        2     3.20
#>  3 Sepal.Length <= 4.9 & Species in setosa       4    20 TRUE         3     3.14
#>  4 Sepal.Length <= 5 & Sepal.Length > 4.9 &…     5     8 TRUE         3     3.36
#>  5 Sepal.Length > 5 & Species in setosa          6    22 FALSE        2     3.71
#>  6 Sepal.Length <= 5.3 & Sepal.Length > 5 &…     7    12 TRUE         3     3.62
#>  7 Sepal.Length > 5.3 & Species in setosa        8    10 TRUE         3     3.82
#>  8 Species in versicolor, virginica              9   100 FALSE        1     2.87
#>  9 Sepal.Length <= 6.3 & Species in versico…    10    58 FALSE        2     2.74
#> 10 Sepal.Length <= 5.5 & Species in versico…    11    12 TRUE         3     2.47
#> 11 Sepal.Length <= 6.3 & Sepal.Length > 5.5…    12    46 TRUE         3     2.81
#> 12 Sepal.Length > 6.3 & Species in versicol…    13    42 TRUE         2     3.05

# ... and with rpart trees (more models to come)
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model)
#> # A tibble: 8 x 6
#>   rule                                          id n.obs depth terminal estimate
#>   <chr>                                      <dbl> <int> <dbl> <lgl>       <dbl>
#> 1 Species = versicolor,virginica                 2   100     1 FALSE        2.87
#> 2 Species = versicolor,virginica & Sepal.Le…     4    58     2 FALSE        2.74
#> 3 Species = versicolor,virginica & Sepal.Le…     8    12     3 TRUE         2.47
#> 4 Species = versicolor,virginica & Sepal.Le…     9    46     3 TRUE         2.81
#> 5 Species = versicolor,virginica & Sepal.Le…     5    42     2 TRUE         3.05
#> 6 Species = setosa                               3    50     1 FALSE        3.43
#> 7 Species = setosa & Sepal.Length < 5.05         6    28     2 TRUE         3.20
#> 8 Species = setosa & Sepal.Length >= 5.05        7    22     2 TRUE         3.71

The rules can optionally be rendered in a R compatible format, for easy use as data filters, or as list of rules.

library(tidytrees)
library(dplyr)
library(rpart)

model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

# Evaluation friendly rules

out <- tidy_tree(model, eval_ready = T)

iris %>% filter(eval(str2expression(out$rule[3]))) %>% str
#> 'data.frame':    12 obs. of  5 variables:
#>  $ Sepal.Length: num  5.5 4.9 5.2 5 5.5 5.5 5.4 5.5 5.5 5 ...
#>  $ Sepal.Width : num  2.3 2.4 2.7 2 2.4 2.4 3 2.5 2.6 2.3 ...
#>  $ Petal.Length: num  4 3.3 3.9 3.5 3.8 3.7 4.5 4 4.4 3.3 ...
#>  $ Petal.Width : num  1.3 1 1.4 1 1.1 1 1.5 1.3 1.2 1 ...
#>  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 2 2 2 2 2 2 2 2 2 2 ...

# Rules as vectors
out <- tidy_tree(model, rule_as_text = F)

out$rule[3]
#> [[1]]
#> [1] "Species = versicolor,virginica" "Sepal.Length < 6.35"           
#> [3] "Sepal.Length < 5.55"

# Both
out <- tidy_tree(model, rule_as_text = F, eval_ready = T)

out$rule[3]
#> [[1]]
#> [1] "Species %in% c(\"versicolor\", \"virginica\")"
#> [2] "Sepal.Length < 6.35"                          
#> [3] "Sepal.Length < 5.55"

Tree models tend to create explicit, nested rules with redundant components, useful to retain the whole branching information. The package can simplify such rules in order to make them more human-friendly while keeping the minimal necessary set of conditions to identify a partition. The simplified rules are ordered alphabetically to group conditions on the same variables together.

library(tidytrees)
library(dplyr)
library(rpart)

model <- rpart(Sepal.Length ~ Species + Sepal.Width, data = iris)

# Full rules

tidy_tree(model)$rule[5:9]
#> [1] "Species = versicolor,virginica & Species = versicolor"                                            
#> [2] "Species = versicolor,virginica & Species = versicolor & Sepal.Width < 2.75"                       
#> [3] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75"                      
#> [4] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75 & Sepal.Width < 3.05" 
#> [5] "Species = versicolor,virginica & Species = versicolor & Sepal.Width >= 2.75 & Sepal.Width >= 3.05"

# Simplified rules
tidy_tree(model, simplify_rules = T)$rule[5:9]
#> [1] "Species = versicolor"                                           
#> [2] "Sepal.Width < 2.75 & Species = versicolor"                      
#> [3] "Sepal.Width >= 2.75 & Species = versicolor"                     
#> [4] "Sepal.Width < 3.05 & Sepal.Width >= 2.75 & Species = versicolor"
#> [5] "Sepal.Width >= 3.05 & Species = versicolor"

# It works also on a list of conditions
tidy_tree(model, rule_as_text = F, simplify_rules = T)$rule[5:9]
#> [[1]]
#> [1] "Species = versicolor"
#> 
#> [[2]]
#> [1] "Sepal.Width < 2.75"   "Species = versicolor"
#> 
#> [[3]]
#> [1] "Sepal.Width >= 2.75"  "Species = versicolor"
#> 
#> [[4]]
#> [1] "Sepal.Width < 3.05"   "Sepal.Width >= 2.75"  "Species = versicolor"
#> 
#> [[5]]
#> [1] "Sepal.Width >= 3.05"  "Species = versicolor"

# Can be applied to previously created rules

rules <- tidy_tree(model)$rule[5:9]

simplify_rules(rules)
#> [1] "Species = versicolor"                                           
#> [2] "Sepal.Width < 2.75 & Species = versicolor"                      
#> [3] "Sepal.Width >= 2.75 & Species = versicolor"                     
#> [4] "Sepal.Width < 3.05 & Sepal.Width >= 2.75 & Species = versicolor"
#> [5] "Sepal.Width >= 3.05 & Species = versicolor"

Node predictions

The output contains optionally the predicted value in the node and estimation intervals, with the possibility to chose the interval coverage (default = 95%).

library(tidytrees)
library(dplyr)
library(rpart)

# Intervals for continuous...
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, interval_level = .89)
#> # A tibble: 8 x 8
#>   rule                       id n.obs depth terminal estimate conf.low conf.high
#>   <chr>                   <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#> 1 Species = versicolor,v…     2   100     1 FALSE        2.87     2.83      2.91
#> 2 Species = versicolor,v…     4    58     2 FALSE        2.74     2.69      2.79
#> 3 Species = versicolor,v…     8    12     3 TRUE         2.47     2.38      2.55
#> 4 Species = versicolor,v…     9    46     3 TRUE         2.81     2.76      2.87
#> 5 Species = versicolor,v…     5    42     2 TRUE         3.05     3.00      3.10
#> 6 Species = setosa            3    50     1 FALSE        3.43     3.36      3.49
#> 7 Species = setosa & Sep…     6    28     2 TRUE         3.20     3.14      3.27
#> 8 Species = setosa & Sep…     7    22     2 TRUE         3.71     3.64      3.79

# ... and discrete outcomes
model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, interval_level = .89)
#> # A tibble: 24 x 9
#>    rule              id n.obs depth terminal estimate conf.low conf.high y.level
#>    <chr>          <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl> <chr>  
#>  1 Sepal.Length …     2    52     1 FALSE      0.865   0.765      0.934  setosa 
#>  2 Sepal.Length …     2    52     1 FALSE      0.115   0.0527     0.212  versic…
#>  3 Sepal.Length …     2    52     1 FALSE      0.0192  0.00109    0.0860 virgin…
#>  4 Sepal.Length …     4    45     2 TRUE       0.978   0.901      0.999  setosa 
#>  5 Sepal.Length …     4    45     2 TRUE       0.0222  0.00126    0.0988 versic…
#>  6 Sepal.Length …     4    45     2 TRUE       0       0          0.0624 virgin…
#>  7 Sepal.Length …     5     7     2 TRUE       0.143   0.00805    0.512  setosa 
#>  8 Sepal.Length …     5     7     2 TRUE       0.714   0.349      0.944  versic…
#>  9 Sepal.Length …     5     7     2 TRUE       0.143   0.00805    0.512  virgin…
#> 10 Sepal.Length …     3    98     1 FALSE      0.0510  0.0209     0.103  setosa 
#> # … with 14 more rows

The default intervals are based on the normal approximation for continuous values and on binom.test() for discrete ones. But the estimation function is pluggable, so users can provide their own.

library(tidytrees)
library(dplyr)
library(rpart)

# Quantile intervals for continuous outcomes
model <- rpart(Sepal.Width ~ Species + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) {
    data.frame(
        estimate = median(values),
        conf.low = quantile(values, (1 - interval_level) / 2),
        conf.high = quantile(values, .5 + interval_level/2)
    )
})
#> # A tibble: 8 x 8
#>   rule                       id n.obs depth terminal estimate conf.low conf.high
#>   <chr>                   <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#> 1 Species = versicolor,v…     2   100     1 FALSE        2.9      2.2       3.50
#> 2 Species = versicolor,v…     4    58     2 FALSE        2.75     2.2       3.4 
#> 3 Species = versicolor,v…     8    12     3 TRUE         2.45     2.08      2.92
#> 4 Species = versicolor,v…     9    46     3 TRUE         2.8      2.2       3.4 
#> 5 Species = versicolor,v…     5    42     2 TRUE         3        2.60      3.80
#> 6 Species = setosa            3    50     1 FALSE        3.4      2.92      4.18
#> 7 Species = setosa & Sep…     6    28     2 TRUE         3.2      2.70      3.6 
#> 8 Species = setosa & Sep…     7    22     2 TRUE         3.7      3.35      4.30

# Bayesian regularized credibility intervals for discrete outcomes
model <- rpart(Species ~ Sepal.Width + Sepal.Length, data = iris)

tidy_tree(model, add_interval = T, est_fun = function(values, add_interval, interval_level) {
    table(values) %>%
        lapply(function(cases) {
            qbeta(
                c(.5, (1 - interval_level) / 2, .5 + interval_level/2),
                cases + 1.1,
                length(values) - cases + 1.1
            ) %>% matrix(nrow = 1) %>% as.data.frame() %>% 
                setNames(c('estimate', 'cred.low', 'cred.high'))
        }) %>% bind_rows()
})
#> # A tibble: 24 x 8
#>    rule                      id n.obs depth terminal estimate cred.low cred.high
#>    <chr>                  <dbl> <int> <dbl> <lgl>       <dbl>    <dbl>     <dbl>
#>  1 Sepal.Length < 5.45        2    52     1 FALSE      0.855  0.745       0.931 
#>  2 Sepal.Length < 5.45        2    52     1 FALSE      0.126  0.0558      0.232 
#>  3 Sepal.Length < 5.45        2    52     1 FALSE      0.0332 0.00519     0.103 
#>  4 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.962  0.882       0.994 
#>  5 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.0382 0.00599     0.118 
#>  6 Sepal.Length < 5.45 &…     4    45     2 TRUE       0.0170 0.000803    0.0810
#>  7 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.208  0.0353      0.531 
#>  8 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.675  0.349       0.911 
#>  9 Sepal.Length < 5.45 &…     5     7     2 TRUE       0.208  0.0353      0.531 
#> 10 Sepal.Length >= 5.45       3    98     1 FALSE      0.0580 0.0231      0.115 
#> # … with 14 more rows