Skip to content

Commit

Permalink
revised web demo and data pre-processing pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Pengcheng Yin committed Mar 2, 2019
1 parent 41bd69e commit a895df8
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 1,127 deletions.
2 changes: 1 addition & 1 deletion asdl/lang/prolog/prolog_transition_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def is_equal_ast(this_ast, other_ast):
@Registrable.register('prolog')
class PrologTransitionSystem(TransitionSystem):
def compare_ast(self, hyp_ast, ref_ast):
raise NotImplementedError
return is_equal_ast(hyp_ast, ref_ast)

def ast_to_surface_code(self, asdl_ast):
return ast_to_prolog_expr(asdl_ast)
Expand Down
1 change: 1 addition & 0 deletions components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
if six.PY3:
from datasets.conala.evaluator import ConalaEvaluator
from datasets.wikisql.evaluator import WikiSQLEvaluator

10 changes: 9 additions & 1 deletion components/standalone_parser.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import print_function
import argparse
import sys

import six
import torch
from model import parser

from common.registerable import Registrable

from datasets.geo.example_processor import GeoQueryExampleProcessor
from datasets.atis.example_processor import ATISExampleProcessor
from datasets.django.example_processor import DjangoExampleProcessor

if six.PY3:
from datasets.conala.example_processor import ConalaExampleProcessor


