Skip to content

Commit

Permalink
feat: standardize p/r/f metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Oct 25, 2023
1 parent 0aa7ac5 commit e5eef12
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 50 deletions.
14 changes: 14 additions & 0 deletions edsnlp/scorers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
]


def prf(pred, gold):
tp = len(set(pred) & set(gold))
np = len(pred)
ng = len(gold)
return {
"f": 2 * tp / max(1, np + ng),
"p": 1 if tp == np else (tp / np),
"r": 1 if tp == ng else (tp / ng),
"tp": tp,
"support": ng, # num gold
"positives": np, # num predicted
}


def make_examples(*args):
if len(args) == 2:
return (
Expand Down
64 changes: 32 additions & 32 deletions edsnlp/scorers/ner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from collections import defaultdict
from typing import Any, Dict, Iterable

from spacy.training import Example

from edsnlp import registry
from edsnlp.pipelines.base import SpanGetter, SpanGetterArg, get_spans
from edsnlp.scorers import make_examples
from edsnlp.scorers import make_examples, prf


def ner_exact_scorer(
examples: Iterable[Example], span_getter: SpanGetter
examples: Iterable[Example],
span_getter: SpanGetter,
micro_key: str = "micro",
) -> Dict[str, Any]:
"""
Scores the extracted entities that may be overlapping or nested
Expand All @@ -17,43 +20,42 @@ def ner_exact_scorer(
Parameters
----------
examples: Iterable[Example]
The examples to score
span_getter: SpanGetter
The span getter to use to extract the spans from the document
micro_key: str
The key to use to store the micro-averaged results for spans of all types
Returns
-------
Dict[str, Any]
"""
pred_spans = set()
gold_spans = set()
labels = defaultdict(lambda: (set(), set()))
labels["micro"] = (set(), set())
for eg_idx, eg in enumerate(examples):
for span in (
span_getter(eg.predicted)
if callable(span_getter)
else get_spans(eg.predicted, span_getter)
):
pred_spans.add((eg_idx, span.start, span.end, span.label_))
labels[span.label_][0].add((eg_idx, span.start, span.end, span.label_))
labels[micro_key][0].add((eg_idx, span.start, span.end, span.label_))

for span in (
span_getter(eg.reference)
if callable(span_getter)
else get_spans(eg.reference, span_getter)
):
gold_spans.add((eg_idx, span.start, span.end, span.label_))
labels[span.label_][1].add((eg_idx, span.start, span.end, span.label_))
labels[micro_key][1].add((eg_idx, span.start, span.end, span.label_))

tp = len(pred_spans & gold_spans)

return {
"ents_p": tp / len(pred_spans) if pred_spans else float(len(gold_spans) == 0),
"ents_r": tp / len(gold_spans) if gold_spans else float(len(gold_spans) == 0),
"ents_f": 2 * tp / (len(pred_spans) + len(gold_spans))
if pred_spans or gold_spans
else float(len(pred_spans) == len(gold_spans)),
"support": len(gold_spans),
}
return {name: prf(pred, gold) for name, (pred, gold) in labels.items()}


def ner_token_scorer(
examples: Iterable[Example], span_getter: SpanGetter
examples: Iterable[Example],
span_getter: SpanGetter,
micro_key: str = "micro",
) -> Dict[str, Any]:
"""
Scores the extracted entities that may be overlapping or nested
Expand All @@ -63,41 +65,39 @@ def ner_token_scorer(
Parameters
----------
examples: Iterable[Example]
The examples to score
span_getter: SpanGetter
The span getter to use to extract the spans from the document
micro_key: str
The key to use to store the micro-averaged results for spans of all types
Returns
-------
Dict[str, Any]
"""
pred_spans = set()
gold_spans = set()
# label -> pred, gold
labels = defaultdict(lambda: (set(), set()))
labels["micro"] = (set(), set())
for eg_idx, eg in enumerate(examples):
for span in (
span_getter(eg.predicted)
if callable(span_getter)
else get_spans(eg.predicted, span_getter)
):
for i in range(span.start, span.end):
pred_spans.add((eg_idx, i, span.label_))
labels[span.label_][0].add((eg_idx, i, span.label_))
labels[micro_key][0].add((eg_idx, i, span.label_))

for span in (
span_getter(eg.reference)
if callable(span_getter)
else get_spans(eg.reference, span_getter)
):
for i in range(span.start, span.end):
gold_spans.add((eg_idx, i, span.label_))

tp = len(pred_spans & gold_spans)

return {
"ents_p": tp / len(pred_spans) if pred_spans else float(tp == len(pred_spans)),
"ents_r": tp / len(gold_spans) if gold_spans else float(tp == len(gold_spans)),
"ents_f": 2 * tp / (len(pred_spans) + len(gold_spans))
if pred_spans or gold_spans
else float(len(pred_spans) == len(gold_spans)),
"support": len(gold_spans),
}
labels[span.label_][1].add((eg_idx, i, span.label_))
labels[micro_key][1].add((eg_idx, i, span.label_))

return {name: prf(pred, gold) for name, (pred, gold) in labels.items()}


@registry.scorers.register("eds.ner_exact_scorer")
Expand Down
29 changes: 11 additions & 18 deletions edsnlp/scorers/span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from spacy.training import Example

from edsnlp import registry
from edsnlp.scorers import make_examples
from edsnlp.scorers import make_examples, prf
from edsnlp.utils.bindings import BINDING_GETTERS, Qualifiers, QualifiersArg
from edsnlp.utils.span_getters import SpanGetterArg, get_spans

Expand All @@ -14,6 +14,7 @@ def span_classification_scorer(
span_getter: SpanGetterArg,
qualifiers: Qualifiers,
include_falsy: bool = False,
micro_key: str = "micro",
):
"""
Scores the extracted entities that may be overlapping or nested
Expand All @@ -28,9 +29,11 @@ def span_classification_scorer(
qualifiers : Sequence[str]
The qualifiers to use to score the spans
include_falsy : bool
Whether to count predicted or gold occurences of falsy values when computing
Whether to count predicted or gold occurrences of falsy values when computing
the metrics. If `False`, only the non-falsy values will be counted and matched
together.
micro_key : str
The key to use to store the micro-averaged results for spans of all types
Returns
-------
Expand All @@ -40,6 +43,7 @@ def span_classification_scorer(
labels[None] = ([], [])
total_pred_count = 0
total_gold_count = 0
labels["micro"] = (set(), set())
for eg_idx, eg in enumerate(examples):
doc_spans = get_spans(eg.predicted, span_getter)
for span_idx, span in enumerate(doc_spans):
Expand Down Expand Up @@ -73,24 +77,13 @@ def span_classification_scorer(
"another NER pipe in your model."
)

def prf(pred, gold):
tp = len(set(pred) & set(gold))
np = len(pred)
ng = len(gold)
return {
"f": 2 * tp / max(1, np + ng),
"p": 1 if tp == np else (tp / np),
"r": 1 if tp == ng else (tp / ng),
"support": len(gold),
}

results = {name: prf(pred, gold) for name, (pred, gold) in labels.items()}
micro_results = results.pop(None)
return {
"qual_p": micro_results["p"],
"qual_r": micro_results["r"],
"qual_f": micro_results["f"],
"support": len(labels[None][1]),
"qual_p": results[micro_key]["p"],
"qual_r": results[micro_key]["r"],
"qual_f": results[micro_key]["f"],
"support": results[micro_key]["support"],
"positives": results[micro_key]["positives"],
"qual_per_type": results,
}

Expand Down

0 comments on commit e5eef12

Please sign in to comment.