Skip to content

Commit

Permalink
differentiate between attention mask and conditioning mask - conditio…
Browse files Browse the repository at this point in the history
…n should be auto-generated by binary conditioning mask during training, as in section 3.2
  • Loading branch information
lucidrains committed Sep 26, 2023
1 parent 9499426 commit 91654c3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 37 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ from voicebox_pytorch import (
)

wav2vec = HubertWithKmeans(
checkpoint_path = './path/to/hubert/checkpoint.pt',
kmeans_path = './path/to/hubert/kmeans.bin'
checkpoint_path = './hubert_base_ls960.pt',
kmeans_path = './hubert_base_ls960_L9_km500.bin'
)

text_to_semantic = TextToSemantic(
Expand Down Expand Up @@ -71,12 +71,10 @@ cfm_wrapper = ConditionalFlowMatcherWrapper(
# mock data

audio = torch.randn(2, 12000)
cond = torch.randn(2, 12000)

# train

loss = cfm_wrapper(audio, cond = cond)

loss = cfm_wrapper(audio)
loss.backward()

# after much training
Expand All @@ -86,6 +84,7 @@ texts = [
'she sells sea shells by the seashore'
]

cond = torch.randn(2, 12000)
sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1, <audio length>)
```

Expand Down Expand Up @@ -113,18 +112,19 @@ cfm_wrapper = ConditionalFlowMatcherWrapper(

# mock data

cond = torch.randn(2, 1024, 512)
x = torch.randn(2, 1024, 512)

# train

loss = cfm_wrapper(x, cond = cond)
loss = cfm_wrapper(x)

loss.backward()

# after much training

sampled = cfm_wrapper.sample(cond = cond)
cond = torch.randn(2, 1024, 512)

sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512)
```

## Todo
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
79 changes: 51 additions & 28 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def __init__(
def forward(
self,
x,
mask = None,
adaptive_rmsnorm_cond = None
):
skip_connects = []
Expand All @@ -324,7 +325,7 @@ def forward(
x = skip_combiner(x)

attn_input = attn_prenorm(x, **rmsnorm_kwargs)
x = attn(attn_input, rotary_emb = rotary_emb) + x
x = attn(attn_input, mask = mask, rotary_emb = rotary_emb) + x

ff_input = ff_prenorm(x, **rmsnorm_kwargs)
x = ff(ff_input) + x
Expand Down Expand Up @@ -749,19 +750,27 @@ def forward(
self,
x,
*,
cond,
times,
cond_token_ids,
self_attn_mask = None,
cond_drop_prob = 0.1,
target = None,
mask = None
cond = None,
cond_mask = None
):
batch, seq_len, cond_dim = cond.shape
assert cond_dim == x.shape[-1]

# project in, in case codebook dim is not equal to model dimensions

x, cond = map(self.proj_in, (x, cond))
x = self.proj_in(x)

if exists(cond):
cond = self.proj_in(cond)

cond = default(cond, x)

# shapes

batch, seq_len, cond_dim = cond.shape
assert cond_dim == x.shape[-1]

# auto manage shape of times, for odeint times

Expand All @@ -771,16 +780,25 @@ def forward(
if times.ndim == 1 and times.shape[0] == 1:
times = repeat(times, '1 -> b', b = cond.shape[0])

# construct mask if not given
# construct conditioning mask if not given

if not exists(mask):
if coin_flip():
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
mask = mask_from_frac_lengths(seq_len, frac_lengths)
else:
mask = prob_mask_like((batch, seq_len), self.p_drop_prob, self.device)
if self.training:
if not exists(cond_mask):
if coin_flip():
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
cond_mask = mask_from_frac_lengths(seq_len, frac_lengths)
else:
cond_mask = prob_mask_like((batch, seq_len), self.p_drop_prob, self.device)
else:
if not exists(cond_mask):
cond_mask = torch.ones((batch, seq_len), device = cond.device, dtype = torch.bool)

cond = cond * rearrange(~mask, '... -> ... 1')
cond_mask_with_pad_dim = rearrange(cond_mask, '... -> ... 1')

# as described in section 3.2

x = x * cond_mask_with_pad_dim
cond = cond * ~cond_mask_with_pad_dim

# classifier free guidance

Expand Down Expand Up @@ -826,7 +844,11 @@ def forward(

# attend

x = self.transformer(x, adaptive_rmsnorm_cond = time_emb)
x = self.transformer(
x,
mask = self_attn_mask,
adaptive_rmsnorm_cond = time_emb
)

x = self.to_pred(x)

Expand All @@ -835,26 +857,26 @@ def forward(
if not exists(target):
return x

if not exists(mask):
if not exists(cond_mask):
return F.mse_loss(x, target)

loss = F.mse_loss(x, target, reduction = 'none')

loss = reduce(loss, 'b n d -> b n', 'mean')
loss = loss.masked_fill(~mask, 0.)
loss = loss.masked_fill(~cond_mask, 0.)

# masked mean

num = reduce(loss, 'b n -> b', 'sum')
den = mask.sum(dim = -1).clamp(min = 1e-5)
den = cond_mask.sum(dim = -1).clamp(min = 1e-5)
loss = num / den

return loss.mean()

# wrapper for the CNF

def is_probably_audio_from_shape(t):
return t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1)
return exists(t) and (t.ndim == 2 or (t.ndim == 3 and t.shape[1] == 1))

class ConditionalFlowMatcherWrapper(Module):
@beartype
Expand Down Expand Up @@ -906,7 +928,7 @@ def sample(
text_token_ids: Optional[Tensor] = None,
semantic_token_ids = None,
phoneme_ids = None,
mask = None,
cond_mask = None,
steps = 3,
cond_scale = 1.,
decode_to_audio = True
Expand All @@ -930,7 +952,7 @@ def sample(
num_cond_inputs = sum([*map(exists, (texts, text_token_ids, semantic_token_ids, phoneme_ids))])
assert num_cond_inputs <= 1

mask = None
self_attn_mask = None
cond_token_ids = None

if self.condition_on_text:
Expand All @@ -939,7 +961,7 @@ def sample(

if not exists(semantic_token_ids):

semantic_token_ids, mask = self.text_to_semantic.generate(
semantic_token_ids, self_attn_mask = self.text_to_semantic.generate(
source = default(text_token_ids, texts),
source_type = 'text',
target_type = 'speech',
Expand Down Expand Up @@ -978,7 +1000,8 @@ def fn(t, x, *, packed_shape = None):
cond_token_ids = cond_token_ids,
cond = cond,
cond_scale = cond_scale,
mask = mask
cond_mask = cond_mask,
self_attn_mask = self_attn_mask
)

if exists(packed_shape):
Expand Down Expand Up @@ -1030,10 +1053,10 @@ def forward(
self,
x1,
*,
cond,
semantic_token_ids = None,
phoneme_ids = None,
mask = None,
cond = None,
cond_mask = None,
input_sampling_rate = None # will assume it to be the same as the audio encoder decoder sampling rate, if not given. if given, will resample
):
"""
Expand Down Expand Up @@ -1062,7 +1085,7 @@ def forward(
x1 = resample(x1, input_sampling_rate, audio_enc_dec_sampling_rate)
x1 = self.voicebox.audio_enc_dec.encode(x1)

if cond_is_raw_audio:
if exists(cond) and cond_is_raw_audio:
cond = resample(cond, input_sampling_rate, audio_enc_dec_sampling_rate)
cond = self.voicebox.audio_enc_dec.encode(cond)

Expand Down Expand Up @@ -1111,7 +1134,7 @@ def forward(
loss = self.voicebox(
w,
cond = cond,
mask = mask,
cond_mask = cond_mask,
times = times,
target = flow,
cond_token_ids = cond_token_ids,
Expand Down

0 comments on commit 91654c3

Please sign in to comment.