diff --git a/scripts/map_tokens.py b/scripts/map_tokens.py index 56abb9a2..81e3e595 100644 --- a/scripts/map_tokens.py +++ b/scripts/map_tokens.py @@ -1,13 +1,15 @@ -from datasets import DatasetDict, IterableDataset, IterableDatasetDict -from datasets.arrow_dataset import Dataset +from typing import cast + +from datasets import Dataset def token_map( - tokenized_dataset: DatasetDict | Dataset | IterableDatasetDict | IterableDataset, + tokenized_dataset: Dataset, ) -> dict[int, list[tuple[int, int]]]: mapping = {} - + tokenized_dataset = cast(Dataset, tokenized_dataset) for prompt_idx, prompt in enumerate(tokenized_dataset): + prompt = cast(dict, prompt) for token_idx, token in enumerate(prompt["tokens"]): mapping.setdefault(token, []).append((prompt_idx, token_idx))