Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation issue with fast inference #203

Open
mquillot opened this issue Feb 6, 2025 · 1 comment
Open

Compilation issue with fast inference #203

mquillot opened this issue Feb 6, 2025 · 1 comment

Comments

@mquillot
Copy link

mquillot commented Feb 6, 2025

Hello,

I tried to follow the readme but, when I run the fast inference command I get an error.

Command I run:

poetry run python -i fam/llm/fast_inference.py

I get the following error

BUILDING MODEL
Using device=cpu
Loading model ...
using dtype=float16
Time to load model: 0.35 seconds
Compiling...Can take up to 2 mins.
Traceback (most recent call last):
  File "metavoice-src/fam/llm/fast_inference.py", line 203, in <module>
    tts = tyro.cli(TTS)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/tyro/_cli.py", line 187, in cli
    output = _cli_impl(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/tyro/_cli.py", line 454, in _cli_impl
    out, consumed_keywords = _calling.call_from_args(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/tyro/_calling.py", line 241, in call_from_args
    return unwrapped_f(*positional_args, **kwargs), consumed_keywords  # type: ignore
  File "metavoice-src/fam/llm/fast_inference.py", line 100, in __init__
    self.model, self.tokenizer, self.smodel, self.model_size = build_model(
  File "metavoice-src/fam/llm/fast_inference_utils.py", line 375, in build_model
    y = generate(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "metavoice-src/fam/llm/fast_inference_utils.py", line 211, in generate
    next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File 'mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/lazy.py", line 90, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/nn_module.py", line 328, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1252, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars.items)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 294, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 248, in call_function
    return super().call_function(tx, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 81, in call_function
    return tx.inline_user_function_return(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2261, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2376, in inline_call_
    tracer.run()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1264, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 542, in call_function
    tensor_variable = wrap_fx_proxy(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1314, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 1399, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1525, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1486, in get_fake_value
    ret_val = wrap_fake_exception(
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1027, in wrap_fake_exception
    return fn()
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1487, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1592, in run_node
    raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1571, in run_node
    return node.target(*args, **kwargs)
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., size=(2, 16, s0, 128)), FakeTensor(..., size=(2, 16, 2048, 128), dtype=torch.float16), FakeTensor(..., size=(2, 16, 2048, 128), dtype=torch.float16)), **{'attn_mask': FakeTensor(..., size=(1, 1, s0, 2048), dtype=torch.bool), 'dropout_p': 0.0}):
Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: c10::Half and value.dtype: c10::Half instead.

from user code:
   File "metavoice-src/fam/llm/fast_inference_utils.py", line 131, in prefill
    logits = model(x, spk_emb, input_pos)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "metavoice-src/fam/llm/fast_model.py", line 160, in forward
    x = layer(x, input_pos, mask)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "metavoice-src/fam/llm/fast_model.py", line 179, in forward
    h = x + self.attention(self.attention_norm(x), mask, input_pos)
  File "mquillot_metavoice_2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "metavoice-src/fam/llm/fast_model.py", line 222, in forward
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The problem is appearing on the following line of code of fam/llm/fast_inference.py:

        self.model, self.tokenizer, self.smodel, self.model_size = build_model(
            precision=self.precision,
            checkpoint_path=Path(self._first_stage_ckpt),
            spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
            device=self._device,
            compile=True,
            compile_prefill=True,
            quantisation_mode=quantisation_mode,
        )

At the beginning, I get the following error and warning messages. I am not sure it's related but it can also interest you:

The "poetry.dev-dependencies" section is deprecated and will be removed in a future version. Use "poetry.group.dev.dependencies" instead.
WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:
    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.2.1+cu121)
    Python  3.10.13 (you have 3.10.16)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
mquillot_metavoice_2/lib/python3.10/site-packages/torch/cuda/__init__.py:141: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:108.)
  return torch._C._cuda_getDeviceCount() > 0
mquillot_metavoice_2/lib/python3.10/site-packages/transformers/utils/hub.py:124: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.
  warnings.warn(
mquillot_metavoice_2/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
  warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
mquillot_metavoice_2/lib/python3.10/site-packages/transformers/models/encodec/modeling_encodec.py:123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.register_buffer("padding_total", torch.tensor(kernel_size - stride, dtype=torch.int64), persistent=False)
mquillot_metavoice_2/lib/python3.10/site-packages/df/io.py:9: UserWarning: `torchaudio.backend.common.AudioMetaData` has been moved to `torchaudio.AudioMetaData`. Please update the import path.
  from torchaudio.backend.common import AudioMetaData
using dtype=float16
`

Everything seems quite ok, but the compilation is unfortunately failing.

Do you have an idea of what's happening?

I strictly followed the installation guidelines in the README.
@Natalie-Caruana
Copy link

I had a similar issue and resolved it by modifying the fast_inference_utils.py script. Specifically, I change line 281 to:

model = model.to(device=device, dtype=precision)

The problem was that the precision had been hardcoded to bfloat16, which caused conflicts with other cached variables. Adjusting it to match the expected precision resolved the issue for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants