From a1107ba6f8c5ef8db748de2a9101c2b4b64ec551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Perceval=20Wajsb=C3=BCrt?= Date: Wed, 13 Dec 2023 18:23:43 +0100 Subject: [PATCH] fix: update spacy patch & remove ineffective and buggy registry override --- edsnlp/patch_spacy.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/edsnlp/patch_spacy.py b/edsnlp/patch_spacy.py index 66b662c40..ff1a16fe4 100644 --- a/edsnlp/patch_spacy.py +++ b/edsnlp/patch_spacy.py @@ -114,6 +114,7 @@ def __init__( max_length: int = 10**6, meta: Dict[str, Any] = {}, create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None, + create_vectors: Optional[Callable[["Vocab"], Any]] = None, batch_size: int = 1000, **kwargs, ) -> None: @@ -139,12 +140,6 @@ def __init__( DOCS: https://spacy.io/api/language#init """ - - # EDS-NLP: disable spacy default call to load every factory - # since some of them may be missing dependencies (like torch) - # util.registry._entry_point_factories.get_all() - util.registry.factories = util.registry._entry_point_factories - self._config = DEFAULT_CONFIG.merge(self.default_config) self._meta = dict(meta) self._path = None @@ -158,8 +153,13 @@ def __init__( if vocab is True: vectors_name = meta.get("vectors", {}).get("name") vocab = create_vocab(self.lang, self.Defaults, vectors_name=vectors_name) - if (self.lang and vocab.lang) and (self.lang != vocab.lang): - raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) + if not create_vectors and "@vectors" in self._config["nlp"]["vectors"]: + vectors_cfg = {"vectors": self._config["nlp"]["vectors"]} + create_vectors = registry.resolve(vectors_cfg)["vectors"] + vocab.vectors = create_vectors(vocab) + else: + if (self.lang and vocab.lang) and (self.lang != vocab.lang): + raise ValueError(Errors.E150.format(nlp=self.lang, vocab=vocab.lang)) self.vocab: Vocab = vocab if self.lang is None: self.lang = self.vocab.lang @@ -167,12 +167,9 @@ def __init__( self._disabled: Set[str] = set() self.max_length = max_length # Create the default tokenizer from the default config - create_tokenizer = ( - create_tokenizer - or registry.resolve({"tokenizer": self._config["nlp"]["tokenizer"]})[ - "tokenizer" - ] - ) + if not create_tokenizer: + tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]} + create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"] self.tokenizer = create_tokenizer(self) self.batch_size = batch_size self.default_error_handler = raise_error