diff --git a/DESCRIPTION b/DESCRIPTION index be822bd1b..a58e5485f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -2,7 +2,7 @@ Package: PatientLevelPrediction Type: Package Title: Developing patient level prediction using data in the OMOP Common Data Model -Version: 6.3.2 +Version: 6.3.3 Date: 2023-05-15 Authors@R: c( person("Jenna", "Reps", email = "jreps@its.jnj.com", role = c("aut", "cre")), diff --git a/NEWS.md b/NEWS.md index ee19962b5..8c55aceed 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +PatientLevelPrediction 6.3.3 +====================== +- fixed bug introduced with new reticulate update in model saving to json tests + + PatientLevelPrediction 6.3.2 ====================== - fixed bug with database insert if result is incomplete diff --git a/R/HelperFunctions.R b/R/HelperFunctions.R index 63c76c990..e3b939ff6 100644 --- a/R/HelperFunctions.R +++ b/R/HelperFunctions.R @@ -115,7 +115,7 @@ configurePython <- function(envname='PLP', envtype=NULL){ ParallelLogger::logInfo(paste0('Creating virtual conda environment called ', envname)) location <- reticulate::conda_create(envname=envname, packages = "python", conda = "auto") } - packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib', 'sklearn-json') + packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib') ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname)) reticulate::conda_install(envname=envname, packages = packages, forge = TRUE, pip = FALSE, pip_ignore_installed = TRUE, conda = "auto") @@ -128,7 +128,7 @@ configurePython <- function(envname='PLP', envtype=NULL){ ParallelLogger::logInfo(paste0('Creating virtual python environment called ', envname)) location <- reticulate::virtualenv_create(envname=envname) } - packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus','sklearn-json') + packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus') ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname)) reticulate::virtualenv_install(envname=envname, packages = packages, ignore_installed = TRUE) diff --git a/R/SklearnToJson.R b/R/SklearnToJson.R index 95607dd3c..9c1dba1ae 100644 --- a/R/SklearnToJson.R +++ b/R/SklearnToJson.R @@ -54,17 +54,17 @@ sklearnFromJson <- function(path) { with(py$open(path, "r"), as=file, { model <- json$load(fp=file) }) - if (model["meta"] == "decision-tree") { + if (reticulate::py_bool(model["meta"] == "decision-tree")) { model <- deSerializeDecisionTree(model) - } else if (model["meta"] == "rf") { + } else if (reticulate::py_bool(model["meta"] == "rf")) { model <- deSerializeRandomForest(model) - } else if (model["meta"] == "adaboost") { + } else if (reticulate::py_bool(model["meta"] == "adaboost")) { model <- deSerializeAdaboost(model) - } else if (model["meta"] == "naive-bayes") { + } else if (reticulate::py_bool(model["meta"] == "naive-bayes")) { model <- deSerializeNaiveBayes(model) - } else if (model["meta"] == "mlp") { + } else if (reticulate::py_bool(model["meta"] == "mlp")) { model <- deSerializeMlp(model) - } else if (model["meta"] == "svm") { + } else if (reticulate::py_bool(model["meta"] == "svm")) { model <- deSerializeSVM(model) } else { stop("Unsupported model") @@ -181,7 +181,7 @@ serializeRandomForest <- function(model) { "params" = model$get_params(), "n_classes_" = model$n_classes_) - if (model$`__dict__`["oob_score_"] != reticulate::py_none()) { + if (reticulate::py_bool(model$`__dict__`["oob_score_"] != reticulate::py_none())) { serialized_model["oob_score_"] <- model$oob_score_ serialized_model["oob_decision_function_"] <- model$oob_decision_function_$tolist() } @@ -215,7 +215,7 @@ deSerializeRandomForest <- function(model_dict) { model$min_impurity_split <- model_dict["min_impurity_split"] model$n_classes_ <- model_dict["n_classes_"] - if (model_dict$oob_score_ != reticulate::py_none()){ + if (reticulate::py_bool(model_dict$oob_score_ != reticulate::py_none())){ model$oob_score_ <- model_dict["oob_score_"] model$oob_decision_function_ <- model_dict["oob_decision_function_"] } @@ -387,23 +387,23 @@ deSerializeSVM <- function(model_dict) { model$`_probB` <- np$array(model_dict["probB_"])$astype(np$float64) model$`_intercept_` <- np$array(model_dict["_intercept_"])$astype(np$float64) - if ((model_dict$support_vectors_["meta"] != reticulate::py_none()) & - (model_dict$support_vectors_["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$support_vectors_["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$support_vectors_["meta"] == "csr"))) { model$support_vectors_ <- deSerializeCsrMatrix(model_dict$support_vectors_) model$`_sparse` <- TRUE } else { model$support_vectors_ <- np$array(model_dict$support_vectors_)$astype(np$float64) model$`_sparse` <- FALSE } - if ((model_dict$dual_coef_["meta"] != reticulate::py_none()) & - (model_dict$dual_coef_["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$dual_coef_["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$dual_coef_["meta"] == "csr"))) { model$dual_coef_ <- deSerializeCsrMatrix(model_dict$dual_coef_) } else { model$dual_coef_ <- np$array(model_dict$dual_coef_)$astype(np$float64) } - if ((model_dict$`_dual_coef_`["meta"] != reticulate::py_none()) & - (model_dict$`_dual_coef_`["meta"] == "csr")) { + if (reticulate::py_bool((model_dict$`_dual_coef_`["meta"] != reticulate::py_none())) & + (reticulate::py_bool(model_dict$`_dual_coef_`["meta"] == "csr"))) { model$`_dual_coef_` <- deSerializeCsrMatrix(model_dict$`dual_coef_`) } else { model$`_dual_coef_` <- np$array(model_dict$`_dual_coef_`)$astype(np$float64)