Skip to content

Commit

Permalink
ray support added and used as a default (instead of joblib)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwan-tee committed Apr 16, 2024
1 parent e431803 commit bf64bfe
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
38 changes: 34 additions & 4 deletions hiclass/Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
else:
shap_installed = True

try:
import ray
except ImportError:
_has_ray = False
else:
_has_ray = True


class Explainer:
"""Explainer class for returning shap values for each of the three hierarchical classifiers."""
Expand Down Expand Up @@ -125,7 +132,7 @@ def explain(self, X):
else:
raise ValueError(f"Invalid model: {self.hierarchical_model}.")

def _explain_with_xr(self, X):
def _explain_with_xr(self, X, use_joblib: bool = False):
"""
Generate SHAP values for each node using the SHAP package.
Expand All @@ -139,10 +146,28 @@ def _explain_with_xr(self, X):
explanation : xarray.Dataset
An xarray Dataset consisting of SHAP values for each sample.
"""
explanations = Parallel(n_jobs=self.n_jobs, backend="threading")(
delayed(self._calculate_shap_values)(sample.reshape(1, -1)) for sample in X
)
if self.n_jobs > 1:
if _has_ray and not use_joblib:
if not ray.is_initialized():
ray.init(num_cpus=self.n_jobs)

calculate_shap_values_remote = ray.remote(calculate_shap_values_wrapper)

tasks = [
calculate_shap_values_remote.remote(self, sample.reshape(1, -1))
for sample in X
]

explanations = ray.get(tasks)
else:
explanations = Parallel(n_jobs=self.n_jobs, backend="threading")(
delayed(self._calculate_shap_values)(sample.reshape(1, -1))
for sample in X
)
else:
explanations = [
self._calculate_shap_values(sample.reshape(1, -1)) for sample in X
]
dataset = xr.concat(explanations, dim="sample")
return dataset

Expand Down Expand Up @@ -590,3 +615,8 @@ def shap_multi_plot(self, class_names, features, pred_class, features_names=None
class_names=class_names,
)
return explanations


# A wrapper function for Ray enabling
def calculate_shap_values_wrapper(explainer, sample):
return explainer._calculate_shap_values(sample)
3 changes: 3 additions & 0 deletions tests/test_Explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ def test_shap_multi_plot(data, request, classifier):
explainer = Explainer(clf, data=x_train)

class_names = np.random.choice(predictions[0, :], size=2)
while class_names[0] == "" or class_names[1] == "":
class_names = np.random.choice(predictions[0, :], size=2)

explanations = explainer.shap_multi_plot(
class_names=np.random.choice(predictions[0, :], size=2),
features=x_test,
Expand Down

0 comments on commit bf64bfe

Please sign in to comment.