From 08a8c5e56d8728c17eb96de3ae603ce6799a5b2f Mon Sep 17 00:00:00 2001 From: Yizhong Wang Date: Thu, 31 Jan 2019 22:05:11 -0800 Subject: [PATCH] Add QaNet model (#2446) * Add max length limit for passage and question in SQuAD reader. * Add QaNet model. * fixes the squad reader and adds doc. * Move `get_best_span()` function out of bidaf. * Update the docstring of QANet and BiDAF * Move `ResidualWithLayerDropout` to a separate module file. * Update the docstring and test cases for the length limits in squad reader. * Keep the old `get_best_span` function in `bidaf.py`. * Add docstring for `get_best_span` function. * Separate test case for the `get_best_span` function. * Fixes docs. * Update the training configuration file. * ignores pylint error. * add docs for layer dropout. * fixes docs. * Remove the unsqueeze() --- .../reading_comprehension/squad.py | 49 +++- allennlp/models/__init__.py | 1 + .../models/reading_comprehension/__init__.py | 1 + .../models/reading_comprehension/bidaf.py | 14 +- .../reading_comprehension/bidaf_ensemble.py | 3 +- .../models/reading_comprehension/qanet.py | 261 ++++++++++++++++++ allennlp/models/reading_comprehension/util.py | 33 +++ allennlp/modules/__init__.py | 1 + .../modules/residual_with_layer_dropout.py | 59 ++++ allennlp/modules/seq2seq_encoders/__init__.py | 2 + .../modules/seq2seq_encoders/qanet_encoder.py | 251 +++++++++++++++++ .../reading_comprehension/squad_test.py | 33 +++ allennlp/tests/fixtures/qanet/experiment.json | 141 ++++++++++ .../reading_comprehension/qanet_test.py | 72 +++++ .../models/reading_comprehension/util_test.py | 32 +++ .../residual_with_layer_dropout_test.py | 42 +++ .../seq2seq_encoders/qanet_encoder_test.py | 44 +++ .../allennlp.models.reading_comprehension.rst | 10 + ...lp.modules.residual_with_layer_dropout.rst | 7 + doc/api/allennlp.modules.rst | 1 + doc/api/allennlp.modules.seq2seq_encoders.rst | 5 + training_config/qanet.jsonnet | 159 +++++++++++ 22 files changed, 1206 insertions(+), 15 deletions(-) create mode 100644 allennlp/models/reading_comprehension/qanet.py create mode 100644 allennlp/models/reading_comprehension/util.py create mode 100644 allennlp/modules/residual_with_layer_dropout.py create mode 100644 allennlp/modules/seq2seq_encoders/qanet_encoder.py create mode 100644 allennlp/tests/fixtures/qanet/experiment.json create mode 100644 allennlp/tests/models/reading_comprehension/qanet_test.py create mode 100644 allennlp/tests/models/reading_comprehension/util_test.py create mode 100644 allennlp/tests/modules/residual_with_layer_dropout_test.py create mode 100644 allennlp/tests/modules/seq2seq_encoders/qanet_encoder_test.py create mode 100644 doc/api/allennlp.modules.residual_with_layer_dropout.rst create mode 100644 training_config/qanet.jsonnet diff --git a/allennlp/data/dataset_readers/reading_comprehension/squad.py b/allennlp/data/dataset_readers/reading_comprehension/squad.py index bdc1f6845ee..853c462a640 100644 --- a/allennlp/data/dataset_readers/reading_comprehension/squad.py +++ b/allennlp/data/dataset_readers/reading_comprehension/squad.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Optional from overrides import overrides @@ -26,6 +26,14 @@ class SquadReader(DatasetReader): ``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD evaluation script to get metrics. + We also support limiting the maximum length for both passage and question. However, some gold + answer spans may exceed the maximum passage length, which will cause error in making instances. + We simply skip these spans to avoid errors. If all of the gold answer spans of an example + are skipped, during training, we will skip this example. During validating or testing, since + we cannot skip examples, we use the last token as the pseudo gold answer span instead. The + computed loss will not be accurate as a result. But this will not affect the answer evaluation, + because we keep all the original gold answer texts. + Parameters ---------- tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) @@ -34,14 +42,29 @@ class SquadReader(DatasetReader): token_indexers : ``Dict[str, TokenIndexer]``, optional We similarly use this for both the question and the passage. See :class:`TokenIndexer`. Default is ``{"tokens": SingleIdTokenIndexer()}``. + lazy : ``bool``, optional (default=False) + If this is true, ``instances()`` will return an object whose ``__iter__`` method + reloads the dataset each time it's called. Otherwise, ``instances()`` returns a list. + passage_length_limit : ``int``, optional (default=None) + if specified, we will cut the passage if the length of passage exceeds this limit. + question_length_limit : ``int``, optional (default=None) + if specified, we will cut the question if the length of passage exceeds this limit. + skip_invalid_examples: ``bool``, optional (default=False) + if this is true, we will skip those invalid examples """ def __init__(self, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None, - lazy: bool = False) -> None: + lazy: bool = False, + passage_length_limit: int = None, + question_length_limit: int = None, + skip_invalid_examples: bool = False) -> None: super().__init__(lazy) self._tokenizer = tokenizer or WordTokenizer() self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} + self.passage_length_limit = passage_length_limit + self.question_length_limit = question_length_limit + self.skip_invalid_examples = skip_invalid_examples @overrides def _read(self, file_path: str): @@ -68,7 +91,8 @@ def _read(self, file_path: str): zip(span_starts, span_ends), answer_texts, tokenized_paragraph) - yield instance + if instance is not None: + yield instance @overrides def text_to_instance(self, # type: ignore @@ -76,17 +100,23 @@ def text_to_instance(self, # type: ignore passage_text: str, char_spans: List[Tuple[int, int]] = None, answer_texts: List[str] = None, - passage_tokens: List[Token] = None) -> Instance: + passage_tokens: List[Token] = None) -> Optional[Instance]: # pylint: disable=arguments-differ if not passage_tokens: passage_tokens = self._tokenizer.tokenize(passage_text) + question_tokens = self._tokenizer.tokenize(question_text) + if self.passage_length_limit is not None: + passage_tokens = passage_tokens[: self.passage_length_limit] + if self.question_length_limit is not None: + question_tokens = question_tokens[: self.question_length_limit] char_spans = char_spans or [] - # We need to convert character indices in `passage_text` to token indices in # `passage_tokens`, as the latter is what we'll actually use for supervision. token_spans: List[Tuple[int, int]] = [] passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] for char_span_start, char_span_end in char_spans: + if char_span_end > passage_offsets[-1][1]: + continue (span_start, span_end), error = util.char_span_to_token_span(passage_offsets, (char_span_start, char_span_end)) if error: @@ -98,8 +128,13 @@ def text_to_instance(self, # type: ignore logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1]) logger.debug("Answer: %s", passage_text[char_span_start:char_span_end]) token_spans.append((span_start, span_end)) - - return util.make_reading_comprehension_instance(self._tokenizer.tokenize(question_text), + # The original answer is filtered out + if char_spans and not token_spans: + if self.skip_invalid_examples: + return None + else: + token_spans.append((len(passage_tokens) - 1, len(passage_tokens) - 1)) + return util.make_reading_comprehension_instance(question_tokens, passage_tokens, self._token_indexers, passage_text, diff --git a/allennlp/models/__init__.py b/allennlp/models/__init__.py index 23e4a4cf073..2e40320cb3f 100644 --- a/allennlp/models/__init__.py +++ b/allennlp/models/__init__.py @@ -14,6 +14,7 @@ from allennlp.models.event2mind import Event2Mind from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow +from allennlp.models.reading_comprehension.qanet import QaNet from allennlp.models.semantic_parsing.nlvr.nlvr_coverage_semantic_parser import NlvrCoverageSemanticParser from allennlp.models.semantic_parsing.nlvr.nlvr_direct_semantic_parser import NlvrDirectSemanticParser from allennlp.models.semantic_parsing.quarel.quarel_semantic_parser import QuarelSemanticParser diff --git a/allennlp/models/reading_comprehension/__init__.py b/allennlp/models/reading_comprehension/__init__.py index 7fe235f9fc2..5477edb1a3a 100644 --- a/allennlp/models/reading_comprehension/__init__.py +++ b/allennlp/models/reading_comprehension/__init__.py @@ -8,3 +8,4 @@ from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow from allennlp.models.reading_comprehension.bidaf_ensemble import BidafEnsemble from allennlp.models.reading_comprehension.dialog_qa import DialogQA +from allennlp.models.reading_comprehension.qanet import QaNet diff --git a/allennlp/models/reading_comprehension/bidaf.py b/allennlp/models/reading_comprehension/bidaf.py index 10ef43c7c3d..62da15528e2 100644 --- a/allennlp/models/reading_comprehension/bidaf.py +++ b/allennlp/models/reading_comprehension/bidaf.py @@ -7,6 +7,7 @@ from allennlp.common.checks import check_dimensions_match from allennlp.data import Vocabulary from allennlp.models.model import Model +from allennlp.models.reading_comprehension.util import get_best_span from allennlp.modules import Highway from allennlp.modules import Seq2SeqEncoder, SimilarityFunction, TimeDistributed, TextFieldEmbedder from allennlp.modules.matrix_attention.legacy_matrix_attention import LegacyMatrixAttention @@ -139,12 +140,11 @@ def forward(self, # type: ignore ending position of the answer with the passage. This is an `inclusive` token index. If this is given, we will compute a loss that gets included in the output dictionary. metadata : ``List[Dict[str, Any]]``, optional - If present, this should contain the question ID, original passage text, and token - offsets into the passage for each instance in the batch. We use this for computing - official metrics using the official SQuAD evaluation script. The length of this list - should be the batch size, and each dictionary should have the keys ``id``, - ``original_passage``, and ``token_offsets``. If you only want the best span string and - don't care about official metrics, you can omit the ``id`` key. + metadata : ``List[Dict[str, Any]]``, optional + If present, this should contain the question tokens, passage tokens, original passage + text, and token offsets into the passage for each instance in the batch. The length + of this list should be the batch size, and each dictionary should have the keys + ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. Returns ------- @@ -245,7 +245,7 @@ def forward(self, # type: ignore span_end_probs = util.masked_softmax(span_end_logits, passage_mask) span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e7) span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e7) - best_span = self.get_best_span(span_start_logits, span_end_logits) + best_span = get_best_span(span_start_logits, span_end_logits) output_dict = { "passage_question_attention": passage_question_attention, diff --git a/allennlp/models/reading_comprehension/bidaf_ensemble.py b/allennlp/models/reading_comprehension/bidaf_ensemble.py index 018d804114d..5734e3a6d5b 100644 --- a/allennlp/models/reading_comprehension/bidaf_ensemble.py +++ b/allennlp/models/reading_comprehension/bidaf_ensemble.py @@ -8,6 +8,7 @@ from allennlp.models.archival import load_archive from allennlp.models.model import Model from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow +from allennlp.models.reading_comprehension.util import get_best_span from allennlp.common import Params from allennlp.data import Vocabulary from allennlp.training.metrics import SquadEmAndF1 @@ -140,4 +141,4 @@ def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor: span_start_probs = sum(subresult['span_start_probs'] for subresult in subresults) / len(subresults) span_end_probs = sum(subresult['span_end_probs'] for subresult in subresults) / len(subresults) - return BidirectionalAttentionFlow.get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore + return get_best_span(span_start_probs.log(), span_end_probs.log()) # type: ignore diff --git a/allennlp/models/reading_comprehension/qanet.py b/allennlp/models/reading_comprehension/qanet.py new file mode 100644 index 00000000000..114f4781a1f --- /dev/null +++ b/allennlp/models/reading_comprehension/qanet.py @@ -0,0 +1,261 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.functional import nll_loss + +from allennlp.data import Vocabulary +from allennlp.models.model import Model +from allennlp.models.reading_comprehension.util import get_best_span +from allennlp.modules import Highway +from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder +from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention +from allennlp.nn import util, InitializerApplicator, RegularizerApplicator +from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy, SquadEmAndF1 +from allennlp.nn.util import masked_softmax + + +@Model.register("qanet") +class QaNet(Model): + """ + This class implements Adams Wei Yu's `QANet Model `_ + for machine reading comprehension published at ICLR 2018. + + The overall architecture of QANet is very similar to BiDAF. The main difference is that QANet + replaces the RNN encoder with CNN + self-attention. There are also some minor differences in the + modeling layer and output layer. + + Parameters + ---------- + vocab : ``Vocabulary`` + text_field_embedder : ``TextFieldEmbedder`` + Used to embed the ``question`` and ``passage`` ``TextFields`` we get as input to the model. + num_highway_layers : ``int`` + The number of highway layers to use in between embedding the input and passing it through + the phrase layer. + phrase_layer : ``Seq2SeqEncoder`` + The encoder (with its own internal stacking) that we will use in between embedding tokens + and doing the passage-question attention. + matrix_attention_layer : ``MatrixAttention`` + The matrix attention function that we will use when comparing encoded passage and question + representations. + modeling_layer : ``Seq2SeqEncoder`` + The encoder (with its own internal stacking) that we will use in between the bidirectional + attention and predicting span start and end. + dropout_prob : ``float``, optional (default=0.1) + If greater than 0, we will apply dropout with this probability between layers. + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) + Used to initialize the model parameters. + regularizer : ``RegularizerApplicator``, optional (default=``None``) + If provided, will be used to calculate the regularization penalty during training. + """ + def __init__(self, vocab: Vocabulary, + text_field_embedder: TextFieldEmbedder, + num_highway_layers: int, + phrase_layer: Seq2SeqEncoder, + matrix_attention_layer: MatrixAttention, + modeling_layer: Seq2SeqEncoder, + dropout_prob: float = 0.1, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super().__init__(vocab, regularizer) + + text_embed_dim = text_field_embedder.get_output_dim() + encoding_in_dim = phrase_layer.get_input_dim() + encoding_out_dim = phrase_layer.get_output_dim() + modeling_in_dim = modeling_layer.get_input_dim() + modeling_out_dim = modeling_layer.get_output_dim() + + self._text_field_embedder = text_field_embedder + + self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim) + self._highway_layer = Highway(encoding_in_dim, num_highway_layers) + + self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim) + self._phrase_layer = phrase_layer + + self._matrix_attention = matrix_attention_layer + + self._modeling_proj_layer = torch.nn.Linear(encoding_out_dim * 4, modeling_in_dim) + self._modeling_layer = modeling_layer + + self._span_start_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) + self._span_end_predictor = torch.nn.Linear(modeling_out_dim * 2, 1) + + self._span_start_accuracy = CategoricalAccuracy() + self._span_end_accuracy = CategoricalAccuracy() + self._span_accuracy = BooleanAccuracy() + self._metrics = SquadEmAndF1() + self._dropout = torch.nn.Dropout(p=dropout_prob) if dropout_prob > 0 else lambda x: x + + initializer(self) + + def forward(self, # type: ignore + question: Dict[str, torch.LongTensor], + passage: Dict[str, torch.LongTensor], + span_start: torch.IntTensor = None, + span_end: torch.IntTensor = None, + metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + question : Dict[str, torch.LongTensor] + From a ``TextField``. + passage : Dict[str, torch.LongTensor] + From a ``TextField``. The model assumes that this passage contains the answer to the + question, and predicts the beginning and ending positions of the answer within the + passage. + span_start : ``torch.IntTensor``, optional + From an ``IndexField``. This is one of the things we are trying to predict - the + beginning position of the answer with the passage. This is an `inclusive` token index. + If this is given, we will compute a loss that gets included in the output dictionary. + span_end : ``torch.IntTensor``, optional + From an ``IndexField``. This is one of the things we are trying to predict - the + ending position of the answer with the passage. This is an `inclusive` token index. + If this is given, we will compute a loss that gets included in the output dictionary. + metadata : ``List[Dict[str, Any]]``, optional + If present, this should contain the question tokens, passage tokens, original passage + text, and token offsets into the passage for each instance in the batch. The length + of this list should be the batch size, and each dictionary should have the keys + ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``. + + Returns + ------- + An output dictionary consisting of: + span_start_logits : torch.FloatTensor + A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log + probabilities of the span start position. + span_start_probs : torch.FloatTensor + The result of ``softmax(span_start_logits)``. + span_end_logits : torch.FloatTensor + A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log + probabilities of the span end position (inclusive). + span_end_probs : torch.FloatTensor + The result of ``softmax(span_end_logits)``. + best_span : torch.IntTensor + The result of a constrained inference over ``span_start_logits`` and + ``span_end_logits`` to find the most probable span. Shape is ``(batch_size, 2)`` + and each offset is a token index. + loss : torch.FloatTensor, optional + A scalar loss to be optimised. + best_span_str : List[str] + If sufficient metadata was provided for the instances in the batch, we also return the + string from the original passage that the model thinks is the best answer to the + question. + """ + question_mask = util.get_text_field_mask(question).float() + passage_mask = util.get_text_field_mask(passage).float() + + embedded_question = self._dropout(self._text_field_embedder(question)) + embedded_passage = self._dropout(self._text_field_embedder(passage)) + embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question)) + embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage)) + + batch_size = embedded_question.size(0) + + projected_embedded_question = self._encoding_proj_layer(embedded_question) + projected_embedded_passage = self._encoding_proj_layer(embedded_passage) + + encoded_question = self._dropout(self._phrase_layer(projected_embedded_question, question_mask)) + encoded_passage = self._dropout(self._phrase_layer(projected_embedded_passage, passage_mask)) + + # Shape: (batch_size, passage_length, question_length) + passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question) + # Shape: (batch_size, passage_length, question_length) + passage_question_attention = masked_softmax( + passage_question_similarity, + question_mask, + memory_efficient=True) + # Shape: (batch_size, passage_length, encoding_dim) + passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention) + + # Shape: (batch_size, question_length, passage_length) + question_passage_attention = masked_softmax( + passage_question_similarity.transpose(1, 2), + passage_mask, + memory_efficient=True) + # Shape: (batch_size, passage_length, passage_length) + attention_over_attention = torch.bmm(passage_question_attention, question_passage_attention) + # Shape: (batch_size, passage_length, encoding_dim) + passage_passage_vectors = util.weighted_sum(encoded_passage, attention_over_attention) + + # Shape: (batch_size, passage_length, encoding_dim * 4) + merged_passage_attention_vectors = self._dropout( + torch.cat([encoded_passage, passage_question_vectors, + encoded_passage * passage_question_vectors, + encoded_passage * passage_passage_vectors], + dim=-1) + ) + + modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)] + + for _ in range(3): + modeled_passage = self._dropout(self._modeling_layer(modeled_passage_list[-1], passage_mask)) + modeled_passage_list.append(modeled_passage) + + # Shape: (batch_size, passage_length, modeling_dim * 2)) + span_start_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1) + # Shape: (batch_size, passage_length) + span_start_logits = self._span_start_predictor(span_start_input).squeeze(-1) + + # Shape: (batch_size, passage_length, modeling_dim * 2) + span_end_input = torch.cat([modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1) + span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1) + span_start_logits = util.replace_masked_values(span_start_logits, passage_mask, -1e32) + span_end_logits = util.replace_masked_values(span_end_logits, passage_mask, -1e32) + + # Shape: (batch_size, passage_length) + span_start_probs = torch.nn.functional.softmax(span_start_logits, dim=-1) + span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1) + + best_span = get_best_span(span_start_logits, span_end_logits) + + output_dict = { + "passage_question_attention": passage_question_attention, + "span_start_logits": span_start_logits, + "span_start_probs": span_start_probs, + "span_end_logits": span_end_logits, + "span_end_probs": span_end_probs, + "best_span": best_span, + } + + # Compute the loss for training. + if span_start is not None: + loss = nll_loss(util.masked_log_softmax(span_start_logits, passage_mask), span_start.squeeze(-1)) + self._span_start_accuracy(span_start_logits, span_start.squeeze(-1)) + loss += nll_loss(util.masked_log_softmax(span_end_logits, passage_mask), span_end.squeeze(-1)) + self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) + self._span_accuracy(best_span, torch.stack([span_start, span_end], -1)) + output_dict["loss"] = loss + + # Compute the EM and F1 on SQuAD and add the tokenized input to the output. + if metadata is not None: + output_dict['best_span_str'] = [] + question_tokens = [] + passage_tokens = [] + for i in range(batch_size): + question_tokens.append(metadata[i]['question_tokens']) + passage_tokens.append(metadata[i]['passage_tokens']) + passage_str = metadata[i]['original_passage'] + offsets = metadata[i]['token_offsets'] + predicted_span = tuple(best_span[i].detach().cpu().numpy()) + start_offset = offsets[predicted_span[0]][0] + end_offset = offsets[predicted_span[1]][1] + best_span_string = passage_str[start_offset:end_offset] + output_dict['best_span_str'].append(best_span_string) + answer_texts = metadata[i].get('answer_texts', []) + if answer_texts: + self._metrics(best_span_string, answer_texts) + output_dict['question_tokens'] = question_tokens + output_dict['passage_tokens'] = passage_tokens + return output_dict + + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + exact_match, f1_score = self._metrics.get_metric(reset) + return { + 'start_acc': self._span_start_accuracy.get_metric(reset), + 'end_acc': self._span_end_accuracy.get_metric(reset), + 'span_acc': self._span_accuracy.get_metric(reset), + 'em': exact_match, + 'f1': f1_score, + } diff --git a/allennlp/models/reading_comprehension/util.py b/allennlp/models/reading_comprehension/util.py new file mode 100644 index 00000000000..5abca5cf0cd --- /dev/null +++ b/allennlp/models/reading_comprehension/util.py @@ -0,0 +1,33 @@ +import torch + + +def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: + """ + This acts the same as the static method ``BidirectionalAttentionFlow.get_best_span()`` + in ``allennlp/models/reading_comprehension/bidaf.py``. We keep it here so that users can + directly import this function without the class. + + We call the inputs "logits" - they could either be unnormalized logits or normalized log + probabilities. A log_softmax operation is a constant shifting of the entire logit + vector, so taking an argmax over either one gives the same result. + """ + if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: + raise ValueError("Input shapes must be (batch_size, passage_length)") + batch_size, passage_length = span_start_logits.size() + device = span_start_logits.device + # (batch_size, passage_length, passage_length) + span_log_probs = span_start_logits.unsqueeze(2) + span_end_logits.unsqueeze(1) + # Only the upper triangle of the span matrix is valid; the lower triangle has entries where + # the span ends before it starts. + span_log_mask = torch.triu(torch.ones((passage_length, passage_length), + device=device)).log() + valid_span_log_probs = span_log_probs + span_log_mask + + # Here we take the span matrix and flatten it, then find the best span using argmax. We + # can recover the start and end indices from this flattened list using simple modular + # arithmetic. + # (batch_size, passage_length * passage_length) + best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1) + span_start_indices = best_spans // passage_length + span_end_indices = best_spans % passage_length + return torch.stack([span_start_indices, span_end_indices], dim=-1) diff --git a/allennlp/modules/__init__.py b/allennlp/modules/__init__.py index 97b72df7cc5..13241dac8bc 100644 --- a/allennlp/modules/__init__.py +++ b/allennlp/modules/__init__.py @@ -23,3 +23,4 @@ from allennlp.modules.attention import Attention from allennlp.modules.input_variational_dropout import InputVariationalDropout from allennlp.modules.bimpm_matching import BiMpmMatching +from allennlp.modules.residual_with_layer_dropout import ResidualWithLayerDropout diff --git a/allennlp/modules/residual_with_layer_dropout.py b/allennlp/modules/residual_with_layer_dropout.py new file mode 100644 index 00000000000..011bfaffe7e --- /dev/null +++ b/allennlp/modules/residual_with_layer_dropout.py @@ -0,0 +1,59 @@ +import torch + + +class ResidualWithLayerDropout(torch.nn.Module): + """ + A residual connection with the layer dropout technique `Deep Networks with Stochastic + Depth `_ . + + This module accepts the input and output of a layer, decides whether this layer should + be stochastically dropped, returns either the input or output + input. During testing, + it will re-calibrate the outputs of this layer by the expected number of times it + participates in training. + """ + def __init__(self, undecayed_dropout_prob: float = 0.5) -> None: + super().__init__() + if undecayed_dropout_prob < 0 or undecayed_dropout_prob > 1: + raise ValueError(f"undecayed dropout probability has to be between 0 and 1, " + f"but got {undecayed_dropout_prob}") + self.undecayed_dropout_prob = undecayed_dropout_prob + + def forward(self, + layer_input: torch.Tensor, + layer_output: torch.Tensor, + layer_index: int = None, + total_layers: int = None) -> torch.Tensor: + # pylint: disable=arguments-differ + """ + Apply dropout to this layer, for this whole mini-batch. + dropout_prob = layer_index / total_layers * undecayed_dropout_prob if layer_idx and + total_layers is specified, else it will use the undecayed_dropout_prob directly. + + Parameters + ---------- + layer_input ``torch.FloatTensor`` required + The input tensor of this layer. + layer_output ``torch.FloatTensor`` required + The output tensor of this layer, with the same shape as the layer_input. + layer_index ``int`` + The layer index, starting from 1. This is used to calcuate the dropout prob + together with the `total_layers` parameter. + total_layers ``int`` + The total number of layers. + + Returns + ------- + output: ``torch.FloatTensor`` + A tensor with the same shape as `layer_input` and `layer_output`. + """ + if layer_index is not None and total_layers is not None: + dropout_prob = 1.0 * self.undecayed_dropout_prob * layer_index / total_layers + else: + dropout_prob = 1.0 * self.undecayed_dropout_prob + if self.training: + if torch.rand(1) < dropout_prob: + return layer_input + else: + return layer_output + layer_input + else: + return (1 - dropout_prob) * layer_output + layer_input diff --git a/allennlp/modules/seq2seq_encoders/__init__.py b/allennlp/modules/seq2seq_encoders/__init__.py index f6df58eeccc..90bae029ff5 100644 --- a/allennlp/modules/seq2seq_encoders/__init__.py +++ b/allennlp/modules/seq2seq_encoders/__init__.py @@ -38,6 +38,8 @@ from allennlp.modules.seq2seq_encoders.multi_head_self_attention import MultiHeadSelfAttention from allennlp.modules.seq2seq_encoders.pass_through_encoder import PassThroughEncoder from allennlp.modules.seq2seq_encoders.feedforward_encoder import FeedForwardEncoder +from allennlp.modules.seq2seq_encoders.qanet_encoder import QaNetEncoder + logger = logging.getLogger(__name__) # pylint: disable=invalid-name diff --git a/allennlp/modules/seq2seq_encoders/qanet_encoder.py b/allennlp/modules/seq2seq_encoders/qanet_encoder.py new file mode 100644 index 00000000000..1ff0eeefcff --- /dev/null +++ b/allennlp/modules/seq2seq_encoders/qanet_encoder.py @@ -0,0 +1,251 @@ +from typing import List + +from overrides import overrides +import torch +from torch.nn import Dropout +from torch.nn import LayerNorm +from allennlp.modules.feedforward import FeedForward +from allennlp.modules.residual_with_layer_dropout import ResidualWithLayerDropout +from allennlp.modules.seq2seq_encoders.multi_head_self_attention import MultiHeadSelfAttention +from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder +from allennlp.nn.activations import Activation +from allennlp.nn.util import add_positional_features +from allennlp.common.checks import check_dimensions_match + + +@Seq2SeqEncoder.register("qanet_encoder") +class QaNetEncoder(Seq2SeqEncoder): + """ + Stack multiple QANetEncoderBlock into one sequence encoder. + + Parameters + ---------- + input_dim : ``int``, required. + The input dimension of the encoder. + hidden_dim : ``int``, required. + The hidden dimension used for convolution output channels, multi-head attention output + and the final output of feedforward layer. + attention_projection_dim : ``int``, required. + The dimension of the linear projections for the self-attention layers. + feedforward_hidden_dim : ``int``, required. + The middle dimension of the FeedForward network. The input and output + dimensions are fixed to ensure sizes match up for the self attention layers. + num_blocks : ``int``, required. + The number of stacked encoder blocks. + num_convs_per_block: ``int``, required. + The number of convolutions in each block. + conv_kernel_size: ``int``, required. + The kernel size for convolution. + num_attention_heads : ``int``, required. + The number of attention heads to use per layer. + use_positional_encoding: ``bool``, optional, (default = True) + Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended, + as without this feature, the self attention layers have no idea of absolute or relative + position (as they are just computing pairwise similarity between vectors of elements), + which can be important features for many tasks. + dropout_prob : ``float``, optional, (default = 0.1) + The dropout probability for the feedforward network. + layer_dropout_undecayed_prob : ``float``, optional, (default = 0.1) + The initial dropout probability for layer dropout, and this might decay w.r.t the depth + of the layer. For each mini-batch, the convolution/attention/ffn sublayer is + stochastically dropped according to its layer dropout probability. + attention_dropout_prob : ``float``, optional, (default = 0) + The dropout probability for the attention distributions in the attention layer. + """ + def __init__(self, + input_dim: int, + hidden_dim: int, + attention_projection_dim: int, + feedforward_hidden_dim: int, + num_blocks: int, + num_convs_per_block: int, + conv_kernel_size: int, + num_attention_heads: int, + use_positional_encoding: bool = True, + dropout_prob: float = 0.1, + layer_dropout_undecayed_prob: float = 0.1, + attention_dropout_prob: float = 0) -> None: + super().__init__() + + self._input_projection_layer = None + + if input_dim != hidden_dim: + self._input_projection_layer = torch.nn.Linear(input_dim, hidden_dim) + else: + self._input_projection_layer = lambda x: x + + self._encoder_blocks: List[QaNetEncoderBlock] = [] + for block_index in range(num_blocks): + encoder_block = QaNetEncoderBlock(hidden_dim, + hidden_dim, + attention_projection_dim, + feedforward_hidden_dim, + num_convs_per_block, + conv_kernel_size, + num_attention_heads, + use_positional_encoding, + dropout_prob, + layer_dropout_undecayed_prob, + attention_dropout_prob) + self.add_module(f"encoder_block_{block_index}", encoder_block) + self._encoder_blocks.append(encoder_block) + + self._input_dim = input_dim + self._output_dim = hidden_dim + + @overrides + def get_input_dim(self) -> int: + return self._input_dim + + @overrides + def get_output_dim(self) -> int: + return self._output_dim + + @overrides + def is_bidirectional(self) -> bool: + return False + + @overrides + def forward(self, inputs: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: # pylint: disable=arguments-differ + inputs = self._input_projection_layer(inputs) + output = inputs + for encoder_block in self._encoder_blocks: + output = encoder_block(output, mask) + return output + + +@Seq2SeqEncoder.register("qanet_encoder_block") +class QaNetEncoderBlock(Seq2SeqEncoder): + """ + Implements the encoder block described in `QANet: Combining Local Convolution with Global + Self-attention for Reading Comprehension `_ . + + One encoder block mainly contains 4 parts: + + 1. Add position embedding. + 2. Several depthwise seperable convolutions. + 3. Multi-headed self attention, which uses 2 learnt linear projections + to perform a dot-product similarity between every pair of elements + scaled by the square root of the sequence length. + 4. A two-layer FeedForward network. + + Parameters + ---------- + input_dim : ``int``, required. + The input dimension of the encoder. + hidden_dim : ``int``, required. + The hidden dimension used for convolution output channels, multi-head attention output + and the final output of feedforward layer. + attention_projection_dim : ``int``, required. + The dimension of the linear projections for the self-attention layers. + feedforward_hidden_dim : ``int``, required. + The middle dimension of the FeedForward network. The input and output + dimensions are fixed to ensure sizes match up for the self attention layers. + num_convs: ``int``, required. + The number of convolutions in each block. + conv_kernel_size: ``int``, required. + The kernel size for convolution. + num_attention_heads : ``int``, required. + The number of attention heads to use per layer. + use_positional_encoding: ``bool``, optional, (default = True) + Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended, + as without this feature, the self attention layers have no idea of absolute or relative + position (as they are just computing pairwise similarity between vectors of elements), + which can be important features for many tasks. + dropout_prob : ``float``, optional, (default = 0.1) + The dropout probability for the feedforward network. + layer_dropout_undecayed_prob : ``float``, optional, (default = 0.1) + The initial dropout probability for layer dropout, and this might decay w.r.t the depth + of the layer. For each mini-batch, the convolution/attention/ffn sublayer is randomly + dropped according to its layer dropout probability. + attention_dropout_prob : ``float``, optional, (default = 0) + The dropout probability for the attention distributions in the attention layer. + """ + def __init__(self, + input_dim: int, + hidden_dim: int, + attention_projection_dim: int, + feedforward_hidden_dim: int, + num_convs: int, + conv_kernel_size: int, + num_attention_heads: int, + use_positional_encoding: bool = True, + dropout_prob: float = 0.1, + layer_dropout_undecayed_prob: float = 0.1, + attention_dropout_prob: float = 0) -> None: + super().__init__() + + check_dimensions_match(input_dim, hidden_dim, 'input_dim', 'hidden_dim') + + self._use_positional_encoding = use_positional_encoding + + self._conv_norm_layers = torch.nn.ModuleList([LayerNorm(hidden_dim) for _ in range(num_convs)]) + self._conv_layers = torch.nn.ModuleList() + for _ in range(num_convs): + padding = torch.nn.ConstantPad1d((conv_kernel_size // 2, (conv_kernel_size - 1) // 2), 0) + depthwise_conv = torch.nn.Conv1d(hidden_dim, hidden_dim, conv_kernel_size, groups=hidden_dim) + pointwise_conv = torch.nn.Conv1d(hidden_dim, hidden_dim, 1) + self._conv_layers.append( + torch.nn.Sequential(padding, depthwise_conv, pointwise_conv, Activation.by_name("relu")()) + ) + + self.attention_norm_layer = LayerNorm(hidden_dim) + self.attention_layer = MultiHeadSelfAttention(num_heads=num_attention_heads, + input_dim=hidden_dim, + attention_dim=attention_projection_dim, + values_dim=attention_projection_dim, + attention_dropout_prob=attention_dropout_prob) + self.feedforward_norm_layer = LayerNorm(hidden_dim) + self.feedforward = FeedForward(hidden_dim, + activations=[Activation.by_name('relu')(), + Activation.by_name('linear')()], + hidden_dims=[feedforward_hidden_dim, hidden_dim], + num_layers=2, + dropout=dropout_prob) + + self.dropout = Dropout(dropout_prob) + self.residual_with_layer_dropout = ResidualWithLayerDropout(layer_dropout_undecayed_prob) + self._input_dim = input_dim + self._output_dim = hidden_dim + + @overrides + def get_input_dim(self) -> int: + return self._input_dim + + @overrides + def get_output_dim(self) -> int: + return self._output_dim + + @overrides + def is_bidirectional(self): + return False + + @overrides + def forward(self, inputs: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: # pylint: disable=arguments-differ + if self._use_positional_encoding: + output = add_positional_features(inputs) + else: + output = inputs + + total_sublayers = len(self._conv_layers) + 2 + sublayer_count = 0 + + for conv_norm_layer, conv_layer in zip(self._conv_norm_layers, self._conv_layers): + conv_norm_out = self.dropout(conv_norm_layer(output)) + conv_out = self.dropout(conv_layer(conv_norm_out.transpose_(1, 2)).transpose_(1, 2)) + sublayer_count += 1 + output = self.residual_with_layer_dropout(output, conv_out, + sublayer_count, total_sublayers) + + attention_norm_out = self.dropout(self.attention_norm_layer(output)) + attention_out = self.dropout(self.attention_layer(attention_norm_out, mask)) + sublayer_count += 1 + output = self.residual_with_layer_dropout(output, attention_out, + sublayer_count, total_sublayers) + + feedforward_norm_out = self.dropout(self.feedforward_norm_layer(output)) + feedforward_out = self.dropout(self.feedforward(feedforward_norm_out)) + sublayer_count += 1 + output = self.residual_with_layer_dropout(output, feedforward_out, + sublayer_count, total_sublayers) + return output diff --git a/allennlp/tests/data/dataset_readers/reading_comprehension/squad_test.py b/allennlp/tests/data/dataset_readers/reading_comprehension/squad_test.py index 9e3c2f30983..8d5492a7796 100644 --- a/allennlp/tests/data/dataset_readers/reading_comprehension/squad_test.py +++ b/allennlp/tests/data/dataset_readers/reading_comprehension/squad_test.py @@ -6,6 +6,7 @@ from allennlp.data.dataset_readers import SquadReader from allennlp.common.testing import AllenNlpTestCase + class TestSquadReader: @pytest.mark.parametrize("lazy", (True, False)) def test_read_from_file(self, lazy): @@ -43,3 +44,35 @@ def test_can_build_from_params(self): # pylint: disable=protected-access assert reader._tokenizer.__class__.__name__ == 'WordTokenizer' assert reader._token_indexers["tokens"].__class__.__name__ == 'SingleIdTokenIndexer' + + def test_length_limit_works(self): + # We're making sure the length of the text is correct if length limit is provided. + reader = SquadReader(passage_length_limit=30, + question_length_limit=10, + skip_invalid_examples=True) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'squad.json')) + assert len(instances[0].fields["question"].tokens) == 10 + assert len(instances[0].fields["passage"].tokens) == 30 + # invalid examples where all the answers exceed the passage length should be skipped. + assert len(instances) == 3 + + # Length limit still works if we do not skip the invalid examples + reader = SquadReader(passage_length_limit=30, + question_length_limit=10, + skip_invalid_examples=False) + instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'squad.json')) + assert len(instances[0].fields["question"].tokens) == 10 + assert len(instances[0].fields["passage"].tokens) == 30 + # invalid examples should not be skipped. + assert len(instances) == 5 + + # Make sure the answer texts does not change, so that the evaluation will not be affected + reader_unlimited = SquadReader(passage_length_limit=30, + question_length_limit=10, + skip_invalid_examples=False) + instances_unlimited = ensure_list( + reader_unlimited.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'squad.json')) + for instance_x, instance_y in zip(instances, instances_unlimited): + print(instance_x.fields["metadata"]["answer_texts"]) + assert set(instance_x.fields["metadata"]["answer_texts"]) \ + == set(instance_y.fields["metadata"]["answer_texts"]) diff --git a/allennlp/tests/fixtures/qanet/experiment.json b/allennlp/tests/fixtures/qanet/experiment.json new file mode 100644 index 00000000000..8ab9f814af4 --- /dev/null +++ b/allennlp/tests/fixtures/qanet/experiment.json @@ -0,0 +1,141 @@ +{ + "dataset_reader": { + "type": "squad", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 5 + } + }, + "passage_length_limit": 400, + "question_length_limit": 50, + "skip_invalid_examples": true + }, + "validation_dataset_reader": { + "type": "squad", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 5 + } + }, + "passage_length_limit": 1000, + "question_length_limit": 100, + "skip_invalid_examples": false + }, + "train_data_path": "allennlp/tests/fixtures/data/squad.json", + "validation_data_path": "allennlp/tests/fixtures/data/squad.json", + "model": { + "type": "qanet", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "embedding_dim": 16, + "trainable": false + }, + "token_characters": { + "type": "character_encoding", + "embedding": { + "num_embeddings": 262, + "embedding_dim": 8 + }, + "encoder": { + "type": "cnn", + "embedding_dim": 8, + "num_filters": 8, + "ngram_filter_sizes": [5] + } + } + } + }, + "num_highway_layers": 2, + "phrase_layer": { + "type": "qanet_encoder", + "input_dim": 16, + "hidden_dim": 16, + "attention_projection_dim": 16, + "feedforward_hidden_dim": 16, + "num_blocks": 1, + "num_convs_per_block": 2, + "conv_kernel_size": 2, + "num_attention_heads": 4, + "dropout_prob": 0.1, + "layer_dropout_undecayed_prob": 0.1, + "attention_dropout_prob": 0 + }, + "matrix_attention_layer": { + "type": "linear", + "tensor_1_dim": 16, + "tensor_2_dim": 16, + "combination": "x,y,x*y" + }, + "modeling_layer": { + "type": "qanet_encoder", + "input_dim": 16, + "hidden_dim": 16, + "attention_projection_dim": 16, + "feedforward_hidden_dim": 16, + "num_blocks": 2, + "num_convs_per_block": 2, + "conv_kernel_size": 5, + "num_attention_heads": 4, + "dropout_prob": 0.1, + "layer_dropout_undecayed_prob": 0.1, + "attention_dropout_prob": 0 + }, + "dropout_prob": 0.1, + "regularizer": [ + [ + ".*", + { + "type": "l2", + "alpha": 1e-07 + } + ] + ] + }, + "iterator": { + "type": "bucket", + "sorting_keys": [ + [ + "passage", + "num_tokens" + ], + [ + "question", + "num_tokens" + ] + ], + "batch_size": 5, + "padding_noise": 0.0 + }, + "trainer": { + "num_epochs": 1, + "grad_norm": 5, + "patience": 10, + "validation_metric": "+f1", + "cuda_device": -1, + "optimizer": { + "type": "adam", + "lr": 0.001, + "betas": [ + 0.8, + 0.999 + ], + "eps": 1e-07 + }, + "moving_average": { + "type": "exponential", + "decay": 0.9999 + } + } +} \ No newline at end of file diff --git a/allennlp/tests/models/reading_comprehension/qanet_test.py b/allennlp/tests/models/reading_comprehension/qanet_test.py new file mode 100644 index 00000000000..4a71be1d23a --- /dev/null +++ b/allennlp/tests/models/reading_comprehension/qanet_test.py @@ -0,0 +1,72 @@ +# pylint: disable=no-self-use,invalid-name +import numpy +from numpy.testing import assert_almost_equal +from allennlp.common import Params +from allennlp.data import DatasetReader, Vocabulary +from allennlp.common.testing import ModelTestCase +from allennlp.data.dataset import Batch +from allennlp.models import Model + + +class QaNetTest(ModelTestCase): + def setUp(self): + super().setUp() + self.set_up_model(self.FIXTURES_ROOT / 'qanet' / 'experiment.json', + self.FIXTURES_ROOT / 'data' / 'squad.json') + + def test_forward_pass_runs_correctly(self): + batch = Batch(self.instances) + batch.index_instances(self.vocab) + training_tensors = batch.as_tensor_dict() + output_dict = self.model(**training_tensors) + + metrics = self.model.get_metrics(reset=True) + # We've set up the data such that there's a fake answer that consists of the whole + # paragraph. _Any_ valid prediction for that question should produce an F1 of greater than + # zero, while if we somehow haven't been able to load the evaluation data, or there was an + # error with using the evaluation script, this will fail. This makes sure that we've + # loaded the evaluation data correctly and have hooked things up to the official evaluation + # script. + assert metrics['f1'] > 0 + + span_start_probs = output_dict['span_start_probs'][0].data.numpy() + span_end_probs = output_dict['span_start_probs'][0].data.numpy() + assert_almost_equal(numpy.sum(span_start_probs, -1), 1, decimal=6) + assert_almost_equal(numpy.sum(span_end_probs, -1), 1, decimal=6) + span_start, span_end = tuple(output_dict['best_span'][0].data.numpy()) + assert span_start >= 0 + assert span_start <= span_end + assert span_end < self.instances[0].fields['passage'].sequence_length() + assert isinstance(output_dict['best_span_str'][0], str) + + def test_model_can_train_save_and_load(self): + self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-4) + + def test_batch_predictions_are_consistent(self): + # The same issue as the bidaf test case. + # The CNN encoder has problems with this kind of test - it's not properly masked yet, so + # changing the amount of padding in the batch will result in small differences in the + # output of the encoder. So, we'll remove the CNN encoder entirely from the model for this test. + # Save some state. + # pylint: disable=protected-access,attribute-defined-outside-init + saved_model = self.model + saved_instances = self.instances + + # Modify the state, run the test with modified state. + params = Params.from_file(self.param_file) + reader = DatasetReader.from_params(params['dataset_reader']) + reader._token_indexers = {'tokens': reader._token_indexers['tokens']} + self.instances = reader.read(self.FIXTURES_ROOT / 'data' / 'squad.json') + vocab = Vocabulary.from_instances(self.instances) + for instance in self.instances: + instance.index_fields(vocab) + del params['model']['text_field_embedder']['token_embedders']['token_characters'] + params['model']['phrase_layer']['num_convs_per_block'] = 0 + params['model']['modeling_layer']['num_convs_per_block'] = 0 + self.model = Model.from_params(vocab=vocab, params=params['model']) + + self.ensure_batch_predictions_are_consistent() + + # Restore the state. + self.model = saved_model + self.instances = saved_instances diff --git a/allennlp/tests/models/reading_comprehension/util_test.py b/allennlp/tests/models/reading_comprehension/util_test.py new file mode 100644 index 00000000000..6825692688e --- /dev/null +++ b/allennlp/tests/models/reading_comprehension/util_test.py @@ -0,0 +1,32 @@ +# pylint: disable=no-self-use +from numpy.testing import assert_almost_equal +import torch +from allennlp.common.testing import AllenNlpTestCase +from allennlp.models.reading_comprehension.util import get_best_span + + +class TestRcUtil(AllenNlpTestCase): + def test_get_best_span(self): + span_begin_probs = torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]]).log() + span_end_probs = torch.FloatTensor([[0.65, 0.05, 0.2, 0.05, 0.05]]).log() + begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) + assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) + + # When we were using exclusive span ends, this was an edge case of the dynamic program. + # We're keeping the test to make sure we get it right now, after the switch in inclusive + # span end. The best answer is (1, 1). + span_begin_probs = torch.FloatTensor([[0.4, 0.5, 0.1]]).log() + span_end_probs = torch.FloatTensor([[0.3, 0.6, 0.1]]).log() + begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) + assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]]) + + # Another instance that used to be an edge case. + span_begin_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log() + span_end_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log() + begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) + assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) + + span_begin_probs = torch.FloatTensor([[0.1, 0.2, 0.05, 0.3, 0.25]]).log() + span_end_probs = torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]]).log() + begin_end_idxs = get_best_span(span_begin_probs, span_end_probs) + assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]]) diff --git a/allennlp/tests/modules/residual_with_layer_dropout_test.py b/allennlp/tests/modules/residual_with_layer_dropout_test.py new file mode 100644 index 00000000000..5943623b191 --- /dev/null +++ b/allennlp/tests/modules/residual_with_layer_dropout_test.py @@ -0,0 +1,42 @@ +# pylint: disable=no-self-use,invalid-name +from numpy.testing import assert_almost_equal +import torch + +from allennlp.modules import ResidualWithLayerDropout +from allennlp.common.testing import AllenNlpTestCase + + +class TestResidualWithLayerDropout(AllenNlpTestCase): + def test_dropout_works_for_training(self): + layer_input_tensor = torch.FloatTensor([[2, 1], [-3, -2]]) + layer_output_tensor = torch.FloatTensor([[1, 3], [2, -1]]) + + # The layer output should be dropped + residual_with_layer_dropout = ResidualWithLayerDropout(1) + residual_with_layer_dropout.train() + result = residual_with_layer_dropout(layer_input_tensor, layer_output_tensor).data.numpy() + assert result.shape == (2, 2) + assert_almost_equal(result, [[2, 1], [-3, -2]]) + + result = residual_with_layer_dropout(layer_input_tensor, layer_output_tensor, 1, 1).data.numpy() + assert result.shape == (2, 2) + assert_almost_equal(result, [[2, 1], [-3, -2]]) + + # The layer output should not be dropped + residual_with_layer_dropout = ResidualWithLayerDropout(0.0) + residual_with_layer_dropout.train() + result = residual_with_layer_dropout(layer_input_tensor, layer_output_tensor).data.numpy() + assert result.shape == (2, 2) + assert_almost_equal(result, [[2 + 1, 1 + 3], [-3 + 2, -2 - 1]]) + + def test_dropout_works_for_testing(self): + layer_input_tensor = torch.FloatTensor([[2, 1], [-3, -2]]) + layer_output_tensor = torch.FloatTensor([[1, 3], [2, -1]]) + + # During testing, the layer output is re-calibrated according to the survival probability, + # and then added to the input. + residual_with_layer_dropout = ResidualWithLayerDropout(0.2) + residual_with_layer_dropout.eval() + result = residual_with_layer_dropout(layer_input_tensor, layer_output_tensor).data.numpy() + assert result.shape == (2, 2) + assert_almost_equal(result, [[2 + 1 * 0.8, 1 + 3 * 0.8], [-3 + 2 * 0.8, -2 - 1 * 0.8]]) diff --git a/allennlp/tests/modules/seq2seq_encoders/qanet_encoder_test.py b/allennlp/tests/modules/seq2seq_encoders/qanet_encoder_test.py new file mode 100644 index 00000000000..6bbf1cfe1d9 --- /dev/null +++ b/allennlp/tests/modules/seq2seq_encoders/qanet_encoder_test.py @@ -0,0 +1,44 @@ +# pylint: disable=invalid-name,no-self-use,too-many-public-methods +import torch + +from allennlp.common.testing import AllenNlpTestCase +from allennlp.modules.seq2seq_encoders import QaNetEncoder +from allennlp.common.params import Params + + +class QaNetEncoderTest(AllenNlpTestCase): + + def test_qanet_encoder_can_build_from_params(self): + params = Params({ + "input_dim": 16, + "hidden_dim": 16, + "attention_projection_dim": 16, + "feedforward_hidden_dim": 16, + "num_blocks": 2, + "num_convs_per_block": 2, + "conv_kernel_size": 3, + "num_attention_heads": 4, + "dropout_prob": 0.1, + "layer_dropout_undecayed_prob": 0.1, + "attention_dropout_prob": 0 + }) + + encoder = QaNetEncoder.from_params(params) + assert isinstance(encoder, QaNetEncoder) + assert encoder.get_input_dim() == 16 + assert encoder.get_output_dim() == 16 + + def test_qanet_encoder_runs_forward(self): + encoder = QaNetEncoder(input_dim=16, + hidden_dim=16, + attention_projection_dim=16, + feedforward_hidden_dim=16, + num_blocks=2, + num_convs_per_block=2, + conv_kernel_size=3, + num_attention_heads=4, + dropout_prob=0.1, + layer_dropout_undecayed_prob=0.1, + attention_dropout_prob=0.1) + inputs = torch.randn(2, 12, 16) + assert list(encoder(inputs).size()) == [2, 12, 16] diff --git a/doc/api/allennlp.models.reading_comprehension.rst b/doc/api/allennlp.models.reading_comprehension.rst index 20fb9bd1777..860feccf45e 100644 --- a/doc/api/allennlp.models.reading_comprehension.rst +++ b/doc/api/allennlp.models.reading_comprehension.rst @@ -20,3 +20,13 @@ allennlp.models.reading_comprehension :members: :undoc-members: :show-inheritance: + +.. automodule:: allennlp.models.reading_comprehension.qanet + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: allennlp.models.reading_comprehension.util + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/api/allennlp.modules.residual_with_layer_dropout.rst b/doc/api/allennlp.modules.residual_with_layer_dropout.rst new file mode 100644 index 00000000000..b88ec012b9d --- /dev/null +++ b/doc/api/allennlp.modules.residual_with_layer_dropout.rst @@ -0,0 +1,7 @@ +allennlp.modules.residual_with_layer_dropout +============================================ + +.. automodule:: allennlp.modules.residual_with_layer_dropout + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/api/allennlp.modules.rst b/doc/api/allennlp.modules.rst index e660105791f..a63d4fd556e 100644 --- a/doc/api/allennlp.modules.rst +++ b/doc/api/allennlp.modules.rst @@ -36,3 +36,4 @@ allennlp.modules allennlp.modules.bimpm_matching allennlp.modules.masked_layer_norm allennlp.modules.sampled_softmax_loss + allennlp.modules.residual_with_layer_dropout diff --git a/doc/api/allennlp.modules.seq2seq_encoders.rst b/doc/api/allennlp.modules.seq2seq_encoders.rst index 7da2a9e074e..69c7cd3688a 100644 --- a/doc/api/allennlp.modules.seq2seq_encoders.rst +++ b/doc/api/allennlp.modules.seq2seq_encoders.rst @@ -50,3 +50,8 @@ allennlp.modules.seq2seq_encoders :members: :undoc-members: :show-inheritance: + +.. automodule:: allennlp.modules.seq2seq_encoders.qanet_encoder + :members: + :undoc-members: + :show-inheritance: diff --git a/training_config/qanet.jsonnet b/training_config/qanet.jsonnet new file mode 100644 index 00000000000..0d4ee0bb723 --- /dev/null +++ b/training_config/qanet.jsonnet @@ -0,0 +1,159 @@ +// Configuration for the basic QANet model from "QANet: Combining Local +// Convolution with Global Self-Attention for Reading Comprehension" +// (https://arxiv.org/abs/1804.09541). +{ + "dataset_reader": { + "type": "squad", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 5 + } + }, + "passage_length_limit": 400, + "question_length_limit": 50, + "skip_invalid_examples": true + }, + "validation_dataset_reader": { + "type": "squad", + "token_indexers": { + "tokens": { + "type": "single_id", + "lowercase_tokens": true + }, + "token_characters": { + "type": "characters", + "min_padding_length": 5 + } + }, + "passage_length_limit": 1000, + "question_length_limit": 100, + "skip_invalid_examples": false + }, + "vocabulary": { + "min_count": { + "token_characters": 200 + }, + "pretrained_files": { + // This embedding file is created from the Glove 840B 300d embedding file. + // We kept all the original lowercased words and their embeddings. But there are also many words + // with only the uppercased version. To include as many words as possible, we lowered those words + // and used the embeddings of uppercased words as an alternative. + "tokens": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.lower.converted.zip" + }, + "only_include_pretrained_words": true + }, + "train_data_path": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json", + "validation_data_path": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json", + "model": { + "type": "qanet", + "text_field_embedder": { + "token_embedders": { + "tokens": { + "type": "embedding", + "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.lower.converted.zip", + "embedding_dim": 300, + "trainable": false + }, + "token_characters": { + "type": "character_encoding", + "embedding": { + "embedding_dim": 64 + }, + "encoder": { + "type": "cnn", + "embedding_dim": 64, + "num_filters": 200, + "ngram_filter_sizes": [ + 5 + ] + } + } + } + }, + "num_highway_layers": 2, + "phrase_layer": { + "type": "qanet_encoder", + "input_dim": 128, + "hidden_dim": 128, + "attention_projection_dim": 128, + "feedforward_hidden_dim": 128, + "num_blocks": 1, + "num_convs_per_block": 4, + "conv_kernel_size": 7, + "num_attention_heads": 8, + "dropout_prob": 0.1, + "layer_dropout_undecayed_prob": 0.1, + "attention_dropout_prob": 0 + }, + "matrix_attention_layer": { + "type": "linear", + "tensor_1_dim": 128, + "tensor_2_dim": 128, + "combination": "x,y,x*y" + }, + "modeling_layer": { + "type": "qanet_encoder", + "input_dim": 128, + "hidden_dim": 128, + "attention_projection_dim": 128, + "feedforward_hidden_dim": 128, + "num_blocks": 7, + "num_convs_per_block": 2, + "conv_kernel_size": 5, + "num_attention_heads": 8, + "dropout_prob": 0.1, + "layer_dropout_undecayed_prob": 0.1, + "attention_dropout_prob": 0 + }, + "dropout_prob": 0.1, + "regularizer": [ + [ + ".*", + { + "type": "l2", + "alpha": 1e-07 + } + ] + ] + }, + "iterator": { + "type": "bucket", + "sorting_keys": [ + [ + "passage", + "num_tokens" + ], + [ + "question", + "num_tokens" + ] + ], + "batch_size": 25, + "max_instances_in_memory": 600 + }, + "trainer": { + "num_epochs": 50, + "grad_norm": 5, + "patience": 10, + "validation_metric": "+em", + "cuda_device": 0, + "optimizer": { + "type": "adam", + "lr": 0.001, + "betas": [ + 0.8, + 0.999 + ], + "eps": 1e-07 + }, + "moving_average": { + "type": "exponential", + "decay": 0.9999 + } + } +} \ No newline at end of file