Skip to content

Commit

Permalink
update pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
frankxu2004 committed Oct 19, 2019
1 parent b604f19 commit f9faae4
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 74 deletions.
134 changes: 72 additions & 62 deletions datasets/conala/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pickle

from components.action_info import get_action_infos
from datasets.conala.evaluator import ConalaEvaluator
from datasets.conala.util import *
from asdl.lang.py3.py3_transition_system import python_ast_to_asdl_ast, asdl_ast_to_python_ast, Python3TransitionSystem

Expand All @@ -16,7 +17,8 @@
from components.action_info import ActionInfo


def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, code_freq=3):
def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, code_freq=3,
mined_data_file=None, num_mined=0):
np.random.seed(1234)

asdl_text = open(grammar_file).read()
Expand All @@ -32,19 +34,12 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
dev_examples = train_examples[:200]
train_examples = train_examples[200:]

# full_train_examples = train_examples[:]
# np.random.shuffle(train_examples)
# dev_examples = []
# dev_questions = set()
# dev_examples_id = []
# for i, example in enumerate(full_train_examples):
# qid = example.meta['example_dict']['question_id']
# if qid not in dev_questions and len(dev_examples) < 200:
# dev_questions.add(qid)
# dev_examples.append(example)
# dev_examples_id.append(i)

# train_examples = [e for i, e in enumerate(full_train_examples) if i not in dev_examples_id]
if mined_data_file and num_mined > 0:
print("use mined data: ", num_mined)
mined_examples = preprocess_dataset(mined_data_file, name='mined', transition_system=transition_system,
firstk=num_mined)
train_examples += mined_examples

print(f'{len(train_examples)} training instances', file=sys.stderr)
print(f'{len(dev_examples)} dev instances', file=sys.stderr)

