-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathfoolConcode.py
145 lines (122 loc) · 3.96 KB
/
foolConcode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import random
import time
import re
from collections import Counter
# from nltk.translate.bleu_score import corpus_bleu
import numpy as np
from nltk.util import ngrams
# from bleu_freq import corpus_bleu, SmoothingFunction
from CodeBLEU.code_bleu import code_bleu
from bleu_ignoring import corpus_bleu, SmoothingFunction
from pygments import lex
from pygments.lexers.jvm import JavaLexer
from pygments.lexers.c_cpp import CLexer, CppLexer
from matplotlib import pyplot as plt
from ast import literal_eval as make_tuple
from pygments.token import Comment
from statistical_test import stat_test
MAXN = 4
mc = 500
sm_func = SmoothingFunction(epsilon=0.0001).method1
total = 0
with open('concode/train.json') as f:
data = list(map(lambda x: json.loads(x)['code'] ,f.read().split('\n')[:-1]))
ref = []
start_time = time.process_time()
all_ngrams = []
total_tokens = 0
for j in data:
tokenized = j.split(' ')
total_tokens += len(tokenized)
# ref.append([tokenized])
for j in range(1, MAXN+1):
n_grams = list(ngrams(tokenized, j))
all_ngrams.extend(n_grams)
freq = Counter(all_ngrams)
print(time.process_time() - start_time, 'seconds')
# print(len(all_ngrams), len(freq))
print('{} tokens'.format(total_tokens))
with open('concode/predictions.txt') as f:
tmp = f.read().split('\n')[:-1]
hyp = []
for j in tmp:
hyp.append(j.split(' '))
# with open('nexgen/tgt-test.txt') as f:
with open('concode/answers.json') as f:
tmp = list(map(lambda x: json.loads(x)['code'], f.read().split('\n')[:-1]))
for j in tmp:
ref.append([j.split(' ')])
hyp2 = []
target = []
comm_ngrams = dict(freq.most_common(mc))
most_common_dict = comm_ngrams
c = 0
fltr = []
for j in range(len(ref)):
res = []
cn = comm_ngrams.items().__iter__()
i = 1
while len(res) < len(ref[j][0]):
try:
if random.random() < 0.825:#0.82
k, v = cn.__next__()
res = list(k) + res
else:
res.append(ref[j][0][i])
i = (i+1)%len(ref[j][0])
except:
cn = comm_ngrams.items().__iter__()
hyp2.append(res)
c += 1
print(len(ref))
print('Real predictions:')
em = 0
for i, j in zip(ref, hyp):
if i[0] == j:
em += 1
print(f'Exact match: {em}')
start_time = time.process_time()
crystalbleu = corpus_bleu(
ref, hyp, smoothing_function=sm_func, ignoring=most_common_dict)
print(time.process_time() - start_time, 'seconds for CrystalBLEU')
print('CrystalBLEU:', crystalbleu)
start_time = time.process_time()
bleu_vanilla = corpus_bleu(
ref, hyp, smoothing_function=sm_func)
print(time.process_time() - start_time, 'seconds for BLEU')
print('BLEU:', bleu_vanilla)
start_time = time.process_time()
codebleu = code_bleu(
ref, hyp)
print(time.process_time() - start_time, 'seconds for CodeBLEU')
print('CodeBLEU:', codebleu)
print('--------------------------------')
print('Fake predictions:')
em = 0
for i, j in zip(ref, hyp2):
if i[0] == j:
em += 1
print(f'Exact match: {em}')
start_time = time.process_time()
crystalbleu = corpus_bleu(
ref, hyp2, smoothing_function=sm_func, ignoring=most_common_dict)
print(time.process_time() - start_time, 'seconds for CrystalBLEU')
print('CrystalBLEU:', crystalbleu)
start_time = time.process_time()
bleu_vanilla = corpus_bleu(
ref, hyp2, smoothing_function=sm_func)
print(time.process_time() - start_time, 'seconds for BLEU')
print('BLEU:', bleu_vanilla)
start_time = time.process_time()
codebleu = code_bleu(
ref, hyp2)
print(time.process_time() - start_time, 'seconds for CodeBLEU')
print('CodeBLEU:', codebleu)
# stat_test(ref, hyp, hyp2, most_common_dict)
samp = random.choices(range(len(ref)), k=10)
print('\n'.join([' '.join(ref[i][0]) for i in samp]))
print('------------------------------------------------------------------')
print('\n'.join([' '.join(hyp[i]) for i in samp]))
print('------------------------------------------------------------------')
print('\n'.join([' '.join(hyp2[i]) for i in samp]))