diff --git a/scripts/test_mmlu.py b/scripts/test_mmlu.py index fd10f27..a6035df 100755 --- a/scripts/test_mmlu.py +++ b/scripts/test_mmlu.py @@ -1,9 +1,41 @@ +# Standard +from typing import Dict, List, Tuple, TypedDict + # First Party from instructlab.eval.mmlu import MMLUEvaluator SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant.""" +class MMLUSample(TypedDict): + """ + Example of a single sample returned from lm_eval when running MMLU. + This is not a comprehensive type, just the subset of fields we care about for this test. + """ + + # Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples. + # They will not be present with few_shot=0 + arguments: List[Tuple[str, str]] + + +def all_samples_contain_system_prompt( + samples: Dict[str, List[MMLUSample]], prompt: str +) -> bool: + """ + Given a mapping of evaluation --> list of results, validates that all few-shot examples + included the system prompt + """ + for topic, samples_set in samples.items(): + for sample in samples_set: + for mmlu_prompt, _ in sample["arguments"]: + if prompt not in mmlu_prompt: + # we are looking for the exact system prompt, so no need to convert to normalize to lowercase + print(f"found a sample in the '{topic}' MMLU topic set") + return False + + return True + + def test_minimal_mmlu(): print("===> Executing 'test_minimal_mmlu'...") try: @@ -14,9 +46,28 @@ def test_minimal_mmlu(): tasks=tasks, system_prompt=SYSTEM_PROMPT, ) - overall_score, individual_scores = mmlu.run() + overall_score, individual_scores = mmlu.run( + extra_args={"log_samples": True, "write_out": True} + ) + samples = mmlu.results["samples"] + print(overall_score) print(individual_scores) + + # we need n-shots > 1 to be able to validate the inclusion of the system prompt + eligible_samples = { + topic: samples[topic] + for topic, shot in mmlu.results["n-shot"].items() + if shot > 1 + } + if eligible_samples: + if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT): + return False + else: + print( + "MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..." + ) + except Exception as exc: print(f"'test_minimal_mmlu' failed: {exc}") return False diff --git a/src/instructlab/eval/mmlu.py b/src/instructlab/eval/mmlu.py index 8637ad4..d4b14a6 100644 --- a/src/instructlab/eval/mmlu.py +++ b/src/instructlab/eval/mmlu.py @@ -7,12 +7,12 @@ """ # Standard -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import os # Third Party -from lm_eval.evaluator import simple_evaluate # type: ignore -from lm_eval.tasks import TaskManager # type: ignore +from lm_eval.evaluator import simple_evaluate +from lm_eval.tasks import TaskManager import torch # First Party @@ -103,6 +103,7 @@ class AbstractMMLUEvaluator(Evaluator): batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'. device PyTorch device (e.g. "cpu" or "cuda:0") for running models system_prompt system prompt to be used when applying the chat template + results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run. """ def __init__( @@ -124,18 +125,33 @@ def __init__( self.few_shots = few_shots self.batch_size = batch_size self.device = device + self._results = None - def run(self, server_url: str | None = None) -> tuple: + @property + def results(self) -> Dict[str, Any] | None: + """ + Returns the results of the last MMLU evaluation, if one has taken place. + + Returns: + Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate` + """ + return self._results + + def run( + self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None + ) -> tuple: """ Runs evaluation Attributes server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated + extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function. Returns: overall_score Average score for the task group individual_scores Individual scores for each task in the task group """ + extra_args = {} if not extra_args else extra_args logger.debug(locals()) # TODO: make this a parameter for class? @@ -156,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple: return overall_score, individual_scores - def _run_mmlu(self, server_url: str | None = None) -> dict: + def _run_mmlu( + self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None + ) -> dict: + extra_args = {} if not extra_args else extra_args if server_url is not None: # Requires lm_eval >= 0.4.4 model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface" @@ -172,18 +191,23 @@ def _run_mmlu(self, server_url: str | None = None) -> dict: raise InvalidTasksDirError(self.tasks_dir) tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir) should_apply_chat_template = self.system_prompt is not None - mmlu_output = self._simple_evaluate_with_error_handling( - model=model, - model_args=model_args, - tasks=self.tasks, - num_fewshot=self.few_shots, - batch_size=self.batch_size, - device=self.device, - task_manager=tm, - system_instruction=self.system_prompt, - apply_chat_template=should_apply_chat_template, - ) - results = mmlu_output["results"] + + # configure the args here so users can override them as necessary + simple_evaluate_kwargs = { + "model": model, + "model_args": model_args, + "tasks": self.tasks, + "num_fewshot": self.few_shots, + "batch_size": self.batch_size, + "device": self.device, + "task_manager": tm, + "system_instruction": self.system_prompt, + "apply_chat_template": should_apply_chat_template, + } + simple_evaluate_kwargs.update(extra_args) + + results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs) + self._results = results return results # This method converts general errors from simple_evaluate