Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DALEX shows columns in tables and plots that were removed as a step in a recipe (tidymodels) #253

Closed
jcpsantiago opened this issue Jun 30, 2020 · 3 comments
Labels
R 🐳 Related to R

Comments

@jcpsantiago
Copy link

DALEX works fine with {tidymodels} when models are fit using the formula interface. When using a recipe in which some columns are dropped these still show up in the plots and tables, instead of not being part of the explanation.

Here's a reproducible example:

library(DALEX)
#> Welcome to DALEX (version: 1.2.1).
#> Find examples and detailed introduction at: https://pbiecek.github.io/ema/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
library(tidymodels)
#> ── Attaching packages ───────────────────────────────────── tidymodels 0.1.0 ──
#> ✓ broom     0.5.6          ✓ recipes   0.1.12    
#> ✓ dials     0.0.7          ✓ rsample   0.0.7     
#> ✓ dplyr     1.0.0          ✓ tibble    3.0.1     
#> ✓ ggplot2   3.3.2          ✓ tune      0.1.0     
#> ✓ infer     0.5.2          ✓ workflows 0.1.1     
#> ✓ parsnip   0.1.1.9000     ✓ yardstick 0.0.6     
#> ✓ purrr     0.3.4
#> ── Conflicts ──────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::explain() masks DALEX::explain()
#> x dplyr::filter()  masks stats::filter()
#> x dplyr::lag()     masks stats::lag()
#> x recipes::step()  masks stats::step()

xgb_model <- boost_tree(mode = "regression") %>%
  set_engine(engine = "xgboost")

xgb_recipe <- recipe(mpg ~ ., mtcars) %>% 
  step_rm(am, gear)

xgb_workflow <- workflows::workflow() %>%
  workflows::add_model(., xgb_model) %>%
  workflows::add_recipe(., xgb_recipe)

fitted <- xgb_workflow %>% 
  fit(data = mtcars)
#> [22:26:43] WARNING: amalgamation/../src/objective/regression_obj.cu:170: reg:linear is now deprecated in favor of reg:squarederror.

