Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
alilevy committed Feb 15, 2024
2 parents 3b3686c + c249212 commit 9da8b8e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 18 deletions.
18 changes: 5 additions & 13 deletions easy_tpp/preprocess/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,19 @@ def build_input_from_json(self, source_dir: str, split: str):
split_ = 'validation' if split == 'dev' else split
# load locally
if source_dir.split('.')[-1] == 'json':
data = load_dataset('json', data_files={split_: source_dir})
data = load_dataset('json', data_files={split_: source_dir}, split=split_)
elif source_dir.startswith('easytpp'):
data = load_dataset(source_dir, split=split_)
else:
raise NotImplementedError

py_assert(data[split_].data['dim_process'][0].as_py() == self.num_event_types,
py_assert(data['dim_process'][0] == self.num_event_types,
ValueError,
"inconsistent dim_process in different splits?")

source_data = data[split_]['event_seqs'][0]
time_seqs, type_seqs, time_delta_seqs = [], [], []
for k, v in source_data.items():
cur_time_seq, cur_type_seq, cur_time_delta_seq = [], [], []
for k_, v_ in v.items():
cur_time_seq.append(v_['time_since_start'])
cur_type_seq.append(v_['type_event'])
cur_time_delta_seq.append(v_['time_since_last_event'])
time_seqs.append(cur_time_seq)
type_seqs.append(cur_type_seq)
time_delta_seqs.append(cur_time_delta_seq)
time_seqs = data['time_since_start']
type_seqs = data['type_event']
time_delta_seqs = data['time_since_last_event']

input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
return input_dict
Expand Down
6 changes: 3 additions & 3 deletions examples/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ pipeline_config_id: runner_config
data:
taxi:
data_format: json
train_dir: ./data/taxi/train.json
valid_dir: ./data/taxi/dev.json
test_dir: ./data/taxi/test.json
train_dir: easytpp/taxi # ./data/taxi/train.json
valid_dir: easytpp/taxi # ./data/taxi/dev.json
test_dir: easytpp/taxi # ./data/taxi/test.json
data_specs:
num_event_types: 10
pad_token_id: 10
Expand Down
8 changes: 7 additions & 1 deletion examples/script_data_processing/make_hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def make_json_serializable(input_dict):

return input_dict


def make_hf_dataset(source_dir, target_dir, split='test'):
data_pkl = load_pickle(source_dir)

Expand Down Expand Up @@ -50,4 +51,9 @@ def make_hf_dataset(source_dir, target_dir, split='test'):


if __name__ == '__main__':
make_hf_dataset('../data/taxi/test.pkl', 'test.json', split='test')
test_data_dir = ['taxi/test.pkl', 'taxi/test.json']
dev_data_dir = ['taxi/dev.pkl', 'taxi/dev.json']
train_data_dir = ['taxi/train.pkl', 'taxi/train.json']
make_hf_dataset(source_dir=test_data_dir[0], target_dir=test_data_dir[1])
make_hf_dataset(source_dir=dev_data_dir[0], target_dir=dev_data_dir[1], split='dev')
make_hf_dataset(source_dir=train_data_dir[0], target_dir=train_data_dir[1], split='train')
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.7'
__version__ = '0.0.7.1'

0 comments on commit 9da8b8e

Please sign in to comment.