Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix group pause #506

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions lightllm/server/router/model_infer/mode_backend/beamsearch/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferReqGroup, InferSamplingParams, requests_mapping, group_mapping
from lightllm.server.router.model_infer.infer_batch import (
InferBatch,
InferReq,
InferReqGroup,
InferSamplingParams,
requests_mapping,
group_mapping,
)
from lightllm.server.io_struct import ReqRunStatus, FinishStatus
from lightllm.utils.log_utils import init_logger
from .pre_process import prepare_prefill_inputs, prepare_decode_inputs
Expand All @@ -20,14 +27,15 @@ def prefill_batch(self, batch_id):
@calculate_time(show=True, min_cost_ms=200)
def decode_batch(self, batch_id):
return self.forward(batch_id, is_prefill=False)

def build_group(self, batch):
for r_id in batch.request_ids:
req = requests_mapping[r_id]
group_req_id = req.group_req_id
best_of = req.sampling_param.best_of
if group_req_id not in group_mapping:
group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id, best_of=best_of)
# dealing with paused
del group_mapping[group_req_id]
group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id, best_of=best_of)
group_mapping[group_req_id].add_req(r_id)

def forward(self, batch_id, is_prefill):
Expand All @@ -41,9 +49,13 @@ def forward(self, batch_id, is_prefill):
kwargs, run_reqs = prepare_decode_inputs(batch, self.radix_cache)

logits = self.model.forward(**kwargs)
next_token_id_groups, next_token_logprob_groups, next_cumlogprob_groups = sample(logits, run_reqs, is_prefill, self.model.vocab_size, self.model.req_manager, self.eos_id)
next_token_id_groups, next_token_logprob_groups, next_cumlogprob_groups = sample(
logits, run_reqs, is_prefill, self.model.vocab_size, self.model.req_manager, self.eos_id
)

for req_group_obj, next_token_id_group, next_token_logprob_group, next_cumlogprob_group in zip(run_reqs, next_token_id_groups, next_token_logprob_groups, next_cumlogprob_groups):
for req_group_obj, next_token_id_group, next_token_logprob_group, next_cumlogprob_group in zip(
run_reqs, next_token_id_groups, next_token_logprob_groups, next_cumlogprob_groups
):
# prefill and decode is same
for i in range(req_group_obj.best_of):
req_obj = req_group_obj.get_req(i)
Expand Down Expand Up @@ -75,7 +87,7 @@ def forward(self, batch_id, is_prefill):
req_obj.req_status,
req_obj.cur_kv_len,
req_obj.get_output_len(),
[], # empty meta
[], # empty meta
0, # unfinished
None,
) # 请求
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def build_group(self, batch):
req = requests_mapping[r_id]
group_req_id = req.group_req_id
best_of = req.sampling_param.best_of
if group_req_id not in group_mapping:
group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id, best_of=best_of)
# dealing with paused
del group_mapping[group_req_id]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here, how to handle the case "group_req_id is not in group_mapping" @shihaobai

group_mapping[group_req_id] = InferReqGroup(group_req_id=group_req_id, best_of=best_of)
group_mapping[group_req_id].add_req(r_id)

def forward(self, batch_id, is_prefill):
Expand Down
Loading