diff --git a/README.md b/README.md index aca9888..78fc126 100644 --- a/README.md +++ b/README.md @@ -75,5 +75,5 @@ A [simple benchmark](https://github.com/bigscience-workshop/Megatron-DeepSpeed/i [WMT](https://huggingface.co/datasets/wmt19) and [TyDi QA](https://huggingface.co/datasets/tydiqa) E.g. ```shell -python3 -m evaluation.scripts.simple_benchmark --model_name_or_path=gpt2 +python3 -m evaluation.eval --model_name_or_path=gpt2 --eval_tasks tydiqa_secondary ``` diff --git a/evaluation/eval.py b/evaluation/eval.py new file mode 100644 index 0000000..053a291 --- /dev/null +++ b/evaluation/eval.py @@ -0,0 +1,84 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional, List +import os + +import torch +from transformers import ( + HfArgumentParser, + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + set_seed, +) +import evaluation.tasks # needed for AutoTask.__subclass__() to work correctly +from evaluation.tasks.auto_task import AutoTask +from evaluation.utils.log import get_logger + + +@dataclass +class EvaluationArguments: + """ + Arguments for any adjustable params in this evaluation script + """ + model_name_or_path: str = field( + metadata={"help": "The model checkpoint that we want to evaluate, could be name or the path."} + ) + eval_tasks: List[str] = field( + metadata={"help": "A list of tasks to run the evaluation on, e.g. tydiqa_secondary"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name."} + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} + ) + tag: Optional[str] = field( + default=None, + metadata={"help": "Identifier for the evaluation run."} + ) + + +def main(): + parser = HfArgumentParser((EvaluationArguments, TrainingArguments)) + eval_args, train_args = parser.parse_args_into_dataclasses() + + if not eval_args.eval_tasks: + raise ValueError('Must provide at least one eval task!') + + # initialize device + device = torch.device(train_args.device) + + logger = get_logger() + logger.info(f"Beginning evaluation on device {train_args.device}") + + # Load model & tokenizer + logger.info("Loading model...") + tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + + model = AutoModelForCausalLM.from_pretrained( + eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token, + ) + model.config.pad_token_id = model.config.eos_token_id + model.resize_token_embeddings(len(tokenizer)) + model.to(device) + + # Exporting results + tag = eval_args.tag or datetime.now().strftime("%y%m%d_%H%M%S") + output_dir = os.path.join(train_args.output_dir, tag) + os.makedirs(output_dir, exist_ok=True) + + for eval_task in eval_args.eval_tasks: + logger.info(f"Benchmarking {eval_task}...") + task = AutoTask.from_task_name(eval_task, tokenizer=tokenizer, model=model, device=device) + set_seed(train_args.seed) + task.evaluate() + task.save_metrics(output_dir, logger) + + +if __name__ == "__main__": + main() diff --git a/evaluation/scripts/simple_benchmark.py b/evaluation/scripts/simple_benchmark.py deleted file mode 100644 index 8de06bd..0000000 --- a/evaluation/scripts/simple_benchmark.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -from dataclasses import dataclass, field -from datetime import datetime -from typing import Optional -import os - -import torch -from datasets import load_dataset -from tqdm import tqdm -from transformers import ( - HfArgumentParser, - AutoTokenizer, - AutoModelForCausalLM, - set_seed, -) - -from evaluation.datasets.tydiqa import TyDiQADataset -from evaluation.utils.io import save_json - -logger = logging.getLogger(__name__) - -torch_device = "cuda" if torch.cuda.is_available() else "cpu" - - -@dataclass -class EvaluationArguments: - """ - Arguments for any adjustable params in this evaluation script - """ - model_name_or_path: Optional[str] = field( - default=None, - metadata={"help": "The model checkpoint that we want to evaluate, could be name or the path."} - ) - config_name: Optional[str] = field( - default=None, - metadata={"help": "Pretrained config name or path if not the same as model_name."} - ) - tokenizer_name: Optional[str] = field( - default=None, - metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} - ) - output_dir: Optional[str] = field( - default="outputs", - metadata={"help": "Directory for saving evaluation outputs."} - ) - random_seed: Optional[int] = field( - default=24, - metadata={"help": "Customized random seed"} - ) - - -def main(): - # parse arguments - parser = HfArgumentParser(EvaluationArguments) - eval_args, = parser.parse_args_into_dataclasses() - - # set up logging - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - ) - logger.setLevel(logging.INFO) - - # set random seed - set_seed(eval_args.random_seed) - - logger.info("Beginning evaluation") - - # Load model & tokenizer - logger.info("Loading model...") - tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - model = AutoModelForCausalLM.from_pretrained(eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token) - model.config.pad_token_id = model.config.eos_token_id - model.resize_token_embeddings(len(tokenizer)) - model.to(torch_device) - - # Load dataset - logger.info("Benchmarking TyDiQA...") - target_langs = ["english"] - data = load_dataset("tydiqa", "secondary_task", split="validation") - dataset = TyDiQADataset(data, tokenizer, target_langs) - - tydiqa_substring_matches = 0 - for sample in tqdm(dataset): - output = model.generate( - input_ids=sample["input_ids"].to(torch_device), - attention_mask=sample["attention_mask"].to(torch_device), - max_length=min(sample["input_len"]*2, model.config.n_positions), - ) - - prompt_len = len(sample["prompt"]) - decoded_output = tokenizer.decode(output[0], skip_special_tokens=True) - predicted_answer = decoded_output[prompt_len:] - - target_answers = sample["target_answer"] - substring_match = any([target_answer in predicted_answer.lower() for target_answer in target_answers]) - tydiqa_substring_matches += substring_match - tydiqa_metrics = { - "substring_matches": tydiqa_substring_matches / len(dataset) * 100 - } - logger.info(f"TyDiQA: {tydiqa_metrics['substring_matches']}% of samples contain substring matches") - - # Exporting results - if eval_args.output_dir: - output_dir = os.path.join(eval_args.output_dir, datetime.now().strftime("%y%m%d_%H%M%S")) - os.makedirs(output_dir, exist_ok=True) - # Exporting TyDiQA results - tydiqa_filename = os.path.join(output_dir, "tydiqa.json") - save_json(tydiqa_metrics, tydiqa_filename) - logger.info(f"TyDiQA: result exported to {tydiqa_filename}") - - -if __name__ == "__main__": - main() diff --git a/evaluation/tasks/__init__.py b/evaluation/tasks/__init__.py new file mode 100644 index 0000000..acd0f36 --- /dev/null +++ b/evaluation/tasks/__init__.py @@ -0,0 +1,9 @@ +# recursively import every submodule at runtime +# source: https://stackoverflow.com/questions/3365740/how-to-import-all-submodules +import pkgutil + +__all__ = [] +for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): + __all__.append(module_name) + _module = loader.find_module(module_name).load_module(module_name) + globals()[module_name] = _module diff --git a/evaluation/tasks/auto_task.py b/evaluation/tasks/auto_task.py new file mode 100644 index 0000000..dfa6f3b --- /dev/null +++ b/evaluation/tasks/auto_task.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +import os + +from evaluation.utils.io import save_json + + +class AutoTask(ABC): + def __init__(self, tokenizer, model, device): + self.tokenizer = tokenizer + self.model = model + self.device = device + self.metrics = {} + + @classmethod + def from_task_name(cls, task_name: str, tokenizer, model, device): + all_tasks = cls.__subclasses__() + for task in all_tasks: + if task.get_display_name() == task_name: + return task(tokenizer=tokenizer, model=model, device=device) + + raise ValueError(f'Invalid task: {task_name}') + + @staticmethod + @abstractmethod + def get_display_name() -> str: + pass + + @abstractmethod + def evaluate(self) -> None: + pass + + def save_metrics(self, output_dir, logger=None) -> str: + output_filename = os.path.join(output_dir, f"{self.get_display_name()}.json") + save_json(self.metrics, output_filename) + + if logger: + logger.info(f"{self.get_display_name()}: result exported to {output_filename}") + return output_filename diff --git a/evaluation/datasets/__init__.py b/evaluation/tasks/tydiqa_primary/__init__.py similarity index 100% rename from evaluation/datasets/__init__.py rename to evaluation/tasks/tydiqa_primary/__init__.py diff --git a/evaluation/datasets/tydiqa.py b/evaluation/tasks/tydiqa_primary/tydiqa_primary.py similarity index 100% rename from evaluation/datasets/tydiqa.py rename to evaluation/tasks/tydiqa_primary/tydiqa_primary.py diff --git a/evaluation/scripts/__init__.py b/evaluation/tasks/tydiqa_secondary/__init__.py similarity index 100% rename from evaluation/scripts/__init__.py rename to evaluation/tasks/tydiqa_secondary/__init__.py diff --git a/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py new file mode 100644 index 0000000..7d97e42 --- /dev/null +++ b/evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py @@ -0,0 +1,95 @@ +# Module for any additional processing required for the TyDi QA dataset +# HuggingFace dataset link: https://huggingface.co/datasets/tydiqa +from typing import Dict + +from jinja2 import Template +from torch.utils.data import Dataset +from datasets import load_dataset +from tqdm import tqdm + +from evaluation.tasks.auto_task import AutoTask + +TEMPLATE = Template( + """ + {%- set _blank=["passage", "text", "text snippet", "context"]|random -%} + {%- set _position = ["above", "following"] |random -%} + {%- if _position == "above" -%} + {{context}}{{"\n"}} + {%- endif -%} + Given the {{_position}} {{_blank}}, answer the question: {{question}} + {%- if _position == "following" -%} + {{"\n"}}{{context}} + {%- endif -%} + {{"\n"}}Answer: + """ +) + + +class TyDiQADataset(Dataset): + def __init__(self, tokenizer, target_langs): + super().__init__() + tydiqa = load_dataset("tydiqa", "secondary_task", split="validation") + self.items = [] + + for sample in tydiqa: + lang = sample["id"].split("-")[0] + if lang in target_langs: + # Filter out samples in languages that are not used during training + prompt = TEMPLATE.render( + id = sample["id"], + context = sample["context"], + question = sample["question"], + ) + prompt = prompt.strip() # Remove trailing white space and newline + + # Tokenize and construct this sample + inputs = tokenizer( + prompt, + padding=True, + return_tensors='pt', + ) + self.items.append( + { + "prompt": prompt, + "lang": lang, + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "input_len": inputs["attention_mask"].shape[1], + "target_answer": [ans.lower() for ans in sample["answers"]['text']], + } + ) + + def __len__(self): + return len(self.items) + + def __getitem__(self, index): + return self.items[index] + + +class TydiqaSecondaryTask(AutoTask): + @staticmethod + def get_display_name() -> str: + return 'tydiqa_secondary' + + def evaluate(self) -> None: + dataset = TyDiQADataset(self.tokenizer, target_langs=["english"]) + + substring_matches = 0 + for sample in tqdm(dataset, desc=f'Evaluating {self.get_display_name()}'): + output = self.model.generate( + input_ids=sample["input_ids"].to(self.device), + attention_mask=sample["attention_mask"].to(self.device), + max_length=min(sample["input_len"] * 2, self.model.config.n_positions), + ) + + prompt_len = len(sample["prompt"]) + decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True) + predicted_answer = decoded_output[prompt_len:] + + target_answers = sample["target_answer"] + substring_match = any([target_answer in predicted_answer.lower() for target_answer in target_answers]) + substring_matches += substring_match + + self.metrics = { + "substring_matches": substring_matches / len(dataset) * 100 + } diff --git a/evaluation/datasets/wmt.py b/evaluation/tasks/wmt/wmt.py similarity index 100% rename from evaluation/datasets/wmt.py rename to evaluation/tasks/wmt/wmt.py diff --git a/evaluation/utils/log.py b/evaluation/utils/log.py new file mode 100644 index 0000000..63854c8 --- /dev/null +++ b/evaluation/utils/log.py @@ -0,0 +1,14 @@ +import logging + + +def get_logger(): + logger = logging.getLogger("evaluation") + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt="%m/%d/%Y %H:%M:%S", + ) + handler = logging.StreamHandler() + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + return logger \ No newline at end of file