Skip to content

Commit

Permalink
Fix data loader and add example config
Browse files Browse the repository at this point in the history
  • Loading branch information
alilevy committed Feb 12, 2024
1 parent 801a08e commit 1b04cf5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions easy_tpp/preprocess/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def build_input_from_json(self, source_dir: str, split: str):
# load locally
if source_dir.split('.')[-1] == 'json':
data = load_dataset('json', data_files={split_: source_dir})
else:
elif source_dir.startswith('easytpp'):
data = load_dataset(source_dir, split=split_)
else:
raise NotImplementedError

py_assert(data[split_].data['dim_process'][0].as_py() == self.num_event_types,
ValueError,
Expand All @@ -49,7 +51,7 @@ def build_input_from_json(self, source_dir: str, split: str):
time_seqs, type_seqs, time_delta_seqs = [], [], []
for k, v in source_data.items():
cur_time_seq, cur_type_seq, cur_time_delta_seq = [], [], []
for k_, v_ in v:
for k_, v_ in v.items():
cur_time_seq.append(v_['time_since_start'])
cur_type_seq.append(v_['type_event'])
cur_time_delta_seq.append(v_['time_since_last_event'])
Expand Down
8 changes: 4 additions & 4 deletions examples/configs/experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ pipeline_config_id: runner_config

data:
taxi:
data_format: pkl
train_dir: ./data/taxi/train.pkl
valid_dir: ./data/taxi/dev.pkl
test_dir: ./data/taxi/test.pkl
data_format: json
train_dir: ./data/taxi/train.json
valid_dir: ./data/taxi/dev.json
test_dir: ./data/taxi/test.json
data_specs:
num_event_types: 10
pad_token_id: 10
Expand Down

0 comments on commit 1b04cf5

Please sign in to comment.