expl <- DALEX::explain(
  fitted,
  data = mtcars %>% select(-mpg), 
  y = mtcars$mpg, 
  predict_function = function(x, y){predict(x, new_data = y) %>% pull(.pred)})
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  ( �[33m default �[39m )
#>   -> data              :  32  rows  10  cols 
#>   -> target variable   :  32  values 
#>   -> model_info        :  package Model of class: workflow package unrecognized , ver. Unknown , task regression ( �[33m default �[39m ) 
#>   -> predict function  :  function(x, y) {     predict(x, new_data = y) %>% pull(.pred) } 
#>   -> predicted values  :  numerical, min =  10.33316 , mean =  19.83897 , max =  31.99849  
#>   -> residual function :  difference between y and yhat ( �[33m default �[39m )
#>   -> residuals         :  numerical, min =  -0.3375872 , mean =  0.2516516 , max =  1.901513  
#>  �[32m A new explainer has been created! �[39m

variable_importance(expl, type = "ratio")
#>        variable mean_dropout_loss    label
#> 1  _full_model_          1.000000 workflow
#> 2            vs          1.000000 workflow
#> 3            am          1.000000 workflow  <<<<<- shouldn't be here
#> 4          gear          1.000000 workflow  <<<<<- shouldn't be here
#> 5          carb          2.576819 workflow
#> 6          drat          3.091860 workflow
#> 7          qsec          6.172457 workflow
#> 8            hp          9.013178 workflow
#> 9            wt         32.111813 workflow
#> 10         disp         41.519480 workflow
#> 11          cyl         43.896437 workflow
#> 12   _baseline_        317.729359 workflow
plot(variable_attribution(expl, new_observation = mtcars[1,], type = "break_down"))

Created on 2020-06-30 by the reprex package (v0.3.0)

Session info
devtools::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.0 (2020-04-24)
#>  os       macOS Catalina 10.15.5      
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2020-06-30                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package       * version    date       lib source                             
#>  assertthat      0.2.1      2019-03-21 [1] CRAN (R 4.0.0)                     
#>  backports       1.1.7      2020-05-13 [1] CRAN (R 4.0.0)                     
#>  base64enc       0.1-3      2015-07-28 [1] CRAN (R 4.0.0)                     
#>  bayesplot       1.7.2      2020-05-28 [1] CRAN (R 4.0.0)                     
#>  boot            1.3-25     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  broom         * 0.5.6      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  callr           3.4.3      2020-03-28 [1] CRAN (R 4.0.0)                     
#>  class           7.3-17     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  cli             2.0.2      2020-02-28 [1] CRAN (R 4.0.0)                     
#>  codetools       0.2-16     2018-12-24 [1] CRAN (R 4.0.0)                     
#>  colorspace      1.4-1      2019-03-18 [1] CRAN (R 4.0.0)                     
#>  colourpicker    1.0        2017-09-27 [1] CRAN (R 4.0.0)                     
#>  crayon          1.3.4      2017-09-16 [1] CRAN (R 4.0.0)                     
#>  crosstalk       1.1.0.1    2020-03-13 [1] CRAN (R 4.0.0)                     
#>  curl            4.3        2019-12-02 [1] CRAN (R 4.0.0)                     
#>  DALEX         * 1.2.1      2020-04-25 [1] CRAN (R 4.0.0)                     
#>  data.table      1.12.8     2019-12-09 [1] CRAN (R 4.0.0)                     
#>  desc            1.2.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  devtools        2.3.0      2020-04-10 [1] CRAN (R 4.0.0)                     
#>  dials         * 0.0.7      2020-06-10 [1] CRAN (R 4.0.0)                     
#>  DiceDesign      1.8-1      2019-07-31 [1] CRAN (R 4.0.0)                     
#>  digest          0.6.25     2020-02-23 [1] CRAN (R 4.0.0)                     
#>  dplyr         * 1.0.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  DT              0.13       2020-03-23 [1] CRAN (R 4.0.0)                     
#>  dygraphs        1.1.1.6    2018-07-11 [1] CRAN (R 4.0.0)                     
#>  ellipsis        0.3.1      2020-05-15 [1] CRAN (R 4.0.0)                     
#>  evaluate        0.14       2019-05-28 [1] CRAN (R 4.0.0)                     
#>  fansi           0.4.1      2020-01-08 [1] CRAN (R 4.0.0)                     
#>  farver          2.0.3      2020-01-16 [1] CRAN (R 4.0.0)                     
#>  fastmap         1.0.1      2019-10-08 [1] CRAN (R 4.0.0)                     
#>  foreach         1.5.0      2020-03-30 [1] CRAN (R 4.0.0)                     
#>  fs              1.4.1      2020-04-04 [1] CRAN (R 4.0.0)                     
#>  furrr           0.1.0      2018-05-16 [1] CRAN (R 4.0.0)                     
#>  future          1.17.0     2020-04-18 [1] CRAN (R 4.0.0)                     
#>  generics        0.0.2      2018-11-29 [1] CRAN (R 4.0.0)                     
#>  ggplot2       * 3.3.2      2020-06-19 [1] CRAN (R 4.0.0)                     
#>  ggridges        0.5.2      2020-01-12 [1] CRAN (R 4.0.0)                     
#>  globals         0.12.5     2019-12-07 [1] CRAN (R 4.0.0)                     
#>  glue            1.4.1      2020-05-13 [1] CRAN (R 4.0.0)                     
#>  gower           0.2.1      2019-05-14 [1] CRAN (R 4.0.0)                     
#>  GPfit           1.0-8      2019-02-08 [1] CRAN (R 4.0.0)                     
#>  gridExtra       2.3        2017-09-09 [1] CRAN (R 4.0.0)                     
#>  gtable          0.3.0      2019-03-25 [1] CRAN (R 4.0.0)                     
#>  gtools          3.8.2      2020-03-31 [1] CRAN (R 4.0.0)                     
#>  hardhat         0.1.3      2020-05-20 [1] CRAN (R 4.0.0)                     
#>  highr           0.8        2019-03-20 [1] CRAN (R 4.0.0)                     
#>  htmltools       0.4.0      2019-10-04 [1] CRAN (R 4.0.0)                     
#>  htmlwidgets     1.5.1      2019-10-08 [1] CRAN (R 4.0.0)                     
#>  httpuv          1.5.4      2020-06-06 [1] CRAN (R 4.0.0)                     
#>  httr            1.4.1      2019-08-05 [1] CRAN (R 4.0.0)                     
#>  iBreakDown      1.2.0      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  igraph          1.2.5      2020-03-19 [1] CRAN (R 4.0.0)                     
#>  infer         * 0.5.2      2020-06-14 [1] CRAN (R 4.0.0)                     
#>  ingredients     1.2.0      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  inline          0.3.15     2018-05-18 [1] CRAN (R 4.0.0)                     
#>  ipred           0.9-9      2019-04-28 [1] CRAN (R 4.0.0)                     
#>  iterators       1.0.12     2019-07-26 [1] CRAN (R 4.0.0)                     
#>  janeaustenr     0.1.5      2017-06-10 [1] CRAN (R 4.0.0)                     
#>  knitr           1.28       2020-02-06 [1] CRAN (R 4.0.0)                     
#>  labeling        0.3        2014-08-23 [1] CRAN (R 4.0.0)                     
#>  later           1.1.0.1    2020-06-05 [1] CRAN (R 4.0.0)                     
#>  lattice         0.20-41    2020-04-02 [1] CRAN (R 4.0.0)                     
#>  lava            1.6.7      2020-03-05 [1] CRAN (R 4.0.0)                     
#>  lhs             1.0.2      2020-04-13 [1] CRAN (R 4.0.0)                     
#>  lifecycle       0.2.0      2020-03-06 [1] CRAN (R 4.0.0)                     
#>  listenv         0.8.0      2019-12-05 [1] CRAN (R 4.0.0)                     
#>  lme4            1.1-23     2020-04-07 [1] CRAN (R 4.0.0)                     
#>  loo             2.2.0      2019-12-19 [1] CRAN (R 4.0.0)                     
#>  lubridate       1.7.9      2020-06-08 [1] CRAN (R 4.0.0)                     
#>  magrittr        1.5        2014-11-22 [1] CRAN (R 4.0.0)                     
#>  markdown        1.1        2019-08-07 [1] CRAN (R 4.0.0)                     
#>  MASS            7.3-51.6   2020-04-26 [1] CRAN (R 4.0.0)                     
#>  Matrix          1.2-18     2019-11-27 [1] CRAN (R 4.0.0)                     
#>  matrixStats     0.56.0     2020-03-13 [1] CRAN (R 4.0.0)                     
#>  memoise         1.1.0      2017-04-21 [1] CRAN (R 4.0.0)                     
#>  mime            0.9        2020-02-04 [1] CRAN (R 4.0.0)                     
#>  miniUI          0.1.1.1    2018-05-18 [1] CRAN (R 4.0.0)                     
#>  minqa           1.2.4      2014-10-09 [1] CRAN (R 4.0.0)                     
#>  munsell         0.5.0      2018-06-12 [1] CRAN (R 4.0.0)                     
#>  nlme            3.1-148    2020-05-24 [1] CRAN (R 4.0.0)                     
#>  nloptr          1.2.2.1    2020-03-11 [1] CRAN (R 4.0.0)                     
#>  nnet            7.3-14     2020-04-26 [1] CRAN (R 4.0.0)                     
#>  parsnip       * 0.1.1.9000 2020-06-19 [1] Github (tidymodels/parsnip@3671e19)
#>  pillar          1.4.4      2020-05-05 [1] CRAN (R 4.0.0)                     
#>  pkgbuild        1.0.8      2020-05-07 [1] CRAN (R 4.0.0)                     
#>  pkgconfig       2.0.3      2019-09-22 [1] CRAN (R 4.0.0)                     
#>  pkgload         1.1.0      2020-05-29 [1] CRAN (R 4.0.0)                     
#>  plyr            1.8.6      2020-03-03 [1] CRAN (R 4.0.0)                     
#>  prettyunits     1.1.1      2020-01-24 [1] CRAN (R 4.0.0)                     
#>  pROC            1.16.2     2020-03-19 [1] CRAN (R 4.0.0)                     
#>  processx        3.4.2      2020-02-09 [1] CRAN (R 4.0.0)                     
#>  prodlim         2019.11.13 2019-11-17 [1] CRAN (R 4.0.0)                     
#>  promises        1.1.1      2020-06-09 [1] CRAN (R 4.0.0)                     
#>  ps              1.3.3      2020-05-08 [1] CRAN (R 4.0.0)                     
#>  purrr         * 0.3.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  R6              2.4.1      2019-11-12 [1] CRAN (R 4.0.0)                     
#>  Rcpp            1.0.4.6    2020-04-09 [1] CRAN (R 4.0.0)                     
#>  RcppParallel    5.0.1      2020-05-06 [1] CRAN (R 4.0.0)                     
#>  recipes       * 0.1.12     2020-05-01 [1] CRAN (R 4.0.0)                     
#>  remotes         2.1.1      2020-02-15 [1] CRAN (R 4.0.0)                     
#>  reshape2        1.4.4      2020-04-09 [1] CRAN (R 4.0.0)                     
#>  rlang           0.4.6      2020-05-02 [1] CRAN (R 4.0.0)                     
#>  rmarkdown       2.3        2020-06-18 [1] CRAN (R 4.0.0)                     
#>  rpart           4.1-15     2019-04-12 [1] CRAN (R 4.0.0)                     
#>  rprojroot       1.3-2      2018-01-03 [1] CRAN (R 4.0.0)                     
#>  rsample       * 0.0.7      2020-06-04 [1] CRAN (R 4.0.0)                     
#>  rsconnect       0.8.16     2019-12-13 [1] CRAN (R 4.0.0)                     
#>  rstan           2.19.3     2020-02-11 [1] CRAN (R 4.0.0)                     
#>  rstanarm        2.19.3     2020-02-11 [1] CRAN (R 4.0.0)                     
#>  rstantools      2.0.0      2019-09-15 [1] CRAN (R 4.0.0)                     
#>  rstudioapi      0.11       2020-02-07 [1] CRAN (R 4.0.0)                     
#>  scales        * 1.1.1      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  sessioninfo     1.1.1      2018-11-05 [1] CRAN (R 4.0.0)                     
#>  shiny           1.4.0.2    2020-03-13 [1] CRAN (R 4.0.0)                     
#>  shinyjs         1.1        2020-01-13 [1] CRAN (R 4.0.0)                     
#>  shinystan       2.5.0      2018-05-01 [1] CRAN (R 4.0.0)                     
#>  shinythemes     1.1.2      2018-11-06 [1] CRAN (R 4.0.0)                     
#>  SnowballC       0.7.0      2020-04-01 [1] CRAN (R 4.0.0)                     
#>  StanHeaders     2.21.0-5   2020-06-09 [1] CRAN (R 4.0.0)                     
#>  statmod         1.4.34     2020-02-17 [1] CRAN (R 4.0.0)                     
#>  stringi         1.4.6      2020-02-17 [1] CRAN (R 4.0.0)                     
#>  stringr         1.4.0      2019-02-10 [1] CRAN (R 4.0.0)                     
#>  survival        3.2-3      2020-06-13 [1] CRAN (R 4.0.0)                     
#>  testthat        2.3.2      2020-03-02 [1] CRAN (R 4.0.0)                     
#>  threejs         0.3.3      2020-01-21 [1] CRAN (R 4.0.0)                     
#>  tibble        * 3.0.1      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  tidymodels    * 0.1.0      2020-02-16 [1] CRAN (R 4.0.0)                     
#>  tidyposterior   0.0.3      2020-06-11 [1] CRAN (R 4.0.0)                     
#>  tidypredict     0.4.5      2020-02-10 [1] CRAN (R 4.0.0)                     
#>  tidyr           1.1.0      2020-05-20 [1] CRAN (R 4.0.0)                     
#>  tidyselect      1.1.0      2020-05-11 [1] CRAN (R 4.0.0)                     
#>  tidytext        0.2.4      2020-04-17 [1] CRAN (R 4.0.0)                     
#>  timeDate        3043.102   2018-02-21 [1] CRAN (R 4.0.0)                     
#>  tokenizers      0.2.1      2018-03-29 [1] CRAN (R 4.0.0)                     
#>  tune          * 0.1.0      2020-04-02 [1] CRAN (R 4.0.0)                     
#>  usethis         1.6.1      2020-04-29 [1] CRAN (R 4.0.0)                     
#>  vctrs           0.3.1      2020-06-05 [1] CRAN (R 4.0.0)                     
#>  withr           2.2.0      2020-04-20 [1] CRAN (R 4.0.0)                     
#>  workflows     * 0.1.1      2020-03-17 [1] CRAN (R 4.0.0)                     
#>  xfun            0.14       2020-05-20 [1] CRAN (R 4.0.0)                     
#>  xgboost         1.1.1.1    2020-06-14 [1] CRAN (R 4.0.0)                     
#>  xml2            1.3.2      2020-04-23 [1] CRAN (R 4.0.0)                     
#>  xtable          1.8-4      2019-04-21 [1] CRAN (R 4.0.0)                     
#>  xts             0.12-0     2020-01-19 [1] CRAN (R 4.0.0)                     
#>  yaml            2.2.1      2020-02-01 [1] CRAN (R 4.0.0)                     
#>  yardstick     * 0.0.6      2020-03-17 [1] CRAN (R 4.0.0)                     
#>  zoo             1.8-8      2020-05-02 [1] CRAN (R 4.0.0)                     
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library

I understand DALEX might not be targeting compatibility with the full {tidymodels} workflow, but maybe there's a way to make this work? I haven't looked at the source code yet, but I assume explain() takes all the columns from the data supplied, which are not the same columns after the recipe is applied during the prediction phase. Simply removing the columns from the data won't work, because the recipe expects them to be there.

@pbiecek
Copy link
Member

pbiecek commented Jul 2, 2020

Thanks.
DALEX treats the models like a black box, so it doesn't check what variables the model uses. Calculating the importance or attribution it checks each column given in the date in the explain() function.

If you know how to check which variables use model created with tidymodel I can add support for such models.
But for now API of tidymodels is changing very dynamically and I haven't found any information how to extract data used for model training from tidyverse model.

So, the easiest solution (and universal) is to filter out variables with the 'zero' contributions.
This way they will not appear on the plots. The easiest way to do this is with the filter function (the explanation is simple data frames).

For example

for variable importance

variable_importance(expl, type = "ratio") %>%
  filter((dropout_loss > 1) | (variable == "_full_model_")) %>%
  plot()

and for variable attribution

variable_attribution(expl, new_observation = mtcars[1,], type = "break_down") %>%
  filter(contribution != 0) %>%
  mutate(position = as.numeric(factor(position))) %>%
  plot()

Hope it helps.

@pbiecek pbiecek added the R 🐳 Related to R label Jul 8, 2020
@hbaniecki
Copy link
Member

moved to #265

@jacekkotowski
Copy link

Can this issue be related? ModelOriented/rSAFE#10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
R 🐳 Related to R
Projects
None yet
Development

No branches or pull requests

4 participants