Skip to content

Commit

Permalink
on further discussions with Bryan, the equations in paper seem right.…
Browse files Browse the repository at this point in the history
… will ignore torchcfm
  • Loading branch information
lucidrains committed Aug 7, 2023
1 parent 8d26c7b commit 88b9a52
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.12',
version = '0.0.14',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
13 changes: 4 additions & 9 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ def fn(t, x):
y0 = torch.randn_like(cond)
t = torch.linspace(0, 1, steps, device = self.device)

print('sampling')

trajectory = odeint(fn, y0, t, adjoint_params=(), **self.odeint_kwargs)

sampled = trajectory[-1] # last in trajectory
Expand All @@ -613,31 +615,24 @@ def forward(
*,
phoneme_ids,
cond,
mask = None,
use_torchcfm_impl = False
mask = None
):
"""
following eq (5) (6) in https://arxiv.org/pdf/2306.15687.pdf
using https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/conditional_flow_matching.py as reference
"""

batch, seq_len, dtype, σ = *x1.shape[:2], x1.dtype, self.sigma

x0 = torch.randn_like(x1)

# a tiny difference between the paper and Alex Tong's torchcfm implementation, is that he uses a different sample gaussian noise for epsilon for calculating w
# not sure what is correct

to_eps = identity if not use_torchcfm_impl else torch.randn_like

# random times

times = torch.rand((batch,), dtype = dtype, device = self.device)
t = rearrange(times, 'b -> b 1 1')

# sample xt (w in the paper)

w = (1 - (1 - σ) * t) * to_eps(x0) + t * x1
w = (1 - (1 - σ) * t) * x0 + t * x1

flow = x1 - (1 - σ) * x0

Expand Down

0 comments on commit 88b9a52

Please sign in to comment.