Skip to content

Commit

Permalink
Add QuaRel semantic parser (allenai#1857)
Browse files Browse the repository at this point in the history
  • Loading branch information
OyvindTafjord authored Oct 4, 2018
1 parent 8236624 commit 8ff8324
Show file tree
Hide file tree
Showing 41 changed files with 3,075 additions and 5 deletions.
4 changes: 3 additions & 1 deletion allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,11 @@ def train_model(params: Params,
if key in datasets_for_vocab_creation)
)

model = Model.from_params(vocab=vocab, params=params.pop('model'))

# Initializing the model can have side effect of expanding the vocabulary
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

model = Model.from_params(vocab=vocab, params=params.pop('model'))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)
validation_iterator_params = params.pop("validation_iterator", None)
Expand Down
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from allennlp.data.dataset_readers.quora_paraphrase import QuoraParaphraseDatasetReader
from allennlp.data.dataset_readers.semantic_parsing import (
WikiTablesDatasetReader, AtisDatasetReader, NlvrDatasetReader, TemplateText2SqlDatasetReader)
from allennlp.data.dataset_readers.semantic_parsing.quarel import QuarelDatasetReader
508 changes: 508 additions & 0 deletions allennlp/data/dataset_readers/semantic_parsing/quarel.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions allennlp/data/fields/knowledge_graph_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,13 @@ def _span_overlap_fraction(self,
# Some tables have empty cells.
return 0
seen_entity_words = set()
token_index_left = token_index
while token_index < len(tokens) and tokens[token_index].text in entity_words:
seen_entity_words.add(tokens[token_index].text)
token_index += 1
while token_index_left >= 0 and tokens[token_index_left].text in entity_words:
seen_entity_words.add(tokens[token_index_left].text)
token_index_left -= 1
return len(seen_entity_words) / len(entity_words)

def _span_lemma_overlap_fraction(self,
Expand All @@ -415,9 +419,13 @@ def _span_lemma_overlap_fraction(self,
# Some tables have empty cells.
return 0
seen_entity_lemmas = set()
token_index_left = token_index
while token_index < len(tokens) and tokens[token_index].lemma_ in entity_lemmas:
seen_entity_lemmas.add(tokens[token_index].lemma_)
token_index += 1
while token_index_left >= 0 and tokens[token_index_left].lemma_ in entity_lemmas:
seen_entity_lemmas.add(tokens[token_index_left].lemma_)
token_index_left -= 1
return len(seen_entity_lemmas) / len(entity_lemmas)

# pylint: enable=unused-argument,no-self-use
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from allennlp.models.reading_comprehension.bidaf import BidirectionalAttentionFlow
from allennlp.models.semantic_parsing.nlvr.nlvr_coverage_semantic_parser import NlvrCoverageSemanticParser
from allennlp.models.semantic_parsing.nlvr.nlvr_direct_semantic_parser import NlvrDirectSemanticParser
from allennlp.models.semantic_parsing.quarel.quarel_semantic_parser import QuarelSemanticParser
from allennlp.models.semantic_parsing.wikitables.wikitables_mml_semantic_parser import WikiTablesMmlSemanticParser
from allennlp.models.semantic_parsing.wikitables.wikitables_erm_semantic_parser import WikiTablesErmSemanticParser
from allennlp.models.semantic_parsing.atis.atis_semantic_parser import AtisSemanticParser
Expand Down
Empty file.
749 changes: 749 additions & 0 deletions allennlp/models/semantic_parsing/quarel/quarel_semantic_parser.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions allennlp/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from allennlp.predictors.event2mind import Event2MindPredictor
from allennlp.predictors.nlvr_parser import NlvrParserPredictor
from allennlp.predictors.open_information_extraction import OpenIePredictor
from allennlp.predictors.quarel_parser import QuarelParserPredictor
from allennlp.predictors.semantic_role_labeler import SemanticRoleLabelerPredictor
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor
from allennlp.predictors.simple_seq2seq import SimpleSeq2SeqPredictor
Expand Down
2 changes: 1 addition & 1 deletion allennlp/predictors/open_information_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from allennlp.data import DatasetReader, Instance
from allennlp.data.tokenizers import WordTokenizer
from allennlp.models import Model
from allennlp.service.predictors.predictor import Predictor
from allennlp.predictors.predictor import Predictor
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.tokenizers import Token

Expand Down
3 changes: 2 additions & 1 deletion allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
'dialog_qa': 'dialog_qa',
'event2mind': 'event2mind',
'simple_tagger': 'sentence-tagger',
'srl': 'semantic-role-labeling'
'srl': 'semantic-role-labeling',
'quarel_parser': 'quarel-parser'
}

class Predictor(Registrable):
Expand Down
100 changes: 100 additions & 0 deletions allennlp/predictors/quarel_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import cast, Tuple

from overrides import overrides

from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance
from allennlp.data.dataset_readers.semantic_parsing.quarel import QuarelDatasetReader
from allennlp.predictors.predictor import Predictor
from allennlp.semparse.contexts.quarel_utils import get_explanation, from_qr_spec_string
from allennlp.semparse.contexts.quarel_utils import words_from_entity_string, from_entity_cues_string


@Predictor.register('quarel-parser')
class QuarelParserPredictor(Predictor):
"""
Wrapper for the quarel_semantic_parser model.
"""
def _my_json_to_instance(self, json_dict: JsonDict) -> Tuple[Instance, JsonDict]:
"""
"""

# Make a cast here to satisfy mypy
dataset_reader = cast(QuarelDatasetReader, self._dataset_reader)

# TODO: Fix protected access usage
question_data = dataset_reader.preprocess(json_dict, predict=True)[0]

qr_spec_override = None
dynamic_entities = None
if 'entitycues' in json_dict:
entity_cues = from_entity_cues_string(json_dict['entitycues'])
dynamic_entities = dataset_reader._dynamic_entities.copy() # pylint: disable=protected-access
for entity, cues in entity_cues.items():
key = "a:" + entity
entity_strings = [words_from_entity_string(entity).lower()]
entity_strings += cues
dynamic_entities[key] = " ".join(entity_strings)

if 'qrspec' in json_dict:
qr_spec_override = from_qr_spec_string(json_dict['qrspec'])
old_entities = dynamic_entities
if old_entities is None:
old_entities = dataset_reader._dynamic_entities.copy() # pylint: disable=protected-access
dynamic_entities = {}
for qset in qr_spec_override:
for entity in qset:
key = "a:" + entity
value = old_entities.get(key, words_from_entity_string(entity).lower())
dynamic_entities[key] = value

question = question_data['question']
tokenized_question = dataset_reader._tokenizer.tokenize(question.lower()) # pylint: disable=protected-access
world_extractions = question_data.get('world_extractions')

instance = dataset_reader.text_to_instance(question,
world_extractions=world_extractions,
qr_spec_override=qr_spec_override,
dynamic_entities_override=dynamic_entities)

world_extractions_out = {"world1": "N/A", "world2": "N/A"}
if world_extractions is not None:
world_extractions_out.update(world_extractions)

extra_info = {'question': json_dict['question'],
'question_tokens': tokenized_question,
"world_extractions": world_extractions_out}
return instance, extra_info

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
instance, _ = self._my_json_to_instance(json_dict)
return instance

@overrides
def predict_json(self, inputs: JsonDict) -> JsonDict:
instance, return_dict = self._my_json_to_instance(inputs)
world = instance.fields['world'].metadata # type: ignore
outputs = self._model.forward_on_instance(instance)

answer_index = outputs['answer_index']
if answer_index == 0:
answer = "A"
elif answer_index == 1:
answer = "B"
else:
answer = "None"
outputs['answer'] = answer

return_dict.update(outputs)

if answer != "None":
explanation = get_explanation(return_dict['logical_form'],
return_dict['world_extractions'],
answer_index,
world)
else:
explanation = [{"header": "No consistent interpretation found!", "content": []}]

return_dict['explanation'] = explanation
return sanitize(return_dict)
Loading

0 comments on commit 8ff8324

Please sign in to comment.