Skip to content

Commit

Permalink
addressing assorted nits
Browse files Browse the repository at this point in the history
  • Loading branch information
Jai committed Feb 11, 2024
1 parent 90a2a16 commit 0f213bc
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 95 deletions.
16 changes: 8 additions & 8 deletions notebooks/vis_demo.ipynb

Large diffs are not rendered by default.

42 changes: 12 additions & 30 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
from dataclasses import dataclass
from typing import cast

import torch
import torch.nn as nn
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_correct_and_all_probs
from delphi.eval.utils import get_all_and_next_logprobs_single


@dataclass
class ModelId:
model_name: str


def identify_model(model: PreTrainedModel) -> ModelId:
return ModelId(model_name=model.config.name_or_path)
def identify_model(model: PreTrainedModel) -> str:
return model.config.name_or_path


@dataclass
Expand All @@ -27,28 +20,18 @@ class TokenPrediction:

@dataclass
class NextTokenStats:
base_model: ModelId
lift_model: ModelId
base_model: str
lift_model: str
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def _pad_start(tensor: torch.Tensor) -> torch.Tensor:
value_to_prepend = -1
if len(tensor.shape) == 1:
return torch.cat((torch.tensor([value_to_prepend]), tensor))
else:
# input: 2D tensor of shape [seq_len - 1, top_k]
pre = torch.full((1, tensor.size()[-1]), value_to_prepend)
return torch.cat((pre, tensor), dim=0)


def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> list[NextTokenStats]:
) -> list[NextTokenStats | None]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
Expand All @@ -67,19 +50,18 @@ def compare_models(
device = model_a.device
sample_tok = sample_tok.to(device)

next_probs_a, probs_a = get_correct_and_all_probs(model_a, sample_tok)
next_probs_b, probs_b = get_correct_and_all_probs(model_b, sample_tok)
probs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
probs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)

top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

next_probs_a = _pad_start(next_probs_a)
next_probs_b = _pad_start(next_probs_b)
top_k_b_tokens = _pad_start(top_k_b.indices)
top_k_a_probs = _pad_start(top_k_a_probs)
top_k_b_probs = _pad_start(top_k_b.values)
top_k_b_tokens = top_k_b.indices
top_k_b_probs = top_k_b.values

comparisons = []
# ignore first token when evaluating predictions
comparisons.append(None)

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
Expand Down
24 changes: 18 additions & 6 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def gather_logprobs(


def get_all_and_next_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
model: Callable,
input_ids: Int[torch.Tensor, "batch seq"],
) -> tuple[
Float[torch.Tensor, "batch shorter_seq vocab"],
Float[torch.Tensor, "batch shorter_seq"],
Expand All @@ -40,17 +41,28 @@ def get_all_and_next_logprobs(
return logprobs, gather_logprobs(logprobs, next_tokens)


def get_next_and_top_k_probs(
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
def get_all_and_next_logprobs_single(
model: Callable,
input_ids: Int[torch.Tensor, "seq"],
) -> tuple[
Float[torch.Tensor, "shorter_seq vocab"],
Float[torch.Tensor, "shorter_seq"],
torch.return_types.topk,
]:
all_logprobs, next_logprobs = get_all_and_next_logprobs(
model, input_ids.unsqueeze(0)
)
all_probs = torch.exp(all_logprobs[0])
next_probs = torch.exp(next_logprobs[0])
return all_logprobs[0], next_logprobs[0]


def get_next_and_top_k_probs(
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
) -> tuple[
Float[torch.Tensor, "shorter_seq"],
torch.return_types.topk,
]:
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
all_probs = torch.exp(all_logprobs)
next_probs = torch.exp(next_logprobs)
top_k = torch.topk(all_probs, k, dim=-1)
return next_probs, top_k

Expand Down
70 changes: 19 additions & 51 deletions tests/eval/test_compare_models.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,23 @@
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.eval.compare_models import NextTokenStats, compare_models
from delphi.eval.utils import load_text_from_dataset, load_validation_dataset, tokenize

torch.set_grad_enabled(False)


# define a pytest fixture for the model name
@pytest.fixture
def model_name():
return "roneneldan/TinyStories-1M"


# define a pytest fixture for a default tokenizer using the model_name fixture
@pytest.fixture
def tokenizer(model_name):
return AutoTokenizer.from_pretrained(model_name)


# define a pytest fixture for a default model using the model_name fixture
@pytest.fixture
def model(model_name):
return AutoModelForCausalLM.from_pretrained(model_name)


# define a pytest fixture for the raw dataset
@pytest.fixture
def ds_txt():
return load_text_from_dataset(load_validation_dataset("tinystories-v2-clean"))[:100]


# define a pytest fixture for the tokenized dataset
@pytest.fixture
def ds_tok(tokenizer, ds_txt):
return [tokenize(tokenizer, txt) for txt in ds_txt]


# define a pytest fixture for a tokenized sample
@pytest.fixture
def sample_tok(ds_tok):
return ds_tok[0]


def test_compare_models(model, sample_tok):
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
assert isinstance(model_comparison[0], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[0].topk) == K
from delphi.eval.utils import load_validation_dataset, tokenize


def test_compare_models():
with torch.set_grad_enabled(False):
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
sample_tok = tokenize(tokenizer, ds_txt[0])
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
# ignore the first element comparison
assert model_comparison[0] is None
assert isinstance(model_comparison[1], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[1].topk) == K

0 comments on commit 0f213bc

Please sign in to comment.