-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support automatically calculate max_total_token_num #81
Open
singularity-s0
wants to merge
5
commits into
ModelTC:main
Choose a base branch
from
singularity-s0:auto_max_token_num
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+68
−2
Open
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
3f6ccda
support auto calculate total token num
singularity-s0 e68926a
fix cuda multiprocessing error
singularity-s0 a9359d9
bug fix
singularity-s0 a1ef209
update kv_cache_size calculation
singularity-s0 7e35330
Merge pull request #1 from singularity-s0/main
singularity-s0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
torch.multiprocessing.set_start_method('spawn', force=True) # Fork start method will cause CUDA re-initialization error | ||
import os | ||
import json | ||
|
||
def get_total_free_gpu_memory(tp): | ||
""" | ||
Returns the total amount of free memory available on all GPUs, in Gigabytes. | ||
""" | ||
devices = min(tp, torch.cuda.device_count()) | ||
total_free = 0 | ||
for i in range(devices): | ||
total_free += torch.cuda.mem_get_info(i)[0] | ||
total_free = total_free / (1024 ** 3) | ||
return total_free | ||
|
||
def get_total_weight_size(weight_dir): | ||
""" | ||
Returns the total size of all parameters in the model, in Gigabytes. | ||
""" | ||
total_size = 0 | ||
files = os.listdir(weight_dir) | ||
candidate_files = list(filter(lambda x : x.endswith('.safetensors'), files)) | ||
if len(candidate_files) == 0: | ||
candidate_files = list(filter(lambda x : x.endswith('.bin'), files)) | ||
assert len(candidate_files) != 0, "can only support pytorch tensor and safetensors format for weights." | ||
for file in candidate_files: | ||
total_size += os.path.getsize(os.path.join(weight_dir, file)) | ||
total_size = total_size / (1024 ** 3) | ||
return total_size | ||
|
||
def get_kv_cache_size(model_dir): | ||
""" | ||
Returns the size of the kv cache for a single token, in Gigabytes. | ||
""" | ||
# Read from config.json | ||
config_path = os.path.join(model_dir, 'config.json') | ||
assert os.path.exists(config_path), "config.json not found in model directory." | ||
try: | ||
with open(config_path, 'r') as f: | ||
config = json.load(f) | ||
hidden_size = config['hidden_size'] | ||
layer_num = config['num_hidden_layers'] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @singularity-s0 This code may not be very robust when the key name in config.json changes. |
||
num_attention_heads = config['num_attention_heads'] | ||
num_key_value_heads = config.get('num_key_value_heads', num_attention_heads) # Models may not be using GQA | ||
dtype = config.get('torch_dtype', 'float16') # TODO: dtype may not be specified in config.json, should we load weights to check? | ||
except: | ||
raise Exception("Error reading config.json when trying to determine max_total_token_num. Please manually specify max_total_token_num in startup arguments.") | ||
dtype_size = torch.empty(0, dtype=getattr(torch, dtype)).element_size() | ||
kv_cache_size = hidden_size * dtype_size * 2 * layer_num / num_attention_heads * num_key_value_heads / (1024 ** 3) | ||
return kv_cache_size | ||
|
||
def calc_max_total_token_num(tp, weight_dir, mem_fill_rate=0.8): | ||
""" | ||
Calculate the max total token num that can be supported by the model. | ||
""" | ||
kv_cache_size = get_kv_cache_size(weight_dir) | ||
max_token_num = (get_total_free_gpu_memory(tp)-get_total_weight_size(weight_dir)) * mem_fill_rate / kv_cache_size | ||
return int(max_token_num) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"get_kv_cache_size and xxxx" is best implemented as a member function of TpPartBaseModel and should be inherited and implemented by its subclasses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that
max_total_token_num
(andbatch_max_tokens
that depends on it) gets passed to a lot of subsystems before the model is initialized. We need this value to be ready early.Is there any way to achieve this if implemented as a member function of TpPartBaseModel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@singularity-s0 You can try to add a method in TpPartBaseModel, but it is not easy to get and set batch_max_tokens in TpPartBaseModel. Let me think about how to implement it elegantly. What are your suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally, since each instance of LightLLM server is bound to only one model, model configuration can (and should) be loaded before all other subsystems are initialized (because other subsystems may depend on model configuration, as in the case of
max_total_token_num
). A refactor would be the most elegant way to address this.Other parameters like
max_req_total_len
anddtype
(which is currently hardcoded tofp16
) might also be dependent on modelconfig.json
and would benefit from this refactor.However I imagine such a refactor would not be easy. Hacky solutions are also available but it is ultimately up to you to decide which way is the best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@singularity-s0 You can write a standalone recommendation program to generate a value for max_total_token_num. that will be more appropriate。