From 1fc2f63417a9e2a37d9419a4be6199914cfdde72 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:00:08 -0400 Subject: [PATCH 01/14] add classification --- tests/toolkit/test_dataset.py | 33 +++++++++++ tsfm_public/toolkit/dataset.py | 102 +++++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/tests/toolkit/test_dataset.py b/tests/toolkit/test_dataset.py index 6e23adad..32dd5c84 100644 --- a/tests/toolkit/test_dataset.py +++ b/tests/toolkit/test_dataset.py @@ -8,8 +8,10 @@ import numpy as np import pandas as pd import pytest +from torch.utils.data import DataLoader, default_collate from tsfm_public.toolkit.dataset import ( + ClassificationDFDataset, ForecastDFDataset, PretrainDFDataset, ts_padding, @@ -261,3 +263,34 @@ def test_forecasting_df_dataset_non_autoregressive(ts_data_with_categorical): # check that past values of targets are zeroed out assert np.all(ds[0]["past_values"][:, 0].numpy() == 0) + + +def test_clasification_df_dataset(ts_data): + data = ts_data.copy() + data["label"] = (data["val"] > 4).astype(int) + + ds = ClassificationDFDataset( + data, + timestamp_column="time_date", + input_columns=["val", "val2"], + label_column=["label"], + id_columns=["id", "id2"], + context_length=4, + ) + + # check length + assert len(ds) == len(data) - ds.context_length + 1 + + # check alignment + assert ds[-1]["timestamp"] == ts_data.iloc[-1]["time_date"] + + # check shape under dataloader + def my_collate(batch): + valid_keys = ["past_values", "target_values"] + batch_ = [{k: item[k] for k in valid_keys} for item in batch] + return default_collate(batch_) + + dl = DataLoader(ds, batch_size=2, collate_fn=my_collate) + b = next(iter(dl)) + + assert len(b["target_values"].shape) == 1 diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 1487563d..7c54bdd7 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -543,6 +543,7 @@ def __init__( timestamp_column=timestamp_column, num_workers=num_workers, context_length=context_length, + prediction_length=0, cls=self.BaseRegressionDFDataset, input_columns=input_columns, target_columns=target_columns, @@ -616,6 +617,107 @@ def __getitem__(self, index): return ret +class ClassificationDFDataset(BaseConcatDFDataset): + """ + A dataset for use with time series classification. + Args: + data_df (DataFrame, required): input data + datetime_col (str, optional): datetime column in the data_df. Defaults to None + x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list. + y_cols (list, required): list of columns of y. Defaults to an empty list. + group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group. + seq_len (int, required): the sequence length. Defaults to 1 + num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1. + """ + + def __init__( + self, + data: pd.DataFrame, + id_columns: List[str] = [], + timestamp_column: Optional[str] = None, + input_columns: List[str] = [], + label_column: str = "label", + static_categorical_columns: List[str] = [], + context_length: int = 1, + num_workers: int = 1, + stride: int = 1, + ): + super().__init__( + data_df=data, + id_columns=id_columns, + timestamp_column=timestamp_column, + num_workers=num_workers, + context_length=context_length, + prediction_length=0, + cls=self.BaseClassificationDFDataset, + input_columns=input_columns, + label_column=label_column, + static_categorical_columns=static_categorical_columns, + stride=stride, + ) + + self.n_inp = 2 + + class BaseClassificationDFDataset(BaseDFDataset): + def __init__( + self, + data_df: pd.DataFrame, + group_id: Optional[Union[List[int], List[str]]] = None, + context_length: int = 1, + prediction_length: int = 0, + drop_cols: list = [], + id_columns: List[str] = [], + timestamp_column: Optional[str] = None, + label_column: str = "label", + input_columns: List[str] = [], + static_categorical_columns: List[str] = [], + stride: int = 1, + ): + self.label_column = label_column + self.input_columns = input_columns + self.static_categorical_columns = static_categorical_columns + + x_cols = input_columns + y_cols = label_column + + super().__init__( + data_df=data_df, + id_columns=id_columns, + timestamp_column=timestamp_column, + x_cols=x_cols, + y_cols=y_cols, + context_length=context_length, + prediction_length=prediction_length, + group_id=group_id, + drop_cols=drop_cols, + stride=stride, + ) + + def __getitem__(self, index): + # seq_x: batch_size x seq_len x num_x_cols + + time_id = index * self.stride + seq_x = self.X[time_id : time_id + self.context_length].values + # seq_y = self.y[time_id + self.context_length - 1 : time_id + self.context_length].values.ravel() + seq_y = self.y.iloc[time_id + self.context_length - 1].values[0] + + ret = { + "past_values": np_to_torch(seq_x), + "target_values": torch.tensor(seq_y), + } + if self.datetime_col: + ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] + + if self.group_id: + ret["id"] = self.group_id + + if self.static_categorical_columns: + categorical_values = self.data_df[self.static_categorical_columns].values[0, :] + ret["static_categorical_values"] = np_to_torch(categorical_values) + + return ret + + def np_to_torch(data: np.array, float_type=np.float32): if data.dtype == "float": return torch.from_numpy(data.astype(float_type)) From 965d08b304531168bdc82544d93b4d310d2106e9 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:00:47 -0400 Subject: [PATCH 02/14] add utility for reading UCR data --- tsfm_public/toolkit/util.py | 524 ++++++++++++++++++++++++++++++++++++ 1 file changed, 524 insertions(+) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index c471e2f9..d1fe1a9c 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -8,6 +8,7 @@ from distutils.util import strtobool from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import pandas as pd @@ -483,6 +484,529 @@ def convert_tsf_to_dataframe( ) +def convert_tsfile_to_dataframe( + full_file_path_and_name, + return_separate_X_and_y=True, + replace_missing_vals_with="NaN", +): + """Load data from a .ts file into a Pandas DataFrame. + Parameters + ---------- + full_file_path_and_name: str + The full pathname of the .ts file to read. + return_separate_X_and_y: bool + true if X and Y values should be returned as separate Data Frames ( + X) and a numpy array (y), false otherwise. + This is only relevant for data that + replace_missing_vals_with: str + The value that missing values in the text file should be replaced + with prior to parsing. + Returns + ------- + DataFrame (default) or ndarray (i + If return_separate_X_and_y then a tuple containing a DataFrame and a + numpy array containing the relevant time-series and corresponding + class values. + DataFrame + If not return_separate_X_and_y then a single DataFrame containing + all time-series and (if relevant) a column "class_vals" the + associated class values. + """ + # Initialize flags and variables used when parsing the file + metadata_started = False + data_started = False + + has_problem_name_tag = False + has_timestamps_tag = False + has_univariate_tag = False + has_class_labels_tag = False + has_data_tag = False + + previous_timestamp_was_int = None + prev_timestamp_was_timestamp = None + num_dimensions = None + is_first_case = True + instance_list = [] + class_val_list = [] + line_num = 0 + # Parse the file + with open(full_file_path_and_name, "r", encoding="utf-8") as file: + for line in file: + # Strip white space from start/end of line and change to + # lowercase for use below + line = line.strip().lower() + # Empty lines are valid at any point in a file + if line: + # Check if this line contains metadata + # Please note that even though metadata is stored in this + # function it is not currently published externally + if line.startswith("@problemname"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("problemname tag requires an associated value") + # problem_name = line[len("@problemname") + 1:] + has_problem_name_tag = True + metadata_started = True + elif line.startswith("@timestamps"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len != 2: + raise IOError("timestamps tag requires an associated Boolean " "value") + elif tokens[1] == "true": + timestamps = True + elif tokens[1] == "false": + timestamps = False + else: + raise IOError("invalid timestamps value") + has_timestamps_tag = True + metadata_started = True + elif line.startswith("@univariate"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len != 2: + raise IOError("univariate tag requires an associated Boolean " "value") + elif tokens[1] == "true": + # univariate = True + pass + elif tokens[1] == "false": + # univariate = False + pass + else: + raise IOError("invalid univariate value") + has_univariate_tag = True + metadata_started = True + elif line.startswith("@classlabel"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("classlabel tag requires an associated Boolean " "value") + if tokens[1] == "true": + class_labels = True + elif tokens[1] == "false": + class_labels = False + else: + raise IOError("invalid classLabel value") + # Check if we have any associated class values + if token_len == 2 and class_labels: + raise IOError("if the classlabel tag is true then class values " "must be supplied") + has_class_labels_tag = True + class_label_list = [token.strip() for token in tokens[2:]] + metadata_started = True + elif line.startswith("@targetlabel"): + if data_started: + raise IOError("metadata must come before data") + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("targetlabel tag requires an associated Boolean value") + if tokens[1] == "true": + class_labels = True + elif tokens[1] == "false": + class_labels = False + else: + raise IOError("invalid targetlabel value") + if token_len > 2: + raise IOError( + "targetlabel tag should not be accompanied with info " + "apart from true/false, but found " + f"{tokens}" + ) + has_class_labels_tag = True + metadata_started = True + # Check if this line contains the start of data + elif line.startswith("@data"): + if line != "@data": + raise IOError("data tag should not have an associated value") + if data_started and not metadata_started: + raise IOError("metadata must come before data") + else: + has_data_tag = True + data_started = True + # If the 'data tag has been found then metadata has been + # parsed and data can be loaded + elif data_started: + # Check that a full set of metadata has been provided + if ( + not has_problem_name_tag + or not has_timestamps_tag + or not has_univariate_tag + or not has_class_labels_tag + or not has_data_tag + ): + raise IOError("a full set of metadata has not been provided " "before the data") + # Replace any missing values with the value specified + line = line.replace("?", replace_missing_vals_with) + # Check if we are dealing with data that has timestamps + if timestamps: + # We're dealing with timestamps so cannot just split + # line on ':' as timestamps may contain one + has_another_value = False + has_another_dimension = False + timestamp_for_dim = [] + values_for_dimension = [] + this_line_num_dim = 0 + line_len = len(line) + char_num = 0 + while char_num < line_len: + # Move through any spaces + while char_num < line_len and str.isspace(line[char_num]): + char_num += 1 + # See if there is any more data to read in or if + # we should validate that read thus far + if char_num < line_len: + # See if we have an empty dimension (i.e. no + # values) + if line[char_num] == ":": + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype="object")) + this_line_num_dim += 1 + has_another_value = False + has_another_dimension = True + timestamp_for_dim = [] + values_for_dimension = [] + char_num += 1 + else: + # Check if we have reached a class label + if line[char_num] != "(" and class_labels: + class_val = line[char_num:].strip() + if class_val not in class_label_list: + raise IOError( + "the class value '" + + class_val + + "' on line " + + str(line_num + 1) + + " is not " + "valid" + ) + class_val_list.append(class_val) + char_num = line_len + has_another_value = False + has_another_dimension = False + timestamp_for_dim = [] + values_for_dimension = [] + else: + # Read in the data contained within + # the next tuple + if line[char_num] != "(" and not class_labels: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " does " + "not " + "start " + "with a " + "'('" + ) + char_num += 1 + tuple_data = "" + while char_num < line_len and line[char_num] != ")": + tuple_data += line[char_num] + char_num += 1 + if char_num >= line_len or line[char_num] != ")": + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " does " + "not end" + " with a " + "')'" + ) + # Read in any spaces immediately + # after the current tuple + char_num += 1 + while char_num < line_len and str.isspace(line[char_num]): + char_num += 1 + + # Check if there is another value or + # dimension to process after this tuple + if char_num >= line_len: + has_another_value = False + has_another_dimension = False + elif line[char_num] == ",": + has_another_value = True + has_another_dimension = False + elif line[char_num] == ":": + has_another_value = False + has_another_dimension = True + char_num += 1 + # Get the numeric value for the + # tuple by reading from the end of + # the tuple data backwards to the + # last comma + last_comma_index = tuple_data.rfind(",") + if last_comma_index == -1: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that has " + "no comma inside of it" + ) + try: + value = tuple_data[last_comma_index + 1 :] + value = float(value) + except ValueError: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that does " + "not have a valid numeric " + "value" + ) + # Check the type of timestamp that + # we have + timestamp = tuple_data[0:last_comma_index] + try: + timestamp = int(timestamp) + timestamp_is_int = True + timestamp_is_timestamp = False + except ValueError: + timestamp_is_int = False + if not timestamp_is_int: + try: + timestamp = timestamp.strip() + timestamp_is_timestamp = True + except ValueError: + timestamp_is_timestamp = False + # Make sure that the timestamps in + # the file (not just this dimension + # or case) are consistent + if not timestamp_is_timestamp and not timestamp_is_int: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that " + "has an invalid timestamp '" + timestamp + "'" + ) + if ( + previous_timestamp_was_int is not None + and previous_timestamp_was_int + and not timestamp_is_int + ): + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains tuples where the " + "timestamp format is " + "inconsistent" + ) + if ( + prev_timestamp_was_timestamp is not None + and prev_timestamp_was_timestamp + and not timestamp_is_timestamp + ): + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains tuples where the " + "timestamp format is " + "inconsistent" + ) + # Store the values + timestamp_for_dim += [timestamp] + values_for_dimension += [value] + # If this was our first tuple then + # we store the type of timestamp we + # had + if prev_timestamp_was_timestamp is None and timestamp_is_timestamp: + prev_timestamp_was_timestamp = True + previous_timestamp_was_int = False + + if previous_timestamp_was_int is None and timestamp_is_int: + prev_timestamp_was_timestamp = False + previous_timestamp_was_int = True + # See if we should add the data for + # this dimension + if not has_another_value: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + + if timestamp_is_timestamp: + timestamp_for_dim = pd.DatetimeIndex(timestamp_for_dim) + + instance_list[this_line_num_dim].append( + pd.Series( + index=timestamp_for_dim, + data=values_for_dimension, + ) + ) + this_line_num_dim += 1 + timestamp_for_dim = [] + values_for_dimension = [] + elif has_another_value: + raise IOError( + "dimension " + str(this_line_num_dim + 1) + " on " + "line " + str(line_num + 1) + " ends with a ',' that " + "is not followed by " + "another tuple" + ) + elif has_another_dimension and class_labels: + raise IOError( + "dimension " + str(this_line_num_dim + 1) + " on " + "line " + str(line_num + 1) + " ends with a ':' while " + "it should list a class " + "value" + ) + elif has_another_dimension and not class_labels: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype=np.float32)) + this_line_num_dim += 1 + num_dimensions = this_line_num_dim + # If this is the 1st line of data we have seen + # then note the dimensions + if not has_another_value and not has_another_dimension: + if num_dimensions is None: + num_dimensions = this_line_num_dim + if num_dimensions != this_line_num_dim: + raise IOError( + "line " + str(line_num + 1) + " does not have the " + "same number of " + "dimensions as the " + "previous line of " + "data" + ) + # Check that we are not expecting some more data, + # and if not, store that processed above + if has_another_value: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " ends with a ',' that is " + "not followed by another " + "tuple" + ) + elif has_another_dimension and class_labels: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " ends with a ':' while it " + "should list a class value" + ) + elif has_another_dimension and not class_labels: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype="object")) + this_line_num_dim += 1 + num_dimensions = this_line_num_dim + # If this is the 1st line of data we have seen then + # note the dimensions + if not has_another_value and num_dimensions != this_line_num_dim: + raise IOError( + "line " + str(line_num + 1) + " does not have the same " + "number of dimensions as the " + "previous line of data" + ) + # Check if we should have class values, and if so + # that they are contained in those listed in the + # metadata + if class_labels and len(class_val_list) == 0: + raise IOError("the cases have no associated class values") + else: + dimensions = line.split(":") + # If first row then note the number of dimensions ( + # that must be the same for all cases) + if is_first_case: + num_dimensions = len(dimensions) + if class_labels: + num_dimensions -= 1 + for _dim in range(0, num_dimensions): + instance_list.append([]) + is_first_case = False + # See how many dimensions that the case whose data + # in represented in this line has + this_line_num_dim = len(dimensions) + if class_labels: + this_line_num_dim -= 1 + # All dimensions should be included for all series, + # even if they are empty + if this_line_num_dim != num_dimensions: + raise IOError( + "inconsistent number of dimensions. " + "Expecting " + str(num_dimensions) + " but have read " + str(this_line_num_dim) + ) + # Process the data for each dimension + for dim in range(0, num_dimensions): + dimension = dimensions[dim].strip() + + if dimension: + data_series = dimension.split(",") + data_series = [float(i) for i in data_series] + instance_list[dim].append(pd.Series(data_series)) + else: + instance_list[dim].append(pd.Series(dtype="object")) + if class_labels: + class_val_list.append(dimensions[num_dimensions].strip()) + line_num += 1 + # Check that the file was not empty + if line_num: + # Check that the file contained both metadata and data + if metadata_started and not ( + has_problem_name_tag + and has_timestamps_tag + and has_univariate_tag + and has_class_labels_tag + and has_data_tag + ): + raise IOError("metadata incomplete") + + elif metadata_started and not data_started: + raise IOError("file contained metadata but no data") + + elif metadata_started and data_started and len(instance_list) == 0: + raise IOError("file contained metadata but no data") + # Create a DataFrame from the data parsed above + data = pd.DataFrame(dtype=np.float32) + for dim in range(0, num_dimensions): + data["dim_" + str(dim)] = instance_list[dim] + # Check if we should return any associated class labels separately + if class_labels: + if return_separate_X_and_y: + return data, np.asarray(class_val_list) + else: + data["class_vals"] = pd.Series(class_val_list) + return data + else: + return data + else: + raise IOError("empty file") + + def get_split_params( split_config: Dict[str, Union[float, List[Union[int, float]]]], context_length: Optional[int] = None, From 377a0806af285c44b388cdd0cb78d3460e0bde70 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:01:08 -0400 Subject: [PATCH 03/14] force int64 for label --- tsfm_public/toolkit/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 7c54bdd7..9939e314 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -703,7 +703,7 @@ def __getitem__(self, index): ret = { "past_values": np_to_torch(seq_x), - "target_values": torch.tensor(seq_y), + "target_values": torch.tensor(seq_y, dtype=torch.int64), } if self.datetime_col: ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] From 63fb71fe7e49dc698835f3b1516cad87362fa2c7 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:01:15 -0400 Subject: [PATCH 04/14] add sources --- tsfm_public/toolkit/util.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index d1fe1a9c..0b59cde3 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -358,6 +358,18 @@ def convert_tsf_to_dataframe( replace_missing_vals_with="NaN", value_column_name="series_value", ): + """Read a .tsf file into a pandas dataframe. + + Args: + full_file_path_and_name (_type_): _description_ + replace_missing_vals_with (str, optional): _description_. Defaults to "NaN". + value_column_name (str, optional): _description_. Defaults to "series_value". + + This code adopted from the Monash forecasting repository github: + https://github.com/rakshitha123/TSForecasting/blob/master/utils/data_loader.py + + """ + col_names = [] col_types = [] all_data = {} @@ -511,6 +523,11 @@ class values. If not return_separate_X_and_y then a single DataFrame containing all time-series and (if relevant) a column "class_vals" the associated class values. + + + This code adopted from sktime: + https://github.com/sktime/sktime/blob/v0.30.0/sktime/datasets/_readers_writers/ts.py#L32-L615 + """ # Initialize flags and variables used when parsing the file metadata_started = False From 7a04e06e235dd41a3700c42b66a3d10be52f4a13 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:01:33 -0400 Subject: [PATCH 05/14] updates to handle mask --- tsfm_public/toolkit/dataset.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 9939e314..ba0075c7 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -333,7 +333,7 @@ def __getitem__(self, index): time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values ret = { - "past_values": np_to_torch(seq_x), + "past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)), "past_observed_mask": np_to_torch(~np.isnan(seq_x)), } if self.datetime_col: @@ -534,6 +534,7 @@ def __init__( context_length: int = 1, num_workers: int = 1, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): # self.y_cols = y_cols @@ -549,6 +550,7 @@ def __init__( target_columns=target_columns, static_categorical_columns=static_categorical_columns, stride=stride, + fill_value=fill_value, ) self.n_inp = 2 @@ -571,6 +573,7 @@ def __init__( input_columns: List[str] = [], static_categorical_columns: List[str] = [], stride: int = 1, + fill_value: Union[float, int] = 0.0, ): self.target_columns = target_columns self.input_columns = input_columns @@ -590,6 +593,7 @@ def __init__( group_id=group_id, drop_cols=drop_cols, stride=stride, + fill_value=fill_value, ) def __getitem__(self, index): @@ -601,8 +605,9 @@ def __getitem__(self, index): # return _torch(seq_x, seq_y) ret = { - "past_values": np_to_torch(seq_x), - "target_values": np_to_torch(seq_y), + "past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)), + "target_values": np_to_torch(np.nan_to_num(seq_y, nan=self.fill_value)), + "past_observed_mask": np_to_torch(~np.isnan(seq_x)), } if self.datetime_col: ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] @@ -641,6 +646,7 @@ def __init__( context_length: int = 1, num_workers: int = 1, stride: int = 1, + fill_value: Union[float, int] = 0.0, ): super().__init__( data_df=data, @@ -654,6 +660,7 @@ def __init__( label_column=label_column, static_categorical_columns=static_categorical_columns, stride=stride, + fill_value=fill_value, ) self.n_inp = 2 @@ -672,6 +679,7 @@ def __init__( input_columns: List[str] = [], static_categorical_columns: List[str] = [], stride: int = 1, + fill_value: Union[float, int] = 0.0, ): self.label_column = label_column self.input_columns = input_columns @@ -691,6 +699,7 @@ def __init__( group_id=group_id, drop_cols=drop_cols, stride=stride, + fill_value=fill_value, ) def __getitem__(self, index): @@ -702,8 +711,9 @@ def __getitem__(self, index): seq_y = self.y.iloc[time_id + self.context_length - 1].values[0] ret = { - "past_values": np_to_torch(seq_x), - "target_values": torch.tensor(seq_y, dtype=torch.int64), + "past_values": np_to_torch(np.nan_to_num(seq_x, nan=self.fill_value)), + "target_values": torch.tensor(np.nan_to_num(seq_y, nan=self.fill_value), dtype=torch.int64), + "past_observed_mask": np_to_torch(~np.isnan(seq_x)), } if self.datetime_col: ret["timestamp"] = self.timestamps[time_id + self.context_length - 1] From 777f450440171fd2177f1b923b00b85d6e08d411 Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:03:07 -0400 Subject: [PATCH 06/14] Initial fixing of convert_tsfile_to_dataframe --- tsfm_public/toolkit/util.py | 597 +++++------------------------------- 1 file changed, 73 insertions(+), 524 deletions(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 0b59cde3..b5ec8c2c 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -496,532 +496,81 @@ def convert_tsf_to_dataframe( ) -def convert_tsfile_to_dataframe( - full_file_path_and_name, - return_separate_X_and_y=True, - replace_missing_vals_with="NaN", -): - """Load data from a .ts file into a Pandas DataFrame. - Parameters - ---------- - full_file_path_and_name: str - The full pathname of the .ts file to read. - return_separate_X_and_y: bool - true if X and Y values should be returned as separate Data Frames ( - X) and a numpy array (y), false otherwise. - This is only relevant for data that - replace_missing_vals_with: str - The value that missing values in the text file should be replaced - with prior to parsing. - Returns - ------- - DataFrame (default) or ndarray (i - If return_separate_X_and_y then a tuple containing a DataFrame and a - numpy array containing the relevant time-series and corresponding - class values. - DataFrame - If not return_separate_X_and_y then a single DataFrame containing - all time-series and (if relevant) a column "class_vals" the - associated class values. - - - This code adopted from sktime: - https://github.com/sktime/sktime/blob/v0.30.0/sktime/datasets/_readers_writers/ts.py#L32-L615 +def convert_tsfile_to_dataframe(path): + + with open(path, "r", encoding="utf-8") as file: + + df = pd.DataFrame() + + timestamps = False + seriesLength = 0 + classification = False + + id = 0 - """ - # Initialize flags and variables used when parsing the file - metadata_started = False - data_started = False - - has_problem_name_tag = False - has_timestamps_tag = False - has_univariate_tag = False - has_class_labels_tag = False - has_data_tag = False - - previous_timestamp_was_int = None - prev_timestamp_was_timestamp = None - num_dimensions = None - is_first_case = True - instance_list = [] - class_val_list = [] - line_num = 0 - # Parse the file - with open(full_file_path_and_name, "r", encoding="utf-8") as file: for line in file: - # Strip white space from start/end of line and change to - # lowercase for use below - line = line.strip().lower() - # Empty lines are valid at any point in a file - if line: - # Check if this line contains metadata - # Please note that even though metadata is stored in this - # function it is not currently published externally - if line.startswith("@problemname"): - # Check that the data has not started - if data_started: - raise IOError("metadata must come before data") - # Check that the associated value is valid - tokens = line.split(" ") - token_len = len(tokens) - if token_len == 1: - raise IOError("problemname tag requires an associated value") - # problem_name = line[len("@problemname") + 1:] - has_problem_name_tag = True - metadata_started = True - elif line.startswith("@timestamps"): - # Check that the data has not started - if data_started: - raise IOError("metadata must come before data") - # Check that the associated value is valid - tokens = line.split(" ") - token_len = len(tokens) - if token_len != 2: - raise IOError("timestamps tag requires an associated Boolean " "value") - elif tokens[1] == "true": - timestamps = True - elif tokens[1] == "false": - timestamps = False - else: - raise IOError("invalid timestamps value") - has_timestamps_tag = True - metadata_started = True - elif line.startswith("@univariate"): - # Check that the data has not started - if data_started: - raise IOError("metadata must come before data") - # Check that the associated value is valid - tokens = line.split(" ") - token_len = len(tokens) - if token_len != 2: - raise IOError("univariate tag requires an associated Boolean " "value") - elif tokens[1] == "true": - # univariate = True - pass - elif tokens[1] == "false": - # univariate = False - pass - else: - raise IOError("invalid univariate value") - has_univariate_tag = True - metadata_started = True - elif line.startswith("@classlabel"): - # Check that the data has not started - if data_started: - raise IOError("metadata must come before data") - # Check that the associated value is valid - tokens = line.split(" ") - token_len = len(tokens) - if token_len == 1: - raise IOError("classlabel tag requires an associated Boolean " "value") - if tokens[1] == "true": - class_labels = True - elif tokens[1] == "false": - class_labels = False - else: - raise IOError("invalid classLabel value") - # Check if we have any associated class values - if token_len == 2 and class_labels: - raise IOError("if the classlabel tag is true then class values " "must be supplied") - has_class_labels_tag = True - class_label_list = [token.strip() for token in tokens[2:]] - metadata_started = True - elif line.startswith("@targetlabel"): - if data_started: - raise IOError("metadata must come before data") - tokens = line.split(" ") - token_len = len(tokens) - if token_len == 1: - raise IOError("targetlabel tag requires an associated Boolean value") - if tokens[1] == "true": - class_labels = True - elif tokens[1] == "false": - class_labels = False - else: - raise IOError("invalid targetlabel value") - if token_len > 2: - raise IOError( - "targetlabel tag should not be accompanied with info " - "apart from true/false, but found " - f"{tokens}" - ) - has_class_labels_tag = True - metadata_started = True - # Check if this line contains the start of data - elif line.startswith("@data"): - if line != "@data": - raise IOError("data tag should not have an associated value") - if data_started and not metadata_started: - raise IOError("metadata must come before data") - else: - has_data_tag = True - data_started = True - # If the 'data tag has been found then metadata has been - # parsed and data can be loaded - elif data_started: - # Check that a full set of metadata has been provided - if ( - not has_problem_name_tag - or not has_timestamps_tag - or not has_univariate_tag - or not has_class_labels_tag - or not has_data_tag - ): - raise IOError("a full set of metadata has not been provided " "before the data") - # Replace any missing values with the value specified - line = line.replace("?", replace_missing_vals_with) - # Check if we are dealing with data that has timestamps - if timestamps: - # We're dealing with timestamps so cannot just split - # line on ':' as timestamps may contain one - has_another_value = False - has_another_dimension = False - timestamp_for_dim = [] - values_for_dimension = [] - this_line_num_dim = 0 - line_len = len(line) - char_num = 0 - while char_num < line_len: - # Move through any spaces - while char_num < line_len and str.isspace(line[char_num]): - char_num += 1 - # See if there is any more data to read in or if - # we should validate that read thus far - if char_num < line_len: - # See if we have an empty dimension (i.e. no - # values) - if line[char_num] == ":": - if len(instance_list) < (this_line_num_dim + 1): - instance_list.append([]) - instance_list[this_line_num_dim].append(pd.Series(dtype="object")) - this_line_num_dim += 1 - has_another_value = False - has_another_dimension = True - timestamp_for_dim = [] - values_for_dimension = [] - char_num += 1 - else: - # Check if we have reached a class label - if line[char_num] != "(" and class_labels: - class_val = line[char_num:].strip() - if class_val not in class_label_list: - raise IOError( - "the class value '" - + class_val - + "' on line " - + str(line_num + 1) - + " is not " - "valid" - ) - class_val_list.append(class_val) - char_num = line_len - has_another_value = False - has_another_dimension = False - timestamp_for_dim = [] - values_for_dimension = [] - else: - # Read in the data contained within - # the next tuple - if line[char_num] != "(" and not class_labels: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " does " - "not " - "start " - "with a " - "'('" - ) - char_num += 1 - tuple_data = "" - while char_num < line_len and line[char_num] != ")": - tuple_data += line[char_num] - char_num += 1 - if char_num >= line_len or line[char_num] != ")": - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " does " - "not end" - " with a " - "')'" - ) - # Read in any spaces immediately - # after the current tuple - char_num += 1 - while char_num < line_len and str.isspace(line[char_num]): - char_num += 1 - - # Check if there is another value or - # dimension to process after this tuple - if char_num >= line_len: - has_another_value = False - has_another_dimension = False - elif line[char_num] == ",": - has_another_value = True - has_another_dimension = False - elif line[char_num] == ":": - has_another_value = False - has_another_dimension = True - char_num += 1 - # Get the numeric value for the - # tuple by reading from the end of - # the tuple data backwards to the - # last comma - last_comma_index = tuple_data.rfind(",") - if last_comma_index == -1: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " contains a tuple that has " - "no comma inside of it" - ) - try: - value = tuple_data[last_comma_index + 1 :] - value = float(value) - except ValueError: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " contains a tuple that does " - "not have a valid numeric " - "value" - ) - # Check the type of timestamp that - # we have - timestamp = tuple_data[0:last_comma_index] - try: - timestamp = int(timestamp) - timestamp_is_int = True - timestamp_is_timestamp = False - except ValueError: - timestamp_is_int = False - if not timestamp_is_int: - try: - timestamp = timestamp.strip() - timestamp_is_timestamp = True - except ValueError: - timestamp_is_timestamp = False - # Make sure that the timestamps in - # the file (not just this dimension - # or case) are consistent - if not timestamp_is_timestamp and not timestamp_is_int: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " contains a tuple that " - "has an invalid timestamp '" + timestamp + "'" - ) - if ( - previous_timestamp_was_int is not None - and previous_timestamp_was_int - and not timestamp_is_int - ): - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " contains tuples where the " - "timestamp format is " - "inconsistent" - ) - if ( - prev_timestamp_was_timestamp is not None - and prev_timestamp_was_timestamp - and not timestamp_is_timestamp - ): - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " contains tuples where the " - "timestamp format is " - "inconsistent" - ) - # Store the values - timestamp_for_dim += [timestamp] - values_for_dimension += [value] - # If this was our first tuple then - # we store the type of timestamp we - # had - if prev_timestamp_was_timestamp is None and timestamp_is_timestamp: - prev_timestamp_was_timestamp = True - previous_timestamp_was_int = False - - if previous_timestamp_was_int is None and timestamp_is_int: - prev_timestamp_was_timestamp = False - previous_timestamp_was_int = True - # See if we should add the data for - # this dimension - if not has_another_value: - if len(instance_list) < (this_line_num_dim + 1): - instance_list.append([]) - - if timestamp_is_timestamp: - timestamp_for_dim = pd.DatetimeIndex(timestamp_for_dim) - - instance_list[this_line_num_dim].append( - pd.Series( - index=timestamp_for_dim, - data=values_for_dimension, - ) - ) - this_line_num_dim += 1 - timestamp_for_dim = [] - values_for_dimension = [] - elif has_another_value: - raise IOError( - "dimension " + str(this_line_num_dim + 1) + " on " - "line " + str(line_num + 1) + " ends with a ',' that " - "is not followed by " - "another tuple" - ) - elif has_another_dimension and class_labels: - raise IOError( - "dimension " + str(this_line_num_dim + 1) + " on " - "line " + str(line_num + 1) + " ends with a ':' while " - "it should list a class " - "value" - ) - elif has_another_dimension and not class_labels: - if len(instance_list) < (this_line_num_dim + 1): - instance_list.append([]) - instance_list[this_line_num_dim].append(pd.Series(dtype=np.float32)) - this_line_num_dim += 1 - num_dimensions = this_line_num_dim - # If this is the 1st line of data we have seen - # then note the dimensions - if not has_another_value and not has_another_dimension: - if num_dimensions is None: - num_dimensions = this_line_num_dim - if num_dimensions != this_line_num_dim: - raise IOError( - "line " + str(line_num + 1) + " does not have the " - "same number of " - "dimensions as the " - "previous line of " - "data" - ) - # Check that we are not expecting some more data, - # and if not, store that processed above - if has_another_value: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " ends with a ',' that is " - "not followed by another " - "tuple" - ) - elif has_another_dimension and class_labels: - raise IOError( - "dimension " - + str(this_line_num_dim + 1) - + " on line " - + str(line_num + 1) - + " ends with a ':' while it " - "should list a class value" - ) - elif has_another_dimension and not class_labels: - if len(instance_list) < (this_line_num_dim + 1): - instance_list.append([]) - instance_list[this_line_num_dim].append(pd.Series(dtype="object")) - this_line_num_dim += 1 - num_dimensions = this_line_num_dim - # If this is the 1st line of data we have seen then - # note the dimensions - if not has_another_value and num_dimensions != this_line_num_dim: - raise IOError( - "line " + str(line_num + 1) + " does not have the same " - "number of dimensions as the " - "previous line of data" - ) - # Check if we should have class values, and if so - # that they are contained in those listed in the - # metadata - if class_labels and len(class_val_list) == 0: - raise IOError("the cases have no associated class values") - else: - dimensions = line.split(":") - # If first row then note the number of dimensions ( - # that must be the same for all cases) - if is_first_case: - num_dimensions = len(dimensions) - if class_labels: - num_dimensions -= 1 - for _dim in range(0, num_dimensions): - instance_list.append([]) - is_first_case = False - # See how many dimensions that the case whose data - # in represented in this line has - this_line_num_dim = len(dimensions) - if class_labels: - this_line_num_dim -= 1 - # All dimensions should be included for all series, - # even if they are empty - if this_line_num_dim != num_dimensions: - raise IOError( - "inconsistent number of dimensions. " - "Expecting " + str(num_dimensions) + " but have read " + str(this_line_num_dim) - ) - # Process the data for each dimension - for dim in range(0, num_dimensions): - dimension = dimensions[dim].strip() - - if dimension: - data_series = dimension.split(",") - data_series = [float(i) for i in data_series] - instance_list[dim].append(pd.Series(data_series)) - else: - instance_list[dim].append(pd.Series(dtype="object")) - if class_labels: - class_val_list.append(dimensions[num_dimensions].strip()) - line_num += 1 - # Check that the file was not empty - if line_num: - # Check that the file contained both metadata and data - if metadata_started and not ( - has_problem_name_tag - and has_timestamps_tag - and has_univariate_tag - and has_class_labels_tag - and has_data_tag - ): - raise IOError("metadata incomplete") - - elif metadata_started and not data_started: - raise IOError("file contained metadata but no data") - - elif metadata_started and data_started and len(instance_list) == 0: - raise IOError("file contained metadata but no data") - # Create a DataFrame from the data parsed above - data = pd.DataFrame(dtype=np.float32) - for dim in range(0, num_dimensions): - data["dim_" + str(dim)] = instance_list[dim] - # Check if we should return any associated class labels separately - if class_labels: - if return_separate_X_and_y: - return data, np.asarray(class_val_list) - else: - data["class_vals"] = pd.Series(class_val_list) - return data - else: - return data - else: - raise IOError("empty file") + if line.startswith("#"): + continue + + elif line.startswith("@timestamps"): + timestamp_line = line.split() + if timestamp_line[-1].strip()=='true': + timestamps = True + + elif line.startswith("@seriesLength") or line.startswith("@serieslength"): + seriesLength_line = line.split() + seriesLength = int(seriesLength_line[-1].strip()) + + elif line.startswith("@classLabel"): + classification = True + + elif not line.startswith("@"): + + sub_df = pd.DataFrame() + + if not timestamps: + split_line = line.split(":") + split_line = [element.split(",") for element in split_line] + + sub_df['id'] = [id]*seriesLength + + for i, column in enumerate(split_line[:-1]): + sub_df[f'value_{i}'] = column + + else: + split_line = line.split("):") + split_line = [element.split("),(") for element in split_line] + + sub_df['id'] = [id]*seriesLength + + for i, column in enumerate(split_line[:-1]): + timestamp_column = [value.split(",")[0].strip("()") for value in column] + value_column = [value.split(",")[1] for value in column] + sub_df[f'value_{i}'] = value_column + + sub_df.insert(0, "timestamp", timestamp_column) + + target = split_line[-1][0].strip() + sub_df['target'] = [target]*seriesLength + + df = pd.concat([df, sub_df]) + + id += 1 + + ## convert targets to floats or integers + ## non-numeric classification labels will be converted to integers as well + try: + df['target'] = pd.to_numeric(df['target']) + except: + string_labels = df['target'].unique() + label_to_int_map = {str_label: num for num, str_label in enumerate(string_labels)} + df['target'] = df['target'].map(label_to_int_map) + + ## make sure labels are 0 indexed if classification + if classification and df['target'].min() != 0: + df['target'] = df['target'] - 1 + + return df def get_split_params( From 55926e331d91442fb95c3adbf01bb6af80ac128f Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 2 Jul 2024 21:00:10 -0400 Subject: [PATCH 07/14] fixed timestamps --- tsfm_public/toolkit/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index b5ec8c2c..a1ccd4d3 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -544,7 +544,7 @@ def convert_tsfile_to_dataframe(path): sub_df['id'] = [id]*seriesLength for i, column in enumerate(split_line[:-1]): - timestamp_column = [value.split(",")[0].strip("()") for value in column] + timestamp_column = [pd.to_datetime(value.split(",")[0].strip("()")) for value in column] value_column = [value.split(",")[1] for value in column] sub_df[f'value_{i}'] = value_column From 8584e9f41f81808f33838f808adcf754460278a2 Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Mon, 8 Jul 2024 15:49:33 -0400 Subject: [PATCH 08/14] Fixing bug in original --- tsfm_public/toolkit/util.py | 588 +++++++++++++++++++++++++++++++----- 1 file changed, 515 insertions(+), 73 deletions(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index a1ccd4d3..79d666ba 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -496,81 +496,523 @@ def convert_tsf_to_dataframe( ) -def convert_tsfile_to_dataframe(path): - - with open(path, "r", encoding="utf-8") as file: - - df = pd.DataFrame() - - timestamps = False - seriesLength = 0 - classification = False - - id = 0 +def convert_tsfile_to_dataframe( + full_file_path_and_name, + return_separate_X_and_y=True, + replace_missing_vals_with="NaN", +): + """Load data from a .ts file into a Pandas DataFrame. + Parameters + ---------- + full_file_path_and_name: str + The full pathname of the .ts file to read. + return_separate_X_and_y: bool + true if X and Y values should be returned as separate Data Frames ( + X) and a numpy array (y), false otherwise. + This is only relevant for data that + replace_missing_vals_with: str + The value that missing values in the text file should be replaced + with prior to parsing. + Returns + ------- + DataFrame (default) or ndarray (i + If return_separate_X_and_y then a tuple containing a DataFrame and a + numpy array containing the relevant time-series and corresponding + class values. + DataFrame + If not return_separate_X_and_y then a single DataFrame containing + all time-series and (if relevant) a column "class_vals" the + associated class values. + + + This code adopted from sktime: + https://github.com/sktime/sktime/blob/v0.30.0/sktime/datasets/_readers_writers/ts.py#L32-L615 + """ + # Initialize flags and variables used when parsing the file + metadata_started = False + data_started = False + + has_problem_name_tag = False + has_timestamps_tag = False + has_univariate_tag = False + has_class_labels_tag = False + has_data_tag = False + + previous_timestamp_was_int = None + prev_timestamp_was_timestamp = None + num_dimensions = None + is_first_case = True + instance_list = [] + class_val_list = [] + line_num = 0 + # Parse the file + with open(full_file_path_and_name, "r", encoding="utf-8") as file: for line in file: - if line.startswith("#"): - continue - - elif line.startswith("@timestamps"): - timestamp_line = line.split() - if timestamp_line[-1].strip()=='true': - timestamps = True - - elif line.startswith("@seriesLength") or line.startswith("@serieslength"): - seriesLength_line = line.split() - seriesLength = int(seriesLength_line[-1].strip()) - - elif line.startswith("@classLabel"): - classification = True - - elif not line.startswith("@"): - - sub_df = pd.DataFrame() - - if not timestamps: - split_line = line.split(":") - split_line = [element.split(",") for element in split_line] - - sub_df['id'] = [id]*seriesLength - - for i, column in enumerate(split_line[:-1]): - sub_df[f'value_{i}'] = column - - else: - split_line = line.split("):") - split_line = [element.split("),(") for element in split_line] - - sub_df['id'] = [id]*seriesLength - - for i, column in enumerate(split_line[:-1]): - timestamp_column = [pd.to_datetime(value.split(",")[0].strip("()")) for value in column] - value_column = [value.split(",")[1] for value in column] - sub_df[f'value_{i}'] = value_column - - sub_df.insert(0, "timestamp", timestamp_column) - - target = split_line[-1][0].strip() - sub_df['target'] = [target]*seriesLength - - df = pd.concat([df, sub_df]) - - id += 1 - - ## convert targets to floats or integers - ## non-numeric classification labels will be converted to integers as well - try: - df['target'] = pd.to_numeric(df['target']) - except: - string_labels = df['target'].unique() - label_to_int_map = {str_label: num for num, str_label in enumerate(string_labels)} - df['target'] = df['target'].map(label_to_int_map) - - ## make sure labels are 0 indexed if classification - if classification and df['target'].min() != 0: - df['target'] = df['target'] - 1 - - return df + # Strip white space from start/end of line and change to + # lowercase for use below + line = line.strip().lower() + # Empty lines are valid at any point in a file + if line: + # Check if this line contains metadata + # Please note that even though metadata is stored in this + # function it is not currently published externally + if line.startswith("@problemname"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("problemname tag requires an associated value") + # problem_name = line[len("@problemname") + 1:] + has_problem_name_tag = True + metadata_started = True + elif line.startswith("@timestamps"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len != 2: + raise IOError("timestamps tag requires an associated Boolean " "value") + elif tokens[1] == "true": + timestamps = True + elif tokens[1] == "false": + timestamps = False + else: + raise IOError("invalid timestamps value") + has_timestamps_tag = True + metadata_started = True + elif line.startswith("@univariate"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len != 2: + raise IOError("univariate tag requires an associated Boolean " "value") + elif tokens[1] == "true": + # univariate = True + pass + elif tokens[1] == "false": + # univariate = False + pass + else: + raise IOError("invalid univariate value") + has_univariate_tag = True + metadata_started = True + elif line.startswith("@classlabel"): + # Check that the data has not started + if data_started: + raise IOError("metadata must come before data") + # Check that the associated value is valid + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("classlabel tag requires an associated Boolean " "value") + if tokens[1] == "true": + class_labels = True + elif tokens[1] == "false": + class_labels = False + else: + raise IOError("invalid classLabel value") + # Check if we have any associated class values + if token_len == 2 and class_labels: + raise IOError("if the classlabel tag is true then class values " "must be supplied") + has_class_labels_tag = True + class_label_list = [token.strip() for token in tokens[2:]] + metadata_started = True + elif line.startswith("@targetlabel"): + if data_started: + raise IOError("metadata must come before data") + tokens = line.split(" ") + token_len = len(tokens) + if token_len == 1: + raise IOError("targetlabel tag requires an associated Boolean value") + if tokens[1] == "true": + class_labels = True + elif tokens[1] == "false": + class_labels = False + else: + raise IOError("invalid targetlabel value") + if token_len > 2: + raise IOError( + "targetlabel tag should not be accompanied with info " + "apart from true/false, but found " + f"{tokens}" + ) + has_class_labels_tag = True + metadata_started = True + # Check if this line contains the start of data + elif line.startswith("@data"): + if line != "@data": + raise IOError("data tag should not have an associated value") + if data_started and not metadata_started: + raise IOError("metadata must come before data") + else: + has_data_tag = True + data_started = True + # If the 'data tag has been found then metadata has been + # parsed and data can be loaded + elif data_started: + # Check that a full set of metadata has been provided + if ( + not has_problem_name_tag + or not has_timestamps_tag + or not has_univariate_tag + or not has_class_labels_tag + or not has_data_tag + ): + raise IOError("a full set of metadata has not been provided " "before the data") + # Replace any missing values with the value specified + line = line.replace("?", replace_missing_vals_with) + # Check if we are dealing with data that has timestamps + if timestamps: + # We're dealing with timestamps so cannot just split + # line on ':' as timestamps may contain one + has_another_value = False + has_another_dimension = False + timestamp_for_dim = [] + values_for_dimension = [] + this_line_num_dim = 0 + line_len = len(line) + char_num = 0 + while char_num < line_len: + # Move through any spaces + while char_num < line_len and str.isspace(line[char_num]): + char_num += 1 + # See if there is any more data to read in or if + # we should validate that read thus far + if char_num < line_len: + # See if we have an empty dimension (i.e. no + # values) + if line[char_num] == ":": + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype="object")) + this_line_num_dim += 1 + has_another_value = False + has_another_dimension = True + timestamp_for_dim = [] + values_for_dimension = [] + char_num += 1 + else: + # Check if we have reached a class label + if line[char_num] != "(" and class_labels: + class_val = line[char_num:].strip() + class_val_list.append(class_val) + char_num = line_len + has_another_value = False + has_another_dimension = False + timestamp_for_dim = [] + values_for_dimension = [] + else: + # Read in the data contained within + # the next tuple + if line[char_num] != "(" and not class_labels: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " does " + "not " + "start " + "with a " + "'('" + ) + char_num += 1 + tuple_data = "" + while char_num < line_len and line[char_num] != ")": + tuple_data += line[char_num] + char_num += 1 + if char_num >= line_len or line[char_num] != ")": + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " does " + "not end" + " with a " + "')'" + ) + # Read in any spaces immediately + # after the current tuple + char_num += 1 + while char_num < line_len and str.isspace(line[char_num]): + char_num += 1 + + # Check if there is another value or + # dimension to process after this tuple + if char_num >= line_len: + has_another_value = False + has_another_dimension = False + elif line[char_num] == ",": + has_another_value = True + has_another_dimension = False + elif line[char_num] == ":": + has_another_value = False + has_another_dimension = True + char_num += 1 + # Get the numeric value for the + # tuple by reading from the end of + # the tuple data backwards to the + # last comma + last_comma_index = tuple_data.rfind(",") + if last_comma_index == -1: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that has " + "no comma inside of it" + ) + try: + value = tuple_data[last_comma_index + 1 :] + value = float(value) + except ValueError: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that does " + "not have a valid numeric " + "value" + ) + # Check the type of timestamp that + # we have + timestamp = tuple_data[0:last_comma_index] + try: + timestamp = int(timestamp) + timestamp_is_int = True + timestamp_is_timestamp = False + except ValueError: + timestamp_is_int = False + if not timestamp_is_int: + try: + timestamp = timestamp.strip() + timestamp_is_timestamp = True + except ValueError: + timestamp_is_timestamp = False + # Make sure that the timestamps in + # the file (not just this dimension + # or case) are consistent + if not timestamp_is_timestamp and not timestamp_is_int: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains a tuple that " + "has an invalid timestamp '" + timestamp + "'" + ) + if ( + previous_timestamp_was_int is not None + and previous_timestamp_was_int + and not timestamp_is_int + ): + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains tuples where the " + "timestamp format is " + "inconsistent" + ) + if ( + prev_timestamp_was_timestamp is not None + and prev_timestamp_was_timestamp + and not timestamp_is_timestamp + ): + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " contains tuples where the " + "timestamp format is " + "inconsistent" + ) + # Store the values + timestamp_for_dim += [timestamp] + values_for_dimension += [value] + # If this was our first tuple then + # we store the type of timestamp we + # had + if prev_timestamp_was_timestamp is None and timestamp_is_timestamp: + prev_timestamp_was_timestamp = True + previous_timestamp_was_int = False + + if previous_timestamp_was_int is None and timestamp_is_int: + prev_timestamp_was_timestamp = False + previous_timestamp_was_int = True + # See if we should add the data for + # this dimension + if not has_another_value: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + + if timestamp_is_timestamp: + timestamp_for_dim = pd.DatetimeIndex(timestamp_for_dim) + + instance_list[this_line_num_dim].append( + pd.Series( + index=timestamp_for_dim, + data=values_for_dimension, + ) + ) + this_line_num_dim += 1 + timestamp_for_dim = [] + values_for_dimension = [] + elif has_another_value: + raise IOError( + "dimension " + str(this_line_num_dim + 1) + " on " + "line " + str(line_num + 1) + " ends with a ',' that " + "is not followed by " + "another tuple" + ) + elif has_another_dimension and class_labels: + raise IOError( + "dimension " + str(this_line_num_dim + 1) + " on " + "line " + str(line_num + 1) + " ends with a ':' while " + "it should list a class " + "value" + ) + elif has_another_dimension and not class_labels: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype=np.float32)) + this_line_num_dim += 1 + num_dimensions = this_line_num_dim + # If this is the 1st line of data we have seen + # then note the dimensions + if not has_another_value and not has_another_dimension: + if num_dimensions is None: + num_dimensions = this_line_num_dim + if num_dimensions != this_line_num_dim: + raise IOError( + "line " + str(line_num + 1) + " does not have the " + "same number of " + "dimensions as the " + "previous line of " + "data" + ) + # Check that we are not expecting some more data, + # and if not, store that processed above + if has_another_value: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " ends with a ',' that is " + "not followed by another " + "tuple" + ) + elif has_another_dimension and class_labels: + raise IOError( + "dimension " + + str(this_line_num_dim + 1) + + " on line " + + str(line_num + 1) + + " ends with a ':' while it " + "should list a class value" + ) + elif has_another_dimension and not class_labels: + if len(instance_list) < (this_line_num_dim + 1): + instance_list.append([]) + instance_list[this_line_num_dim].append(pd.Series(dtype="object")) + this_line_num_dim += 1 + num_dimensions = this_line_num_dim + # If this is the 1st line of data we have seen then + # note the dimensions + if not has_another_value and num_dimensions != this_line_num_dim: + raise IOError( + "line " + str(line_num + 1) + " does not have the same " + "number of dimensions as the " + "previous line of data" + ) + # Check if we should have class values, and if so + # that they are contained in those listed in the + # metadata + if class_labels and len(class_val_list) == 0: + raise IOError("the cases have no associated class values") + else: + dimensions = line.split(":") + # If first row then note the number of dimensions ( + # that must be the same for all cases) + if is_first_case: + num_dimensions = len(dimensions) + if class_labels: + num_dimensions -= 1 + for _dim in range(0, num_dimensions): + instance_list.append([]) + is_first_case = False + # See how many dimensions that the case whose data + # in represented in this line has + this_line_num_dim = len(dimensions) + if class_labels: + this_line_num_dim -= 1 + # All dimensions should be included for all series, + # even if they are empty + if this_line_num_dim != num_dimensions: + raise IOError( + "inconsistent number of dimensions. " + "Expecting " + str(num_dimensions) + " but have read " + str(this_line_num_dim) + ) + # Process the data for each dimension + for dim in range(0, num_dimensions): + dimension = dimensions[dim].strip() + + if dimension: + data_series = dimension.split(",") + data_series = [float(i) for i in data_series] + instance_list[dim].append(pd.Series(data_series)) + else: + instance_list[dim].append(pd.Series(dtype="object")) + if class_labels: + class_val_list.append(dimensions[num_dimensions].strip()) + line_num += 1 + # Check that the file was not empty + if line_num: + # Check that the file contained both metadata and data + if metadata_started and not ( + has_problem_name_tag + and has_timestamps_tag + and has_univariate_tag + and has_class_labels_tag + and has_data_tag + ): + raise IOError("metadata incomplete") + + elif metadata_started and not data_started: + raise IOError("file contained metadata but no data") + + elif metadata_started and data_started and len(instance_list) == 0: + raise IOError("file contained metadata but no data") + # Create a DataFrame from the data parsed above + data = pd.DataFrame(dtype=np.float32) + for dim in range(0, num_dimensions): + data["dim_" + str(dim)] = instance_list[dim] + # Check if we should return any associated class labels separately + if class_labels: + if return_separate_X_and_y: + return data, np.asarray(class_val_list) + else: + data["class_vals"] = pd.Series(class_val_list) + return data + else: + return data + else: + raise IOError("empty file") def get_split_params( From a12d21cdd875a9f3f124901740716f6a8a204a18 Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 9 Jul 2024 09:43:44 -0400 Subject: [PATCH 09/14] convert_tsfile initial implementation --- tsfm_public/toolkit/util.py | 50 +++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 79d666ba..2d629123 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -1192,3 +1192,53 @@ def join_list_without_repeat(*lists: List[List[Any]]) -> List[Any]: final = final + [item for item in alist if item not in final_set] final_set = set(final) return final + +def convert_tsfile(filename: str, classification=False) -> pd.DataFrame: + """Converts a .ts file into a pandas dataframe. + Returns the result in canonical multi-time series format, with an ID column, and timestamp. + + Args: + filename (str): Input file name. + classification (bool): classification dataset + + Returns: + pd.DataFrame: Converted time series + """ + + final_df = pd.DataFrame() + + df = convert_tsfile_to_dataframe(filename, return_separate_X_and_y=False) + + rows, columns = df.shape + + for i in range(rows): + temp_df = pd.DataFrame() + for j in range(columns): + if j!=columns-1: + series_to_df = df.iloc[i].iloc[j].to_frame().reset_index() + if j==0: + repeat = len(series_to_df) + if type(series_to_df['index'][0])==pd.Timestamp: ## include timestamp columns if data includes timestamps + temp_df['timestamp'] = series_to_df['index'] + temp_df['id'] = [i]*repeat + temp_df[f'value_{j}'] = series_to_df[0] + else: + target = df.iloc[i].iloc[j] + temp_df['target'] = [target]*repeat + + final_df = pd.concat([final_df, temp_df],ignore_index=True) + + ## convert targets to floats or integers + ## non-numeric classification labels will be converted to integers as well + try: + final_df['target'] = pd.to_numeric(final_df['target']) + except: + string_labels = final_df['target'].unique() + label_to_int_map = {str_label: num for num, str_label in enumerate(string_labels)} + final_df['target'] = final_df['target'].map(label_to_int_map) + + ## make sure labels are 0 indexed if classification + if classification and final_df['target'].min() != 0: + final_df['target'] = final_df['target'] - 1 + + return final_df From 4d3296070853f97ca84219e552f37da381175b35 Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:12:13 -0400 Subject: [PATCH 10/14] regression timestamp fix --- tsfm_public/toolkit/util.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 2d629123..2e11a0a8 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -647,6 +647,7 @@ class values. ) has_class_labels_tag = True metadata_started = True + regression = True # Check if this line contains the start of data elif line.startswith("@data"): if line != "@data": @@ -704,6 +705,16 @@ class values. # Check if we have reached a class label if line[char_num] != "(" and class_labels: class_val = line[char_num:].strip() + if not regression: + if class_val not in class_label_list: + raise IOError( + "the class value '" + + class_val + + "' on line " + + str(line_num + 1) + + " is not " + "valid" + ) class_val_list.append(class_val) char_num = line_len has_another_value = False From d6310e01cddcf41e21c42cc5226f1ff30d933d4c Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:22:48 -0400 Subject: [PATCH 11/14] classification_dataset bool --- tsfm_public/toolkit/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 2e11a0a8..b305c80b 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -626,6 +626,7 @@ class values. has_class_labels_tag = True class_label_list = [token.strip() for token in tokens[2:]] metadata_started = True + classification_dataset = True elif line.startswith("@targetlabel"): if data_started: raise IOError("metadata must come before data") @@ -647,7 +648,6 @@ class values. ) has_class_labels_tag = True metadata_started = True - regression = True # Check if this line contains the start of data elif line.startswith("@data"): if line != "@data": @@ -705,7 +705,7 @@ class values. # Check if we have reached a class label if line[char_num] != "(" and class_labels: class_val = line[char_num:].strip() - if not regression: + if classification_dataset: if class_val not in class_label_list: raise IOError( "the class value '" From f64ba7707c9416b51e2bf9504b0d5cdcb135c04e Mon Sep 17 00:00:00 2001 From: hirokimii <145586445+hirokimii@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:25:05 -0400 Subject: [PATCH 12/14] small fix --- tsfm_public/toolkit/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index b305c80b..29fb7cf0 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -648,6 +648,7 @@ class values. ) has_class_labels_tag = True metadata_started = True + classification_dataset = False # Check if this line contains the start of data elif line.startswith("@data"): if line != "@data": From beb7b9d5bec966406a30c571bdfca3af2244b5a4 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:18:20 -0400 Subject: [PATCH 13/14] remove classification specific steps --- tsfm_public/toolkit/util.py | 56 ++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 29fb7cf0..14ce9261 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -1205,20 +1205,19 @@ def join_list_without_repeat(*lists: List[List[Any]]) -> List[Any]: final_set = set(final) return final -def convert_tsfile(filename: str, classification=False) -> pd.DataFrame: + +def convert_tsfile(filename: str) -> pd.DataFrame: """Converts a .ts file into a pandas dataframe. Returns the result in canonical multi-time series format, with an ID column, and timestamp. Args: filename (str): Input file name. - classification (bool): classification dataset Returns: pd.DataFrame: Converted time series """ - - final_df = pd.DataFrame() + dfs = [] df = convert_tsfile_to_dataframe(filename, return_separate_X_and_y=False) rows, columns = df.shape @@ -1226,31 +1225,36 @@ def convert_tsfile(filename: str, classification=False) -> pd.DataFrame: for i in range(rows): temp_df = pd.DataFrame() for j in range(columns): - if j!=columns-1: + if j != columns - 1: series_to_df = df.iloc[i].iloc[j].to_frame().reset_index() - if j==0: + if j == 0: repeat = len(series_to_df) - if type(series_to_df['index'][0])==pd.Timestamp: ## include timestamp columns if data includes timestamps - temp_df['timestamp'] = series_to_df['index'] - temp_df['id'] = [i]*repeat - temp_df[f'value_{j}'] = series_to_df[0] + if ( + type(series_to_df["index"][0]) == pd.Timestamp + ): ## include timestamp columns if data includes timestamps + temp_df["timestamp"] = series_to_df["index"] + temp_df["id"] = [i] * repeat + temp_df[f"value_{j}"] = series_to_df[0] else: target = df.iloc[i].iloc[j] - temp_df['target'] = [target]*repeat - - final_df = pd.concat([final_df, temp_df],ignore_index=True) - - ## convert targets to floats or integers - ## non-numeric classification labels will be converted to integers as well - try: - final_df['target'] = pd.to_numeric(final_df['target']) - except: - string_labels = final_df['target'].unique() - label_to_int_map = {str_label: num for num, str_label in enumerate(string_labels)} - final_df['target'] = final_df['target'].map(label_to_int_map) - - ## make sure labels are 0 indexed if classification - if classification and final_df['target'].min() != 0: - final_df['target'] = final_df['target'] - 1 + temp_df["target"] = [target] * repeat + + dfs.append(temp_df) + + final_df = pd.concat(dfs, ignore_index=True) + + # to be moved to a preprocessor + # ## convert targets to floats or integers + # ## non-numeric classification labels will be converted to integers as well + # try: + # final_df["target"] = pd.to_numeric(final_df["target"]) + # except KeyError: + # string_labels = final_df["target"].unique() + # label_to_int_map = {str_label: num for num, str_label in enumerate(string_labels)} + # final_df["target"] = final_df["target"].map(label_to_int_map) + + # ## make sure labels are 0 indexed if classification + # if classification and final_df["target"].min() != 0: + # final_df["target"] = final_df["target"] - 1 return final_df From 6778707eb229fe5965bf9796af06c177677254d4 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:10:42 -0400 Subject: [PATCH 14/14] add tests, update tsfile conversion --- tests/toolkit/test_util.py | 31 ++++++++++++++++++++++++++++++- tsfm_public/toolkit/util.py | 37 ++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/tests/toolkit/test_util.py b/tests/toolkit/test_util.py index ac17a474..ece0faad 100644 --- a/tests/toolkit/test_util.py +++ b/tests/toolkit/test_util.py @@ -1,9 +1,11 @@ """Tests for util functions""" +import tempfile + import pandas as pd import pytest -from tsfm_public.toolkit.util import convert_to_univariate, get_split_params, train_test_split +from tsfm_public.toolkit.util import convert_to_univariate, convert_tsfile, get_split_params, train_test_split split_cases = [ @@ -73,3 +75,30 @@ def test_convert_to_univariate(ts_data): df_uni = convert_to_univariate( ts_data, timestamp_column=timestamp_column, id_columns=id_columns, target_columns=target_columns ) + + +def test_convert_tsfile(): + data = """#Test +# +#The classes are +#1. one +#2. two +#3. three +@problemName test +@timeStamps false +@missing false +@univariate true +@equalLength true +@seriesLength 5 +@classLabel true 1 2 3 +@data +1,2,3,4,5:1 +10,20,30,40,50:2 +11,12,13,14,15:3 +""" + with tempfile.NamedTemporaryFile() as t: + t.write(data.encode("utf-8")) + t.flush() + df = convert_tsfile(t.name) + + assert df.shape == (15, 3) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 14ce9261..d31d4dfe 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -1215,29 +1215,32 @@ def convert_tsfile(filename: str) -> pd.DataFrame: Returns: pd.DataFrame: Converted time series + + + To do: + - address renaming of columns + - check that we catch all timestamp column types + """ dfs = [] df = convert_tsfile_to_dataframe(filename, return_separate_X_and_y=False) - rows, columns = df.shape + # rows, columns = df.shape + value_columns = [c for c in df.columns if c != "class_vals"] - for i in range(rows): - temp_df = pd.DataFrame() - for j in range(columns): - if j != columns - 1: - series_to_df = df.iloc[i].iloc[j].to_frame().reset_index() - if j == 0: - repeat = len(series_to_df) - if ( - type(series_to_df["index"][0]) == pd.Timestamp - ): ## include timestamp columns if data includes timestamps - temp_df["timestamp"] = series_to_df["index"] - temp_df["id"] = [i] * repeat - temp_df[f"value_{j}"] = series_to_df[0] - else: - target = df.iloc[i].iloc[j] - temp_df["target"] = [target] * repeat + for row in df.itertuples(): + l = len(row.dim_0) + temp_df = pd.DataFrame({"id": [row.Index] * l}) + + for j, c in enumerate(value_columns): + c_data = getattr(row, c) + if isinstance(c_data.index, pd.Timestamp) and "timestamp" not in temp_df.columns: + ## include timestamp columns if data includes timestamps + temp_df["timestamp"] = c_data.index + temp_df[f"value_{j}"] = c_data.values + + temp_df["target"] = row.class_vals dfs.append(temp_df)