Skip to content

Commit

Permalink
Merge pull request #182 from ibm-granite/issue_181
Browse files Browse the repository at this point in the history
check bounds on inner datasets
  • Loading branch information
wgifford authored Nov 6, 2024
2 parents aac8497 + 91cd577 commit fea9e7e
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@ def pad_zero(self, data_df):
def __len__(self):
return max((len(self.X) - self.context_length - self.prediction_length) // self.stride + 1, 0)

def _check_index(self, index: int) -> int:
if index >= len(self):
raise IndexError("Index exceeds dataset length")

if index < 0:
if -index > len(self):
raise ValueError("Absolute value of index should not exceed dataset length")
index = len(self) + index
return index

def __getitem__(self, index: int):
"""
Args:
Expand Down Expand Up @@ -358,6 +368,8 @@ def __init__(
)

def __getitem__(self, index):
index = self._check_index(index)

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
ret = {
Expand Down Expand Up @@ -565,6 +577,7 @@ def apply_masking_specification(self, past_values_tensor: np.ndarray) -> np.ndar

def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols
index = self._check_index(index)

time_id = index * self.stride

Expand Down Expand Up @@ -715,6 +728,7 @@ def __init__(

def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols
index = self._check_index(index)

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
Expand Down Expand Up @@ -840,6 +854,7 @@ def __init__(

def __getitem__(self, index):
# seq_x: batch_size x seq_len x num_x_cols
index = self._check_index(index)

time_id = index * self.stride
seq_x = self.X[time_id : time_id + self.context_length].values
Expand Down

0 comments on commit fea9e7e

Please sign in to comment.