From c635bc4747b5ee5f14d7a31602010d4892e6e27e Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Wed, 3 Oct 2018 11:19:04 -0700 Subject: [PATCH] Graph parser for semantic dependencies (#1743) --- allennlp/data/dataset_readers/__init__.py | 1 + .../semantic_dependency_parsing.py | 114 ++++++ allennlp/data/fields/__init__.py | 1 + allennlp/data/fields/adjacency_field.py | 129 +++++++ allennlp/models/__init__.py | 1 + allennlp/models/graph_parser.py | 344 ++++++++++++++++++ .../bilinear_matrix_attention.py | 19 +- .../semantic_dependency_parsing.py | 33 ++ .../tests/data/fields/adjacency_field_test.py | 54 +++ allennlp/tests/fixtures/data/dm.sdp | 35 ++ .../fixtures/graph_parser/experiment.json | 44 +++ allennlp/tests/models/graph_parser_test.py | 27 ++ doc/api/allennlp.data.dataset_readers.rst | 2 + ...et_readers.semantic_dependency_parsing.rst | 7 + doc/api/allennlp.data.fields.rst | 7 + doc/api/allennlp.models.graph_parser.rst | 7 + doc/api/allennlp.models.rst | 1 + training_config/semantic_dependencies.json | 64 ++++ 18 files changed, 887 insertions(+), 3 deletions(-) create mode 100644 allennlp/data/dataset_readers/semantic_dependency_parsing.py create mode 100644 allennlp/data/fields/adjacency_field.py create mode 100644 allennlp/models/graph_parser.py create mode 100644 allennlp/tests/data/dataset_readers/semantic_dependency_parsing.py create mode 100644 allennlp/tests/data/fields/adjacency_field_test.py create mode 100644 allennlp/tests/fixtures/data/dm.sdp create mode 100644 allennlp/tests/fixtures/graph_parser/experiment.json create mode 100644 allennlp/tests/models/graph_parser_test.py create mode 100644 doc/api/allennlp.data.dataset_readers.semantic_dependency_parsing.rst create mode 100644 doc/api/allennlp.models.graph_parser.rst create mode 100644 training_config/semantic_dependencies.json diff --git a/allennlp/data/dataset_readers/__init__.py b/allennlp/data/dataset_readers/__init__.py index a48fad845fd..64b77447b83 100644 --- a/allennlp/data/dataset_readers/__init__.py +++ b/allennlp/data/dataset_readers/__init__.py @@ -19,6 +19,7 @@ from allennlp.data.dataset_readers.penn_tree_bank import PennTreeBankConstituencySpanDatasetReader from allennlp.data.dataset_readers.reading_comprehension import SquadReader, TriviaQaReader, QuACReader from allennlp.data.dataset_readers.semantic_role_labeling import SrlReader +from allennlp.data.dataset_readers.semantic_dependency_parsing import SemanticDependenciesDatasetReader from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader from allennlp.data.dataset_readers.sequence_tagging import SequenceTaggingDatasetReader from allennlp.data.dataset_readers.snli import SnliReader diff --git a/allennlp/data/dataset_readers/semantic_dependency_parsing.py b/allennlp/data/dataset_readers/semantic_dependency_parsing.py new file mode 100644 index 00000000000..76d02120437 --- /dev/null +++ b/allennlp/data/dataset_readers/semantic_dependency_parsing.py @@ -0,0 +1,114 @@ +from typing import Dict, List, Tuple +import logging +from overrides import overrides + +from allennlp.common.file_utils import cached_path +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.fields import AdjacencyField, MetadataField, SequenceLabelField +from allennlp.data.fields import Field, TextField +from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer +from allennlp.data.tokenizers import Token +from allennlp.data.instance import Instance + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +FIELDS = ["id", "form", "lemma", "pos", "head", "deprel", "top", "pred", "frame"] + +def parse_sentence(sentence_blob: str) -> Tuple[List[Dict[str, str]], List[Tuple[int, int]], List[str]]: + """ + Parses a chunk of text in the SemEval SDP format. + + Each word in the sentence is returned as a dictionary with the following + format: + 'id': '1', + 'form': 'Pierre', + 'lemma': 'Pierre', + 'pos': 'NNP', + 'head': '2', # Note that this is the `syntactic` head. + 'deprel': 'nn', + 'top': '-', + 'pred': '+', + 'frame': 'named:x-c' + + Along with a list of arcs and their corresponding tags. Note that + in semantic dependency parsing words can have more than one head + (it is not a tree), meaning that the list of arcs and tags are + not tied to the length of the sentence. + """ + annotated_sentence = [] + arc_indices = [] + arc_tags = [] + predicates = [] + + lines = [line.split("\t") for line in sentence_blob.split("\n") + if line and not line.strip().startswith("#")] + for line_idx, line in enumerate(lines): + annotated_token = {k:v for k, v in zip(FIELDS, line)} + if annotated_token['pred'] == "+": + predicates.append(line_idx) + annotated_sentence.append(annotated_token) + + for line_idx, line in enumerate(lines): + for predicate_idx, arg in enumerate(line[len(FIELDS):]): + if arg != "_": + arc_indices.append((line_idx, predicates[predicate_idx])) + arc_tags.append(arg) + return annotated_sentence, arc_indices, arc_tags + + +def lazy_parse(text: str): + for sentence in text.split("\n\n"): + if sentence: + yield parse_sentence(sentence) + + +@DatasetReader.register("semantic_dependencies") +class SemanticDependenciesDatasetReader(DatasetReader): + """ + Reads a file in the SemEval 2015 Task 18 (Broad-coverage Semantic Dependency Parsing) + format. + + Parameters + ---------- + token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) + The token indexers to be applied to the words TextField. + """ + def __init__(self, + token_indexers: Dict[str, TokenIndexer] = None, + lazy: bool = False) -> None: + super().__init__(lazy) + self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} + + @overrides + def _read(self, file_path: str): + # if `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + logger.info("Reading semantic dependency parsing data from: %s", file_path) + + with open(file_path) as sdp_file: + for annotated_sentence, directed_arc_indices, arc_tags in lazy_parse(sdp_file.read()): + # If there are no arc indices, skip this instance. + if not directed_arc_indices: + continue + tokens = [word["form"] for word in annotated_sentence] + pos_tags = [word["pos"] for word in annotated_sentence] + yield self.text_to_instance(tokens, pos_tags, directed_arc_indices, arc_tags) + + @overrides + def text_to_instance(self, # type: ignore + tokens: List[str], + pos_tags: List[str] = None, + arc_indices: List[Tuple[int, int]] = None, + arc_tags: List[str] = None) -> Instance: + # pylint: disable=arguments-differ + fields: Dict[str, Field] = {} + token_field = TextField([Token(t) for t in tokens], self._token_indexers) + fields["tokens"] = token_field + fields["metadata"] = MetadataField({"tokens": tokens}) + if pos_tags is not None: + fields["pos_tags"] = SequenceLabelField(pos_tags, token_field, label_namespace="pos") + if arc_indices is not None and arc_tags is not None: + fields["arc_tags"] = AdjacencyField(arc_indices, token_field, arc_tags) + + return Instance(fields) diff --git a/allennlp/data/fields/__init__.py b/allennlp/data/fields/__init__.py index 6def3e8142d..fc4893e43b7 100644 --- a/allennlp/data/fields/__init__.py +++ b/allennlp/data/fields/__init__.py @@ -5,6 +5,7 @@ from allennlp.data.fields.field import Field from allennlp.data.fields.array_field import ArrayField +from allennlp.data.fields.adjacency_field import AdjacencyField from allennlp.data.fields.index_field import IndexField from allennlp.data.fields.knowledge_graph_field import KnowledgeGraphField from allennlp.data.fields.label_field import LabelField diff --git a/allennlp/data/fields/adjacency_field.py b/allennlp/data/fields/adjacency_field.py new file mode 100644 index 00000000000..c8583cd36e4 --- /dev/null +++ b/allennlp/data/fields/adjacency_field.py @@ -0,0 +1,129 @@ +from typing import Dict, List, Set, Tuple +import logging +import textwrap + +from overrides import overrides +import torch + +from allennlp.common.checks import ConfigurationError +from allennlp.data.fields.field import Field +from allennlp.data.fields.sequence_field import SequenceField +from allennlp.data.vocabulary import Vocabulary + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class AdjacencyField(Field[torch.Tensor]): + """ + A ``AdjacencyField`` defines directed adjacency relations between elements + in a :class:`~allennlp.data.fields.sequence_field.SequenceField`. + Because it's a labeling of some other field, we take that field as input here + and use it to determine our padding and other things. + + This field will get converted into an array of shape (sequence_field_length, sequence_field_length), + where the (i, j)th array element is either a binary flag indicating there is an edge from i to j, + or an integer label k, indicating there is a label from i to j of type k. + + Parameters + ---------- + indices : ``List[Tuple[int, int]]`` + sequence_field : ``SequenceField`` + A field containing the sequence that this ``AdjacencyField`` is labeling. Most often, + this is a ``TextField``, for tagging edge relations between tokens in a sentence. + labels : ``List[str]``, optional, default = None + Optional labels for the edges of the adjacency matrix. + label_namespace : ``str``, optional (default='labels') + The namespace to use for converting tag strings into integers. We convert tag strings to + integers for you, and this parameter tells the ``Vocabulary`` object which mapping from + strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word). + padding_value : ``int``, (optional, default = -1) + The value to use as padding. + """ + # It is possible that users want to use this field with a namespace which uses OOV/PAD tokens. + # This warning will be repeated for every instantiation of this class (i.e for every data + # instance), spewing a lot of warnings so this class variable is used to only log a single + # warning per namespace. + _already_warned_namespaces: Set[str] = set() + + def __init__(self, + indices: List[Tuple[int, int]], + sequence_field: SequenceField, + labels: List[str] = None, + label_namespace: str = 'labels', + padding_value: int = -1) -> None: + self.indices = indices + self.labels = labels + self.sequence_field = sequence_field + self._label_namespace = label_namespace + self._padding_value = padding_value + self._indexed_labels: List[int] = None + + self._maybe_warn_for_namespace(label_namespace) + field_length = sequence_field.sequence_length() + + if len(set(indices)) != len(indices): + raise ConfigurationError(f"Indices must be unique, but found {indices}") + + if not all([0 <= index[1] < field_length and 0 <= index[0] < field_length for index in indices]): + raise ConfigurationError(f"Label indices and sequence length " + f"are incompatible: {indices} and {field_length}") + + if labels is not None and len(indices) != len(labels): + raise ConfigurationError(f"Labelled indices were passed, but their lengths do not match: " + f" {labels}, {indices}") + + def _maybe_warn_for_namespace(self, label_namespace: str) -> None: + if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): + if label_namespace not in self._already_warned_namespaces: + logger.warning("Your label namespace was '%s'. We recommend you use a namespace " + "ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " + "default to your vocabulary. See documentation for " + "`non_padded_namespaces` parameter in Vocabulary.", + self._label_namespace) + self._already_warned_namespaces.add(label_namespace) + + @overrides + def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): + if self._indexed_labels is None and self.labels is not None: + for label in self.labels: + counter[self._label_namespace][label] += 1 # type: ignore + + @overrides + def index(self, vocab: Vocabulary): + if self._indexed_labels is None and self.labels is not None: + self._indexed_labels = [vocab.get_token_index(label, self._label_namespace) + for label in self.labels] + + @overrides + def get_padding_lengths(self) -> Dict[str, int]: + return {'num_tokens': self.sequence_field.sequence_length()} + + @overrides + def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: + desired_num_tokens = padding_lengths['num_tokens'] + tensor = torch.ones(desired_num_tokens, desired_num_tokens) * self._padding_value + labels = self._indexed_labels or [1 for _ in range(len(self.indices))] + + for index, label in zip(self.indices, labels): + tensor[index] = label + return tensor + + @overrides + def empty_field(self) -> 'AdjacencyField': + # pylint: disable=protected-access + # The empty_list here is needed for mypy + empty_list: List[Tuple[int, int]] = [] + adjacency_field = AdjacencyField(empty_list, + self.sequence_field.empty_field(), + padding_value=self._padding_value) + return adjacency_field + + def __str__(self) -> str: + length = self.sequence_field.sequence_length() + formatted_labels = "".join(["\t\t" + labels + "\n" + for labels in textwrap.wrap(repr(self.labels), 100)]) + formatted_indices = "".join(["\t\t" + index + "\n" + for index in textwrap.wrap(repr(self.indices), 100)]) + return f"AdjacencyField of length {length}\n" \ + f"\t\twith indices:\n {formatted_indices}\n" \ + f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." diff --git a/allennlp/models/__init__.py b/allennlp/models/__init__.py index b3bfdc5b5cc..992707d8172 100644 --- a/allennlp/models/__init__.py +++ b/allennlp/models/__init__.py @@ -23,3 +23,4 @@ from allennlp.models.simple_tagger import SimpleTagger from allennlp.models.esim import ESIM from allennlp.models.bimpm import BiMpm +from allennlp.models.graph_parser import GraphParser diff --git a/allennlp/models/graph_parser.py b/allennlp/models/graph_parser.py new file mode 100644 index 00000000000..7e5d1f31999 --- /dev/null +++ b/allennlp/models/graph_parser.py @@ -0,0 +1,344 @@ +from typing import Dict, Optional, Tuple, Any, List +import logging +import copy + +from overrides import overrides +import torch +from torch.nn.modules import Dropout +import numpy + +from allennlp.common.checks import check_dimensions_match, ConfigurationError +from allennlp.data import Vocabulary +from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder, Embedding, InputVariationalDropout +from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention +from allennlp.modules import FeedForward +from allennlp.models.model import Model +from allennlp.nn import InitializerApplicator, RegularizerApplicator, Activation +from allennlp.nn.util import get_text_field_mask +from allennlp.nn.util import get_lengths_from_binary_sequence_mask +from allennlp.training.metrics import F1Measure + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +@Model.register("graph_parser") +class GraphParser(Model): + """ + A Parser for arbitrary graph stuctures. + + Parameters + ---------- + vocab : ``Vocabulary``, required + A Vocabulary, required in order to compute sizes for input/output projections. + text_field_embedder : ``TextFieldEmbedder``, required + Used to embed the ``tokens`` ``TextField`` we get as input to the model. + encoder : ``Seq2SeqEncoder`` + The encoder (with its own internal stacking) that we will use to generate representations + of tokens. + tag_representation_dim : ``int``, required. + The dimension of the MLPs used for arc tag prediction. + arc_representation_dim : ``int``, required. + The dimension of the MLPs used for arc prediction. + tag_feedforward : ``FeedForward``, optional, (default = None). + The feedforward network used to produce tag representations. + By default, a 1 layer feedforward network with an elu activation is used. + arc_feedforward : ``FeedForward``, optional, (default = None). + The feedforward network used to produce arc representations. + By default, a 1 layer feedforward network with an elu activation is used. + pos_tag_embedding : ``Embedding``, optional. + Used to embed the ``pos_tags`` ``SequenceLabelField`` we get as input to the model. + dropout : ``float``, optional, (default = 0.0) + The variational dropout applied to the output of the encoder and MLP layers. + input_dropout : ``float``, optional, (default = 0.0) + The dropout applied to the embedded text input. + edge_prediction_threshold : ``int``, optional (default = 0.5) + The probability at which to consider a scored edge to be 'present' + in the decoded graph. Must be between 0 and 1. + initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) + Used to initialize the model parameters. + regularizer : ``RegularizerApplicator``, optional (default=``None``) + If provided, will be used to calculate the regularization penalty during training. + """ + def __init__(self, + vocab: Vocabulary, + text_field_embedder: TextFieldEmbedder, + encoder: Seq2SeqEncoder, + tag_representation_dim: int, + arc_representation_dim: int, + tag_feedforward: FeedForward = None, + arc_feedforward: FeedForward = None, + pos_tag_embedding: Embedding = None, + dropout: float = 0.0, + input_dropout: float = 0.0, + edge_prediction_threshold: float = 0.5, + initializer: InitializerApplicator = InitializerApplicator(), + regularizer: Optional[RegularizerApplicator] = None) -> None: + super(GraphParser, self).__init__(vocab, regularizer) + + self.text_field_embedder = text_field_embedder + self.encoder = encoder + self.edge_prediction_threshold = edge_prediction_threshold + if not 0 < edge_prediction_threshold < 1: + raise ConfigurationError(f"edge_prediction_threshold must be between " + f"0 and 1 (exclusive) but found {edge_prediction_threshold}.") + + encoder_dim = encoder.get_output_dim() + + self.head_arc_feedforward = arc_feedforward or \ + FeedForward(encoder_dim, 1, + arc_representation_dim, + Activation.by_name("elu")()) + self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) + + self.arc_attention = BilinearMatrixAttention(arc_representation_dim, + arc_representation_dim, + use_input_biases=True) + + num_labels = self.vocab.get_vocab_size("labels") + self.head_tag_feedforward = tag_feedforward or \ + FeedForward(encoder_dim, 1, + tag_representation_dim, + Activation.by_name("elu")()) + self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) + + self.tag_bilinear = BilinearMatrixAttention(tag_representation_dim, + tag_representation_dim, + label_dim=num_labels) + + self._pos_tag_embedding = pos_tag_embedding or None + self._dropout = InputVariationalDropout(dropout) + self._input_dropout = Dropout(input_dropout) + + representation_dim = text_field_embedder.get_output_dim() + if pos_tag_embedding is not None: + representation_dim += pos_tag_embedding.get_output_dim() + + check_dimensions_match(representation_dim, encoder.get_input_dim(), + "text field embedding dim", "encoder input dim") + check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), + "tag representation dim", "tag feedforward output dim") + check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), + "arc representation dim", "arc feedforward output dim") + + self._unlabelled_f1 = F1Measure(positive_label=1) + self._arc_loss = torch.nn.BCEWithLogitsLoss(reduce=False) + self._tag_loss = torch.nn.CrossEntropyLoss(reduce=False) + initializer(self) + + @overrides + def forward(self, # type: ignore + tokens: Dict[str, torch.LongTensor], + pos_tags: torch.LongTensor = None, + metadata: List[Dict[str, Any]] = None, + arc_tags: torch.LongTensor = None) -> Dict[str, torch.Tensor]: + # pylint: disable=arguments-differ + """ + Parameters + ---------- + tokens : Dict[str, torch.LongTensor], required + The output of ``TextField.as_array()``. + pos_tags : ``torch.LongTensor``, optional, (default = None). + The output of a ``SequenceLabelField`` containing POS tags. + arc_tags : torch.LongTensor, optional (default = None) + A torch tensor representing the sequence of integer indices denoting the parent of every + word in the dependency parse. Has shape ``(batch_size, sequence_length, sequence_length)``. + + Returns + ------- + An output dictionary. + """ + embedded_text_input = self.text_field_embedder(tokens) + if pos_tags is not None and self._pos_tag_embedding is not None: + embedded_pos_tags = self._pos_tag_embedding(pos_tags) + embedded_text_input = torch.cat([embedded_text_input, embedded_pos_tags], -1) + elif self._pos_tag_embedding is not None: + raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.") + + mask = get_text_field_mask(tokens) + embedded_text_input = self._input_dropout(embedded_text_input) + encoded_text = self.encoder(embedded_text_input, mask) + + float_mask = mask.float() + encoded_text = self._dropout(encoded_text) + + # shape (batch_size, sequence_length, arc_representation_dim) + head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) + child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) + + # shape (batch_size, sequence_length, tag_representation_dim) + head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) + child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) + # shape (batch_size, sequence_length, sequence_length) + arc_scores = self.arc_attention(head_arc_representation, + child_arc_representation) + # shape (batch_size, num_tags, sequence_length, sequence_length) + arc_tag_logits = self.tag_bilinear(head_tag_representation, + child_tag_representation) + # Switch to (batch_size, sequence_length, sequence_length, num_tags) + arc_tag_logits = arc_tag_logits.permute(0, 2, 3, 1).contiguous() + + minus_inf = -1e8 + minus_mask = (1 - float_mask) * minus_inf + arc_scores = arc_scores + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) + + arc_probs, arc_tag_probs = self._greedy_decode(arc_scores, + arc_tag_logits, + mask) + + output_dict = { + "arc_probs": arc_probs, + "arc_tag_probs": arc_tag_probs, + "mask": mask, + "tokens": [meta["tokens"] for meta in metadata], + } + + if arc_tags is not None: + arc_nll, tag_nll = self._construct_loss(arc_scores=arc_scores, + arc_tag_logits=arc_tag_logits, + arc_tags=arc_tags, + mask=mask) + output_dict["loss"] = arc_nll + tag_nll + output_dict["arc_loss"] = arc_nll + output_dict["tag_loss"] = tag_nll + + # Make the arc tags not have negative values anywhere + # (by default, no edge is indicated with -1). + arc_indices = (arc_tags != -1).float() + tag_mask = float_mask.unsqueeze(1) * float_mask.unsqueeze(2) + one_minus_arc_probs = 1 - arc_probs + # We stack scores here because the f1 measure expects a + # distribution, rather than a single value. + self._unlabelled_f1(torch.stack([one_minus_arc_probs, arc_probs], -1), arc_indices, tag_mask) + + return output_dict + + @overrides + def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + arc_tag_probs = output_dict["arc_tag_probs"].cpu().detach().numpy() + arc_probs = output_dict["arc_probs"].cpu().detach().numpy() + mask = output_dict["mask"] + lengths = get_lengths_from_binary_sequence_mask(mask) + arcs = [] + arc_tags = [] + for instance_arc_probs, instance_arc_tag_probs, length in zip(arc_probs, arc_tag_probs, lengths): + + arc_matrix = instance_arc_probs > self.edge_prediction_threshold + edges = [] + edge_tags = [] + for i in range(length): + for j in range(length): + if arc_matrix[i, j] == 1: + edges.append((i, j)) + tag = instance_arc_tag_probs[i, j].argmax(-1) + edge_tags.append(self.vocab.get_token_from_index(tag, "labels")) + arcs.append(edges) + arc_tags.append(edge_tags) + + output_dict["arcs"] = arcs + output_dict["arc_tags"] = arc_tags + return output_dict + + def _construct_loss(self, + arc_scores: torch.Tensor, + arc_tag_logits: torch.Tensor, + arc_tags: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the arc and tag loss for an adjacency matrix. + + Parameters + ---------- + arc_scores : ``torch.Tensor``, required. + A tensor of shape (batch_size, sequence_length, sequence_length) used to generate a + binary classification decision for whether an edge is present between two words. + arc_tag_logits : ``torch.Tensor``, required. + A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to generate + a distribution over edge tags for a given edge. + arc_tags : ``torch.Tensor``, required. + A tensor of shape (batch_size, sequence_length, sequence_length). + The labels for every arc. + mask : ``torch.Tensor``, required. + A mask of shape (batch_size, sequence_length), denoting unpadded + elements in the sequence. + + Returns + ------- + arc_nll : ``torch.Tensor``, required. + The negative log likelihood from the arc loss. + tag_nll : ``torch.Tensor``, required. + The negative log likelihood from the arc tag loss. + """ + float_mask = mask.float() + arc_indices = (arc_tags != -1).float() + # Make the arc tags not have negative values anywhere + # (by default, no edge is indicated with -1). + arc_tags = arc_tags * arc_indices + arc_nll = self._arc_loss(arc_scores, arc_indices) * float_mask.unsqueeze(1) * float_mask.unsqueeze(2) + # We want the mask for the tags to only include the unmasked words + # and we only care about the loss with respect to the gold arcs. + tag_mask = float_mask.unsqueeze(1) * float_mask.unsqueeze(2) * arc_indices + + batch_size, sequence_length, _, num_tags = arc_tag_logits.size() + original_shape = [batch_size, sequence_length, sequence_length] + reshaped_logits = arc_tag_logits.view(-1, num_tags) + reshaped_tags = arc_tags.view(-1) + tag_nll = self._tag_loss(reshaped_logits, reshaped_tags.long()).view(original_shape) * tag_mask + + valid_positions = tag_mask.sum() + + arc_nll = arc_nll.sum() / valid_positions.float() + tag_nll = tag_nll.sum() / valid_positions.float() + return arc_nll, tag_nll + + @staticmethod + def _greedy_decode(arc_scores: torch.Tensor, + arc_tag_logits: torch.Tensor, + mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decodes the head and head tag predictions by decoding the unlabeled arcs + independently for each word and then again, predicting the head tags of + these greedily chosen arcs indpendently. + + Parameters + ---------- + arc_scores : ``torch.Tensor``, required. + A tensor of shape (batch_size, sequence_length, sequence_length) used to generate + a distribution over attachements of a given word to all other words. + arc_tag_logits : ``torch.Tensor``, required. + A tensor of shape (batch_size, sequence_length, sequence_length, num_tags) used to + generate a distribution over tags for each arc. + mask : ``torch.Tensor``, required. + A mask of shape (batch_size, sequence_length). + + Returns + ------- + arc_probs : ``torch.Tensor`` + A tensor of shape (batch_size, sequence_length, sequence_length) representing the + probability of an arc being present for this edge. + arc_tag_probs : ``torch.Tensor`` + A tensor of shape (batch_size, sequence_length, sequence_length, sequence_length) + representing the distribution over edge tags for a given edge. + """ + # Mask the diagonal, because we don't self edges. + inf_diagonal_mask = torch.diag(arc_scores.new(mask.size(1)).fill_(-numpy.inf)) + arc_scores = arc_scores + inf_diagonal_mask + # shape (batch_size, sequence_length, sequence_length, num_tags) + arc_tag_logits = arc_tag_logits + inf_diagonal_mask.unsqueeze(0).unsqueeze(-1) + # Mask padded tokens, because we only want to consider actual word -> word edges. + minus_mask = (1 - mask).byte().unsqueeze(2) + arc_scores.masked_fill_(minus_mask, -numpy.inf) + arc_tag_logits.masked_fill_(minus_mask.unsqueeze(-1), -numpy.inf) + # shape (batch_size, sequence_length, sequence_length) + arc_probs = arc_scores.sigmoid() + # shape (batch_size, sequence_length, sequence_length, num_tags) + arc_tag_probs = torch.nn.functional.softmax(arc_tag_logits, dim=-1) + return arc_probs, arc_tag_probs + + @overrides + def get_metrics(self, reset: bool = False) -> Dict[str, float]: + metrics = {} + precision, recall, f1_measure = self._unlabelled_f1.get_metric(reset) + metrics["precision"] = precision + metrics["recall"] = recall + metrics["f1"] = f1_measure + return metrics diff --git a/allennlp/modules/matrix_attention/bilinear_matrix_attention.py b/allennlp/modules/matrix_attention/bilinear_matrix_attention.py index f3e495c6d52..9d06978e2f3 100644 --- a/allennlp/modules/matrix_attention/bilinear_matrix_attention.py +++ b/allennlp/modules/matrix_attention/bilinear_matrix_attention.py @@ -30,17 +30,26 @@ class BilinearMatrixAttention(MatrixAttention): If True, we add biases to the inputs such that the final computation is equivelent to the original bilinear matrix multiplication plus a projection of both inputs. + label_dim : ``int``, optional (default = 1) + The number of output classes. Typically in an attention setting this will be one, + but this parameter allows this class to function as an equivelent to ``torch.nn.Bilinear`` + for matrices, rather than vectors. """ def __init__(self, matrix_1_dim: int, matrix_2_dim: int, activation: Activation = None, - use_input_biases: bool = False) -> None: + use_input_biases: bool = False, + label_dim: int = 1) -> None: super().__init__() if use_input_biases: matrix_1_dim += 1 matrix_2_dim += 1 - self._weight_matrix = Parameter(torch.Tensor(matrix_1_dim, matrix_2_dim)) + + if label_dim == 1: + self._weight_matrix = Parameter(torch.Tensor(matrix_1_dim, matrix_2_dim)) + else: + self._weight_matrix = Parameter(torch.Tensor(label_dim, matrix_1_dim, matrix_2_dim)) self._bias = Parameter(torch.Tensor(1)) self._activation = activation or Activation.by_name('linear')() @@ -60,6 +69,10 @@ def forward(self, matrix_1: torch.Tensor, matrix_2: torch.Tensor) -> torch.Tenso matrix_1 = torch.cat([matrix_1, bias1], -1) matrix_2 = torch.cat([matrix_2, bias2], -1) - intermediate = torch.matmul(matrix_1.unsqueeze(1), self._weight_matrix.unsqueeze(0)) + + weight = self._weight_matrix + if weight.dim() == 2: + weight = weight.unsqueeze(0) + intermediate = torch.matmul(matrix_1.unsqueeze(1), weight) final = torch.matmul(intermediate, matrix_2.unsqueeze(1).transpose(2, 3)) return self._activation(final.squeeze(1) + self._bias) diff --git a/allennlp/tests/data/dataset_readers/semantic_dependency_parsing.py b/allennlp/tests/data/dataset_readers/semantic_dependency_parsing.py new file mode 100644 index 00000000000..7f073205b14 --- /dev/null +++ b/allennlp/tests/data/dataset_readers/semantic_dependency_parsing.py @@ -0,0 +1,33 @@ +# pylint: disable=no-self-use,invalid-name +from allennlp.data.dataset_readers.semantic_dependency_parsing import SemanticDependenciesDatasetReader +from allennlp.common.util import ensure_list +from allennlp.common.testing import AllenNlpTestCase + + +class TestSemanticDependencyParsingDatasetReader: + def test_read_from_file(self): + reader = SemanticDependenciesDatasetReader() + instances = reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'dm.sdp') + instances = ensure_list(instances) + + instance = instances[0] + arcs = instance.fields["arc_tags"] + tokens = [x.text for x in instance.fields["tokens"].tokens] + assert tokens == ['Pierre', 'Vinken', ',', '61', 'years', 'old', ',', + 'will', 'join', 'the', 'board', 'as', 'a', + 'nonexecutive', 'director', 'Nov.', '29', '.'] + assert arcs.indices == [(1, 0), (1, 5), (1, 8), (4, 3), (5, 4), + (8, 11), (8, 16), (10, 8), (10, 9), + (14, 11), (14, 12), (14, 13), (16, 15)] + assert arcs.labels == ['compound', 'ARG1', 'ARG1', 'ARG1', 'measure', + 'ARG1', 'loc', 'ARG2', 'BV', 'ARG2', 'BV', 'ARG1', 'of'] + + instance = instances[1] + arcs = instance.fields["arc_tags"] + tokens = [x.text for x in instance.fields["tokens"].tokens] + assert tokens == ['Mr.', 'Vinken', 'is', 'chairman', 'of', 'Elsevier', + 'N.V.', ',', 'the', 'Dutch', 'publishing', 'group', '.'] + assert arcs.indices == [(1, 0), (1, 2), (3, 2), (3, 4), (5, 4), (5, 6), + (5, 11), (11, 8), (11, 9), (11, 10)] + assert arcs.labels == ['compound', 'ARG1', 'ARG2', 'ARG1', 'ARG2', 'compound', + 'appos', 'BV', 'ARG1', 'compound'] diff --git a/allennlp/tests/data/fields/adjacency_field_test.py b/allennlp/tests/data/fields/adjacency_field_test.py new file mode 100644 index 00000000000..62e94968b76 --- /dev/null +++ b/allennlp/tests/data/fields/adjacency_field_test.py @@ -0,0 +1,54 @@ +# pylint: disable=invalid-name +import pytest +import numpy +from allennlp.common.checks import ConfigurationError +from allennlp.common.testing import AllenNlpTestCase +from allennlp.data.fields import AdjacencyField, TextField +from allennlp.data.token_indexers import SingleIdTokenIndexer +from allennlp.data import Vocabulary, Token + + +class TestAdjacencyField(AllenNlpTestCase): + + def setUp(self): + super().setUp() + self.text = TextField([Token(t) for t in ["here", "is", "a", "sentence", "."]], + {"words": SingleIdTokenIndexer("words")}) + + def test_adjacency_field_can_index_with_vocab(self): + vocab = Vocabulary() + vocab.add_token_to_namespace("a", namespace="labels") + vocab.add_token_to_namespace("b", namespace="labels") + vocab.add_token_to_namespace("c", namespace="labels") + + labels = ["a", "b"] + indices = [(0, 1), (2, 1)] + adjacency_field = AdjacencyField(indices, self.text, labels) + adjacency_field.index(vocab) + tensor = adjacency_field.as_tensor(adjacency_field.get_padding_lengths()) + numpy.testing.assert_equal(tensor.numpy(), numpy.array([[-1, 0, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, 1, -1, -1, -1], + [-1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1]])) + + def test_adjacency_field_raises_with_out_of_bounds_indices(self): + with pytest.raises(ConfigurationError): + _ = AdjacencyField([(0, 24)], self.text) + + def test_adjacency_field_raises_with_mismatching_labels_for_indices(self): + with pytest.raises(ConfigurationError): + _ = AdjacencyField([(0, 1), (0, 2)], self.text, ["label1"]) + + def test_adjacency_field_raises_with_duplicate_indices(self): + with pytest.raises(ConfigurationError): + _ = AdjacencyField([(0, 1), (0, 1)], self.text, ["label1"]) + + def test_adjacency_field_empty_field_works(self): + field = AdjacencyField([(0, 1)], self.text) + empty_field = field.empty_field() + assert empty_field.indices == [] + + def test_printing_doesnt_crash(self): + adjacency_field = AdjacencyField([(0, 1)], self.text, ["label1"]) + print(adjacency_field) diff --git a/allennlp/tests/fixtures/data/dm.sdp b/allennlp/tests/fixtures/data/dm.sdp new file mode 100644 index 00000000000..a624370840c --- /dev/null +++ b/allennlp/tests/fixtures/data/dm.sdp @@ -0,0 +1,35 @@ +#SDP 2015 +#20001001 +1 Pierre Pierre NNP 2 nn - + named:x-c _ _ _ _ _ _ _ _ _ _ _ +2 Vinken _generic_proper_ne_ NNP 9 nsubj - - named:x-c compound _ _ ARG1 ARG1 _ _ _ _ _ _ +3 , _ , 2 punct - - _ _ _ _ _ _ _ _ _ _ _ _ +4 61 _generic_card_ne_ CD 5 num - + card:i-i-c _ _ _ _ _ _ _ _ _ _ _ +5 years year NNS 6 npadvmod - + n:x _ ARG1 _ _ _ _ _ _ _ _ _ +6 old old JJ 2 amod - + a:e-p _ _ measure _ _ _ _ _ _ _ _ +7 , _ , 2 punct - - _ _ _ _ _ _ _ _ _ _ _ _ +8 will will MD 9 aux - - _ _ _ _ _ _ _ _ _ _ _ _ +9 join join VB 0 root + + v:e-i-p _ _ _ _ _ _ ARG1 _ _ _ loc +10 the the DT 11 det - + q:i-h-h _ _ _ _ _ _ _ _ _ _ _ +11 board board NN 9 dobj - - n_of:x-i _ _ _ _ ARG2 BV _ _ _ _ _ +12 as as IN 9 prep - + p:e-u-i _ _ _ _ _ _ _ _ _ _ _ +13 a a DT 15 det - + q:i-h-h _ _ _ _ _ _ _ _ _ _ _ +14 nonexecutive _generic_jj_ JJ 15 amod - + a:e-u _ _ _ _ _ _ _ _ _ _ _ +15 director director NN 12 pobj - - n_of:x-i _ _ _ _ _ _ ARG2 BV ARG1 _ _ +16 Nov. Nov. NNP 9 tmod - + mofy:x-c _ _ _ _ _ _ _ _ _ _ _ +17 29 _generic_dom_card_ne_ CD 16 num - + dofm:x-c _ _ _ _ _ _ _ _ _ of _ +18 . _ . 9 punct - - _ _ _ _ _ _ _ _ _ _ _ _ + +#20001002 +1 Mr. Mr. NNP 2 nn - + n:x _ _ _ _ _ _ _ _ +2 Vinken _generic_proper_ne_ NNP 4 nsubj - - named:x-c compound ARG1 _ _ _ _ _ _ +3 is is VBZ 4 cop + + v_id:e-p-i _ _ _ _ _ _ _ _ +4 chairman chairman NN 0 root - - n_of:x _ ARG2 ARG1 _ _ _ _ _ +5 of of IN 4 prep - + p:e-x-i _ _ _ _ _ _ _ _ +6 Elsevier _generic_proper_ne_ NNP 7 nn - - named:x-c _ _ ARG2 compound _ _ _ appos +7 N.V. N.V. NNP 5 pobj - + n:x _ _ _ _ _ _ _ _ +8 , _ , 7 punct - - _ _ _ _ _ _ _ _ _ +9 the the DT 12 det - + q:i-h-h _ _ _ _ _ _ _ _ +10 Dutch Dutch JJ 12 amod - + a:e-p _ _ _ _ _ _ _ _ +11 publishing publish NN 12 nn - + v:e-i-p _ _ _ _ _ _ _ _ +12 group group NN 7 appos - + n_of:x _ _ _ _ BV ARG1 compound _ +13 . _ . 4 punct - - _ _ _ _ _ _ _ _ _ diff --git a/allennlp/tests/fixtures/graph_parser/experiment.json b/allennlp/tests/fixtures/graph_parser/experiment.json new file mode 100644 index 00000000000..9dfa5497ec0 --- /dev/null +++ b/allennlp/tests/fixtures/graph_parser/experiment.json @@ -0,0 +1,44 @@ +{ + "dataset_reader":{ + "type":"semantic_dependencies" + }, + "train_data_path": "allennlp/tests/fixtures/data/dm.sdp", + "validation_data_path": "allennlp/tests/fixtures/data/dm.sdp", + "model": { + "type": "graph_parser", + "text_field_embedder": { + "tokens": { + "type": "embedding", + "embedding_dim": 2, + "trainable": true + } + }, + "encoder": { + "type": "lstm", + "input_size": 2, + "hidden_size": 4, + "num_layers": 1 + }, + "arc_representation_dim": 3, + "tag_representation_dim": 3 + }, + + "iterator": { + "type": "bucket", + "sorting_keys": [["tokens", "num_tokens"]], + "padding_noise": 0.0, + "batch_size" : 5 + }, + "trainer": { + "num_epochs": 1, + "grad_norm": 1.0, + "patience": 500, + "cuda_device": -1, + "optimizer": { + "type": "adadelta", + "lr": 0.000001, + "rho": 0.95 + } + } + } + diff --git a/allennlp/tests/models/graph_parser_test.py b/allennlp/tests/models/graph_parser_test.py new file mode 100644 index 00000000000..b332d09cc8a --- /dev/null +++ b/allennlp/tests/models/graph_parser_test.py @@ -0,0 +1,27 @@ +# pylint: disable=no-self-use,invalid-name,no-value-for-parameter + + +from allennlp.common.testing.model_test_case import ModelTestCase + +class GraphParserTest(ModelTestCase): + + def setUp(self): + super(GraphParserTest, self).setUp() + self.set_up_model(self.FIXTURES_ROOT / "graph_parser" / "experiment.json", + self.FIXTURES_ROOT / "data" / "dm.sdp") + + def test_graph_parser_can_save_and_load(self): + self.ensure_model_can_train_save_and_load(self.param_file) + + def test_batch_predictions_are_consistent(self): + self.ensure_batch_predictions_are_consistent() + + def test_model_can_decode(self): + self.model.eval() + training_tensors = self.dataset.as_tensor_dict() + output_dict = self.model(**training_tensors) + decode_output_dict = self.model.decode(output_dict) + + assert set(decode_output_dict.keys()) == set(['arc_loss', 'tag_loss', 'loss', + 'arcs', 'arc_tags', 'arc_tag_probs', + 'arc_probs', 'tokens', 'mask']) diff --git a/doc/api/allennlp.data.dataset_readers.rst b/doc/api/allennlp.data.dataset_readers.rst index bdeff7ac7bb..9f90c65421a 100644 --- a/doc/api/allennlp.data.dataset_readers.rst +++ b/doc/api/allennlp.data.dataset_readers.rst @@ -21,6 +21,7 @@ allennlp.data.dataset_readers allennlp.data.dataset_readers.penn_tree_bank allennlp.data.dataset_readers.quora_paraphrase allennlp.data.dataset_readers.reading_comprehension + allennlp.data.dataset_readers.semantic_dependency_parsing allennlp.data.dataset_readers.semantic_parsing allennlp.data.dataset_readers.semantic_role_labeling allennlp.data.dataset_readers.seq2seq @@ -28,3 +29,4 @@ allennlp.data.dataset_readers allennlp.data.dataset_readers.snli allennlp.data.dataset_readers.stanford_sentiment_tree_bank allennlp.data.dataset_readers.universal_dependencies + allennlp.data.dataset_readers.quora_paraphrase diff --git a/doc/api/allennlp.data.dataset_readers.semantic_dependency_parsing.rst b/doc/api/allennlp.data.dataset_readers.semantic_dependency_parsing.rst new file mode 100644 index 00000000000..3d693faa35d --- /dev/null +++ b/doc/api/allennlp.data.dataset_readers.semantic_dependency_parsing.rst @@ -0,0 +1,7 @@ +allennlp.data.dataset_readers.semantic_dependency_parsing +========================================================= + +.. automodule:: allennlp.data.dataset_readers.semantic_dependency_parsing + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/api/allennlp.data.fields.rst b/doc/api/allennlp.data.fields.rst index 0be5e2636fb..2b3a550c93b 100644 --- a/doc/api/allennlp.data.fields.rst +++ b/doc/api/allennlp.data.fields.rst @@ -19,6 +19,7 @@ allennlp.data.fields * :ref:`SequenceField` * :ref:`SequenceLabelField` * :ref:`TextField` +* :ref:`AdjacencyField` .. _field: .. automodule:: allennlp.data.fields.field @@ -97,3 +98,9 @@ allennlp.data.fields :members: :undoc-members: :show-inheritance: + +.. _adjacency-field: +.. automodule:: allennlp.data.fields.adjacency_field + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/doc/api/allennlp.models.graph_parser.rst b/doc/api/allennlp.models.graph_parser.rst new file mode 100644 index 00000000000..ea474782112 --- /dev/null +++ b/doc/api/allennlp.models.graph_parser.rst @@ -0,0 +1,7 @@ +allennlp.models.graph_parser +============================ + +.. automodule:: allennlp.models.graph_parser + :members: + :undoc-members: + :show-inheritance: diff --git a/doc/api/allennlp.models.rst b/doc/api/allennlp.models.rst index 8d1c7086175..44b326f269c 100644 --- a/doc/api/allennlp.models.rst +++ b/doc/api/allennlp.models.rst @@ -21,6 +21,7 @@ allennlp.models allennlp.models.ensemble allennlp.models.esim allennlp.models.event2mind + allennlp.models.graph_parser allennlp.models.reading_comprehension allennlp.models.semantic_parsing allennlp.models.semantic_role_labeler diff --git a/training_config/semantic_dependencies.json b/training_config/semantic_dependencies.json new file mode 100644 index 00000000000..51a69456641 --- /dev/null +++ b/training_config/semantic_dependencies.json @@ -0,0 +1,64 @@ +{ + "dataset_reader":{ + "type":"semantic_dependencies" + }, + "train_data_path": "/home/markn/data/semantic_dependency_parsing/semeval2015_data/dm/data/english/english_dm_augmented_train.sdp", + "validation_data_path": "/home/markn/data/semantic_dependency_parsing/semeval2015_data/dm/data/english/english_dm_augmented_dev.sdp", + "test_data_path": "/home/markn/data/semantic_dependency_parsing/semeval2015_data/dm/data/english/english_id_dm_augmented_test.sdp", + "model": { + "type": "graph_parser", + "text_field_embedder": { + "tokens": { + "type": "embedding", + "embedding_dim": 100, + "pretrained_file": "/home/markn/data/glove/glove.6B/glove.6B.100d.txt", + "trainable": true, + "sparse": true + } + }, + "pos_tag_embedding":{ + "embedding_dim": 100, + "vocab_namespace": "pos", + "sparse": true + }, + "encoder": { + "type": "stacked_bidirectional_lstm", + "input_size": 200, + "hidden_size": 400, + "num_layers": 3, + "recurrent_dropout_probability": 0.3, + "use_highway": true + }, + "arc_representation_dim": 500, + "tag_representation_dim": 100, + "dropout": 0.3, + "input_dropout": 0.3, + "initializer": [ + [".*feedforward.*weight", {"type": "xavier_uniform"}], + [".*feedforward.*bias", {"type": "zero"}], + [".*tag_bilinear.*weight", {"type": "xavier_uniform"}], + [".*tag_bilinear.*bias", {"type": "zero"}], + [".*weight_ih.*", {"type": "xavier_uniform"}], + [".*weight_hh.*", {"type": "orthogonal"}], + [".*bias_ih.*", {"type": "zero"}], + [".*bias_hh.*", {"type": "lstm_hidden_bias"}]] + }, + + "iterator": { + "type": "bucket", + "sorting_keys": [["tokens", "num_tokens"]], + "batch_size" : 128 + }, + "trainer": { + "num_epochs": 80, + "grad_norm": 5.0, + "patience": 50, + "cuda_device": 0, + "validation_metric": "+LAS", + "optimizer": { + "type": "dense_sparse_adam", + "betas": [0.9, 0.9] + } + } + } +