diff --git a/src/nanotron/data/nanoset.py b/src/nanotron/data/nanoset.py index 9d62b33d..48ac2716 100644 --- a/src/nanotron/data/nanoset.py +++ b/src/nanotron/data/nanoset.py @@ -91,9 +91,9 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: self.dataset_buffer = memoryview(self.dataset_buffer_mmap) # uint16 -> 2 bytes per token, int32 -> 4 bytes per token - offset = dataset_sample * self.sequence_length * (np.iinfo(self.token_dtype).bits / 8) + offset = int(dataset_sample) * self.sequence_length * int(np.iinfo(self.token_dtype).bits / 8) input_ids_tokens = np.frombuffer( - self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=int(offset) + self.dataset_buffer, dtype=self.token_dtype, count=(self.sequence_length + 1), offset=offset ) # Return tokens as np.int32 as Torch can't handle uint16