Skip to content

Commit

Permalink
add tests, update tsfile conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Aug 7, 2024
1 parent beb7b9d commit 6778707
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
31 changes: 30 additions & 1 deletion tests/toolkit/test_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for util functions"""

import tempfile

import pandas as pd
import pytest

from tsfm_public.toolkit.util import convert_to_univariate, get_split_params, train_test_split
from tsfm_public.toolkit.util import convert_to_univariate, convert_tsfile, get_split_params, train_test_split


split_cases = [
Expand Down Expand Up @@ -73,3 +75,30 @@ def test_convert_to_univariate(ts_data):
df_uni = convert_to_univariate(
ts_data, timestamp_column=timestamp_column, id_columns=id_columns, target_columns=target_columns
)


def test_convert_tsfile():
data = """#Test
#
#The classes are
#1. one
#2. two
#3. three
@problemName test
@timeStamps false
@missing false
@univariate true
@equalLength true
@seriesLength 5
@classLabel true 1 2 3
@data
1,2,3,4,5:1
10,20,30,40,50:2
11,12,13,14,15:3
"""
with tempfile.NamedTemporaryFile() as t:
t.write(data.encode("utf-8"))
t.flush()
df = convert_tsfile(t.name)

assert df.shape == (15, 3)
37 changes: 20 additions & 17 deletions tsfm_public/toolkit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,29 +1215,32 @@ def convert_tsfile(filename: str) -> pd.DataFrame:
Returns:
pd.DataFrame: Converted time series
To do:
- address renaming of columns
- check that we catch all timestamp column types
"""

dfs = []
df = convert_tsfile_to_dataframe(filename, return_separate_X_and_y=False)

rows, columns = df.shape
# rows, columns = df.shape
value_columns = [c for c in df.columns if c != "class_vals"]

for i in range(rows):
temp_df = pd.DataFrame()
for j in range(columns):
if j != columns - 1:
series_to_df = df.iloc[i].iloc[j].to_frame().reset_index()
if j == 0:
repeat = len(series_to_df)
if (
type(series_to_df["index"][0]) == pd.Timestamp
): ## include timestamp columns if data includes timestamps
temp_df["timestamp"] = series_to_df["index"]
temp_df["id"] = [i] * repeat
temp_df[f"value_{j}"] = series_to_df[0]
else:
target = df.iloc[i].iloc[j]
temp_df["target"] = [target] * repeat
for row in df.itertuples():
l = len(row.dim_0)
temp_df = pd.DataFrame({"id": [row.Index] * l})

for j, c in enumerate(value_columns):
c_data = getattr(row, c)
if isinstance(c_data.index, pd.Timestamp) and "timestamp" not in temp_df.columns:
## include timestamp columns if data includes timestamps
temp_df["timestamp"] = c_data.index
temp_df[f"value_{j}"] = c_data.values

temp_df["target"] = row.class_vals

dfs.append(temp_df)

Expand Down

0 comments on commit 6778707

Please sign in to comment.