From fcc723ecff687917ae713c29e45e3ca7432396e3 Mon Sep 17 00:00:00 2001 From: Jeroen Dries Date: Wed, 24 Jul 2024 15:21:30 +0200 Subject: [PATCH] #142 make sure onnx session is cached --- src/openeo_gfmap/inference/model_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/openeo_gfmap/inference/model_inference.py b/src/openeo_gfmap/inference/model_inference.py index ca7ff41..080ecf9 100644 --- a/src/openeo_gfmap/inference/model_inference.py +++ b/src/openeo_gfmap/inference/model_inference.py @@ -75,8 +75,9 @@ def extract_dependencies(cls, base_url: str, dependency_name: str) -> str: return abs_path + @classmethod @functools.lru_cache(maxsize=6) - def load_ort_session(self, model_url: str): + def load_ort_session(cls, model_url: str): """Loads an onnx session from a publicly available URL. The URL must be a direct download link to the ONNX session file. The `lru_cache` decorator avoids loading multiple time the model within the same worker. @@ -181,7 +182,7 @@ def execute(self, inarr: xr.DataArray) -> xr.DataArray: raise ValueError("The model_url must be defined in the parameters.") # Load the model and the input_name parameters - session = self.load_ort_session(self._parameters.get("model_url")) + session = ModelInference.load_ort_session(self._parameters.get("model_url")) input_name = self._parameters.get("input_name") if input_name is None: