Skip to content

Commit

Permalink
Merge pull request #72 from ibm-granite/classification_2
Browse files Browse the repository at this point in the history
Classification 2
  • Loading branch information
wgifford authored Aug 7, 2024
2 parents 21cf29b + 6778707 commit c9bbacf
Show file tree
Hide file tree
Showing 4 changed files with 779 additions and 4 deletions.
33 changes: 33 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import numpy as np
import pandas as pd
import pytest
from torch.utils.data import DataLoader, default_collate

from tsfm_public.toolkit.dataset import (
ClassificationDFDataset,
ForecastDFDataset,
PretrainDFDataset,
ts_padding,
Expand Down Expand Up @@ -261,3 +263,34 @@ def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical):

# check that past values of targets are zeroed out
assert np.all(ds[0]["past_values"][:, 0].numpy() == 0)


def test_clasification_df_dataset(ts_data):
data = ts_data.copy()
data["label"] = (data["val"] > 4).astype(int)

ds = ClassificationDFDataset(
data,
timestamp_column="time_date",
input_columns=["val", "val2"],
label_column=["label"],
id_columns=["id", "id2"],
context_length=4,
)

# check length
assert len(ds) == len(data) - ds.context_length + 1

# check alignment
assert ds[-1]["timestamp"] == ts_data.iloc[-1]["time_date"]

# check shape under dataloader
def my_collate(batch):
valid_keys = ["past_values", "target_values"]
batch_ = [{k: item[k] for k in valid_keys} for item in batch]
return default_collate(batch_)

dl = DataLoader(ds, batch_size=2, collate_fn=my_collate)
b = next(iter(dl))

assert len(b["target_values"].shape) == 1
31 changes: 30 additions & 1 deletion tests/toolkit/test_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for util functions"""

import tempfile

import pandas as pd
import pytest

from tsfm_public.toolkit.util import convert_to_univariate, get_split_params, train_test_split
from tsfm_public.toolkit.util import convert_to_univariate, convert_tsfile, get_split_params, train_test_split


split_cases = [
Expand Down Expand Up @@ -73,3 +75,30 @@ def test_convert_to_univariate(ts_data):
df_uni = convert_to_univariate(
ts_data, timestamp_column=timestamp_column, id_columns=id_columns, target_columns=target_columns
)


def test_convert_tsfile():
data = """#Test
#
#The classes are
#1. one
#2. two
#3. three
@problemName test
@timeStamps false
@missing false
@univariate true
@equalLength true
@seriesLength 5
@classLabel true 1 2 3
@data
1,2,3,4,5:1
10,20,30,40,50:2
11,12,13,14,15:3
"""
with tempfile.NamedTemporaryFile() as t:
t.write(data.encode("utf-8"))
t.flush()
df = convert_tsfile(t.name)

assert df.shape == (15, 3)
118 changes: 115 additions & 3 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __getitem__(self, index):
time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
ret = {
"past_values": np_to_torch(seq_x),
"past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)),
"past_observed_mask": np_to_torch(~np.isnan(seq_x)),
}
if self.datetime_col:
Expand Down Expand Up @@ -534,6 +534,7 @@ def __init__(
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
fill_value: Union[float, int] = 0.0,
):
# self.y_cols = y_cols

Expand All @@ -543,11 +544,13 @@ def __init__(
timestamp_column=timestamp_column,
num_workers=num_workers,
context_length=context_length,
prediction_length=0,
cls=self.BaseRegressionDFDataset,
input_columns=input_columns,
target_columns=target_columns,
static_categorical_columns=static_categorical_columns,
stride=stride,
fill_value=fill_value,
)

self.n_inp = 2
Expand All @@ -570,6 +573,7 @@ def __init__(
input_columns: List[str] = [],
static_categorical_columns: List[str] = [],
stride: int = 1,
fill_value: Union[float, int] = 0.0,
):
self.target_columns = target_columns
self.input_columns = input_columns
Expand All @@ -589,6 +593,7 @@ def __init__(
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
fill_value=fill_value,
)

def __getitem__(self, index):
Expand All @@ -600,8 +605,115 @@ def __getitem__(self, index):
# return _torch(seq_x, seq_y)

ret = {
"past_values": np_to_torch(seq_x),
"target_values": np_to_torch(seq_y),
"past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)),
"target_values": np_to_torch(np.nan_to_num(seq_y, nan=self.fill_value)),
"past_observed_mask": np_to_torch(~np.isnan(seq_x)),
}
if self.datetime_col:
ret["timestamp"] = self.timestamps[time_id + self.context_length - 1]

if self.group_id:
ret["id"] = self.group_id

if self.static_categorical_columns:
categorical_values = self.data_df[self.static_categorical_columns].values[0, :]
ret["static_categorical_values"] = np_to_torch(categorical_values)

return ret


class ClassificationDFDataset(BaseConcatDFDataset):
"""
A dataset for use with time series classification.
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
y_cols (list, required): list of columns of y. Defaults to an empty list.
group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
seq_len (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
"""

def __init__(
self,
data: pd.DataFrame,
id_columns: List[str] = [],
timestamp_column: Optional[str] = None,
input_columns: List[str] = [],
label_column: str = "label",
static_categorical_columns: List[str] = [],
context_length: int = 1,
num_workers: int = 1,
stride: int = 1,
fill_value: Union[float, int] = 0.0,
):
super().__init__(
data_df=data,
id_columns=id_columns,
timestamp_column=timestamp_column,
num_workers=num_workers,
context_length=context_length,
prediction_length=0,
cls=self.BaseClassificationDFDataset,
input_columns=input_columns,
label_column=label_column,
static_categorical_columns=static_categorical_columns,
stride=stride,
fill_value=fill_value,
)

self.n_inp = 2

class BaseClassificationDFDataset(BaseDFDataset):
def __init__(
self,
data_df: pd.DataFrame,
group_id: Optional[Union[List[int], List[str]]] = None,
context_length: int = 1,
prediction_length: int = 0,
drop_cols: list = [],
id_columns: List[str] = [],
timestamp_column: Optional[str] = None,
label_column: str = "label",
input_columns: List[str] = [],
static_categorical_columns: List[str] = [],
stride: int = 1,
fill_value: Union[float, int] = 0.0,
):
self.label_column = label_column
self.input_columns = input_columns
self.static_categorical_columns = static_categorical_columns

x_cols = input_columns
y_cols = label_column

super().__init__(
data_df=data_df,
id_columns=id_columns,
timestamp_column=timestamp_column,
x_cols=x_cols,
y_cols=y_cols,
context_length=context_length,
prediction_length=prediction_length,
group_id=group_id,
drop_cols=drop_cols,
stride=stride,
fill_value=fill_value,
)

def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
# seq_y = self.y[time_id + self.context_length - 1 : time_id + self.context_length].values.ravel()
seq_y = self.y.iloc[time_id + self.context_length - 1].values[0]

ret = {
"past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)),
"target_values": torch.tensor(np.nan_to_num(seq_y, nan=self.fill_value), dtype=torch.int64),
"past_observed_mask": np_to_torch(~np.isnan(seq_x)),
}
if self.datetime_col:
ret["timestamp"] = self.timestamps[time_id + self.context_length - 1]
Expand Down
Loading

0 comments on commit c9bbacf

Please sign in to comment.