Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while running bash command: run_sample_video.sh | Error: "TypeError: missing a required argument: 'segment_ids'" #77

Open
samitm-123 opened this issue Jun 18, 2024 · 6 comments

Comments

@samitm-123
Copy link

I receive this error when i run this bash command: !bash LWM/scripts/run_sample_video.sh. I have followed all the direction listed in the repo.

/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/content/LWM/lwm/vision_generation.py", line 256, in <module>
    run(main)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.10/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/content/LWM/lwm/vision_generation.py", line 92, in main
    model = FlaxVideoLLaMAForCausalLM(
  File "/content/LWM/lwm/vision_llama.py", line 141, in __init__
    super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_flax_utils.py", line 224, in __init__
    params_shape_tree = jax.eval_shape(init_fn, self.key)
  File "/content/LWM/lwm/vision_llama.py", line 166, in init_weights
    random_params = self.module.init(rngs, input_ids, vision_masks, attention_mask, segment_ids, position_ids, return_dict=False)["params"]
  File "/content/LWM/lwm/vision_llama.py", line 396, in __call__
    outputs = self.transformer(
  File "/content/LWM/lwm/vision_llama.py", line 315, in __call__
    outputs = self.h(
  File "/content/LWM/lwm/llama.py", line 945, in __call__
    hidden_states, _ = nn.scan(
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 151, in scan_fn
    _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
  File "/usr/local/lib/python3.10/dist-packages/flax/core/axes_scan.py", line 123, in body_fn
    broadcast_out, c, ys = fn(broadcast_in, c, *xs)
  File "/content/LWM/lwm/llama.py", line 724, in __call__
    attn_outputs = self.attention(
  File "/content/LWM/lwm/llama.py", line 615, in __call__
    attn_output = ring_attention_sharded(
  File "/usr/lib/python3.10/inspect.py", line 3186, in bind
    return self._bind(args, kwargs)
  File "/usr/lib/python3.10/inspect.py", line 3101, in _bind
    raise TypeError(msg) from None
TypeError: missing a required argument: 'segment_ids'

Would appreciate some help here.

@gabeweisz
Copy link

Seeing the same error. Commit 97ae4b6 works for me on GPU

@samitm-123
Copy link
Author

@gabeweisz I get the following error:
(lwm) madhu@madhupc:~/LWM$ bash scripts/run_sample_image.sh
Traceback (most recent call last):
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/runpy.py", line 86, in runcode
exec(code, runglobals)
File "/home/madhu/LWM/lwm/visiongeneration.py", line 11, in
from tux import (
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/_init.py", line 1, in
from .checkpoint import StreamingCheckpointer
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/tux/checkpoint.py", line 4, in
import flax
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/__init.py", line 18, in
from .configurations import (
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/configurations.py", line 92, in
flax_filter_frames = define_bool_state(
File "/home/madhu/anaconda3/envs/lwm/lib/python3.10/site-packages/flax/configurations.py", line 42, in define_bool_state
return jax_config.define_bool_state('flax' + name, default, help)
AttributeError: 'Config' object has no attribute 'define_bool_state'

This is our google colab, do you mind taking a look and telling us changes should be made to run this model.

https://colab.research.google.com/drive/1Bx-wRzOspvq5JLctNKRHwHq-vIgw7wlv?usp=sharing

@gabeweisz
Copy link

For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86

I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information

@madhuvanthp
Copy link

For the version of the repo I pointed you to, it works for me using Jax 0.4.25 and with flax==0.8.2 and chex==0.1.86

I'm not part of your google collaboration, but maybe the authors of this project will chime in with more information

So for gpus you used Commit 97ae4b6 and solely followed the instructions for that specific version? Or did you run some other commands? Also, do you mind showing me your entire requirements txt file? The versions in the requirements.txt from 97ae4b6 are different from what you mentioned. I am struggling to get this working with my gpu.

@gabeweisz
Copy link

I used commit 97ae4b6 and did not change anything.

I installed packages using the requirements.txt in that commit, and then updated the two packages that I mention above manually using pip.

I most likely have a different GPU than you do, but this is what worked for me.

@gabeweisz
Copy link

The newest commit (b8e3602) fixes this error for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants