Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Command-R Max Model Length #3727

Merged
merged 4 commits into from
Mar 29, 2024

Conversation

ywang96
Copy link
Member

@ywang96 ywang96 commented Mar 29, 2024

Currently, the max context window for CohereForAI/c4ai-command-r-v01 is not defined by max_position_embeddings but a special model_max_length key instead. This has been discussed in these two threads: 1, 2

We still use max_position_embeddings as default max_model_len for the memory concern, but when the user specifies a value higher than max_position_embeddings but lower than or equal to model_max_length, we will allow this to go through.

This PR fixes #3676

cc @saurabhdash I'm not sure if there's a cleaner/better way to do this but please take a look.

@pseudotensor
Copy link

cool thanks!

@ywang96 ywang96 requested a review from youkaichao March 29, 2024 08:39
@saurabhdash
Copy link

Currently, the max context window for CohereForAI/c4ai-command-r-v01 is not defined by max_position_embeddings but a special model_max_length key instead. This has been discussed in these two threads: 1, 2

We still use max_position_embeddings as default max_model_len for the memory concern, but when the user specifies a value higher than max_position_embeddings but lower than or equal to model_max_length, we will allow this to go through.

This PR fixes #3676

cc @saurabhdash I'm not sure if there's a cleaner/better way to do this but please take a look.

Just so that I understand correctly, this allows people to increase the context length upto 128k and throw a warning if larger?

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 29, 2024

A quick question: why not just change the model's max_position_embeddings value in config.json file?

@pseudotensor
Copy link

@saurabhdash It's not a warning, it's a fatal raise.

@pseudotensor
Copy link

pseudotensor commented Mar 29, 2024

@esmeetu Because that's not normally done for any other models and it is not maintainable when pulling weights into cached location that may be updated. It's also not correct since rope scaling is not same as just changing the embedding size AFAIK.

@simon-mo simon-mo added the release-blocker This PR/issue blocks the next release, therefore deserves highest priority label Mar 29, 2024
@ywang96
Copy link
Member Author

ywang96 commented Mar 29, 2024

@pseudotensor That make sense. But this PR is a little bit tricky. And max_model_length is not a common parameter across the open models. Furthermore, if we apply this, it extremely probably trigger OOM error for most users' environment. Because it takes two much memory when using 128k context and seems unnecessary for them. For model cache convenience, I think you could fork that repo, and tune the parameter as you need. @simon-mo WDYT?

Just to clarify - we still use 8192 as the default max_model_len if user doesnt specify it. This PR really just allows users to go above that until the truth context length at their own risk.

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put up a quick manual test for this?

@ywang96
Copy link
Member Author

ywang96 commented Mar 29, 2024

can you put up a quick manual test for this?

@simon-mo Sure, here's a quick test script

from vllm import LLM, SamplingParams
import sys

try:
    if len(sys.argv) == 2:
        user_specified = int(sys.argv[-1])
    else:
        user_specified = None
    llm = LLM(model="CohereForAI/c4ai-command-r-v01", tensor_parallel_size=4, max_model_len=user_specified)
    print(f"Max Length: {llm.llm_engine.model_config.max_model_len}")
except Exception as e:
    print(e)

On A100-80G

  • Default:
INFO 03-29 17:41:43 model_runner.py:867] Graph capturing finished in 14 secs.
Max Length: 8192
  • Specifying max_model_len=131072 with TP4
INFO 03-29 17:51:17 ray_gpu_executor.py:240] # GPU blocks: 6158, # CPU blocks: 819
The model's max seq len (131072) is larger than the maximum number of tokens that can be stored in KV cache (98528). Try increasing `gpu_memory_utilization` or decreasing `max_model_len` when initializing the engine.
  • Specifying max_model_len=131072 with TP8
INFO 03-29 17:53:38 model_runner.py:867] Graph capturing finished in 14 secs.
Max Length: 131072
  • Specifying max_model_len=131073
User-specified max_model_len (131073) is greater than the derived max_model_len (max_position_embeddings=8192 or model_max_length=131072 in model's config.json). This may lead to incorrect model outputs or CUDA errors. Make sure the value is correct and within the model context size.

There's one small "visual" bug I fixed: the variable max_len_key was used incorrecty.

@simon-mo simon-mo merged commit 97356f3 into vllm-project:main Mar 29, 2024
34 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
release-blocker This PR/issue blocks the next release, therefore deserves highest priority
Projects
None yet
5 participants