Skip to content

Commit

Permalink
Add QaNet model (allenai#2446)
Browse files Browse the repository at this point in the history
* Add max length limit for passage and question in SQuAD reader.

* Add QaNet model.

* fixes the squad reader and adds doc.

* Move `get_best_span()` function out of bidaf.

* Update the docstring of QANet and BiDAF

* Move `ResidualWithLayerDropout` to a separate module file.

* Update the docstring and test cases for the length limits in squad reader.

* Keep the old `get_best_span` function in `bidaf.py`.

* Add docstring for `get_best_span` function.

* Separate test case for the `get_best_span` function.

* Fixes docs.

* Update the training configuration file.

* ignores pylint error.

* add docs for layer dropout.

* fixes docs.

* Remove the unsqueeze()
  • Loading branch information
yizhongw authored and matt-gardner committed Feb 1, 2019
1 parent e417486 commit 08a8c5e
Show file tree
Hide file tree
Showing 22 changed files with 1,206 additions and 15 deletions.
49 changes: 42 additions & 7 deletions allennlp/data/dataset_readers/reading_comprehension/squad.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Optional

from overrides import overrides

Expand All @@ -26,6 +26,14 @@ class SquadReader(DatasetReader):
``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD
evaluation script to get metrics.
We also support limiting the maximum length for both passage and question. However, some gold
answer spans may exceed the maximum passage length, which will cause error in making instances.
We simply skip these spans to avoid errors. If all of the gold answer spans of an example
are skipped, during training, we will skip this example. During validating or testing, since
we cannot skip examples, we use the last token as the pseudo gold answer span instead. The
computed loss will not be accurate as a result. But this will not affect the answer evaluation,
because we keep all the original gold answer texts.
Parameters
----------
tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
Expand All @@ -34,14 +42,29 @@ class SquadReader(DatasetReader):
token_indexers : ``Dict[str, TokenIndexer]``, optional
We similarly use this for both the question and the passage. See :class:`TokenIndexer`.
Default is ``{"tokens": SingleIdTokenIndexer()}``.
lazy : ``bool``, optional (default=False)
If this is true, ``instances()`` will return an object whose ``__iter__`` method
reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list.
passage_length_limit : ``int``, optional (default=None)
if specified, we will cut the passage if the length of passage exceeds this limit.
question_length_limit : ``int``, optional (default=None)
if specified, we will cut the question if the length of passage exceeds this limit.
skip_invalid_examples: ``bool``, optional (default=False)
if this is true, we will skip those invalid examples
"""
def __init__(self,
tokenizer: Tokenizer = None,
token_indexers: Dict[str, TokenIndexer] = None,
lazy: bool = False) -> None:
lazy: bool = False,
passage_length_limit: int = None,
question_length_limit: int = None,
skip_invalid_examples: bool = False) -> None:
super().__init__(lazy)
self._tokenizer = tokenizer or WordTokenizer()
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
self.passage_length_limit = passage_length_limit
self.question_length_limit = question_length_limit
self.skip_invalid_examples = skip_invalid_examples

@overrides
def _read(self, file_path: str):
Expand All @@ -68,25 +91,32 @@ def _read(self, file_path: str):
zip(span_starts, span_ends),
answer_texts,
tokenized_paragraph)
yield instance
if instance is not None:
yield instance

@overrides
def text_to_instance(self, # type: ignore
question_text: str,
passage_text: str,
char_spans: List[Tuple[int, int]] = None,
answer_texts: List[str] = None,
passage_tokens: List[Token] = None) -> Instance:
passage_tokens: List[Token] = None) -> Optional[Instance]:
# pylint: disable=arguments-differ
if not passage_tokens:
passage_tokens = self._tokenizer.tokenize(passage_text)
question_tokens = self._tokenizer.tokenize(question_text)
if self.passage_length_limit is not None:
passage_tokens = passage_tokens[: self.passage_length_limit]
if self.question_length_limit is not None:
question_tokens = question_tokens[: self.question_length_limit]
char_spans = char_spans or []

# We need to convert character indices in `passage_text` to token indices in
# `passage_tokens`, as the latter is what we'll actually use for supervision.
token_spans: List[Tuple[int, int]] = []
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
for char_span_start, char_span_end in char_spans:
if char_span_end > passage_offsets[-1][1]:
continue
(span_start, span_end), error = util.char_span_to_token_span(passage_offsets,
(char_span_start, char_span_end))
if error:
Expand All @@ -98,8 +128,13 @@ def text_to_instance(self, # type: ignore
logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1])
logger.debug("Answer: %s", passage_text[char_span_start:char_span_end])
token_spans.append((span_start, span_end))

return util.make_reading_comprehension_instance(self._tokenizer.tokenize(question_text),
# The original answer is filtered out
if char_spans and not token_spans:
if self.skip_invalid_examples:
return None
else:
token_spans.append((len(passage_tokens) - 1, len(passage_tokens) - 1))
return util.make_reading_comprehension_instance(question_tokens,
passage_tokens,
self._token_indexers,
passage_text,
Expand Down
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from allennlp.models.event2mind import Event2Mind
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
from allennlp.models.reading_comprehension.qanet import QaNet
from allennlp.models.semantic_parsing.nlvr.nlvr_coverage_semantic_parser import NlvrCoverageSemanticParser
from allennlp.models.semantic_parsing.nlvr.nlvr_direct_semantic_parser import NlvrDirectSemanticParser
from allennlp.models.semantic_parsing.quarel.quarel_semantic_parser import QuarelSemanticParser
Expand Down
1 change: 1 addition & 0 deletions allennlp/models/reading_comprehension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
from allennlp.models.reading_comprehension.bidaf_ensemble import BidafEnsemble
from allennlp.models.reading_comprehension.dialog_qa import DialogQA
from allennlp.models.reading_comprehension.qanet import QaNet
14 changes: 7 additions & 7 deletions allennlp/models/reading_comprehension/bidaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from allennlp.common.checks import check_dimensions_match
from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.models.reading_comprehension.util import get_best_span
from allennlp.modules import Highway
from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder
from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention
Expand Down Expand Up @@ -139,12 +140,11 @@ def forward(self, # type: ignore
ending position of the answer with the passage. This is an `inclusive` token index.
If this is given, we will compute a loss that gets included in the output dictionary.
metadata : ``List[Dict[str, Any]]``, optional
If present, this should contain the question ID, original passage text, and token
offsets into the passage for each instance in the batch. We use this for computing
official metrics using the official SQuAD evaluation script. The length of this list
should be the batch size, and each dictionary should have the keys ``id``,
``original_passage``, and ``token_offsets``. If you only want the best span string and
don't care about official metrics, you can omit the ``id`` key.
metadata : ``List[Dict[str, Any]]``, optional
If present, this should contain the question tokens, passage tokens, original passage
text, and token offsets into the passage for each instance in the batch. The length
of this list should be the batch size, and each dictionary should have the keys
``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.
Returns
-------
Expand Down Expand Up @@ -245,7 +245,7 @@ def forward(self, # type: ignore
span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7)
span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7)
best_span = self.get_best_span(span_start_logits, span_end_logits)
best_span = get_best_span(span_start_logits, span_end_logits)

output_dict = {
"passage_question_attention": passage_question_attention,
Expand Down
3 changes: 2 additions & 1 deletion allennlp/models/reading_comprehension/bidaf_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allennlp.models.archival import load_archive
from allennlp.models.model import Model
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
from allennlp.models.reading_comprehension.util import get_best_span
from allennlp.common import Params
from allennlp.data import Vocabulary
from allennlp.training.metrics import SquadEmAndF1
Expand Down Expand Up @@ -140,4 +141,4 @@ def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor:

span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults)
span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults)
return BidirectionalAttentionFlow.get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore
return get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore
Loading

0 comments on commit 08a8c5e

Please sign in to comment.