class StandaloneParser(object):
"""
Expand Down
6 changes: 6 additions & 0 deletions datasets/atis/data_process/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
This data pre-processing script is authored by [Li Dong](http://homepages.inf.ed.ac.uk/s1478528/)

```
Li Dong and Mirella Lapata
Language to Logical Form with Neural Attention, ACL 2016
```
15 changes: 0 additions & 15 deletions datasets/atis/data_process/generate_number_word_mapping.py

This file was deleted.

143 changes: 8 additions & 135 deletions datasets/atis/data_process/process_atis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from .utils import *
from .utils_date_number import *


e2m_dict, m2e_dict, e2type_dict = read_entity_mention_mapping()


def q_process(_q):
is_successful = True
q = _q.strip().lower()
Expand All @@ -20,14 +22,14 @@ def q_process(_q):
const_index_dict = {}
type_index_dict = {}
while True:
q_list = filter(lambda x: len(x) > 0, ' '.join(map(lambda x: x, q.split(' '))).split(' '))
q_list = list(filter(lambda x: len(x) > 0, ' '.join(map(lambda x: x, q.split(' '))).split(' ')))
found_flag = False
for n in range(7, 0, -1):
if len(q_list) >= n:
for i in range(0, len(q_list) - n + 1):
m = ' '.join(q_list[i:i + n])
should_replace_flag = True
if m2e_dict.has_key(m):
if m in m2e_dict:
e = m2e_dict[m]
t = e2type_dict[e]
if (t == 'st' or is_state_token(m)) and (i > 0) and is_city_token(q_list[i - 1]):
Expand All @@ -47,10 +49,10 @@ def q_process(_q):
else:
should_replace_flag = False
if should_replace_flag:
if not const_index_dict.has_key(e):
if e not in const_index_dict:
type_index_dict[t] = type_index_dict.get(t, -1) + 1
const_index_dict[e] = type_index_dict[t]
if not e2type_dict.has_key(e):
if e not in e2type_dict:
e2type_dict[e] = t
q = q.replace(' %s ' % (m,), ' %s%d ' % (t, const_index_dict[e]))
found_flag = True
Expand All @@ -62,137 +64,8 @@ def q_process(_q):

q = add_padding(fix_missing_link(q, const_index_dict, e2type_dict))

q_list = filter(lambda x: len(x) > 0, ' '.join(
map(lambda x: norm_word(x), q.split(' '))).split(' '))
q_list = list(filter(lambda x: len(x) > 0, ' '.join(
map(lambda x: norm_word(x), q.split(' '))).split(' ')))

return q_list, const_index_dict, type_index_dict


def qf_process(_q, _f, split):
is_successful = True
q, f = _q, _f
# tokenize q
q = add_padding(' '.join(norm_airline(norm_daynumber(
norm_q_time(norm_dollar(q.split(' ')))))))
# tokenize f
f = add_padding(' '.join([norm_form(it) for it in filter(lambda x: len(x) > 0, f.replace(
'(', ' ( ').replace(')', ' ) ').strip().split(' '))]))

if split == 'test':
f = norm_lambda_variable(f)
f = add_padding(fix_form_type_entity_mismatch(f))

# find entities in q, and replace them with type_id
const_index_dict = {}
type_index_dict = {}
while True:
q_list = filter(lambda x: len(x) > 0, ' '.join(map(lambda x: x, q.split(' '))).split(' '))
found_flag=False
for n in range(7, 0, -1):
if len(q_list) >= n:
for i in range(0, len(q_list)-n+1):
m = ' '.join(q_list[i:i+n])
should_replace_flag=True
if m2e_dict.has_key(m):
e = m2e_dict[m]
t = e2type_dict[e]
if (t == 'st' or is_state_token(m)) and (i > 0) and is_city_token(q_list[i-1]):
should_replace_flag = False
elif (n == 1) and (m == 'may') and (i+1 < len(q_list)) and (q_list[i+1] == 'i'):
should_replace_flag = False
elif n == 1:
if m.startswith('$'):
e,t = m[1:]+':do','do'
elif m.startswith('_'):
t, _e = m[1:].split('_')
e = '%s:%s' % (_e,t)
elif is_normalized_time_mention_str(m):
e,t = convert_time_m2e(m),'ti'
else:
should_replace_flag = False
else:
should_replace_flag = False
if should_replace_flag:
if not const_index_dict.has_key(e):
type_index_dict[t] = type_index_dict.get(t, -1) + 1
const_index_dict[e] = type_index_dict[t]
if not e2type_dict.has_key(e):
e2type_dict[e] = t
q = q.replace(' %s ' % (m,), ' %s%d ' % (t, const_index_dict[e]))
found_flag=True
break
if found_flag:
break
if not found_flag:
break

q = add_padding(fix_missing_link(q, const_index_dict, e2type_dict))

# replace const entity with ``type_id''
for e_name, e_type in sort_entity_list(entity_re.findall(f.replace(' ', ' '))):
const_str='%s:%s' % (e_name, e_type)
if const_index_dict.has_key(const_str):
f = f.replace(' %s ' % (const_str,), ' %s%d ' % (e2type_dict[const_str], const_index_dict[const_str]))
else:
if e_type in set(('ci','ap','ti','rc','mn','dn')):
is_successful = False

q_list = filter(lambda x: len(x) > 0, ' '.join(
map(lambda x: norm_word(x), q.split(' '))).split(' '))
f_list = filter(lambda x: len(x) > 0, f.strip().split(' '))

return (q_list, f_list, const_index_dict, type_index_dict, is_successful)


def process_main(d, split):
with open('data/atis/%s.raw' % (split,), 'r') as f_in:
raw_list = filter(lambda x: len(x) > 0, map(
lambda x: x.strip(), f_in.read().decode('utf-8').split('\n')))
qf_list = [(raw_list[2 * i], raw_list[2 * i + 1])
for i in xrange(len(raw_list) / 2)]
l_list = []
for q, f in qf_list:
q_list, f_list, is_successful = qf_process(q, f, split)
if ((split != 'test') and is_successful) or (split == 'test'):
q_orig, f_orig, _ = stat_atis.qf_process(q, f, split)
q_orig_length, f_orig_length = len(q_orig), len(f_orig)
l_list.append((q_list, f_list, q_orig_length, f_orig_length))
l_list.sort(key=lambda x: len(x[0]))
with open(d + '%s.txt' % (split,), 'w') as f_out:
f_out.write('\n'.join(map(lambda x: '%s\t%s' %
(' '.join(x[0]), ' '.join(x[1])), l_list)).encode('utf-8'))
print('maximun length:', max([len(x[0]) for x in l_list]), max([len(x[1]) for x in l_list]))

def vocab_main(d):
cq, cf = {}, {}
l_list = []
with open(d + 'train.txt', 'r') as f_in:
l_list.extend(filter(lambda x: len(x) > 0, map(
lambda x: x.strip(), f_in.read().decode('utf-8').split('\n'))))
for l in l_list:
q, f = l.decode('utf-8').strip().split('\t')
for it in q.split(' '):
cq[it] = cq.get(it, 0) + 1
for it in f.split(' '):
cf[it] = cf.get(it, 0) + 1
with open(d + 'vocab.q.txt', 'w') as f_out:
for w, c in sorted([(k, v) for k, v in cq.iteritems()], key=lambda x: x[1], reverse=True):
print(('%s\t%d' % (w, c)).encode('utf-8'), file=f_out)
with open(d + 'vocab.f.txt', 'w') as f_out:
for w, c in sorted([(k, v) for k, v in cf.iteritems()], key=lambda x: x[1], reverse=True):
print(('%s\t%d' % (w, c)).encode('utf-8'), file=f_out)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--split", help="official split",
default=1)
parser.add_argument("-d", "--data_dir", help="input folder",
default="/disk/scratch_ssd/lidong/deep_qa/atis/")
args = parser.parse_args()

for split in ('train', 'dev', 'test'):
print(split, ':')
process_main(args.data_dir, split)

vocab_main(args.data_dir)
Loading

0 comments on commit a895df8

Please sign in to comment.