Skip to content

Commit

Permalink
Merge pull request #36 from lucasnewman/fix-cond-drop
Browse files Browse the repository at this point in the history
Fix conditional drop for CFG when conditioning on semantic/phoneme tokens
  • Loading branch information
lucidrains authored Nov 22, 2023
2 parents ca8482b + c113619 commit 79f6672
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,8 @@ def forward(

# classifier free guidance

cond_ids = cond_token_ids

if cond_drop_prob > 0.:
cond_drop_mask = prob_mask_like(cond.shape[:1], cond_drop_prob, self.device)

Expand All @@ -1043,7 +1045,7 @@ def forward(
cond_emb = None

if self.condition_on_text:
cond_emb = self.to_cond_emb(cond_token_ids)
cond_emb = self.to_cond_emb(cond_ids)

cond_emb_length = cond_emb.shape[-2]
if cond_emb_length != seq_len:
Expand Down

0 comments on commit 79f6672

Please sign in to comment.