Skip to content

Commit

Permalink
add ability to sample with torchode, thanks to @b-chiang for the tip …
Browse files Browse the repository at this point in the history
…on how to make it work
  • Loading branch information
lucidrains committed Aug 7, 2023
1 parent 4938046 commit 70f2fcf
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ sampled = cfm_wrapper.sample(
- [x] basic loss
- [x] get neural ode working with torchdyn
- [x] get basic mask generation logic with the p_drop of 0.2-0.3 for ICL
- [x] just use torchdiffeq, nothing else is mature. torchode looks promising but cannot support ndim > 2
- [x] take care of p_drop, different between voicebox and duration model
- [x] support torchdiffeq and torchode

- [ ] consider switching to adaptive rmsnorm for time conditioning
- [ ] integrate with either hifi-gan and soundstream / encodec
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.15',
version = '0.0.16',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand All @@ -20,7 +20,8 @@
'beartype',
'einops>=0.6.1',
'torch>=2.0',
'torchdiffeq'
'torchdiffeq',
'torchode'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down
54 changes: 46 additions & 8 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from torch.nn import Module
import torch.nn.functional as F

import torchode as to
from torchode.single_step_methods import SingleStepMethod

from torchdiffeq import odeint

from beartype import beartype
Expand Down Expand Up @@ -556,8 +559,10 @@ def __init__(
sigma = 0.,
ode_atol = 1e-5,
ode_rtol = 1e-5,
ode_method = 'midpoint',
ode_step_size = 0.0625,
use_torchode = False,
torchdiffeq_ode_method = 'midpoint', # use midpoint for torchdiffeq, as in paper
torchode_method_klass: SingleStepMethod = to.Tsit5, # use tsit5 for torchode, as torchode does not have midpoint (recommended by Bryan @b-chiang)
cond_drop_prob = 0.
):
super().__init__()
Expand All @@ -567,10 +572,13 @@ def __init__(

self.cond_drop_prob = cond_drop_prob

self.use_torchode = use_torchode
self.torchode_method_klass = torchode_method_klass

self.odeint_kwargs = dict(
atol = ode_atol,
rtol = ode_rtol,
method = ode_method,
method = torchdiffeq_ode_method,
options = dict(step_size = ode_step_size)
)

Expand All @@ -585,30 +593,60 @@ def sample(
phoneme_ids,
cond,
mask = None,
steps = 2,
steps = 3,
cond_scale = 1.
):
shape = cond.shape
batch = shape[0]

self.voicebox.eval()

def fn(t, x):
return self.voicebox.forward_with_cond_scale(
x = x.reshape(*shape)

out = self.voicebox.forward_with_cond_scale(
x,
times = t,
phoneme_ids = phoneme_ids,
cond = cond,
cond_scale = cond_scale
)

batch = cond.shape[0]
return rearrange(out, 'b ... -> b (...)')

y0 = torch.randn_like(cond)
t = torch.linspace(0, 1, steps, device = self.device)

print('sampling')
if not self.use_torchode:
print('sampling with torchdiffeq')

trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = trajectory[-1]
else:
print('sampling with torchode')

term = to.ODETerm(fn)
step_method = self.torchode_method_klass(term = term)

step_size_controller = to.IntegralController(
atol = self.odeint_kwargs['atol'],
rtol = self.odeint_kwargs['rtol'],
term = term
)

solver = to.AutoDiffAdjoint(step_method, step_size_controller)
jit_solver = torch.compile(solver)

t = repeat(t, 'n -> b n', b = batch)
y0 = rearrange(y0, 'b ... -> b (...)')

init_value = to.InitialValueProblem(y0 = y0, t_eval = t)

sol = jit_solver.solve(init_value)

trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
sampled = sol.ys[:, -1]
sampled = sampled.reshape(*shape)

sampled = trajectory[-1] # last in trajectory
return sampled

def forward(
Expand Down

0 comments on commit 70f2fcf

Please sign in to comment.