From 62e2890eaa3816e19e272c58c95efe896b073505 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 27 Nov 2023 08:36:43 -0800 Subject: [PATCH] save on unnecessary gradient synchronizations --- setup.py | 2 +- voicebox_pytorch/trainer.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 59f5608..f747a6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.4.5', + version = '0.4.6', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/trainer.py b/voicebox_pytorch/trainer.py index 01a209b..a4e5d26 100644 --- a/voicebox_pytorch/trainer.py +++ b/voicebox_pytorch/trainer.py @@ -1,6 +1,8 @@ -from pathlib import Path import re +from pathlib import Path from shutil import rmtree +from functools import partial +from contextlib import nullcontext from beartype import beartype @@ -256,12 +258,16 @@ def train_step(self): # training step - for _ in range(self.grad_accum_every): + for grad_accum_step in range(self.grad_accum_every): + is_last = grad_accum_step == (self.grad_accum_every - 1) + context = partial(self.accelerator.no_sync, self.cfm_wrapper) if not is_last else nullcontext + wave, = next(self.dl_iter) - loss = self.cfm_wrapper(wave) + with self.accelerator.autocast(), context(): + loss = self.cfm_wrapper(wave) - self.accelerator.backward(loss / self.grad_accum_every) + self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})