Skip to content

Commit

Permalink
Merge pull request #52 from bigscience-workshop/refactor
Browse files Browse the repository at this point in the history
Refactor overall directory structure
  • Loading branch information
jaketae authored Aug 17, 2021
2 parents dbffc07 + 8012c3d commit 77b0910
Show file tree
Hide file tree
Showing 11 changed files with 241 additions and 118 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ A [simple benchmark](https://github.com/bigscience-workshop/Megatron-DeepSpeed/i
[WMT](https://huggingface.co/datasets/wmt19) and [TyDi QA](https://huggingface.co/datasets/tydiqa)
E.g.
```shell
python3 -m evaluation.scripts.simple_benchmark --model_name_or_path=gpt2
python3 -m evaluation.eval --model_name_or_path=gpt2 --eval_tasks tydiqa_secondary
```
84 changes: 84 additions & 0 deletions evaluation/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, List
import os

import torch
from transformers import (
HfArgumentParser,
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
set_seed,
)
import evaluation.tasks # needed for AutoTask.__subclass__() to work correctly
from evaluation.tasks.auto_task import AutoTask
from evaluation.utils.log import get_logger


@dataclass
class EvaluationArguments:
"""
Arguments for any adjustable params in this evaluation script
"""
model_name_or_path: str = field(
metadata={"help": "The model checkpoint that we want to evaluate, could be name or the path."}
)
eval_tasks: List[str] = field(
metadata={"help": "A list of tasks to run the evaluation on, e.g. tydiqa_secondary"}
)
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name."}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
)
tag: Optional[str] = field(
default=None,
metadata={"help": "Identifier for the evaluation run."}
)


def main():
parser = HfArgumentParser((EvaluationArguments, TrainingArguments))
eval_args, train_args = parser.parse_args_into_dataclasses()

if not eval_args.eval_tasks:
raise ValueError('Must provide at least one eval task!')

# initialize device
device = torch.device(train_args.device)

logger = get_logger()
logger.info(f"Beginning evaluation on device {train_args.device}")

# Load model & tokenizer
logger.info("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(eval_args.tokenizer_name or eval_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
eval_args.model_name_or_path, pad_token_id=tokenizer.eos_token,
)
model.config.pad_token_id = model.config.eos_token_id
model.resize_token_embeddings(len(tokenizer))
model.to(device)

# Exporting results
tag = eval_args.tag or datetime.now().strftime("%y%m%d_%H%M%S")
output_dir = os.path.join(train_args.output_dir, tag)
os.makedirs(output_dir, exist_ok=True)

for eval_task in eval_args.eval_tasks:
logger.info(f"Benchmarking {eval_task}...")
task = AutoTask.from_task_name(eval_task, tokenizer=tokenizer, model=model, device=device)
set_seed(train_args.seed)
task.evaluate()
task.save_metrics(output_dir, logger)


if __name__ == "__main__":
main()
117 changes: 0 additions & 117 deletions evaluation/scripts/simple_benchmark.py

This file was deleted.

9 changes: 9 additions & 0 deletions evaluation/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# recursively import every submodule at runtime
# source: https://stackoverflow.com/questions/3365740/how-to-import-all-submodules
import pkgutil

__all__ = []
for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
__all__.append(module_name)
_module = loader.find_module(module_name).load_module(module_name)
globals()[module_name] = _module
38 changes: 38 additions & 0 deletions evaluation/tasks/auto_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
import os

from evaluation.utils.io import save_json


class AutoTask(ABC):
def __init__(self, tokenizer, model, device):
self.tokenizer = tokenizer
self.model = model
self.device = device
self.metrics = {}

@classmethod
def from_task_name(cls, task_name: str, tokenizer, model, device):
all_tasks = cls.__subclasses__()
for task in all_tasks:
if task.get_display_name() == task_name:
return task(tokenizer=tokenizer, model=model, device=device)

raise ValueError(f'Invalid task: {task_name}')

@staticmethod
@abstractmethod
def get_display_name() -> str:
pass

@abstractmethod
def evaluate(self) -> None:
pass

def save_metrics(self, output_dir, logger=None) -> str:
output_filename = os.path.join(output_dir, f"{self.get_display_name()}.json")
save_json(self.metrics, output_filename)

if logger:
logger.info(f"{self.get_display_name()}: result exported to {output_filename}")
return output_filename
File renamed without changes.
File renamed without changes.
File renamed without changes.
95 changes: 95 additions & 0 deletions evaluation/tasks/tydiqa_secondary/tydiqa_secondary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Module for any additional processing required for the TyDi QA dataset
# HuggingFace dataset link: https://huggingface.co/datasets/tydiqa
from typing import Dict

from jinja2 import Template
from torch.utils.data import Dataset
from datasets import load_dataset
from tqdm import tqdm

from evaluation.tasks.auto_task import AutoTask

TEMPLATE = Template(
"""
{%- set _blank=["passage", "text", "text snippet", "context"]|random -%}
{%- set _position = ["above", "following"] |random -%}
{%- if _position == "above" -%}
{{context}}{{"\n"}}
{%- endif -%}
Given the {{_position}} {{_blank}}, answer the question: {{question}}
{%- if _position == "following" -%}
{{"\n"}}{{context}}
{%- endif -%}
{{"\n"}}Answer:
"""
)


class TyDiQADataset(Dataset):
def __init__(self, tokenizer, target_langs):
super().__init__()
tydiqa = load_dataset("tydiqa", "secondary_task", split="validation")
self.items = []

for sample in tydiqa:
lang = sample["id"].split("-")[0]
if lang in target_langs:
# Filter out samples in languages that are not used during training
prompt = TEMPLATE.render(
id = sample["id"],
context = sample["context"],
question = sample["question"],
)
prompt = prompt.strip() # Remove trailing white space and newline

# Tokenize and construct this sample
inputs = tokenizer(
prompt,
padding=True,
return_tensors='pt',
)
self.items.append(
{
"prompt": prompt,
"lang": lang,
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"input_len": inputs["attention_mask"].shape[1],
"target_answer": [ans.lower() for ans in sample["answers"]['text']],
}
)

def __len__(self):
return len(self.items)

def __getitem__(self, index):
return self.items[index]


class TydiqaSecondaryTask(AutoTask):
@staticmethod
def get_display_name() -> str:
return 'tydiqa_secondary'

def evaluate(self) -> None:
dataset = TyDiQADataset(self.tokenizer, target_langs=["english"])

substring_matches = 0
for sample in tqdm(dataset, desc=f'Evaluating {self.get_display_name()}'):
output = self.model.generate(
input_ids=sample["input_ids"].to(self.device),
attention_mask=sample["attention_mask"].to(self.device),
max_length=min(sample["input_len"] * 2, self.model.config.n_positions),
)

prompt_len = len(sample["prompt"])
decoded_output = self.tokenizer.decode(output[0], skip_special_tokens=True)
predicted_answer = decoded_output[prompt_len:]

target_answers = sample["target_answer"]
substring_match = any([target_answer in predicted_answer.lower() for target_answer in target_answers])
substring_matches += substring_match

self.metrics = {
"substring_matches": substring_matches / len(dataset) * 100
}
File renamed without changes.
14 changes: 14 additions & 0 deletions evaluation/utils/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging


def get_logger():
logger = logging.getLogger("evaluation")
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt="%m/%d/%Y %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
return logger

0 comments on commit 77b0910

Please sign in to comment.