From ef3a98906b16a5e577e38f81c3069749e3974fb6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 29 Jan 2025 09:17:44 -0500 Subject: [PATCH] fix iterator overflow when gradient accumulation is 1 --- src/transformers/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 00938a630764..2221b5da9d4b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2487,6 +2487,8 @@ def _inner_training_loop( remainder = args.gradient_accumulation_steps update_step = -1 total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + if args.gradient_accumulation_steps == 1: + total_updates -= 1 for _ in range(total_updates): update_step += 1 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder