From d23d6f51e4d72cb5105ae79ecb1e854ca490acb7 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 21 Nov 2024 18:17:02 -0500 Subject: [PATCH] [C++] Invoke storage allocation for CUDA Graph explicitly (#3042) This PR adds a function that invokes the storage allocation function generated by CUDA Graph rewrite. With this function, we now manually trigger the storage allocation at initialization time. The reason we need this is because that the storage allocation may contain CUDA IPC memory alloc that has to run through a Disco session. So when a function that needs CUDA graph storage allocation runs first outside a Disco session, there might be error caused if we did not initialize the allocation in advance. --- cpp/serve/function_table.cc | 5 +++ cpp/serve/function_table.h | 1 + .../attach_cuda_graph_alloc_init_func.py | 33 +++++++++++++++++++ python/mlc_llm/compiler_pass/pipeline.py | 3 +- 4 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py diff --git a/cpp/serve/function_table.cc b/cpp/serve/function_table.cc index 790d80047e..63ce1492f9 100644 --- a/cpp/serve/function_table.cc +++ b/cpp/serve/function_table.cc @@ -152,6 +152,10 @@ void FunctionTable::Init(String reload_lib_path, Device device, picojson::object } ICHECK_EQ(this->model_metadata_.tensor_parallel_shards, num_shards); ICHECK_EQ(this->model_metadata_.pipeline_parallel_stages, num_stages); + // Invoke the CUDA graph allocation init function if it is defined. + if (cuda_graph_alloc_init_func_.defined()) { + this->cuda_graph_alloc_init_func_(); + } } ObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device) { @@ -231,6 +235,7 @@ void FunctionTable::_InitFunctions() { this->apply_penalty_func_ = mod->GetFunction("apply_penalty_inplace", true); this->apply_bitmask_func_ = mod->GetFunction("apply_bitmask_inplace", true); this->alloc_embedding_tensor_func_ = mod_get_func("alloc_embedding_tensor"); + this->cuda_graph_alloc_init_func_ = mod_get_func("cuda_graph_alloc_init"); this->create_kv_cache_func_ = mod_get_func("create_flashinfer_paged_kv_cache"); if (this->model_metadata_.sliding_window_size != -1 || !this->create_kv_cache_func_.defined()) { PackedFunc f_create_rnn_state = mod_get_func("create_rnn_state"); diff --git a/cpp/serve/function_table.h b/cpp/serve/function_table.h index f2a513ec87..46fae540da 100644 --- a/cpp/serve/function_table.h +++ b/cpp/serve/function_table.h @@ -99,6 +99,7 @@ struct FunctionTable { PackedFunc apply_penalty_func_; PackedFunc apply_bitmask_func_; PackedFunc alloc_embedding_tensor_func_; + PackedFunc cuda_graph_alloc_init_func_; PackedFunc create_kv_cache_func_; PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; diff --git a/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py b/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py new file mode 100644 index 0000000000..70f6598852 --- /dev/null +++ b/python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py @@ -0,0 +1,33 @@ +"""The pass that attaches an empty function for initialization.""" + +import tvm +from tvm import IRModule, relax + + +@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphAllocInitFunc") +class AttachCUDAGraphAllocInitFunc: # pylint: disable=too-few-public-methods + """Attach an empty function for initialization.""" + + def __init__(self): + pass + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entrypoint""" + bb = relax.BlockBuilder(mod) + alloc_func_gv = None + for gv, _ in mod.functions_items(): + if gv.name_hint.startswith("cuda_graph_alloc"): + assert alloc_func_gv is None + alloc_func_gv = gv + if alloc_func_gv is None: + return mod + + with bb.function("cuda_graph_alloc_init", []): + bb.emit_func_output( + relax.op.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + args=[alloc_func_gv, relax.PrimValue(0)], + sinfo_args=relax.ObjectStructInfo(), + ) + ) + return bb.finalize() diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index 363be1a59b..af1cf9f0e9 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -12,6 +12,7 @@ from mlc_llm.interface.compiler_flags import IPCAllReduceStrategyType from mlc_llm.support import logging +from .attach_cuda_graph_alloc_init_func import AttachCUDAGraphAllocInitFunc from .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc from .attach_logit_processor import AttachLogitProcessFunc from .attach_sampler import AttachGPUSamplingFunc @@ -159,7 +160,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I ), ScatterTupleGetItem(), PipelineParallelRewrite(), - _DebugDump("after-pipeline-rewrite.py", debug_dump, show_meta=False), tvm.relax.transform.RewriteDataflowReshape(), tvm.relax.transform.ToNonDataflow(), tvm.relax.transform.RemovePurityChecking(), @@ -172,6 +172,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.StaticPlanBlockMemory(), AttachMetadataWithMemoryUsage(metadata), tvm.relax.transform.RewriteCUDAGraph(), + AttachCUDAGraphAllocInitFunc(), tvm.relax.transform.LowerGPUIPCAllocStorage(), tvm.relax.transform.LowerAllocTensor(), tvm.relax.transform.KillAfterLastUse(),