Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768) #62

Open
scsonic opened this issue Mar 17, 2024 · 2 comments

Comments

@scsonic
Copy link

scsonic commented Mar 17, 2024

Create one image png: success
Create video: fail
A100 80G

update --n_frames=2048
always fail

!JAX_TRACEBACK_FILTERING=off python3 -u -m lwm.vision_generation
--prompt={prompt}
--output_file={output_filename}
--temperature_image=1.0
--top_k_image=8192
--cfg_scale_image=5.0
--vqgan_checkpoint="{vqgan_checkpoint}"
--n_frames=2048
--dtype='fp32'
--load_llama_config='7b'
--update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,use_flash_attention=True,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)"
--load_checkpoint="params::{lwm_checkpoint}"
--tokenizer.vocab_file="{llama_tokenizer_path}"

the output
/tmp/notebook/content/LWM
env: PYTHONPHAT=/tmp/notebook/content/LWM
env: NUMEXPR_MAX_THREADS=12
I0317 09:57:37.786386 139819503906816 xla_bridge.py:660] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA
I0317 09:57:37.787172 139819503906816 xla_bridge.py:660] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
2024-03-17 09:57:38.663241: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
100%|█████████████████████████████████████████████| 1/1 [00:24<00:00, 24.97s/it]
0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 258, in
run(main)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 247, in main
videos.extend(generate_video_pred(prompts, images, max_input_length=128))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 215, in generate_video_pred
output, sharded_rng = _sharded_forward_generate(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 257, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _, _, _ = infer_params_fn(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 781, in infer_params
return common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 493, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat, out_layouts_flat = _pjit_jaxpr(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 996, in _pjit_jaxpr
jaxpr, final_consts, out_type = _create_pjit_jaxpr(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 349, in memoized_fun
ans = call(fun, *args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/pjit.py", line 936, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2288, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/src/interpreters/partial_eval.py", line 2310, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers
)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/tmp/notebook/content/LWM/lwm/vision_generation.py", line 116, in _forward_generate
output = model.generate_vision(
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 710, in generate_vision
return self._sample_vision(
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 515, in _sample_vision
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 453, in prepare_inputs_for_generation
past_key_values = self.init_cache(batch_size, max_length)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 151, in init_cache
init_variables = self.module.init(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2319, in init
_, v_out = self.init_with_output(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2215, in init_with_output
return init_with_output(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1137, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/scope.py", line 1101, in wrapper
y = fn(root, *args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 2972, in scope_fn
return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 396, in call
outputs = self.transformer(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/vision_llama.py", line 315, in call
outputs = self.h(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 981, in call
hidden_states, _ = nn.scan(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 378, in wrapped_fn
ret = trafo_fn(module_scopes, *args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 325, in wrapper
y, out_variable_groups_xs_t = fn(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1024, in inner
broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 148, in scan_fn
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 774, in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/linear_util.py", line 191, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 120, in body_fn
broadcast_out, c, ys = fn(broadcast_in, c, *xs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/core/lift.py", line 1005, in scanned
c, y = fn(scope, c, *args)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/transforms.py", line 370, in core_fn
res = fn(cloned, *args, **kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 757, in call
attn_outputs = self.attention(
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 694, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/flax/linen/module.py", line 1226, in _call_wrapped_method
y = run_fun(self, *args, **kwargs)
File "/tmp/notebook/content/LWM/lwm/llama.py", line 627, in call
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1227, in broadcast_to
return util._broadcast_to(array, shape)
File "/home/everlyai-user/.local/lib/python3.10/site-packages/jax/_src/numpy/util.py", line 428, in _broadcast_to
raise ValueError(msg.format(arr_shape, shape))
ValueError: Incompatible shapes for broadcasting: (2, 1, 1, 526464) and requested shape (2, 1, 32768, 32768)

@OkinoLeiba
Copy link

My initial thought is adding dummy values or performing some sort of interpolation...one hot encoding. Are the dimensions strides?

@ZQpengyu
Copy link

ZQpengyu commented May 7, 2024

Hello, has this issue been resolved? I've also encountered a similar problem. Additionally, when I increase the frame rate, it throws a shape error. Besides, I feel like I'm not using the GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants