Skip to content

Commit

Permalink
Merge pull request #70 from infopz/main
Browse files Browse the repository at this point in the history
Solved issue #40
  • Loading branch information
ivyleavedtoadflax authored Oct 23, 2023
2 parents 99e6187 + 21302ac commit 0978d89
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
name: CI Code Checks
on: [pull_request]
permissions:
contents: write
jobs:
build:
name: code checks
Expand Down Expand Up @@ -41,4 +43,4 @@ jobs:
folder: coverage

- name: Type checking with mypy
run: mypy --config setup.cfg src
run: mypy --config setup.cfg src
44 changes: 25 additions & 19 deletions src/nervaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def compute_metrics( # type: ignore
true_named_entities = [clean_entities(ent) for ent in true_named_entities if ent["label"] in tags]
pred_named_entities = [clean_entities(ent) for ent in pred_named_entities if ent["label"] in tags]

# Sort the lists to improve the speed of the overlap comparison
true_named_entities.sort(key=lambda x: x["start"])
pred_named_entities.sort(key=lambda x: x["end"])

# go through each predicted named-entity
for pred in pred_named_entities:
found_overlap = False
Expand All @@ -169,6 +173,10 @@ def compute_metrics( # type: ignore
else:
# check for overlaps with any of the true entities
for true in true_named_entities:
# Only enter this block if an overlap is possible
if pred["end"] < true["start"]:
break

# overlapping needs to take into account last token as well
pred_range = range(pred["start"], pred["end"] + 1)
true_range = range(true["start"], true["end"] + 1)
Expand Down Expand Up @@ -214,29 +222,27 @@ def compute_metrics( # type: ignore

found_overlap = True

break
else:
# Scenario VI: Entities overlap, but the entity type is
# different.

# Scenario VI: Entities overlap, but the entity type is
# different.

# overall results
evaluation["strict"]["incorrect"] += 1
evaluation["ent_type"]["incorrect"] += 1
evaluation["partial"]["partial"] += 1
evaluation["exact"]["incorrect"] += 1
# overall results
evaluation["strict"]["incorrect"] += 1
evaluation["ent_type"]["incorrect"] += 1
evaluation["partial"]["partial"] += 1
evaluation["exact"]["incorrect"] += 1

# aggregated by entity type results
# Results against the true entity
# aggregated by entity type results
# Results against the true entity

evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1
evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1
evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1
evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1
evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1
evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1
evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1
evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1

# Results against the predicted entity
# evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1
found_overlap = True
break
# Results against the predicted entity
# evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1
found_overlap = True

# Scenario II: Entities are spurious (i.e., over-generated).
if not found_overlap:
Expand Down
71 changes: 71 additions & 0 deletions tests/test_nervaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,3 +816,74 @@ def test_compute_precision_recall():
out = compute_precision_recall(results)

assert out == expected


def test_compute_metrics_one_pred_two_true():
true_named_entities_1 = [
{"start": 0, "end": 12, "label": "A"},
{"start": 14, "end": 17, "label": "B"},
]
true_named_entities_2 = [
{"start": 14, "end": 17, "label": "B"},
{"start": 0, "end": 12, "label": "A"},
]
pred_named_entities = [
{"start": 0, "end": 17, "label": "A"},
]

results1, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"])
results2, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"])

expected = {
'ent_type': {
'correct': 1,
'incorrect': 1,
'partial': 0,
'missed': 0,
'spurious': 0,
'possible': 2,
'actual': 2,
'precision': 0,
'recall': 0,
'f1': 0
},
'partial': {
'correct': 0,
'incorrect': 0,
'partial': 2,
'missed': 0,
'spurious': 0,
'possible': 2,
'actual': 2,
'precision': 0,
'recall': 0,
'f1': 0
},
'strict': {
'correct': 0,
'incorrect': 2,
'partial': 0,
'missed': 0,
'spurious': 0,
'possible': 2,
'actual': 2,
'precision': 0,
'recall': 0,
'f1': 0
},
'exact': {
'correct': 0,
'incorrect': 2,
'partial': 0,
'missed': 0,
'spurious': 0,
'possible': 2,
'actual': 2,
'precision': 0,
'recall': 0,
'f1': 0
}
}

assert results1 == expected
assert results2 == expected

0 comments on commit 0978d89

Please sign in to comment.