Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Pengcheng Yin committed Oct 25, 2018
1 parent 753fb47 commit b01b0ee
Show file tree
Hide file tree
Showing 31 changed files with 6,966 additions and 12 deletions.
3 changes: 3 additions & 0 deletions asdl/lang/lambda_dcs/lambda_dcs_transition_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ def get_primitive_field_actions(self, realized_field):
return [GenTokenAction(realized_field.value)]
else:
return []

def is_valid_hypothesis(self, hyp, **kwargs):
return True
9 changes: 9 additions & 0 deletions asdl/lang/py/py_transition_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ def get_primitive_field_actions(self, realized_field):
actions.append(GenTokenAction(tok))

return actions

def is_valid_hypothesis(self, hyp, **kwargs):
try:
hyp_code = self.ast_to_surface_code(hyp.tree)
ast.parse(hyp_code)
self.tokenize_code(hyp_code)
except:
return False
return True
6 changes: 6 additions & 0 deletions components/action_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ def __init__(self, action=None):
# for GenToken actions only
self.copy_from_src = False
self.src_token_position = -1

def __repr__(self):
return '%s (t=%d, p_t=%d, frontier_field=%s)' % (repr(self.action),
self.t,
self.parent_t,
self.frontier_field.__repr__(True) if self.frontier_field else 'None')
66 changes: 66 additions & 0 deletions components/standalone_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import print_function
import argparse
import sys

import torch

from model.utils import get_parser_class
from datasets.utils import get_example_processor_cls


def init_args():
arg_parser = argparse.ArgumentParser()

#### General configuration ####
arg_parser.add_argument('--cuda', action='store_true', default=False, help='Use gpu')

#### decoding/validation/testing ####
arg_parser.add_argument('--load_model', default=None, type=str, help='Load a pre-trained model')
arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search')
arg_parser.add_argument('--decode_max_time_step', default=100, type=int, help='Maximum number of time steps used '
'in decoding and sampling')

args = arg_parser.parse_args()

return args


class StandaloneParser(object):
"""
a tranX parser that could parse raw input issued by end user, it is a
bundle of a `Parser` and an `ExampleProcessor`. It is useful for demo
purposes
"""

def __init__(self, dataset_name, model_path, beam_size=5, cuda=False):
parser_saved_args = torch.load(model_path,
map_location=lambda storage, loc: storage)['args']
print('load parser from [%s]' % model_path, file=sys.stderr)
parser = get_parser_class(parser_saved_args.lang).load(model_path, cuda=cuda)
parser.eval()

self.parser = parser
self.example_processor = get_example_processor_cls(dataset_name)(parser.transition_system)
self.beam_size = beam_size

def parse(self, utterance):
utterance = utterance.strip()
processed_utterance_tokens, utterance_meta = self.example_processor.pre_process_utterance(utterance)
print(processed_utterance_tokens)
hypotheses = self.parser.parse(processed_utterance_tokens, beam_size=self.beam_size)

valid_hypotheses = list(filter(lambda hyp: self.parser.transition_system.is_valid_hypothesis(hyp), hypotheses))
print(len(valid_hypotheses))

for hyp in valid_hypotheses:
self.example_processor.post_process_hypothesis(hyp, utterance_meta)

# for hyp_id, hyp in enumerate(valid_hypotheses):
# print('------------------ Hypothesis %d ------------------' % hyp_id)
# print(hyp.code)
# print(hyp.tree.to_string())
# print('Actions:')
# for action_t in hyp.action_infos:
# print(action_t.action)

return valid_hypotheses
Empty file added datasets/__init__.py
Empty file.
Empty file added datasets/atis/__init__.py
Empty file.
Empty file.
15 changes: 15 additions & 0 deletions datasets/atis/data_process/generate_number_word_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-

from num2words import num2words


def generate_variant(w):
r_list = [w]
if w.find('-') >= 0:
r_list.append(w.replace('-', ''))
return r_list

with open('../data/atis/number_word_mapping.txt', 'w') as f_out:
for i in xrange(1, 31 + 1):
print >>f_out, '\t'.join(['\t'.join(it) for it in ([str(i)], generate_variant(
num2words(i)), generate_variant(num2words(i, ordinal=True)))]).encode('utf-8')
Loading

0 comments on commit b01b0ee

Please sign in to comment.