Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

「Question」Support 24GB 4090 inferences with multiple nodes #205

Open
aisensiy opened this issue Jan 2, 2025 · 6 comments
Open

「Question」Support 24GB 4090 inferences with multiple nodes #205

aisensiy opened this issue Jan 2, 2025 · 6 comments

Comments

@aisensiy
Copy link

aisensiy commented Jan 2, 2025

Is your feature request related to a problem? Please describe.

Currently, the only consumer-grade GPU that supports FP8 is the RTX 4090. I am attempting to run DeepSeek V3 across 4 nodes, each with 8 GPUs, but even with a very small context size (128), I encounter an “Out of Memory” error.

I want to confirm whether this issue is due to my configuration or if a model of this scale simply cannot run even with 32 RTX 4090 GPUs.

Here is my vLLM script, and I am using the latest version (0.6.6):

export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
 
vllm serve deepseek-ai/DeepSeek-V3 \
    --trust-remote-code \
    --host 0.0.0.0 --port $PORT \
    --gpu-memory-utilization 0.98 \
    --max-model-len 128 \
    --tensor-parallel-size 8 --pipeline-parallel-size 4 --enforce-eager

Here are some of the outputs:

INFO 01-02 02:07:12 model_runner.py:1099] Loading model weights took 16.9583 GB
(RayWorkerWrapper pid=2040) INFO 01-02 02:07:12 model_runner.py:1099] Loading model weights took 16.9687 GB [repeated 2x across cluster]
(RayWorkerWrapper pid=2031) INFO 01-02 02:07:18 model_runner.py:1099] Loading model weights took 16.9687 GB [repeated 2x across cluster]
(RayWorkerWrapper pid=2025) INFO 01-02 02:07:26 model_runner.py:1099] Loading model weights took 16.9687 GB [repeated 2x across cluster]


