From 6800095d99d240da982decae93074d92d7da1464 Mon Sep 17 00:00:00 2001 From: Jacob Reinhold <5241441+jcreinhold@users.noreply.github.com> Date: Thu, 11 Apr 2024 12:24:00 -0400 Subject: [PATCH] Fix shap_values compatibility with shap>=0.43.0 by adjusting check_additivity parameter handling for TreeExplainer (#872) Signed-off-by: Jacob Reinhold --- econml/_shap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/econml/_shap.py b/econml/_shap.py index eed548202..3eca0ae61 100644 --- a/econml/_shap.py +++ b/econml/_shap.py @@ -12,6 +12,7 @@ """ +import inspect import shap from collections import defaultdict import numpy as np @@ -161,7 +162,7 @@ def _shap_explain_model_cate(cme_model, models, X, d_t, d_y, featurizer=None, fe output_names=output_names_, input_names=input_names_, background_samples=background_samples) - if explainer.__class__.__name__ == "Tree": + if "check_additivity" in inspect.signature(explainer).parameters: shap_out = explainer(F, check_additivity=False) else: shap_out = explainer(F) @@ -340,7 +341,7 @@ def _shap_explain_multitask_model_cate(cme_model, multitask_model_cate, X, d_t, input_names=input_names_, background_samples=background_samples) - if explainer.__class__.__name__ == "Tree": + if "check_additivity" in inspect.signature(explainer).parameters: shap_out = explainer(F, check_additivity=False) else: shap_out = explainer(F)