diff --git a/setup.py b/setup.py index cb22c4e..1dac4ca 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.14', + version = '0.0.15', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index dcf950c..6288d97 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -6,7 +6,7 @@ from torch.nn import Module import torch.nn.functional as F -from torchdiffeq import odeint_adjoint as odeint +from torchdiffeq import odeint from beartype import beartype from beartype.typing import Tuple @@ -556,7 +556,8 @@ def __init__( sigma = 0., ode_atol = 1e-5, ode_rtol = 1e-5, - ode_method = 'dopri5', + ode_method = 'midpoint', + ode_step_size = 0.0625, cond_drop_prob = 0. ): super().__init__() @@ -569,7 +570,8 @@ def __init__( self.odeint_kwargs = dict( atol = ode_atol, rtol = ode_rtol, - method= ode_method + method = ode_method, + options = dict(step_size = ode_step_size) ) @property @@ -604,7 +606,7 @@ def fn(t, x): print('sampling') - trajectory = odeint(fn, y0, t, adjoint_params=(), **self.odeint_kwargs) + trajectory = odeint(fn, y0, t, **self.odeint_kwargs) sampled = trajectory[-1] # last in trajectory return sampled