-
Notifications
You must be signed in to change notification settings - Fork 2
/
compute_rating_stat.py
74 lines (58 loc) · 2.08 KB
/
compute_rating_stat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import json
from os.path import join, exists
import re
import argparse
import numpy as np
import pickle as pkl
def _count_data(path):
""" count number of data in the given path"""
matcher = re.compile(r'[0-9]+\.json')
match = lambda name: bool(matcher.match(name))
names = os.listdir(path)
n_data = len(list(filter(match, names)))
return n_data
def read_word_list_from_file(filename):
with open(filename) as f:
word_list = [l.strip() for l in f.readlines()]
return word_list
def main(data_dir, split):
split_dir = join(data_dir, split)
n_data = _count_data(split_dir)
all_ratings = np.zeros(n_data)
rating_count = np.zeros(5)
rating_dir = join(data_dir, 'ratings', split)
if not exists(rating_dir):
os.makedirs(rating_dir)
#zero_flag = True
for i in range(n_data):
js = json.load(open(join(split_dir, '{}.json'.format(i))))
rating = int(js['overall'])
all_ratings[i] = rating
rating_count[rating-1] += 1
#if zero_flag and rating == 1:
# print(i)
# zero_flag = False
print("Average rating: {}".format(all_ratings.mean()))
print("Rating count:")
print(rating_count)
print("Rating ratio:")
normalized_rating_count = rating_count/rating_count.sum()
print(normalized_rating_count)
print("Class weights")
print(1.0/normalized_rating_count)
all_ratings = all_ratings -1
with open(join(rating_dir, 'gold_ratings.pkl'), 'wb') as f:
pkl.dump(all_ratings, f, pkl.HIGHEST_PROTOCOL)
with open(join(rating_dir, 'rating_count.pkl'), 'wb') as f:
pkl.dump(rating_count, f, pkl.HIGHEST_PROTOCOL)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=('Preprocess review data')
)
parser.add_argument('-data_dir', type=str, action='store',
help='The directory of the data.')
parser.add_argument('-split', type=str, action='store',
help='train or val or test.')
args = parser.parse_args()
main(args.data_dir, args.split)