diff --git a/python/ngen_cal/src/ngen/cal/__main__.py b/python/ngen_cal/src/ngen/cal/__main__.py index 1d369f3c..230951c7 100644 --- a/python/ngen_cal/src/ngen/cal/__main__.py +++ b/python/ngen_cal/src/ngen/cal/__main__.py @@ -4,7 +4,8 @@ import yaml from os import chdir from pathlib import Path -from ngen.cal.configuration import General +from ngen.cal.configuration import General, Model +from ngen.cal.ngen import Ngen from ngen.cal.search import dds, dds_set, pso_search from ngen.cal.strategy import Algorithm from ngen.cal.agent import Agent @@ -42,6 +43,13 @@ def main(general: General, model_conf: Mapping[str, Any]): import numpy as np np.random.seed(general.random_seed) + # model scope plugins setup in constructor + model = Model(model=model_conf) + + # NOTE: if support for new models is added, this will need to be modified + assert isinstance(model.model, Ngen), f"ngen.cal.ngen.Ngen expected, got {type(model.model)}" + model_inner = model.model.unwrap() + plugins = cast(List[Union[Callable, ModuleType]], general.plugins) plugin_manager = setup_plugin_manager(plugins) @@ -57,8 +65,13 @@ def main(general: General, model_conf: Mapping[str, Any]): into a single variable vector and calibrating a set of heterogenous formultions... """ start_iteration = 0 - #Initialize the starting agent - agent = Agent(model_conf, general.workdir, general.log, general.restart, general.strategy.parameters) + + # Initialize the starting agent + agent = Agent(model, general.workdir, general.log, general.restart, general.strategy.parameters) + + # Agent mutates the model config, so `ngen_cal_model_configure` is called afterwards + model_inner._plugin_manager.hook.ngen_cal_model_configure(config=model_inner) + if general.strategy.algorithm == Algorithm.dds: func = dds_set #FIXME what about explicit/dds start_iteration = general.start_iteration diff --git a/python/ngen_cal/src/ngen/cal/_hookspec.py b/python/ngen_cal/src/ngen/cal/_hookspec.py index ef1e5bf6..dbda5d88 100644 --- a/python/ngen_cal/src/ngen/cal/_hookspec.py +++ b/python/ngen_cal/src/ngen/cal/_hookspec.py @@ -13,6 +13,7 @@ from hypy.nexus import Nexus from ngen.cal.configuration import General + from ngen.cal.model import ModelExec from ngen.cal.meta import JobMeta hookspec = pluggy.HookspecMarker(PROJECT_SLUG) @@ -51,6 +52,19 @@ def ngen_cal_finish(exception: Exception | None) -> None: class ModelHooks: + @hookspec + def ngen_cal_model_configure(self, config: ModelExec) -> None: + """ + Called before calibration begins. + This allow plugins to perform initial configuration. + + Plugins' configuration data should be provided using the + `plugins_settings` field in the `model` section of an `ngen.cal` + configuration file. + By convention, the name of the plugin should be used as top level key in + the `plugin_settings` dictionary. + """ + @hookspec(firstresult=True) def ngen_cal_model_observations( self, diff --git a/python/ngen_cal/src/ngen/cal/agent.py b/python/ngen_cal/src/ngen/cal/agent.py index 4f124c8d..6979094a 100644 --- a/python/ngen_cal/src/ngen/cal/agent.py +++ b/python/ngen_cal/src/ngen/cal/agent.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from ngen.cal.meta import JobMeta -from ngen.cal.configuration import Model +from ngen.cal.configuration import Model, NoModel from ngen.cal.utils import pushd from pathlib import Path from typing import TYPE_CHECKING @@ -59,9 +59,14 @@ def best_params(self) -> str: class Agent(BaseAgent): - def __init__(self, model_conf, workdir: Path, log: bool=False, restart: bool=False, parameters: Mapping[str, Any] | None = {}): + def __init__(self, model: Model, workdir: Path, log: bool=False, restart: bool=False, parameters: Mapping[str, Any] | None = {}): self._workdir = workdir self._job = None + assert not isinstance(model.model, NoModel), "invariant" + # NOTE: if support for new models is added, support for other model + # type variants will be required + ngen_model = model.model.unwrap() + self._model = model if restart: # find prior ngen workdirs # FIXME if a user starts with an independent calibration strategy @@ -72,17 +77,15 @@ def __init__(self, model_conf, workdir: Path, log: bool=False, restart: bool=Fal # 0 correctly since not all basin params can be loaded. # There are probably some similar issues with explicit and independent, since they have # similar data semantics - workdirs = list(Path.glob(workdir, model_conf['type']+"_*_worker")) + workdirs = list(Path.glob(workdir, ngen_model.type+"_*_worker")) if len(workdirs) > 1: print("More than one existing workdir, cannot restart") elif len(workdirs) == 1: - self._job = JobMeta(model_conf['type'], workdir, workdirs[0], log=log) + self._job = JobMeta(ngen_model.type, workdir, workdirs[0], log=log) if self._job is None: - self._job = JobMeta(model_conf['type'], workdir, log=log) - resolved_binary = Path(model_conf['binary']).resolve() - model_conf['workdir'] = self.job.workdir - self._model = Model(model=model_conf, binary=resolved_binary) + self._job = JobMeta(ngen_model.type, workdir, log=log) + ngen_model.workdir = self.job.workdir self._model.model.resolve_paths(self.job.workdir) self._params = parameters @@ -117,4 +120,4 @@ def duplicate(self) -> Agent: data = self.model.__root__.copy(deep=True) #return a new agent, which has a unique Model instance #and its own Job/workspace - return Agent(data.dict(by_alias=True), self._workdir) + return Agent(data, self._workdir) diff --git a/python/ngen_cal/src/ngen/cal/ngen.py b/python/ngen_cal/src/ngen/cal/ngen.py index 27c7db13..bf2e1ea3 100644 --- a/python/ngen_cal/src/ngen/cal/ngen.py +++ b/python/ngen_cal/src/ngen/cal/ngen.py @@ -595,6 +595,11 @@ def get_binary(self) -> str: return self.__root__.get_binary() def update_config(self, *args, **kwargs): return self.__root__.update_config(*args, **kwargs) + + def unwrap(self) -> NgenBase: + """convenience method that returns the underlying __root__ instance""" + return self.__root__ + #proxy methods for model @property def adjustables(self): diff --git a/python/ngen_cal/tests/conftest.py b/python/ngen_cal/tests/conftest.py index 64881ce9..678b1cf8 100644 --- a/python/ngen_cal/tests/conftest.py +++ b/python/ngen_cal/tests/conftest.py @@ -5,7 +5,7 @@ import json import pandas as pd # type: ignore import geopandas as gpd # type: ignore -from ngen.cal.configuration import General +from ngen.cal.configuration import General, Model from ngen.cal.ngen import Ngen from ngen.cal.meta import JobMeta from ngen.cal.calibration_cathment import CalibrationCatchment @@ -104,9 +104,10 @@ def meta(ngen_config, general_config, mocker) -> Generator[JobMeta, None, None]: yield m @pytest.fixture -def agent(ngen_config, general_config) -> Generator['Agent', None, None]: - a = Agent(ngen_config.__root__.dict(), general_config.workdir, general_config.log) - yield a +def agent(ngen_config, general_config) -> Agent: + model = Model(model=ngen_config) + a = Agent(model, general_config.workdir, general_config.log) + return a @pytest.fixture def eval(ngen_config) -> Generator[EvaluationOptions, None, None]: diff --git a/python/ngen_cal/tests/test_plugin_system.py b/python/ngen_cal/tests/test_plugin_system.py index b76e9baf..36401b2e 100644 --- a/python/ngen_cal/tests/test_plugin_system.py +++ b/python/ngen_cal/tests/test_plugin_system.py @@ -15,6 +15,8 @@ from hypy.nexus import Nexus from ngen.cal.configuration import General + from ngen.cal.model import ModelExec + from pathlib import Path def test_setup_plugin_manager(): @@ -74,6 +76,10 @@ def ngen_cal_start(self) -> None: def ngen_cal_finish(self) -> None: """Called after exiting the calibration loop.""" + @hookimpl + def ngen_cal_model_configure(self, config: ModelExec) -> None: + """Test model model configure plugin""" + @hookimpl def ngen_cal_model_output(self) -> None: """Test model output plugin"""