Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] pipeline 加载模型时无限期挂起 而命令行部署正常 #3107

Open
3 tasks done
NB-Group opened this issue Jan 31, 2025 · 2 comments
Open
3 tasks done

Comments

@NB-Group
Copy link

NB-Group commented Jan 31, 2025

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.

Describe the bug

在使用pipeline加载deepseek-r1-distill-qwen-7b-gptq-int4模型时卡住,但是在命令行部署时正常。
我在标题里写“挂起”,因为它真的挂起了:

Image

Reproduction

这是问题代码

        self.backend_config = TurbomindEngineConfig(dtype='auto', model_format='gptq', tp=1, session_len=131072,
                                                    max_batch_size=1, cache_max_entry_count=0.8, cache_chunk_size=-1,
                                                    cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0,
                                                    rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None,
                                                    revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0,
                                                    max_prefill_iters=1)
        self.gen_config = GenerationConfig(top_p=0.8,
                                           top_k=40,
                                           temperature=0.8)
        self.pipe = pipeline(config["LLM_MODEL_PATH"],
                             backend_config=self.backend_config,log_level="INFO")

而命令行部署正常:

Models >lmdeploy chat .\deepseek-r1-distill-qwen-7b-gptq-int4-turbomind\ --model-format gptq
Add dll path C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin, please note cuda version should >= 11.3 when compiled with cuda 11
2025-01-31 17:11:14,567 - lmdeploy - WARNING - supported_models.py:106 - .\deepseek-r1-distill-qwen-7b-gptq-int4-turbomind\ seems to be a turbomind workspace, which can only be ran with turbomind engine.
chat_template_config:
ChatTemplateConfig(model_name='deepseek-r1', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None, eoa=None, tool=None, eotool=None, separator=None, capability='chat', stop_words=None)
engine_cfg:
TurbomindEngineConfig(dtype='auto', model_format='gptq', tp=1, session_len=131072, max_batch_size=1, cache_max_entry_count=0.8, cache_chunk_size=-1, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0, max_prefill_iters=1)
[WARNING] gemm_config.in is not found; using default GEMM algo

double enter to end input >>> 你好

<|begin▁of▁sentence|><|User|>你好<|Assistant|><think>

</think>

你好!很高兴见到你,有什么我可以帮忙的吗?

Environment

sys.platform: win32
Python: 3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0: NVIDIA GeForce RTX 4060 Laptop GPU
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8
NVCC: Cuda compilation tools, release 12.8, V12.8.61
MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.42.34436 版
GCC: n/a
PyTorch: 2.6.0+cu126
PyTorch compiling details: PyTorch built with:
  - C++ Version: 201703
  - MSVC 192930157
  - Intel(R) oneAPI Math Kernel Library Version 2025.0.1-Product Build 20241031 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.5.3 (Git Hash 66f0cb9eb66affd2da3bf5f8d897376f04aae6af)
  - OpenMP 2019
  - LAPACK is enabled (usually provided by MKL)
  - CPU capability usage: AVX2
  - CUDA Runtime 12.6
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 90.5.1
  - Magma 2.5.4
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, COMMIT_SHA=2236df1770800ffea5697b11b0bb0d910b2e59e1, CUDA_VERSION=12.6, CUDNN_VERSION=9.5.1, CXX_COMPILER=C:/actions-runner/_work/pytorch/pytorch/pytorch/.ci/pytorch/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /Zc:__cplusplus /bigobj /FS /utf-8 -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DLIBKINETO_NOXPUPTI=ON -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE /wd4624 /wd4068 /wd4067 /wd4267 /wd4661 /wd4717 /wd4244 /wd4804 /wd4273, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, TORCH_VERSION=2.6.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

TorchVision: 0.21.0+cu126
LMDeploy: 0.7.0.post2+
transformers: 4.48.2
gradio: Not Found
fastapi: 0.115.8
pydantic: 2.10.6
triton: Not Found

Error traceback

这是日志:

C:\Users\NB_Group\.conda\envs\MOSS\python.exe F:\Code\Python\MOSS_Ultra\tests\test_language_model.py 
Loading language model...
Add dll path C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin, please note cuda version should >= 11.3 when compiled with cuda 11
2025-01-31 17:30:52,472 - lmdeploy - WARNING - supported_models.py:106 - G:\Models\deepseek-r1-distill-qwen-7b-gptq-int4-turbomind seems to be a turbomind workspace, which can only be ran with turbomind engine.
2025-01-31 17:30:52,472 - lmdeploy - INFO - api.py:81 - Using turbomind engine
2025-01-31 17:30:52,472 - lmdeploy - INFO - async_engine.py:259 - input backend=turbomind, backend_config=TurbomindEngineConfig(dtype='auto', model_format='gptq', tp=1, session_len=131072, max_batch_size=1, cache_max_entry_count=0.8, cache_chunk_size=-1, cache_block_seq_len=64, enable_prefix_caching=False, quant_policy=0, rope_scaling_factor=0.0, use_logn_attn=False, download_dir=None, revision=None, max_prefill_token_num=8192, num_tokens_per_iter=0, max_prefill_iters=1)
2025-01-31 17:30:52,472 - lmdeploy - INFO - async_engine.py:260 - input chat_template_config=None
2025-01-31 17:30:52,499 - lmdeploy - INFO - async_engine.py:269 - updated chat_template_onfig=ChatTemplateConfig(model_name='deepseek-r1', system=None, meta_instruction=None, eosys=None, user=None, eoh=None, assistant=None, eoa=None, tool=None, eotool=None, separator=None, capability=None, stop_words=None)
2025-01-31 17:30:53,896 - lmdeploy - INFO - turbomind.py:282 - model_source: workspace
2025-01-31 17:30:53,902 - lmdeploy - INFO - turbomind.py:190 - turbomind model config:

