Skip to content

Commit

Permalink
fix: add back support for deprecated span pooler span_getter arg
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed May 16, 2024
1 parent ac4c270 commit 79d2601
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 6 deletions.
14 changes: 12 additions & 2 deletions edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import (
Any,
Dict,
Expand All @@ -9,6 +10,7 @@

import foldedtensor as ft
import torch
from confit import VisibleDeprecationWarning
from spacy.tokens import Doc, Span
from typing_extensions import Literal, TypedDict

Expand Down Expand Up @@ -81,11 +83,19 @@ def __init__(
span_getter: Any = None,
):
if span_getter is not None:
raise ValueError(
warnings.warn(
"The `span_getter` parameter of the `eds.span_pooler` component is "
"deprecated. Please use the `span_getter` parameter of the "
"`eds.span_classifier` or `eds.span_linker` components instead."
"`eds.span_classifier` or `eds.span_linker` components instead.",
VisibleDeprecationWarning,
)
sub_span_getter = getattr(embedding, "span_getter", None)
if sub_span_getter is not None and span_getter is None: # pragma: no cover
self.span_getter = sub_span_getter
sub_context_getter = getattr(embedding, "context_getter", None)
if sub_context_getter is not None: # pragma: no cover
self.context_getter = sub_context_getter

self.output_size = embedding.output_size if hidden_size is None else hidden_size

super().__init__(nlp, name)
Expand Down
7 changes: 7 additions & 0 deletions edsnlp/pipes/trainable/embeddings/text_cnn/text_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def __init__(
residual: bool = True,
normalize: Literal["pre", "post", "none"] = "pre",
):
sub_span_getter = getattr(embedding, "span_getter", None)
if sub_span_getter is not None: # pragma: no cover
self.span_getter = sub_span_getter
sub_context_getter = getattr(embedding, "context_getter", None)
if sub_context_getter is not None: # pragma: no cover
self.context_getter = sub_context_getter

super().__init__(nlp, name)
self.embedding = embedding
self.output_size = output_size or embedding.output_size
Expand Down
8 changes: 5 additions & 3 deletions edsnlp/pipes/trainable/embeddings/transformer/transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings
from pathlib import Path
from typing import List, Optional, Set, Tuple, Union

import foldedtensor as ft
import tokenizers
import tokenizers.normalizers
import torch
from confit import validate_arguments
from confit import VisibleDeprecationWarning, validate_arguments
from transformers import AutoModel, AutoTokenizer
from transformers import BitsAndBytesConfig as BitsAndBytesConfig_
from typing_extensions import Literal, TypedDict
Expand Down Expand Up @@ -148,10 +149,11 @@ def __init__(
super().__init__(nlp, name)

if span_getter is not None:
raise ValueError(
warnings.warn(
"The `span_getter` parameter of the `eds.transformer` component is "
"deprecated. Please use the `context_getter` parameter of the "
"other higher level task components instead."
"other higher level task components instead.",
VisibleDeprecationWarning,
)
self.transformer = AutoModel.from_pretrained(
model,
Expand Down
5 changes: 5 additions & 0 deletions edsnlp/pipes/trainable/ner_crf/ner_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ def __init__(
"You cannot set both the `labels` key of the `target_span_getter` "
"parameter and the `labels` parameter."
)
sub_context_getter = getattr(embedding, "context_getter", None)
if (
sub_context_getter is not None and context_getter is None
): # pragma: no cover
context_getter = sub_context_getter

super().__init__(
nlp=nlp,
Expand Down
11 changes: 11 additions & 0 deletions edsnlp/pipes/trainable/span_classifier/span_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def __init__(
)
assert attributes is None
attributes = qualifiers
sub_span_getter = getattr(embedding, "span_getter", None)
if (
sub_span_getter is not None and span_getter is None
): # pragma: no cover # noqa: E501
span_getter = sub_span_getter
sub_context_getter = getattr(embedding, "context_getter", None)
if (
sub_context_getter is not None and context_getter is None
): # pragma: no cover
context_getter = sub_context_getter

self.values = values
self.keep_none = keep_none
self.bindings: List[Tuple[str, List[str], List[Any]]] = [
Expand Down
9 changes: 9 additions & 0 deletions edsnlp/pipes/trainable/span_linker/span_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ def __init__(
):
self.attribute = attribute

sub_span_getter = getattr(embedding, "span_getter", None)
if sub_span_getter is not None and span_getter is None: # pragma: no cover
span_getter = sub_span_getter
sub_context_getter = getattr(embedding, "context_getter", None)
if (
sub_context_getter is not None and context_getter is None
): # pragma: no cover
context_getter = sub_context_getter

super().__init__(
nlp,
name,
Expand Down
7 changes: 6 additions & 1 deletion edsnlp/utils/span_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,16 @@ def __init__(
self,
context_words: NonNegativeInt = 0,
context_sents: Optional[NonNegativeInt] = 1,
span_getter: Optional[SpanGetterArg] = None,
):
self.context_words = context_words
self.context_sents = context_sents
self.span_getter = validate_span_getter(span_getter, optional=True)

def __call__(self, span: Union[Doc, Span]) -> Union[Span, List[Span]]:
if isinstance(span, Doc): # pragma: no cover
return [self(s) for s in get_spans(span, self.span_getter)]

def __call__(self, span: Span) -> List[Span]:
n_sents: int = self.context_sents
n_words = self.context_words

Expand Down

0 comments on commit 79d2601

Please sign in to comment.