Expand All @@ -71,58 +66,65 @@ def preprocess_conala_dataset(train_file, test_file, grammar_file, src_freq=3, c
print('Avg action len: %d' % np.average(action_lens), file=sys.stderr)
print('Actions larger than 100: %d' % len(list(filter(lambda x: x > 100, action_lens))), file=sys.stderr)

pickle.dump(train_examples, open('data/conala/train.var_str_sep.bin', 'wb'))
pickle.dump(full_train_examples, open('data/conala/train.var_str_sep.full.bin', 'wb'))
pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.bin', 'wb'))
pickle.dump(test_examples, open('data/conala/test.var_str_sep.bin', 'wb'))
pickle.dump(vocab, open('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.bin' % (src_freq, code_freq), 'wb'))
pickle.dump(train_examples, open('data/conala/train.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
pickle.dump(full_train_examples, open('data/conala/train.var_str_sep.full.mined_{}.bin'.format(num_mined), 'wb'))
pickle.dump(dev_examples, open('data/conala/dev.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
pickle.dump(test_examples, open('data/conala/test.var_str_sep.mined_{}.bin'.format(num_mined), 'wb'))
pickle.dump(vocab, open('data/conala/vocab.var_str_sep.new_dev.src_freq%d.code_freq%d.mined_%s.bin' % (src_freq, code_freq, num_mined), 'wb'))


def preprocess_dataset(file_path, transition_system, name='train'):
dataset = json.load(open(file_path))
def preprocess_dataset(file_path, transition_system, name='train', firstk=None):
try:
dataset = json.load(open(file_path))
except:
dataset = [json.loads(jline) for jline in open(file_path).readlines()]
examples = []
evaluator = ConalaEvaluator(transition_system)

f = open(file_path + '.debug', 'w')

for i, example_json in enumerate(dataset):
example_dict = preprocess_example(example_json)
if example_json['question_id'] in (18351951, 9497290, 19641579, 32283692):
print(example_json['question_id'])
if firstk and i >= firstk:
break
try:
example_dict = preprocess_example(example_json)
if example_json['question_id'] in (18351951, 9497290, 19641579, 32283692):
print(example_json['question_id'])
continue

python_ast = ast.parse(example_dict['canonical_snippet'])
canonical_code = astor.to_source(python_ast).strip()
tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
tgt_actions = transition_system.get_actions(tgt_ast)

# sanity check
hyp = Hypothesis()
for t, action in enumerate(tgt_actions):
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
if isinstance(action, ApplyRuleAction):
assert action.production in transition_system.get_valid_continuating_productions(hyp)

p_t = -1
f_t = None
if hyp.frontier_node:
p_t = hyp.frontier_node.created_time
f_t = hyp.frontier_field.field.__repr__(plain=True)

# print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
hyp = hyp.clone_and_apply_action(action)

assert hyp.frontier_node is None and hyp.frontier_field is None
hyp.code = code_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
assert code_from_hyp == canonical_code

decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp, example_dict['slot_map'])
assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp))
assert transition_system.compare_ast(transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
transition_system.surface_code_to_ast(example_json['snippet']))

tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)
except:
continue

python_ast = ast.parse(example_dict['canonical_snippet'])
canonical_code = astor.to_source(python_ast).strip()
tgt_ast = python_ast_to_asdl_ast(python_ast, transition_system.grammar)
tgt_actions = transition_system.get_actions(tgt_ast)

# sanity check
hyp = Hypothesis()
for t, action in enumerate(tgt_actions):
assert action.__class__ in transition_system.get_valid_continuation_types(hyp)
if isinstance(action, ApplyRuleAction):
assert action.production in transition_system.get_valid_continuating_productions(hyp)

p_t = -1
f_t = None
if hyp.frontier_node:
p_t = hyp.frontier_node.created_time
f_t = hyp.frontier_field.field.__repr__(plain=True)

# print('\t[%d] %s, frontier field: %s, parent: %d' % (t, action, f_t, p_t))
hyp = hyp.clone_and_apply_action(action)

assert hyp.frontier_node is None and hyp.frontier_field is None
hyp.code = code_from_hyp = astor.to_source(asdl_ast_to_python_ast(hyp.tree, transition_system.grammar)).strip()
assert code_from_hyp == canonical_code

decanonicalized_code_from_hyp = decanonicalize_code(code_from_hyp, example_dict['slot_map'])
assert compare_ast(ast.parse(example_json['snippet']), ast.parse(decanonicalized_code_from_hyp))
assert transition_system.compare_ast(transition_system.surface_code_to_ast(decanonicalized_code_from_hyp),
transition_system.surface_code_to_ast(example_json['snippet']))

tgt_action_infos = get_action_infos(example_dict['intent_tokens'], tgt_actions)

example = Example(idx=f'{i}-{example_json["question_id"]}',
src_sent=example_dict['intent_tokens'],
tgt_actions=tgt_action_infos,
Expand All @@ -136,7 +138,10 @@ def preprocess_dataset(file_path, transition_system, name='train'):

# log!
f.write(f'Example: {example.idx}\n')
f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
if 'rewritten_intent' in example.meta['example_dict']:
f.write(f"Original Utterance: {example.meta['example_dict']['rewritten_intent']}\n")
else:
f.write(f"Original Utterance: {example.meta['example_dict']['intent']}\n")
f.write(f"Original Snippet: {example.meta['example_dict']['snippet']}\n")
f.write(f"\n")
f.write(f"Utterance: {' '.join(example.src_sent)}\n")
Expand All @@ -150,9 +155,11 @@ def preprocess_dataset(file_path, transition_system, name='train'):

def preprocess_example(example_json):
intent = example_json['intent']
rewritten_intent = example_json['rewritten_intent']
if 'rewritten_intent' in example_json:
rewritten_intent = example_json['rewritten_intent']
else:
rewritten_intent = None
snippet = example_json['snippet']
question_id = example_json['question_id']

if rewritten_intent is None:
rewritten_intent = intent
Expand Down Expand Up @@ -190,8 +197,11 @@ def generate_vocab_for_paraphrase_model(vocab_path, save_path):

if __name__ == '__main__':
# the json files can be download from http://conala-corpus.github.io
preprocess_conala_dataset(train_file='data/conala/conala-train.json',
for num in (10000, 20000):
preprocess_conala_dataset(train_file='data/conala/conala-train.json',
test_file='data/conala/conala-test.json',
grammar_file='asdl/lang/py3/py3_asdl.simplified.txt', src_freq=3, code_freq=3)
mined_data_file='data/conala/conala-mined.jsonl',
grammar_file='asdl/lang/py3/py3_asdl.simplified.txt',
src_freq=3, code_freq=3, num_mined=num)

# generate_vocab_for_paraphrase_model('data/conala/vocab.src_freq3.code_freq3.bin', 'data/conala/vocab.para.src_freq3.code_freq3.bin')
13 changes: 11 additions & 2 deletions datasets/conala/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import csv

from components.evaluator import Evaluator
from common.registerable import Registrable
from components.dataset import Dataset
Expand Down Expand Up @@ -32,7 +34,10 @@ def get_sentence_bleu(self, example, hyp):
tokenize_for_bleu_eval(hyp.decanonical_code),
smoothing_function=SmoothingFunction().method3)

def evaluate_dataset(self, dataset, decode_results, fast_mode=False):

def evaluate_dataset(self, dataset, decode_results, fast_mode=False, args=None):
if args.save_decode_to:
csv_writer = csv.writer(open(args.save_decode_to + '.csv', 'w'))
examples = dataset.examples if isinstance(dataset, Dataset) else dataset
assert len(examples) == len(decode_results)

Expand Down Expand Up @@ -88,7 +93,11 @@ def evaluate_dataset(self, dataset, decode_results, fast_mode=False):

top_decanonical_code_tokens = hyp_list[0].decanonical_code_tokens
sent_bleu_score = hyp_list[0].bleu_score

# write results to file
if args.save_decode_to:
csv_writer.writerow([" ".join(example.src_sent),
" ".join(example.reference_code_tokens),
" ".join(top_decanonical_code_tokens)])
best_hyp_idx = np.argmax(example_hyp_bleu_scores)
oracle_sent_bleu = example_hyp_bleu_scores[best_hyp_idx]
_best_hyp_code_tokens = hyp_list[best_hyp_idx].decanonical_code_tokens
Expand Down
2 changes: 1 addition & 1 deletion evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def decode(examples, model, args, verbose=False, **kwargs):
def evaluate(examples, parser, evaluator, args, verbose=False, return_decode_result=False, eval_top_pred_only=False):
decode_results = decode(examples, parser, args, verbose=verbose)

eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only)
eval_result = evaluator.evaluate_dataset(examples, decode_results, fast_mode=eval_top_pred_only, args=args)

