Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moshi Generation Does Not Work as Expected #36160

Open
4 tasks
SeungyounShin opened this issue Feb 13, 2025 · 1 comment · May be fixed by #36171
Open
4 tasks

Moshi Generation Does Not Work as Expected #36160

SeungyounShin opened this issue Feb 13, 2025 · 1 comment · May be fixed by #36171
Labels

Comments

@SeungyounShin
Copy link

System Info

🐛 Bug Report

Description

The provided Moshi example code does not function correctly with the Transformers library. The generate function fails when attempting to generate new tokens, and an issue arises with the expected input formats.

And here is moshi_output.wav

moshi_bug_report.mp4

I tried different temperature settings, generation configurations, and other samples, but it only produces a static 'chijijik...' sound.

cc. @ylacombe

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from datasets import load_dataset, Audio
import torch, math
from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer
import soundfile as sf
import torch
import transformers
import os
import torch

# Disable all automatic compilation features
os.environ['TORCH_COMPILE'] = '0'
os.environ['TORCHDYNAMO_DISABLE'] = '1'  # Fully disables TorchDynamo
os.environ['TORCHDYNAMO_VERBOSE'] = '0'  # Suppresses unnecessary logs
os.environ['TORCHDYNAMO_RECOMPILE_LIMIT'] = '0'  # Avoid recompile limits

# Apply global config settings for eager mode
torch._dynamo.config.suppress_errors = True  # Avoids crashes and falls back to eager mode
torch._dynamo.config.cache_size_limit = 0  # Prevents recompilation limits
torch._dynamo.reset()  # Clears any cached compile traces


librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

device = "cuda"
# prepare user input audio 
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=24000))
audio_sample = librispeech_dummy[-1]["audio"]["array"] # (107520,)
# WAV_PATH = f"./audio/moshi_opening.wav"
# audio_sample, sample_rate = sf.read(WAV_PATH)
waveform_to_token_ratio = 1 / 1920

model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", attn_implementation="eager", torch_dtype=torch.float16)
feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")
tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
model = model.to(device)

user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=24000, return_tensors="pt").to(device=device, dtype=torch.float16)

# prepare moshi input values - we suppose moshi didn't say anything while the user spoke
moshi_input_values = torch.zeros_like(user_input_values.input_values) # (1, 1, 107520)

# prepare moshi input ids - we suppose moshi didn't say anything while the user spoke
num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio)
input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("<pad>")[0]

# Force disable torch.compile inside Transformers
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation
)

# generate 25 new tokens (around 2s of audio)
output = model.generate(
    input_ids=input_ids,
    user_input_values=user_input_values.input_values,
    moshi_input_values=moshi_input_values,
    max_new_tokens=50,
    temperature=0.8,
    do_sample=True,
)

text_tokens = output.sequences
# decode text tokens
text = tokenizer.decode(text_tokens[0], skip_special_tokens=True)
print(text)

# decode audio tokens
audio_waveforms = output.audio_sequences.squeeze(0).squeeze(0) # (L,)
audio_waveforms = audio_waveforms.double()

# cut audio for input length
audio_waveforms = audio_waveforms[:user_input_values.input_values.shape[-1]]

# save audio
sf.write("moshi_output.wav", audio_waveforms.cpu().numpy(), 24000)

Expected behavior

should produce sounds

@SeungyounShin
Copy link
Author

Another approach (using codes directly)

from datasets import load_dataset, Audio
import torch, math
from transformers import MoshiForConditionalGeneration, AutoFeatureExtractor, AutoTokenizer, MimiModel
import soundfile as sf
import torch
import transformers
import os
import torch

# Disable all automatic compilation features
os.environ['TORCH_COMPILE'] = '0'
os.environ['TORCHDYNAMO_DISABLE'] = '1'  # Fully disables TorchDynamo
os.environ['TORCHDYNAMO_VERBOSE'] = '0'  # Suppresses unnecessary logs
os.environ['TORCHDYNAMO_RECOMPILE_LIMIT'] = '0'  # Avoid recompile limits

# Apply global config settings for eager mode
torch._dynamo.config.suppress_errors = True  # Avoids crashes and falls back to eager mode
torch._dynamo.config.cache_size_limit = 0  # Prevents recompilation limits
torch._dynamo.reset()  # Clears any cached compile traces


librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

device = "cuda"
# prepare user input audio 
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=24000))
audio_sample = librispeech_dummy[-1]["audio"]["array"] # (107520,)
# WAV_PATH = f"./audio/moshi_opening.wav"
# audio_sample, sample_rate = sf.read(WAV_PATH)
waveform_to_token_ratio = 1 / 1920

model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko", attn_implementation="eager", torch_dtype=torch.float16)
feature_extractor = AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")
tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
model = model.to(device)

user_input_values = feature_extractor(raw_audio=audio_sample, sampling_rate=24000, return_tensors="pt").to(device=device, dtype=torch.float16)

# prepare moshi input values - we suppose moshi didn't say anything while the user spoke
moshi_input_values = torch.zeros_like(user_input_values.input_values) # (1, 1, 107520)

# prepare moshi input ids - we suppose moshi didn't say anything while the user spoke
num_tokens = math.ceil(moshi_input_values.shape[-1] * waveform_to_token_ratio)
input_ids = torch.ones((1, num_tokens), device=device, dtype=torch.int64) * tokenizer.encode("<pad>")[0]

