From 87992caa82add6a468f45ec599545520a37db90d Mon Sep 17 00:00:00 2001 From: itayhubara Date: Thu, 1 Feb 2024 18:28:24 +0200 Subject: [PATCH] fix masking bug --- llm_finetune/scripts/train.py | 2 +- llm_finetune/scripts/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/llm_finetune/scripts/train.py b/llm_finetune/scripts/train.py index 6b81a6a0b..ecf2b8b5a 100644 --- a/llm_finetune/scripts/train.py +++ b/llm_finetune/scripts/train.py @@ -106,7 +106,7 @@ class ScriptArguments: default=6, metadata={"help": "Log every X updates steps."} ) target_eval_loss: float = field( - default=1.19, metadata={"help": "target eval loss - NOT FINAL."} + default=0.92, metadata={"help": "target eval loss - NOT FINAL."} ) output_dir: str = field( default="results", metadata={"help": "Where to store the final model."} diff --git a/llm_finetune/scripts/utils.py b/llm_finetune/scripts/utils.py index 1747eb03b..c376a6c8b 100644 --- a/llm_finetune/scripts/utils.py +++ b/llm_finetune/scripts/utils.py @@ -143,7 +143,8 @@ def group_texts(examples, block_size): k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } - result["labels"] = result["input_ids"].copy() + if 'labels' not in result: + result["labels"] = result["input_ids"].copy() return result