diff --git a/tests/toolkit/test_util.py b/tests/toolkit/test_util.py index ac17a474..ece0faad 100644 --- a/tests/toolkit/test_util.py +++ b/tests/toolkit/test_util.py @@ -1,9 +1,11 @@ """Tests for util functions""" +import tempfile + import pandas as pd import pytest -from tsfm_public.toolkit.util import convert_to_univariate, get_split_params, train_test_split +from tsfm_public.toolkit.util import convert_to_univariate, convert_tsfile, get_split_params, train_test_split split_cases = [ @@ -73,3 +75,30 @@ def test_convert_to_univariate(ts_data): df_uni = convert_to_univariate( ts_data, timestamp_column=timestamp_column, id_columns=id_columns, target_columns=target_columns ) + + +def test_convert_tsfile(): + data = """#Test +# +#The classes are +#1. one +#2. two +#3. three +@problemName test +@timeStamps false +@missing false +@univariate true +@equalLength true +@seriesLength 5 +@classLabel true 1 2 3 +@data +1,2,3,4,5:1 +10,20,30,40,50:2 +11,12,13,14,15:3 +""" + with tempfile.NamedTemporaryFile() as t: + t.write(data.encode("utf-8")) + t.flush() + df = convert_tsfile(t.name) + + assert df.shape == (15, 3) diff --git a/tsfm_public/toolkit/util.py b/tsfm_public/toolkit/util.py index 14ce9261..d31d4dfe 100644 --- a/tsfm_public/toolkit/util.py +++ b/tsfm_public/toolkit/util.py @@ -1215,29 +1215,32 @@ def convert_tsfile(filename: str) -> pd.DataFrame: Returns: pd.DataFrame: Converted time series + + + To do: + - address renaming of columns + - check that we catch all timestamp column types + """ dfs = [] df = convert_tsfile_to_dataframe(filename, return_separate_X_and_y=False) - rows, columns = df.shape + # rows, columns = df.shape + value_columns = [c for c in df.columns if c != "class_vals"] - for i in range(rows): - temp_df = pd.DataFrame() - for j in range(columns): - if j != columns - 1: - series_to_df = df.iloc[i].iloc[j].to_frame().reset_index() - if j == 0: - repeat = len(series_to_df) - if ( - type(series_to_df["index"][0]) == pd.Timestamp - ): ## include timestamp columns if data includes timestamps - temp_df["timestamp"] = series_to_df["index"] - temp_df["id"] = [i] * repeat - temp_df[f"value_{j}"] = series_to_df[0] - else: - target = df.iloc[i].iloc[j] - temp_df["target"] = [target] * repeat + for row in df.itertuples(): + l = len(row.dim_0) + temp_df = pd.DataFrame({"id": [row.Index] * l}) + + for j, c in enumerate(value_columns): + c_data = getattr(row, c) + if isinstance(c_data.index, pd.Timestamp) and "timestamp" not in temp_df.columns: + ## include timestamp columns if data includes timestamps + temp_df["timestamp"] = c_data.index + temp_df[f"value_{j}"] = c_data.values + + temp_df["target"] = row.class_vals dfs.append(temp_df)