From 4254e7d8c964d727d7b4a47965b09b6dd0719d20 Mon Sep 17 00:00:00 2001 From: Aakash Apoorv Date: Wed, 12 Jun 2024 23:51:14 +0000 Subject: [PATCH] Refactor validation loss calculation to accumulate before normalizing --- train_gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_gpt2.py b/train_gpt2.py index 9327989..a92431c 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -389,8 +389,8 @@ def get_lr(it): x, y = x.to(device), y.to(device) with torch.autocast(device_type=device_type, dtype=torch.bfloat16): logits, loss = model(x, y) - loss = loss / val_loss_steps val_loss_accum += loss.detach() + val_loss_accum /= val_loss_steps if ddp: dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG) if master_process: