Skip to content

Commit

Permalink
add data inspection utilities and tests
Browse files Browse the repository at this point in the history
1. Add data inspection utilities in data loader.
2. Use omegaconf to load and save yaml config.
3. Add tests to run NHP.
  • Loading branch information
iLampard committed Oct 28, 2024
1 parent 1f85043 commit 21699a9
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 43 deletions.
17 changes: 9 additions & 8 deletions easy_tpp/config_factory/config.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
from abc import abstractmethod
from typing import Any
from omegaconf import OmegaConf

from easy_tpp.utils import save_yaml_config, load_yaml_config, Registrable, logger
from easy_tpp.utils import save_yaml_config, Registrable, logger


class Config(Registrable):

def save_to_yaml_file(self, fn):
"""Save the config into the yaml file 'fn'.
def save_to_yaml_file(self, config_dir):
"""Save the config into the yaml file 'config_dir'.
Args:
fn (str): Target filename.
config_dir (str): Target filename.
Returns:
"""
yaml_config = self.get_yaml_config()
save_yaml_config(fn, yaml_config)
OmegaConf.save(yaml_config, config_dir)

@staticmethod
def build_from_yaml_file(yaml_fn, **kwargs):
def build_from_yaml_file(yaml_dir, **kwargs):
"""Load yaml config file from disk.
Args:
yaml_fn (str): Path of the yaml config file.
yaml_dir (str): Path of the yaml config file.
Returns:
EasyTPP.Config: Config object corresponding to cls.
"""
config = load_yaml_config(yaml_fn)
config = OmegaConf.load(yaml_dir)
pipeline_config = config.get('pipeline_config_id')
config_cls = Config.by_name(pipeline_config.lower())
logger.critical(f'Load pipeline config class {config_cls.__name__}')
Expand Down
6 changes: 4 additions & 2 deletions easy_tpp/config_factory/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def copy(self):
max_len=self.max_len)


@Config.register('data_config')
class DataConfig(Config):
def __init__(self, train_dir, valid_dir, test_dir, specs=None):
def __init__(self, train_dir, valid_dir, test_dir, data_format, specs=None):
"""Initialize the DataConfig object.
Args:
Expand All @@ -83,7 +84,7 @@ def __init__(self, train_dir, valid_dir, test_dir, specs=None):
self.valid_dir = valid_dir
self.test_dir = test_dir
self.data_specs = specs or DataSpecConfig()
self.data_format = train_dir.split('.')[-1]
self.data_format = train_dir.split('.')[-1] if data_format is None else data_format

def get_yaml_config(self):
"""Return the config in dict (yaml compatible) format.
Expand Down Expand Up @@ -113,6 +114,7 @@ def parse_from_yaml_config(yaml_config):
train_dir=yaml_config.get('train_dir'),
valid_dir=yaml_config.get('valid_dir'),
test_dir=yaml_config.get('test_dir'),
data_format=yaml_config.get('data_format'),
specs=DataSpecConfig.parse_from_yaml_config(yaml_config.get('data_specs'))
)

Expand Down
182 changes: 150 additions & 32 deletions easy_tpp/preprocess/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import matplotlib.pyplot as plt
import numpy as np
from collections import Counter
from easy_tpp.preprocess.dataset import TPPDataset
from easy_tpp.preprocess.dataset import get_data_loader
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
from easy_tpp.utils import load_pickle, py_assert


class TPPDataLoader:
def __init__(self, data_config, backend, **kwargs):
def __init__(self, data_config, **kwargs):
"""Initialize the dataloader
Args:
Expand All @@ -14,45 +17,75 @@ def __init__(self, data_config, backend, **kwargs):
"""
self.data_config = data_config
self.num_event_types = data_config.data_specs.num_event_types
self.backend = backend
self.backend = kwargs.get('backend', 'torch')
self.kwargs = kwargs

def build_input_from_pkl(self, source_dir: str, split: str):
data = load_pickle(source_dir)
def build_input(self, source_dir, data_format, split):
"""Helper function to load and process dataset based on file format.
Args:
source_dir (str): Path to dataset directory.
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
Returns:
dict: Dictionary containing sequences of event times, types, and intervals.
"""

if data_format == 'pkl':
return self._build_input_from_pkl(source_dir, split)
elif data_format == 'json':
return self._build_input_from_json(source_dir, split)
else:
raise ValueError(f"Unsupported file format: {data_format}")

def _build_input_from_pkl(self, source_dir, split):
"""Load and process data from a pickle file.
Args:
source_dir (str): Path to the pickle file.
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
Returns:
dict: Dictionary with processed event sequences.
"""
data = load_pickle(source_dir)
py_assert(data["dim_process"] == self.num_event_types,
ValueError,
"inconsistent dim_process in different splits?")
ValueError, "Inconsistent dim_process in different splits.")

source_data = data[split]
time_seqs = [[x["time_since_start"] for x in seq] for seq in source_data]
type_seqs = [[x["type_event"] for x in seq] for seq in source_data]
time_delta_seqs = [[x["time_since_last_event"] for x in seq] for seq in source_data]
return {
'time_seqs': [[x["time_since_start"] for x in seq] for seq in source_data],
'type_seqs': [[x["type_event"] for x in seq] for seq in source_data],
'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in source_data]
}

input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
return input_dict
def _build_input_from_json(self, source_dir, split):
"""Load and process data from a JSON file.
def build_input_from_json(self, source_dir: str, split: str):
Args:
source_dir (str): Path to the JSON file or Hugging Face dataset name.
split (str): Dataset split, e.g., 'train', 'dev', 'test'.
Returns:
dict: Dictionary with processed event sequences.
"""
from datasets import load_dataset
split_ = 'validation' if split == 'dev' else split
# load locally
if source_dir.split('.')[-1] == 'json':
data = load_dataset('json', data_files={split_: source_dir}, split=split_)
split_mapped = 'validation' if split == 'dev' else split
if source_dir.endswith('.json'):
data = load_dataset('json', data_files={split_mapped: source_dir}, split=split_mapped)
elif source_dir.startswith('easytpp'):
data = load_dataset(source_dir, split=split_)
data = load_dataset(source_dir, split=split_mapped)
else:
raise NotImplementedError
raise ValueError("Unsupported source directory format for JSON.")

py_assert(data['dim_process'][0] == self.num_event_types,
ValueError,
"inconsistent dim_process in different splits?")

time_seqs = data['time_since_start']
type_seqs = data['type_event']
time_delta_seqs = data['time_since_last_event']
ValueError, "Inconsistent dim_process in different splits.")

input_dict = dict({'time_seqs': time_seqs, 'time_delta_seqs': time_delta_seqs, 'type_seqs': type_seqs})
return input_dict
return {
'time_seqs': data['time_since_start'],
'type_seqs': data['type_event'],
'time_delta_seqs': data['time_since_last_event']
}

def get_loader(self, split='train', **kwargs):
"""Get the corresponding data loader.
Expand All @@ -68,12 +101,7 @@ def get_loader(self, split='train', **kwargs):
EasyTPP.DataLoader: the data loader for tpp data.
"""
data_dir = self.data_config.get_data_dir(split)
data_source_type = data_dir.split('.')[-1]

if data_source_type == 'pkl':
data = self.build_input_from_pkl(data_dir, split)
else:
data = self.build_input_from_json(data_dir, split)
data = self.build_input(data_dir, self.data_config.data_format, split)

dataset = TPPDataset(data)
tokenizer = EventTokenizer(self.data_config.data_specs)
Expand Down Expand Up @@ -109,3 +137,93 @@ def test_loader(self, **kwargs):
EasyTPP.DataLoader: data loader for test set.
"""
return self.get_loader('test', **kwargs)

def get_statistics(self, split='train'):
"""Get basic statistics about the dataset.
Args:
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
Returns:
dict: Dictionary containing statistics about the dataset.
"""
data_dir = self.data_config.get_data_dir(split)
data = self.build_input(data_dir, self.data_config.data_format, split)

