From 756c302e804d4a680b16831170a917a82434469b Mon Sep 17 00:00:00 2001 From: Clive Chan Date: Tue, 17 Jan 2023 18:50:04 -0800 Subject: [PATCH] Zero-grad more aggressively to save memory --- mingpt/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mingpt/trainer.py b/mingpt/trainer.py index c0d08521..27385eee 100644 --- a/mingpt/trainer.py +++ b/mingpt/trainer.py @@ -93,10 +93,10 @@ def run(self): logits, self.loss = model(x, y) # backprop and update the parameters - model.zero_grad(set_to_none=True) self.loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) self.optimizer.step() + model.zero_grad(set_to_none=True) self.trigger_callbacks('on_batch_end') self.iter_num += 1