From d3a8f4f699c1af672ac555ba0df12cbc62369808 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Mon, 15 Oct 2018 13:01:34 -0700 Subject: [PATCH] Track dev loss in ATIS model (#1907) * get dev loss * add test --- .../data/dataset_readers/semantic_parsing/atis.py | 7 +++++-- .../dataset_readers/semantic_parsing/atis_test.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/allennlp/data/dataset_readers/semantic_parsing/atis.py b/allennlp/data/dataset_readers/semantic_parsing/atis.py index 5d321b9f09e..57c89d0e0a3 100644 --- a/allennlp/data/dataset_readers/semantic_parsing/atis.py +++ b/allennlp/data/dataset_readers/semantic_parsing/atis.py @@ -133,6 +133,7 @@ def text_to_instance(self, # type: ignore try: action_sequence = world.get_action_sequence(sql_query) except ParseError: + action_sequence = [] logger.debug(f'Parsing error') tokenized_utterance = self._tokenizer.tokenize(utterance.lower()) @@ -159,12 +160,14 @@ def text_to_instance(self, # type: ignore if sql_query_labels != None: fields['sql_queries'] = MetadataField(sql_query_labels) - if action_sequence and not self._keep_if_unparseable: + if self._keep_if_unparseable or action_sequence: for production_rule in action_sequence: index_fields.append(IndexField(action_map[production_rule], action_field)) + if not action_sequence: + index_fields = [IndexField(-1, action_field)] action_sequence_field = ListField(index_fields) fields['target_action_sequence'] = action_sequence_field - elif not self._keep_if_unparseable: + else: # If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly # to keep it, then we will skip the it. return None diff --git a/allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py b/allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py index f7fbdd0928e..9a13936c187 100644 --- a/allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py +++ b/allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py @@ -6,6 +6,17 @@ from allennlp.semparse.worlds import AtisWorld class TestAtisReader(AllenNlpTestCase): + def test_atis_keep_unparseable(self): + database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db") + reader = AtisDatasetReader(database_file=database_file, keep_if_unparseable=True) + instance = reader.text_to_instance(utterances=['show me the one way flights from detroit me to westchester county'], + sql_query_labels=['this is not a query that can be parsed']) + + # If we have a query that can't be parsed, we check that it only has one element in the list of index fields and + # that index is the padding index, -1. + assert len(instance.fields['target_action_sequence'].field_list) == 1 + assert instance.fields['target_action_sequence'].field_list[0].sequence_index == -1 + def test_atis_read_from_file(self): data_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "atis" / "sample.json" database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db")