Skip to content

Commit

Permalink
feat(misc): Profiler support
Browse files Browse the repository at this point in the history
use --profiler=MODE to enable, currently support torch_profile
and nvtx (use with NVIDIA Nsight system) mode
  • Loading branch information
SiYu Wu committed Nov 20, 2024
1 parent 06afb4a commit 65a0ab2
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 2 deletions.
15 changes: 15 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="""Maximum sequence length that can be captured by the cuda graph for decodign stage.
The default value is 8192. It will turn into eagar mode if encounters a larger value. """,
)
parser.add_argument(
"--profiler",
type=str,
choices=["torch_profile", "nvtx"],
default=None,
help="""Enable profiler support.
This will expose '/profiler_start' and '/profiler_stop' API,
below profiling features will only been enabled in this range.
Options:
'torch_profile' will setup torch.profiler.profile(), traces file will been saved to './trace',
or set by 'LIGHTLLM_TRACE_DIR' env;
'nvtx' will add NVTX marks for external profiler like NVIDIA Nsight System
(you should setup it by youself).
A NVTX named 'LIGHTLLM_PROFILE' will been added within the profiling range.""",
)
return parser
18 changes: 18 additions & 0 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,24 @@ async def kv_move_status(websocket: WebSocket):
return


@app.get("/profiler_start")
async def profiler_start() -> Response:
if g_objs.args.profiler:
g_objs.httpserver_manager.profiler_msg("start")
return {"status": "ok"}
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)


@app.get("/profiler_stop")
async def profiler_stop() -> Response:
if g_objs.args.profiler:
g_objs.httpserver_manager.profiler_msg("stop")
return {"status": "ok"}
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)


@app.on_event("shutdown")
async def shutdown():
logger.info("Received signal to shutdown. Performing graceful shutdown...")
Expand Down
6 changes: 5 additions & 1 deletion lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from typing import Union, List, Tuple, Dict
from ..tokenizer import get_tokenizer
from ..io_struct import BatchStrOut, AbortReq, FinishStatus
from ..io_struct import BatchStrOut, AbortReq, ProfilerReq, FinishStatus
from ..pd_io_struct import NodeRole
from ..embed_cache.utils import get_shm_name_data, create_shm
from ..req_id_generator import convert_sub_id_to_group_id
Expand Down Expand Up @@ -438,6 +438,10 @@ async def timer_to_pd_master(self):
await asyncio.sleep(10)
logger.info("reconnection to pd_master")

def profiler_msg(self, msg):
abort_req = ProfilerReq(msg)
self.send_to_router.send_pyobj(abort_req)


class ReqStatus:
def __init__(self, req_id, multimodal_params) -> None:
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,8 @@ def __init__(self):
class AbortReq:
def __init__(self, group_req_id):
self.group_req_id = group_req_id


class ProfilerReq:
def __init__(self, msg):
self.msg = msg
29 changes: 28 additions & 1 deletion lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from lightllm.utils.infer_utils import calculate_time
from .dynamic_prompt.shared_arr import SharedInt
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from ..io_struct import BatchTokenIdOut, AbortReq, ReqRunStatus, FinishStatus, ReqDetokenizationState
from ..io_struct import BatchTokenIdOut, AbortReq, ProfilerReq, ReqRunStatus, FinishStatus, ReqDetokenizationState
from .stats import Stats
from .pause_strategy import Fcfs, select_paused_reqs
from ..tokenizer import get_tokenizer
from lightllm.utils.log_utils import init_logger
from lightllm.utils.profiler import LocalProfiler
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
from lightllm.server.metrics.manager import MetricClient
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
# 主要是为了防止调度失误,造成 OOM 等错误
self.router_lock = mp.Lock()
g_router_lock.obj = self.router_lock

self.profiler = LocalProfiler(mode=args.profiler, name="lightllm-router") if args.profiler else None
return

