Skip to content

Commit

Permalink
Merge pull request #401 from OHDSI/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jreps authored Jun 23, 2023
2 parents f542d0c + 76cbe81 commit f5942b6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 17 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: PatientLevelPrediction
Type: Package
Title: Developing patient level prediction using data in the OMOP Common Data
Model
Version: 6.3.2
Version: 6.3.3
Date: 2023-05-15
Authors@R: c(
person("Jenna", "Reps", email = "[email protected]", role = c("aut", "cre")),
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
PatientLevelPrediction 6.3.3
======================
- fixed bug introduced with new reticulate update in model saving to json tests


PatientLevelPrediction 6.3.2
======================
- fixed bug with database insert if result is incomplete
Expand Down
4 changes: 2 additions & 2 deletions R/HelperFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ configurePython <- function(envname='PLP', envtype=NULL){
ParallelLogger::logInfo(paste0('Creating virtual conda environment called ', envname))
location <- reticulate::conda_create(envname=envname, packages = "python", conda = "auto")
}
packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib', 'sklearn-json')
packages <- c('numpy','scipy','scikit-learn', 'pandas','pydotplus','joblib')
ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname))
reticulate::conda_install(envname=envname, packages = packages, forge = TRUE, pip = FALSE,
pip_ignore_installed = TRUE, conda = "auto")
Expand All @@ -128,7 +128,7 @@ configurePython <- function(envname='PLP', envtype=NULL){
ParallelLogger::logInfo(paste0('Creating virtual python environment called ', envname))
location <- reticulate::virtualenv_create(envname=envname)
}
packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus','sklearn-json')
packages <- c('numpy', 'scikit-learn','scipy', 'pandas','pydotplus')
ParallelLogger::logInfo(paste0('Adding python dependancies to ', envname))
reticulate::virtualenv_install(envname=envname, packages = packages,
ignore_installed = TRUE)
Expand Down
28 changes: 14 additions & 14 deletions R/SklearnToJson.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,17 @@ sklearnFromJson <- function(path) {
with(py$open(path, "r"), as=file, {
model <- json$load(fp=file)
})
if (model["meta"] == "decision-tree") {
if (reticulate::py_bool(model["meta"] == "decision-tree")) {
model <- deSerializeDecisionTree(model)
} else if (model["meta"] == "rf") {
} else if (reticulate::py_bool(model["meta"] == "rf")) {
model <- deSerializeRandomForest(model)
} else if (model["meta"] == "adaboost") {
} else if (reticulate::py_bool(model["meta"] == "adaboost")) {
model <- deSerializeAdaboost(model)
} else if (model["meta"] == "naive-bayes") {
} else if (reticulate::py_bool(model["meta"] == "naive-bayes")) {
model <- deSerializeNaiveBayes(model)
} else if (model["meta"] == "mlp") {
} else if (reticulate::py_bool(model["meta"] == "mlp")) {
model <- deSerializeMlp(model)
} else if (model["meta"] == "svm") {
} else if (reticulate::py_bool(model["meta"] == "svm")) {
model <- deSerializeSVM(model)
} else {
stop("Unsupported model")
Expand Down Expand Up @@ -181,7 +181,7 @@ serializeRandomForest <- function(model) {
"params" = model$get_params(),
"n_classes_" = model$n_classes_)

if (model$`__dict__`["oob_score_"] != reticulate::py_none()) {
if (reticulate::py_bool(model$`__dict__`["oob_score_"] != reticulate::py_none())) {
serialized_model["oob_score_"] <- model$oob_score_
serialized_model["oob_decision_function_"] <- model$oob_decision_function_$tolist()
}
Expand Down Expand Up @@ -215,7 +215,7 @@ deSerializeRandomForest <- function(model_dict) {
model$min_impurity_split <- model_dict["min_impurity_split"]
model$n_classes_ <- model_dict["n_classes_"]

if (model_dict$oob_score_ != reticulate::py_none()){
if (reticulate::py_bool(model_dict$oob_score_ != reticulate::py_none())){
model$oob_score_ <- model_dict["oob_score_"]
model$oob_decision_function_ <- model_dict["oob_decision_function_"]
}
Expand Down Expand Up @@ -387,23 +387,23 @@ deSerializeSVM <- function(model_dict) {
model$`_probB` <- np$array(model_dict["probB_"])$astype(np$float64)
model$`_intercept_` <- np$array(model_dict["_intercept_"])$astype(np$float64)

if ((model_dict$support_vectors_["meta"] != reticulate::py_none()) &
(model_dict$support_vectors_["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$support_vectors_["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$support_vectors_["meta"] == "csr"))) {
model$support_vectors_ <- deSerializeCsrMatrix(model_dict$support_vectors_)
model$`_sparse` <- TRUE
} else {
model$support_vectors_ <- np$array(model_dict$support_vectors_)$astype(np$float64)
model$`_sparse` <- FALSE
}
if ((model_dict$dual_coef_["meta"] != reticulate::py_none()) &
(model_dict$dual_coef_["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$dual_coef_["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$dual_coef_["meta"] == "csr"))) {
model$dual_coef_ <- deSerializeCsrMatrix(model_dict$dual_coef_)
} else {
model$dual_coef_ <- np$array(model_dict$dual_coef_)$astype(np$float64)
}

if ((model_dict$`_dual_coef_`["meta"] != reticulate::py_none()) &
(model_dict$`_dual_coef_`["meta"] == "csr")) {
if (reticulate::py_bool((model_dict$`_dual_coef_`["meta"] != reticulate::py_none())) &
(reticulate::py_bool(model_dict$`_dual_coef_`["meta"] == "csr"))) {
model$`_dual_coef_` <- deSerializeCsrMatrix(model_dict$`dual_coef_`)
} else {
model$`_dual_coef_` <- np$array(model_dict$`_dual_coef_`)$astype(np$float64)
Expand Down

0 comments on commit f5942b6

Please sign in to comment.