From 88c50a85cd6f900213d0855c6e69243aad9ff0cb Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:47:47 +0000 Subject: [PATCH 1/7] Pass explainer weights to loss_after_permutation --- .../dalex/model_explanations/_variable_importance/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 2090944b9..587a69363 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -18,7 +18,7 @@ def calculate_variable_importance(explainer, if processes == 1: result = [None] * B for i in range(B): - result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.model, explainer.predict_function, + result[i] = loss_after_permutation(explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function, loss_function, variables, N, np.random) else: # Create number generator for each iteration @@ -26,7 +26,7 @@ def calculate_variable_importance(explainer, generators = [default_rng(s) for s in ss.spawn(B)] pool = mp.get_context('spawn').Pool(processes) result = pool.starmap_async(loss_after_permutation, [ - (explainer.data, explainer.y, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for + (explainer.data, explainer.y, explainer.weights, explainer.model, explainer.predict_function, loss_function, variables, N, generators[i]) for i in range(B)]).get() pool.close() @@ -49,7 +49,7 @@ def calculate_variable_importance(explainer, return result, raw_permutations -def loss_after_permutation(data, y, model, predict, loss_function, variables, N, rng): +def loss_after_permutation(data, y, weights, model, predict, loss_function, variables, N, rng): if isinstance(N, int): N = min(N, data.shape[0]) sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False) From 5348a5146a209a56e6fe82cb6bcc4ee1e443576e Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:49:29 +0000 Subject: [PATCH 2/7] Define weights for sampled data --- .../dalex/model_explanations/_variable_importance/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 587a69363..0f8de5ad5 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -55,9 +55,11 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari sampled_rows = rng.choice(np.arange(data.shape[0]), N, replace=False) sampled_data = data.iloc[sampled_rows, :] observed = y[sampled_rows] + sample_weights = weights[sampled_rows] if weights is not None else None else: sampled_data = data observed = y + sample_weights = weights # loss on the full model or when outcomes are permuted loss_full = loss_function(observed, predict(model, sampled_data)) From 5fc4b91f27de368b6329e211dc1533b5c095225b Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:50:41 +0000 Subject: [PATCH 3/7] Function to handle loss functions with or without sample_weight arg --- .../_variable_importance/utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 0f8de5ad5..47a72d7ce 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -82,3 +82,18 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari loss_features['_baseline_'] = loss_baseline return pd.DataFrame(loss_features, index=[0]) + + +def calculate_loss(loss_function, observed, predicted, sample_weights=None): + # Determine if loss function accepts 'sample_weight' + loss_args = inspect.signature(loss_function).parameters + supports_weight = "sample_weight" in loss_args + + if supports_weight: + return loss_function(observed, predicted, sample_weight=sample_weights) + else: + if sample_weights: + warnings.warn( + f"Loss function {loss_function.__name__} does not take sample weights. Calculating unweighted loss." + ) + return loss_function(observed, predicted) From f15ddbe3ee06b6dd191b6808b33739bb8ca9cc11 Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:52:03 +0000 Subject: [PATCH 4/7] Replace loss function calls with wrapper --- .../dalex/model_explanations/_variable_importance/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 47a72d7ce..623fd093f 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -62,10 +62,11 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari sample_weights = weights # loss on the full model or when outcomes are permuted - loss_full = loss_function(observed, predict(model, sampled_data)) + loss_full = calculate_loss(loss_function, observed, predict(model, sampled_data), sample_weights) sampled_rows2 = rng.choice(range(observed.shape[0]), observed.shape[0], replace=False) - loss_baseline = loss_function(observed[sampled_rows2], predict(model, sampled_data)) + sample_weights_rows2 = sample_weights[sampled_rows2] if sample_weights is not None else None + loss_baseline = calculate_loss(loss_function, observed[sampled_rows2], predict(model, sampled_data), sample_weights_rows2) loss_features = {} for variables_set_key in variables: @@ -76,7 +77,7 @@ def loss_after_permutation(data, y, weights, model, predict, loss_function, vari predicted = predict(model, ndf) - loss_features[variables_set_key] = loss_function(observed, predicted) + loss_features[variables_set_key] = calculate_loss(loss_function, observed, predicted, sample_weights) loss_features['_full_model_'] = loss_full loss_features['_baseline_'] = loss_baseline From 1497a42196eeaa00e7b2c18037b378066a6d69c6 Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 10:52:54 +0000 Subject: [PATCH 5/7] Add imports --- .../dalex/model_explanations/_variable_importance/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 623fd093f..2a147e280 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -1,4 +1,6 @@ +import inspect import multiprocessing as mp +import warnings from numpy.random import SeedSequence, default_rng import numpy as np From 9cb1401b8edec80a63e5344c797739ab97f44be7 Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 11:17:15 +0000 Subject: [PATCH 6/7] Avoid ambiguous truth values --- .../dalex/model_explanations/_variable_importance/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index 2a147e280..f92431a6f 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -95,7 +95,7 @@ def calculate_loss(loss_function, observed, predicted, sample_weights=None): if supports_weight: return loss_function(observed, predicted, sample_weight=sample_weights) else: - if sample_weights: + if sample_weights is not None: warnings.warn( f"Loss function {loss_function.__name__} does not take sample weights. Calculating unweighted loss." ) From 5af346222bce700002620fcd79f4a00fbe2e2b28 Mon Sep 17 00:00:00 2001 From: danielarifmurphy <48385579+danielarifmurphy@users.noreply.github.com> Date: Wed, 1 May 2024 11:30:13 +0000 Subject: [PATCH 7/7] More explicit warning if weights passed but not used in loss calc --- .../dalex/model_explanations/_variable_importance/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dalex/dalex/model_explanations/_variable_importance/utils.py b/python/dalex/dalex/model_explanations/_variable_importance/utils.py index f92431a6f..9d08f9bd1 100644 --- a/python/dalex/dalex/model_explanations/_variable_importance/utils.py +++ b/python/dalex/dalex/model_explanations/_variable_importance/utils.py @@ -97,6 +97,6 @@ def calculate_loss(loss_function, observed, predicted, sample_weights=None): else: if sample_weights is not None: warnings.warn( - f"Loss function {loss_function.__name__} does not take sample weights. Calculating unweighted loss." + f"Loss function `{loss_function.__name__}` does not have `sample_weight` argument. Calculating unweighted loss." ) return loss_function(observed, predicted)