diff --git a/src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py b/src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py index 917b12f7..e1f2a8fb 100644 --- a/src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py +++ b/src/modalities/dataloader/preprocessing/tokenization/tokenized_file_writer.py @@ -2,6 +2,7 @@ import pickle from itertools import repeat from pathlib import Path +from typing import BinaryIO import numpy as np @@ -66,12 +67,12 @@ def _update_data_length_in_initial_header(tokenized_dataset_file_path: Path, ind fout.write(data_section_length_in_bytes) @staticmethod - def _write_index_segment(file_descriptor, index_list: list[tuple[int, int]]) -> None: + def _write_index_segment(file_descriptor: BinaryIO, index_list: list[tuple[int, int]]) -> None: file_descriptor.write(pickle.dumps(index_list)) @staticmethod def _write_data_segment( - file_descriptor, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int + file_descriptor: BinaryIO, token_data: list[np.ndarray], token_size_in_bytes: int, write_batch_size: int ) -> list[tuple[int, int]]: def encoded_token_to_bytes(encoded_token: int, token_size_in_bytes: int) -> bytes: # Converts an token_ids to its byte representation.