# Force disable torch.compile inside Transformers
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.forward
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.generate
)
transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation = torch._dynamo.disable(
    transformers.models.moshi.modeling_moshi.MoshiForConditionalGeneration.prepare_inputs_for_generation
)


input_ids = torch.tensor([[3, 3, 3, 3, 3, 0, 2295, 3, 667, 261, 3, 3, 3, 3, 3, 3, 3, 3, 3]]).to(device)
input_ids = torch.tensor([[3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]).to(device)

# (batch_size, num_codebooks, sequence_length)
user_audio_codes = torch.tensor([
    [[1049],[1520],[1502],[1562],[1168],[1016],[ 976],[2008]],
    [[127],[441],[1873],[317],[1622],[757],[ 1686],[1621]],
    [[1272],[724],[665],[1417],[896],[1331],[ 1926],[1451]],
    [[962],[773],[720],[1869],[1944],[720],[ 1223],[467]],
    [[1707],[773],[1606],[966],[700],[1467],[ 886],[1349]],
    [[1707],[773],[1515],[1869],[38],[1878],[ 815],[1146]],
    [[1216],[677],[230],[1015],[1924],[1582],[ 595],[605]],
    [[1942],[1732],[440],[714],[1324],[1568],[ 41],[1938]],
    [[2005],[238],[739],[1381],[27],[87],[ 13],[959]],
    [[1293],[1628],[1744],[1263],[1051],[1686],[ 1007],[94]],
    [[70],[424],[683],[1552],[869],[19],[ 634],[1624]],
    [[1164],[1393],[71],[159],[1166],[1679],[ 1235],[919]],
    [[1957],[287],[370],[829],[928],[210],[1943],[ 144]],
    [[502],[403],[1626],[164],[1736],[1572],[ 976],[1551]],
    [[769],[243],[1149],[290],[481],[1030],[ 1238],[568]],
    [[1743],[243],[1559],[1348],[340],[347],[ 1238],[1744]],
    [[716],[1056],[1502],[1712],[306],[1030],[ 1238],[1744]],
    [[1544],[1056],[1149],[164],[1736],[1572],[ 1238],[1744]],
    [[1544],[1056],[1178],[1348],[340],[1443],[ 1238],[1744]],
]).to(device).permute(2,1,0)
moshi_audio_codes = torch.tensor([
    [[1049],[1700],[1626],[546],[306],[ 1443],[1871],[2008]],
    [[127],[243],[457],[290],[1019],[1030],[ 428],[152]],
    [[1880],[91],[1029],[390],[1335],[1569],[ 192],[424]],
    [[972],[91],[457],[164],[1809],[478],[ 1044],[1744]],
    [[972],[91],[457],[290],[1335],[1447],[ 428],[1053]],
    [[1766],[1783],[1626],[400],[1317],[332],[ 1433],[1439]],
    [[1853],[398],[1089],[325],[1246],[1256],[ 922],[514]], # Good
    [[1853],[161],[1762],[1589],[484],[1162],[ 273],[166]],
    [[662],[1919],[380],[610],[204],[39],[ 390],[391]],
    [[966],[833],[1832],[425],[1466],[ 1575],[2003],[245]], # day
    [[1575],[1068],[1588],[264],[92],[854],[ 1765],[2007]], # ,
    [[198],[168],[928],[229],[266],[612],[ 1164],[238]],
    [[198],[1635],[1653],[1709],[757],[889],[ 286],[410]],
    [[582],[669],[480],[826],[1335],[1707],[ 849],[1084]],
    [[978],[811],[1178],[164],[1217],[347],[ 1543],[1511]],
    [[729],[1056],[457],[290],[1335],[1572],[ 1433],[1744]],
    [[940],[811],[1178],[164],[1314],[347],[ 43],[1648]],
    [[940],[91],[1029],[546],[1335],[347],[ 1433],[424]],
    [[1098],[1519],[1697],[390],[1836],[1443],[ 1234],[2008]],
]).to(device).permute(2,1,0)

# generate 25 new tokens (around 2s of audio)
output = model.generate(
    input_ids=input_ids,
    user_audio_codes=user_audio_codes,
    moshi_audio_codes=moshi_audio_codes,
    max_new_tokens=50,
    temperature=1.0,
    do_sample=True,
)

text_tokens = output.sequences
# decode text tokens
text = tokenizer.decode(text_tokens[0], skip_special_tokens=False)
print(text)

# decode audio tokens
audio_waveforms = output.audio_sequences.squeeze(0).squeeze(0) # (L,)
audio_waveforms = audio_waveforms.double()

# save audio
sf.write("moshi_output_2.wav", audio_waveforms.cpu().numpy(), 24000)

# save user audio
mimi_model = MimiModel.from_pretrained("kyutai/mimi")
mimi_model = mimi_model.to(device)
user_audio_values = mimi_model.decode(user_audio_codes)
sf.write("moshi_user_audio_2.wav",  user_audio_values.audio_values.detach().cpu().squeeze().numpy(), 24000)

import pdb; pdb.set_trace()

This acoustic codes work on original repo but not work for transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant