diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f9901bb7f..3423b3191 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -219,4 +219,19 @@ def make_argument_parser() -> argparse.ArgumentParser: help="""Maximum sequence length that can be captured by the cuda graph for decodign stage. The default value is 8192. It will turn into eagar mode if encounters a larger value. """, ) + parser.add_argument( + "--profiler", + type=str, + choices=["torch_profile", "nvtx"], + default=None, + help="""Enable profiler support. + This will expose '/profiler_start' and '/profiler_stop' API, + below profiling features will only been enabled in this range. + Options: + 'torch_profile' will setup torch.profiler.profile(), traces file will been saved to './trace', + or set by 'LIGHTLLM_TRACE_DIR' env; + 'nvtx' will add NVTX marks for external profiler like NVIDIA Nsight System + (you should setup it by youself). + A NVTX named 'LIGHTLLM_PROFILE' will been added within the profiling range.""", + ) return parser diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 721d9e44f..8798fca0c 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -350,6 +350,24 @@ async def kv_move_status(websocket: WebSocket): return +@app.get("/profiler_start") +async def profiler_start() -> Response: + if g_objs.args.profiler: + g_objs.httpserver_manager.profiler_msg("start") + return {"status": "ok"} + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=500) + + +@app.get("/profiler_stop") +async def profiler_stop() -> Response: + if g_objs.args.profiler: + g_objs.httpserver_manager.profiler_msg("stop") + return {"status": "ok"} + else: + return JSONResponse({"message": "Profiling support not enabled"}, status_code=500) + + @app.on_event("shutdown") async def shutdown(): logger.info("Received signal to shutdown. Performing graceful shutdown...") diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index dc1a8f3eb..e5e09763b 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict from ..tokenizer import get_tokenizer -from ..io_struct import BatchStrOut, AbortReq, FinishStatus +from ..io_struct import BatchStrOut, AbortReq, ProfilerReq, FinishStatus from ..pd_io_struct import NodeRole from ..embed_cache.utils import get_shm_name_data, create_shm from ..req_id_generator import convert_sub_id_to_group_id @@ -438,6 +438,10 @@ async def timer_to_pd_master(self): await asyncio.sleep(10) logger.info("reconnection to pd_master") + def profiler_msg(self, msg): + abort_req = ProfilerReq(msg) + self.send_to_router.send_pyobj(abort_req) + class ReqStatus: def __init__(self, req_id, multimodal_params) -> None: diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index 9d717455d..02ef4cb4d 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -344,3 +344,8 @@ def __init__(self): class AbortReq: def __init__(self, group_req_id): self.group_req_id = group_req_id + + +class ProfilerReq: + def __init__(self, msg): + self.msg = msg diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index e53dba811..5a6fcd3bf 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -20,11 +20,12 @@ from lightllm.utils.infer_utils import calculate_time from .dynamic_prompt.shared_arr import SharedInt from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient -from ..io_struct import BatchTokenIdOut, AbortReq, ReqRunStatus, FinishStatus, ReqDetokenizationState +from ..io_struct import BatchTokenIdOut, AbortReq, ProfilerReq, ReqRunStatus, FinishStatus, ReqDetokenizationState from .stats import Stats from .pause_strategy import Fcfs, select_paused_reqs from ..tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger +from lightllm.utils.profiler import LocalProfiler from lightllm.server.router.token_load import TokenLoad from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.server.metrics.manager import MetricClient @@ -79,6 +80,8 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr # 主要是为了防止调度失误,造成 OOM 等错误 self.router_lock = mp.Lock() g_router_lock.obj = self.router_lock + + self.profiler = LocalProfiler(mode=args.profiler, name="lightllm-router") if args.profiler else None return async def wait_to_model_ready(self): @@ -132,6 +135,7 @@ async def wait_to_model_ready(self): "mem_fraction": self.args.mem_fraction, "batch_max_tokens": self.args.batch_max_tokens, "pd_rpyc_port": self.args.pd_tp_infer_rpyc_ports[rank_id], # 非 pd 模式可以不设置 + "profiler": self.profiler if self.world_size == 1 else None, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) @@ -335,8 +339,12 @@ async def _prefill_batch(self, batch: Batch): await self._init_batch(batch) if not self.is_splitfuse_mode: # 在 非 splitfuse 模式下,才需要真的执行 prefill 的操作。 + if self.profiler: + mark_range = self.profiler.mark_range_start(f"prefill len={batch.input_tokens()}") rets = [self.model_rpcs[tp_rank].prefill_batch(batch.batch_id) for tp_rank in range(self.world_size)] ans = await asyncio.gather(*rets) + if self.profiler: + self.profiler.mark_range_end(mark_range) if self.world_size != 1: req_to_out_status = obtain(ans[0]) else: @@ -355,12 +363,16 @@ async def _prefill_batch(self, batch: Batch): async def _decode_batch(self, batch: Batch): start_time = time.time() self.metric_client.counter_inc("lightllm_batch_inference_count", "decode") + if self.profiler: + mark_range = self.profiler.mark_range_start(f"decode bs={len(batch.reqs)}") rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)] ans = await asyncio.gather(*rets) if self.world_size != 1: req_to_out_status = obtain(ans[0]) else: req_to_out_status = ans[0] + if self.profiler: + self.profiler.mark_range_end(mark_range) self._update_out_status_to_batch(batch, req_to_out_status) unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status() @@ -486,9 +498,24 @@ async def loop_for_netio_req(self): group_req_id = abort_req.group_req_id await self.abort(group_req_id) self.send_to_detokenization.send_pyobj(abort_req) + elif isinstance(recv_req, ProfilerReq): + await self.profiler_ops(recv_req.msg) else: assert False, f"Error Req Inf {recv_req}" + async def profiler_ops(self, msg): + # assert self.profiler + if self.world_size != 1: + rets = [self.model_rpcs[tp_rank].profiler_ops(msg) for tp_rank in range(self.world_size)] + await asyncio.gather(*rets) + + if msg == "start": + self.profiler.start() + elif msg == "stop": + self.profiler.stop() + else: + assert False, "invalid profiler ops" + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 9c9cff922..27cba938e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -52,6 +52,7 @@ from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock +from lightllm.utils.profiler import LocalProfiler class ModeBackend: @@ -88,6 +89,15 @@ def init_model(self, kvargs): self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None) max_total_token_num = kvargs["max_total_token_num"] + self.profiler = None + if kvargs.get("profiler") is not None: + # when world_size == 1, model and router are in the same process, so use same profiler object + assert world_size == 1 + self.profiler = kvargs.get("profiler") + elif self.args.profiler: + # when world_size > 1 + self.profiler = LocalProfiler(mode=self.args.profiler, name=f"lightllm-model_backend-{self.tp_rank}") + dist.init_process_group( "nccl", init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', rank=self.tp_rank, world_size=world_size ) @@ -336,3 +346,12 @@ def remove_batch(self, batch_id): del batch g_infer_state_lock.release() return + + def profiler_ops(self, msg): + assert self.profiler + if msg == "start": + self.profiler.start() + elif msg == "stop": + self.profiler.stop() + else: + assert False, "invalid profiler ops" diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 63420e7af..3dbc414c4 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -133,6 +133,11 @@ def exposed_remove_batch(self, batch_id): def exposed_get_max_total_token_num(self): return self.backend.get_max_total_token_num() + def exposed_profiler_ops(self, msg): + if self.world_size != 1: + msg = obtain(msg) + return self.backend.profiler_ops(msg) + class ModelRpcClient: def __init__(self, model_rpc, world_size, rpc_server_process=None): @@ -161,6 +166,7 @@ async def _func(*args, **kwargs): self._filter_batch = async_wrap(self.model.filter_batch) self._merge_batch = async_wrap(self.model.merge_batch) self._remove_batch = async_wrap(self.model.remove_batch) + self._profiler_ops = async_wrap(self.model.profiler_ops) self._get_max_total_token_num = async_wrap(self.model.get_max_total_token_num) else: self._init_model = self.model.exposed_init_model @@ -171,6 +177,7 @@ async def _func(*args, **kwargs): self._filter_batch = self.model.exposed_filter_batch self._merge_batch = self.model.exposed_merge_batch self._remove_batch = self.model.exposed_remove_batch + self._profiler_ops = self.model.exposed_profiler_ops self._get_max_total_token_num = self.model.exposed_get_max_total_token_num return @@ -242,6 +249,14 @@ async def get_max_total_token_num(self): else: return ans + async def profiler_ops(self, msg): + ans = self._profiler_ops(msg) + if self.use_rpc: + await ans + return + else: + return + def _init_env(args, port, info_queue, mem_queue, router_lock): # 注册graceful 退出的处理 diff --git a/lightllm/utils/profiler.py b/lightllm/utils/profiler.py new file mode 100644 index 000000000..5a71df0e9 --- /dev/null +++ b/lightllm/utils/profiler.py @@ -0,0 +1,82 @@ +import os +from typing import Any, Literal, Optional +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class LocalProfiler: + def __init__(self, mode: Literal["torch_profile", "nvtx"], name: Optional[str] = None): + self.mode: Literal["torch_profile", "nvtx"] = mode + self.name: Optional[str] = name + self.active: bool = False + if self.mode == "torch_profile": + trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace") + self._torch_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, # additional overhead + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir, worker_name=name, use_gzip=True), + ) + logger.warning( + "Profiler support (--profiler=XXX) for torch.profile enabled, trace file will been saved to %s", + trace_dir, + ) + logger.warning("do not enable this feature in production") + elif self.mode == "nvtx": + self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE" + self._nvtx_toplevel_id = None + logger.warning( + """Profiler support (--profiler=XXX) for NVTX enabled, toplevel NVTX mark is %s, + use it with external profiling tools""", + self._nvtx_toplevel_mark, + ) + logger.warning( + """e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx + -e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [--other_nsys_options] + python -m lightllm.server.api_server --profiler=nvtx [--other_lightllm_options]""", + self._nvtx_toplevel_mark, + ) + elif self.mode is not None: + assert False, "invalid profiler mode" + + def start(self): + if self.active: + logger.error("profiler already started, ignore") + return + logger.warning("Profiler support: profiling start") + self.active = True + if self.mode == "torch_profile": + self._torch_profiler.start() + elif self.mode == "nvtx": + self._nvtx_toplevel_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark) + + def stop(self): + if not self.active: + logger.error("profiler not started, ignore") + return + logger.warning("Profiler support: profiling stop") + self.active = False + if self.mode == "torch_profile": + logger.warning("Profiler support: torch_profiler saving trace file, it might take a while...") + self._torch_profiler.stop() + logger.warning("Profiler support: torch_profiler saving done") + elif self.mode == "nvtx": + torch.cuda.nvtx.range_end(self._nvtx_toplevel_id) + + def mark_range_start(self, message: str) -> Any: + "return the handle of the range, to be used in mark_range_end()" + if self.active: + # only support for NVTX mode + if self.mode == "nvtx": + return torch.cuda.nvtx.range_start(message) + + def mark_range_end(self, handle: Any): + if self.active: + # only support for NVTX mode + if self.mode == "nvtx": + return torch.cuda.nvtx.range_end(handle)