diff --git a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py index 52d5e163b..5d2ed02e9 100644 --- a/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py +++ b/edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import ( Any, Dict, @@ -9,6 +10,7 @@ import foldedtensor as ft import torch +from confit import VisibleDeprecationWarning from spacy.tokens import Doc, Span from typing_extensions import Literal, TypedDict @@ -81,11 +83,19 @@ def __init__( span_getter: Any = None, ): if span_getter is not None: - raise ValueError( + warnings.warn( "The `span_getter` parameter of the `eds.span_pooler` component is " "deprecated. Please use the `span_getter` parameter of the " - "`eds.span_classifier` or `eds.span_linker` components instead." + "`eds.span_classifier` or `eds.span_linker` components instead.", + VisibleDeprecationWarning, ) + sub_span_getter = getattr(embedding, "span_getter", None) + if sub_span_getter is not None and span_getter is None: # pragma: no cover + self.span_getter = sub_span_getter + sub_context_getter = getattr(embedding, "context_getter", None) + if sub_context_getter is not None: # pragma: no cover + self.context_getter = sub_context_getter + self.output_size = embedding.output_size if hidden_size is None else hidden_size super().__init__(nlp, name) diff --git a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py index 6c46847b6..7a1a3a39a 100644 --- a/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py +++ b/edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py @@ -63,6 +63,13 @@ def __init__( residual: bool = True, normalize: Literal["pre", "post", "none"] = "pre", ): + sub_span_getter = getattr(embedding, "span_getter", None) + if sub_span_getter is not None: # pragma: no cover + self.span_getter = sub_span_getter + sub_context_getter = getattr(embedding, "context_getter", None) + if sub_context_getter is not None: # pragma: no cover + self.context_getter = sub_context_getter + super().__init__(nlp, name) self.embedding = embedding self.output_size = output_size or embedding.output_size diff --git a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py index da6d19c76..4d1b68efb 100644 --- a/edsnlp/pipes/trainable/embeddings/transformer/transformer.py +++ b/edsnlp/pipes/trainable/embeddings/transformer/transformer.py @@ -1,3 +1,4 @@ +import warnings from pathlib import Path from typing import List, Optional, Set, Tuple, Union @@ -5,7 +6,7 @@ import tokenizers import tokenizers.normalizers import torch -from confit import validate_arguments +from confit import VisibleDeprecationWarning, validate_arguments from transformers import AutoModel, AutoTokenizer from transformers import BitsAndBytesConfig as BitsAndBytesConfig_ from typing_extensions import Literal, TypedDict @@ -148,10 +149,11 @@ def __init__( super().__init__(nlp, name) if span_getter is not None: - raise ValueError( + warnings.warn( "The `span_getter` parameter of the `eds.transformer` component is " "deprecated. Please use the `context_getter` parameter of the " - "other higher level task components instead." + "other higher level task components instead.", + VisibleDeprecationWarning, ) self.transformer = AutoModel.from_pretrained( model, diff --git a/edsnlp/pipes/trainable/ner_crf/ner_crf.py b/edsnlp/pipes/trainable/ner_crf/ner_crf.py index 19ae1af27..ac2d19800 100644 --- a/edsnlp/pipes/trainable/ner_crf/ner_crf.py +++ b/edsnlp/pipes/trainable/ner_crf/ner_crf.py @@ -183,6 +183,11 @@ def __init__( "You cannot set both the `labels` key of the `target_span_getter` " "parameter and the `labels` parameter." ) + sub_context_getter = getattr(embedding, "context_getter", None) + if ( + sub_context_getter is not None and context_getter is None + ): # pragma: no cover + context_getter = sub_context_getter super().__init__( nlp=nlp, diff --git a/edsnlp/pipes/trainable/span_classifier/span_classifier.py b/edsnlp/pipes/trainable/span_classifier/span_classifier.py index a9f66e769..b7111cbe2 100644 --- a/edsnlp/pipes/trainable/span_classifier/span_classifier.py +++ b/edsnlp/pipes/trainable/span_classifier/span_classifier.py @@ -212,6 +212,17 @@ def __init__( ) assert attributes is None attributes = qualifiers + sub_span_getter = getattr(embedding, "span_getter", None) + if ( + sub_span_getter is not None and span_getter is None + ): # pragma: no cover # noqa: E501 + span_getter = sub_span_getter + sub_context_getter = getattr(embedding, "context_getter", None) + if ( + sub_context_getter is not None and context_getter is None + ): # pragma: no cover + context_getter = sub_context_getter + self.values = values self.keep_none = keep_none self.bindings: List[Tuple[str, List[str], List[Any]]] = [ diff --git a/edsnlp/pipes/trainable/span_linker/span_linker.py b/edsnlp/pipes/trainable/span_linker/span_linker.py index 46cfa2498..f73593915 100644 --- a/edsnlp/pipes/trainable/span_linker/span_linker.py +++ b/edsnlp/pipes/trainable/span_linker/span_linker.py @@ -286,6 +286,15 @@ def __init__( ): self.attribute = attribute + sub_span_getter = getattr(embedding, "span_getter", None) + if sub_span_getter is not None and span_getter is None: # pragma: no cover + span_getter = sub_span_getter + sub_context_getter = getattr(embedding, "context_getter", None) + if ( + sub_context_getter is not None and context_getter is None + ): # pragma: no cover + context_getter = sub_context_getter + super().__init__( nlp, name, diff --git a/edsnlp/utils/span_getters.py b/edsnlp/utils/span_getters.py index 80e6accd1..6a39c3332 100644 --- a/edsnlp/utils/span_getters.py +++ b/edsnlp/utils/span_getters.py @@ -273,11 +273,16 @@ def __init__( self, context_words: NonNegativeInt = 0, context_sents: Optional[NonNegativeInt] = 1, + span_getter: Optional[SpanGetterArg] = None, ): self.context_words = context_words self.context_sents = context_sents + self.span_getter = validate_span_getter(span_getter, optional=True) + + def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]: + if isinstance(span, Doc): # pragma: no cover + return [self(s) for s in get_spans(span, self.span_getter)] - def __call__(self, span: Span) -> List[Span]: n_sents: int = self.context_sents n_words = self.context_words