From 7b07ff225cc30d79e5be59d7287c60812f135c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Fri, 1 Dec 2023 00:20:34 +0100 Subject: [PATCH] feat: allow overriding the config when loading a pipeline from the disk --- edsnlp/core/pipeline.py | 19 +++++++++++++++++++ tests/test_pipeline.py | 7 ++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/edsnlp/core/pipeline.py b/edsnlp/core/pipeline.py index 1376502ed..aab198391 100644 --- a/edsnlp/core/pipeline.py +++ b/edsnlp/core/pipeline.py @@ -968,15 +968,32 @@ def blank( def load( config: Union[Path, str, Config], + overrides: Optional[Dict[str, Any]] = None, + *, exclude: Optional[Union[str, Iterable[str]]] = None, ): """ Load a pipeline from a config file or a directory. + Examples + -------- + + ```{ .python .no-check } + import edsnlp + + nlp = edsnlp.load( + "path/to/config.cfg", + overrides={"components": {"my_component": {"arg": "value"}}}, + ) + ``` + Parameters ---------- config: Union[Path, str, Config] The config to use for the pipeline, or the path to a config file or a directory. + overrides: Optional[Dict[str, Any]] + Overrides to apply to the config when loading the pipeline. These are the + same parameters as the ones used when initializing the pipeline. exclude: Optional[Union[str, Iterable[str]]] The names of the components, or attributes to exclude from the loading process. :warning: The `exclude` argument will be mutated in place. @@ -991,6 +1008,8 @@ def load( if path.is_dir(): path = (Path(path) if isinstance(path, str) else path).absolute() config = Config.from_disk(path / "config.cfg") + if overrides: + config = config.merge(overrides) pwd = os.getcwd() try: os.chdir(path) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1c3ab9d66..f607fdbea 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -68,6 +68,7 @@ def test_sequence(frozen_ml_nlp: Pipeline): def test_disk_serialization(tmp_path, ml_nlp): nlp = ml_nlp + assert nlp.get_pipe("transformer").stride == 96 ner = nlp.get_pipe("ner") ner.update_labels(["PERSON", "GIFT"]) nlp.to_disk(tmp_path / "model") @@ -90,8 +91,12 @@ def test_disk_serialization(tmp_path, ml_nlp): ) ) - nlp = edsnlp.load(tmp_path / "model") + nlp = edsnlp.load( + tmp_path / "model", + overrides={"components": {"transformer": {"stride": 64}}}, + ) assert nlp.get_pipe("ner").labels == ["PERSON", "GIFT"] + assert nlp.get_pipe("transformer").stride == 64 config_str = """\