Skip to content

Commit

Permalink
dict output format
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju committed Oct 22, 2024
1 parent 07aeedc commit b152ea3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 276 deletions.
271 changes: 0 additions & 271 deletions inst/jss_paper/code_sec_3.R

This file was deleted.

20 changes: 15 additions & 5 deletions python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def explain(
routput.rx2['timing'] = shapr.compute_time(rinternal)

# Some cleanup when doing testing
testing = rinternal.rx2('parameters').rx2('testing')[0]
if base.isTRUE(testing):
routput = shapr.testing_cleanup(routput)
#testing = rinternal.rx2('parameters').rx2('testing')[0]
#if base.isTRUE(testing):
# routput = shapr.testing_cleanup(routput)

# Convert R objects to Python objects
shapley_values_est = r2py(base.as_data_frame(routput.rx2('shapley_values_est')))
Expand All @@ -288,8 +288,18 @@ def explain(
#saving_path = StrVector(routput.rx2['saving_path']) # NOt sure why this is not working
saving_path = StrVector(rinternal.rx2['parameters'].rx2['output_args'].rx2['saving_path'])[0]
#internal = recurse_r_tree(routput.rx2('rinternal')) # Currently get an error with NULL elements here

return shapley_values_est, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal
rtiming = routput.rx2['timing']

return {
"shapley_values_est": shapley_values_est,
"shapley_values_sd": shapley_values_sd,
"pred_explain": pred_explain,
"MSEv": MSEv,
"iterative_results": iterative_results,
"saving_path": saving_path,
"internal": rinternal,
"timing": rtiming
}


def compute_vS(rinternal, model, predict_model):
Expand Down

0 comments on commit b152ea3

Please sign in to comment.