async def wait_to_model_ready(self):
Expand Down Expand Up @@ -132,6 +135,7 @@ async def wait_to_model_ready(self):
"mem_fraction": self.args.mem_fraction,
"batch_max_tokens": self.args.batch_max_tokens,
"pd_rpyc_port": self.args.pd_tp_infer_rpyc_ports[rank_id], # 非 pd 模式可以不设置
"profiler": self.profiler if self.world_size == 1 else None,
}
init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))

Expand Down Expand Up @@ -335,8 +339,12 @@ async def _prefill_batch(self, batch: Batch):
await self._init_batch(batch)
if not self.is_splitfuse_mode:
# 在 非 splitfuse 模式下,才需要真的执行 prefill 的操作。
if self.profiler:
mark_range = self.profiler.mark_range_start(f"prefill len={batch.input_tokens()}")
rets = [self.model_rpcs[tp_rank].prefill_batch(batch.batch_id) for tp_rank in range(self.world_size)]
ans = await asyncio.gather(*rets)
if self.profiler:
self.profiler.mark_range_end(mark_range)
if self.world_size != 1:
req_to_out_status = obtain(ans[0])
else:
Expand All @@ -355,12 +363,16 @@ async def _prefill_batch(self, batch: Batch):
async def _decode_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
if self.profiler:
mark_range = self.profiler.mark_range_start(f"decode bs={len(batch.reqs)}")
rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)]
ans = await asyncio.gather(*rets)
if self.world_size != 1:
req_to_out_status = obtain(ans[0])
else:
req_to_out_status = ans[0]
if self.profiler:
self.profiler.mark_range_end(mark_range)

self._update_out_status_to_batch(batch, req_to_out_status)
unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status()
Expand Down Expand Up @@ -486,9 +498,24 @@ async def loop_for_netio_req(self):
group_req_id = abort_req.group_req_id
await self.abort(group_req_id)
self.send_to_detokenization.send_pyobj(abort_req)
elif isinstance(recv_req, ProfilerReq):
await self.profiler_ops(recv_req.msg)
else:
assert False, f"Error Req Inf {recv_req}"

async def profiler_ops(self, msg):
# assert self.profiler
if self.world_size != 1:
rets = [self.model_rpcs[tp_rank].profiler_ops(msg) for tp_rank in range(self.world_size)]
await asyncio.gather(*rets)

if msg == "start":
self.profiler.start()
elif msg == "stop":
self.profiler.stop()
else:
assert False, "invalid profiler ops"

def clean_up(self):
for model_rpc in self.model_rpcs:
model_rpc.rpc_server_process.kill()
Expand Down
19 changes: 19 additions & 0 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping
from lightllm.server.router.token_load import TokenLoad
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
from lightllm.utils.profiler import LocalProfiler


class ModeBackend:
Expand Down Expand Up @@ -88,6 +89,15 @@ def init_model(self, kvargs):
self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None)
max_total_token_num = kvargs["max_total_token_num"]

self.profiler = None
if kvargs.get("profiler") is not None:
# when world_size == 1, model and router are in the same process, so use same profiler object
assert world_size == 1
self.profiler = kvargs.get("profiler")
elif self.args.profiler:
# when world_size > 1
self.profiler = LocalProfiler(mode=self.args.profiler, name=f"lightllm-model_backend-{self.tp_rank}")

dist.init_process_group(
"nccl", init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', rank=self.tp_rank, world_size=world_size
)
Expand Down Expand Up @@ -336,3 +346,12 @@ def remove_batch(self, batch_id):
del batch
g_infer_state_lock.release()
return

def profiler_ops(self, msg):
assert self.profiler
if msg == "start":
self.profiler.start()
elif msg == "stop":
self.profiler.stop()
else:
assert False, "invalid profiler ops"
15 changes: 15 additions & 0 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def exposed_remove_batch(self, batch_id):
def exposed_get_max_total_token_num(self):
return self.backend.get_max_total_token_num()

