forked from pcyin/tranX
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Pengcheng Yin
committed
Oct 25, 2018
1 parent
753fb47
commit b01b0ee
Showing
31 changed files
with
6,966 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import print_function | ||
import argparse | ||
import sys | ||
|
||
import torch | ||
|
||
from model.utils import get_parser_class | ||
from datasets.utils import get_example_processor_cls | ||
|
||
|
||
def init_args(): | ||
arg_parser = argparse.ArgumentParser() | ||
|
||
#### General configuration #### | ||
arg_parser.add_argument('--cuda', action='store_true', default=False, help='Use gpu') | ||
|
||
#### decoding/validation/testing #### | ||
arg_parser.add_argument('--load_model', default=None, type=str, help='Load a pre-trained model') | ||
arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search') | ||
arg_parser.add_argument('--decode_max_time_step', default=100, type=int, help='Maximum number of time steps used ' | ||
'in decoding and sampling') | ||
|
||
args = arg_parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
class StandaloneParser(object): | ||
""" | ||
a tranX parser that could parse raw input issued by end user, it is a | ||
bundle of a `Parser` and an `ExampleProcessor`. It is useful for demo | ||
purposes | ||
""" | ||
|
||
def __init__(self, dataset_name, model_path, beam_size=5, cuda=False): | ||
parser_saved_args = torch.load(model_path, | ||
map_location=lambda storage, loc: storage)['args'] | ||
print('load parser from [%s]' % model_path, file=sys.stderr) | ||
parser = get_parser_class(parser_saved_args.lang).load(model_path, cuda=cuda) | ||
parser.eval() | ||
|
||
self.parser = parser | ||
self.example_processor = get_example_processor_cls(dataset_name)(parser.transition_system) | ||
self.beam_size = beam_size | ||
|
||
def parse(self, utterance): | ||
utterance = utterance.strip() | ||
processed_utterance_tokens, utterance_meta = self.example_processor.pre_process_utterance(utterance) | ||
print(processed_utterance_tokens) | ||
hypotheses = self.parser.parse(processed_utterance_tokens, beam_size=self.beam_size) | ||
|
||
valid_hypotheses = list(filter(lambda hyp: self.parser.transition_system.is_valid_hypothesis(hyp), hypotheses)) | ||
print(len(valid_hypotheses)) | ||
|
||
for hyp in valid_hypotheses: | ||
self.example_processor.post_process_hypothesis(hyp, utterance_meta) | ||
|
||
# for hyp_id, hyp in enumerate(valid_hypotheses): | ||
# print('------------------ Hypothesis %d ------------------' % hyp_id) | ||
# print(hyp.code) | ||
# print(hyp.tree.to_string()) | ||
# print('Actions:') | ||
# for action_t in hyp.action_infos: | ||
# print(action_t.action) | ||
|
||
return valid_hypotheses |
Empty file.
Empty file.
Empty file.
15 changes: 15 additions & 0 deletions
15
datasets/atis/data_process/generate_number_word_mapping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
from num2words import num2words | ||
|
||
|
||
def generate_variant(w): | ||
r_list = [w] | ||
if w.find('-') >= 0: | ||
r_list.append(w.replace('-', '')) | ||
return r_list | ||
|
||
with open('../data/atis/number_word_mapping.txt', 'w') as f_out: | ||
for i in xrange(1, 31 + 1): | ||
print >>f_out, '\t'.join(['\t'.join(it) for it in ([str(i)], generate_variant( | ||
num2words(i)), generate_variant(num2words(i, ordinal=True)))]).encode('utf-8') |
Oops, something went wrong.