Skip to content

Commit

Permalink
Track dev loss in ATIS model (allenai#1907)
Browse files Browse the repository at this point in the history
* get dev loss

* add test
  • Loading branch information
kl2806 authored Oct 15, 2018
1 parent c450565 commit d3a8f4f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
7 changes: 5 additions & 2 deletions allennlp/data/dataset_readers/semantic_parsing/atis.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def text_to_instance(self, # type: ignore
try:
action_sequence = world.get_action_sequence(sql_query)
except ParseError:
action_sequence = []
logger.debug(f'Parsing error')

tokenized_utterance = self._tokenizer.tokenize(utterance.lower())
Expand All @@ -159,12 +160,14 @@ def text_to_instance(self, # type: ignore

if sql_query_labels != None:
fields['sql_queries'] = MetadataField(sql_query_labels)
if action_sequence and not self._keep_if_unparseable:
if self._keep_if_unparseable or action_sequence:
for production_rule in action_sequence:
index_fields.append(IndexField(action_map[production_rule], action_field))
if not action_sequence:
index_fields = [IndexField(-1, action_field)]
action_sequence_field = ListField(index_fields)
fields['target_action_sequence'] = action_sequence_field
elif not self._keep_if_unparseable:
else:
# If we are given a SQL query, but we are unable to parse it, and we do not specify explicitly
# to keep it, then we will skip the it.
return None
Expand Down
11 changes: 11 additions & 0 deletions allennlp/tests/data/dataset_readers/semantic_parsing/atis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from allennlp.semparse.worlds import AtisWorld

class TestAtisReader(AllenNlpTestCase):
def test_atis_keep_unparseable(self):
database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db")
reader = AtisDatasetReader(database_file=database_file, keep_if_unparseable=True)
instance = reader.text_to_instance(utterances=['show me the one way flights from detroit me to westchester county'],
sql_query_labels=['this is not a query that can be parsed'])

# If we have a query that can't be parsed, we check that it only has one element in the list of index fields and
# that index is the padding index, -1.
assert len(instance.fields['target_action_sequence'].field_list) == 1
assert instance.fields['target_action_sequence'].field_list[0].sequence_index == -1

def test_atis_read_from_file(self):
data_path = AllenNlpTestCase.FIXTURES_ROOT / "data" / "atis" / "sample.json"
database_file = cached_path("https://s3-us-west-2.amazonaws.com/allennlp/datasets/atis/atis.db")
Expand Down

0 comments on commit d3a8f4f

Please sign in to comment.