Skip to content

Commit

Permalink
🐛 Fix displacy render function (#10)
Browse files Browse the repository at this point in the history
* 🐛 Fix displacy render

Signed-off-by: Gabriele Picco <[email protected]>

* ✅ Add displacy test

Signed-off-by: Gabriele Picco <[email protected]>

Signed-off-by: Gabriele Picco <[email protected]>
Co-authored-by: Gabriele Picco <[email protected]>
  • Loading branch information
GabrielePicco and Gabriele Picco authored Oct 4, 2022
1 parent 89b8d5b commit 79db790
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
20 changes: 20 additions & 0 deletions zshot/tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import spacy

from zshot import PipelineConfig, displacy
from zshot.tests.config import EX_ENTITIES, EX_DOCS
from zshot.tests.linker.test_linker import DummyLinkerEnd2End
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
from zshot.utils.data_models import Span
from zshot.utils.alignment_utils import align_spans, AlignmentMode, filter_overlapping_spans

Expand Down Expand Up @@ -164,3 +170,17 @@ def test_alignment_expand_overlaps_no_score():
assert filtered_spans[0].label == "A"
assert filtered_spans[1].start == 3 and filtered_spans[1].end == 8
assert filtered_spans[1].label == "C"


def test_displacy_render():
nlp = spacy.blank("en")

nlp.add_pipe("zshot", config=PipelineConfig(
mentions_extractor=DummyMentionsExtractor(),
linker=DummyLinkerEnd2End(),
entities=EX_ENTITIES), last=True)
doc = nlp(EX_DOCS[1])
assert len(doc.ents) > 0
assert len(doc._.spans) > 0
res = displacy.render(doc, style="ent", jupyter=False)
assert res is not None
4 changes: 2 additions & 2 deletions zshot/utils/displacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def color_from_label(label: str):
class displacy:

@staticmethod
def render(doc, options: Dict = None, **kwargs):
def render(doc, options: Dict = None, **kwargs) -> str:
if options:
options['colors'] = ents_colors(doc)
else:
options = {'colors': ents_colors(doc)}
s_displacy.render(doc, options=options, **kwargs)
return s_displacy.render(doc, options=options, **kwargs)

@staticmethod
def serve(doc, options: Dict = None, **kwargs):
Expand Down

0 comments on commit 79db790

Please sign in to comment.