if return_decode_result:
return eval_result, decode_results
Expand Down
1 change: 1 addition & 0 deletions exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def train(args):
model.save(model_file)

# perform validation
is_better = False
if args.dev_file:
if epoch % args.valid_every_epoch == 0:
print('[Epoch %d] begin validation' % epoch, file=sys.stderr)
Expand Down
7 changes: 5 additions & 2 deletions model/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def to_input_variable(sequences, vocab, cuda=False, training=True, append_bounda

word_ids = word2id(sequences, vocab)
sents_t = input_transpose(word_ids, vocab['<pad>'])

sents_var = Variable(torch.LongTensor(sents_t), volatile=(not training), requires_grad=False)
if training:
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
else:
with torch.no_grad():
sents_var = Variable(torch.LongTensor(sents_t), requires_grad=False)
if cuda:
sents_var = sents_var.cuda()

Expand Down
6 changes: 4 additions & 2 deletions model/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def parse(self, src_sent, context=None, beam_size=5, debug=False):

zero_action_embed = Variable(self.new_tensor(args.action_embed_size).zero_())

hyp_scores = Variable(self.new_tensor([0.]), volatile=True)
with torch.no_grad():
hyp_scores = Variable(self.new_tensor([0.]))

# For computing copy probabilities, we marginalize over tokens with the same surface form
# `aggregated_primitive_tokens` stores the position of occurrence of each source token
Expand All @@ -525,7 +526,8 @@ def parse(self, src_sent, context=None, beam_size=5, debug=False):
exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, src_encodings_att_linear.size(1), src_encodings_att_linear.size(2))

if t == 0:
x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_(), volatile=True)
with torch.no_grad():
x = Variable(self.new_tensor(1, self.decoder_lstm.input_size).zero_())
if args.no_parent_field_type_embed is False:
offset = args.action_embed_size # prev_action
offset += args.att_vec_size * (not args.no_input_feed)
Expand Down
Empty file modified scripts/atis/train.sh
100755 → 100644
Empty file.
11 changes: 6 additions & 5 deletions scripts/conala/train.sh
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#!/bin/bash
set -e

seed=${1:-0}
vocab="data/conala/vocab.var_str_sep.src_freq3.code_freq3.bin"
train_file="data/conala/train.var_str_sep.bin"
dev_file="data/conala/dev.var_str_sep.bin"
seed=0
mined_num=$1
vocab="data/conala/vocab.var_str_sep.new_dev.src_freq3.code_freq3.mined_${mined_num}.bin"
train_file="data/conala/train.var_str_sep.mined_${mined_num}.bin"
dev_file="data/conala/dev.var_str_sep.mined_${mined_num}.bin"
dropout=0.3
hidden_size=256
embed_size=128
Expand All @@ -17,7 +18,7 @@ lr_decay=0.5
beam_size=15
lstm='lstm' # lstm
lr_decay_after_epoch=15
model_name=model.sup.conala.${lstm}.hidden${hidden_size}.embed${embed_size}.action${action_embed_size}.field${field_embed_size}.type${type_embed_size}.dr${dropout}.lr${lr}.lr_de${lr_decay}.lr_da${lr_decay_after_epoch}.beam${beam_size}.$(basename ${vocab}).$(basename ${train_file}).glorot.par_state.seed${seed}
model_name=model.sup.conala.${lstm}.hidden${hidden_size}.embed${embed_size}.action${action_embed_size}.field${field_embed_size}.type${type_embed_size}.dr${dropout}.lr${lr}.lr_de${lr_decay}.lr_da${lr_decay_after_epoch}.beam${beam_size}.$(basename ${vocab}).$(basename ${train_file}).glorot.par_state.seed${seed}.mined_${mined_num}

echo "**** Writing results to logs/conala/${model_name}.log ****"
mkdir -p logs/conala
Expand Down
Empty file modified scripts/django/train.sh
100755 → 100644
Empty file.
Empty file modified scripts/geo/test.sh
100755 → 100644
Empty file.
Empty file modified scripts/geo/train.sh
100755 → 100644
Empty file.
Empty file modified scripts/jobs/test.sh
100755 → 100644
Empty file.
Empty file modified scripts/jobs/train.sh
100755 → 100644
Empty file.
Empty file modified scripts/wikisql/test.sh
100755 → 100644
Empty file.
Empty file modified scripts/wikisql/train.sh
100755 → 100644
Empty file.
Empty file modified server/static/d3Tree.js
100755 → 100644
Empty file.
Empty file modified server/static/parser.js
100755 → 100644
Empty file.
Empty file modified server/static/tree-viewer.css
100755 → 100644
Empty file.

0 comments on commit f9faae4

Please sign in to comment.