-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfetch_data.py
169 lines (148 loc) · 6.2 KB
/
fetch_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import print_function
import argparse
import datetime
import json
import os
DATA_DIR = os.path.join(os.path.dirname(
os.path.realpath(__file__)), 'experiment_data')
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def json_serial(obj):
"""JSON serializer for objects not serializable by default JSON code."""
if isinstance(obj, datetime.datetime):
serial = obj.isoformat()
return serial
raise TypeError('Type not serializable')
try:
import dev_appserver
dev_appserver.fix_sys_path()
except ImportError:
print('Please make sure the App Engine SDK is in your PYTHONPATH.')
raise
from google.appengine.ext.remote_api import remote_api_stub
import model
def fetch_all(query, limit=100):
results = []
start_dt = datetime.datetime.now()
# Fetch all entities in batches.
entities, cursor, more = query.fetch_page(limit)
while more:
results.extend(entities)
entities, cursor, more = query.fetch_page(limit, start_cursor=cursor)
print(
len(results),
'entities at',
datetime.datetime.now(),
) # Progress and time tracker.
results.extend(entities)
print(
len(results),
'total entities fetched in',
(datetime.datetime.now() - start_dt).seconds,
'seconds',
) # Progress and time tracker.
return results
def filter_to_str(filter):
return '{}.{}.{}.{}.json'.format(
filter.experiment_id if filter.experiment_id is not None else '',
filter.task_id if filter.task_id is not None else '',
filter.participant_id if filter.participant_id is not None else '',
filter.participant_index if filter.participant_index is not None else '',
)
def main(project_id, participants=None, tasks=None, data_dir=DATA_DIR,
separate=True, overwrite=False):
"""Fetch experiment data.
Args:
participants (Optional[[int]]): List of participant ids for
study. Necessary because projection query is not behaving
correctly.
tasks (Optional[[str]]): List of task names.
separate (Optional[bool]): Separate into tasks / participants.
Defaults to True.
Creates data_dir with the following files (JSON-serialized lists of
entities):
- actions.json: model.Actions entities.
"""
if not os.path.exists(data_dir):
os.makedirs(data_dir)
remote_api_stub.ConfigureRemoteApiForOAuth(
'{}.appspot.com'.format(project_id),
'/_ah/remote_api')
for entity_type, fname in [(model.Actions, 'actions.json')]:
filters = []
if fname == 'actions.json' and not participants:
filters = entity_type.query(
projection=[
entity_type.experiment_id,
entity_type.task_id,
entity_type.participant_id,
entity_type.participant_index,
],
distinct=True
).fetch()
if not separate:
with open(os.path.join(data_dir, fname), 'w') as f:
print('fetching {}...'.format(fname))
records = [
rec.to_dict() for rec in fetch_all(
entity_type.query().order(entity_type.date))
]
json.dump(records, f, cls=Encoder, default=json_serial)
else:
if not participants:
print('about to fetch {}'.format(map(filter_to_str, filters)))
elif tasks:
filters = []
for task in tasks:
for p in participants:
d = AttrDict()
d['experiment_id'] = 'study'
d['task_id'] = task
d['participant_index'] = p
d['participant_id'] = None
filters.append(d)
else:
raise Exception('Participants specified but not tasks.')
for filter in filters:
name = '{}.{}.{}.{}.json'.format(
filter.experiment_id if filter.experiment_id is not None else '',
filter.task_id if filter.task_id is not None else '',
filter.participant_id if filter.participant_id is not None else '',
filter.participant_index if filter.participant_index is not None else '',
)
path = os.path.join(data_dir, name)
if not os.path.exists(path) or overwrite:
print('fetching {}...'.format(name))
with open(path, 'w') as f:
records = [
rec.to_dict() for rec in fetch_all(
entity_type.query(
entity_type.experiment_id == filter.experiment_id,
entity_type.task_id == filter.task_id,
entity_type.participant_id == filter.participant_id,
entity_type.participant_index == filter.participant_index,
).order(entity_type.start_time)
)
]
try:
for rec in records:
del rec['next_state']['currentItem']['currentItemId']['history']
del rec['prev_state']['currentItem']['currentItemId']['history']
except KeyError:
pass
json.dump(records, f, default=json_serial)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('project_id', help='Your Project ID.')
parser.add_argument('--participants', '-p', type=int, nargs='*')
parser.add_argument('--tasks', '-t', type=str, nargs='*')
args = parser.parse_args()
main(
args.project_id,
participants=args.participants or None,
tasks=args.tasks or None
)