Skip to content

Commit

Permalink
added test cases for token_map
Browse files Browse the repository at this point in the history
  • Loading branch information
menamerai committed Feb 10, 2024
1 parent 41b84e0 commit 86a1965
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/delphi/dataset/token_map.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import os
from pickle import dump
from typing import cast

from datasets import Dataset


def token_map(
tokenized_dataset: Dataset,
output_path: str | None = None,
file_name: str | None = None,
) -> dict[int, list[tuple[int, int]]]:
"""Return a mapping of tokens to their (prompt_idx, token_idx) locations in the tokenized_dataset.
Args:
tokenized_dataset (Dataset): A tokenized dataset.
save_output (bool, optional): Whether to save the output to a file. Defaults to True.
output_path (str, optional): The output file path. Defaults to "/data/token_map.pkl".
Returns:
dict[int, list[tuple[int, int]]]: A mapping of tokens to their (prompt_idx, token_idx)
Expand All @@ -22,4 +28,9 @@ def token_map(
for token_idx, token in enumerate(prompt["tokens"]):
mapping.setdefault(token, []).append((prompt_idx, token_idx))

if output_path is not None:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "wb") as f:
dump(mapping, f)

return mapping
19 changes: 19 additions & 0 deletions tests/test_dataset.py → tests/dataset/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from pickle import load

import pytest
from datasets import Dataset

Expand Down Expand Up @@ -25,6 +28,7 @@ def test_token_map():
6: [(0, 7), (1, 7), (2, 7)],
7: [(0, 8), (1, 8), (2, 8)],
}

tokenized_dataset = Dataset.from_dict(
{
"tokens": [
Expand Down Expand Up @@ -71,3 +75,18 @@ def test_token_map():
6: [(0, 7), (0, 16), (0, 25)],
7: [(0, 8), (0, 17), (0, 26)],
}

# Test saving the output
tokenized_dataset = Dataset.from_dict(
{
"tokens": [
[0, 1],
]
}
)
mapping = token_map(tokenized_dataset, output_path="./data/token_map.pkl")
assert mapping == {0: [(0, 0)], 1: [(0, 1)]}
with open("./data/token_map.pkl", "rb") as f:
saved_mapping = load(f)
assert saved_mapping == mapping
os.remove("./data/token_map.pkl")

0 comments on commit 86a1965

Please sign in to comment.