Skip to content

Commit

Permalink
feat: allow overriding the config when loading a pipeline from the disk
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Nov 30, 2023
1 parent 2d2c59b commit 7b07ff2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
19 changes: 19 additions & 0 deletions edsnlp/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,15 +968,32 @@ def blank(

def load(
config: Union[Path, str, Config],
overrides: Optional[Dict[str, Any]] = None,
*,
exclude: Optional[Union[str, Iterable[str]]] = None,
):
"""
Load a pipeline from a config file or a directory.
Examples
--------
```{ .python .no-check }
import edsnlp
nlp = edsnlp.load(
"path/to/config.cfg",
overrides={"components": {"my_component": {"arg": "value"}}},
)
```
Parameters
----------
config: Union[Path, str, Config]
The config to use for the pipeline, or the path to a config file or a directory.
overrides: Optional[Dict[str, Any]]
Overrides to apply to the config when loading the pipeline. These are the
same parameters as the ones used when initializing the pipeline.
exclude: Optional[Union[str, Iterable[str]]]
The names of the components, or attributes to exclude from the loading
process. :warning: The `exclude` argument will be mutated in place.
Expand All @@ -991,6 +1008,8 @@ def load(
if path.is_dir():
path = (Path(path) if isinstance(path, str) else path).absolute()
config = Config.from_disk(path / "config.cfg")
if overrides:
config = config.merge(overrides)
pwd = os.getcwd()
try:
os.chdir(path)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_sequence(frozen_ml_nlp: Pipeline):
def test_disk_serialization(tmp_path, ml_nlp):
nlp = ml_nlp

assert nlp.get_pipe("transformer").stride == 96
ner = nlp.get_pipe("ner")
ner.update_labels(["PERSON", "GIFT"])
nlp.to_disk(tmp_path / "model")
Expand All @@ -90,8 +91,12 @@ def test_disk_serialization(tmp_path, ml_nlp):
)
)

nlp = edsnlp.load(tmp_path / "model")
nlp = edsnlp.load(
tmp_path / "model",
overrides={"components": {"transformer": {"stride": 64}}},
)
assert nlp.get_pipe("ner").labels == ["PERSON", "GIFT"]
assert nlp.get_pipe("transformer").stride == 64


config_str = """\
Expand Down

0 comments on commit 7b07ff2

Please sign in to comment.