diff --git a/README.md b/README.md
index 11d0b73..cf5f059 100644
--- a/README.md
+++ b/README.md
@@ -28,17 +28,36 @@ $ pip install voicebox-pytorch
## Usage
+Training and sampling with `TextToSemantic` module from SpearTTS
+
```python
import torch
from voicebox_pytorch import (
VoiceBox,
- ConditionalFlowMatcherWrapper
+ EncodecVoco,
+ ConditionalFlowMatcherWrapper,
+ HubertWithKmeans,
+ TextToSemantic
+)
+
+wav2vec = HubertWithKmeans(
+ checkpoint_path = './path/to/hubert/checkpoint.pt',
+ kmeans_path = './path/to/hubert/kmeans.bin'
+)
+
+text_to_semantic = TextToSemantic(
+ wav2vec = wav2vec,
+ dim = 512,
+ source_depth = 1,
+ target_depth = 1,
+ use_openai_tokenizer = True
)
model = VoiceBox(
dim = 512,
- num_phoneme_tokens = 256,
+ audio_enc_dec = EncodecVoco(),
+ num_cond_tokens = 500,
depth = 2,
dim_head = 64,
heads = 16
@@ -46,30 +65,66 @@ model = VoiceBox(
cfm_wrapper = ConditionalFlowMatcherWrapper(
voicebox = model,
- use_torchode = False # by default will use torchdiffeq with midpoint as in paper, but can use the promising torchode package too
+ text_to_semantic = text_to_semantic
)
-x = torch.randn(2, 1024, 512)
-phonemes = torch.randint(0, 256, (2, 1024))
-mask = torch.randint(0, 2, (2, 1024)).bool()
-
-loss = cfm_wrapper(
- x,
- phoneme_ids = phonemes,
- cond = x,
- mask = mask
-)
+# mock data
+
+audio = torch.randn(2, 12000)
+cond = torch.randn(2, 12000)
+
+# train
+
+loss = cfm_wrapper(audio, cond = cond)
loss.backward()
-# after much training above...
+# after much training
+
+texts = [
+ 'the rain in spain falls mainly in the plains',
+ 'she sells sea shells by the seashore'
+]
+
+sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1,