Skip to content

Commit

Permalink
Gruadually revising eval for NER
Browse files Browse the repository at this point in the history
  • Loading branch information
caufieldjh committed Oct 26, 2023
1 parent d4579c9 commit 9fdd007
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions src/ontogpt/evaluation/ctd/eval_ctd_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
performance of 80.64% F-score. A total of 7 teams achieved performance
higher than DNorm."
This evaluation also considers chemical named entity recognition.
"""
import gzip
import logging
Expand Down Expand Up @@ -48,23 +50,14 @@
TEST_SET_BIOC = DATABASE_DIR / "CDR_TestSet.BioC.xml.gz"
TRAIN_SET_BIOC = DATABASE_DIR / "CDR_TrainSet.BioC.xml.gz"

RMAP = {"CID": "induces"}

logger = logging.getLogger(__name__)


def negated(ChemicalToDiseaseRelationship) -> bool:
return (
ChemicalToDiseaseRelationship.qualifier
and ChemicalToDiseaseRelationship.qualifier.lower() == "not"
)


class PredictionRE(BaseModel):
"""A prediction for a relationship extraction task."""
class PredictionNER(BaseModel):
"""A prediction for a named entity recognition task."""

test_object: Optional[TextWithTriples] = None
"""source of truth to evaluate against"""
"""Source of truth to evaluate against."""

true_positives: Optional[List[Tuple]] = None
num_true_positives: Optional[int] = None
Expand All @@ -76,6 +69,9 @@ class PredictionRE(BaseModel):
predicted_object: Optional[TextWithTriples] = None
named_entities: Optional[List[Any]] = None

# TODO: allow this to take a subset of entities.
# Or set up another child class for each entity type,
# since they may require some fancy adaptations
def calculate_scores(self, labelers: Optional[List[BasicOntologyInterface]] = None):
self.scores = {}

Expand All @@ -90,8 +86,8 @@ def label(x):
def all_objects(dm: Optional[TextWithTriples]):
if dm is not None:
return list(
set(link.subject for link in dm.triples if not negated(link))
| set(link.object for link in dm.triples if not negated(link))
set(link.subject for link in dm.triples)
| set(link.object for link in dm.triples)
)
else:
return list()
Expand All @@ -101,7 +97,6 @@ def pairs(dm: TextWithTriples) -> Set:
return set(
(label(link.subject), label(link.object))
for link in dm.triples
if not negated(link)
)
else:
return set()
Expand All @@ -122,23 +117,25 @@ def pairs(dm: TextWithTriples) -> Set:
self.num_true_positives = len(self.true_positives)


class EvaluationObjectSetRE(BaseModel):
"""A result of predicting relation extractions."""
class EvaluationObjectSetNER(BaseModel):
"""A result of performing named entity recognition."""

precision: float = 0
recall: float = 0
f1: float = 0

training: Optional[List[TextWithTriples]] = None
predictions: Optional[List[PredictionRE]] = None
predictions: Optional[List[PredictionNER]] = None
test: Optional[List[TextWithTriples]] = None


@dataclass
class EvalCTD(SPIRESEvaluationEngine):
class EvalCTDNER(SPIRESEvaluationEngine):
subject_prefix = "MESH"
object_prefix = "MESH"

# TODO: use a restricted version of ctd schema for NER alone,
# but retain the original entity types
def __post_init__(self):
self.extractor = SPIRESEngine(template="ctd.ChemicalToDiseaseDocument", model=self.model)
# synonyms are derived entirely from training set
Expand Down Expand Up @@ -195,7 +192,7 @@ def create_training_set(self, num=100):
completion = ke.serialize_object()
yield dict(prompt=prompt, completion=completion)

def eval(self) -> EvaluationObjectSetRE:
def eval(self) -> EvaluationObjectSetNER:
"""Evaluate the ability to extract relations."""
labeler = get_adapter("sqlite:obo:mesh")
if self.num_tests and isinstance(self.num_tests, int):
Expand All @@ -205,7 +202,7 @@ def eval(self) -> EvaluationObjectSetRE:
ke = self.extractor
docs = list(self.load_test_cases())
shuffle(docs)
eos = EvaluationObjectSetRE(
eos = EvaluationObjectSetNER(
test=docs[:num_test],
training=[],
predictions=[],
Expand Down Expand Up @@ -270,7 +267,7 @@ def included(t: ChemicalToDiseaseRelationship):
logger.info(
f"{len(predicted_obj.triples)} filtered triples (CID only, between MESH only)"
)
pred = PredictionRE(
pred = PredictionNER(
predicted_object=predicted_obj, test_object=doc, named_entities=named_entities
)
named_entities.clear()
Expand All @@ -283,7 +280,7 @@ def included(t: ChemicalToDiseaseRelationship):
self.calc_stats(eos)
return eos

def calc_stats(self, eos: EvaluationObjectSetRE):
def calc_stats(self, eos: EvaluationObjectSetNER):
num_true_positives = sum(p.num_true_positives for p in eos.predictions)
num_false_positives = sum(p.num_false_positives for p in eos.predictions)
num_false_negatives = sum(p.num_false_negatives for p in eos.predictions)
Expand Down

0 comments on commit 9fdd007

Please sign in to comment.