forked from allenai/allennlp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Graph parser for semantic dependencies (allenai#1743)
- Loading branch information
Showing
18 changed files
with
887 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
allennlp/data/dataset_readers/semantic_dependency_parsing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from typing import Dict, List, Tuple | ||
import logging | ||
from overrides import overrides | ||
|
||
from allennlp.common.file_utils import cached_path | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.fields import AdjacencyField, MetadataField, SequenceLabelField | ||
from allennlp.data.fields import Field, TextField | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer | ||
from allennlp.data.tokenizers import Token | ||
from allennlp.data.instance import Instance | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
FIELDS = ["id", "form", "lemma", "pos", "head", "deprel", "top", "pred", "frame"] | ||
|
||
def parse_sentence(sentence_blob: str) -> Tuple[List[Dict[str, str]], List[Tuple[int, int]], List[str]]: | ||
""" | ||
Parses a chunk of text in the SemEval SDP format. | ||
Each word in the sentence is returned as a dictionary with the following | ||
format: | ||
'id': '1', | ||
'form': 'Pierre', | ||
'lemma': 'Pierre', | ||
'pos': 'NNP', | ||
'head': '2', # Note that this is the `syntactic` head. | ||
'deprel': 'nn', | ||
'top': '-', | ||
'pred': '+', | ||
'frame': 'named:x-c' | ||
Along with a list of arcs and their corresponding tags. Note that | ||
in semantic dependency parsing words can have more than one head | ||
(it is not a tree), meaning that the list of arcs and tags are | ||
not tied to the length of the sentence. | ||
""" | ||
annotated_sentence = [] | ||
arc_indices = [] | ||
arc_tags = [] | ||
predicates = [] | ||
|
||
lines = [line.split("\t") for line in sentence_blob.split("\n") | ||
if line and not line.strip().startswith("#")] | ||
for line_idx, line in enumerate(lines): | ||
annotated_token = {k:v for k, v in zip(FIELDS, line)} | ||
if annotated_token['pred'] == "+": | ||
predicates.append(line_idx) | ||
annotated_sentence.append(annotated_token) | ||
|
||
for line_idx, line in enumerate(lines): | ||
for predicate_idx, arg in enumerate(line[len(FIELDS):]): | ||
if arg != "_": | ||
arc_indices.append((line_idx, predicates[predicate_idx])) | ||
arc_tags.append(arg) | ||
return annotated_sentence, arc_indices, arc_tags | ||
|
||
|
||
def lazy_parse(text: str): | ||
for sentence in text.split("\n\n"): | ||
if sentence: | ||
yield parse_sentence(sentence) | ||
|
||
|
||
@DatasetReader.register("semantic_dependencies") | ||
class SemanticDependenciesDatasetReader(DatasetReader): | ||
""" | ||
Reads a file in the SemEval 2015 Task 18 (Broad-coverage Semantic Dependency Parsing) | ||
format. | ||
Parameters | ||
---------- | ||
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) | ||
The token indexers to be applied to the words TextField. | ||
""" | ||
def __init__(self, | ||
token_indexers: Dict[str, TokenIndexer] = None, | ||
lazy: bool = False) -> None: | ||
super().__init__(lazy) | ||
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} | ||
|
||
@overrides | ||
def _read(self, file_path: str): | ||
# if `file_path` is a URL, redirect to the cache | ||
file_path = cached_path(file_path) | ||
|
||
logger.info("Reading semantic dependency parsing data from: %s", file_path) | ||
|
||
with open(file_path) as sdp_file: | ||
for annotated_sentence, directed_arc_indices, arc_tags in lazy_parse(sdp_file.read()): | ||
# If there are no arc indices, skip this instance. | ||
if not directed_arc_indices: | ||
continue | ||
tokens = [word["form"] for word in annotated_sentence] | ||
pos_tags = [word["pos"] for word in annotated_sentence] | ||
yield self.text_to_instance(tokens, pos_tags, directed_arc_indices, arc_tags) | ||
|
||
@overrides | ||
def text_to_instance(self, # type: ignore | ||
tokens: List[str], | ||
pos_tags: List[str] = None, | ||
arc_indices: List[Tuple[int, int]] = None, | ||
arc_tags: List[str] = None) -> Instance: | ||
# pylint: disable=arguments-differ | ||
fields: Dict[str, Field] = {} | ||
token_field = TextField([Token(t) for t in tokens], self._token_indexers) | ||
fields["tokens"] = token_field | ||
fields["metadata"] = MetadataField({"tokens": tokens}) | ||
if pos_tags is not None: | ||
fields["pos_tags"] = SequenceLabelField(pos_tags, token_field, label_namespace="pos") | ||
if arc_indices is not None and arc_tags is not None: | ||
fields["arc_tags"] = AdjacencyField(arc_indices, token_field, arc_tags) | ||
|
||
return Instance(fields) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from typing import Dict, List, Set, Tuple | ||
import logging | ||
import textwrap | ||
|
||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.common.checks import ConfigurationError | ||
from allennlp.data.fields.field import Field | ||
from allennlp.data.fields.sequence_field import SequenceField | ||
from allennlp.data.vocabulary import Vocabulary | ||
|
||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | ||
|
||
|
||
class AdjacencyField(Field[torch.Tensor]): | ||
""" | ||
A ``AdjacencyField`` defines directed adjacency relations between elements | ||
in a :class:`~allennlp.data.fields.sequence_field.SequenceField`. | ||
Because it's a labeling of some other field, we take that field as input here | ||
and use it to determine our padding and other things. | ||
This field will get converted into an array of shape (sequence_field_length, sequence_field_length), | ||
where the (i, j)th array element is either a binary flag indicating there is an edge from i to j, | ||
or an integer label k, indicating there is a label from i to j of type k. | ||
Parameters | ||
---------- | ||
indices : ``List[Tuple[int, int]]`` | ||
sequence_field : ``SequenceField`` | ||
A field containing the sequence that this ``AdjacencyField`` is labeling. Most often, | ||
this is a ``TextField``, for tagging edge relations between tokens in a sentence. | ||
labels : ``List[str]``, optional, default = None | ||
Optional labels for the edges of the adjacency matrix. | ||
label_namespace : ``str``, optional (default='labels') | ||
The namespace to use for converting tag strings into integers. We convert tag strings to | ||
integers for you, and this parameter tells the ``Vocabulary`` object which mapping from | ||
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word). | ||
padding_value : ``int``, (optional, default = -1) | ||
The value to use as padding. | ||
""" | ||
# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens. | ||
# This warning will be repeated for every instantiation of this class (i.e for every data | ||
# instance), spewing a lot of warnings so this class variable is used to only log a single | ||
# warning per namespace. | ||
_already_warned_namespaces: Set[str] = set() | ||
|
||
def __init__(self, | ||
indices: List[Tuple[int, int]], | ||
sequence_field: SequenceField, | ||
labels: List[str] = None, | ||
label_namespace: str = 'labels', | ||
padding_value: int = -1) -> None: | ||
self.indices = indices | ||
self.labels = labels | ||
self.sequence_field = sequence_field | ||
self._label_namespace = label_namespace | ||
self._padding_value = padding_value | ||
self._indexed_labels: List[int] = None | ||
|
||
self._maybe_warn_for_namespace(label_namespace) | ||
field_length = sequence_field.sequence_length() | ||
|
||
if len(set(indices)) != len(indices): | ||
raise ConfigurationError(f"Indices must be unique, but found {indices}") | ||
|
||
if not all([0 <= index[1] < field_length and 0 <= index[0] < field_length for index in indices]): | ||
raise ConfigurationError(f"Label indices and sequence length " | ||
f"are incompatible: {indices} and {field_length}") | ||
|
||
if labels is not None and len(indices) != len(labels): | ||
raise ConfigurationError(f"Labelled indices were passed, but their lengths do not match: " | ||
f" {labels}, {indices}") | ||
|
||
def _maybe_warn_for_namespace(self, label_namespace: str) -> None: | ||
if not (self._label_namespace.endswith("labels") or self._label_namespace.endswith("tags")): | ||
if label_namespace not in self._already_warned_namespaces: | ||
logger.warning("Your label namespace was '%s'. We recommend you use a namespace " | ||
"ending with 'labels' or 'tags', so we don't add UNK and PAD tokens by " | ||
"default to your vocabulary. See documentation for " | ||
"`non_padded_namespaces` parameter in Vocabulary.", | ||
self._label_namespace) | ||
self._already_warned_namespaces.add(label_namespace) | ||
|
||
@overrides | ||
def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): | ||
if self._indexed_labels is None and self.labels is not None: | ||
for label in self.labels: | ||
counter[self._label_namespace][label] += 1 # type: ignore | ||
|
||
@overrides | ||
def index(self, vocab: Vocabulary): | ||
if self._indexed_labels is None and self.labels is not None: | ||
self._indexed_labels = [vocab.get_token_index(label, self._label_namespace) | ||
for label in self.labels] | ||
|
||
@overrides | ||
def get_padding_lengths(self) -> Dict[str, int]: | ||
return {'num_tokens': self.sequence_field.sequence_length()} | ||
|
||
@overrides | ||
def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: | ||
desired_num_tokens = padding_lengths['num_tokens'] | ||
tensor = torch.ones(desired_num_tokens, desired_num_tokens) * self._padding_value | ||
labels = self._indexed_labels or [1 for _ in range(len(self.indices))] | ||
|
||
for index, label in zip(self.indices, labels): | ||
tensor[index] = label | ||
return tensor | ||
|
||
@overrides | ||
def empty_field(self) -> 'AdjacencyField': | ||
# pylint: disable=protected-access | ||
# The empty_list here is needed for mypy | ||
empty_list: List[Tuple[int, int]] = [] | ||
adjacency_field = AdjacencyField(empty_list, | ||
self.sequence_field.empty_field(), | ||
padding_value=self._padding_value) | ||
return adjacency_field | ||
|
||
def __str__(self) -> str: | ||
length = self.sequence_field.sequence_length() | ||
formatted_labels = "".join(["\t\t" + labels + "\n" | ||
for labels in textwrap.wrap(repr(self.labels), 100)]) | ||
formatted_indices = "".join(["\t\t" + index + "\n" | ||
for index in textwrap.wrap(repr(self.indices), 100)]) | ||
return f"AdjacencyField of length {length}\n" \ | ||
f"\t\twith indices:\n {formatted_indices}\n" \ | ||
f"\t\tand labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.