num_sequences = len(data['time_seqs'])
sequence_lengths = [len(seq) for seq in data['time_seqs']]
avg_sequence_length = sum(sequence_lengths) / num_sequences
all_event_types = [event for seq in data['type_seqs'] for event in seq]
event_type_counts = Counter(all_event_types)

# Calculate time_delta_seqs statistics
all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]
mean_time_delta = np.mean(all_time_deltas) if all_time_deltas else 0
min_time_delta = np.min(all_time_deltas) if all_time_deltas else 0
max_time_delta = np.max(all_time_deltas) if all_time_deltas else 0

stats = {
"num_sequences": num_sequences,
"avg_sequence_length": avg_sequence_length,
"event_type_distribution": dict(event_type_counts),
"max_sequence_length": max(sequence_lengths),
"min_sequence_length": min(sequence_lengths),
"mean_time_delta": mean_time_delta,
"min_time_delta": min_time_delta,
"max_time_delta": max_time_delta
}

return stats

def plot_event_type_distribution(self, split='train'):
"""Plot the distribution of event types in the dataset.
Args:
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
"""
stats = self.get_statistics(split)
event_type_distribution = stats['event_type_distribution']

plt.figure(figsize=(8, 6))
plt.bar(event_type_distribution.keys(), event_type_distribution.values(), color='skyblue')
plt.xlabel('Event Types')
plt.ylabel('Frequency')
plt.title(f'Event Type Distribution ({split} set)')
plt.show()

def plot_event_delta_times_distribution(self, split='train'):
"""Plot the distribution of event delta times in the dataset.
Args:
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
"""
data_dir = self.data_config.get_data_dir(split)
data = self.build_input(data_dir, self.data_config.data_format, split)

# Flatten the time_delta_seqs to get all delta times
all_time_deltas = [delta for seq in data['time_delta_seqs'] for delta in seq]

plt.figure(figsize=(10, 6))
plt.hist(all_time_deltas, bins=30, color='skyblue', edgecolor='black')
plt.xlabel('Event Delta Times')
plt.ylabel('Frequency')
plt.title(f'Event Delta Times Distribution ({split} set)')
plt.grid(axis='y', alpha=0.75)
plt.show()

def plot_sequence_length_distribution(self, split='train'):
"""Plot the distribution of sequence lengths in the dataset.
Args:
split (str): Dataset split, e.g., 'train', 'dev', 'test'. Default is 'train'.
"""
data_dir = self.data_config.get_data_dir(split)
data = self.build_input(data_dir, self.data_config.data_format, split)
sequence_lengths = [len(seq) for seq in data['time_seqs']]

plt.figure(figsize=(8, 6))
plt.hist(sequence_lengths, bins=10, color='salmon', edgecolor='black')
plt.xlabel('Sequence Length')
plt.ylabel('Frequency')
plt.title(f'Sequence Length Distribution ({split} set)')
plt.show()
10 changes: 10 additions & 0 deletions examples/data_inspection/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pipeline_config_id: data_config

data_format: json
train_dir: easytpp/taxi # ./data/taxi/train.json
valid_dir: easytpp/taxi # ./data/taxi/dev.json
test_dir: easytpp/taxi # ./data/taxi/test.json
data_specs:
num_event_types: 10
pad_token_id: 10
padding_side: right
20 changes: 20 additions & 0 deletions examples/data_inspection/data_inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
import sys
# Get the directory of the current file
current_file_path = os.path.abspath(__file__)
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))))

from easy_tpp.config_factory import Config
from easy_tpp.preprocess.data_loader import TPPDataLoader


def main():
config = Config.build_from_yaml_file('./config.yaml')
tpp_loader = TPPDataLoader(config)
stats = tpp_loader.get_statistics(split='train')
print(stats)
tpp_loader.plot_event_type_distribution()
tpp_loader.plot_event_delta_times_distribution()

if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ torch
tensorboard
packaging
datasets
omegaconf
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.7.1'
__version__ = '0.0.8'

0 comments on commit 21699a9

Please sign in to comment.