Skip to content

Commit

Permalink
Merge pull request #398 from OHDSI/reticulate-fix
Browse files Browse the repository at this point in the history
wrap comparisons with python objects in py_bool
  • Loading branch information
jreps authored Jun 22, 2023
2 parents 603f9cc + ea01f2e commit bcfdda7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ configurePython <- function(envname='PLP', envtype=NULL){
ParallelLogger::logInfo(paste0('Creating virtual conda environment called ', envname))
location <- reticulate::conda_create(envname=envname, packages = "python", conda = "auto")
}
packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib', 'sklearn-json')
packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib')
ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname))
reticulate::conda_install(envname=envname, packages = packages, forge = TRUE, pip = FALSE,
pip_ignore_installed = TRUE, conda = "auto")
Expand All @@ -128,7 +128,7 @@ configurePython <- function(envname='PLP', envtype=NULL){
ParallelLogger::logInfo(paste0('Creating virtual python environment called ', envname))
location <- reticulate::virtualenv_create(envname=envname)
}
packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus','sklearn-json')
packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus')
ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname))
reticulate::virtualenv_install(envname=envname, packages = packages,
ignore_installed = TRUE)
Expand Down
28 changes: 14 additions & 14 deletions R/SklearnToJson.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ sklearnFromJson <- function(path) {
with(py$open(path, "r"), as=file, {
model <- json$load(fp=file)
})
if (model["meta"] == "decision-tree") {
if (reticulate::py_bool(model["meta"] == "decision-tree")) {
model <- deSerializeDecisionTree(model)
} else if (model["meta"] == "rf") {
} else if (reticulate::py_bool(model["meta"] == "rf")) {
model <- deSerializeRandomForest(model)
} else if (model["meta"] == "adaboost") {
} else if (reticulate::py_bool(model["meta"] == "adaboost")) {
model <- deSerializeAdaboost(model)
} else if (model["meta"] == "naive-bayes") {
} else if (reticulate::py_bool(model["meta"] == "naive-bayes")) {
model <- deSerializeNaiveBayes(model)
} else if (model["meta"] == "mlp") {
} else if (reticulate::py_bool(model["meta"] == "mlp")) {
model <- deSerializeMlp(model)
} else if (model["meta"] == "svm") {
} else if (reticulate::py_bool(model["meta"] == "svm")) {
model <- deSerializeSVM(model)
} else {
stop("Unsupported model")
Expand Down Expand Up @@ -181,7 +181,7 @@ serializeRandomForest <- function(model) {
"params" = model$get_params(),
"n_classes_" = model$n_classes_)

if (model$`__dict__`["oob_score_"] != reticulate::py_none()) {
if (reticulate::py_bool(model$`__dict__`["oob_score_"] != reticulate::py_none())) {
serialized_model["oob_score_"] <- model$oob_score_
serialized_model["oob_decision_function_"] <- model$oob_decision_function_$tolist()
}
Expand Down Expand Up @@ -215,7 +215,7 @@ deSerializeRandomForest <- function(model_dict) {
model$min_impurity_split <- model_dict["min_impurity_split"]
model$n_classes_ <- model_dict["n_classes_"]

if (model_dict$oob_score_ != reticulate::py_none()){
if (reticulate::py_bool(model_dict$oob_score_ != reticulate::py_none())){
model$oob_score_ <- model_dict["oob_score_"]
model$oob_decision_function_ <- model_dict["oob_decision_function_"]
}
Expand Down Expand Up @@ -387,23 +387,23 @@ deSerializeSVM <- function(model_dict) {
model$`_probB` <- np$array(model_dict["probB_"])$astype(np$float64)
model$`_intercept_` <- np$array(model_dict["_intercept_"])$astype(np$float64)

if ((model_dict$support_vectors_["meta"] != reticulate::py_none()) &
(model_dict$support_vectors_["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$support_vectors_["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$support_vectors_["meta"] == "csr"))) {
model$support_vectors_ <- deSerializeCsrMatrix(model_dict$support_vectors_)
model$`_sparse` <- TRUE
} else {
model$support_vectors_ <- np$array(model_dict$support_vectors_)$astype(np$float64)
model$`_sparse` <- FALSE
}
if ((model_dict$dual_coef_["meta"] != reticulate::py_none()) &
(model_dict$dual_coef_["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$dual_coef_["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$dual_coef_["meta"] == "csr"))) {
model$dual_coef_ <- deSerializeCsrMatrix(model_dict$dual_coef_)
} else {
model$dual_coef_ <- np$array(model_dict$dual_coef_)$astype(np$float64)
}

if ((model_dict$`_dual_coef_`["meta"] != reticulate::py_none()) &
(model_dict$`_dual_coef_`["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$`_dual_coef_`["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$`_dual_coef_`["meta"] == "csr"))) {
model$`_dual_coef_` <- deSerializeCsrMatrix(model_dict$`dual_coef_`)
} else {
model$`_dual_coef_` <- np$array(model_dict$`_dual_coef_`)$astype(np$float64)
Expand Down

0 comments on commit bcfdda7

Please sign in to comment.