diff --git a/easy_tpp/config_factory/config.py b/easy_tpp/config_factory/config.py index 4bdfff0..de0cc9d 100644 --- a/easy_tpp/config_factory/config.py +++ b/easy_tpp/config_factory/config.py @@ -1,33 +1,34 @@ from abc import abstractmethod from typing import Any +from omegaconf import OmegaConf -from easy_tpp.utils import save_yaml_config, load_yaml_config, Registrable, logger +from easy_tpp.utils import save_yaml_config, Registrable, logger class Config(Registrable): - def save_to_yaml_file(self, fn): - """Save the config into the yaml file 'fn'. + def save_to_yaml_file(self, config_dir): + """Save the config into the yaml file 'config_dir'. Args: - fn (str): Target filename. + config_dir (str): Target filename. Returns: """ yaml_config = self.get_yaml_config() - save_yaml_config(fn, yaml_config) + OmegaConf.save(yaml_config, config_dir) @staticmethod - def build_from_yaml_file(yaml_fn, **kwargs): + def build_from_yaml_file(yaml_dir, **kwargs): """Load yaml config file from disk. Args: - yaml_fn (str): Path of the yaml config file. + yaml_dir (str): Path of the yaml config file. Returns: EasyTPP.Config: Config object corresponding to cls. """ - config = load_yaml_config(yaml_fn) + config = OmegaConf.load(yaml_dir) pipeline_config = config.get('pipeline_config_id') config_cls = Config.by_name(pipeline_config.lower()) logger.critical(f'Load pipeline config class {config_cls.__name__}') diff --git a/easy_tpp/config_factory/data_config.py b/easy_tpp/config_factory/data_config.py index e3f1d9a..9b3fd30 100644 --- a/easy_tpp/config_factory/data_config.py +++ b/easy_tpp/config_factory/data_config.py @@ -69,8 +69,9 @@ def copy(self): max_len=self.max_len) +@Config.register('data_config') class DataConfig(Config): - def __init__(self, train_dir, valid_dir, test_dir, specs=None): + def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None): """Initialize the DataConfig object. Args: @@ -83,7 +84,7 @@ def __init__(self, train_dir, valid_dir, test_dir, specs=None): self.valid_dir = valid_dir self.test_dir = test_dir self.data_specs = specs or DataSpecConfig() - self.data_format = train_dir.split('.')[-1] + self.data_format = train_dir.split('.')[-1] if data_format is None else data_format def get_yaml_config(self): """Return the config in dict (yaml compatible) format. @@ -113,6 +114,7 @@ def parse_from_yaml_config(yaml_config): train_dir=yaml_config.get('train_dir'), valid_dir=yaml_config.get('valid_dir'), test_dir=yaml_config.get('test_dir'), + data_format=yaml_config.get('data_format'), specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs')) ) diff --git a/easy_tpp/preprocess/data_loader.py b/easy_tpp/preprocess/data_loader.py index de4c588..95c50ad 100644 --- a/easy_tpp/preprocess/data_loader.py +++ b/easy_tpp/preprocess/data_loader.py @@ -1,3 +1,6 @@ +import matplotlib.pyplot as plt +import numpy as np +from collections import Counter from easy_tpp.preprocess.dataset import TPPDataset from easy_tpp.preprocess.dataset import get_data_loader from easy_tpp.preprocess.event_tokenizer import EventTokenizer @@ -5,7 +8,7 @@ class TPPDataLoader: - def __init__(self, data_config, backend, **kwargs): + def __init__(self, data_config, **kwargs): """Initialize the dataloader Args: @@ -14,45 +17,75 @@ def __init__(self, data_config, backend, **kwargs): """ self.data_config = data_config self.num_event_types = data_config.data_specs.num_event_types - self.backend = backend + self.backend = kwargs.get('backend', 'torch') self.kwargs = kwargs - def build_input_from_pkl(self, source_dir: str, split: str): - data = load_pickle(source_dir) + def build_input(self, source_dir, data_format, split): + """Helper function to load and process dataset based on file format. + + Args: + source_dir (str): Path to dataset directory. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + + Returns: + dict: Dictionary containing sequences of event times, types, and intervals. + """ + + if data_format == 'pkl': + return self._build_input_from_pkl(source_dir, split) + elif data_format == 'json': + return self._build_input_from_json(source_dir, split) + else: + raise ValueError(f"Unsupported file format: {data_format}") + + def _build_input_from_pkl(self, source_dir, split): + """Load and process data from a pickle file. + + Args: + source_dir (str): Path to the pickle file. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + Returns: + dict: Dictionary with processed event sequences. + """ + data = load_pickle(source_dir) py_assert(data["dim_process"] == self.num_event_types, - ValueError, - "inconsistent dim_process in different splits?") + ValueError, "Inconsistent dim_process in different splits.") source_data = data[split] - time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data] - type_seqs = [[x["type_event"] for x in seq] for seq in source_data] - time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data] + return { + 'time_seqs': [[x["time_since_start"] for x in seq] for seq in source_data], + 'type_seqs': [[x["type_event"] for x in seq] for seq in source_data], + 'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in source_data] + } - input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs}) - return input_dict + def _build_input_from_json(self, source_dir, split): + """Load and process data from a JSON file. - def build_input_from_json(self, source_dir: str, split: str): + Args: + source_dir (str): Path to the JSON file or Hugging Face dataset name. + split (str): Dataset split, e.g., 'train', 'dev', 'test'. + + Returns: + dict: Dictionary with processed event sequences. + """ from datasets import load_dataset - split_ = 'validation' if split == 'dev' else split - # load locally - if source_dir.split('.')[-1] == 'json': - data = load_dataset('json', data_files={split_: source_dir}, split=split_) + split_mapped = 'validation' if split == 'dev' else split + if source_dir.endswith('.json'): + data = load_dataset('json', data_files={split_mapped: source_dir}, split=split_mapped) elif source_dir.startswith('easytpp'): - data = load_dataset(source_dir, split=split_) + data = load_dataset(source_dir, split=split_mapped) else: - raise NotImplementedError + raise ValueError("Unsupported source directory format for JSON.") py_assert(data['dim_process'][0] == self.num_event_types, - ValueError, - "inconsistent dim_process in different splits?") - - time_seqs = data['time_since_start'] - type_seqs = data['type_event'] - time_delta_seqs = data['time_since_last_event'] + ValueError, "Inconsistent dim_process in different splits.") - input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs}) - return input_dict + return { + 'time_seqs': data['time_since_start'], + 'type_seqs': data['type_event'], + 'time_delta_seqs': data['time_since_last_event'] + } def get_loader(self, split='train', **kwargs): """Get the corresponding data loader. @@ -68,12 +101,7 @@ def get_loader(self, split='train', **kwargs): EasyTPP.DataLoader: the data loader for tpp data. """ data_dir = self.data_config.get_data_dir(split) - data_source_type = data_dir.split('.')[-1] - - if data_source_type == 'pkl': - data = self.build_input_from_pkl(data_dir, split) - else: - data = self.build_input_from_json(data_dir, split) + data = self.build_input(data_dir, self.data_config.data_format, split) dataset = TPPDataset(data) tokenizer = EventTokenizer(self.data_config.data_specs) @@ -109,3 +137,93 @@ def test_loader(self, **kwargs): EasyTPP.DataLoader: data loader for test set. """ return self.get_loader('test', **kwargs) + + def get_statistics(self, split='train'): + """Get basic statistics about the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + + Returns: + dict: Dictionary containing statistics about the dataset. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + + num_sequences = len(data['time_seqs']) + sequence_lengths = [len(seq) for seq in data['time_seqs']] + avg_sequence_length = sum(sequence_lengths) / num_sequences + all_event_types = [event for seq in data['type_seqs'] for event in seq] + event_type_counts = Counter(all_event_types) + + # Calculate time_delta_seqs statistics + all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq] + mean_time_delta = np.mean(all_time_deltas) if all_time_deltas else 0 + min_time_delta = np.min(all_time_deltas) if all_time_deltas else 0 + max_time_delta = np.max(all_time_deltas) if all_time_deltas else 0 + + stats = { + "num_sequences": num_sequences, + "avg_sequence_length": avg_sequence_length, + "event_type_distribution": dict(event_type_counts), + "max_sequence_length": max(sequence_lengths), + "min_sequence_length": min(sequence_lengths), + "mean_time_delta": mean_time_delta, + "min_time_delta": min_time_delta, + "max_time_delta": max_time_delta + } + + return stats + + def plot_event_type_distribution(self, split='train'): + """Plot the distribution of event types in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + stats = self.get_statistics(split) + event_type_distribution = stats['event_type_distribution'] + + plt.figure(figsize=(8, 6)) + plt.bar(event_type_distribution.keys(), event_type_distribution.values(), color='skyblue') + plt.xlabel('Event Types') + plt.ylabel('Frequency') + plt.title(f'Event Type Distribution ({split} set)') + plt.show() + + def plot_event_delta_times_distribution(self, split='train'): + """Plot the distribution of event delta times in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + + # Flatten the time_delta_seqs to get all delta times + all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq] + + plt.figure(figsize=(10, 6)) + plt.hist(all_time_deltas, bins=30, color='skyblue', edgecolor='black') + plt.xlabel('Event Delta Times') + plt.ylabel('Frequency') + plt.title(f'Event Delta Times Distribution ({split} set)') + plt.grid(axis='y', alpha=0.75) + plt.show() + + def plot_sequence_length_distribution(self, split='train'): + """Plot the distribution of sequence lengths in the dataset. + + Args: + split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'. + """ + data_dir = self.data_config.get_data_dir(split) + data = self.build_input(data_dir, self.data_config.data_format, split) + sequence_lengths = [len(seq) for seq in data['time_seqs']] + + plt.figure(figsize=(8, 6)) + plt.hist(sequence_lengths, bins=10, color='salmon', edgecolor='black') + plt.xlabel('Sequence Length') + plt.ylabel('Frequency') + plt.title(f'Sequence Length Distribution ({split} set)') + plt.show() diff --git a/examples/data_inspection/config.yaml b/examples/data_inspection/config.yaml new file mode 100644 index 0000000..91effc9 --- /dev/null +++ b/examples/data_inspection/config.yaml @@ -0,0 +1,10 @@ +pipeline_config_id: data_config + +data_format: json +train_dir: easytpp/taxi # ./data/taxi/train.json +valid_dir: easytpp/taxi # ./data/taxi/dev.json +test_dir: easytpp/taxi # ./data/taxi/test.json +data_specs: + num_event_types: 10 + pad_token_id: 10 + padding_side: right \ No newline at end of file diff --git a/examples/data_inspection/data_inspection.py b/examples/data_inspection/data_inspection.py new file mode 100644 index 0000000..8955c48 --- /dev/null +++ b/examples/data_inspection/data_inspection.py @@ -0,0 +1,20 @@ +import os +import sys +# Get the directory of the current file +current_file_path = os.path.abspath(__file__) +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))) + +from easy_tpp.config_factory import Config +from easy_tpp.preprocess.data_loader import TPPDataLoader + + +def main(): + config = Config.build_from_yaml_file('./config.yaml') + tpp_loader = TPPDataLoader(config) + stats = tpp_loader.get_statistics(split='train') + print(stats) + tpp_loader.plot_event_type_distribution() + tpp_loader.plot_event_delta_times_distribution() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9a13add..25978a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ torch tensorboard packaging datasets +omegaconf diff --git a/version.py b/version.py index e8a828e..9123cf0 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = '0.0.7.1' +__version__ = '0.0.8'