-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathsplit.py
47 lines (37 loc) · 1.27 KB
/
split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import os
import re
data_dir = "data/MULTIWOZ2.4"
testListFile = []
fin = open(os.path.join(data_dir,'testListFile.json'), 'r')
for line in fin:
testListFile.append(line[:-1])
fin.close()
valListFile = []
fin = open(os.path.join(data_dir,'valListFile.json'), 'r')
for line in fin:
valListFile.append(line[:-1])
fin.close()
data = json.load(open(os.path.join(data_dir, "data.json")))
test_dials = {}
val_dials = {}
train_dials = {}
count_train, count_val, count_test = 0, 0, 0
for dialogue_name in data:
if dialogue_name in testListFile:
test_dials[dialogue_name] = data[dialogue_name]
count_test += 1
elif dialogue_name in valListFile:
val_dials[dialogue_name] = data[dialogue_name]
count_val += 1
else:
train_dials[dialogue_name] = data[dialogue_name]
count_train += 1
print("# of dialogues: Train {}, Val {}, Test {}".format(count_train, count_val, count_test))
# save all dialogues
with open(os.path.join(data_dir, 'dev_dials.json'), 'w') as f:
json.dump(val_dials, f, indent=4)
with open(os.path.join(data_dir, 'test_dials.json'), 'w') as f:
json.dump(test_dials, f, indent=4)
with open(os.path.join(data_dir, 'train_dials.json'), 'w') as f:
json.dump(train_dials, f, indent=4)