diff --git a/Dockerfile b/Dockerfile index 8eb3bfa2a..edd447e19 100644 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,15 @@ WORKDIR /build/fairscale RUN git checkout fixing_memory_issues_with_keeping_overlap_may24 RUN pip3 install -e . +RUN pip install \ + py-rouge==1.1 \ + rouge_score==0.1.2 \ + parlai==1.7.1 \ + evaluate==0.4.0 + +ENV NLTK_DATA="/usr/share/nltk_data" +RUN python -c "import nltk; nltk.download('punkt', download_dir='${NLTK_DATA}')" + # Install metaseq WORKDIR /build RUN git clone https://github.com/facebookresearch/metaseq.git diff --git a/metaseq/cli/inference.py b/metaseq/cli/inference.py new file mode 100644 index 000000000..6feb4e72d --- /dev/null +++ b/metaseq/cli/inference.py @@ -0,0 +1,765 @@ +import argparse +import json +import logging +import os +import pathlib +import shutil +import sys +import time +import operator +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from glob import glob +from typing import Callable, Dict, List, Optional + +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from metaseq import options, utils +from metaseq.data import JsonlDataset, data_utils +from metaseq.data.prompt_generator import PromptGenerator +from metaseq.dataclass.configs import MetaseqConfig +from metaseq.dataclass.utils import convert_namespace_to_omegaconf +from metaseq.distributed import utils as distributed_utils +from metaseq.generation_metrics.metrics import GenerationMetrics, evaluate_inference_files +from metaseq.hub_utils import GeneratorInterface +from metaseq.logging import get_logger +from metaseq.utils import flatten_config + +logger = get_logger(__name__) + +TOTAL_TOKENS_GENERATED = 0 + + +def _tokenize_one_json( + json: Dict[str, str], + encode_fn: Callable[[str], List[int]], + prompt_generator: PromptGenerator, + prompt_eos_text: str, + delimiter=None, + max_seq_len: int = 2048, + max_prompt_len: int = 0, + truncation: str = "right", + prompt_end_text: Optional[str] = None, +): + """ + Inference Tokenization Function for Json Input + + :param Dict[str, str] json: The json line containing the prompt and target + :param Callable encode_fn: The function to use to encode the prompt and + target. Encodes a string to a list of token ids. + :param PromptGenerator prompt_generator: The prompt generator to use + :param str prompt_eos_text: The text to append at the end of prompt + :param str delimiter: The delimiter to use to split the prompt and target + :param int max_seq_len: The maximum sequence length + :param int max_prompt_len: The maximum prompt length + :param Optional[str] prompt_end_text: If provided and the prompt + doesn't end with the specified text before being given to the model then + the last tokens of the prompt will be replaced with the tokens + corresponding to this text, defaults to None + + :return: (torch.LongTensor, str) Prompt tokens and Target text + """ + source_key = "text" if "text" in json else "src" + assert source_key in json, "json must contain a 'text' or 'src' field" + + target_key = "tgt" if "tgt" in json else None + + # NOTE: if you pass in a delimiter, we assume the target is already included + # in the `text/src` field, and we split on the delimiter to get the prompt + # and target respectively. That means that if you pass in a delimiter, you + # should not pass in a `tgt` field. + if delimiter is not None and target_key is not None: + raise AssertionError("You passed in a delimiter and a target key, but the target key will be ignored.") + + if delimiter is None: + prompt = json[source_key] + target = json[target_key] if target_key is not None else None + else: + prompt, target = json[source_key].rsplit(delimiter, 1) + + prompt = prompt_generator.get_next_prompt(prompt) + prompt_tokens = encode_fn(prompt.rstrip(" ")) + if target is not None: + target = target.rstrip() + + prompt_eos_tokens = encode_fn(prompt_eos_text) + if max_prompt_len > 0: + max_prompt_len = min(max_prompt_len - len(prompt_eos_tokens), max_seq_len - len(prompt_eos_tokens)) + + prompt_tokens, _ = data_utils.truncate_source(prompt_tokens, max_prompt_len, truncation, len(prompt_eos_tokens)) + + if prompt_end_text is not None and prompt_end_text != "": + force_prompt_end_tokens = encode_fn(prompt_end_text) + + # replace last `len(force_prompt_end_tokens)` tokens with + # `force_prompt_end_tokens` if necessary + if prompt_tokens[-len(force_prompt_end_tokens):] != force_prompt_end_tokens: + prompt_tokens = prompt_tokens[:-len(force_prompt_end_tokens)] + force_prompt_end_tokens + + prompt_tokens += prompt_eos_tokens + return torch.LongTensor(prompt_tokens), target + + +def update_generation_config(cfg: MetaseqConfig): + cfg.generation.sampling_topp = (cfg.generation.sampling_topp if cfg.generation.sampling_topp > 0.0 else -1.0) + cfg.generation.sampling = cfg.generation.sampling_topp > 0.0 + + assert cfg.generation.temperature >= 0.0, "temperature must be positive" + if cfg.generation.temperature == 0.0: + cfg.generation.sampling = False + cfg.generation.temperature = 1.0 + cfg.generation.sampling_topp = -1 + + +def convert_generated_tokens_to_text( + generator_interface, + generator, + indexs, + src_lengths, + all_tokens, + all_scores, + all_distributions, + all_logits, + best_n, + num_logprobs, + num_logits, + target_text, + input_items, + echo_prompt=False, + copy_input_to_output=False, + input_keys_to_copy=None, + output_tokens_and_offsets=False, + **kwargs, +): + results = [] + tokens_generated = 0 + batch_size = all_tokens.size(0) + for batch_idx in range(batch_size): + instance_result = defaultdict(list) + if copy_input_to_output and len(input_items) > 0: + for key in input_keys_to_copy.split(","): + instance_result[key] = input_items[batch_idx].get(key, None) + instance_result["instance_idx"] = indexs[batch_idx].item() + + if len(target_text) > 0: + instance_result["target_text"] = target_text[batch_idx] + if output_tokens_and_offsets: + instance_result["target_tokens"] = [ + generator_interface.bpe.bpe.decode([t]) for t in generator_interface.encode_fn(target_text[batch_idx]) + ] + + for beam_idx in range(min(generator.beam_size, best_n)): + # first beam is always the highest scoring + tokens = all_tokens[batch_idx, beam_idx].tolist() + scores = all_scores[batch_idx, beam_idx].tolist() + distributions = all_distributions[batch_idx, beam_idx] if num_logprobs > 0 else None + logits = all_logits[batch_idx, beam_idx] if num_logits > 0 else None + + src_length = src_lengths[batch_idx].item() + prompt = tokens[1:src_length][:generator.max_len_b] + if not echo_prompt: + tokens = tokens[src_length:][:generator.max_len_b] + scores = scores[src_length:][:generator.max_len_b] + if num_logprobs > 0: + distributions = distributions[src_length:][:generator.max_len_b] + if num_logits > 0: + logits = logits[src_length:][:generator.max_len_b] + tokens_generated += len(tokens) + else: + tokens_generated += len(tokens) - src_length + + tokens, scores, distributions, logits = generator_interface._filter_special( + generator_interface._pad_token_ind, generator_interface._special_token_inds, tokens, scores, distributions, + logits + ) + + # cut off 'eos' tokens at the start + tokens_no_eos = tokens[1:] if echo_prompt else tokens + scores_with_eos = [None] + scores[1:] if echo_prompt else scores + # turn it into a string + prompt = generator_interface.bpe.bpe.decode(prompt) + generated_text = generator_interface.bpe.bpe.decode(tokens_no_eos) + # re-encode it so we get offsets + token_offsets = [s for s, e in generator_interface.bpe.bpe.encode(generated_text).offsets] + + if "prompt_text" not in instance_result: + instance_result["prompt_text"] = prompt + + beam_result = {"generated_text": generated_text} + + decoded_tokens = [generator_interface.bpe.bpe.decode([t]) for t in tokens] + instance_length = len(decoded_tokens) + assert instance_length == len(scores_with_eos) + # TODO: len(generator_interface.bpe.bpe.encode(generator_interface.bpe.bpe.decode([50118, 50118]))) != len([50118, 50118]) + # assert instance_length == len(result["text_offset"]) + if output_tokens_and_offsets: + beam_result.update( + { + "tokens": decoded_tokens, + # text offset is useful for cutting off prompts or prefixes + # or evaluating PPL on just a subset of tokens + "text_offset": token_offsets, + "token_scores": scores_with_eos, + } + ) + + if num_logprobs > 0: + # final result is a List[Dict[str, float]] + # where each item in the list corresponds to a token in the + # sequence, and the dict provides the probabilities of the + # top-k tokens at that timestep. + out_logprobs = [] + all_top_scores, all_top_tokens = distributions.topk(k=num_logprobs, dim=-1) + for top_scores, top_tokens in zip(all_top_scores, all_top_tokens): + logprob_item = { + generator_interface.bpe.bpe.decode([t.item()]): { + 'token_id': t.item(), + 'logprob_score': s.item() + } + for t, s in zip(top_tokens, top_scores) + } + out_logprobs.append(logprob_item) + + if echo_prompt: + # use null instead of giving bunk probs for EOS token + beam_result["top_logprobs"] = [None] + out_logprobs[1:] + else: + beam_result["top_logprobs"] = out_logprobs + + # Seq generation may finish after reaching max_generation_length or after generating an + # token. We should account for both cases in the assertion below. + assert instance_length in [len(beam_result["top_logprobs"]) - 1, generator.max_generation_length] + + if num_logits > 0: + assert all_logits is not None + # similar to top_logprobs + out_logits = [] + all_top_logits, all_top_tokens = logits.topk(k=num_logits, dim=-1) + for top_logits, top_tokens in zip(all_top_logits, all_top_tokens): + logit_item = { + generator_interface.bpe.bpe.decode([t.item()]): { + 'token_id': t.item(), + 'logit': l.item() + } + for t, l in zip(top_tokens, top_logits) + } + out_logits.append(logit_item) + if echo_prompt: + # use null instead of giving zero logits for EOS token at the start + beam_result["top_logits"] = [None] + out_logits[1:] + else: + beam_result["top_logits"] = out_logits + + # Seq generation may finish after reaching max_generation_length or after generating an + # token. We should account for both cases in the assertion below. + assert instance_length in [len(beam_result["top_logits"]) - 1, generator.max_generation_length] + + instance_result["beam_results"].append(beam_result) + results.append(instance_result) + return results, tokens_generated + + +def write_inference_results_to_file(output_file: str, **kwargs): + global TOTAL_TOKENS_GENERATED + + try: + if distributed_utils.get_model_parallel_rank() == 0: + batch_result, tokens_generated = convert_generated_tokens_to_text(**kwargs) + TOTAL_TOKENS_GENERATED += (tokens_generated * distributed_utils.get_data_parallel_world_size()) + + progress_bar = kwargs.pop("progress_bar") + translate_time = kwargs.pop("translate_time") + if distributed_utils.get_global_rank() == 0: + progress_bar.set_postfix( + { + "Tokens Per Second": + f"{tokens_generated * distributed_utils.get_data_parallel_world_size() / translate_time:.2f}" + } + ) + + with open(output_file, "a", encoding="utf-8") as f: + for result in batch_result: + f.write(json.dumps(result, ensure_ascii=False) + "\n") + + except Exception as e: + import traceback + logger.error(f"Error in writing inference results to file.") + logger.error(f"stack trace: {traceback.format_exc()}") + + +def generate(cfg: MetaseqConfig, args: argparse.Namespace): + global TOTAL_TOKENS_GENERATED + TOTAL_TOKENS_GENERATED = 0 + + generator_interface = GeneratorInterface(cfg) + generator_interface.load_model() + logger.info(f"loaded model {cfg.distributed_training.distributed_rank}") + + logger.info(f"Resolving generator settings") + update_generation_config(cfg) + + MAX_SEQ_LEN = utils.resolve_max_positions( + generator_interface.task.max_positions(), + *[model.max_positions() for model in generator_interface.models], + ) + logger.info(f"Max sequence length for generation is set to: {MAX_SEQ_LEN}") + + logger.info("Praparing batches for generation") + dataset, length, tgt_text, input_items = [], [], [], [] + if os.path.isdir(cfg.generation.input): + test_files = glob(f"{cfg.generation.input}/**/*.jsonl", recursive=True) + else: + test_files = [cfg.generation.input] + + logger.info(f"Found the following files for generation: {test_files}") + + data_glob_path = None + if args.few_shot_samples_data_path: + data_glob_path = f'{args.few_shot_samples_data_path}/**/*.jsonl' + + prompt_generator = PromptGenerator( + prompt_template_s=args.prompt_template, + few_shot_data_glob_path=data_glob_path, + n_fewshot_samples=args.n_few_shot, + rng_seed=args.seed, + ) + + for file in test_files: + json_dataset = JsonlDataset( + path=file, + tokenizer=partial( + _tokenize_one_json, + encode_fn=generator_interface.encode_fn, + prompt_generator=prompt_generator, + prompt_eos_text=args.prompt_eos_text, + delimiter=args.prompt_delimiter, + max_seq_len=MAX_SEQ_LEN, + max_prompt_len=args.max_prompt_len, + truncation=args.truncation, + prompt_end_text=args.prompt_end_text, + ), + epoch=1, + data_subshard_count=1, + output_raw_items=args.copy_input_to_output + ) + + # TODO: stream instead of loading entire dataset + # will fail if loading dataset greater than RAM + for data in json_dataset: + if args.copy_input_to_output: + token_data, input_data = data + input_items.append(input_data) + else: + token_data = data + ids = token_data[0] + target_text = token_data[1] + dataset.append(ids) + length.append(len(ids)) + tgt_text.append(target_text) + + batches = generator_interface.task.get_batch_iterator( + dataset=generator_interface.task.build_dataset_for_inference(dataset, length), + max_tokens=None, + max_sentences=cfg.dataset.batch_size, + max_positions=None, + ignore_invalid_inputs=False, + skip_remainder_batch=False, + seed=cfg.common.seed, + data_buffer_size=cfg.generation.buffer_size, + num_shards=distributed_utils.get_data_parallel_world_size(), + shard_id=distributed_utils.get_data_parallel_rank(), + num_workers=4, + ).next_epoch_itr(shuffle=False) + + logger.info(f"Preparing generator with settings {cfg.generation}") + args.stop_tokens = ( + [generator_interface.encode_fn(token)[0] + for token in args.stop_tokens.split(",")] if args.stop_tokens is not None else None + ) + generator = generator_interface.task.build_generator( + generator_interface.models, + cfg.generation, + extra_gen_cls_kwargs={ + "stop": args.stop_tokens, + "need_logprobs": args.top_k_logprobs > 0, + "need_logits": args.top_k_logits > 0, + }, + ) + min_len = generator.min_len + max_len_b = generator.max_len_b + + start_time = time.time() + total_generation_time = 1e-10 + + logger.info(f'Results path: {cfg.common_eval.results_path}') + + if args.copy_input_to_output and args.input_keys_to_copy is None: + raise argparse.ArgumentError(None, "'input_keys_to_copy' must be specified when 'copy_input_to_output' is True.") + + if distributed_utils.get_global_rank() == 0: + progress_bar = tqdm(batches, desc="Generating") + else: + progress_bar = batches + + async_processor = ThreadPoolExecutor(max_workers=1) + for batch in progress_bar: + if batch == {}: continue + src_lengths = batch["net_input"]["src_lengths"] + + # generator.max_generation_length is the max generation size WITHOUT the src_length. + # E.g.: if --max_len_b 128, this would be also 128. + generator.max_generation_length = min(MAX_SEQ_LEN, max_len_b) + + # size of the largest src item in this batch + max_seq_len = src_lengths.max().item() + + # generator.max_len_b is the max generation size WITH the src_length. + # E.g.: if --max_len_b 128, this would be (128 + ). + generator.max_len_b = min(MAX_SEQ_LEN, max_len_b + max_seq_len) + # generator.min_len is the min generation size WITH the src_length. + # E.g.: if --min_len 1, this would be (1 + ). + generator.min_len = min(MAX_SEQ_LEN, min_len + max_seq_len) + + if not args.use_cpu: + batch = utils.move_to_cuda(batch) + + translate_start_time = time.time() + output_generation = generator_interface.task.inference_step(generator, generator_interface.models, batch) + translate_time = time.time() - translate_start_time + total_generation_time += translate_time + + all_tokens = output_generation["tokens"].cpu() + all_scores = output_generation["scores"].cpu() + if args.top_k_logprobs > 0: + all_distributions = output_generation["distributions"].cpu( + ) # (bsz, beam_size, seq_len + prompt_size, self.vocab_size) + else: + all_distributions = None + + if args.top_k_logits > 0: + all_logits = output_generation["logits"].cpu() # (bsz, beam_size, seq_len + prompt_size, self.vocab_size) + else: + all_logits = None + + batch_itemgetter = operator.itemgetter(*batch["id"].tolist()) + batch_target_text = list(batch_itemgetter(tgt_text)) if tgt_text[0] is not None else [] + batch_input_items = list(batch_itemgetter(input_items)) if args.copy_input_to_output else [] + + # Async write to file + async_processor.submit( + write_inference_results_to_file, + output_file=os.path.join( + str(cfg.common_eval.results_path), + f"worker_prediction_results_rank{distributed_utils.get_data_parallel_rank()}.jsonl" + ), + generator_interface=generator_interface, + generator=generator, + indexs=batch["id"], + src_lengths=src_lengths, + all_tokens=all_tokens, + all_scores=all_scores, + all_distributions=all_distributions, + all_logits=all_logits, + best_n=args.best_n, + num_logprobs=args.top_k_logprobs, + num_logits=args.top_k_logits, + target_text=batch_target_text, + input_items=batch_input_items, + copy_input_to_output=args.copy_input_to_output, + input_keys_to_copy=args.input_keys_to_copy, + output_tokens_and_offsets=args.output_tokens_and_offsets, + echo_prompt=args.echo_prompt, + progress_bar=progress_bar, + translate_time=translate_time, + ) + + # All processes wait for the async writes to finish at the end + async_processor.shutdown() + distributed_utils.global_barrier() + + if distributed_utils.get_global_rank() == 0: + progress_bar.close() + + if args.merge_preds_on_all_ranks: + unified_output_prediction_file = os.path.join(str(cfg.common_eval.results_path), "all_prediction_results.jsonl") + result_file_max_chunk_bytes_size = None + if args.result_file_max_chunk_mb_size is not None: + result_file_max_chunk_bytes_size = args.result_file_max_chunk_mb_size * 1024 * 1024 + + current_unified_chunk_idx = 0 + + def get_unified_chunk_for_idx(idx): + # if we didn't set a max_size then return filename without + # "chunk" part + file_name = unified_output_prediction_file + if result_file_max_chunk_bytes_size is not None: + file_name = unified_output_prediction_file.replace(".jsonl", f"_chunk{idx}.jsonl") + + return open(file_name, "w+", encoding="utf-8") + + # NOTE: we need to be careful to close this later on + current_unified_file = get_unified_chunk_for_idx(current_unified_chunk_idx) + + workers_output_predictions_files = glob( + os.path.join(str(cfg.common_eval.results_path), "worker_prediction_results*") + ) + + line_iterator = data_utils.multiple_file_line_generator(workers_output_predictions_files) + for predicted_line in line_iterator: + current_unified_file.write(predicted_line) + + # if we specified a maximum chunk size and the bytes in the + # current chunk are already more than our max_size then close + # current chunk and open next one + if ( + result_file_max_chunk_bytes_size is not None + and current_unified_file.tell() > result_file_max_chunk_bytes_size + ): + # close current file and open next chunk + current_unified_file.close() + current_unified_chunk_idx += 1 + current_unified_file = get_unified_chunk_for_idx(current_unified_chunk_idx) + + # close last "unified file" + current_unified_file.close() + + logger.info(f"Written generated results to {unified_output_prediction_file}") + + logger.info( + "Total time: {:.3f} seconds; generation time: {:.3f} seconds; avg tokens/second: {:.2f} ".format( + time.time() - start_time, + total_generation_time, + TOTAL_TOKENS_GENERATED / total_generation_time, + ) + ) + + if args.metrics_list is not None or args.dataset_configuration_name is not None: + if isinstance(args.metrics_list, str): + args.metrics_list = args.metrics_list.split(",") + if isinstance(args.evaluation_libraries, str): + args.evaluation_libraries = args.evaluation_libraries.split(",") + + prediction_files_pattern = os.path.join(str(cfg.common_eval.results_path), "worker_prediction_results*") + + output_evaluation_file_path = os.path.join(str(cfg.common_eval.results_path), "evaluation_results.json") + output_evaluation_individual_file_path = os.path.join( + str(cfg.common_eval.results_path), "evaluation_individual_results.jsonl" + ) + output_evaluation_exceptions_file_path = os.path.join( + str(cfg.common_eval.results_path), "evaluation_exceptions.jsonl" + ) + + evaluate_inference_files( + inference_file_glob_pattern=prediction_files_pattern, + evaluation_output_file_path=output_evaluation_file_path, + individual_results_output_file_path=output_evaluation_individual_file_path, + exceptions_ouput_file_path=output_evaluation_exceptions_file_path, + libraries=args.evaluation_libraries, + metrics=args.metrics_list, + dataset_configuration_name=args.dataset_configuration_name, + model_configuration_name=args.model_configuration_name, + output_metrics_for_all=args.output_metrics_for_all, + ) + + if args.save_generation_info: + generation_info = { + "checkpoint_name": ( + # .../1.3b-resharded-inference-1x1/reshard.pt → 1.3b-resharded-inference-1x1 + pathlib.Path(str(cfg.common_eval.path)).parent.name + ), + # .../hellaswag/valid → hellaswag + "dataset_name": + pathlib.Path(cfg.generation.input).parent.name, + "num_parameters": + sum(param.numel() + for param in generator.model.parameters()) * distributed_utils.get_model_parallel_world_size(), + "generator": { + "vocab_size": generator.vocab_size, + }, + "parsed_args": + args.__dict__, + "raw_args": + sys.argv, + } + + output_generation_info_file = os.path.join(str(cfg.common_eval.results_path), "generation_info.json") + with open(output_generation_info_file, "w") as f: + json.dump(generation_info, f) + + +def extra_args(parser): + parser.add_argument("--model-dir", help="Trained checkpoint directory") + parser.add_argument("--use-cpu", action="store_true", help="Use CPU instead for inference") + + # Prompt Parameters + parser.add_argument( + "--prompt-delimiter", + type=str, + default=None, + help=( + "Prompt delimiter for LM datasets. If the delimiter is specified then we'll expect a " + "sample text to contain both the source and target texts, separated by the delimiter.", + ) + ) + + parser.add_argument("--prompt-template", type=str, default=None, help=("A Jinja2 template passed to PromptGenerator", )) + parser.add_argument( + "--n-few-shot", + type=int, + default=0, + help="Number of examples to use for few-shot generation", + ) + parser.add_argument( + "--few-shot-samples-data-path", + type=str, + default=None, + help="Path to a folder which contains a jsonl file from which the samples for few-shot generation are taken", + ) + parser.add_argument( + "--max-prompt-len", + type=int, + default=0, + help="Maximum number of tokens to use for the prompt. If 0, then the entire prompt will be used.", + ) + parser.add_argument( + "--prompt-eos-text", + type=str, + default="", + help="This will be appended to the end of prompt text before generation.", + ) + parser.add_argument( + "--truncation", + type=str, + default="right", + help="Truncation strategy for the prompt. Can be 'left', 'right', or 'none'.", + ) + parser.add_argument( + "--prompt-end-text", + type=str, + default=None, + help=( + "If provided, each prompt will be forced to end with this text, independently of whether the prompt was truncated or not. " + "For example, this might be helpful if you want to ensure all your prompts end with `TLDR`. The provided tokens will overwrite " + "the last len(prompt_end_text) tokens of the prompt." + ), + ) + + # Generation Parameters + parser.add_argument("--echo-prompt", action="store_true", help="Echo prompt in output") + parser.add_argument("--stop-tokens", type=str, default=None, help="a list of terminating tokens") + parser.add_argument( + "--top-k-logprobs", + type=int, + default=0, + help="Return this cutoff of the probability distribution. Can not be used together with 'top-k-logits' argument.", + ) + parser.add_argument( + "--top-k-logits", + type=int, + default=0, + help="Return the top K logits in the inference output file. Can not be used together with 'top-k-logprobs' argument.", + ) + parser.add_argument("--best-n", type=int, default=1, help="return this cutoff of the beam search") + parser.add_argument("--merge-preds-on-all-ranks", action="store_true", help="merge prediction results on all ranks") + parser.add_argument("--copy-input-to-output", action="store_true", help="copy input to output") + parser.add_argument("--input-keys-to-copy", type=str, default=None, help="comma separated list of input keys to copy") + parser.add_argument("--output-tokens-and-offsets", action="store_true", help="output tokens and offsets") + + # Metrics + parser.add_argument( + "--evaluation-libraries", + type=str, + default="parlai", + help= + f"name of the library that should be used to compute the evaluation metrics. Possible options: {GenerationMetrics.metric_libraries}" + ) + parser.add_argument( + "--metrics-list", + type=str, + default=None, + help="comma separated list of metrics to calculate", + ) + parser.add_argument( + "--dataset-configuration-name", + type=str, + default=None, + help="If provided then metric prameters will be obtained from this dataset configuration", + ) + parser.add_argument( + "--model-configuration-name", + type=str, + default=None, + help= + "If provided then this model configuration will be obtained from the dataset configuration and used to compute metrics", + ) + parser.add_argument( + "--result-file-max-chunk-mb-size", + type=float, + default=None, + help= + "If provided, the inference results will be separated accross many files, each being chunked with the specified size", + ) + parser.add_argument( + "--output-metrics-for-all", + action="store_true", + help="output all metrics for each instance", + ) + parser.add_argument("--pretty-metrics", action="store_true", help="pretty print metrics") + parser.add_argument( + "--save-generation-info", + action="store_true", + help="if provided then a file with information on the parameters used for generation will also be saved" + ) + + return parser + + +def cli_main(): + """ + Generation using trained model. + """ + parser = options.get_generation_parser() + parser = extra_args(parser) + + # dumb defaults overriding + parser.set_defaults( + lr_scheduler=None, criterion=None, task="language_modeling", bpe="hf_byte_bpe", arch="transformer_lm_megatron" + ) + args = options.parse_args_and_arch(parser) + + # set args + args.bpe_vocab = args.vocab_filename + args.bpe_merges = args.merges_filename + args.path = args.model_dir + + # Output log for RANK 0 process. All ranks perform inference in data parallel so we record + # the log only for one process. + if os.environ.get("RANK", "0") == "0": + # Clean up results directory + shutil.rmtree(args.results_path, ignore_errors=True) + + if not args.log_file: + args.log_file = "evaluation.log" + args.log_file = os.path.join(args.results_path, args.log_file) + if os.path.dirname(args.log_file) != "": + os.makedirs(os.path.dirname(args.log_file), exist_ok=True) + handler = logging.FileHandler(filename=args.log_file) + logger.addHandler(handler) + + cfg = convert_namespace_to_omegaconf(args) + if os.environ.get("RANK", "0") == "0": + os.makedirs(os.path.join(args.results_path, "config"), exist_ok=True) + OmegaConf.save( + config=flatten_config(cfg), + f=os.path.join(args.results_path, "config", "config.yml"), + ) + + os.environ["NCCL_DEBUG"] = "WARN" + distributed_utils.call_main(cfg, generate, args=args) + + +if __name__ == "__main__": + cli_main() diff --git a/metaseq/cli/train.py b/metaseq/cli/train.py index 011ca0e8a..41e168bed 100644 --- a/metaseq/cli/train.py +++ b/metaseq/cli/train.py @@ -40,6 +40,7 @@ from metaseq.file_io import PathManager from metaseq.logging import meters, metrics, progress_bar from metaseq.trainer import Trainer +from metaseq.utils import flatten_config logging.basicConfig( format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", @@ -75,7 +76,7 @@ def main(cfg: DictConfig) -> None: # TODO(roller): only works when launched with a sweep script # should fix that OmegaConf.save( - config=_flatten_config(cfg), + config=flatten_config(cfg), f=os.path.join(os.environ["METASEQ_SAVE_DIR"], "config.yml"), ) @@ -288,7 +289,7 @@ def train( "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) - progress.update_config(_flatten_config(cfg)) + progress.update_config(flatten_config(cfg)) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = cfg.dataset.valid_subset.split(",") @@ -411,19 +412,6 @@ def train( return valid_losses, should_stop -def _flatten_config(cfg: DictConfig): - config = OmegaConf.to_container(cfg) - # remove any legacy Namespaces and replace with a single "args" - namespace = None - for k, v in list(config.items()): - if isinstance(v, argparse.Namespace): - namespace = v - del config[k] - if namespace is not None: - config["args"] = vars(namespace) - return config - - def validate_and_save( cfg: DictConfig, trainer: Trainer, diff --git a/metaseq/data/datasets/__init__.py b/metaseq/data/datasets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metaseq/data/datasets/cnn_dm_transformers.py b/metaseq/data/datasets/cnn_dm_transformers.py new file mode 100644 index 000000000..f025691e1 --- /dev/null +++ b/metaseq/data/datasets/cnn_dm_transformers.py @@ -0,0 +1,21 @@ +from typing import Any, Dict + +from metaseq.data.datasets import openai_generated_transformers, shared_transformers +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem + + +def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: + item: OAITeacherGeneratedDatasetItem = raw_dict + + item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) + item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="") + + return item + + +def convert_teacher_domain_to_original_domain(model_output: str) -> str: + """Convert output from text-davinci-003 to the format of original dataset""" + + original_domain_output = shared_transformers.remove_non_alpha_from_beginning(model_output) + + return original_domain_output diff --git a/metaseq/data/datasets/dataset_configurations.py b/metaseq/data/datasets/dataset_configurations.py new file mode 100644 index 000000000..545a10500 --- /dev/null +++ b/metaseq/data/datasets/dataset_configurations.py @@ -0,0 +1,84 @@ +from metaseq.data.datasets import e2e_transformers, hellaswag_transformers, piqa_transformers, reddit_transformers, cnn_dm_transformers +from metaseq.data.datasets.types import CommonDatasetConfiguration, DatasetConfiguration, DatasetConfigurationTeacherGenerated, DatasetModelConfig, DatasetModelHooks, DatasetTeacherGeneratedDataHooks, IdentityDict + +# Visual diagram of where hooks/functions are called during inference or data generation +# https://excalidraw.com/#json=zoAk_TdynBHQnP9vZufGm,ekcVg_HqiF79cAp58_HKRQ +DATASET_CONFIGURATIONS = { + "cnn_dailymail": + DatasetConfiguration( + common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), + model_config=DatasetModelConfig( + model_hooks=DatasetModelHooks( + convert_model_type_output_to_original_domain=IdentityDict( + { + "distilled": cnn_dm_transformers.convert_teacher_domain_to_original_domain, + } + ) + ) + ), + teacher_generated_config=DatasetConfigurationTeacherGenerated( + data_hooks=DatasetTeacherGeneratedDataHooks( + before_transforming_into_metaseq_inference=cnn_dm_transformers.before_transforming_into_metaseq_inference, + convert_test_target_to_original_domain_label=cnn_dm_transformers.convert_teacher_domain_to_original_domain + ), + ), + ), + "e2e_nlg": + DatasetConfiguration( + common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), + teacher_generated_config=DatasetConfigurationTeacherGenerated( + data_hooks=DatasetTeacherGeneratedDataHooks( + before_transforming_into_metaseq_inference=e2e_transformers.before_transforming_into_metaseq_inference, + ), + ), + ), + "hellaswag": + DatasetConfiguration( + common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]), + model_config=DatasetModelConfig( + model_hooks=DatasetModelHooks( + convert_model_type_output_to_original_domain=IdentityDict( + { + "distilled": hellaswag_transformers.hellaswag_convert_model_output_domain_to_original_domain, + } + ) + ) + ), + teacher_generated_config=DatasetConfigurationTeacherGenerated( + data_hooks=DatasetTeacherGeneratedDataHooks( + before_transforming_into_metaseq_inference=hellaswag_transformers. + hellaswag_before_transforming_into_metaseq_inference, + convert_test_target_to_original_domain_label=hellaswag_transformers. + hellaswag_convert_model_output_domain_to_original_domain, + ), + ), + ), + "piqa": + DatasetConfiguration( + common=CommonDatasetConfiguration(metric_libraries=["grindstone"], metric_names=["accuracy"]), + model_config=DatasetModelConfig( + model_hooks=DatasetModelHooks( + convert_model_type_output_to_original_domain=IdentityDict( + { + "distilled": piqa_transformers.model_output_to_orig, + } + ) + ) + ), + teacher_generated_config=DatasetConfigurationTeacherGenerated( + data_hooks=DatasetTeacherGeneratedDataHooks( + before_transforming_into_metaseq_inference=piqa_transformers.preprocess_teacher_generated_data, + convert_test_target_to_original_domain_label=piqa_transformers.model_output_to_orig + ) + ) + ), + "openai_tldr_reddit": + DatasetConfiguration( + common=CommonDatasetConfiguration(metric_libraries=["grindstone", "coco"], metric_names=["bertscore", "rouge-L"]), + teacher_generated_config=DatasetConfigurationTeacherGenerated( + data_hooks=DatasetTeacherGeneratedDataHooks( + before_transforming_into_metaseq_inference=reddit_transformers.before_transforming_into_metaseq_inference, + ), + ), + ), +} diff --git a/metaseq/data/datasets/e2e_transformers.py b/metaseq/data/datasets/e2e_transformers.py new file mode 100644 index 000000000..4d8223fbe --- /dev/null +++ b/metaseq/data/datasets/e2e_transformers.py @@ -0,0 +1,14 @@ +from typing import Any, Dict + +from metaseq.data.datasets import openai_generated_transformers +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem + + +def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: + item: OAITeacherGeneratedDatasetItem = raw_dict + + item = openai_generated_transformers.sanitize_beginning(item) + item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) + item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="") + + return item diff --git a/metaseq/data/datasets/hellaswag_transformers.py b/metaseq/data/datasets/hellaswag_transformers.py new file mode 100644 index 000000000..7f8d2d8cb --- /dev/null +++ b/metaseq/data/datasets/hellaswag_transformers.py @@ -0,0 +1,65 @@ +from typing import Any, Dict + +from metaseq.data.datasets import openai_generated_transformers +from metaseq.data.datasets.shared_transformers import get_first_number +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem, OAITeacherGeneratedDatasetItemLogprobs + + +def _adjust_teacher_generated_format(data: Dict) -> OAITeacherGeneratedDatasetItem: + """ + The format that we have for the teacher generated data for hellaswag is not + what we expect as OpenAI output, so this function will transform the data + the correct shape. + """ + raw_response = data["response"] + + # move human label from range [0,3] to [1,4] so we match with the + # download_hellaswag script + human_label = int(data.get("label", "-2")) + human_label += 1 + + return { + "source": data["prompt"], + "human": str(human_label), + "text": raw_response["text"], + "finish_reason": raw_response["finish_reason"], + "index": data["ind"], + "logprobs": raw_response["logprobs"], + } + + +def hellaswag_before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: + # Transform data to the correct OpenAI output shape + item = _adjust_teacher_generated_format(raw_dict) + + # remove everything after EOS + item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) + + # replace EOS with + item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="") + + # if found, remove everything after the token that has a closing bracket in + # it. This regex will match any token that has a closing bracked in it. For + # examples: + # - ") " + # - ")" + item = openai_generated_transformers.truncate_after_token(item, r".*?\).*?") + + # verify that the target text contains a number. This will throw if not + # found and item will be skipped + try: + get_first_number(item["text"]) + except AssertionError: + raise ValueError(f"Could not find a number in the generated text: {item['text']}") + + return item + + +def hellaswag_convert_model_output_domain_to_original_domain(model_output: str) -> str: + # example model_output: + # ' (4) something something' + number_s = get_first_number(model_output) + choice_idx = number_s.strip() + + # model generated label is in range [1,4] + return choice_idx diff --git a/metaseq/data/datasets/openai_generated_transformers.py b/metaseq/data/datasets/openai_generated_transformers.py new file mode 100644 index 000000000..f13bb2d79 --- /dev/null +++ b/metaseq/data/datasets/openai_generated_transformers.py @@ -0,0 +1,149 @@ +from typing import List + +import regex + +from metaseq.data.datasets.shared_transformers import remove_non_alpha_from_beginning +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem + + +def sanitize_beginning(data: OAITeacherGeneratedDatasetItem) -> OAITeacherGeneratedDatasetItem: + """ + This function will remove all non-letter characters from the beginning of + the text and the tokens list. + + Also note that this removes any non-alpha character/tokens from the + beginning of the text, which is not always desired. + """ + + data["text"] = remove_non_alpha_from_beginning(data["text"]) + + # also remove non-letter tokens from the beginning of the tokens list + logprobs_dict = data["logprobs"] + token_list: List[str] = logprobs_dict["tokens"] + + first_valid_idx = 0 + while not token_list[first_valid_idx].strip().isalpha(): + first_valid_idx += 1 + + for key in [ + "tokens", + "token_logprobs", + "top_logprobs", + "text_offset", + ]: + logprobs_dict[key] = logprobs_dict[key][first_valid_idx:] + + return data + + +def remove_all_tokens_after_eos_sanitizer( + data: OAITeacherGeneratedDatasetItem, eos_token_name="<|endoftext|>" +) -> OAITeacherGeneratedDatasetItem: + """ + This function will remove all tokens after the first EOS token. + + :param str eos_token_name: The name of the EOS token, defaults to + "<|endoftext|>" + """ + + # it can be that there are some samples whose last token is not + # "<|endoftext|>". According to conversation with Subho here [1] we should + # remove all tokens after the endoftext + # + # [1]: + # https://teams.microsoft.com/l/message/19:ZYlNWDJ8jxO0FSqvmpwH1-sCI7RjTudf408_odtYMCU1@thread.tacv2/1677880904286?tenantId=72f988bf-86f1-41af-91ab-2d7cd011db47&groupId=72b4c54c-a4e8-4f3e-b2c3-2bbeaf09e0ff&parentMessageId=1677718389179&teamName=Distillery&channelName=General&createdTime=1677880904286&allowXTenantAccess=false + logprobs_dict = data["logprobs"] + token_list: List[str] = logprobs_dict["tokens"] + + # sanity check + assert eos_token_name in token_list + + eos_index = token_list.index(eos_token_name) + + # remove everything after this index (even the eos token) + for key in [ + "tokens", + "token_logprobs", + "top_logprobs", + "text_offset", + ]: + logprobs_dict[key] = logprobs_dict[key][:eos_index] + + return data + + +def replace_eos_sanitizer( + data: OAITeacherGeneratedDatasetItem, + eos_replacement: str = "", + eos_token_name="<|endoftext|>" +) -> OAITeacherGeneratedDatasetItem: + """ + This function will replace the EOS token name with the given replacement + string. + + :param str eos_replacement: New name for the EOS token we want to use, + defaults to "" + :param str eos_token_name: Old name that was being used for the EOS token, + defaults to "<|endoftext|>" + """ + + logprobs_dict = data["logprobs"] + + tokens = logprobs_dict["tokens"] + for t_idx in range(len(tokens)): + if tokens[t_idx] == eos_token_name: + tokens[t_idx] = eos_replacement + + top_logprobs = logprobs_dict["top_logprobs"] + for logprob_dict in top_logprobs: + if eos_token_name in logprob_dict: + # remove existing item and assign it to eos_replacement token + logprob_dict[eos_replacement] = logprob_dict.pop(eos_token_name) + + return data + + +def truncate_after_token(data: OAITeacherGeneratedDatasetItem, rgx: str) -> OAITeacherGeneratedDatasetItem: + """ + This function will truncate the text and tokens list AFTER the first token + that matches the given regex. + + :param str rgx: The regex to match the token after which we should truncate + """ + token_matcher = regex.compile(rgx, flags=regex.MULTILINE | regex.DOTALL) + + logprobs_dict = data["logprobs"] + token_list: List[str] = logprobs_dict["tokens"] + + # find the first token that matches the regex + index_of_last_token = 0 + seen_text = "" + for token in token_list: + if token_matcher.match(token): + break + index_of_last_token += 1 + seen_text += token + + # if we processed all tokens and exceeded the length of the list then we + # didn't find any token that matches the regex, so we raise an error + if index_of_last_token == len(token_list): + raise ValueError(f"Could not find any token that matches the regex {rgx}.") + + # right now we're at the index of the token that matched the regex, + # meaning we want to drop everything after this index, so we add the + # last token to seen text and then increment the index + seen_text += token_list[index_of_last_token] + index_of_token_after_last = index_of_last_token + 1 + + for key in [ + "tokens", + "token_logprobs", + "top_logprobs", + "text_offset", + ]: + logprobs_dict[key] = logprobs_dict[key][:index_of_token_after_last] + + # now we need to truncate the text as well + data["text"] = seen_text + + return data diff --git a/metaseq/data/datasets/piqa_transformers.py b/metaseq/data/datasets/piqa_transformers.py new file mode 100644 index 000000000..5f06844e4 --- /dev/null +++ b/metaseq/data/datasets/piqa_transformers.py @@ -0,0 +1,64 @@ +from typing import Any, Dict +from metaseq.data.datasets import openai_generated_transformers +from metaseq.data.datasets.shared_transformers import get_first_number + +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem + + +def _adjust_teacher_generated_format(data: Dict) -> OAITeacherGeneratedDatasetItem: + assert data["exception_name"] is None, ( + "Skipping line since there was an error while generating the text: " + data["exception_message"] + ) + + raw_response = data["response"] + + # move human label from range [0,1] to [1,2] so we match with the + # download_piqa script + human_label = int(data.get("label", "-2")) + human_label += 1 + + return { + "source": data["prompt"], + "human": str(human_label), + "text": raw_response["text"], + "finish_reason": raw_response["finish_reason"], + "index": 0, + "logprobs": raw_response["logprobs"], + } + + +def preprocess_teacher_generated_data(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: + # Transform data to the correct OpenAI output shape + item = _adjust_teacher_generated_format(raw_dict) + + # remove everything after EOS + item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) + + # replace EOS with + item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="") + + # if found, remove everything after the token that has a closing bracket in + # it. This regex will match any token that has a closing bracked in it. For + # examples: + # - ") " + # - ")" + item = openai_generated_transformers.truncate_after_token(item, r".*?\).*?") + + # verify that the target text contains a number. This will throw if not + # found and item will be skipped + try: + get_first_number(item["text"]) + except AssertionError: + raise ValueError(f"Could not find a number in the generated text: {item['text']}") + + return item + + +def model_output_to_orig(model_output: str) -> str: + # example model_output: + # ' (2) something something' + number_s = get_first_number(model_output) + choice_idx = number_s.strip() + + # model generated label is in range [1,2] + return choice_idx diff --git a/metaseq/data/datasets/reddit_transformers.py b/metaseq/data/datasets/reddit_transformers.py new file mode 100644 index 000000000..4d8223fbe --- /dev/null +++ b/metaseq/data/datasets/reddit_transformers.py @@ -0,0 +1,14 @@ +from typing import Any, Dict + +from metaseq.data.datasets import openai_generated_transformers +from metaseq.data.datasets.types import OAITeacherGeneratedDatasetItem + + +def before_transforming_into_metaseq_inference(raw_dict: Any) -> OAITeacherGeneratedDatasetItem: + item: OAITeacherGeneratedDatasetItem = raw_dict + + item = openai_generated_transformers.sanitize_beginning(item) + item = openai_generated_transformers.remove_all_tokens_after_eos_sanitizer(item) + item = openai_generated_transformers.replace_eos_sanitizer(item, eos_replacement="") + + return item diff --git a/metaseq/data/datasets/shared_transformers.py b/metaseq/data/datasets/shared_transformers.py new file mode 100644 index 000000000..68c2ef4dd --- /dev/null +++ b/metaseq/data/datasets/shared_transformers.py @@ -0,0 +1,47 @@ +from typing import Any +import regex + + +def identity_transformer(x: Any) -> Any: + return x + + +def remove_non_alpha_from_beginning(text: str) -> str: + """ + Given string, remove all non-"letter" characters from beginning. + Will remove spaces, numbers, special characters (like hyphens and colons) + """ + while not text[0].isalpha(): + text = text[1:] + + return text + + +def get_first_match_of_first_capture_group(rgx: str, input_str: str) -> str: + """ + Given a regex with a capture group and a string, it will return the first + match in the input string. + + Note that only the first capture group is used. If no capture groups are + present then it will raise an exception. + + :param regexp rgx: Regex string to use + :return str: the first match in the input string. + """ + extract_number_regex = regex.compile(rgx, flags=regex.MULTILINE | regex.DOTALL) + + matches = extract_number_regex.findall(input_str) + assert len(matches) > 0, f"Could not find a match for r'{rgx}' in '{input_str}'" + return matches[0] + + +def get_first_number(s: str) -> str: + """ + Given string return first number surrounded by parenthesis + + Examples: + + get_first_number(' (2) something something') -> '2' + get_first_number(' (31) something something') -> '31' + """ + return get_first_match_of_first_capture_group(r".*?(\d+).*?", s) diff --git a/metaseq/data/datasets/types.py b/metaseq/data/datasets/types.py new file mode 100644 index 000000000..b876303d5 --- /dev/null +++ b/metaseq/data/datasets/types.py @@ -0,0 +1,193 @@ +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, TypedDict, Union + +from metaseq.data.datasets.shared_transformers import identity_transformer + + +class DatasetItemLogprob(TypedDict): + token_id: Optional[str] + logprob_score: float + + +class DatasetItem(TypedDict): + """ + This is the format that is used by Metaseq as input for finetuning and + inference. + """ + src: str + tgt: str + + top_logprobs: Optional[List[Dict[str, DatasetItemLogprob]]] + tokens: Optional[List[str]] + token_scores: Optional[List[float]] + + +class MetaseqInferenceOutputItemBeamResult(TypedDict): + generated_text: str + + top_logprobs: List[Dict[str, DatasetItemLogprob]] + tokens: List[str] + token_scores: List[float] + + +class MetaseqInferenceOutputItem(TypedDict): + """ + This is the format that Metaseq's inference component will produce. + """ + instance_idx: int + target_text: str + prompt_text: str + beam_results: List[MetaseqInferenceOutputItemBeamResult] + + +class OAITeacherGeneratedDatasetItemLogprobs(TypedDict): + tokens: List[str] + token_logprobs: List[float] + text_offset: List[int] + top_logprobs: List[Dict[str, float]] + + +class OAITeacherGeneratedDatasetItem(TypedDict): + """ + This is the format that OpenAI teacher-generated datasets are in. + """ + source: str + """ + The prompt that was used to generate the text. + """ + + human: str + """ + The ground truth human label that corresponds with this prompt. + """ + + text: str + """ + The text generated by the OpenAI model. + """ + + index: int + finish_reason: str + logprobs: OAITeacherGeneratedDatasetItemLogprobs + + +@dataclass +class DatasetTeacherGeneratedDataHooks: + + before_transforming_into_metaseq_inference: Callable[[OAITeacherGeneratedDatasetItem], + OAITeacherGeneratedDatasetItem] = field( + default_factory=lambda: identity_transformer + ) + """ + This hook will get called on each item of the dataset just before it is + transformed into the format that metaseq expects for inference. + + For example, this hook will get executed when we're converting the OpenAI + teacher-generated dataset into the format that Metaseq expects as input for + inference/finetuning, or when we're converting it to inference format with + the purpose of evaluating the teacher performance. + """ + + convert_test_target_to_original_domain_label: Callable[[str], str] = field(default_factory=lambda: identity_transformer) + """ + Called when we're converting the test set into the format that Metaseq + expects as input for inference/finetuning. This hook will get called on the + target text created by the teacher so that it is converted into the same + domain as the human labels. Note that this will get called AFTER + `before_transforming_into_metaseq_inference`. + + For example, if the human labels are numbers like "1", "2", "3", but the + model outputs numbers as words "one", "two", "three", then this hook will + get called to convert the model output into actual numbers. + """ + + +@dataclass +class DatasetModelHooks: + convert_model_type_output_to_original_domain: Dict[str, Callable[[str], str]] = field( + default_factory=lambda: + # by default every model type just retuns the identity transformer + defaultdict(lambda: identity_transformer) + ) + """ + This dict maps the model type to a function that will convert the model's output + into the same domain as the original dataset. + + For example: + { + "distilled": lambda x: int(s.plit(" ")[0]), + } + """ + + +class IdentityDict(defaultdict): + """ + Wrapper around default dictionary that may be initialized with static keys. + The dictionary maps function names to the function implemenation where + undefined keys map to identity function. + """ + + def __init__(self, values: dict) -> None: + super().__init__(lambda: identity_transformer) + self.update(values) + + +@dataclass +class CommonDatasetConfiguration: + metric_libraries: Optional[List[str]] = None + """ + The list of metric libraries that will be used to evaluate the model's + output + """ + + metric_names: Optional[List[str]] = None + """ + The list of metric names that will be used to evaluate the model's output + for this dataset. + """ + + +@dataclass +class DatasetDataHooks: + ... + + +@dataclass +class DatasetConfigurationOriginal: + """ + Configuration for an Original dataset (the one with human labels). + """ + data_hooks: DatasetDataHooks = field(default_factory=lambda: DatasetDataHooks()) + + +@dataclass +class DatasetConfigurationTeacherGenerated: + """ + Configuration for a dataset that comes from a teacher-generated output. + """ + data_hooks: DatasetTeacherGeneratedDataHooks = field(default_factory=lambda: DatasetTeacherGeneratedDataHooks()) + + +@dataclass +class DatasetModelConfig: + """ + Configuration for a model for a dataset. + """ + model_hooks: DatasetModelHooks = field(default_factory=lambda: DatasetModelHooks()) + + +@dataclass +class DatasetConfiguration: + """ + Top level type for a dataset configuration. Includes configuration for the + models, and different variations of the dataset (teacher generated, orig), + etc. + """ + common: CommonDatasetConfiguration + model_config: DatasetModelConfig = field(default_factory=lambda: DatasetModelConfig()) + + teacher_generated_config: DatasetConfigurationTeacherGenerated = field( + default_factory=lambda: DatasetConfigurationTeacherGenerated() + ) + orig_config: DatasetConfigurationOriginal = field(default_factory=lambda: DatasetConfigurationOriginal()) diff --git a/metaseq/data/prompt_generator.py b/metaseq/data/prompt_generator.py new file mode 100644 index 000000000..1882a4063 --- /dev/null +++ b/metaseq/data/prompt_generator.py @@ -0,0 +1,314 @@ +from argparse import ArgumentError +import json +import random +import math +from jinja2 import Environment, meta + +from collections import defaultdict +from glob import glob +from tokenizers import ByteLevelBPETokenizer, Tokenizer +from typing import Generator, List, Optional + +from metaseq.logging import get_logger + +logger = get_logger(__name__) + +prompt_template_default = """ +{%- for shot in shots -%} +{{ shot['src'] }} {{ shot['tgt'] }} + +{%+ endfor -%} +{{ input_s }}""" + + +class PromptGenerator: + """ + This class is responsible for decorating the current sample that is intended + to be used as the input to the LM with additional information. The goal of + adding this additional information is to help the LM better understand the + context of the current sample. + + See :meth:`get_next_prompt` for more details on how the prompt is decorated + and what decorations are possible. + """ + + def __init__( + self, + prompt_template_s: Optional[str] = None, + n_fewshot_samples: int = 0, + few_shot_data_glob_path: Optional[str] = None, + fewshot_sample_method='random', + rng_seed=42, + max_tokens: int = None, + shots_to_input_truncation_percentage: float = 0.5, + tokenizer_vocab_file_path="/mnt/input_data_dir/pretrained_models/OPT/dependencies/gpt2-vocab.json", + tokenizer_merges_file_path="/mnt/input_data_dir/pretrained_models/OPT/dependencies/gpt2-merges.txt", + ): + """ + Initializes the PromptGenerator + + :param Optional[str] prompt_template_s: A Jinja template string + Must contain `input_s` + If `n_fewshot_samples` > 0, must contain `shots` loop + Each shot should be a dictionary with `src` and `tgt` keys + :param Optional[str] few_shot_data_glob_path: If provided, this will be the + path to the few-shot data that will be used to generate the prompts. + If None, no few-shot data will be used. Defaults to None. + :param int n_fewshot_samples: The number of few-shot samples that we + want to add to the prompt. If this is > 0 then `few_shot_data_path` + must also be provided. Defaults to 0 + :param str fewshot_sample_method: The method used to choose samples for the few shots. + This may be random or fixed. If fixed the generator will choose first N samples from + the given few_shot_data_path + :param int rng_seed: Seed for the random number generator, defaults to + 42 + :param int max_tokens: The maximum tokens the prompt can use. Defaults to None. + We are not using the exact tokenizer that Open AI uses and thus we force output LESS tokens than the actual value by + subtracting an extra 1% of the max_tokens given to ensure result will be within the allowed range by the true tokenizer. + :param int shots_to_input_truncation_percentage: The ratio or percentage of tokens to truncation from the shot samples vs the input sample + :param str tokenizer_vocab_file_path: Path to tokenizer vocabulary file, defaults to "/mnt/input_data_dir/pretrained_models/OPT/dependencies/gpt2-vocab.json" + :param str tokenizer_merges_file_path: Path to tokenizer merges file, defaults to "/mnt/input_data_dir/pretrained_models/OPT/dependencies/gpt2-merges.txt" + """ + + if prompt_template_s is None: + prompt_template_s = prompt_template_default + + logger.debug(f'Using prompt template:\n{prompt_template_s}') + jinja_env = Environment() + prompt_template_ast = jinja_env.parse(prompt_template_s) + template_variables = meta.find_undeclared_variables(prompt_template_ast) + + assert 'input_s' in template_variables, ( + f"You provided the PromptGenerator with template that did not contain required variable 'input_s'. " + f"Please update the template and try again.\n" + f"Template:\n{prompt_template_s}" + ) + self._n_fewshot_samples = n_fewshot_samples + if n_fewshot_samples > 0: + assert 'shots' in template_variables, ( + f"You configured the PromptGenerator to use {n_fewshot_samples} shot samples; however, " + f"the template provided do not contain the required variable 'shots'. " + f"Please update the template and try again.\n" + f"Template:\n{prompt_template_s}" + ) + + self._prompt_template = jinja_env.from_string(prompt_template_s) + + if few_shot_data_glob_path is not None: + assert n_fewshot_samples > 0, "A path for the few shot data was provided but n_fewshot_samples is not a positive number" + self._fewshot_data_generator = self._create_fewshot_sample_generator( + few_shot_data_glob_path, fewshot_sample_method + ) + + self._rng = random.Random(rng_seed) + + assert 0 <= shots_to_input_truncation_percentage <= 1, "Parameter shots_to_input_truncation_percentage must be withing range [0, 1]" + self._shots_to_input_truncation_percentage = shots_to_input_truncation_percentage + + self._adjusted_max_tokens = None + self._tokenizer = None + # This value is not correct, but we can't compute tokens without tokenizer + self._num_tokens_of_fixed_prompt = 0 + + if max_tokens is not None: + # We are not using the exact tokenizer that Open AI uses and thus it make output LESS tokens than the actual value + # We subtract extra 1% from the max value to ensure result will be within the allowed range by the true tokenizer + tokenizer_gap = math.ceil(0.01 * max_tokens) + self._adjusted_max_tokens = max_tokens - tokenizer_gap + logger.debug(f'Maximum prompt tokens = {max_tokens} - (1% * {max_tokens})') + logger.debug(f' = {max_tokens} - {tokenizer_gap}') + logger.debug(f' = {self._adjusted_max_tokens}') + self._tokenizer: Tokenizer = ByteLevelBPETokenizer.from_file( + tokenizer_vocab_file_path, + tokenizer_merges_file_path, + ) + + empty_shots = [{'src': '', 'tgt': ''} for shot in range(n_fewshot_samples)] + fixed_str_portion_of_prompt = self._prompt_builder(shot_samples=empty_shots, input_s='') + logger.debug(f'Empty {self._n_fewshot_samples}-shot template\n{fixed_str_portion_of_prompt}') + self._num_tokens_of_fixed_prompt = len(self._tokenizer.encode(fixed_str_portion_of_prompt)) + logger.debug(f'Num tokens for fixed portion of prompt = {self._num_tokens_of_fixed_prompt}') + + assert self._adjusted_max_tokens > self._num_tokens_of_fixed_prompt, ( + f'You provided the PromptGenerator with template that contributes more minimum tokens {self._num_tokens_of_fixed_prompt} than the maximum allowed tokens {self._adjusted_max_tokens}.' + f'Please increase the max tokens number or reduce the prompt parameters.' + ) + + def get_next_prompt(self, input_s: str) -> str: + """ + This function takes the current sample that is being used as the input + to the LM and decorates it according to the :class:`PromptGenerator`'s + configuration. + + These are the possible `decorations` that this method might do to the + input `input_s` + + - addition of `self._n_fewshot_samples` few-shot samples to the prompt + + Note that the PromptGenerator will return the `input_s` input + as-is if it is not configured to `decorate` the prompt in any way. + + :param str input_s: current sample that is the input to the LM. + This `base prompt` will be `decorated` by the PromptGenerator to + create the final prompt. + :return str: the final decorated prompt that will be used as the input + to the LM. + """ + + # decorate the prompt with few-shot samples if needed + samples_for_this_shot = [] + + if self._n_fewshot_samples > 0: + samples_for_this_shot = next(self._fewshot_data_generator) + assert len( + samples_for_this_shot + ) == self._n_fewshot_samples, f"Fewshot data generator returned the wrong number of samples, expected {self._n_fewshot_samples} but got {len(samples_for_this_shot)}" + + prompt = self._prompt_builder(samples_for_this_shot, input_s) + + if self._adjusted_max_tokens is not None and self._tokenizer: + prompt_token_ids = self._tokenizer.encode(prompt).ids + prompt_token_ids_len = len(prompt_token_ids) + logger.debug(f'Generated prompt is {prompt_token_ids_len} tokens.') + + # If the prompt tokens exceeds the maximum then truncate to fit + if prompt_token_ids_len >= self._adjusted_max_tokens: + logger.warn( + f'Generated prompt has {prompt_token_ids_len} but maximum allowed is {self._adjusted_max_tokens}. Prompt shots and input will truncated.' + ) + + num_tokens_to_truncate = (prompt_token_ids_len - self._adjusted_max_tokens) + self._num_tokens_of_fixed_prompt + original_prompt = prompt + original_prompt_token_ids_len = prompt_token_ids_len + truncated_samples = [] + + if self._n_fewshot_samples > 0: + num_shot_tokens_to_truncate = math.ceil( + num_tokens_to_truncate * self._shots_to_input_truncation_percentage + ) + num_tokens_to_truncate_from_each_shot = math.ceil(num_shot_tokens_to_truncate / self._n_fewshot_samples) + num_input_tokens_to_truncate = num_tokens_to_truncate - num_shot_tokens_to_truncate + + for sample_index, sample in enumerate(samples_for_this_shot): + + src_token_ids = self._tokenizer.encode(sample['src']).ids + tgt_token_ids = self._tokenizer.encode(sample['tgt']).ids + src_token_percentage = len(src_token_ids) / (len(src_token_ids) + len(tgt_token_ids)) + tgt_token_percentage = len(tgt_token_ids) / (len(src_token_ids) + len(tgt_token_ids)) + + src_tokens_to_truncate = math.ceil(num_tokens_to_truncate_from_each_shot * src_token_percentage) + tgt_tokens_to_truncate = math.ceil(num_tokens_to_truncate_from_each_shot * tgt_token_percentage) + + logger.debug( + f'Sample {sample_index} truncation params:\n' + f'Tokens to truncate: {num_tokens_to_truncate_from_each_shot}\n' + f'Source tokens: {len(src_token_ids)} ({src_token_percentage:.2f}%)\n' + f'Right truncating {src_tokens_to_truncate} from source.\n' + f'Target tokens: {len(tgt_token_ids)} ({tgt_token_percentage:.2f}%)\n' + f'Right truncating {tgt_tokens_to_truncate} from target.' + ) + + truncated_sample = { + 'src': self._tokenizer.decode(src_token_ids[:-src_tokens_to_truncate]), + 'tgt': self._tokenizer.decode(tgt_token_ids[:-tgt_tokens_to_truncate]) + } + truncated_samples.append(truncated_sample) + else: + # Since there are 0 shots, all the truncation occurs on the input string + num_input_tokens_to_truncate = num_tokens_to_truncate + + logger.debug(f'Right truncating {num_input_tokens_to_truncate} from input.') + truncated_input_s = self._truncate_n_tokens(input_s, num_input_tokens_to_truncate) + prompt = self._prompt_builder(truncated_samples, truncated_input_s) + + logger.debug(f'Original Prompt (Tokens: {original_prompt_token_ids_len}):\n{original_prompt}\n') + prompt_token_ids = self._tokenizer.encode(prompt).ids + prompt_token_ids_len = len(prompt_token_ids) + logger.debug(f'Truncated Prompt (Tokens: {prompt_token_ids_len})') + + return prompt + + def _truncate_n_tokens(self, s: str, num_tokens_to_truncate: int) -> str: + """Given string, encode, truncate, decode.""" + return self._tokenizer.decode(self._tokenizer.encode(s).ids[:-num_tokens_to_truncate]) + + def _prompt_builder(self, shot_samples: List[dict], input_s: str) -> str: + return self._prompt_template.render(shots=shot_samples, input_s=input_s) + + def _create_fewshot_sample_generator(self, data_glob_path: str, sample_method: str) -> Generator[List[dict], None, None]: + """ + Creates a generator that yields a list of few-shot samples. The samples + themselves are chosen randomly from the few-shot data. Each "iteration" + of the generator returned by this function will yield a list of + `self._n_fewshot_samples` samples. + + :param str data_path: Path to the directly which contains the few-shot + data. These must be `jsonl` documents. + :yield Generator[List[dict], None, None]: Infinite generator that yields + a list of random few-shot samples for each iteration. The size of + the list is `self._n_fewshot_samples`. + """ + logger.info(f"Creating few-shot sample generator from {data_glob_path}") + + file_paths = glob(data_glob_path) + assert len(file_paths) > 0, f"No files found in {data_glob_path}" + + logger.info(f"Found the following files for creation of few-shot samples: {file_paths}") + + # get index of every newline in each file + file_idx_to_prompt_start_idx = defaultdict(list) + file_idx_to_handle = {} + for file_idx, file_path in enumerate(file_paths): + # NOTE: this approach prevents us from having to load the entire + # file into memory but it does mean that we need to have a file + # handle open for each file + file_handle = open(file_path, "r") + file_idx_to_handle[file_idx] = file_handle + + # prompt_offset is the index of the first character of the prompt + # within the file. Initially it is 0 since the first prompt starts + # at the beginning of the file + prompt_offset = 0 + + file_idx_to_prompt_start_idx[file_idx].append(prompt_offset) + + line = file_handle.readline() + while line != "": + prompt_offset += len(line) + file_idx_to_prompt_start_idx[file_idx].append(prompt_offset) + line = file_handle.readline() + + # pop the last entry since it happens just before EOF (there is no sample after it) + file_idx_to_prompt_start_idx[file_idx].pop() + + # The generator state has been set up, now we only need to yield the + # samples when needed + while True: + # Every time we need a new sample, we randomly choose a file and a + # sample from that file + chosen_file_idx = self._rng.sample(file_idx_to_prompt_start_idx.keys(), 1)[0] + file_handle = file_idx_to_handle[chosen_file_idx] + + if sample_method == 'random': + # Sampling self.n_fewshot_samples samples at a time ensures there + # are no duplicates for the samples of the current prompt + chosen_offsets = self._rng.sample(file_idx_to_prompt_start_idx[chosen_file_idx], self._n_fewshot_samples) + elif sample_method == 'fixed': + chosen_offsets = [file_idx_to_prompt_start_idx[chosen_file_idx][i] for i in range(self._n_fewshot_samples)] + else: + raise ArgumentError( + f"You passed an unknown sample method: {sample_method}. Allowed values are: 'random' or 'fixed'" + ) + + samples = [None] * self._n_fewshot_samples + + for i, offset in enumerate(chosen_offsets): + file_handle.seek(offset) + line = file_handle.readline() + + sample = json.loads(line) + assert set(sample.keys()).issuperset({"src", "tgt"}), "jsonl entry must contain 'src' and 'tgt' fields" + + samples[i] = sample + + yield samples #type: ignore diff --git a/metaseq/generation_metrics/__init__.py b/metaseq/generation_metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metaseq/generation_metrics/coco_metrics.py b/metaseq/generation_metrics/coco_metrics.py new file mode 100644 index 000000000..658771d53 --- /dev/null +++ b/metaseq/generation_metrics/coco_metrics.py @@ -0,0 +1,72 @@ +from collections import defaultdict +from typing import Dict, List, Union + +from metaseq.generation_metrics.wrappers.coco import measure_scores +from metaseq.generation_metrics.wrappers.coco.measure_scores import CustomCOCOEvalCap, ResultsForPrompt +from metaseq.logging import get_logger + +logger = get_logger(__name__) + + +class CocoMetrics: + """ + A note on these metrics is that they'll automatically perform *grouping* of + the items to be evaluted by their ``prompt``. The *generated text* is then + compared against all ``target texts`` related with the prompt at once and + the average is returned. + + Existing metrics: + + - bleu + - meteor + - rouge-L / rouge + - cider + - spice + - nist + """ + + allowed_metrics = CustomCOCOEvalCap.allowed_metrics.union({"nist"}) + + def __init__(self, metrics: List[str] = []) -> None: + + # if empty list then assume we want to run all metrics + if len(metrics) == 0: + metrics = list(CocoMetrics.allowed_metrics) + + # ignore metrics that we don't know about. Not an error since these + # other metrics might be used by other "evaluators" (e.g. ParlAI) + final_metrics = set(metrics).intersection(CocoMetrics.allowed_metrics) + + logger.info("Registering COCO evaluator with the following list of metrics: %s", final_metrics) + self.metrics_list = final_metrics + + self.prompt_to_results: Dict[ + str, ResultsForPrompt] = defaultdict(lambda: ResultsForPrompt( + prompt="", + generated_text="", + target_texts=[], + )) + + def __call__(self, prompt: str, prediction: str, label: Union[List[str], str]) -> Dict: + prompt_result = self.prompt_to_results[prompt] + prompt_result["prompt"] = prompt + + # only keep the first generated_text + if prompt_result["generated_text"] == "": + prompt_result["generated_text"] = prediction + + if type(label) == list: + prompt_result["target_texts"].extend(label) + else: + prompt_result["target_texts"].append(label) # type: ignore + + # TODO is this ok? Returning metrics just for every sample would be + # really slow. Does it even make sense considering we're assured to only + # have access to a single target? + return {} + + def accumulate(self) -> Dict: + return measure_scores.evaluate( + list(self.prompt_to_results.values()), + metrics=self.metrics_list, + ) diff --git a/metaseq/generation_metrics/grindstone_metrics.py b/metaseq/generation_metrics/grindstone_metrics.py new file mode 100644 index 000000000..486ce558e --- /dev/null +++ b/metaseq/generation_metrics/grindstone_metrics.py @@ -0,0 +1,150 @@ +from collections import defaultdict +from typing import Callable, Dict, List, Union +from bert_score import BERTScorer +import numpy as np +from rouge_score import rouge_scorer + +EvalFunctionType = Callable[[str, List[str]], Dict[str, float]] + + +def _build_rouge_function(rouge_types: List[str]) -> EvalFunctionType: + """ + Returns a function that will compute the Rouge metrics for a pair of + "prediction" and "references". + + These are computed using the same library and params that are used in the + HELM benchamark. + + :param List[str] rouge_types: Rouge types we want to have this evaluator + return. These come from :ref:`rouge_scorer.RougeScorer`. + :return EvalFunctionType: evaluation function that returns a score for each + of the :param rouge_types:. + """ + # Use the same rouge scorers that are used in Helm + # https://github.dev/stanford-crfm/helm/blob/a1c5e4293d6b475aca567abaadf4eae8c978bd4e/src/helm/benchmark/metrics/summarization_metrics.py#L46-L51 + scorer = rouge_scorer.RougeScorer(rouge_types, use_stemmer=True) + + def evaluate_closure(pred: str, references: List[str]) -> Dict[str, float]: + # https://github.dev/stanford-crfm/helm/blob/a1c5e4293d6b475aca567abaadf4eae8c978bd4e/src/helm/benchmark/metrics/summarization_metrics.py#L131-L137 + max_scores = defaultdict(lambda: -np.inf) + for reference in references: + score = scorer.score(prediction=pred, target=reference) + + for rouge_type in rouge_types: + score_for_type = score[rouge_type].fmeasure + if score_for_type > max_scores[rouge_type]: + max_scores[rouge_type] = score_for_type + + return max_scores + + return evaluate_closure + + +def _build_bert_scorer() -> EvalFunctionType: + """ + Returns a function that will compute the BERTScore for a pair of + "prediction" and "references". + + These are computed using the same library and params that are used in the + HELM benchamark. + + :return EvalFunctionType: evaluation function that returns the Precision, + Recall, and F1, from the BERTScorer. + """ + # same params as Helm: + # https://github.dev/stanford-crfm/helm/blob/a1c5e4293d6b475aca567abaadf4eae8c978bd4e/src/helm/benchmark/metrics/summarization_metrics.py#L68-L69 + scorer = BERTScorer( + model_type="microsoft/deberta-large-mnli", + lang="en", + rescale_with_baseline=True, + ) + + def evaluate_closure(pred: str, references: List[str]) -> Dict[str, float]: + # https://github.dev/stanford-crfm/helm/blob/a1c5e4293d6b475aca567abaadf4eae8c978bd4e/src/helm/benchmark/metrics/summarization_metrics.py#L150-L152 + precision, recall, f1score = scorer.score(cands=[pred], refs=[references]) + return {"BERTScore-P": precision[0].item(), "BERTScore-R": recall[0].item(), "BERTScore-F": f1score[0].item()} + + return evaluate_closure + + +def _build_accuracy_scorer() -> EvalFunctionType: + + # Scorer comes directly from HELM: + # https://github.dev/stanford-crfm/helm/blob/80ecb204a9fed6a54a53e1f5950e9c1e145acddc/src/helm/benchmark/metrics/basic_metrics.py#L137-L138 + def exact_match(gold: str, pred: str) -> float: + if not pred: + return 0 + + return 1 if gold.strip() == pred.strip() else 0 + + def evaluate_closure(pred: str, references: List[str]) -> Dict[str, float]: + computed_exact_match = np.mean([exact_match(ref, pred) for ref in references]) + + return {"accuracy": computed_exact_match} + + return evaluate_closure + + +class GrindstoneMetrics: + """ + This class presents a collection of metrics+evaluators with the goal of + ensuring the algorithms we use here are the same as those used in popular + benchmarks (e.g. HELM, BabelBench) so we can properly compare against them. + """ + + def __init__(self, metrics: List[str]) -> None: + self.metrics = set(m.lower() for m in metrics) + + if "rouge-l" in self.metrics: + # Add all Rouge-* if rouge-L is provided so we're compatible with + # other evaluators + self.metrics.add("rouge") + + # the sum of all results of a given metric + self.sum_of_metrics: Dict[str, float] = defaultdict(lambda: 0.0) + + # how many examples we've seen + self.exs = 0 + + # register metric functions + self.metric_name_to_eval_fn: Dict[str, EvalFunctionType] = {} + if "bertscore" in self.metrics: + self.metric_name_to_eval_fn["BERTScore"] = _build_bert_scorer() + + if "rouge" in self.metrics: + self.metric_name_to_eval_fn["rouge"] = _build_rouge_function(["rouge1", "rouge2", "rougeL"]) + + if "accuracy" in self.metrics: + self.metric_name_to_eval_fn["accuracy"] = _build_accuracy_scorer() + + def __call__(self, _prompt: str, prediction: str, label: Union[List[str], str]) -> Dict: + response = {} + + label_list: List[str] + if type(label) == str: + label_list = [label] # type: ignore + else: + label_list = label # type: ignore + + self.exs += 1 + + # get metrics from each of the selected metrics classes + for eval_fn in self.metric_name_to_eval_fn.values(): + output = eval_fn(prediction, label_list) + + for metric, score in output.items(): + # self.sum_of_metrics is a defaultdict(0.0), so score always + # starts at 0 + self.sum_of_metrics[metric] += score + + response.update(output) + + return response + + def accumulate(self) -> Dict: + result = {} + + for metric, accumulated_score in self.sum_of_metrics.items(): + result[metric] = accumulated_score / self.exs + + return result diff --git a/metaseq/generation_metrics/hf_metrics.py b/metaseq/generation_metrics/hf_metrics.py new file mode 100644 index 000000000..13a348160 --- /dev/null +++ b/metaseq/generation_metrics/hf_metrics.py @@ -0,0 +1,54 @@ +import logging +from typing import Dict, List, Union + +import evaluate + +from metaseq.logging import get_logger + +logger = get_logger(__name__) + + +class HFEvaluateMetrics(): + """ + Existing Metrics: + + .. code-block:: python + + [ + 'precision', 'code_eval', 'roc_auc', 'cuad', 'xnli', 'rouge', 'pearsonr', 'mse', 'super_glue', + 'comet', 'cer', 'sacrebleu', 'mahalanobis', 'wer', 'competition_math', 'f1', 'recall', 'coval', + 'mauve', 'xtreme_s', 'bleurt', 'ter', 'accuracy', 'exact_match', 'indic_glue', 'spearmanr', 'mae', + 'squad', 'chrf', 'glue', 'perplexity', 'mean_iou', 'squad_v2', 'meteor', 'bleu', 'wiki_split', 'sari', + 'frugalscore', 'google_bleu', 'bertscore', 'matthews_correlation', 'seqeval','trec_eval', 'rl_reliability', + 'poseval', 'brier_score', 'mase', 'mape', 'smape', 'nist_mt', 'character', 'charcut_mt', 'mcnemar', + 'exact_match', 'wilcoxon', 'word_length', 'word_count', 'text_duplicates', 'perplexity', 'label_distribution', + 'toxicity', 'regard', 'honest' + ] + """ + + def __init__(self, metrics: List[str], **kwargs) -> None: + self.metrics = metrics + self.infer_metrics(**kwargs) + + def infer_metrics(self, **kwargs) -> None: + self.metric_cls = {} + for metric in self.metrics: + try: + self.metric_cls[metric] = evaluate.load(metric, **kwargs) + except Exception: + logger.warning(f"HF Evaluate: Metric {metric} is not supported.") + + def __call__(self, _prompt: str, prediction: str, label: Union[List[str], str]) -> Dict: + response = {} + for metric, metric_cls in self.metric_cls.items(): + kwargs = {} + if metric == "perplexity": + # TODO: Update this to not use a hardcoded default model_id + kwargs = {"model_id": "gpt2"} + + response.update(metric_cls.compute(predictions=[prediction], references=[label], **kwargs)) + return response + + def accumulate(self) -> Dict: + logging.warning("Accumulate is not supported for HF Evaluate Metrics yet.") + return {} diff --git a/metaseq/generation_metrics/metrics.py b/metaseq/generation_metrics/metrics.py new file mode 100644 index 000000000..81ad9b5e1 --- /dev/null +++ b/metaseq/generation_metrics/metrics.py @@ -0,0 +1,336 @@ +import os +from glob import glob +import json +from typing import Any, Dict, List, Optional, Union +import fire +from datetime import datetime + +from metaseq.data import data_utils +from metaseq.data.datasets.dataset_configurations import DATASET_CONFIGURATIONS +from metaseq.data.datasets.types import MetaseqInferenceOutputItem + +from metaseq.generation_metrics.coco_metrics import CocoMetrics +from metaseq.generation_metrics.grindstone_metrics import GrindstoneMetrics +from metaseq.generation_metrics.hf_metrics import HFEvaluateMetrics +from metaseq.generation_metrics.parlai_metrics import ParlAiMetrics +from metaseq.logging import get_logger +from metaseq.scripts import script_utils + +repo_root = script_utils.get_repo_root_path() +logger = get_logger(__name__) + + +class GenerationMetrics(): + metric_libraries = {"all", "parlai", "hf", "coco", "grindstone"} + + def __init__(self, metrics: List[str], libraries: List[str] = ["parlai"], **kwargs) -> None: + self.metrics = [m.strip() for m in metrics] + self.libraries = [l.lower().strip() for l in libraries] + self.infer_metrics(**kwargs) + + def infer_metrics(self, **kwargs) -> None: + self.metric_cls = {} + + for library in self.libraries: + assert library in self.metric_libraries, f"library {self.libraries} not supported" + + if "all" in self.libraries or "parlai" in self.libraries: + # parlai expects metrics list as a string + metrics = ",".join(self.metrics) + self.metric_cls["parlai"] = ParlAiMetrics(metrics_list=metrics, **kwargs) + + if "all" in self.libraries or "hf" in self.libraries: + self.metric_cls["hf"] = HFEvaluateMetrics(metrics=self.metrics, **kwargs) + + if "all" in self.libraries or "coco" in self.libraries: + self.metric_cls["coco"] = CocoMetrics(metrics=self.metrics, **kwargs) + + if "all" in self.libraries or "grindstone" in self.libraries: + self.metric_cls["grindstone"] = GrindstoneMetrics(metrics=self.metrics, **kwargs) + + def __call__(self, prompt: str, prediction: str, label: Union[List[str], str]) -> Dict[str, float]: + response = {} + for metric_lib, metric_cls in self.metric_cls.items(): + response[metric_lib] = metric_cls(prompt, prediction, label) + return response + + def accumulate(self) -> Dict[str, float]: + # TODO: "accumulate" is not supported by HFEvaluateMetrics + assert "hf" not in self.metric_cls, "'accumulate' is not supported by HFEvaluateMetrics" + + response = {} + for metric_lib, metric_cls in self.metric_cls.items(): + response[metric_lib] = metric_cls.accumulate() + + return response + + +def evaluate_inference_files( + inference_file_glob_pattern: str, + evaluation_output_file_path: str, + individual_results_output_file_path: str, + exceptions_ouput_file_path: str, + libraries: Optional[List[str]] = None, + metrics: Optional[List[str]] = None, + dataset_configuration_name: Optional[str] = None, + model_configuration_name: Optional[str] = None, + output_metrics_for_all: bool = False, +): + """ + Evaluate inference files using the provided metrics and libraries, or using + the provided dataset and model configuration. + + :param str inference_file_glob_pattern: Glob pattern that will be used to + find the inference files to evaluate. + :param str evaluation_output_file: Path to the file where the evaluation + results will be written. + :param str exceptions_ouput_file: Path to the file where any exceptions that + happen during evaluation will be recorded. + :param Optional[List[str]] libraries: List of libraries to use during + evaluation, defaults to None + :param Optional[List[str]] metrics: List of metrics to use during + evaluation, defaults to None + :param Optional[str] dataset_configuration_name: The dataset configuratio + name to use during evaluation. If provided AND this configuration has + `libraries` and `metrics` then this will overwrite the `libraries` and + `metrics` args with the values of the configuration, defaults to None + :param Optional[str] model_configuration_name: Name of the model + configuration to get from the dataset configuration, defaults to None + :param bool output_metrics_for_all: If True then the result file will + contain a line for the evaluation result of each inference, defaults to + False + """ + + if model_configuration_name is not None: + assert dataset_configuration_name is not None, "Expected dataset configuration name to be provided when model configuration name is provided" + + # initialize evaluator + if dataset_configuration_name is not None: + logger.info(f"Using dataset configuration: {dataset_configuration_name}") + + dataset_config = DATASET_CONFIGURATIONS[dataset_configuration_name] + if dataset_config.common.metric_names is not None: + logger.info(f"Overriding metrics with dataset configuration: {dataset_config.common.metric_names}") + metrics = dataset_config.common.metric_names + + if dataset_config.common.metric_libraries is not None: + logger.info( + f"Overriding evaluation libraries with dataset configuration: {dataset_config.common.metric_libraries}" + ) + libraries = dataset_config.common.metric_libraries + + assert libraries is not None, "Expected evaluation libraries to be provided when running evaluation" + assert metrics is not None, "Expected evaluation metrics to be provided when running evaluation" + + # load hook (or default to identity) + model_domain_to_original_domain_transformer = lambda x: x + if dataset_configuration_name is not None and model_configuration_name is not None: + model_domain_to_original_domain_transformer = ( + DATASET_CONFIGURATIONS[dataset_configuration_name].model_config.model_hooks. + convert_model_type_output_to_original_domain[model_configuration_name] + ) + + # create evaluator + evaluator = GenerationMetrics(metrics=metrics, libraries=libraries) + + # load inference files + inference_files = glob(inference_file_glob_pattern, recursive=True) + assert len( + inference_files + ) > 0, f"Found no files to evaluate on that match the provided pattern: {inference_file_glob_pattern}" + line_iterator = data_utils.multiple_file_line_generator(inference_files) + + # evaluate + with open(evaluation_output_file_path, "w") as f_results, \ + open(individual_results_output_file_path, "w") as f_individual_results, \ + open(exceptions_ouput_file_path, "w") as f_exceptions: + total_num_exceptions = 0 + total_num_evaluated_successfully = 0 + + for predicted_line in line_iterator: + inference_row: MetaseqInferenceOutputItem = json.loads(predicted_line) + + row_metrics = [] + target_text = inference_row["target_text"] + + for beam_idx, beam in enumerate(inference_row["beam_results"]): + + model_output = beam["generated_text"] + + try: + model_output = model_domain_to_original_domain_transformer(model_output) + + row_metrics.append( + evaluator( + prompt=inference_row["prompt_text"], + prediction=model_output, + label=target_text, + ) + ) + + total_num_evaluated_successfully += 1 + except Exception as e: + logger.error( + f"Exception attempting evaluate line {line_iterator.current_line_num}\n" + f"Exception: {e}\n" + f"Raw item:\n\t{inference_row}" + ) + total_num_exceptions += 1 + + f_exceptions.write( + json.dumps( + { + "file": line_iterator.current_file_path, + "line_number": line_iterator.current_line_num, + "error": f"{type(e).__name__} - {e}", + "raw_line": predicted_line, + "beam_idx": beam_idx, + } + ) + "\n" + ) + + if output_metrics_for_all: + f_individual_results.write(json.dumps(row_metrics) + "\n") + + accumulated_metrics: Dict[str, Any] = evaluator.accumulate() + accumulated_metrics["evaluation_info"] = { + "total_num_evaluated_successfully": total_num_evaluated_successfully, + "total_num_exceptions": total_num_exceptions, + } + + f_results.write(json.dumps(accumulated_metrics, indent=2)) + + logger.info(f"Saved evaluation results to {evaluation_output_file_path}") + total_rows = total_num_exceptions + total_num_evaluated_successfully + if total_num_exceptions > 0: + logger.info(f"Saved evaluation exceptions to {exceptions_ouput_file_path}") + logger.info( + f"Total number of exceptions encountered: {total_num_exceptions} ({total_num_exceptions / total_rows * 100:0.2f}% of total rows)" + ) + logger.info( + f"Total number of inference rows evaluated successfully: {total_num_evaluated_successfully} ({total_num_evaluated_successfully / total_rows * 100:0.2f}% of total rows)" + ) + logger.info(f"(Accumulated Metrics)\n{json.dumps(accumulated_metrics, indent=2)}") + + +def cli_main( + input_file_path: str, + output_folder_path=repo_root / '_results', + libraries: Union[str, List[str]] = ["parlai", "grindstone"], + metrics: Union[str, List[str]] = ["rouge-L", "bertscore"], + pretty: bool = False, + output_all: bool = False, + dataset_configuration_name: Optional[str] = None, + model_configuration_name: Optional[str] = None, + prompt: str = "", + label: Optional[str] = None, + prediction: Optional[str] = None, +): + """ + This script can be used in two ways: + + 1. Pass it an input file which is obtained from Metaseq's inference + component, and it will evaluate how good are the model's predictions vs + the ground truth label. + 2. Pass it a single triplet of "prompt", "label" (aka 'ground truth') and + "prediction" and it will print out the result of evaluating that single + triplet. This is mainly used for quick testing. + + **Example usage:** + + **Evaluating inference results directly:** + + .. code-block:: bash + + python -m metaseq.generation_metrics.metrics \\ + --pretty=True \\ + --metrics="rouge-L,bertscore" \\ + --libraries="coco,grindstone" \\ + --input-file-path="/mnt/input_data_dir/examples/evaluation/all_prediction_results.jsonl" + + + **Evaluating inference results using dataset configuration:** + + .. code-block:: bash + + python -m metaseq.generation_metrics.metrics \\ + --pretty=True \\ + --dataset_configuration_name="hellaswag" \\ + --model_configuration_name="distilled" \\ + --input-file-path="/mnt/input_data_dir/examples/evaluation/all_prediction_results.jsonl" + + :param str input_file_path: This is the path to an inference file that was + produced by Metaseq's inference component. + :param str output_folder_path: The path to folder which will contain the results of + evaluating the items in `input_file`. + :param str libraries: A comma-separated list of the libraries we want to use + to compute the evaluation. Allowed values are [``parlai``, + ``grindstone``, ``hf``, ``all``]. If ``all`` is provided then *all* + libraries will be used. The final result will have an entry for each of + the libraries. Defaults to "parlai,grindstone" + :param str metrics: A comma-separated list of the metrics that we want to + compute. If the librar(y/ies) being used support this metric then it + will be included among the metrics computed by said library. This means + that if multiple libraries are chosen and a subset of them supports a + given metric then it will be computed for all libraries in the subset, + defaults to "bleu,rouge-L" + :param bool pretty: If True then the final result will be pretty-printed, + defaults to False + :param bool output_all: If True then the evaluation score of every item will + be returned together with the aggregate. If false, only the aggregate is + returned, defaults to False + :param Optional[str] dataset_configuration_name: If provided then the + evaluation configuration will be taken from the dataset config with this + name, defaults to None + :param Optional[str] model_configuration_name: If provided then this model + configuration will be obtained from the dataset configuration and used + to compute metrics, defaults to None + :param str prompt: To be used when you desire to do an evalution of a single + triplet. Represents the prompt. Defaults to "" + :param Optional[str] label: To be used when you desire to do an evalution of + a single triplet. Represents the label/ground truth. Defaults to None + :param Optional[str] prediction: To be used when you desire to do an + evalution of a single triplet. Represents the prediction/generated text. + Defaults to None + """ + + # convert metrics and libraries args to lists if they are not already + if isinstance(metrics, str): + metrics = metrics.split(",") + + if isinstance(libraries, str): + libraries = libraries.split(",") + + output_folder_path, [aggregated_results_file_path, individual_results_file_path, + exceptions_file_path] = script_utils.create_output_folder_and_get_filepaths( + output_folder_path, 'metrics_evaluation', + [f'aggregated_results.json', f'individual_results.jsonl', f'exceptions.jsonl'] + ) + + if input_file_path is not None and output_folder_path is not None: + evaluate_inference_files( + inference_file_glob_pattern=input_file_path, + evaluation_output_file_path=aggregated_results_file_path, + individual_results_output_file_path=individual_results_file_path, + exceptions_ouput_file_path=exceptions_file_path, + libraries=libraries, + metrics=metrics, + dataset_configuration_name=dataset_configuration_name, + model_configuration_name=model_configuration_name, + output_metrics_for_all=output_all, + ) + + elif prediction is not None and label is not None: + gen_metrics = GenerationMetrics(metrics=metrics, libraries=libraries) + results = gen_metrics(prompt, prediction, label) + if pretty: + print(json.dumps(results, indent=2)) + else: + print(results) + + else: + raise ValueError("Either input_file_path and output_folder_path or prediction and label must be provided") + + +if __name__ == "__main__": + fire.Fire(cli_main) diff --git a/metaseq/generation_metrics/parlai_metrics.py b/metaseq/generation_metrics/parlai_metrics.py new file mode 100644 index 000000000..7da43209e --- /dev/null +++ b/metaseq/generation_metrics/parlai_metrics.py @@ -0,0 +1,55 @@ +from typing import Dict, List, Optional, Union + +from parlai.core.metrics import Metric, TeacherMetrics + +from metaseq.logging import get_logger + +logger = get_logger(__name__) + + +class ParlAiMetrics(TeacherMetrics): + """ + Existing Metrics: + + .. code-block:: python + + [ + 'accuracy', 'auc', 'bleu-4', 'clen', 'clip', 'ctpb', 'ctps', 'ctrunc', 'ctrunclen', + 'exps', 'exs', 'f1', 'gen_n_toks', 'gnorm', 'gpu_mem', 'hits@1', 'hits@5', + 'interdistinct-1', 'interdistinct-2', 'intradistinct-1', 'intradictinct-2', + 'jga', 'llen', 'loss', 'lr', 'ltpb', 'ltps', 'ltrunc', 'ltrunclen', 'precision', + 'recall', 'rouge-1', 'rouge-2', 'rouge-L', 'token_acc', 'token_em', 'total_train_updates', + 'tpb', 'tps', 'ups' + ] + + - DEFAULT_METRICS = {'bleu-4', 'accuracy', 'f1'} + - ROUGE_METRICS = {'rouge-1', 'rouge-2', 'rouge-L'} + - BLEU_METRICS = {'bleu-1', 'bleu-2', 'bleu-3', 'bleu-4'} + - DISTINCT_METRICS = {'interdistinct-1', 'interdistinct-2', 'intradistinct-1', 'intradistinct-2'} + + Alias: + + .. code-block:: python + + [ + 'default', 'rouge', 'bleu', 'distinct', 'all' + ] + """ + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + def __call__(self, _prompt: str, prediction: str, label: Union[List[str], str]) -> Dict: + self.evaluate_response(observation={'text': prediction}, labels=[label] if isinstance(label, str) else label) + return {k: v.value() for k, v in self.report_recent().items()} + + def accumulate(self) -> Dict: + return {k: v.value() for k, v in self.report().items()} + + def add(self, key: str, value: Optional[Metric]) -> None: + """ + Record an accumulation to a metric. + """ + # Fixing bug with self._recent_data in parlai.core.metrics.TeacherMetrics.add + self._data[key] = self._data.get(key) + value + self._recent_data[key] = value diff --git a/metaseq/generation_metrics/wrappers/coco/__init__.py b/metaseq/generation_metrics/wrappers/coco/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/metaseq/generation_metrics/wrappers/coco/measure_scores.py b/metaseq/generation_metrics/wrappers/coco/measure_scores.py new file mode 100644 index 000000000..9748b7089 --- /dev/null +++ b/metaseq/generation_metrics/wrappers/coco/measure_scores.py @@ -0,0 +1,292 @@ +# Adapted from https://github.com/tuetschek/e2e-metrics/blob/master/measure_scores.py + +#!/usr/bin/env python3 + +from __future__ import print_function + +import json +import sys +from builtins import str, zip +from collections import defaultdict +from glob import glob +from pprint import pp +from typing import List, Set, TypedDict +import fire +from pycocotools.coco import COCO +from pycocoevalcap.eval import COCOEvalCap +from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.meteor.meteor import Meteor +from pycocoevalcap.rouge.rouge import Rouge +from pycocoevalcap.cider.cider import Cider +from pycocoevalcap.spice.spice import Spice +from .pymteval import BLEUScore, NISTScore +from metaseq.generation_metrics import grindstone_metrics + + +class ResultsForPrompt(TypedDict): + """ + Represents the result of a given prompt with its generated text and all the + target senteces that belong to this prompt. + """ + prompt: str + generated_text: str + target_texts: List[str] + + +class CustomCOCORouge(Rouge): + """ + Here we're overriding COCO's default Rouge implementation since it differs + from the implementation used in Helm and BabelBench. Specifically, the + difference is that the algorithm uses a different tokenizer. This causes + only small discrepancies in the final score, but there's no need to + introduce extra noise (in the score). + + :param _type_ Rouge: _description_ + """ + + def __init__(self): + self.inner_scorer = grindstone_metrics._build_rouge_function(["rougeL"]) + + def calc_score(self, candidate, refs): + # ensure input has the correct shape + if type(candidate) == list: + candidate = candidate[0] + + if type(refs) == str: + refs = [refs] + + result = self.inner_scorer(pred=candidate, references=refs) + return result["rougeL"] + + +class CustomCOCOEvalCap(COCOEvalCap): + """ + This is a reimplementation of pycocoevalcap.eval.COCOEvalCap which allows us + to specify which metrics we want to run. + """ + + allowed_metrics = {"bleu", "meteor", "rouge-L", "rouge", "cider", "spice"} + + def __init__(self, coco, cocoRes, metrics: Set[str] = allowed_metrics): + if "rouge" in metrics: + # rouge is a supergroup that contains the following + metrics.add("rouge-L") + + assert len(metrics.difference(CustomCOCOEvalCap.allowed_metrics)) == 0, ( + "The provided list of metrics to CustomCOCOEvalCap is invalid. " + f"Allowed metrics are: {CustomCOCOEvalCap.allowed_metrics}" + ) + + super().__init__(coco, cocoRes) + self.metrics_list = metrics + + def evaluate(self): + imgIds = self.params['image_id'] + # imgIds = self.coco.getImgIds() + gts = {} + res = {} + for imgId in imgIds: + gts[imgId] = self.coco.imgToAnns[imgId] + res[imgId] = self.cocoRes.imgToAnns[imgId] + + # ================================================= + # Set up scorers + # ================================================= + print('tokenization...') + tokenizer = PTBTokenizer() + gts = tokenizer.tokenize(gts) + res = tokenizer.tokenize(res) + + # ================================================= + # Set up scorers + # ================================================= + print('setting up scorers...') + + # maps the name of the metric we want to calculate to a lazy function + # which returns the instance of the scorer so that we can avoid + # initializing scorers that we won't end up using. + metric_name_to_scorers = { + "bleu": (lambda: Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), + "meteor": (lambda: Meteor(), "METEOR"), + "rouge-L": (lambda: CustomCOCORouge(), "ROUGE_L"), + "cider": (lambda: Cider(), "CIDEr"), + "spice": (lambda: Spice(), "SPICE") + } + + # ================================================= + # Compute scores + # ================================================= + for metric in self.metrics_list: + if metric in metric_name_to_scorers: + get_scorer, method = metric_name_to_scorers[metric] + scorer = get_scorer() + + print('computing %s score...' % (scorer.method())) + score, scores = scorer.compute_score(gts, res) + if type(method) == list: + for sc, scs, m in zip(score, scores, method): + self.setEval(sc, m) + self.setImgToEvalImgs(scs, gts.keys(), m) + print("%s: %0.3f" % (m, sc)) + else: + self.setEval(score, method) + self.setImgToEvalImgs(scores, gts.keys(), method) + print("%s: %0.3f" % (method, score)) + + self.setEvalImgs() + + +def create_coco_refs(data_ref): + """Create MS-COCO human references JSON.""" + out = {'info': {}, 'licenses': [], 'images': [], 'type': 'captions', 'annotations': []} + ref_id = 0 + for inst_id, refs in enumerate(data_ref): + out['images'].append({'id': 'inst-%d' % inst_id}) + for ref in refs: + out['annotations'].append({'image_id': 'inst-%d' % inst_id, 'id': ref_id, 'caption': ref}) + ref_id += 1 + return out + + +def create_coco_sys(data_sys): + """Create MS-COCO system outputs JSON.""" + out = [] + for inst_id, inst in enumerate(data_sys): + out.append({'image_id': 'inst-%d' % inst_id, 'caption': inst}) + return out + + +def run_pymteval(data_ref, data_sys, metrics: Set[str]): + """Run document-level BLEU and NIST in their Python implementation (should give the + same results as Perl).""" + print('Running Py-MTEval metrics...', file=sys.stderr) + + bleu = None + if "bleu" in metrics: + bleu = BLEUScore() + + nist = None + if "nist" in metrics: + nist = NISTScore() + + # collect statistics + for sents_ref, sent_sys in zip(data_ref, data_sys): + if bleu: + bleu.append(sent_sys, sents_ref) + + if nist: + nist.append(sent_sys, sents_ref) + + result = {} + + if bleu: + result["BLEU"] = bleu.score() + + if nist: + result["NIST"] = nist.score() + + return result + + +def run_coco_eval(data_ref, data_sys, metrics: Set[str]): + """Run the COCO evaluator, return the resulting evaluation object (contains both + system- and segment-level scores.""" + # convert references and system outputs to MS-COCO format in-memory + coco_ref = create_coco_refs(data_ref) + coco_sys = create_coco_sys(data_sys) + + print('Running MS-COCO evaluator...', file=sys.stderr) + coco = COCO() + coco.dataset = coco_ref + coco.createIndex() + + coco_res = coco.loadRes(resFile=coco_sys) + + coco_metrics = metrics.intersection(CustomCOCOEvalCap.allowed_metrics) + coco_eval = CustomCOCOEvalCap(coco, coco_res, coco_metrics) + coco_eval.evaluate() + + return coco_eval + + +def evaluate(results: List[ResultsForPrompt], metrics: Set[str]): + """Main procedure, running the MS-COCO & MTEval evaluators on the loaded data.""" + + # generated[i], and targets[i] all correspond to the i'th sample + # in the data + all_generated: List[str] = [] + all_targets: List[List[str]] = [] + + for res in results: + all_generated.append(res["generated_text"]) + all_targets.append(res["target_texts"]) + + assert len(all_generated) == len(all_targets) + + # run the MS-COCO evaluator + coco_eval = run_coco_eval(all_targets, all_generated, metrics) + scores = {metric: score for metric, score in list(coco_eval.eval.items())} + + # run MT-Eval + mteval_scores = run_pymteval(all_targets, all_generated, metrics) + + scores.update(mteval_scores) + + return scores + + +def get_prediction_results(all_predictions_file_glob: str) -> List[ResultsForPrompt]: + res = defaultdict(lambda: ResultsForPrompt( + prompt="", + generated_text="", + target_texts=[], + )) + + all_prediction_files = sorted(glob(all_predictions_file_glob, recursive=True)) + + for pred_file_path in all_prediction_files: + for line in open(pred_file_path, "r"): + data = json.loads(line) + + prompt: str = data["prompt_text"] + target: str = data["target_text"] + + # TODO here we're taking the BEST from the beam. But we could also + # take all results in the beam. Any ideas? + generated: str = data["beam_results"][0]['generated_text'] + + res_for_this_prompt = res[prompt] + res_for_this_prompt["prompt"] = prompt + res_for_this_prompt["target_texts"].append(target) + res_for_this_prompt["generated_text"] = generated + + return list(res.values()) + + +def cli_main(all_predictions_file_glob: str, output_path: str = "_results/coco_evaluation_results.json"): + """ + E2E Challenge evaluation -- MS-COCO & MTEval wrapper + + :param str all_predictions_file_glob: Glob string to be used to find all the + predictions we want to evaluate + :param str output_path: Output path to where we want to save the computed + metrics, defaults to "_results/coco_evaluation_results.json" + """ + parsed_results = get_prediction_results(all_predictions_file_glob) + + print(f"Running evaluation on {len(parsed_results)} items") + + scores = evaluate(results=parsed_results, ) + + # print out the results + with open(output_path, 'w+') as f: + json.dump(scores, f, indent=2) + + print() + pp(scores) + print() + + +if __name__ == '__main__': + fire.Fire(cli_main) diff --git a/metaseq/logging/__init__.py b/metaseq/logging/__init__.py index e69de29bb..6bac2e767 100644 --- a/metaseq/logging/__init__.py +++ b/metaseq/logging/__init__.py @@ -0,0 +1,21 @@ +import logging +import os +import sys +import time +from typing import List + + +def get_logger(name: str, logger_blocklist: List[str] = []): + for module in logger_blocklist: + logging.getLogger(module).setLevel(logging.WARNING) + + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, + ) + logging.Formatter.converter = time.gmtime # Enforce UTC timestamps + logger = logging.getLogger(name) + + return logger diff --git a/metaseq/utils.py b/metaseq/utils.py index 87baadc6b..699ec5381 100644 --- a/metaseq/utils.py +++ b/metaseq/utils.py @@ -9,6 +9,7 @@ import math import os import random +import argparse import re import sys import warnings @@ -21,6 +22,19 @@ import torch.nn.functional as F from metaseq.distributed import utils as distributed_utils +from omegaconf import DictConfig, OmegaConf + +try: + from megatron.mpu import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + ) + from megatron.mpu.utils import VocabUtility + + has_megatron_submodule = True +except (ImportError, ModuleNotFoundError): + has_megatron_submodule = False try: from amp_C import multi_tensor_l2norm @@ -80,6 +94,23 @@ def _move_to_cpu(tensor): return apply_to_sample(_move_to_cpu, sample) +def load_align_dict(replace_unk): + if replace_unk is None: + align_dict = None + elif isinstance(replace_unk, str) and len(replace_unk) > 0: + # Load alignment dictionary for unknown word replacement if it was passed as an argument. + align_dict = {} + with open(replace_unk, "r") as f: + for line in f: + cols = line.split() + align_dict[cols[0]] = cols[1] + else: + # No alignment dictionary provided but we still want to perform unknown word replacement by copying the + # original source word. + align_dict = {} + return align_dict + + def make_positions(tensor, padding_idx: int): """Replace non-padding symbols with their position numbers. @@ -93,6 +124,10 @@ def make_positions(tensor, padding_idx: int): return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx +def strip_token(tensor, token): + return tensor[tensor.ne(token)] + + def item(tensor): if hasattr(tensor, "item"): return tensor.item() @@ -348,6 +383,67 @@ def get_perplexity(loss, round=2, base=2): return float("inf") +def vocab_parallel_token_accuracy( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + ignored_token_ids: Optional[List[int]] = None, + ignored_tokens_mask: Optional[torch.Tensor] = None +): + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()) + # Subtract the maximum value. This is done for numerical stability. + # This helps prevent exp() from overflowing while computing the softmax. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + # If given explicit token ids to ignore, add them to the mask + if ignored_token_ids is not None: + for token_id in ignored_token_ids: + target_mask |= target == token_id + + # If given mask tensor, add it to the mask + if ignored_tokens_mask is not None: + target_mask |= ignored_tokens_mask + + # Shift the target ids to be in the range [0, vocab-size - vocab_start_index). + masked_target = target.clone() - vocab_start_index + + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + target_mask_1d = target_mask.view(-1) + + # Compute accuracy across vocab parallel logits. + # Find the token index with the maximum logit value and compare it to the target token index. + # The result has 1s where the prediction was correct and 0s otherwise. Mask out the invalid tokens. + + # Note that max logit index will be in range [0, partition-vocab-size) while target token index + # will be in range [0, vocab-size - vocab_start_index). + # All reduce SUM is critical to this implementation since the right index could be on a + # different model parallel partition. By SUM, we get 1 if any partition has the right index. + # This is also why sum() is computed after all reduce step. + accuracy = (torch.argmax(logits_2d, dim=-1) == masked_target_1d).float() + accuracy.masked_fill_(target_mask_1d, 0.0) + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(accuracy, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + accuracy_correct_tokens = accuracy.sum() + n_unmasked_tokens = (~target_mask_1d).sum() + torch.distributed.all_reduce(n_unmasked_tokens, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group()) + + return accuracy_correct_tokens / n_unmasked_tokens + + def has_parameters(module): try: next(module.parameters()) @@ -580,3 +676,30 @@ def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): ["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid ] return alignment + + +def flatten_config(cfg: DictConfig): + config = OmegaConf.to_container(cfg) + # remove any legacy Namespaces and replace with a single "args" + namespace = None + for k, v in list(config.items()): + if isinstance(v, argparse.Namespace): + namespace = v + del config[k] + if namespace is not None: + config["args"] = vars(namespace) + return config + + +# a function to print current rank and then message +def print_with_rank(*args): + rank = get_tensor_model_parallel_rank() + print(f"MRank {rank}, ", end="") + print(*args) + +def print_tensor_with_rank(tensor): + # get the indices of the non-zero values + indices = tensor.nonzero() + # print the indices and values of the non-zero values + for index in indices: + print_with_rank(f"I: {index.tolist()}, V: {tensor[index[0], index[1]]}")