Skip to content

Commit

Permalink
complete conditioning training and inference using spear-tts TextToSe…
Browse files Browse the repository at this point in the history
…mantic module
  • Loading branch information
lucidrains committed Sep 24, 2023
1 parent 8d4eb1e commit d36b1bc
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 77 deletions.
99 changes: 76 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,103 @@ $ pip install voicebox-pytorch

## Usage

Training and sampling with `TextToSemantic` module from <a href="https://github.com/lucidrains/spear-tts-pytorch">SpearTTS</a>

```python
import torch

from voicebox_pytorch import (
VoiceBox,
ConditionalFlowMatcherWrapper
EncodecVoco,
ConditionalFlowMatcherWrapper,
HubertWithKmeans,
TextToSemantic
)

wav2vec = HubertWithKmeans(
checkpoint_path = './path/to/hubert/checkpoint.pt',
kmeans_path = './path/to/hubert/kmeans.bin'
)

text_to_semantic = TextToSemantic(
wav2vec = wav2vec,
dim = 512,
source_depth = 1,
target_depth = 1,
use_openai_tokenizer = True
)

model = VoiceBox(
dim = 512,
num_phoneme_tokens = 256,
audio_enc_dec = EncodecVoco(),
num_cond_tokens = 500,
depth = 2,
dim_head = 64,
heads = 16
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
voicebox = model,
use_torchode = False # by default will use torchdiffeq with midpoint as in paper, but can use the promising torchode package too
text_to_semantic = text_to_semantic
)

x = torch.randn(2, 1024, 512)
phonemes = torch.randint(0, 256, (2, 1024))
mask = torch.randint(0, 2, (2, 1024)).bool()

loss = cfm_wrapper(
x,
phoneme_ids = phonemes,
cond = x,
mask = mask
)
# mock data

audio = torch.randn(2, 12000)
cond = torch.randn(2, 12000)

# train

loss = cfm_wrapper(audio, cond = cond)

loss.backward()

# after much training above...
# after much training

texts = [
'the rain in spain falls mainly in the plains',
'she sells sea shells by the seashore'
]

sampled = cfm_wrapper.sample(cond = cond, texts = texts) # (2, 1, <audio length>)
```

For unconditional training, `condition_on_text` on `VoiceBox` must be set to `False`

```python
import torch
from voicebox_pytorch import (
VoiceBox,
ConditionalFlowMatcherWrapper
)

model = VoiceBox(
dim = 512,
num_cond_tokens = 500,
depth = 2,
dim_head = 64,
heads = 16,
condition_on_text = False
)

cfm_wrapper = ConditionalFlowMatcherWrapper(
voicebox = model
)

# mock data

cond = torch.randn(2, 1024, 512)
x = torch.randn(2, 1024, 512)

sampled = cfm_wrapper.sample(
phoneme_ids = phonemes,
cond = x,
mask = mask
) # (2, 1024, 512) <- same as cond
# train

loss = cfm_wrapper(x, cond = cond)

loss.backward()

# after much training

sampled = cfm_wrapper.sample(cond = cond)
```

## Todo
Expand All @@ -84,13 +139,11 @@ sampled = cfm_wrapper.sample(
- [x] add encodec / voco for starters
- [x] setup training and sampling with raw audio, if `audio_enc_dec` is passed in
- [x] integrate with log mel spec / encodec - vocos
- [x] spear-tts-integration

- [ ] spear-tts-integration
- [ ] extract sample hz from wav2vec module from spear-tts, and handle conditioning and mask during training and sampling automatically. use verified soundstorm logic as a guide

- [ ] basic accelerate trainer
- [ ] figure out the correct settings for `MelVoco` encode, as the reconstructed audio is longer in length
- [ ] calculate how many seconds corresponds to each frame and add as property on `AudioEncoderDecoder` - when sampling, allow for specifying in seconds
- [ ] basic trainer

## Citations

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.35',
version = '0.1.0',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand All @@ -16,11 +16,11 @@
'text to speech'
],
install_requires=[
'accelerate',
'audiolm-pytorch>=1.2.28',
'naturalspeech2-pytorch>=0.0.41',
'beartype',
'einops>=0.6.1',
'lightning>=2.0.7',
'spear-tts-pytorch>=0.3.4',
'torch>=2.0',
'torchdiffeq',
Expand Down
4 changes: 4 additions & 0 deletions voicebox_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
DurationPredictor,
ConditionalFlowMatcherWrapper,
)

from spear_tts_pytorch import TextToSemantic

from audiolm_pytorch import HubertWithKmeans
Loading

0 comments on commit d36b1bc

Please sign in to comment.