Skip to content

Commit

Permalink
logprobs utils & tests (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak authored Feb 6, 2024
1 parent 5448b4b commit 7f7f303
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from collections.abc import Callable

import torch
from jaxtyping import Float, Int


def get_all_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
) -> Float[torch.Tensor, "batch seq vocab"]:
# batch, seq, vocab
logits = model(input_ids).logits
return torch.log_softmax(logits, dim=-1)


def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
) -> Float[torch.Tensor, "batch seq"]:
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


def get_next_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
) -> Float[torch.Tensor, "batch shorter_seq"]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
return gather_logprobs(logprobs, next_tokens)
43 changes: 43 additions & 0 deletions tests/eval/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch

from delphi.eval.utils import gather_logprobs


def test_gather_logprobs():
# vocab size = 3
logprobs = torch.tensor(
[
# batch 0
[
# seq 0
[0.00, 0.01, 0.02],
# seq 1
[0.10, 0.11, 0.12],
],
# batch 1
[
# seq 0
[1.00, 1.01, 1.02],
# seq 1
[1.10, 1.11, 1.12],
],
]
)
tokens = torch.tensor(
[
# batch 0
[0, 2],
# batch 1
[1, 2],
]
)
expected_output = torch.tensor(
[
# batch 0
[0.00, 0.12],
# batch 1
[1.01, 1.12],
]
)
result = gather_logprobs(logprobs, tokens)
assert torch.allclose(result, expected_output)

0 comments on commit 7f7f303

Please sign in to comment.