Skip to content

Commit

Permalink
Enable CC_METHOD for dsv3 by default && fix test script && fix tgi st…
Browse files Browse the repository at this point in the history
…ream api (#732)

Co-authored-by: shihaobai <[email protected]>
  • Loading branch information
shihaobai and shihaobai authored Feb 17, 2025
1 parent c07e3a2 commit 250d7ad
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
if mscale_all_dim:
mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim)
self.softmax_scale = self.softmax_scale * mscale * mscale
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
if self.enable_dp:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def fuse_vb_o(self, layer_weight):
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
self.enable_cc_method = not os.getenv("DISABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
return

Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/api_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json


def format_tgi_params(params, num_beam: int):
def format_tgi_params(params, num_beam: int = 1):
"""
tgi params format -> lightllm server params format
pub(crate) struct GenerateParameters {
Expand Down
31 changes: 18 additions & 13 deletions test/model/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,20 @@ def test_model_inference(world_size, model_class, batch_size, input_len, output_

def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_len, ans_queue):
import torch
from lightllm.distributed import set_custom_reduce
from lightllm.distributed import custom_comm_ops
from lightllm.utils.device_utils import set_current_device_id

import torch.distributed as dist

rank_id = model_kvargs["tp_rank"]
world_size = model_kvargs["world_size"]

torch.cuda.set_device(rank_id)

set_current_device_id(rank_id)
dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size)
set_custom_reduce()

custom_comm_ops.set_custom_reduce()
custom_comm_ops.set_custom_gather()
dist.barrier()

torch.cuda.empty_cache()
Expand All @@ -59,7 +63,9 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
test_data = test_data.reshape(-1)
test_data = torch.from_numpy(test_data).cuda()

b_req_idx = model_part.req_manager.alloc(batch_size).int()
b_req_idx = torch.tensor(
[model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
Expand All @@ -68,7 +74,8 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
b_seq_len[i] = input_len

total_token_num = input_len * batch_size
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()

logics = model_part.forward(
batch_size,
total_token_num,
Expand All @@ -89,7 +96,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
logics = model_part.forward(
batch_size,
total_token_num,
Expand All @@ -108,10 +115,6 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
model_part.mem_manager.free_all()
model_part.req_manager.free_all()

if rank_id == 0:
print("can use mem size:", model_part.mem_manager.can_use_mem_size)
print("can use req size:", model_part.req_manager.can_use_req_size)

b_req_idx = None
b_start_loc = None
b_seq_len = None
Expand All @@ -124,15 +127,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_

prefill_start_time = time.time()

b_req_idx = model_part.req_manager.alloc(batch_size).int()
b_req_idx = torch.tensor(
[model_part.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda"
)
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
for i in range(batch_size):
b_start_loc[i] = i * input_len
b_seq_len[i] = input_len

total_token_num = batch_size * input_len
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]).cuda()
logics = model_part.forward(
batch_size,
total_token_num,
Expand All @@ -159,7 +164,7 @@ def tppart_model_infer(model_class, model_kvargs, batch_size, input_len, output_
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]).cuda()
logics = model_part.forward(
batch_size,
total_token_num,
Expand Down
3 changes: 2 additions & 1 deletion test/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightllm.utils.config_utils import get_dtype
from lightllm.utils.config_utils import get_config_json


def get_model(weight_dir):
model_cfg = get_config_json(weight_dir)
model_type = model_cfg["model_type"]
Expand Down Expand Up @@ -68,7 +69,7 @@ def get_model(weight_dir):

class TestModelInfer(unittest.TestCase):
def test_model_infer(self):
model_dir = "/nvme/ci_performance/models/DeepSeek-V2-Lite-Chat/"
model_dir = "/nvme/models/llama3/Meta-Llama-3-8B/"
model_class = get_model(model_dir)
data_type = get_dtype(model_dir)
mode = "triton_gqa_flashdecoding"
Expand Down

0 comments on commit 250d7ad

Please sign in to comment.