diff --git a/README.md b/README.md index 11d0b73..cf5f059 100644 --- a/README.md +++ b/README.md @@ -28,17 +28,36 @@ $ pip install voicebox-pytorch ## Usage +Training and sampling with `TextToSemantic` module from SpearTTS + ```python import torch from voicebox_pytorch import ( VoiceBox, - ConditionalFlowMatcherWrapper + EncodecVoco, + ConditionalFlowMatcherWrapper, + HubertWithKmeans, + TextToSemantic +) + +wav2vec = HubertWithKmeans( + checkpoint_path = './path/to/hubert/checkpoint.pt', + kmeans_path = './path/to/hubert/kmeans.bin' +) + +text_to_semantic = TextToSemantic( + wav2vec = wav2vec, + dim = 512, + source_depth = 1, + target_depth = 1, + use_openai_tokenizer = True ) model = VoiceBox( dim = 512, - num_phoneme_tokens = 256, + audio_enc_dec = EncodecVoco(), + num_cond_tokens = 500, depth = 2, dim_head = 64, heads = 16 @@ -46,30 +65,66 @@ model = VoiceBox( cfm_wrapper = ConditionalFlowMatcherWrapper( voicebox = model, - use_torchode = False # by default will use torchdiffeq with midpoint as in paper, but can use the promising torchode package too + text_to_semantic = text_to_semantic ) -x = torch.randn(2, 1024, 512) -phonemes = torch.randint(0, 256, (2, 1024)) -mask = torch.randint(0, 2, (2, 1024)).bool() - -loss = cfm_wrapper( - x, - phoneme_ids = phonemes, - cond = x, - mask = mask -) +# mock data + +audio = torch.randn(2, 12000) +cond = torch.randn(2, 12000) + +# train + +loss = cfm_wrapper(audio, cond = cond) loss.backward() -# after much training above... +# after much training + +texts = [ + 'the rain in spain falls mainly in the plains', + 'she sells sea shells by the seashore' +] + +sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1,