From 69a467dbeb808e811bc44c3ac11be549e9ce6a10 Mon Sep 17 00:00:00 2001 From: shankarg87 Date: Thu, 25 Jul 2024 18:21:04 -0700 Subject: [PATCH] Fix batching code * Process large number of rows in batches * Hardcoding some parameters --- src/aihero/research/finetuning/callback.py | 5 +++++ src/aihero/research/finetuning/infer.py | 26 ++++++++++++++-------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/aihero/research/finetuning/callback.py b/src/aihero/research/finetuning/callback.py index 0ea5437..1316671 100644 --- a/src/aihero/research/finetuning/callback.py +++ b/src/aihero/research/finetuning/callback.py @@ -36,6 +36,7 @@ def __init__( run_tests_str=run_tests_str, run_metrics_str=run_metrics_str, max_new_tokens=max_new_tokens, + batch_size=trainer.args.per_device_eval_batch_size, ) # Sample a few rows from the test split to generate a table of predictions @@ -56,7 +57,9 @@ def initialize(self: "LLMSampleCB") -> None: """Generate initial predictions for the sample split and log them to WANDB.""" self._wandb.init() + self.batch_inference.model.eval() _, (records_table, metrics) = self.batch_inference.run_initial_predictions(self.sample_split) + self.batch_inference.model.train() # Log the table of sample predictions to W&B self._wandb.log({"sample_predictions": records_table}) @@ -69,8 +72,10 @@ def on_evaluate(self, args: Any, state: Any, control: Any, **kwargs: dict[str, A """Log the sample predictions and metrics to WANDB on eval callback.""" super().on_evaluate(args, state, control, **kwargs) + self.batch_inference.model.eval() # Generate the table of sample predictions _, (records_table, metrics) = self.batch_inference.infer(self.sample_split) + self.batch_inference.model.train() # Log the table of sample predictions to W&B self._wandb.log({"sample_predictions": records_table}) diff --git a/src/aihero/research/finetuning/infer.py b/src/aihero/research/finetuning/infer.py index 98e179f..537fd6a 100644 --- a/src/aihero/research/finetuning/infer.py +++ b/src/aihero/research/finetuning/infer.py @@ -45,7 +45,7 @@ def __init__(self, batch_inference_job: BatchInferenceJob): size = 0 randomize = False - print(self.dataset_dict) + self.model.eval() self.batch_inference_split = self.dataset_dict["batch_inference"] if size: if randomize: @@ -58,6 +58,7 @@ def __init__(self, batch_inference_job: BatchInferenceJob): run_tests_str=run_tests_str, run_metrics_str=run_metrics_str, max_new_tokens=self.batch_inference_job.generator.max_seq_length or MAX_NEW_TOKENS, + batch_size=8, # Needs to be added to eval arguments ) def load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: @@ -249,6 +250,7 @@ def __init__( """Initialize the batch inference class.""" self.gen_config = GenerationConfig.from_pretrained(model.name_or_path, max_new_tokens=max_new_tokens) self.model = model + self.batch_size = batch_size self.tokenizer = tokenizer self.task = task self.run_tests_str = run_tests_str @@ -265,15 +267,21 @@ def __init__( def generate(self, prompts: List[str]) -> Any: """Generate a completion from a prompt.""" - tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)["input_ids"].cuda() + tokens = self.tokenizer(prompts, return_tensors="pt", padding=True) + outputs = [] with torch.inference_mode(): - output = self.model.generate( - inputs=tokenized_prompts, - generation_config=self.gen_config, - pad_token_id=self.tokenizer.eos_token_id, - ) - decoded_outputs = self.tokenizer.batch_decode(output, skip_special_tokens=True) - return [output[len(tokens) :] for tokens, output in zip(tokenized_prompts, decoded_outputs)] + for i in tqdm(range(0, len(prompts), self.batch_size), leave=False): + output = self.model.generate( + inputs=tokens["input_ids"][i : i + self.batch_size].cuda(), + attention_mask=tokens["attention_mask"][i : i + self.batch_size].cuda(), + generation_config=self.gen_config, + pad_token_id=self.tokenizer.eos_token_id, + repetition_penalty=1.2, # TODO: Add to generation config? + num_return_sequences=1, + ) + outputs.append(output) + decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return [output[len(tokens) :] for tokens, output in zip(tokens["input_ids"].cuda(), decoded_outputs)] def run_initial_predictions(self, rows: Dataset) -> Tuple[list[dict[str, Any]], Tuple[Table, dict[str, Any]]]: """Generate initial predictions for the sample split."""