diff --git a/setup.py b/setup.py index b43756c..06bb471 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.2.3', + version = '0.2.5', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/trainer.py b/voicebox_pytorch/trainer.py index 5cc4172..5f1a571 100644 --- a/voicebox_pytorch/trainer.py +++ b/voicebox_pytorch/trainer.py @@ -149,14 +149,12 @@ def __init__( self.cfm_wrapper, self.optim, self.scheduler, - self.dl, - self.valid_dl + self.dl ) = self.accelerator.prepare( self.cfm_wrapper, self.optim, self.scheduler, - self.dl, - self.valid_dl + self.dl ) # dataloader iterators @@ -281,11 +279,13 @@ def train_step(self): if self.is_main and not (steps % self.save_results_every): wave, = next(self.valid_dl_iter) - + unwrapped_model = self.accelerator.unwrap_model(self.cfm_wrapper) + with torch.inference_mode(): - self.cfm_wrapper.eval() + unwrapped_model.eval() - valid_loss = self.cfm_wrapper(wave) + wave = wave.to(unwrapped_model.device) + valid_loss = unwrapped_model(wave) self.print(f'{steps}: valid loss {valid_loss:0.3f}') self.accelerator.log({"valid_loss": valid_loss}, step=steps)