Skip to content

Commit

Permalink
Add utility functions for text processing and visualization (#17)
Browse files Browse the repository at this point in the history
* Add utility functions for text processing and visualization

* Add compare_models.py (w/ test)
  • Loading branch information
jaidhyani authored Feb 13, 2024
1 parent c72d4aa commit 4ed8b19
Show file tree
Hide file tree
Showing 7 changed files with 453 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- main
pull_request:
branches:
- '*'
- "*"

permissions:
actions: write
Expand Down Expand Up @@ -38,4 +38,4 @@ jobs:
- name: isort
run: isort --profile black --check .
- name: pytest
run: pytest
run: pytest
3 changes: 0 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
"source.organizeImports": "explicit"
},
"python.analysis.typeCheckingMode": "basic",

"isort.args": [
"--profile black"
],

"black-formatter.importStrategy": "fromEnvironment",

}
148 changes: 148 additions & 0 deletions notebooks/vis_demo.ipynb

Large diffs are not rendered by default.

91 changes: 91 additions & 0 deletions src/delphi/eval/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from dataclasses import dataclass

import torch
from jaxtyping import Int
from transformers import PreTrainedModel

from delphi.eval.utils import get_all_and_next_logprobs_single


def identify_model(model: PreTrainedModel) -> str:
return model.config.name_or_path


@dataclass
class TokenPrediction:
token: int
base_model_prob: float
lift_model_prob: float


@dataclass
class NextTokenStats:
base_model: str
lift_model: str
next_prediction: TokenPrediction
topk: list[TokenPrediction]


def compare_models(
model_a: PreTrainedModel,
model_b: PreTrainedModel,
sample_tok: Int[torch.Tensor, "seq"],
top_k: int = 3,
) -> list[NextTokenStats | None]:
"""
Compare the probabilities of the next token for two models and get the top k token predictions according to model B.
Args:
- model_a: The first model (assumed to be the base model)
- model_b: The second model (assumed to be the improved model)
- sample_tok: The tokenized prompt
- top_k: The number of top token predictions to retrieve (default is 5)
Returns:
A list of NextTokenStats objects, one for each token in the prompt.
Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor)
"""
assert (
model_a.device == model_b.device
), "Both models must be on the same device for comparison."

device = model_a.device
sample_tok = sample_tok.to(device)

logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok)
logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok)

probs_a = torch.exp(logprobs_a)
probs_b = torch.exp(logprobs_b)

top_k_b = torch.topk(probs_b, top_k, dim=-1)
top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices)

top_k_b_tokens = top_k_b.indices
top_k_b_probs = top_k_b.values

comparisons = []
# ignore first token when evaluating predictions
comparisons.append(None)

for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip(
next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs
):
nts = NextTokenStats(
base_model=identify_model(model_a),
lift_model=identify_model(model_b),
next_prediction=TokenPrediction(
token=int(next_p_a.item()),
base_model_prob=next_p_a.item(),
lift_model_prob=next_p_b.item(),
),
topk=[
TokenPrediction(
token=int(top_toks_b[i].item()),
base_model_prob=top_probs_a[i].item(),
lift_model_prob=top_probs_b[i].item(),
)
for i in range(top_k)
],
)
comparisons.append(nts)

return comparisons
53 changes: 49 additions & 4 deletions src/delphi/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from datasets import Dataset, load_dataset
from jaxtyping import Float, Int
from transformers import PreTrainedModel, PreTrainedTokenizerBase


def get_all_logprobs(
Expand All @@ -14,19 +15,53 @@ def get_all_logprobs(
return torch.log_softmax(logits, dim=-1)


# convenience wrapper for calling on a single sample
def get_single_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "seq"]
) -> Float[torch.Tensor, "seq vocab"]:
return get_all_logprobs(model, input_ids.unsqueeze(0))[0]


def gather_logprobs(
logprobs: Float[torch.Tensor, "batch seq vocab"],
tokens: Int[torch.Tensor, "batch seq"],
) -> Float[torch.Tensor, "batch seq"]:
return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


def get_next_logprobs(
model: Callable, input_ids: Int[torch.Tensor, "batch seq"]
) -> Float[torch.Tensor, "batch shorter_seq"]:
def get_all_and_next_logprobs(
model: Callable,
input_ids: Int[torch.Tensor, "batch seq"],
) -> tuple[
Float[torch.Tensor, "batch shorter_seq vocab"],
Float[torch.Tensor, "batch shorter_seq"],
]:
logprobs = get_all_logprobs(model, input_ids[:, :-1])
next_tokens = input_ids[:, 1:]
return gather_logprobs(logprobs, next_tokens)
return logprobs, gather_logprobs(logprobs, next_tokens)


def get_all_and_next_logprobs_single(
model: Callable,
input_ids: Int[torch.Tensor, "seq"],
) -> tuple[
Float[torch.Tensor, "shorter_seq vocab"],
Float[torch.Tensor, "shorter_seq"],
]:
all_logprobs, next_logprobs = get_all_and_next_logprobs(
model, input_ids.unsqueeze(0)
)
return all_logprobs[0], next_logprobs[0]


def get_next_and_top_k_probs(
model: PreTrainedModel, input_ids: Int[torch.Tensor, "seq"], k: int = 3
) -> tuple[Float[torch.Tensor, "shorter_seq"], torch.return_types.topk,]:
all_logprobs, next_logprobs = get_all_and_next_logprobs_single(model, input_ids)
all_probs = torch.exp(all_logprobs)
next_probs = torch.exp(next_logprobs)
top_k = torch.topk(all_probs, k, dim=-1)
return next_probs, top_k


def load_validation_dataset(dataset_name: str) -> Dataset:
Expand All @@ -42,3 +77,13 @@ def load_validation_dataset(dataset_name: str) -> Dataset:
split="train",
)
return cast(Dataset, dataset)


def tokenize(
tokenizer: PreTrainedTokenizerBase, sample_txt: str
) -> Int[torch.Tensor, "seq"]:
# supposedly this can be different than prepending the bos token id
return cast(
Int[torch.Tensor, "seq"],
tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0],
)
140 changes: 140 additions & 0 deletions src/delphi/eval/vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import uuid
from typing import cast

import torch
from IPython.core.display import HTML
from IPython.core.display_functions import display
from jaxtyping import Float, Int
from transformers import PreTrainedTokenizerBase


def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]:
# for the endoftext token
# no prediction, no color
colors = ["white"]
for p in probs.tolist():
red_gap = 150 # the higher it is, the less red the tokens will be
green_blue_val = red_gap + int((255 - red_gap) * (1 - p))
colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})")
return colors


def to_tok_prob_str(tok: int, prob: float, tokenizer: PreTrainedTokenizerBase) -> str:
tok_str = tokenizer.decode(tok).replace(" ", " ").replace("\n", r"\n")
prob_str = f"{prob:.2%}"
return f"{prob_str:>6} |{tok_str}|"


def token_to_html(
token: int,
tokenizer: PreTrainedTokenizerBase,
bg_color: str,
data: dict,
) -> str:
data = data or {} # equivalent to if not data: data = {}
# non-breakable space, w/o it leading spaces wouldn't be displayed
str_token = tokenizer.decode(token).replace(" ", " ")

# background or user-select (for \n) goes here
specific_styles = {}
# for now just adds line break or doesn't
br = ""

if bg_color:
specific_styles["background-color"] = bg_color
if str_token == "\n":
# replace new line character with two characters: \ and n
str_token = r"\n"
# add line break in html
br += "<br>"
# this is so we can copy the prompt without "\n"s
specific_styles["user-select"] = "none"

style_str = data_str = ""
# converting style dict into the style attribute
if specific_styles:
inside_style_str = "; ".join(f"{k}: {v}" for k, v in specific_styles.items())
style_str = f" style='{inside_style_str}'"
if data:
data_str = "".join(
f" data-{k}='{v.replace(' ', '&nbsp;')}'" for k, v in data.items()
)
return f"<div class='token'{style_str}{data_str}>{str_token}</div>{br}"


_token_style = {
"border": "1px solid #888",
"display": "inline-block",
# each character of the same width, so we can easily spot a space
"font-family": "monospace",
"font-size": "14px",
"color": "black",
"background-color": "white",
"margin": "1px 0px 1px 1px",
"padding": "0px 1px 1px 1px",
}
_token_style_str = " ".join([f"{k}: {v};" for k, v in _token_style.items()])


def vis_sample_prediction_probs(
sample_tok: Int[torch.Tensor, "pos"],
correct_probs: Float[torch.Tensor, "pos"],
top_k_probs: torch.return_types.topk,
tokenizer: PreTrainedTokenizerBase,
) -> str:
colors = probs_to_colors(correct_probs)
token_htmls = []

# Generate a unique ID for this instance (so we can have multiple instances on the same page)
unique_id = str(uuid.uuid4())

token_class = f"token_{unique_id}"
hover_div_id = f"hover_info_{unique_id}"

for i in range(sample_tok.shape[0]):
tok = cast(int, sample_tok[i].item())
data = {}
if i > 0:
correct_prob = correct_probs[i - 1].item()
data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer)
top_k_probs_tokens = top_k_probs.indices[i - 1]
top_k_probs_values = top_k_probs.values[i - 1]
for j in range(top_k_probs_tokens.shape[0]):
top_tok = top_k_probs_tokens[j].item()
top_tok = cast(int, top_tok)
top_prob = top_k_probs_values[j].item()
data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer)

token_htmls.append(
token_to_html(tok, tokenizer, bg_color=colors[i], data=data).replace(
"class='token'", f"class='{token_class}'"
)
)

html_str = f"""
<style>.{token_class} {{ {_token_style_str} }} #{hover_div_id} {{ height: 100px; font-family: monospace; }}</style>
{"".join(token_htmls)} <div id='{hover_div_id}'></div>
<script>
(function() {{
var token_divs = document.querySelectorAll('.{token_class}');
var hover_info = document.getElementById('{hover_div_id}');
token_divs.forEach(function(token_div) {{
token_div.addEventListener('mousemove', function(e) {{
hover_info.innerHTML = ""
for( var d in this.dataset) {{
hover_info.innerHTML += "<b>" + d + "</b> ";
hover_info.innerHTML += this.dataset[d] + "<br>";
}}
}});
token_div.addEventListener('mouseout', function(e) {{
hover_info.innerHTML = ""
}});
}});
}})();
</script>
"""
display(HTML(html_str))
return html_str
23 changes: 23 additions & 0 deletions tests/eval/test_compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from delphi.eval.compare_models import NextTokenStats, compare_models
from delphi.eval.utils import load_validation_dataset, tokenize


def test_compare_models():
with torch.set_grad_enabled(False):
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M")
model_instruct = AutoModelForCausalLM.from_pretrained(
"roneneldan/TinyStories-Instruct-1M"
)
ds_txt = load_validation_dataset("tinystories-v2-clean")["story"]
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M")
sample_tok = tokenize(tokenizer, ds_txt[0])
K = 3
model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K)
# ignore the first element comparison
assert model_comparison[0] is None
assert isinstance(model_comparison[1], NextTokenStats)
assert len(model_comparison) == sample_tok.shape[0]
assert len(model_comparison[1].topk) == K

0 comments on commit 4ed8b19

Please sign in to comment.