Skip to content

Commit

Permalink
they actually used midpoint with step size 0.0625
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 7, 2023
1 parent 88b9a52 commit 4938046
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 6 additions & 4 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module
import torch.nn.functional as F

from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint

from beartype import beartype
from beartype.typing import Tuple
Expand Down Expand Up @@ -556,7 +556,8 @@ def __init__(
sigma = 0.,
ode_atol = 1e-5,
ode_rtol = 1e-5,
ode_method = 'dopri5',
ode_method = 'midpoint',
ode_step_size = 0.0625,
cond_drop_prob = 0.
):
super().__init__()
Expand All @@ -569,7 +570,8 @@ def __init__(
self.odeint_kwargs = dict(
atol = ode_atol,
rtol = ode_rtol,
method= ode_method
method = ode_method,
options = dict(step_size = ode_step_size)
)

@property
Expand Down Expand Up @@ -604,7 +606,7 @@ def fn(t, x):

print('sampling')

trajectory = odeint(fn, y0, t, adjoint_params=(), **self.odeint_kwargs)
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)

sampled = trajectory[-1] # last in trajectory
return sampled
Expand Down

0 comments on commit 4938046

Please sign in to comment.