(RayWorkerWrapper pid=1053, ip=10.96.18.24) INFO 01-02 02:38:57 model_runner.py:1099] Loading model weights took 20.5795 GB [repeated 3x across cluster]
(RayWorkerWrapper pid=1052, ip=10.96.37.211) INFO 01-02 02:39:09 model_runner.py:1099] Loading model weights took 20.5795 GB [repeated 8x across cluster]
(RayWorkerWrapper pid=1057, ip=10.96.62.52) WARNING 01-02 02:39:13 fused_moe.py:374] Using default MoE config. Performance might be sub-optimal! Config file not found at /root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
(RayWorkerWrapper pid=1057, ip=10.96.37.211) INFO 01-02 02:39:09 model_runner.py:1099] Loading model weights took 20.5795 GB [repeated 7x across cluster]
WARNING 01-02 02:39:15 fused_moe.py:374] Using default MoE config. Performance might be sub-optimal! Config file not found at /root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json
(RayWorkerWrapper pid=1054, ip=10.96.62.52) INFO 01-02 02:39:17 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20250102-023917.pkl...
(RayWorkerWrapper pid=1053, ip=10.96.18.24) INFO 01-02 02:39:17 worker.py:241] Memory profiling takes 8.36 seconds
(RayWorkerWrapper pid=1053, ip=10.96.18.24) INFO 01-02 02:39:17 worker.py:241] the current vLLM instance can use total_gpu_memory (23.64GiB) x gpu_memory_utilization (0.98) = 23.17GiB
(RayWorkerWrapper pid=1053, ip=10.96.18.24) INFO 01-02 02:39:17 worker.py:241] model weights take 20.58GiB; non_torch_memory takes 0.23GiB; PyTorch activation peak memory takes 0.40GiB; the rest of the memory reserved for KV Cache is 1.96GiB.
(RayWorkerWrapper pid=1054, ip=10.96.62.52) INFO 01-02 02:39:17 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20250102-023917.pkl.
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] Error executing method determine_num_available_blocks. This might cause deadlock in distributed execution.
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] Traceback (most recent call last):
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return func(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1747, in execute_model
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     output: SamplerOutput = self.model.sample(
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/models/deepseek_v3.py", line 546, in sample
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     next_tokens = self.sampler(logits, sampling_metadata)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return self._call_impl(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return forward_call(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 274, in forward
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 392, in _apply_top_k_top_p
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 380.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 14.81 MiB is free. Process 258800 has 23.62 GiB memory in use. Of the allocated memory 22.83 GiB is allocated by PyTorch, and 162.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] 
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] The above exception was the direct cause of the following exception:
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] 
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] Traceback (most recent call last):
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 459, in execute_method
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return executor(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return func(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/worker.py", line 202, in determine_num_available_blocks
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     self.model_runner.profile_run()
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return func(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1331, in profile_run
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     return func(*args, **kwargs)
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467]     raise type(err)(
(RayWorkerWrapper pid=1054, ip=10.96.62.52) ERROR 01-02 02:39:17 worker_base.py:467] torch.OutOfMemoryError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250102-023917.pkl): CUDA out of memory. Tried to allocate 380.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 14.81 MiB is free. Process 258800 has 23.62 GiB memory in use. Of the allocated memory 22.83 GiB is allocated by PyTorch, and 162.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
INFO 01-02 02:39:18 worker.py:241] Memory profiling takes 9.27 seconds
INFO 01-02 02:39:18 worker.py:241] the current vLLM instance can use total_gpu_memory (23.64GiB) x gpu_memory_utilization (0.98) = 23.17GiB
INFO 01-02 02:39:18 worker.py:241] model weights take 16.96GiB; non_torch_memory takes 0.22GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 5.62GiB.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/.pylibs/bin/vllm", line 8, in <module>
[rank0]:     sys.exit(main())
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/scripts.py", line 201, in main
[rank0]:     args.dispatch_function(args)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/scripts.py", line 42, in serve
[rank0]:     uvloop.run(run_server(args))
[rank0]:   File "/usr/local/lib/python3.10/site-packages/uvloop/__init__.py", line 82, in run
[rank0]:     return loop.run_until_complete(wrapper())
[rank0]:   File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
[rank0]:   File "/usr/local/lib/python3.10/site-packages/uvloop/__init__.py", line 61, in wrapper
[rank0]:     return await main
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 740, in run_server
[rank0]:     async with build_async_engine_client(args) as engine_client:
[rank0]:   File "/usr/local/lib/python3.10/contextlib.py", line 199, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 118, in build_async_engine_client
[rank0]:     async with build_async_engine_client_from_engine_args(
[rank0]:   File "/usr/local/lib/python3.10/contextlib.py", line 199, in __aenter__
[rank0]:     return await anext(self.gen)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 151, in build_async_engine_client_from_engine_args
[rank0]:     engine_client = build_engine()
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 707, in from_engine_args
[rank0]:     engine = cls(
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 594, in __init__
[rank0]:     self.engine = self._engine_class(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 267, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 276, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 416, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
[rank0]:     num_blocks = self._run_workers("determine_num_available_blocks", )
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/executor/ray_gpu_executor.py", line 413, in _run_workers
[rank0]:     ray_worker_outputs = ray.get(ray_worker_outputs)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 2753, in get
[rank0]:     values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/ray/_private/worker.py", line 904, in get_objects
[rank0]:     raise value.as_instanceof_cause()
[rank0]: ray.exceptions.RayTaskError(OutOfMemoryError): ray::RayWorkerWrapper.execute_method() (pid=1054, ip=10.96.62.52, actor_id=5871d8424c552844053512e201000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x7f80da8db550>)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1747, in execute_model
[rank0]:     output: SamplerOutput = self.model.sample(
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/models/deepseek_v3.py", line 546, in sample
[rank0]:     next_tokens = self.sampler(logits, sampling_metadata)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 274, in forward
[rank0]:     logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 392, in _apply_top_k_top_p
[rank0]:     logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 380.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 14.81 MiB is free. Process 258800 has 23.62 GiB memory in use. Of the allocated memory 22.83 GiB is allocated by PyTorch, and 162.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: ray::RayWorkerWrapper.execute_method() (pid=1054, ip=10.96.62.52, actor_id=5871d8424c552844053512e201000000, repr=<vllm.executor.ray_utils.RayWorkerWrapper object at 0x7f80da8db550>)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 468, in execute_method
[rank0]:     raise e
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 459, in execute_method
[rank0]:     return executor(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/worker.py", line 202, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1331, in profile_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/root/.pylibs/lib/python3.10/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
[rank0]:     raise type(err)(
[rank0]: torch.OutOfMemoryError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250102-023917.pkl): CUDA out of memory. Tried to allocate 380.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 14.81 MiB is free. Process 258800 has 23.62 GiB memory in use. Of the allocated memory 22.83 GiB is allocated by PyTorch, and 162.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
(RayWorkerWrapper pid=2032) WARNING 01-02 02:39:15 fused_moe.py:374] Using default MoE config. Performance might be sub-optimal! Config file not found at /root/.pylibs/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json [repeated 30x across cluster]
(RayWorkerWrapper pid=2032) INFO 01-02 02:39:18 worker.py:241] Memory profiling takes 9.26 seconds [repeated 29x across cluster]
(RayWorkerWrapper pid=2032) INFO 01-02 02:39:18 worker.py:241] the current vLLM instance can use total_gpu_memory (23.64GiB) x gpu_memory_utilization (0.98) = 23.17GiB [repeated 29x across cluster]
(RayWorkerWrapper pid=2032) INFO 01-02 02:39:18 worker.py:241] model weights take 16.97GiB; non_torch_memory takes 0.46GiB; PyTorch activation peak memory takes 0.38GiB; the rest of the memory reserved for KV Cache is 5.36GiB. [repeated 29x across cluster]
[rank0]:[W102 02:39:20.918569899 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
/usr/local/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Describe the solution you'd like
A clear and concise description of what you want to happen.

Describe alternatives you've considered
A clear and concise description of any alternative solutions or features you've considered.

Additional context
Add any other context or screenshots about the feature request here.

  1. I have already consulted ChatGPT, DeepSeek, and other large language models, tried various parameters (different context length, eager mode or not), but the issue persists.
  2. Using ray status, I can see that all 32 GPUs across the four nodes are detected, and they are indeed being utilized after running the script.
  3. I follow the instruction Distributed Inference and Serving and it works for qwen 2.5 72b with two nodes of 4090.
@GeeeekExplorer
Copy link
Contributor

We didn't try running it on 32 RTX 4090 GPUs. I think you might try reducing gpu_memory_utilization.

@tianyunzqs
Copy link

I try it in 8*80G(H800), it's also OOM.

vllm serve deepseek-ai/DeepSeek-V3 \
    --trust-remote-code \
    --gpu-memory-utilization 0.95 \
    --max-model-len 128 \
    --tensor-parallel-size 8

@aisensiy
Copy link
Author

aisensiy commented Jan 3, 2025

We didn't try running it on 32 RTX 4090 GPUs. I think you might try reducing gpu_memory_utilization.

Change it do not help, still oom.

@GeeeekExplorer
Copy link
Contributor

We didn't try running it on 32 RTX 4090 GPUs. I think you might try reducing gpu_memory_utilization.

Change it do not help, still oom.

The model has 61 layers which is not divisible by 4, try specifying VLLM_PP_LAYER_PARTITION="16,15,15,15".

@aisensiy
Copy link
Author

aisensiy commented Jan 4, 2025

We didn't try running it on 32 RTX 4090 GPUs. I think you might try reducing gpu_memory_utilization.

Change it do not help, still oom.

The model has 61 layers which is not divisible by 4, try specifying VLLM_PP_LAYER_PARTITION="16,15,15,15".

Thanks for the reply, but still get oom...

@youkaichao
Copy link

@aisensiy

logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

the OOM happens during sampling, which is proportional to --max-num-seqs .

Loading model weights took 20.5795 GB

Your GPU has 24 GiB memory, and the model weight already took 20.5 GiB. This is quite stressful.

Suggestion:

Try with --max-model-len 128 --max-num-seqs 1 to see if you can run the model.

It is also possible to reduce the model weight memory by using --cpu-offload-gb , but I'm not sure if it is compatible with the deepseek v3 model.

I'm preparing a blogpost to explain the memory footprint, please stay tuned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants