Skip to content

Commit

Permalink
save on unnecessary gradient synchronizations
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 27, 2023
1 parent 0a59a31 commit 62e2890
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.5',
version = '0.4.6',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 10 additions & 4 deletions voicebox_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
import re
from pathlib import Path
from shutil import rmtree
from functools import partial
from contextlib import nullcontext

from beartype import beartype

Expand Down Expand Up @@ -256,12 +258,16 @@ def train_step(self):

# training step

for _ in range(self.grad_accum_every):
for grad_accum_step in range(self.grad_accum_every):
is_last = grad_accum_step == (self.grad_accum_every - 1)
context = partial(self.accelerator.no_sync, self.cfm_wrapper) if not is_last else nullcontext

wave, = next(self.dl_iter)

loss = self.cfm_wrapper(wave)
with self.accelerator.autocast(), context():
loss = self.cfm_wrapper(wave)

self.accelerator.backward(loss / self.grad_accum_every)
self.accelerator.backward(loss / self.grad_accum_every)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

Expand Down

0 comments on commit 62e2890

Please sign in to comment.