diff --git a/setup.py b/setup.py index 59f5608..f747a6d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.4.5', + version = '0.4.6', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/trainer.py b/voicebox_pytorch/trainer.py index 01a209b..a4e5d26 100644 --- a/voicebox_pytorch/trainer.py +++ b/voicebox_pytorch/trainer.py @@ -1,6 +1,8 @@ -from pathlib import Path import re +from pathlib import Path from shutil import rmtree +from functools import partial +from contextlib import nullcontext from beartype import beartype @@ -256,12 +258,16 @@ def train_step(self): # training step - for _ in range(self.grad_accum_every): + for grad_accum_step in range(self.grad_accum_every): + is_last = grad_accum_step == (self.grad_accum_every - 1) + context = partial(self.accelerator.no_sync, self.cfm_wrapper) if not is_last else nullcontext + wave, = next(self.dl_iter) - loss = self.cfm_wrapper(wave) + with self.accelerator.autocast(), context(): + loss = self.cfm_wrapper(wave) - self.accelerator.backward(loss / self.grad_accum_every) + self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})