Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Bunches of Issues in Mamba and Mamba2 #90

Open
WorldEditors opened this issue Dec 9, 2024 · 5 comments
Open

[Bug]: Bunches of Issues in Mamba and Mamba2 #90

WorldEditors opened this issue Dec 9, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@WorldEditors
Copy link
Contributor

Describe the bug

  1. Mamba and Mamba2 There is obvious difference between inference as a whole and chunk-wise inference.

  2. In Mamba2, Only Hidden_Size=2048 will work, others will trigger error

  3. What's the usage of cache_position ? Quite confusing

Steps to reproduce the bug

import torch
import fla
print(fla)

from fla.models.mamba2.modeling_mamba2 import Mamba2Mixer, Mamba2Cache
from fla.models.mamba2.configuration_mamba2 import Mamba2Config

from fla.models.mamba.modeling_mamba import MambaMixer, MambaCache
from fla.models.mamba.configuration_mamba import MambaConfig

bsz = 1
hidden_size = 2048  # Will only be ok if hidden_size=2048, others will trigger error
seq_len = 1024
seg_len = 128
lr = 1.0e-3
use_cache = True   # Will be OK if it is False

seg_num = (seq_len - 1) // seg_len + 1

#config = MambaConfig(hidden_size=hidden_size)
config = Mamba2Config(hidden_size=hidden_size)

#encoder = MambaMixer(
#    config,
#    layer_idx=0)
encoder = Mamba2Mixer(
    config,
    layer_idx=0)

encoder = encoder.to('cuda')
inputs = torch.randn(bsz, seq_len, hidden_size, device='cuda')
outputs = torch.randn(bsz, seq_len, hidden_size, device='cuda')

#cache = MambaCache(config, bsz, dtype=inputs.dtype, device=inputs.device)
cache = Mamba2Cache(config, bsz, dtype=inputs.dtype, device=inputs.device)
cache.reset()

cache_position = torch.arange(0, config.conv_kernel, device=inputs.device)
y_full = encoder(inputs, cache_params=cache, cache_position=cache_position)

cache.reset()

b = 0
for seg_id in range(seg_num):
    b = seg_id * seg_len
    e = min(b + seg_len, seq_len)
    cache_position = torch.arange(b, e, device=inputs.device)
    y = encoder(inputs[:, b:(b+seg_len)], cache_params=cache, cache_position=cache_position)
    err = torch.sum((y_full[:, b:(b+seg_len)] - y)**2)
    print(seg_id, err)
    b += seg_len

Got the following Output:

0 tensor(2.8154e-07, device='cuda:0', grad_fn=)
1 tensor(3285.6797, device='cuda:0', grad_fn=)
2 tensor(687.7205, device='cuda:0', grad_fn=)
3 tensor(801.7145, device='cuda:0', grad_fn=)
4 tensor(772.4307, device='cuda:0', grad_fn=)
5 tensor(688.4492, device='cuda:0', grad_fn=)
6 tensor(690.9346, device='cuda:0', grad_fn=)
7 tensor(897.0207, device='cuda:0', grad_fn=)

Expected behavior

should be close to zero

Environment info

  1. torch: 2.4.1
  2. triton: 3.0.0
@WorldEditors WorldEditors added the bug Something isn't working label Dec 9, 2024
@WorldEditors
Copy link
Contributor Author

Probably also related to state-spaces/mamba#641

@vasqu
Copy link
Contributor

vasqu commented Dec 30, 2024

Maybe a bit late but I try to clarify some stuff:

  1. Chunk-wise inference isn't supported by mamba. Only step by step / one by one (after the first initial input).
  2. I assume that the hidden size needs to be a certain multiple of something (possibly due to the kernel). Hard to say without seeing the error itself. Edit: I see that you pass hidden size as input for the mixer - this could be also the reason for failure as expand*hidden_size is expected iirc. Dont know why 2048 works nonetheless then.
  3. I think cache position is used to follow a standard across hf transformers which really need the positions for rope for example. It indeed isn't really necessary, we likely would only need to differentiate between an initial inference step and the following autoregressive decoding.

I also added a bunch of fixes recently in transformers which probably should be ported over here as well. No idea when I have time but others can gladly take over as well!

@vasqu
Copy link
Contributor

vasqu commented Dec 30, 2024

Before I forget it, reference for fixes: huggingface/transformers#35154

@WorldEditors
Copy link
Contributor Author

@vasqu Thank very much for answering.
I do think chunk-wise inference very needed, I am not sure why most of the repos aren't supporting.

@vasqu
Copy link
Contributor

vasqu commented Dec 31, 2024

No idea tbh, it relates to context parallelism in a way as well so it does have its benefits for multiple purposes - inference and training.

Not too familiar with the details around context parallelism but shouldn't only the mamba and conv ops pose problems, i.e. they need special treatment. I think we'd need passing of initial states for the mamba op and the conv "cache" for the conv. Should be feasible imo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants