From 6639f3ea8eae8241faacb5d66b2a08249f0dbdbd Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:45:13 -0500 Subject: [PATCH] feat: introduce `results` attribute on MMLU evaluator In order to test the validity of our MMLU results or get information on prior runs, we need to be able to access the full set of results from the lm_eval.evaluator.simple_evaluate API. This commit provides that ability by adding a results attribute on the MMLUEvaluator class and storing the results there. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> --- scripts/test_mmlu.py | 53 +++++++++++++++++++++++++++++++- src/instructlab/eval/mmlu.py | 58 +++++++++++++++++++++++++----------- 2 files changed, 93 insertions(+), 18 deletions(-) diff --git a/scripts/test_mmlu.py b/scripts/test_mmlu.py index fd10f27..a6035df 100755 --- a/scripts/test_mmlu.py +++ b/scripts/test_mmlu.py @@ -1,9 +1,41 @@ +# Standard +from typing import Dict, List, Tuple, TypedDict + # First Party from instructlab.eval.mmlu import MMLUEvaluator SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant.""" +class MMLUSample(TypedDict): + """ + Example of a single sample returned from lm_eval when running MMLU. + This is not a comprehensive type, just the subset of fields we care about for this test. + """ + + # Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples. + # They will not be present with few_shot=0 + arguments: List[Tuple[str, str]] + + +def all_samples_contain_system_prompt( + samples: Dict[str, List[MMLUSample]], prompt: str +) -> bool: + """ + Given a mapping of evaluation --> list of results, validates that all few-shot examples + included the system prompt + """ + for topic, samples_set in samples.items(): + for sample in samples_set: + for mmlu_prompt, _ in sample["arguments"]: + if prompt not in mmlu_prompt: + # we are looking for the exact system prompt, so no need to convert to normalize to lowercase + print(f"found a sample in the '{topic}' MMLU topic set") + return False + + return True + + def test_minimal_mmlu(): print("===> Executing 'test_minimal_mmlu'...") try: @@ -14,9 +46,28 @@ def test_minimal_mmlu(): tasks=tasks, system_prompt=SYSTEM_PROMPT, ) - overall_score, individual_scores = mmlu.run() + overall_score, individual_scores = mmlu.run( + extra_args={"log_samples": True, "write_out": True} + ) + samples = mmlu.results["samples"] + print(overall_score) print(individual_scores) + + # we need n-shots > 1 to be able to validate the inclusion of the system prompt + eligible_samples = { + topic: samples[topic] + for topic, shot in mmlu.results["n-shot"].items() + if shot > 1 + } + if eligible_samples: + if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT): + return False + else: + print( + "MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..." + ) + except Exception as exc: print(f"'test_minimal_mmlu' failed: {exc}") return False diff --git a/src/instructlab/eval/mmlu.py b/src/instructlab/eval/mmlu.py index 8637ad4..d4b14a6 100644 --- a/src/instructlab/eval/mmlu.py +++ b/src/instructlab/eval/mmlu.py @@ -7,12 +7,12 @@ """ # Standard -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import os # Third Party -from lm_eval.evaluator import simple_evaluate # type: ignore -from lm_eval.tasks import TaskManager # type: ignore +from lm_eval.evaluator import simple_evaluate +from lm_eval.tasks import TaskManager import torch # First Party @@ -103,6 +103,7 @@ class AbstractMMLUEvaluator(Evaluator): batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'. device PyTorch device (e.g. "cpu" or "cuda:0") for running models system_prompt system prompt to be used when applying the chat template + results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run. """ def __init__( @@ -124,18 +125,33 @@ def __init__( self.few_shots = few_shots self.batch_size = batch_size self.device = device + self._results = None - def run(self, server_url: str | None = None) -> tuple: + @property + def results(self) -> Dict[str, Any] | None: + """ + Returns the results of the last MMLU evaluation, if one has taken place. + + Returns: + Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate` + """ + return self._results + + def run( + self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None + ) -> tuple: """ Runs evaluation Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated + extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function. Returns: overall_score Average score for the task group individual_scores Individual scores for each task in the task group """ + extra_args = {} if not extra_args else extra_args logger.debug(locals()) # TODO: make this a parameter for class? @@ -156,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple: return overall_score, individual_scores - def _run_mmlu(self, server_url: str | None = None) -> dict: + def _run_mmlu( + self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None + ) -> dict: + extra_args = {} if not extra_args else extra_args if server_url is not None: # Requires lm_eval >= 0.4.4 model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface" @@ -172,18 +191,23 @@ def _run_mmlu(self, server_url: str | None = None) -> dict: raise InvalidTasksDirError(self.tasks_dir) tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir) should_apply_chat_template = self.system_prompt is not None - mmlu_output = self._simple_evaluate_with_error_handling( - model=model, - model_args=model_args, - tasks=self.tasks, - num_fewshot=self.few_shots, - batch_size=self.batch_size, - device=self.device, - task_manager=tm, - system_instruction=self.system_prompt, - apply_chat_template=should_apply_chat_template, - ) - results = mmlu_output["results"] + + # configure the args here so users can override them as necessary + simple_evaluate_kwargs = { + "model": model, + "model_args": model_args, + "tasks": self.tasks, + "num_fewshot": self.few_shots, + "batch_size": self.batch_size, + "device": self.device, + "task_manager": tm, + "system_instruction": self.system_prompt, + "apply_chat_template": should_apply_chat_template, + } + simple_evaluate_kwargs.update(extra_args) + + results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs) + self._results = results return results # This method converts general errors from simple_evaluate