diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index 9c052fec..5a813c77 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -212,6 +212,18 @@ def test_create_timestamps(): 2, [103.5, 107.0], ), + ( + pd.Timestamp(2021, 12, 31), + "QE", + None, + 4, + [ + pd.Timestamp(2022, 3, 31), + pd.Timestamp(2022, 6, 30), + pd.Timestamp(2022, 9, 30), + pd.Timestamp(2022, 12, 31), + ], + ), ] for start, freq, sequence, periods, expected in test_cases: @@ -220,8 +232,9 @@ def test_create_timestamps(): assert ts == expected # test based on provided sequence - ts = create_timestamps(start, time_sequence=sequence, periods=periods) - assert ts == expected + if sequence is not None: + ts = create_timestamps(start, time_sequence=sequence, periods=periods) + assert ts == expected # it is an error to provide neither freq or sequence with pytest.raises(ValueError): diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index a69db0a6..68bfa5db 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -759,14 +759,24 @@ def create_timestamps( # more complex logic is required to support all edge cases if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)): - # if isinstance(freq, str): - # freq = pd._libs.tslibs.timedeltas.Timedelta(freq) - - return pd.date_range( - last_timestamp, - freq=freq, - periods=periods + 1, - ).tolist()[1:] + try: + # try date range directly + return pd.date_range( + last_timestamp, + freq=freq, + periods=periods + 1, + ).tolist()[1:] + except ValueError as e: + # if it fails, we can try to compute a timedelta from the provided string + if isinstance(freq, str): + freq = pd._libs.tslibs.timedeltas.Timedelta(freq) + return pd.date_range( + last_timestamp, + freq=freq, + periods=periods + 1, + ).tolist()[1:] + else: + raise e else: # numerical timestamp column return [last_timestamp + i * freq for i in range(1, periods + 1)]