diff --git a/openeogeotrellis/backend.py b/openeogeotrellis/backend.py index 80bace52b..a4c8b4747 100644 --- a/openeogeotrellis/backend.py +++ b/openeogeotrellis/backend.py @@ -849,6 +849,16 @@ def _set_permissions(job_dir: Path): model: JavaObject = RandomForestModel._load_java(sc=gps.get_spark_context(), path="file:" + unpacked_model_path) return model elif architecture == "catboost": + if use_s3: + # TODO: Verify that local files work. If it does, we can remove the model_dir_path implementation. + # Download the model to the tmp directory and load it as a java object. + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir + "/catboost_model.cbm") + logger.info(f"Downloading ml_model from {model_url} to {tmp_path}") + with open(tmp_path, 'wb') as f: + f.write(requests.get(model_url).content) + model: JavaObject = CatBoostClassificationModel.load_native_model(tmp_path) + return model filename = Path(model_dir_path + "/catboost_model.cbm") with open(filename, 'wb') as f: f.write(requests.get(model_url).content)