diff --git a/changelog.md b/changelog.md index 52e66d11e..003ddc278 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,7 @@ - Specify whether to log the validation results or not (`logger=False`) - Added support for the CoNLL format with `edsnlp.data.read_conll` and with a specific `eds.conll_dict2doc` converter - Added a Trainable Biaffine Dependency Parser (`eds.biaffine_dep_parser`) component and metrics +- New `eds.extractive_qa` component to perform extractive question answering using questions as prompts to tag entities instead of a list of predefined labels as in `eds.ner_crf`. ### Fixed diff --git a/docs/pipes/trainable/extractive-qa.md b/docs/pipes/trainable/extractive-qa.md new file mode 100644 index 000000000..93ee51f9f --- /dev/null +++ b/docs/pipes/trainable/extractive-qa.md @@ -0,0 +1,8 @@ +# Extractive Question Answering {: #edsnlp.pipes.trainable.extractive_qa.factory.create_component } + +::: edsnlp.pipes.trainable.extractive_qa.factory.create_component + options: + heading_level: 2 + show_bases: false + show_source: false + only_class_level: true diff --git a/docs/pipes/trainable/index.md b/docs/pipes/trainable/index.md index 3444be280..751c0f6b7 100644 --- a/docs/pipes/trainable/index.md +++ b/docs/pipes/trainable/index.md @@ -14,6 +14,7 @@ All trainable components implement the [`TorchComponent`][edsnlp.core.torch_comp | `eds.text_cnn` | Contextualize embeddings with a CNN | | `eds.span_pooler` | A span embedding component that aggregates word embeddings | | `eds.ner_crf` | A trainable component to extract entities | +| `eds.extractive_qa` | A trainable component for extractive question answering | | `eds.span_classifier` | A trainable component for multi-class multi-label span classification | | `eds.span_linker` | A trainable entity linker (i.e. to a list of concepts) | | `eds.biaffine_dep_parser` | A trainable biaffine dependency parser | diff --git a/edsnlp/pipes/__init__.py b/edsnlp/pipes/__init__.py index 02a2a0489..c5055e95f 100644 --- a/edsnlp/pipes/__init__.py +++ b/edsnlp/pipes/__init__.py @@ -76,6 +76,7 @@ from .qualifiers.reported_speech.factory import create_component as rspeech from .trainable.ner_crf.factory import create_component as ner_crf from .trainable.biaffine_dep_parser.factory import create_component as biaffine_dep_parser + from .trainable.extractive_qa.factory import create_component as extractive_qa from .trainable.span_classifier.factory import create_component as span_classifier from .trainable.span_linker.factory import create_component as span_linker from .trainable.embeddings.span_pooler.factory import create_component as span_pooler diff --git a/edsnlp/pipes/trainable/extractive_qa/__init__.py b/edsnlp/pipes/trainable/extractive_qa/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py b/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py new file mode 100644 index 000000000..4cb0c5385 --- /dev/null +++ b/edsnlp/pipes/trainable/extractive_qa/extractive_qa.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Set + +from spacy.tokens import Doc, Span +from typing_extensions import Literal + +from edsnlp.core.pipeline import Pipeline +from edsnlp.pipes.trainable.embeddings.typing import ( + WordEmbeddingComponent, +) +from edsnlp.pipes.trainable.ner_crf.ner_crf import NERBatchOutput, TrainableNerCrf +from edsnlp.utils.filter import align_spans, filter_spans +from edsnlp.utils.span_getters import ( + SpanGetterArg, + SpanSetterArg, + get_spans, +) +from edsnlp.utils.typing import AsList + + +class TrainableExtractiveQA(TrainableNerCrf): + """ + The `eds.extractive_qa` component is a trainable extractive question answering + component. This can be seen as a Named Entity Recognition (NER) component where the + types of entities predicted by the model are not pre-defined during the training + but are provided as prompts (i.e., questions) at inference time. + + The `eds.extractive_qa` shares a lot of similarities with the `eds.ner_crf` + component, and therefore most of the arguments are the same. + + !!! note "Extractive vs Abstractive Question Answering" + + Extractive Question Answering differs from Abstractive Question Answering in + that the answer is extracted from the text, rather than generated (à la + ChatGPT) from scratch. To normalize the answers, you can use the + `eds.span_linker` component in `synonym` mode and search for the closest + `synonym` in a predefined list. + + Examples + -------- + + ```python + import edsnlp, edsnlp.pipes as eds + + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.extractive_qa( + embedding=eds.transformer( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + ), + mode="joint", + target_span_getter="ner-gold", + span_setter="ents", + questions={ + "disease": "What disease does the patient have?", + "drug": "What drug is the patient taking?", + }, # (1)! + ), + name="qa", + ) + ``` + + To train the model, refer to the [Training](/tutorials/make-a-training-script) + tutorial. + + Once the model is trained, you can use the questions attribute (next section) on the + document you run the model on, or you can change the global questions attribute: + + ```python + nlp.pipes.qa.questions = { + "disease": "When did the patient get sick?", + } + ``` + + # Dynamic Questions + + You can also provide + + ```{ .python .no-check } + eds.extractive_qa(..., questions_attribute="questions") + ``` + + to get the questions dynamically from an attribute on the Doc or Span objects + (e.g., `doc._.questions`). This is useful when you want to have different questions + depending on the document. + + To provide questions from a dataframe, you can use the following code: + + ```{ .python .no-check } + dataframe = pd.DataFrame({"questions": ..., "note_text": ..., "note_id": ...}) + stream = edsnlp.data.from_pandas( + dataframe, + converter="omop", + doc_attributes={"questions": "questions"}, + ) + stream.map_pipeline(nlp) + stream.set_processing(backend="multiprocessing") + out = stream.to_pandas(converters="ents") + ``` + + + Parameters + ---------- + name : str + Name of the component + embedding : WordEmbeddingComponent + The word embedding component + questions : Dict[str, AsList[str]] + The questions to ask, as a mapping between the entity type and the list of + questions to ask for this entity type (or single string if only one question). + questions_attribute : Optional[str] + The attribute to use to get the questions dynamically from the Doc or Span + objects (as returned by the `context_getter` argument). If None, the questions + will be fixed and only taken from the `questions` argument. + context_getter : Optional[SpanGetterArg] + What context to use when computing the span embeddings (defaults to the whole + document). For example `{"section": "conclusion"}` to only extract the + entities from the conclusion. + target_span_getter : SpanGetterArg + Method to call to get the gold spans from a document, for scoring or training. + By default, takes all entities in `doc.ents`, but we recommend you specify + a given span group name instead. + span_setter : Optional[SpanSetterArg] + The span setter to use to set the predicted spans on the Doc object. If None, + the component will infer the span setter from the target_span_getter config. + infer_span_setter : Optional[bool] + Whether to complete the span setter from the target_span_getter config. + False by default, unless the span_setter is None. + mode : Literal["independent", "joint", "marginal"] + The CRF mode to use : independent, joint or marginal + window : int + The window size to use for the CRF. If 0, will use the whole document, at + the cost of a longer computation time. If 1, this is equivalent to assuming + that the tags are independent and will the component be faster, but with + degraded performance. Empirically, we found that a window size of 10 or 20 + works well. + stride : Optional[int] + The stride to use for the CRF windows. Defaults to `window // 2`. + """ + + def __init__( + self, + nlp: Optional[Pipeline] = None, + name: Optional[str] = "extractive_qa", + *, + embedding: WordEmbeddingComponent, + questions: Dict[str, AsList[str]] = {}, + questions_attribute: str = "questions", + context_getter: Optional[SpanGetterArg] = None, + target_span_getter: Optional[SpanGetterArg] = None, + span_setter: Optional[SpanSetterArg] = None, + infer_span_setter: Optional[bool] = None, + mode: Literal["independent", "joint", "marginal"] = "joint", + window: int = 40, + stride: Optional[int] = None, + ): + self.questions_attribute: Optional[str] = questions_attribute + self.questions = questions + super().__init__( + nlp=nlp, + name=name, + embedding=embedding, + context_getter=context_getter, + span_setter=span_setter, + target_span_getter=target_span_getter, + mode=mode, + window=window, + stride=stride, + infer_span_setter=infer_span_setter, + ) + self.update_labels(["answer"]) + self.labels_to_idx = defaultdict(lambda: 0) + + def set_extensions(self): + super().set_extensions() + if self.questions_attribute: + if not Doc.has_extension(self.questions_attribute): + Doc.set_extension(self.questions_attribute, default=None) + if not Span.has_extension(self.questions_attribute): + Span.set_extension(self.questions_attribute, default=None) + + def post_init(self, docs: Iterable[Doc], exclude: Set[str]): + pass + + @property + def cfg(self): + cfg = dict(super().cfg) + cfg.pop("labels") + return cfg + + def preprocess(self, doc, **kwargs): + contexts = ( + list(get_spans(doc, self.context_getter)) + if self.context_getter + else [doc[:]] + ) + prompt_contexts_and_labels = sorted( + { + (prompt, label, context) + for context in contexts + for label, questions in ( + *self.questions.items(), + *(getattr(doc._, self.questions_attribute) or {}).items(), + *( + (getattr(context._, self.questions_attribute) or {}).items() + if context is not doc + else () + ), + ) + for prompt in questions + } + ) + questions = [x[0] for x in prompt_contexts_and_labels] + labels = [x[1] for x in prompt_contexts_and_labels] + ctxs = [x[2] for x in prompt_contexts_and_labels] + return { + "lengths": [len(ctx) for ctx in ctxs], + "$labels": labels, + "$contexts": ctxs, + "embedding": self.embedding.preprocess( + doc, + contexts=ctxs, + prompts=questions, + **kwargs, + ), + } + + def preprocess_supervised(self, doc, **kwargs): + prep = self.preprocess(doc, **kwargs) + contexts = prep["$contexts"] + labels = prep["$labels"] + tags = [] + + for context, label, target_ents in zip( + contexts, + labels, + align_spans( + list(get_spans(doc, self.target_span_getter)), + contexts, + ), + ): + span_tags = [[0] * len(self.labels) for _ in range(len(context))] + start = context.start + target_ents = [ent for ent in target_ents if ent.label_ == label] + + # TODO: move this to the LinearChainCRF class + for ent in filter_spans(target_ents): + label_idx = self.labels_to_idx[ent.label_] + if ent.start == ent.end - 1: + span_tags[ent.start - start][label_idx] = 4 + else: + span_tags[ent.start - start][label_idx] = 2 + span_tags[ent.end - 1 - start][label_idx] = 3 + for i in range(ent.start + 1 - start, ent.end - 1 - start): + span_tags[i][label_idx] = 1 + tags.append(span_tags) + + return { + **prep, + "targets": tags, + } + + def postprocess( + self, + docs: List[Doc], + results: NERBatchOutput, + inputs: List[Dict[str, Any]], + ): + spans: Dict[Doc, list[Span]] = defaultdict(list) + contexts = [ctx for sample in inputs for ctx in sample["$contexts"]] + labels = [label for sample in inputs for label in sample["$labels"]] + tags = results["tags"].cpu() + for context_idx, _, start, end in self.crf.tags_to_spans(tags).tolist(): + span = contexts[context_idx][start:end] + label = labels[context_idx] + span.label_ = label + spans[span.doc].append(span) + for doc in docs: + self.set_spans(doc, spans.get(doc, [])) + return docs diff --git a/edsnlp/pipes/trainable/extractive_qa/factory.py b/edsnlp/pipes/trainable/extractive_qa/factory.py new file mode 100644 index 000000000..60b9e508f --- /dev/null +++ b/edsnlp/pipes/trainable/extractive_qa/factory.py @@ -0,0 +1,14 @@ +from typing import TYPE_CHECKING + +from edsnlp import registry + +from .extractive_qa import TrainableExtractiveQA + +create_component = registry.factory.register( + "eds.extractive_qa", + assigns=[], + deprecated=[], +)(TrainableExtractiveQA) + +if TYPE_CHECKING: + create_component = TrainableExtractiveQA diff --git a/mkdocs.yml b/mkdocs.yml index fcc844502..8ee044304 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -129,6 +129,7 @@ nav: - 'Span Classifier': pipes/trainable/span-classifier.md - 'Span Linker': pipes/trainable/span-linker.md - 'Biaffine Dependency Parser': pipes/trainable/biaffine-dependency-parser.md + - 'Extractive QA': pipes/trainable/extractive-qa.md - tokenizers.md - Data Connectors: - data/index.md diff --git a/pyproject.toml b/pyproject.toml index 855e47759..074e65753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -268,8 +268,9 @@ where = ["."] "eds.text_cnn" = "edsnlp.pipes.trainable.embeddings.text_cnn.factory:create_component" "eds.span_pooler" = "edsnlp.pipes.trainable.embeddings.span_pooler.factory:create_component" "eds.ner_crf" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" +"eds.extractive_qa" = "edsnlp.pipes.trainable.extractive_qa.factory:create_component" "eds.nested_ner" = "edsnlp.pipes.trainable.ner_crf.factory:create_component" -"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" +"eds.span_qualifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" "eds.span_classifier" = "edsnlp.pipes.trainable.span_classifier.factory:create_component" "eds.span_linker" = "edsnlp.pipes.trainable.span_linker.factory:create_component" "eds.biaffine_dep_parser" = "edsnlp.pipes.trainable.biaffine_dep_parser.factory:create_component" diff --git a/tests/pipelines/trainable/test_extractive_qa.py b/tests/pipelines/trainable/test_extractive_qa.py new file mode 100644 index 000000000..bce68d7e3 --- /dev/null +++ b/tests/pipelines/trainable/test_extractive_qa.py @@ -0,0 +1,101 @@ +from spacy.tokens import Span + +import edsnlp +import edsnlp.pipes as eds + + +def test_ner(): + nlp = edsnlp.blank("eds") + nlp.add_pipe( + eds.extractive_qa( + embedding=eds.transformer( + model="prajjwal1/bert-tiny", + window=20, + stride=10, + ), + # During training, where do we get the gold entities from ? + target_span_getter=["ner-gold"], + # Which prompts for each label ? + questions={ + "PERSON": "Quels sont les personnages ?", + "GIFT": "Quels sont les cadeaux ?", + }, + questions_attribute="question", + # During prediction, where do we set the predicted entities ? + span_setter="ents", + ), + ) + + doc = nlp( + "L'aîné eut le Moulin, le second eut l'âne, et le plus jeune n'eut que le Chat." + ) + doc._.question = { + "FAVORITE": ["Qui a eu de l'argent ?"], + } + # doc[0:2], doc[4:5], doc[6:8], doc[9:11], doc[13:16], doc[20:21] + doc.spans["ner-gold"] = [ + Span(doc, 0, 2, "PERSON"), # L'aîné + Span(doc, 4, 5, "GIFT"), # Moulin + Span(doc, 6, 8, "PERSON"), # le second + Span(doc, 9, 11, "GIFT"), # l'âne + Span(doc, 13, 16, "PERSON"), # le plus jeune + Span(doc, 20, 21, "GIFT"), # Chat + ] + nlp.post_init([doc]) + + ner = nlp.pipes.extractive_qa + batch = ner.prepare_batch([doc], supervision=True) + results = ner.module_forward(batch) + + list(ner.pipe([doc]))[0] + + assert results["loss"] is not None + trf_inputs = [ + seq.replace(" [PAD]", "") + for seq in ner.embedding.tokenizer.batch_decode(batch["embedding"]["input_ids"]) + ] + assert trf_inputs == [ + "[CLS] quels sont les cadeaux? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501 + "[CLS] quels sont les cadeaux? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501 + "[CLS] quels sont les cadeaux? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501 + "[CLS] quels sont les personnages? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501 + "[CLS] quels sont les personnages? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501 + "[CLS] quels sont les personnages? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501 + "[CLS] qui a eu de l'argent? [SEP] l'aine eut le moulin, le second eut l'ane, et [SEP]", # noqa: E501 + "[CLS] qui a eu de l'argent? [SEP] le second eut l'ane, et le plus jeune n'eut que le [SEP]", # noqa: E501 + "[CLS] qui a eu de l'argent? [SEP] le plus jeune n'eut que le chat. [SEP]", # noqa: E501 + ] + assert batch["targets"].squeeze(2).tolist() == [ + [0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0], + [2, 3, 0, 0, 0, 0, 2, 3, 0, 0, 0, 0, 0, 2, 1, 3, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + + assert nlp.config.to_yaml_str() == ( + "nlp:\n" + " lang: eds\n" + " pipeline:\n" + " - extractive_qa\n" + " tokenizer:\n" + " '@tokenizers': eds.tokenizer\n" + "components:\n" + " extractive_qa:\n" + " '@factory': eds.extractive_qa\n" + " embedding:\n" + " '@factory': eds.transformer\n" + " model: prajjwal1/bert-tiny\n" + " window: 20\n" + " stride: 10\n" + " questions:\n" + " PERSON: Quels sont les personnages ?\n" + " GIFT: Quels sont les cadeaux ?\n" + " questions_attribute: question\n" + " target_span_getter:\n" + " - ner-gold\n" + " span_setter:\n" + " ents: true\n" + " infer_span_setter: false\n" + " mode: joint\n" + " window: 40\n" + " stride: 20\n" + )