From 1b04cf5477fd89f518f970f8108e4f7467f19afd Mon Sep 17 00:00:00 2001 From: "siqiao.xsq" Date: Mon, 12 Feb 2024 10:38:45 +0800 Subject: [PATCH] Fix data loader and add example config --- easy_tpp/preprocess/data_loader.py | 6 ++++-- examples/configs/experiment_config.yaml | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/easy_tpp/preprocess/data_loader.py b/easy_tpp/preprocess/data_loader.py index a1d6d2c..bc1604f 100644 --- a/easy_tpp/preprocess/data_loader.py +++ b/easy_tpp/preprocess/data_loader.py @@ -38,8 +38,10 @@ def build_input_from_json(self, source_dir: str, split: str): # load locally if source_dir.split('.')[-1] == 'json': data = load_dataset('json', data_files={split_: source_dir}) - else: + elif source_dir.startswith('easytpp'): data = load_dataset(source_dir, split=split_) + else: + raise NotImplementedError py_assert(data[split_].data['dim_process'][0].as_py() == self.num_event_types, ValueError, @@ -49,7 +51,7 @@ def build_input_from_json(self, source_dir: str, split: str): time_seqs, type_seqs, time_delta_seqs = [], [], [] for k, v in source_data.items(): cur_time_seq, cur_type_seq, cur_time_delta_seq = [], [], [] - for k_, v_ in v: + for k_, v_ in v.items(): cur_time_seq.append(v_['time_since_start']) cur_type_seq.append(v_['type_event']) cur_time_delta_seq.append(v_['time_since_last_event']) diff --git a/examples/configs/experiment_config.yaml b/examples/configs/experiment_config.yaml index 4f09747..8dd76ae 100644 --- a/examples/configs/experiment_config.yaml +++ b/examples/configs/experiment_config.yaml @@ -2,10 +2,10 @@ pipeline_config_id: runner_config data: taxi: - data_format: pkl - train_dir: ./data/taxi/train.pkl - valid_dir: ./data/taxi/dev.pkl - test_dir: ./data/taxi/test.pkl + data_format: json + train_dir: ./data/taxi/train.json + valid_dir: ./data/taxi/dev.json + test_dir: ./data/taxi/test.json data_specs: num_event_types: 10 pad_token_id: 10