Skip to content

Commit

Permalink
clearer pack unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 8, 2023
1 parent 13ceb17 commit c958e8f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
12 changes: 9 additions & 3 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def is_odd(n):
def coin_flip():
return random() < 0.5

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

# tensor helpers

def prob_mask_like(shape, prob, device):
Expand Down Expand Up @@ -603,7 +609,7 @@ def sample(

def fn(t, x, *, packed_shape = None):
if exists(packed_shape):
x, = unpack(x, packed_shape, 'b *')
x = unpack_one(x, packed_shape, 'b *')

out = self.voicebox.forward_with_cond_scale(
x,
Expand All @@ -630,7 +636,7 @@ def fn(t, x, *, packed_shape = None):
print('sampling with torchode')

t = repeat(t, 'n -> b n', b = batch)
y0, packed_shape = pack([y0], 'b *')
y0, packed_shape = pack_one(y0, 'b *')

fn = partial(fn, packed_shape = packed_shape)

Expand All @@ -651,7 +657,7 @@ def fn(t, x, *, packed_shape = None):
sol = jit_solver.solve(init_value)

sampled = sol.ys[:, -1]
sampled = sampled.reshape(*shape)
sampled = unpack_one(sampled, packed_shape, 'b *')

return sampled

Expand Down

0 comments on commit c958e8f

Please sign in to comment.