Skip to content

Commit

Permalink
fix: don't overwrite Span._.value ext and return user assigned value …
Browse files Browse the repository at this point in the history
…before getter
  • Loading branch information
percevalw committed Dec 14, 2023
1 parent c3d17b9 commit 350c131
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
13 changes: 12 additions & 1 deletion edsnlp/pipes/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from operator import attrgetter
from typing import (
List,
Expand All @@ -21,6 +22,9 @@


def value_getter(span: Span):
key = span._._get_key("value")
if key in span.doc.user_data:
return span.doc.user_data[key]
return span._.get(span.label_) if span._.has(span.label_) else None


Expand Down Expand Up @@ -49,10 +53,17 @@ def set_extensions(self):
"""
Set `Doc`, `Span` and `Token` extensions.
"""
if Span.has_extension("value"):
if Span.get_extension("value")[2] is not value_getter:
warnings.warn(
"A Span extension 'value' already exists with a different getter. "
"Keeping the existing extension, but some components of edsnlp may "
"not work as expected."
)
return
Span.set_extension(
"value",
getter=value_getter,
force=True,
)

def get_spans(self, doc: Doc): # noqa: F811
Expand Down
51 changes: 51 additions & 0 deletions tests/pipelines/ner/test_value_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import spacy
from spacy.tokens import Span
from spacy.tokens.underscore import Underscore

import edsnlp


def test_warn_value_extension():
old_value_extension = Underscore.span_extensions.pop("value", None)
try:
Underscore._extensions = {}
Span.set_extension("value", getter=lambda span: "stuff")
existing_nlp = spacy.blank("fr")
with pytest.warns(UserWarning) as record:
existing_nlp.add_pipe(
"eds.terminology",
name="test",
config=dict(label="Any", terms={}),
)

assert any(
"A Span extension 'value' already exists with a different getter"
in str(r.message)
for r in record
)
finally:
Underscore.span_extensions.pop("value", None)
if old_value_extension is not None:
Underscore.span_extensions["value"] = old_value_extension


def test_value_extension():
# From https://github.com/aphp/edsnlp/issues/220

# Setting up a first pipeline
existing_nlp = spacy.blank("fr")
existing_nlp.add_pipe(
"eds.terminology",
name="test",
config=dict(label="Any", terms={}),
)

# Setting up another custom pipeline somewhere else in the code
nlp = edsnlp.blank("eds")
text = "hello this is a test"
doc = nlp(text)
my_span = doc[0:3]
my_span._.value = "CustomValue"

assert my_span._.value == "CustomValue"

0 comments on commit 350c131

Please sign in to comment.