def exposed_profiler_ops(self, msg):
if self.world_size != 1:
msg = obtain(msg)
return self.backend.profiler_ops(msg)


class ModelRpcClient:
def __init__(self, model_rpc, world_size, rpc_server_process=None):
Expand Down Expand Up @@ -161,6 +166,7 @@ async def _func(*args, **kwargs):
self._filter_batch = async_wrap(self.model.filter_batch)
self._merge_batch = async_wrap(self.model.merge_batch)
self._remove_batch = async_wrap(self.model.remove_batch)
self._profiler_ops = async_wrap(self.model.profiler_ops)
self._get_max_total_token_num = async_wrap(self.model.get_max_total_token_num)
else:
self._init_model = self.model.exposed_init_model
Expand All @@ -171,6 +177,7 @@ async def _func(*args, **kwargs):
self._filter_batch = self.model.exposed_filter_batch
self._merge_batch = self.model.exposed_merge_batch
self._remove_batch = self.model.exposed_remove_batch
self._profiler_ops = self.model.exposed_profiler_ops
self._get_max_total_token_num = self.model.exposed_get_max_total_token_num
return

Expand Down Expand Up @@ -242,6 +249,14 @@ async def get_max_total_token_num(self):
else:
return ans

async def profiler_ops(self, msg):
ans = self._profiler_ops(msg)
if self.use_rpc:
await ans
return
else:
return


def _init_env(args, port, info_queue, mem_queue, router_lock):
# 注册graceful 退出的处理
Expand Down
82 changes: 82 additions & 0 deletions lightllm/utils/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from typing import Any, Literal, Optional
import torch

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class LocalProfiler:
def __init__(self, mode: Literal["torch_profile", "nvtx"], name: Optional[str] = None):
self.mode: Literal["torch_profile", "nvtx"] = mode
self.name: Optional[str] = name
self.active: bool = False
if self.mode == "torch_profile":
trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace")
self._torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True, # additional overhead
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir, worker_name=name, use_gzip=True),
)
logger.warning(
"Profiler support (--profiler=XXX) for torch.profile enabled, trace file will been saved to %s",
trace_dir,
)
logger.warning("do not enable this feature in production")
elif self.mode == "nvtx":
self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE"
self._nvtx_toplevel_id = None
logger.warning(
"""Profiler support (--profiler=XXX) for NVTX enabled, toplevel NVTX mark is %s,
use it with external profiling tools""",
self._nvtx_toplevel_mark,
)
logger.warning(
"""e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx
-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [--other_nsys_options]
python -m lightllm.server.api_server --profiler=nvtx [--other_lightllm_options]""",
self._nvtx_toplevel_mark,
)
elif self.mode is not None:
assert False, "invalid profiler mode"

def start(self):
if self.active:
logger.error("profiler already started, ignore")
return
logger.warning("Profiler support: profiling start")
self.active = True
if self.mode == "torch_profile":
self._torch_profiler.start()
elif self.mode == "nvtx":
self._nvtx_toplevel_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark)

def stop(self):
if not self.active:
logger.error("profiler not started, ignore")
return
logger.warning("Profiler support: profiling stop")
self.active = False
if self.mode == "torch_profile":
logger.warning("Profiler support: torch_profiler saving trace file, it might take a while...")
self._torch_profiler.stop()
logger.warning("Profiler support: torch_profiler saving done")
elif self.mode == "nvtx":
torch.cuda.nvtx.range_end(self._nvtx_toplevel_id)

def mark_range_start(self, message: str) -> Any:
"return the handle of the range, to be used in mark_range_end()"
if self.active:
# only support for NVTX mode
if self.mode == "nvtx":
return torch.cuda.nvtx.range_start(message)

def mark_range_end(self, handle: Any):
if self.active:
# only support for NVTX mode
if self.mode == "nvtx":
return torch.cuda.nvtx.range_end(handle)

0 comments on commit 65a0ab2

Please sign in to comment.