diff --git a/src/delphi/dataset/dataset.py b/src/delphi/dataset/dataset.py new file mode 100644 index 00000000..6b4bad26 --- /dev/null +++ b/src/delphi/dataset/dataset.py @@ -0,0 +1,21 @@ +from datasets import load_dataset +from tqdm.auto import tqdm + +def load_clean_dataset(split: str, tokenized: bool = False) -> list[str]: + # checking just startswith, because you can include slice like "train[:1000]" + assert split.startswith("train") or split.startswith("validation") + hf_ds = load_dataset(f"jbrinkma/tinystories-v2-clean{'-tokenized' if tokenized else ''}", split=split) + dataset = [] + # hf_ds technically isn't guaranteed to be subscriptable, but it is in this case + for sample in tqdm(hf_ds["tokens" if tokenized else "text"]): # type: ignore + dataset.append(sample) + return dataset + +def token_map(tokenized_dataset: list[list[int]]) -> dict[int, list[tuple[int, int]]]: + mapping = {} + + for prompt_idx, prompt in enumerate(tokenized_dataset): + for token_idx, token in enumerate(prompt): + mapping.setdefault(token, []).append((prompt_idx, token_idx)) + + return mapping \ No newline at end of file