{
  "model_config": {
    "model_name": "deepseek",
    "chat_template": "deepseek-r1",
    "model_arch": "Qwen2ForCausalLM",
    "head_num": 28,
    "kv_head_num": 4,
    "hidden_units": 3584,
    "vocab_size": 152064,
    "embedding_size": 152064,
    "num_layer": 28,
    "inter_size": [
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944,
      18944
    ],
    "norm_eps": 1e-06,
    "attn_bias": 1,
    "start_id": 151643,
    "end_id": 151645,
    "size_per_head": 128,
    "group_size": 128,
    "weight_type": "int4",
    "session_len": 131072,
    "tp": 1,
    "model_format": "gptq",
    "expert_num": [],
    "expert_inter_size": 0,
    "experts_per_token": 0,
    "moe_shared_gate": false,
    "norm_topk_prob": false,
    "routed_scale": 1.0,
    "topk_group": 1,
    "topk_method": "greedy",
    "moe_group_num": 1,
    "q_lora_rank": 0,
    "kv_lora_rank": 0,
    "qk_rope_dim": 0,
    "v_head_dim": 0,
    "tune_layer_num": 1
  },
  "attention_config": {
    "rotary_embedding": 128,
    "rope_theta": 10000.0,
    "softmax_scale": 0.0,
    "attention_factor": -1.0,
    "max_position_embeddings": 131072,
    "original_max_position_embeddings": 0,
    "rope_scaling_type": "",
    "rope_scaling_factor": 0.0,
    "use_dynamic_ntk": 0,
    "low_freq_factor": 1.0,
    "high_freq_factor": 1.0,
    "beta_fast": 32.0,
    "beta_slow": 1.0,
    "use_logn_attn": 0,
    "cache_block_seq_len": 64
  },
  "lora_config": {
    "lora_policy": "",
    "lora_r": 0,
    "lora_scale": 0.0,
    "lora_max_wo_r": 0,
    "lora_rank_pattern": "",
    "lora_scale_pattern": ""
  },
  "engine_config": {
    "dtype": "auto",
    "model_format": "gptq",
    "tp": 1,
    "session_len": 131072,
    "max_batch_size": 1,
    "cache_max_entry_count": 0.8,
    "cache_chunk_size": -1,
    "cache_block_seq_len": 64,
    "enable_prefix_caching": false,
    "quant_policy": 0,
    "rope_scaling_factor": 0.0,
    "use_logn_attn": false,
    "download_dir": null,
    "revision": null,
    "max_prefill_token_num": 8192,
    "num_tokens_per_iter": 8192,
    "max_prefill_iters": 16
  }
}
[TM][WARNING] [LlamaTritonModel] `max_context_token_num` is not set, default to 131072.
[TM][INFO] Barrier(1)
[TM][INFO] Model: 
head_num: 28
kv_head_num: 4
size_per_head: 128
num_layer: 28
vocab_size: 152064
attn_bias: 1
max_batch_size: 1
max_prefill_token_num: 8192
max_context_token_num: 131072
num_tokens_per_iter: 8192
max_prefill_iters: 16
session_len: 131072
cache_max_entry_count: 0.8
cache_block_seq_len: 64
cache_chunk_size: -1
enable_prefix_caching: 0
start_id: 151643
tensor_para_size: 1
pipeline_para_size: 1
enable_custom_all_reduce: 0
model_name: deepseek
model_dir: G:\Models\deepseek-r1-distill-qwen-7b-gptq-int4-turbomind\triton_models\weights
quant_policy: 0
group_size: 128
expert_per_token: 0
moe_method: 1

[TM][INFO] TM_FUSE_SILU_ACT=1
[TM][INFO] [LlamaWeight<T>::prepare] workspace size: 271581184

[WARNING] gemm_config.in is not found; using default GEMM algo
[TM][INFO] [BlockManager] block_size = 3 MB
[TM][INFO] [BlockManager] max_block_count = 346
[TM][INFO] [BlockManager] chunk_size = 346
[TM][WARNING] No enough blocks for `session_len` (131072), `session_len` truncated to 22144.
[TM][INFO] LlamaBatch<T>::Start()
[TM][INFO] [Gemm2] Tuning sequence: 8, 16, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 6144, 8192
[TM][INFO] [Gemm2] 8
[TM][INFO] [Gemm2] 16
[TM][INFO] [Gemm2] 32
[TM][INFO] [Gemm2] 48
[TM][INFO] [Gemm2] 64
[TM][INFO] [Gemm2] 96
[TM][INFO] [Gemm2] 128
[TM][INFO] [Gemm2] 192
[TM][INFO] [Gemm2] 256
[TM][INFO] [Gemm2] 384
[TM][INFO] [Gemm2] 512
[TM][INFO] [Gemm2] 768
[TM][INFO] [Gemm2] 1024
[TM][INFO] [Gemm2] 1536
[TM][INFO] [Gemm2] 2048
[TM][INFO] [Gemm2] 3072
[TM][INFO] [Gemm2] 4096
[TM][INFO] [Gemm2] 6144
[TM][INFO] [Gemm2] 8192

然后就一直卡在这里了
@lzhangzz
Copy link
Collaborator

lzhangzz commented Feb 1, 2025

看起来有可能是 tuning 时 OOM 了,那个地方 OOM 的异常处理不是很完善导致没有报错的情况下卡住了。建议试试降低 cache_max_entry_count 和 max_prefill_token_num。

可以先从 cache_max_entry_count=0.2,max_prefill_token_num=1024 开始尝试。

@NB-Group
Copy link
Author

NB-Group commented Feb 1, 2025

哦!好的,谢谢您的回答!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants