From 072ac1e6b13b638bd1eefafeb4b5d7896e194d52 Mon Sep 17 00:00:00 2001 From: Jett Date: Wed, 22 May 2024 12:23:58 +0200 Subject: [PATCH] stale eval code purge --- src/delphi/eval/compare_models.py | 91 ------------------------------- src/delphi/eval/constants.py | 26 --------- src/delphi/eval/utils.py | 61 +-------------------- src/delphi/eval/vis.py | 75 ------------------------- tests/eval/test_compare_models.py | 23 -------- tests/eval/test_utils_eval.py | 11 +--- 6 files changed, 3 insertions(+), 284 deletions(-) delete mode 100644 src/delphi/eval/compare_models.py delete mode 100644 src/delphi/eval/constants.py delete mode 100644 tests/eval/test_compare_models.py diff --git a/src/delphi/eval/compare_models.py b/src/delphi/eval/compare_models.py deleted file mode 100644 index e03b300c..00000000 --- a/src/delphi/eval/compare_models.py +++ /dev/null @@ -1,91 +0,0 @@ -from dataclasses import dataclass - -import torch -from jaxtyping import Int -from transformers import PreTrainedModel - -from delphi.eval.utils import get_all_and_next_logprobs_single - - -def identify_model(model: PreTrainedModel) -> str: - return model.config.name_or_path - - -@dataclass -class TokenPrediction: - token: int - base_model_prob: float - lift_model_prob: float - - -@dataclass -class NextTokenStats: - base_model: str - lift_model: str - next_prediction: TokenPrediction - topk: list[TokenPrediction] - - -def compare_models( - model_a: PreTrainedModel, - model_b: PreTrainedModel, - sample_tok: Int[torch.Tensor, "seq"], - top_k: int = 3, -) -> list[NextTokenStats | None]: - """ - Compare the probabilities of the next token for two models and get the top k token predictions according to model B. - Args: - - model_a: The first model (assumed to be the base model) - - model_b: The second model (assumed to be the improved model) - - sample_tok: The tokenized prompt - - top_k: The number of top token predictions to retrieve (default is 5) - Returns: - A list of NextTokenStats objects, one for each token in the prompt. - Tensors are aligned to the token they are predicting (by prepending a -1 to the start of the tensor) - """ - assert ( - model_a.device == model_b.device - ), "Both models must be on the same device for comparison." - - device = model_a.device - sample_tok = sample_tok.to(device) - - logprobs_a, next_probs_a = get_all_and_next_logprobs_single(model_a, sample_tok) - logprobs_b, next_probs_b = get_all_and_next_logprobs_single(model_b, sample_tok) - - probs_a = torch.exp(logprobs_a) - probs_b = torch.exp(logprobs_b) - - top_k_b = torch.topk(probs_b, top_k, dim=-1) - top_k_a_probs = torch.gather(probs_a, 1, top_k_b.indices) - - top_k_b_tokens = top_k_b.indices - top_k_b_probs = top_k_b.values - - comparisons = [] - # ignore first token when evaluating predictions - comparisons.append(None) - - for next_p_a, next_p_b, top_toks_b, top_probs_a, top_probs_b in zip( - next_probs_a, next_probs_b, top_k_b_tokens, top_k_a_probs, top_k_b_probs - ): - nts = NextTokenStats( - base_model=identify_model(model_a), - lift_model=identify_model(model_b), - next_prediction=TokenPrediction( - token=int(next_p_a.item()), - base_model_prob=next_p_a.item(), - lift_model_prob=next_p_b.item(), - ), - topk=[ - TokenPrediction( - token=int(top_toks_b[i].item()), - base_model_prob=top_probs_a[i].item(), - lift_model_prob=top_probs_b[i].item(), - ) - for i in range(top_k) - ], - ) - comparisons.append(nts) - - return comparisons diff --git a/src/delphi/eval/constants.py b/src/delphi/eval/constants.py deleted file mode 100644 index 3a586e00..00000000 --- a/src/delphi/eval/constants.py +++ /dev/null @@ -1,26 +0,0 @@ -corpus_dataset = "delphi-suite/tinystories-v2-clean" -tokenized_corpus_dataset = "delphi-suite/tinystories-v2-clean-tokenized-v0" - -LLAMA2_MODELS = [ - "llama2-100k", - "llama2-200k", - "llama2-400k", - "llama2-800k", - "llama2-1.6m", - "llama2-3.2m", - "llama2-6.4m", - "llama2-12.8m", - "llama2-25.6m", -] - -LLAMA2_NEXT_LOGPROBS_DATASETS_MAP = { - "llama2-100k": "delphi-suite/v0-next-logprobs-llama2-100k", - "llama2-200k": "delphi-suite/v0-next-logprobs-llama2-200k", - "llama2-400k": "delphi-suite/v0-next-logprobs-llama2-400k", - "llama2-800k": "delphi-suite/v0-next-logprobs-llama2-800k", - "llama2-1.6m": "delphi-suite/v0-next-logprobs-llama2-1.6m", - "llama2-3.2m": "delphi-suite/v0-next-logprobs-llama2-3.2m", - "llama2-6.4m": "delphi-suite/v0-next-logprobs-llama2-6.4m", - "llama2-12.8m": "delphi-suite/v0-next-logprobs-llama2-12.8m", - "llama2-25.6m": "delphi-suite/v0-next-logprobs-llama2-25.6m", -} diff --git a/src/delphi/eval/utils.py b/src/delphi/eval/utils.py index faf33757..0026e7a7 100644 --- a/src/delphi/eval/utils.py +++ b/src/delphi/eval/utils.py @@ -1,14 +1,10 @@ -import logging from collections.abc import Callable -from typing import Any, cast +from typing import Any import numpy as np import torch -from datasets import Dataset, load_dataset from jaxtyping import Float, Int -from transformers import PreTrainedModel, PreTrainedTokenizerBase - -from delphi.eval import constants +from transformers import PreTrainedModel def get_all_logprobs( @@ -68,59 +64,6 @@ def get_next_and_top_k_probs( return next_probs, top_k -def load_delphi_dataset(dataset_name: str, split: str, slice: str = "") -> Dataset: - # check that split is either "train" or "validation" - if split not in ["train", "validation"]: - raise ValueError(f"Split must be either 'train' or 'validation', not {split}") - if "/" not in dataset_name: - dataset_name = f"delphi-suite/{dataset_name}" - data_files_str = f"data/{split}-*.parquet" - dataset = load_dataset( - dataset_name, - data_files=data_files_str, - verification_mode="no_checks", - # Currently, load_dataset returns a dataset dict *unless* a split is specified, - # EVEN IF NO SPLIT WITHIN THE DATA FILES SPECIFIED. If there's no split arg, - # huggingface just just says everything is in the "train" split and returns {"train": dataset}. - # In our case the data_files glob already specifies just the validation files, so we - # shouldn't need to specify a split. But we do need to specify a split to get a dataset object, - # or we'd get a Dataset dict. See https://github.com/huggingface/datasets/issues/5189 - split=f"train{slice}", - ) - dataset = cast(Dataset, dataset) - logging.info(f" Loaded {data_files_str} ({len(dataset)} entries)") - return dataset - - -def load_validation_dataset(dataset_name: str, slice: str = "") -> Dataset: - return load_delphi_dataset(dataset_name, "validation", slice) - - -def load_train_dataset(dataset_name: str, slice: str = "") -> Dataset: - return load_delphi_dataset(dataset_name, "train", slice) - - -def tokenize( - tokenizer: PreTrainedTokenizerBase, sample_txt: str -) -> Int[torch.Tensor, "seq"]: - # supposedly this can be different than prepending the bos token id - return cast( - Int[torch.Tensor, "seq"], - tokenizer.encode(tokenizer.bos_token + sample_txt, return_tensors="pt")[0], - ) - - -def load_logprob_dataset(model: str): - return load_dataset(f"transcendingvictor/{model}-validation-logprobs") - - -def load_logprob_datasets(split: str = "validation") -> dict[str, list[list[float]]]: - return { - model: cast(dict, load_logprob_dataset(model)[split])["logprobs"] # type: ignore - for model in constants.LLAMA2_MODELS - } - - def dict_filter_quantile( d: dict[Any, float], q_start: float, q_end: float ) -> dict[Any, float]: diff --git a/src/delphi/eval/vis.py b/src/delphi/eval/vis.py index f13924ad..cea76e88 100644 --- a/src/delphi/eval/vis.py +++ b/src/delphi/eval/vis.py @@ -12,17 +12,6 @@ from transformers import PreTrainedTokenizerBase -def probs_to_colors(probs: Float[torch.Tensor, "next_pos"]) -> list[str]: - # for the endoftext token - # no prediction, no color - colors = ["white"] - for p in probs.tolist(): - red_gap = 150 # the higher it is, the less red the tokens will be - green_blue_val = red_gap + int((255 - red_gap) * (1 - p)) - colors.append(f"rgb(255, {green_blue_val}, {green_blue_val})") - return colors - - def single_loss_diff_to_color(loss_diff: float) -> str: # if loss_diff is negative, we want the color to be red # if loss_diff is positive, we want the color to be green @@ -116,70 +105,6 @@ def token_to_html( ) -def vis_sample_prediction_probs( - sample_tok: Int[torch.Tensor, "pos"], - correct_probs: Float[torch.Tensor, "pos"], - top_k_probs: torch.return_types.topk, - tokenizer: PreTrainedTokenizerBase, -) -> str: - colors = probs_to_colors(correct_probs) - token_htmls = [] - - # Generate a unique ID for this instance (so we can have multiple instances on the same page) - unique_id = str(uuid.uuid4()) - - token_class = f"token_{unique_id}" - hover_div_id = f"hover_info_{unique_id}" - - for i in range(sample_tok.shape[0]): - tok = cast(int, sample_tok[i].item()) - data = {} - if i > 0: - correct_prob = correct_probs[i - 1].item() - data["next"] = to_tok_prob_str(tok, correct_prob, tokenizer) - top_k_probs_tokens = top_k_probs.indices[i - 1] - top_k_probs_values = top_k_probs.values[i - 1] - for j in range(top_k_probs_tokens.shape[0]): - top_tok = top_k_probs_tokens[j].item() - top_tok = cast(int, top_tok) - top_prob = top_k_probs_values[j].item() - data[f"top{j}"] = to_tok_prob_str(top_tok, top_prob, tokenizer) - - token_htmls.append( - token_to_html( - tok, tokenizer, bg_color=colors[i], data=data, class_name=token_class - ) - ) - - html_str = f""" - - {"".join(token_htmls)}
- - """ - display(HTML(html_str)) - return html_str - - def vis_pos_map( pos_list: list[tuple[int, int]], selected_tokens: list[int], diff --git a/tests/eval/test_compare_models.py b/tests/eval/test_compare_models.py deleted file mode 100644 index 0521b0cb..00000000 --- a/tests/eval/test_compare_models.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from delphi.eval.compare_models import NextTokenStats, compare_models -from delphi.eval.utils import load_validation_dataset, tokenize - - -def test_compare_models(): - with torch.set_grad_enabled(False): - model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-1M") - model_instruct = AutoModelForCausalLM.from_pretrained( - "roneneldan/TinyStories-Instruct-1M" - ) - ds_txt = load_validation_dataset("tinystories-v2-clean")["story"] - tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-1M") - sample_tok = tokenize(tokenizer, ds_txt[0]) - K = 3 - model_comparison = compare_models(model, model_instruct, sample_tok, top_k=K) - # ignore the first element comparison - assert model_comparison[0] is None - assert isinstance(model_comparison[1], NextTokenStats) - assert len(model_comparison) == sample_tok.shape[0] - assert len(model_comparison[1].topk) == K diff --git a/tests/eval/test_utils_eval.py b/tests/eval/test_utils_eval.py index a259d16b..54e0034a 100644 --- a/tests/eval/test_utils_eval.py +++ b/tests/eval/test_utils_eval.py @@ -3,11 +3,7 @@ import pytest import torch -from delphi.eval.utils import ( - dict_filter_quantile, - gather_logprobs, - load_validation_dataset, -) +from delphi.eval.utils import dict_filter_quantile, gather_logprobs def test_gather_logprobs(): @@ -50,11 +46,6 @@ def test_gather_logprobs(): assert torch.allclose(result, expected_output) -def test_load_validation_dataset(): - text = load_validation_dataset("tinystories-v2-clean") - tokenized = load_validation_dataset("tinystories-v2-clean-tokenized-v0") - - @pytest.mark.filterwarnings( "ignore::RuntimeWarning" ) # ignore warnings from numpy empty slice