From 4bb65f216f2f1dbddb022f4fdd8925c2856baa58 Mon Sep 17 00:00:00 2001 From: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Date: Tue, 12 Mar 2024 18:15:52 +0800 Subject: [PATCH] Update TensorRT-LLM (#1274) * Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> --- .clang-format | 1 + .gitignore | 10 + README.md | 3 + benchmarks/cpp/README.md | 5 - benchmarks/cpp/bertBenchmark.cpp | 2 +- benchmarks/cpp/gptManagerBenchmark.cpp | 161 +- benchmarks/cpp/gptSessionBenchmark.cpp | 36 +- benchmarks/python/allowed_configs.py | 1 + benchmarks/python/base_benchmark.py | 6 +- benchmarks/python/benchmark.py | 9 +- benchmarks/python/build.py | 11 +- benchmarks/python/mem_monitor.py | 9 +- cpp/CMakeLists.txt | 10 +- .../tensorrt_llm/batch_manager/callbacks.h | 4 +- .../batch_manager/inferenceRequest.h | 4 +- .../batch_manager/kvCacheConfig.h | 7 + .../batch_manager/kvCacheManager.h | 41 +- .../tensorrt_llm/batch_manager/llmRequest.h | 9 +- .../tensorrt_llm/batch_manager/namedTensor.h | 8 +- .../batch_manager/trtGptModelOptionalParams.h | 12 +- cpp/include/tensorrt_llm/common/arrayView.h | 7 + cpp/include/tensorrt_llm/common/mpiUtils.h | 34 +- cpp/include/tensorrt_llm/executor/executor.h | 92 +- cpp/include/tensorrt_llm/executor/tensor.h | 5 +- cpp/include/tensorrt_llm/executor/types.h | 15 +- .../tensorrt_llm/runtime/decodingMode.h | 5 + cpp/include/tensorrt_llm/runtime/gptDecoder.h | 12 +- .../tensorrt_llm/runtime/gptDecoderBatch.h | 14 + .../tensorrt_llm/runtime/iGptDecoderBatch.h | 14 +- cpp/include/tensorrt_llm/runtime/ipcUtils.h | 2 +- .../tensorrt_llm/runtime/promptTuningParams.h | 2 +- .../tensorrt_llm/runtime/samplingConfig.h | 2 +- cpp/tensorrt_llm/CMakeLists.txt | 1 - .../libtensorrt_llm_batch_manager_static.a | 4 +- ...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +- .../aarch64-linux-gnu/version.txt | 6 +- .../libtensorrt_llm_batch_manager_static.a | 4 +- ...sorrt_llm_batch_manager_static.pre_cxx11.a | 4 +- .../x86_64-linux-gnu/version.txt | 4 +- cpp/tensorrt_llm/common/allocator.h | 6 +- cpp/tensorrt_llm/common/assert.h | 16 +- cpp/tensorrt_llm/common/cublasMMWrapper.cpp | 62 +- cpp/tensorrt_llm/common/cublasMMWrapper.h | 48 +- cpp/tensorrt_llm/common/cudaDriverWrapper.cpp | 14 +- cpp/tensorrt_llm/common/cudaDriverWrapper.h | 28 +- cpp/tensorrt_llm/common/cudaFp8Utils.cu | 32 +- cpp/tensorrt_llm/common/cudaFp8Utils.h | 40 +- cpp/tensorrt_llm/common/cudaTypeUtils.cuh | 6 +- cpp/tensorrt_llm/common/cudaUtils.h | 46 +- cpp/tensorrt_llm/common/envUtils.cpp | 6 +- cpp/tensorrt_llm/common/logger.cpp | 1 - cpp/tensorrt_llm/common/logger.h | 16 +- cpp/tensorrt_llm/common/memoryUtils.cu | 146 +- cpp/tensorrt_llm/common/memoryUtils.h | 20 +- cpp/tensorrt_llm/common/mpiUtils.cpp | 22 +- cpp/tensorrt_llm/common/quantization.h | 14 +- cpp/tensorrt_llm/common/reduceKernelUtils.cuh | 18 +- cpp/tensorrt_llm/common/tensor.cpp | 4 +- cpp/tensorrt_llm/common/tensor.h | 44 +- cpp/tensorrt_llm/common/tllmException.cpp | 4 +- .../epilogue/thread/fused_activations.h | 4 +- .../epilogue_per_row_per_col_scale.h | 4 +- .../gemm/device/splitk_gemm_grouped.h | 12 +- .../gemm/kernel/fpA_intB_gemm.h | 4 +- .../gemm/kernel/moe_cutlass_kernel.h | 2 +- .../gemm/kernel/moe_problem_visitor.h | 10 +- .../gemm/threadblock/dq_mma_base.h | 4 +- .../dq_mma_multistage_finegrained.h | 4 +- .../threadblock/dq_mma_multistage_percol.h | 6 +- .../gemm/threadblock/dq_mma_pipelined.h | 6 +- .../warp/mma_tensorop_compute_B_with_f16.h | 2 +- .../gemm/warp/mma_tensorop_dequantizer.h | 54 +- .../fine_grained_scale_zero_iterator.h | 12 +- .../libtensorrt_llm_executor_static.a | 4 +- ...ibtensorrt_llm_executor_static.pre_cxx11.a | 4 +- .../executor/aarch64-linux-gnu/version.txt | 6 +- .../libtensorrt_llm_executor_static.a | 4 +- ...ibtensorrt_llm_executor_static.pre_cxx11.a | 4 +- .../executor/x86_64-linux-gnu/version.txt | 4 +- cpp/tensorrt_llm/kernels/banRepeatNgram.cu | 24 +- cpp/tensorrt_llm/kernels/banRepeatNgram.h | 6 +- .../kernels/beamSearchTopkKernels.cu | 170 +- .../kernels/beamSearchTopkKernels.h | 91 +- .../fmhaRunner.cpp | 98 +- .../fmhaRunner.h | 50 +- .../fused_multihead_attention_common.h | 16 +- .../fused_multihead_attention_v2.h | 40 +- .../tmaDescriptor.h | 4 +- .../kernels/customAllReduceKernels.cu | 22 +- .../kernels/customAllReduceKernels.h | 2 +- .../cutlass_kernels/cutlass_heuristic.cpp | 54 +- .../cutlass_kernels/cutlass_heuristic.h | 10 +- .../cutlass_kernels/cutlass_preprocessors.cpp | 154 +- .../cutlass_kernels/cutlass_preprocessors.h | 18 +- .../fpA_intB_gemm/fpA_intB_gemm.h | 34 +- .../fpA_intB_gemm/fpA_intB_gemm_template.h | 54 +- .../fpA_intB_gemm_template_sm90.h | 24 +- .../launchers/fpA_intB_launcher_sm90.h | 6 +- .../launchers/fpA_intB_launcher_sm90.inl | 8 +- .../cutlass_kernels/int8_gemm/int8_gemm.h | 14 +- .../int8_gemm/int8_gemm_template.h | 38 +- .../moe_gemm/moe_gemm_kernels.h | 8 +- .../moe_gemm/moe_gemm_kernels_template.h | 40 +- .../decoderMaskedMultiheadAttention.cu | 12 +- .../kernels/decoderMaskedMultiheadAttention.h | 46 +- .../decoderMaskedMultiheadAttentionLaunch.h | 40 +- .../decoderMaskedMultiheadAttentionTemplate.h | 278 +- .../decoderXQAImpl.cpp | 8 +- .../decoderXQAImpl.h | 16 +- .../decoderXQAImplPrecompiled.cpp | 64 +- .../decoderXQAImplPrecompiled.h | 16 +- .../decoderXQARunner.cpp | 12 +- .../decoderXQARunner.h | 18 +- .../xqaParams.h | 24 +- .../decoderMaskedMultiheadAttentionUtils.h | 170 +- cpp/tensorrt_llm/kernels/decodingCommon.cu | 10 +- cpp/tensorrt_llm/kernels/decodingCommon.h | 6 +- cpp/tensorrt_llm/kernels/decodingKernels.cu | 209 +- cpp/tensorrt_llm/kernels/decodingKernels.h | 44 +- cpp/tensorrt_llm/kernels/gptKernels.cu | 22 +- cpp/tensorrt_llm/kernels/gptKernels.h | 6 +- cpp/tensorrt_llm/kernels/groupGemm.cu | 4 +- cpp/tensorrt_llm/kernels/kvCacheUtils.h | 4 +- cpp/tensorrt_llm/kernels/layernormKernels.cu | 40 +- cpp/tensorrt_llm/kernels/layernormKernels.h | 4 +- cpp/tensorrt_llm/kernels/lookupKernels.cu | 12 +- cpp/tensorrt_llm/kernels/lookupKernels.h | 4 +- .../kernels/mixtureOfExperts/moe_kernels.cu | 264 +- .../kernels/mixtureOfExperts/moe_kernels.h | 74 +- .../kernels/onlineSoftmaxBeamsearchKernels.cu | 26 +- .../kernels/onlineSoftmaxBeamsearchKernels.h | 4 +- .../onlineSoftmaxBeamsearchKernelsTemplate.h | 427 ++- .../parallelDecoding/kvCacheUpdateKernels.cu | 40 +- .../parallelDecoding/kvCacheUpdateKernels.h | 24 +- cpp/tensorrt_llm/kernels/penaltyKernels.cu | 77 +- cpp/tensorrt_llm/kernels/penaltyKernels.h | 45 +- .../kernels/preQuantScaleKernel.cu | 10 +- .../kernels/preQuantScaleKernel.h | 2 +- cpp/tensorrt_llm/kernels/quantization.cu | 34 +- cpp/tensorrt_llm/kernels/quantization.h | 4 +- cpp/tensorrt_llm/kernels/rmsnormKernels.cu | 40 +- cpp/tensorrt_llm/kernels/rmsnormKernels.h | 4 +- .../kernels/samplingAirTopPKernels.cu | 91 +- .../kernels/samplingAirTopPKernels.h | 97 + .../kernels/samplingTopKKernels.cu | 293 ++- .../kernels/samplingTopKKernels.h | 67 +- .../kernels/samplingTopPKernels.cu | 174 +- .../kernels/samplingTopPKernels.h | 96 +- cpp/tensorrt_llm/kernels/selectiveScan.cu | 28 +- cpp/tensorrt_llm/kernels/splitkGroupGemm.cu | 4 +- .../kernels/unfusedAttentionKernels.cu | 470 ++-- .../kernels/unfusedAttentionKernels.h | 92 +- .../unfusedAttentionKernels_2_template.h | 148 +- .../kernels/weightOnlyBatchedGemv/common.h | 26 +- .../kernels/weightOnlyBatchedGemv/enabled.h | 2 +- .../kernels/weightOnlyBatchedGemv/kernel.h | 56 +- .../weightOnlyBatchedGemv/kernelLauncher.cu | 12 +- .../weightOnlyBatchedGemv/kernelLauncher.h | 2 +- .../weightOnlyBatchedGemv/sm90/kernel.h | 8 +- .../kernels/weightOnlyBatchedGemv/utility.h | 16 +- .../layers/baseBeamSearchLayer.cu | 36 +- cpp/tensorrt_llm/layers/baseBeamSearchLayer.h | 4 +- .../layers/dynamicDecodeLayer.cpp | 30 +- cpp/tensorrt_llm/layers/dynamicDecodeLayer.h | 43 +- .../layers/onlineBeamSearchLayer.cu | 62 +- cpp/tensorrt_llm/layers/samplingLayer.cpp | 6 +- cpp/tensorrt_llm/layers/topKSamplingLayer.cu | 19 +- cpp/tensorrt_llm/layers/topKSamplingLayer.h | 2 +- cpp/tensorrt_llm/layers/topPSamplingLayer.cu | 49 +- cpp/tensorrt_llm/layers/topPSamplingLayer.h | 3 +- cpp/tensorrt_llm/plugins/CMakeLists.txt | 1 - cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp | 2 +- cpp/tensorrt_llm/plugins/api/tllmPlugin.h | 2 +- .../bertAttentionPlugin.cpp | 122 +- .../bertAttentionPlugin/bertAttentionPlugin.h | 38 +- .../plugins/common/checkMacrosPlugin.cpp | 4 +- .../plugins/common/checkMacrosPlugin.h | 4 +- .../plugins/common/gemmPluginProfiler.cpp | 32 +- .../plugins/common/gemmPluginProfiler.h | 48 +- cpp/tensorrt_llm/plugins/common/plugin.h | 8 +- .../plugins/gemmPlugin/gemmPlugin.cpp | 106 +- .../plugins/gemmPlugin/gemmPlugin.h | 40 +- .../gptAttentionCommon/gptAttentionCommon.cpp | 228 +- .../gptAttentionCommon/gptAttentionCommon.h | 30 +- .../gptAttentionCommonImpl.h | 4 +- .../gptAttentionPlugin/gptAttentionPlugin.cpp | 156 +- .../gptAttentionPlugin/gptAttentionPlugin.h | 60 +- .../plugins/identityPlugin/identityPlugin.cpp | 48 +- .../plugins/identityPlugin/identityPlugin.h | 34 +- .../layernormQuantizationPlugin.cpp | 76 +- .../layernormQuantizationPlugin.h | 34 +- .../plugins/lookupPlugin/lookupPlugin.cpp | 70 +- .../plugins/lookupPlugin/lookupPlugin.h | 34 +- .../plugins/loraPlugin/loraPlugin.cpp | 122 +- .../plugins/loraPlugin/loraPlugin.h | 38 +- .../mixtureOfExpertsPlugin.cpp | 116 +- .../mixtureOfExperts/mixtureOfExpertsPlugin.h | 54 +- .../plugins/ncclPlugin/allgatherPlugin.cpp | 52 +- .../plugins/ncclPlugin/allgatherPlugin.h | 34 +- .../plugins/ncclPlugin/allreducePlugin.cpp | 60 +- .../plugins/ncclPlugin/allreducePlugin.h | 34 +- .../plugins/ncclPlugin/recvPlugin.cpp | 52 +- .../plugins/ncclPlugin/recvPlugin.h | 34 +- .../ncclPlugin/reduceScatterPlugin.cpp | 52 +- .../plugins/ncclPlugin/reduceScatterPlugin.h | 34 +- .../plugins/ncclPlugin/sendPlugin.cpp | 52 +- .../plugins/ncclPlugin/sendPlugin.h | 34 +- .../quantizePerTokenPlugin.cpp | 50 +- .../quantizePerTokenPlugin.h | 34 +- .../quantizeTensorPlugin.cpp | 54 +- .../quantizeTensorPlugin.h | 34 +- .../rmsnormQuantizationPlugin.cpp | 74 +- .../rmsnormQuantizationPlugin.h | 34 +- .../selectiveScanPlugin.cpp | 68 +- .../selectiveScanPlugin/selectiveScanPlugin.h | 46 +- .../smoothQuantGemmPlugin.cpp | 90 +- .../smoothQuantGemmPlugin.h | 40 +- .../weightOnlyGroupwiseQuantMatmulPlugin.cpp | 142 +- .../weightOnlyGroupwiseQuantMatmulPlugin.h | 38 +- .../weightOnlyQuantMatmulPlugin.cpp | 92 +- .../weightOnlyQuantMatmulPlugin.h | 38 +- .../pybind/batch_manager/gptManager.cpp | 6 +- .../pybind/batch_manager/gptManager.h | 2 +- .../pybind/batch_manager/inferenceRequest.cpp | 4 +- .../pybind/batch_manager/namedTensor.cpp | 2 +- .../pybind/batch_manager/namedTensor.h | 2 +- cpp/tensorrt_llm/pybind/bindings.cpp | 39 +- cpp/tensorrt_llm/pybind/utils/pathCaster.h | 6 +- cpp/tensorrt_llm/runtime/gptDecoder.cpp | 57 +- cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp | 78 +- cpp/tensorrt_llm/runtime/gptJsonConfig.cpp | 4 +- cpp/tensorrt_llm/runtime/gptSession.cpp | 4 +- cpp/tensorrt_llm/runtime/ipcUtils.cpp | 8 +- cpp/tensorrt_llm/runtime/loraUtils.cpp | 8 +- cpp/tensorrt_llm/runtime/loraUtils.h | 8 +- cpp/tensorrt_llm/runtime/ncclCommunicator.h | 4 +- .../runtime/promptTuningParams.cpp | 4 +- cpp/tensorrt_llm/runtime/runtimeBuffers.cpp | 8 +- cpp/tensorrt_llm/runtime/runtimeBuffers.h | 12 +- cpp/tensorrt_llm/runtime/runtimeKernels.cu | 13 +- cpp/tensorrt_llm/runtime/tllmBuffers.h | 4 +- cpp/tensorrt_llm/runtime/utils/debugUtils.cu | 8 +- cpp/tensorrt_llm/runtime/utils/debugUtils.h | 2 +- cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp | 8 +- cpp/tensorrt_llm/runtime/utils/numpyUtils.h | 4 +- cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp | 150 +- cpp/tensorrt_llm/thop/dynamicDecodeOp.h | 2 +- cpp/tensorrt_llm/thop/fp8Op.cpp | 10 +- cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp | 6 +- .../kernels/banRepeatNGramsKernelsTest.cpp | 26 +- cpp/tests/kernels/decodingKernelTest.cpp | 856 ++++-- cpp/tests/kernels/mixtureOfExpertsTest.cu | 40 +- .../kernels/sampling/samplingAirTopPTest.cpp | 41 +- .../kernels/sampling/samplingPenaltyTest.cpp | 457 ++-- cpp/tests/kernels/sampling/samplingTest.cpp | 397 +-- cpp/tests/kernels/sampling/samplingTest.h | 70 +- .../kernels/sampling/samplingTopKTest.cpp | 84 +- .../kernels/sampling/samplingTopPTest.cpp | 80 +- .../kernels/sampling/samplingUtilsTest.cu | 17 +- cpp/tests/kernels/shiftKCacheKernelTest.cu | 90 +- cpp/tests/kernels/stopCriteriaKernelsTest.cpp | 36 +- cpp/tests/layers/baseSamplingLayerTest.cpp | 4 +- cpp/tests/layers/baseSamplingLayerTest.h | 2 +- cpp/tests/layers/dynamicDecodeLayerTest.cpp | 8 +- .../generate_expected_medusa_output.py | 3 +- cpp/tests/resources/scripts/test_cpp.py | 12 +- cpp/tests/runtime/gptDecoderBatchTest.cpp | 5 +- cpp/tests/runtime/gptDecoderTest.cpp | 2 +- cpp/tests/runtime/gptSessionTest.cpp | 14 +- cpp/tests/runtime/runtimeKernelTest.cpp | 14 +- cpp/tests/runtime/tllmBuffersTest.cpp | 2 +- cpp/tests/runtime/tllmRuntimeTest.cpp | 2 +- cpp/tests/runtime/transposeKVKernelTest.cpp | 46 +- docker/Dockerfile.multi | 2 +- docker/Makefile | 19 +- docker/common/install_base.sh | 5 +- docker/common/install_pytorch.sh | 5 +- docker/common/install_tensorrt.sh | 15 +- docs/source/performance.md | 16 +- docs/source/precision.md | 11 + examples/baichuan/README.md | 16 +- examples/baichuan/convert_checkpoint.py | 33 +- examples/baichuan/requirements.txt | 2 + examples/bloom/convert_checkpoint.py | 2 - examples/bloom/requirements.txt | 2 + examples/chatglm/convert_checkpoint.py | 599 ++--- examples/chatglm/requirements.txt | 2 + examples/cpp_library/main.cpp | 6 +- examples/cpp_library/tensorrt_llm_libutils.h | 4 +- examples/enc_dec/README.md | 15 +- examples/enc_dec/bart/convert.py | 2 +- examples/enc_dec/build.py | 12 +- examples/enc_dec/run.py | 29 +- examples/falcon/requirements.txt | 2 + examples/gemma/README.md | 40 +- examples/gemma/convert_checkpoint.py | 153 +- examples/gemma/requirements.txt | 2 + examples/gpt/nemo_lora_convert.py | 3 +- examples/gpt/requirements.txt | 2 + examples/gptj/convert_checkpoint.py | 10 +- examples/gptneox/requirements.txt | 2 + examples/hf_lora_convert.py | 2 +- examples/high-level-api/README.md | 41 +- examples/high-level-api/llm_examples.py | 94 +- .../run_auto_parallel_examples.sh | 18 + examples/high-level-api/run_examples.sh | 4 + examples/high-level-api/run_quant_examples.sh | 0 examples/internlm/README.md | 16 +- examples/internlm/convert_checkpoint.py | 1497 ----------- examples/internlm/requirements.txt | 2 + examples/llama/README.md | 23 +- examples/llama/convert_checkpoint.py | 565 ++-- examples/llama/requirements.txt | 2 + examples/mamba/requirements.txt | 2 + examples/medusa/convert_checkpoint.py | 15 - examples/medusa/requirements.txt | 2 + examples/mixtral/requirements.txt | 2 + examples/mpt/convert_checkpoint.py | 145 +- examples/mpt/requirements.txt | 2 + examples/multimodal/README.md | 78 +- examples/multimodal/build_visual_engine.py | 83 +- examples/multimodal/run.py | 129 +- .../TritonFlashAttentionPlugin.cpp | 86 +- .../TritonFlashAttentionPlugin.h | 46 +- .../manual_plugin/tritonPlugins.cpp | 6 +- examples/opt/requirements.txt | 2 + examples/phi/convert_checkpoint.py | 250 +- examples/phi/requirements.txt | 2 + examples/quantization/quantize.py | 354 +-- examples/quantization/requirements.txt | 2 + examples/qwen/build.py | 12 +- examples/qwen/requirements.txt | 3 +- examples/qwenvl/README.md | 2 +- examples/qwenvl/requirements.txt | 3 +- examples/qwenvl/run.py | 15 +- examples/qwenvl/run_chat.py | 3 +- examples/server/requirements.txt | 2 + examples/server/test_executor.py | 2 +- examples/skywork/requirements.txt | 2 + examples/utils.py | 4 +- examples/whisper/requirements.txt | 2 + requirements-windows.txt | 3 + requirements.txt | 7 +- scripts/replace_version.sh | 9 + tensorrt_llm/__init__.py | 3 + tensorrt_llm/_utils.py | 19 +- tensorrt_llm/auto_parallel/__init__.py | 6 + tensorrt_llm/auto_parallel/auto_parallel.py | 263 ++ tensorrt_llm/auto_parallel/config.py | 393 +++ tensorrt_llm/auto_parallel/device_mesh.py | 612 +++++ tensorrt_llm/auto_parallel/node_graph.py | 347 +++ tensorrt_llm/auto_parallel/parallelization.py | 2311 +++++++++++++++++ tensorrt_llm/auto_parallel/pipeline_graph.py | 1024 ++++++++ .../auto_parallel/runtime_profiling.py | 150 ++ tensorrt_llm/auto_parallel/shape_info.py | 337 +++ tensorrt_llm/auto_parallel/simplifier.py | 835 ++++++ tensorrt_llm/auto_parallel/solver.py | 641 +++++ .../auto_parallel/tensor_parallel/__init__.py | 0 .../tensor_parallel/activation_node.py | 41 + .../tensor_parallel/assertion_node.py | 34 + .../tensor_parallel/cast_node.py | 45 + .../tensor_parallel/comm_spec.py | 58 + .../tensor_parallel/concatenation_node.py | 56 + .../tensor_parallel/constant_node.py | 45 + .../tensor_parallel/elementwise_node.py | 49 + .../tensor_parallel/fill_node.py | 59 + .../tensor_parallel/gather_node.py | 196 ++ .../tensor_parallel/identity_node.py | 56 + .../tensor_parallel/input_node.py | 79 + .../tensor_parallel/matmul_node.py | 798 ++++++ .../auto_parallel/tensor_parallel/node.py | 376 +++ .../tensor_parallel/normalization_node.py | 60 + .../tensor_parallel/output_node.py | 79 + .../auto_parallel/tensor_parallel/p2p_node.py | 67 + .../tensor_parallel/plugin_node.py | 35 + .../tensor_parallel/plugin_nodes/__init__.py | 0 .../tensor_parallel/plugin_nodes/gemm_node.py | 27 + .../plugin_nodes/gpt_attention_node.py | 379 +++ .../plugin_nodes/identity_node.py | 11 + .../plugin_nodes/look_up_node.py | 19 + .../plugin_nodes/normalization_node.py | 28 + .../tensor_parallel/reduce_node.py | 73 + .../tensor_parallel/select_node.py | 56 + .../tensor_parallel/shape_consistency.py | 832 ++++++ .../tensor_parallel/shape_node.py | 41 + .../tensor_parallel/sharding_spec.py | 418 +++ .../tensor_parallel/sharding_strategy.py | 77 + .../tensor_parallel/shuffle_node.py | 238 ++ .../tensor_parallel/slice_node.py | 100 + .../tensor_parallel/softmax_node.py | 54 + .../tensor_parallel/unary_node.py | 42 + tensorrt_llm/auto_parallel/utils.py | 319 +++ tensorrt_llm/builder.py | 204 +- tensorrt_llm/commands/build.py | 222 +- tensorrt_llm/executor.py | 34 +- tensorrt_llm/hlapi/llm.py | 449 ++-- tensorrt_llm/layers/__init__.py | 5 +- tensorrt_llm/layers/attention.py | 95 +- tensorrt_llm/layers/embedding.py | 20 + tensorrt_llm/layers/linear.py | 77 +- tensorrt_llm/layers/lora.py | 33 +- tensorrt_llm/layers/moe.py | 17 +- tensorrt_llm/layers/ssm.py | 9 +- tensorrt_llm/{runtime => }/lora_manager.py | 2 +- tensorrt_llm/models/enc_dec/model.py | 99 +- tensorrt_llm/models/gemma/model.py | 270 +- tensorrt_llm/models/gemma/weight.py | 64 +- tensorrt_llm/models/gpt/model.py | 6 - tensorrt_llm/models/llama/convert.py | 408 ++- tensorrt_llm/models/llama/model.py | 99 +- tensorrt_llm/models/llama/weight.py | 62 +- tensorrt_llm/models/mamba/model.py | 9 +- tensorrt_llm/models/modeling_utils.py | 128 +- tensorrt_llm/models/mpt/model.py | 6 +- tensorrt_llm/models/phi/convert.py | 63 + tensorrt_llm/models/phi/model.py | 20 +- tensorrt_llm/models/qwen/model.py | 19 +- tensorrt_llm/plugin/plugin.py | 12 + tensorrt_llm/quantization/__init__.py | 3 +- tensorrt_llm/quantization/layers.py | 71 +- tensorrt_llm/quantization/mode.py | 4 + tensorrt_llm/quantization/quantize.py | 8 +- tensorrt_llm/quantization/quantize_by_ammo.py | 358 +++ tensorrt_llm/runtime/__init__.py | 2 - tensorrt_llm/runtime/engine.py | 87 - tensorrt_llm/runtime/generation.py | 170 +- tensorrt_llm/runtime/model_runner.py | 262 +- .../plugin_gen/templates/plugin_common.cpp | 8 +- .../plugin_gen/templates/plugin_common.h | 30 +- tensorrt_llm/top_model_mixin.py | 150 +- tensorrt_llm/version.py | 2 +- tests/attention/test_bert_attention.py | 17 +- tests/attention/test_gpt_attention.py | 52 +- tests/attention/test_gpt_attention_IFB.py | 43 +- .../attention/test_gpt_attention_no_cache.py | 9 +- tests/bindings/test_bindings.py | 22 + tests/bindings/test_gpt_manager.py | 8 +- tests/bindings/test_gpt_session.py | 7 +- tests/dump_checkpoint_stats.py | 29 + tests/functional/test_alibi.py | 8 +- tests/functional/test_assertion.py | 7 +- tests/functional/test_cumsum.py | 8 +- tests/functional/test_einsum.py | 7 +- tests/functional/test_embedding_single_gpu.py | 8 +- tests/functional/test_geglu.py | 7 +- tests/functional/test_gelu.py | 26 +- tests/functional/test_group_norm.py | 8 +- tests/functional/test_identity.py | 11 +- tests/functional/test_masked_scatter.py | 8 +- tests/functional/test_masked_select.py | 8 +- tests/functional/test_matmul.py | 20 +- tests/functional/test_moe.py | 21 +- tests/functional/test_nccl.py | 15 +- tests/functional/test_permute.py | 7 +- tests/functional/test_repeat_interleave.py | 7 +- tests/functional/test_selective_scan.py | 31 +- tests/functional/test_slice.py | 9 +- tests/functional/test_split.py | 8 +- tests/functional/test_topk.py | 8 +- tests/functional/torch_ref.py | 1 + tests/hlapi/test_llm.py | 132 +- tests/hlapi/test_llm_quant.py | 10 +- tests/model/test_bert.py | 26 +- tests/model/test_bloom.py | 18 +- tests/model/test_falcon.py | 19 +- tests/model/test_gpt.py | 29 +- tests/model/test_gpt_e2e.py | 6 +- tests/model/test_gptj.py | 14 +- tests/model/test_gptneox.py | 14 +- tests/model/test_llama.py | 38 +- tests/model/test_mamba.py | 47 +- tests/model/test_mistral.py | 44 +- tests/model/test_phi.py | 21 +- tests/model_api/test_model_api_multi_gpu.py | 42 +- tests/model_api/test_model_level_api.py | 72 +- tests/model_api/test_model_quantization.py | 70 +- tests/quantization/test_fp8_quantization.py | 55 +- tests/quantization/test_functional.py | 12 +- tests/quantization/test_quant_layer.py | 106 +- tests/quantization/test_smooth_quant_gemm.py | 16 +- .../test_smooth_quant_layer_norm.py | 7 +- .../test_smooth_quant_rms_norm.py | 8 +- ...test_weight_only_groupwise_quant_matmul.py | 113 +- .../test_weight_only_quant_matmul.py | 49 +- tests/test_layer.py | 115 +- tests/test_llama_conversion.sh | 190 ++ tests/test_model_dtype.py | 4 +- tests/utils/util.py | 77 + 488 files changed, 23178 insertions(+), 10463 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h create mode 100644 examples/high-level-api/run_auto_parallel_examples.sh mode change 100644 => 100755 examples/high-level-api/run_examples.sh mode change 100644 => 100755 examples/high-level-api/run_quant_examples.sh delete mode 100644 examples/internlm/convert_checkpoint.py create mode 100644 scripts/replace_version.sh create mode 100644 tensorrt_llm/auto_parallel/__init__.py create mode 100644 tensorrt_llm/auto_parallel/auto_parallel.py create mode 100644 tensorrt_llm/auto_parallel/config.py create mode 100644 tensorrt_llm/auto_parallel/device_mesh.py create mode 100644 tensorrt_llm/auto_parallel/node_graph.py create mode 100644 tensorrt_llm/auto_parallel/parallelization.py create mode 100644 tensorrt_llm/auto_parallel/pipeline_graph.py create mode 100644 tensorrt_llm/auto_parallel/runtime_profiling.py create mode 100644 tensorrt_llm/auto_parallel/shape_info.py create mode 100644 tensorrt_llm/auto_parallel/simplifier.py create mode 100644 tensorrt_llm/auto_parallel/solver.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/__init__.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/input_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/output_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/select_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/shape_consistency.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/shape_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/sharding_spec.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/sharding_strategy.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/shuffle_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/slice_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/softmax_node.py create mode 100644 tensorrt_llm/auto_parallel/tensor_parallel/unary_node.py create mode 100644 tensorrt_llm/auto_parallel/utils.py rename tensorrt_llm/{runtime => }/lora_manager.py (99%) create mode 100644 tensorrt_llm/models/phi/convert.py create mode 100644 tensorrt_llm/quantization/quantize_by_ammo.py delete mode 100644 tensorrt_llm/runtime/engine.py create mode 100644 tests/dump_checkpoint_stats.py create mode 100755 tests/test_llama_conversion.sh diff --git a/.clang-format b/.clang-format index 1983a9ca5..12bb2f112 100644 --- a/.clang-format +++ b/.clang-format @@ -59,6 +59,7 @@ PenaltyBreakString: 1000 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PointerAlignment: Left +QualifierAlignment: Right ReflowComments: true SeparateDefinitionBlocks: Always SortIncludes: CaseSensitive diff --git a/.gitignore b/.gitignore index 0296d6d11..cb9aee85b 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,16 @@ venv/ .local/ .hypothesis/ .idea/ +dump*/ +.trt-internal +*.dot +*.prof +*.log +*.pkl +*.hdf5 +*.lock +config.json +/*.svg cpp/cmake-build-* cpp/.ccache/ tensorrt_llm/libs diff --git a/README.md b/README.md index f36299deb..c43596475 100644 --- a/README.md +++ b/README.md @@ -355,6 +355,9 @@ however, that it is recommended to use the C++ version. ## Troubleshooting +* If you encounter accuracy issues in the generated text, you may want to increase + the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to + `trtllm-build`. * It's recommended to add options `–shm-size=1g –ulimit memlock=-1` to the docker or nvidia-docker run command. Otherwise you may see NCCL errors when diff --git a/benchmarks/cpp/README.md b/benchmarks/cpp/README.md index 622611df7..e3e6a200f 100644 --- a/benchmarks/cpp/README.md +++ b/benchmarks/cpp/README.md @@ -39,7 +39,6 @@ Take GPT-350M as an example for single GPU ``` ./benchmarks/gptSessionBenchmark \ - --model gpt_350m \ --engine_dir "../../benchmarks/gpt_350m/" \ --batch_size "1" \ --input_output_len "60,20" @@ -50,7 +49,6 @@ Take GPT-350M as an example for single GPU Take GPT-175B as an example for multiple GPUs ``` mpirun -n 8 ./benchmarks/gptSessionBenchmark \ - --model gpt_175b \ --engine_dir "../../benchmarks/gpt_175b/" \ --batch_size "1" \ --input_output_len "60,20" @@ -125,7 +123,6 @@ cd cpp/build Take GPT-350M as an example for single GPU V1 batching ``` ./benchmarks/gptManagerBenchmark \ - --model gpt \ --engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \ --type V1 \ --dataset ../../benchmarks/cpp/preprocessed_dataset.json @@ -135,7 +132,6 @@ Take GPT-350M as an example for single GPU V1 batching Take GPT-350M as an example for 2-GPU inflight batching ``` mpirun -n 2 ./benchmarks/gptManagerBenchmark \ - --model gpt \ --engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \ --type IFB \ --dataset ../../benchmarks/cpp/preprocessed_dataset.json @@ -165,7 +161,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request Take GPT-350M as an example for single GPU with static batching ``` ./benchmarks/gptManagerBenchmark \ - --model gpt \ --engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \ --type IFB \ --static_emulated_batch_size 32 \ diff --git a/benchmarks/cpp/bertBenchmark.cpp b/benchmarks/cpp/bertBenchmark.cpp index aa7e8f760..712149cb7 100644 --- a/benchmarks/cpp/bertBenchmark.cpp +++ b/benchmarks/cpp/bertBenchmark.cpp @@ -237,7 +237,7 @@ int main(int argc, char* argv[]) benchmarkBert(result["model"].as(), result["engine_dir"].as(), batchSizes, inLens, logger, result["warm_up"].as(), result["num_runs"].as(), result["duration"].as()); } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR(e.what()); return 1; diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index 66b98c3dc..4de58c95f 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -24,6 +24,7 @@ #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/worldConfig.h" @@ -64,20 +65,18 @@ struct BenchmarkParams class WorkItem { public: - WorkItem(std::shared_ptr ir, uint64_t requestId) - : mInferenceRequest(ir) + WorkItem(std::shared_ptr inferenceRequest, uint64_t requestId) + : mInferenceRequest(std::move(inferenceRequest)) , mRequestId(requestId) { } - ~WorkItem() {} - - uint64_t requestId() const + [[nodiscard]] uint64_t requestId() const { return mRequestId; } - std::shared_ptr getInferenceRequest() const + [[nodiscard]] std::shared_ptr getInferenceRequest() const { return mInferenceRequest; } @@ -93,7 +92,7 @@ class WorkItemsQueue public: void clear() { - std::lock_guard lk(mMutex); + std::lock_guard lock(mMutex); mPendingWorkItems.clear(); mPendingWorkItemsReqIds.clear(); mInProgressWorkItems.clear(); @@ -289,7 +288,7 @@ class Recorder if (outputFile.is_open()) { - for (const auto& header : headers) + for (auto const& header : headers) { outputFile << header << ","; } @@ -340,13 +339,12 @@ class ExecutorServer mExecutor = std::make_shared(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig); } - ~ExecutorServer() {} - void enqueue(std::vector requests, bool warmup = false) { try { - std::vector inputLengths, maxNewTokens; + std::vector inputLengths; + std::vector maxNewTokens; for (auto const& request : requests) { inputLengths.push_back(request.getInputTokenIds().size()); @@ -363,11 +361,10 @@ class ExecutorServer mActiveCount++; } } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_THROW("%s", e.what()); } - return; } void waitForResponses(std::optional numRequests, bool warmup = false) @@ -415,17 +412,16 @@ class ExecutorServer class GptServer { public: - GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth, + GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth, batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams, std::shared_ptr recorder, std::optional terminateReqId, std::chrono::milliseconds waitSleep, - std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData) + std::optional const staticEmulatedBatchSize, + std::optional const batchTimeout, bool logIterationData) : mRecorder(std::move(recorder)) , mTerminateReqId(terminateReqId) , mWaitSleep(waitSleep) , mStaticEmulatedBatchSize(staticEmulatedBatchSize) - , mEmulatedBatchEndTimestamp( - std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs)) - , mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs) + , mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0})) , mActiveCount(0) { ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log) @@ -473,16 +469,21 @@ class GptServer mRecorder->recordStart(request, requestId); mWorkItemsQueue.push(request, requestId); } - catch (const tc::TllmException& e) + catch (tc::TllmException const& e) { throw; } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_THROW("%s", e.what()); } } + void resetBatchDeadline() + { + mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch(); + } + void waitForEmpty() const { while (!mWorkItemsQueue.empty()) @@ -502,9 +503,9 @@ class GptServer } // Return up to max_num_requests inference requests. - std::list> getInferenceRequests(const int max_num_requests) + std::list> getInferenceRequests(int const max_num_requests) { - std::list> rval; + std::list> inferenceRequests; auto& comm = COMM_SESSION; if (max_num_requests > 0) { @@ -515,12 +516,12 @@ class GptServer auto const numNewWorkItems = std::min(static_cast(mWorkItemsQueue.numPendingWorkItems()), static_cast(max_num_requests)); - bool readyForNextBatch = numNewWorkItems > 0; + bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load(); + bool readyForNextBatch = numNewWorkItems > 0 && timeout; if (mStaticEmulatedBatchSize) { if (numNewWorkItems > 0) { - bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp; bool const previousBatchFinished = mActiveCount == 0; bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value(); readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch); @@ -529,26 +530,23 @@ class GptServer { // Timeout should only begin once we have at least 1 pending request. // Reset timeout when no requests are pending or we submit a new batch. - mEmulatedBatchEndTimestamp - = std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs); + resetBatchDeadline(); } } if (readyForNextBatch) { - int count = 0; // Only add a single batch at a time when emulating static batching auto const numItemsToAdd = std::min( numNewWorkItems, static_cast(mStaticEmulatedBatchSize.value_or(numNewWorkItems))); mActiveCount += numItemsToAdd; - while (count < numItemsToAdd) + while (inferenceRequests.size() < numItemsToAdd) { auto [workItem, markedInProgress] = mWorkItemsQueue.pop(); if (markedInProgress) { - rval.emplace_back(workItem->getInferenceRequest()); - count++; + inferenceRequests.emplace_back(workItem->getInferenceRequest()); } else { @@ -561,14 +559,14 @@ class GptServer } if (world_size > 1) { - auto numNewWorkItems = static_cast(rval.size()); + auto numNewWorkItems = static_cast(inferenceRequests.size()); comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0); if (numNewWorkItems > 0) { std::vector packed; - for (auto const& ir : rval) + for (auto const& infReq : inferenceRequests) { - auto vpacked = ir->serialize(); + auto vpacked = infReq->serialize(); packed.push_back(static_cast(vpacked.size())); packed.insert( packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); @@ -590,18 +588,18 @@ class GptServer for (int64_t count = 0; count < numNewWorkItems; ++count) { int64_t n = *(packed_ptr++); - auto ir = InferenceRequest::deserialize(packed_ptr); + auto infReq = InferenceRequest::deserialize(packed_ptr); packed_ptr += n; - rval.emplace_back(ir); + inferenceRequests.emplace_back(infReq); } } } } - return rval; + return inferenceRequests; } void sendResponse(uint64_t requestId, [[maybe_unused]] std::list const& response_tensors, - bool final_response, [[maybe_unused]] const std::string& errMsg) + bool final_response, [[maybe_unused]] std::string const& errMsg) { // `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs, // cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and @@ -616,7 +614,7 @@ class GptServer mActiveCount--; } } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what()); } @@ -628,9 +626,9 @@ class GptServer WorkItemsQueue mWorkItemsQueue; std::optional mTerminateReqId; std::chrono::milliseconds mWaitSleep; - std::optional mStaticEmulatedBatchSize; - std::chrono::time_point mEmulatedBatchEndTimestamp; - int32_t mStaticEmulatedTimeoutMs; + std::optional mStaticEmulatedBatchSize; + std::chrono::milliseconds mBatchTimeout; + std::atomic mBatchDeadline; std::atomic mActiveCount; }; // class GptServer @@ -674,10 +672,9 @@ std::shared_ptr makeRequest(std::uint64_t reqId, Sample const& auto request = std::make_shared(reqId); auto const& inputIds = sample.inputIds; request->setInputIds(bufferManager.copyFrom( - inputIds, ITensor::makeShape({static_cast(inputIds.size())}), MemoryType::kPINNED)); + inputIds, ITensor::makeShape({static_cast(inputIds.size())}), MemoryType::kCPU)); auto const requestOutputLen = sample.outputLen; - request->setMaxNewTokens( - bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED)); + request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU)); request->setBeamWidth(beamWidthTensor); if (eosId != nullptr) { @@ -704,14 +701,15 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWid { auto samplingConfig = texec::SamplingConfig{beamWidth}; auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; - return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId); + return {sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId}; } void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp, - std::optional const& eosId, std::optional const& padId, BenchmarkParams const& benchmarkParams, - batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, - bool returnGenerationLogits, std::optional const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, + std::optional const& eosId, std::optional const& padId, + BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy, + std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits, + std::optional const staticEmulatedBatchSize, std::optional const batchTimeout, bool logIterationData) { auto const worldConfig = WorldConfig::mpi(); @@ -736,14 +734,14 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)}; // Load dataset - const auto samples = parseWorkloadJson(datasetPath, maxNumSamples); - const auto numSamples = samples.size(); + auto const samples = parseWorkloadJson(datasetPath, maxNumSamples); + auto const numSamples = samples.size(); - const int maxBeamWidth = beamWidth; + int const maxBeamWidth = beamWidth; auto recorder = std::make_shared(opCsvFile); uint64_t terminateReqId = numSamples + 1; auto gptServer = std::make_shared(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, - recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData); + recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData); ITensor::SharedPtr eosIdTensor{ eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr}; @@ -761,6 +759,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType if (worldConfig.getRank() == 0) { // Warm up + gptServer->resetBatchDeadline(); SizeType reqId = 0; for (auto i = 0; i < warmUp; ++i) { @@ -774,6 +773,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType // Benchmark recorder->initialize(); + gptServer->resetBatchDeadline(); for (std::size_t i = 0; i < numSamples; ++i) { auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager, @@ -806,23 +806,19 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits, std::optional const staticEmulatedBatchSize, bool logIterationData) { - // Check that mpi size is 1 for now - auto const worldConfig = WorldConfig::mpi(); - if (worldConfig.getSize() > 1) - { - TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1"); - } + auto const& world = tensorrt_llm::mpi::MpiComm::world(); + auto worldRank = world.getRank(); // Load dataset - const auto samples = parseWorkloadJson(datasetPath, maxNumSamples); - const auto numSamples = samples.size(); + auto const samples = parseWorkloadJson(datasetPath, maxNumSamples); + auto const numSamples = samples.size(); auto recorder = std::make_shared(opCsvFile); auto executorServer = std::make_shared(engineDir, modelType, beamWidth, schedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData); - if (worldConfig.getRank() == 0) + if (worldRank == 0) { // Warm up { @@ -849,7 +845,7 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m delays.push_back(static_cast(samples[i].delay * 1000)); } - bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; }); + bool hasDelay = std::any_of(delays.begin(), delays.end(), [](auto const& delay) { return delay > 0; }); if (hasDelay && staticEmulatedBatchSize) { TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes"); @@ -910,9 +906,6 @@ int main(int argc, char* argv[]) cxxopts::Options options( "TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models."); options.add_options()("h,help", "Print usage"); - // TODO(rkobus): remove because unused - options.add_options()( - "m,model", "Model name specified for engines.", cxxopts::value()->default_value("gpt_350m")); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value()); options.add_options()( "api", "API type: gptManager or executor.", cxxopts::value()->default_value("gptManager")); @@ -929,8 +922,8 @@ int main(int argc, char* argv[]) options.add_options()( "warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value()->default_value("2")); options.add_options()( - "eos_id", "Specify the end-of-sequence token id.", cxxopts::value()->default_value("-1")); - options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value()); + "eos_id", "Specify the end-of-sequence token id.", cxxopts::value()->default_value("-1")); + options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value()); options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value()); options.add_options()( "kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value()); @@ -949,11 +942,15 @@ int main(int argc, char* argv[]) options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.", cxxopts::value()->default_value("guaranteed_no_evict")); + options.add_options()("first_batch_delay", + "Delay before submitting the first batch of requests. This can be used to increase the size of the first " + "batch.", + cxxopts::value()); options.add_options()("static_emulated_batch_size", - "Emulate static batching performance with the provided batch size.", cxxopts::value()); + "Emulate static batching performance with the provided batch size.", cxxopts::value()); options.add_options()("static_emulated_timeout", "Timeout (ms) before launching a partial batch in emulated static batching mode", - cxxopts::value()->default_value("500")); + cxxopts::value()->default_value("500")); options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.", cxxopts::value()->default_value("error")); options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.", @@ -1042,23 +1039,31 @@ int main(int argc, char* argv[]) // Argument: Enable return context logits bool returnGenerationLogits = result["return_generation_logits"].as(); - std::optional padId; + std::optional padId; // Argument: Padding token id if (result.count("pad_id")) { - padId = result["pad_id"].as(); + padId = result["pad_id"].as(); } // Argument: End-of-sentence token id - std::optional eosId = result["eos_id"].as(); + std::optional eosId = result["eos_id"].as(); - std::optional staticEmulatedBatchSize; + std::optional batchTimeout; + // Argument: first_batch_delay + if (result.count("first_batch_delay")) + { + batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as()}; + } + + std::optional staticEmulatedBatchSize; // Argument: Static emulated batch size if (result.count("static_emulated_batch_size")) { - staticEmulatedBatchSize = result["static_emulated_batch_size"].as(); + staticEmulatedBatchSize = result["static_emulated_batch_size"].as(); + + batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as()}; } - auto const staticEmulatedTimeout = result["static_emulated_timeout"].as(); // Argument: Scheduler policy batch_scheduler::SchedulerPolicy schedulerPolicy; @@ -1114,10 +1119,10 @@ int main(int argc, char* argv[]) { benchmarkGptManager(result["engine_dir"].as(), modelType, datasetPath, opCsvFile, maxNumSamples, beamWidth, result["warm_up"].as(), eosId, padId, benchmarkParams, schedulerPolicy, - waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout, + waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout, logIterationData); } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR(e.what()); return 1; @@ -1131,7 +1136,7 @@ int main(int argc, char* argv[]) beamWidth, result["warm_up"].as(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData); } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR(e.what()); return 1; diff --git a/benchmarks/cpp/gptSessionBenchmark.cpp b/benchmarks/cpp/gptSessionBenchmark.cpp index fecc11293..290f50889 100644 --- a/benchmarks/cpp/gptSessionBenchmark.cpp +++ b/benchmarks/cpp/gptSessionBenchmark.cpp @@ -15,7 +15,6 @@ * limitations under the License. */ #include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/gptSession.h" @@ -56,12 +55,11 @@ size_t monitorMemory(std::atomic_bool& done) return peakMem; } -void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath, - std::vector const& batchSizes, int beamWidth, std::vector> const& inOutLen, - std::shared_ptr const& logger, int warmUp, int numRuns, int duration, - GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits, bool disableForceMaxTokens) +void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector const& batchSizes, int beamWidth, + std::vector> const& inOutLen, std::shared_ptr const& logger, int warmUp, + int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits, + bool disableForceMaxTokens) { - std::string modelNameHyphen = modelName; std::filesystem::path jsonFileName = dataPath / "config.json"; auto const json = GptJsonConfig::parse(jsonFileName); auto const modelConfig = json.getModelConfig(); @@ -69,7 +67,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con SizeType deviceCount{0}; TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism()); - auto const enginePath = dataPath / json.engineFilename(worldConfig, modelNameHyphen); + auto const enginePath = dataPath / json.engineFilename(worldConfig); auto const dtype = modelConfig.getDataType(); auto const maxNumTokens = modelConfig.getMaxNumTokens(); auto const useHalf = (dtype == nvinfer1::DataType::kHALF); @@ -104,7 +102,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con auto& memoryCounter = MemoryCounters::getInstance(); TLLM_LOG_INFO(memoryCounter.toString()); - + std::atomic_bool done; for (auto const batchSize : batchSizes) { if (inputPacked && maxNumTokens != std::nullopt) @@ -114,10 +112,11 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con "benchmark on %d tokens", maxNumTokens.value(), maxBatchSize * maxInputLength); } - std::atomic_bool done = false; + done = false; + auto peakMemFuture = std::async(&monitorMemory, std::ref(done)); + size_t peakMem; try { - auto peakMemFuture = std::async(&monitorMemory, std::ref(done)); TLLM_LOG_INFO(memoryCounter.toString()); std::vector inputLengthsHost(batchSize, maxInputLength); @@ -205,7 +204,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con TLLM_LOG_INFO(memoryCounter.toString()); done = true; - size_t peakMem = peakMemFuture.get(); + peakMemFuture.wait(); + peakMem = peakMemFuture.get(); printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000); @@ -275,6 +275,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con std::size_t found = std::string(e.what()).find("out of memory"); // We need to kill the memory monitor when OOM. done = true; + peakMemFuture.wait(); + peakMem = peakMemFuture.get(); // Unexpected error; rethrow if (found == std::string::npos) @@ -297,6 +299,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con { // We need to kill memory monitor when any other issue occurs done = true; + peakMemFuture.wait(); + peakMem = peakMemFuture.get(); throw; } } @@ -311,8 +315,6 @@ int main(int argc, char* argv[]) cxxopts::Options options( "TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models."); options.add_options()("h,help", "Print usage"); - options.add_options()( - "m,model", "Model name specified for engines.", cxxopts::value()->default_value("gpt_350m")); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value()); options.add_options()("batch_size", "Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: " @@ -459,11 +461,11 @@ int main(int argc, char* argv[]) try { - benchmarkGptSession(result["model"].as(), result["engine_dir"].as(), batchSizes, - beamWidth, inOutLen, logger, result["warm_up"].as(), result["num_runs"].as(), - result["duration"].as(), sessionConfig, enableCudaGraph, printAllLogits, disableForceMaxTokens); + benchmarkGptSession(result["engine_dir"].as(), batchSizes, beamWidth, inOutLen, logger, + result["warm_up"].as(), result["num_runs"].as(), result["duration"].as(), sessionConfig, + enableCudaGraph, printAllLogits, disableForceMaxTokens); } - catch (const std::exception& e) + catch (std::exception const& e) { TLLM_LOG_ERROR(e.what()); return 1; diff --git a/benchmarks/python/allowed_configs.py b/benchmarks/python/allowed_configs.py index 67c41e706..d3bfb54aa 100644 --- a/benchmarks/python/allowed_configs.py +++ b/benchmarks/python/allowed_configs.py @@ -86,6 +86,7 @@ class EncDecBuildConfig: max_output_len: Optional[int] = None builder_opt: Optional[int] = None n_mels: Optional[int] = None + skip_cross_qkv: bool = False def __post_init__(self) -> None: assert self.head_size is not None diff --git a/benchmarks/python/base_benchmark.py b/benchmarks/python/base_benchmark.py index 22d11d17c..b497193f0 100644 --- a/benchmarks/python/base_benchmark.py +++ b/benchmarks/python/base_benchmark.py @@ -89,7 +89,11 @@ def __init__(self, (f'Engine world size ({world_size}) != Runtime world size ({self.world_size})') # Load config into self for key, value in self.config['pretrained_config'].items(): - setattr(self, key, value) + if key == "ssm_cfg": + for ssm_key, ssm_value in value.items(): + setattr(self, "mamba_" + ssm_key, ssm_value) + else: + setattr(self, key, value) self.quant_mode = QuantMode.from_quant_algo( quant_algo=self.quantization['quant_algo'], diff --git a/benchmarks/python/benchmark.py b/benchmarks/python/benchmark.py index 8b48afade..d4a3bfe0b 100644 --- a/benchmarks/python/benchmark.py +++ b/benchmarks/python/benchmark.py @@ -327,9 +327,16 @@ def main(args): torch.cuda.empty_cache() latencies = [] + # Disable Host memory monitor when cuda graph is enabled for cuda graph performance. + disable_host_mem_monitor = False + if args.enable_cuda_graph: + logger.warning( + 'Disable host memory monitor when cuda graph is enabled.') + disable_host_mem_monitor = True if not disable_mem_monitor: - memory_monitor = MemoryMonitor() + memory_monitor = MemoryMonitor( + disable_host_mem_monitor=disable_host_mem_monitor) memory_monitor.start() iter_idx = 0 diff --git a/benchmarks/python/build.py b/benchmarks/python/build.py index f775ace33..1aaa319a9 100644 --- a/benchmarks/python/build.py +++ b/benchmarks/python/build.py @@ -648,9 +648,12 @@ def build_gpt(args): 'tp_size': world_size, }, } + config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config) elif family == "internlm": + quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization) + config = { 'architecture': 'LLaMAForCausalLM', @@ -673,8 +676,10 @@ def build_gpt(args): build_config['n_positions'], 'hidden_act': build_config['hidden_act'], - 'quantization': - quant_mode.to_dict(), + 'quantization': { + 'quant_algo': quant_algo, + 'kv_cache_quant_algo': kv_cache_quant_algo + }, 'mapping': { 'world_size': world_size, 'tp_size': world_size @@ -696,6 +701,7 @@ def build_gpt(args): "has_zero_point": True, "pre_quant_scale": False, }) + config = PretrainedConfig.from_dict(config) tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config) elif family == "qwen": @@ -1038,6 +1044,7 @@ def enc_dec_build_helper(component, config, args): or quant_mode.is_int8_weight_only()), quant_mode=quant_mode, n_mels=n_mels, + skip_cross_qkv=config['skip_cross_qkv'], ) # build engine diff --git a/benchmarks/python/mem_monitor.py b/benchmarks/python/mem_monitor.py index 2665260ff..5b654310f 100644 --- a/benchmarks/python/mem_monitor.py +++ b/benchmarks/python/mem_monitor.py @@ -22,7 +22,7 @@ class MemoryMonitor: - def __init__(self, query_interval=0.1): + def __init__(self, query_interval=0.1, disable_host_mem_monitor=False): self.query_interval = query_interval # second(s) self.mem_monitor_process = None # bytes @@ -35,6 +35,8 @@ def __init__(self, query_interval=0.1): self.signal_event = Event() # Sending signal to subprocess self.peak_mem_queue = Queue() # Receiving results from subprocess + self.disable_host_mem_monitor = disable_host_mem_monitor + def start(self): self.mem_monitor_process = Process(target=self._upd_peak_memory_usage, args=(self.signal_event, @@ -70,7 +72,10 @@ def _upd_peak_memory_usage(self, signal_event, peak_mem_queue): peak_mem_queue.put((peak_host_used, peak_device_used)) def get_memory_usage(self): - host_used, _, _ = host_memory_info(self.pid) + if self.disable_host_mem_monitor: + host_used = 0 + else: + host_used, _, _ = host_memory_info(self.pid) device_used, _, _ = device_memory_info() return host_used, device_used diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 54596054e..4cffce5a3 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -36,6 +36,7 @@ option(NVTX_DISABLE "Disable all NVTX features" ON) option(WARNING_IS_ERROR "Treat all warnings as errors" OFF) option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF) option(FAST_MATH "Compiling in fast math mode" OFF) +option(INDEX_RANGE_CHECK "Compiling with index range checks" OFF) if(NVTX_DISABLE) add_compile_definitions("NVTX_DISABLE") @@ -97,6 +98,11 @@ if(FAST_BUILD) message(WARNING "Skip some kernels to accelerate compilation") endif() +if(INDEX_RANGE_CHECK) + add_compile_definitions("INDEX_RANGE_CHECK") + message(WARNING "Check index range to detect OOB accesses") +endif() + # Determine CUDA version before enabling the language extension check_language(CUDA) if(CMAKE_CUDA_COMPILER) @@ -162,10 +168,6 @@ message(STATUS " version: ${CUDAToolkit_VERSION}") message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}") message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}") -find_library( - CUDNN_LIB cudnn - HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR} - PATH_SUFFIXES lib64 lib lib/x64) set(CUBLAS_LIB CUDA::cublas) set(CUBLASLT_LIB CUDA::cublasLt) set(CUDA_DRV_LIB CUDA::cuda_driver) diff --git a/cpp/include/tensorrt_llm/batch_manager/callbacks.h b/cpp/include/tensorrt_llm/batch_manager/callbacks.h index 5f8be3150..d681a741b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/callbacks.h +++ b/cpp/include/tensorrt_llm/batch_manager/callbacks.h @@ -29,9 +29,9 @@ class InferenceRequest; class NamedTensor; using GetInferenceRequestsCallback = std::function>(int32_t)>; -using SendResponseCallback = std::function const&, bool, const std::string&)>; +using SendResponseCallback = std::function const&, bool, std::string const&)>; using PollStopSignalCallback = std::function()>; // json of stats as a string -using ReturnBatchManagerStatsCallback = std::function; +using ReturnBatchManagerStatsCallback = std::function; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h index 1edfc1e5c..d94d8c1f7 100644 --- a/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/inferenceRequest.h @@ -312,9 +312,9 @@ class InferenceRequest : public GenericInferenceRequest serialize() const; - static std::shared_ptr deserialize(const std::vector& packed); + static std::shared_ptr deserialize(std::vector const& packed); - static std::shared_ptr deserialize(const int64_t* packed_ptr); + static std::shared_ptr deserialize(int64_t const* packed_ptr); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h index fe2515c70..91be3fed3 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h @@ -50,6 +50,13 @@ class KvCacheConfig { } + bool operator==(KvCacheConfig const& other) const + { + return maxTokens == other.maxTokens && maxAttentionWindow == other.maxAttentionWindow + && sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction + && enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm; + } + std::optional maxTokens; std::optional maxAttentionWindow; std::optional sinkTokenLength; diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index e37a78907..aa4c915be 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -176,6 +176,13 @@ class GenerationRequest mNumTokens += n; } + void removeTokens(SizeType n) + { + TLLM_CHECK(n <= mNumTokens); + TLLM_CHECK(mNumTokens - n >= 0); + mNumTokens -= n; + } + [[nodiscard]] SizeType getSequenceSlotIdx() const { return mSeqSlotIdx; @@ -214,6 +221,14 @@ class GenerationRequest } } + void removeLastBlock() + { + for (auto& beamBlockIds : mCacheBlockIds) + { + beamBlockIds.pop_back(); + } + } + void setNumPrepopulatedTokens(std::vector numPrepopulatedTokens) { mNumPrepopulatedTokens = std::move(numPrepopulatedTokens); @@ -280,32 +295,40 @@ class BlockManager //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks void schedulingReleaseBlocks(GenerationRequest& sequence); - [[nodiscard]] SizeType getNumFreeBlocks() const + //! \brief Release last block in the sequence + void releaseLastBlock(GenerationRequest& sequence); + + [[nodiscard]] SizeType getNumFreeBlocks() const noexcept { return mFreeBlocks.size(); } - [[nodiscard]] SizeType getNumAllocatedBlocks() const + [[nodiscard]] SizeType getNumReusedBlocks() const noexcept + { + return mReusedBlocks; + } + + [[nodiscard]] SizeType getNumAllocatedBlocks() const noexcept { return getMaxNumBlocks() - getNumFreeBlocks(); } - [[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const + [[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const noexcept { return getNumFreeBlocks() >= numRequired; } - [[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const + [[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const noexcept { return mSchedulingNumFreeBlocks >= numRequired; } - [[nodiscard]] SizeType getMaxNumBlocks() const + [[nodiscard]] SizeType getMaxNumBlocks() const noexcept { return static_cast(mAllBlocksByIdx.size()); } - [[nodiscard]] SizeType getTokensPerBlock() const + [[nodiscard]] SizeType getTokensPerBlock() const noexcept { return mTokensPerBlock; } @@ -478,11 +501,15 @@ class KVCacheManager return mEnableBlockReuse; } + void removeToken(SizeType seqSlotIdx); + void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths); + private: void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth); void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx); void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx); - void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx); + void updateNewBlockPointer(GenerationRequest const& seq, SizeType seqSlotIdx, SizeType blockIdx); + void updateToken(SizeType seqSlotIdx, bool addToken); private: // Number of elements per one blocks diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index c9d340bc5..43ef4bd89 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -474,7 +474,7 @@ class GenericLlmRequest return mDraftTokens->size(); } - void setReturnContextLogits(const bool returnContextLogits) + void setReturnContextLogits(bool const returnContextLogits) { mReturnContextLogits = returnContextLogits; } @@ -484,7 +484,7 @@ class GenericLlmRequest return mReturnContextLogits; } - void setReturnGenerationLogits(const bool returnGenerationLogits) + void setReturnGenerationLogits(bool const returnGenerationLogits) { mReturnGenerationLogits = returnGenerationLogits; } @@ -556,6 +556,11 @@ class GenericLlmRequest return mState == REQUEST_STATE_GENERATION_IN_PROGRESS; } + [[nodiscard]] bool isGenerationCompleteState() const noexcept + { + return mState == REQUEST_STATE_GENERATION_COMPLETE; + } + /// To determine whether the context is unchunked. When a context is chunked into only a part, it /// is still different from the unchunked state, which indicates the initial status. [[nodiscard]] bool isFullContextRequest() const noexcept diff --git a/cpp/include/tensorrt_llm/batch_manager/namedTensor.h b/cpp/include/tensorrt_llm/batch_manager/namedTensor.h index 6b9af2e85..714d72526 100644 --- a/cpp/include/tensorrt_llm/batch_manager/namedTensor.h +++ b/cpp/include/tensorrt_llm/batch_manager/namedTensor.h @@ -64,7 +64,7 @@ class NamedTensor : public GenericNamedTensor const& _shape, std::string _name, const void* _data = nullptr); + nvinfer1::DataType _type, std::vector const& _shape, std::string _name, void const* _data = nullptr); NamedTensor(TensorPtr _tensor, std::string _name) : Base(std::move(_tensor), std::move(_name)){}; @@ -74,6 +74,10 @@ class NamedTensor : public GenericNamedTensor serialize() const; - static NamedTensor deserialize(const int64_t* packed); + void serialize(int64_t* out, const size_t totalSize) const; + + [[nodiscard]] size_t serializedSize() const; + + static NamedTensor deserialize(int64_t const* packed); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h index 751f57d81..e42a342f9 100644 --- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h +++ b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h @@ -50,11 +50,19 @@ class TrtGptModelOptionalParams explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig) : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), - executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(), - executorConfig.getEnableChunkedContext()) + executorConfig.getEnableTrtOverlap(), + executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(), + executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext()) { } + bool operator==(TrtGptModelOptionalParams const& other) const + { + return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap + && deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs + && enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode; + } + KvCacheConfig kvCacheConfig; bool enableTrtOverlap; diff --git a/cpp/include/tensorrt_llm/common/arrayView.h b/cpp/include/tensorrt_llm/common/arrayView.h index cd409e684..31dcd7453 100644 --- a/cpp/include/tensorrt_llm/common/arrayView.h +++ b/cpp/include/tensorrt_llm/common/arrayView.h @@ -16,6 +16,7 @@ #pragma once +#include "tensorrt_llm/common/assert.h" #include namespace tensorrt_llm::common @@ -80,11 +81,17 @@ class ArrayView [[nodiscard]] reference operator[](size_type index) { +#ifdef INDEX_RANGE_CHECK + TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize); +#endif return mData[index]; } [[nodiscard]] const_reference operator[](size_type index) const { +#ifdef INDEX_RANGE_CHECK + TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize); +#endif return mData[index]; } diff --git a/cpp/include/tensorrt_llm/common/mpiUtils.h b/cpp/include/tensorrt_llm/common/mpiUtils.h index 56304b029..51e622e30 100644 --- a/cpp/include/tensorrt_llm/common/mpiUtils.h +++ b/cpp/include/tensorrt_llm/common/mpiUtils.h @@ -56,6 +56,7 @@ enum class MpiType kUINT64, kFP8, kBF16, + kCHAR, }; //! \brief For converting a C++ data type to a TensorRT data type. @@ -133,6 +134,12 @@ struct MpiTypeConverter static constexpr auto value = MpiType::kUINT64; }; +template <> +struct MpiTypeConverter +{ + static constexpr auto value = MpiType::kCHAR; +}; + #ifdef ENABLE_FP8 template <> struct MpiTypeConverter<__nv_fp8_e4m3> @@ -202,8 +209,8 @@ class MpiComm ~MpiComm() noexcept; // no copy - MpiComm(const MpiComm&) = delete; - MpiComm& operator=(const MpiComm&) = delete; + MpiComm(MpiComm const&) = delete; + MpiComm& operator=(MpiComm const&) = delete; // move MpiComm(MpiComm&&) noexcept; @@ -253,7 +260,24 @@ class MpiComm } } - void bcast(std::vector& packed, int root) const; + template + void bcast(std::vector& vec, int root) const + { + auto const rank = getRank(); + auto vecSize = (rank == root) ? static_cast(vec.size()) : int64_t(0); + bcast(&vecSize, 1, MpiType::kINT64, root); + vec.resize(vecSize); + + if constexpr (std::is_fundamental_v>) + { + auto const mpiType = MpiTypeConverter>::value; + bcast(vec.data(), vec.size(), mpiType, root); + } + else + { + bcast(vec.data(), vec.size() * sizeof(T), MpiType::kBYTE, root); + } + } void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const; @@ -297,8 +321,8 @@ class MpiComm } } - void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; - void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const; + void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; + void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const; void barrier() const; void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 07ab22a14..f2a77b6d1 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -34,6 +34,9 @@ namespace tensorrt_llm::executor { +class Model; +class Serialization; + /// @brief Sampling configuration class SamplingConfig { @@ -51,6 +54,8 @@ class SamplingConfig ~SamplingConfig(); + bool operator==(SamplingConfig const& other) const; + [[nodiscard]] SizeType getBeamWidth() const; [[nodiscard]] std::optional getTopK() const; [[nodiscard]] std::optional getTopP() const; @@ -68,6 +73,7 @@ class SamplingConfig [[nodiscard]] std::optional getEarlyStopping() const; private: + friend class Serialization; SizeType mBeamWidth; std::optional mTopK; std::optional mTopP; @@ -86,12 +92,16 @@ class SamplingConfig }; /// @brief Configuration that controls the outputs of a Result -struct OutputConfig +class OutputConfig { - bool returnLogProbs{false}; - bool returnContextLogits{false}; - bool returnGenerationLogits{false}; - bool excludeInputFromOutput{false}; +public: + OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false, + bool excludeInputFromOutput = false); + + bool returnLogProbs; + bool returnContextLogits; + bool returnGenerationLogits; + bool excludeInputFromOutput; }; /// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance @@ -109,6 +119,7 @@ class SpeculativeDecodingConfig [[nodiscard]] std::optional getAcceptanceThreshold() const; private: + friend class Serialization; VecTokens mTokens; std::optional mLogits; std::optional mAcceptanceThreshold; @@ -128,6 +139,7 @@ class PromptTuningConfig [[nodiscard]] Tensor getEmbeddingTable() const; private: + friend class Serialization; Tensor mEmbeddingTable; }; @@ -142,6 +154,8 @@ class LoraConfig [[nodiscard]] Tensor getConfig() const; private: + friend class Serialization; + Tensor mWeights; Tensor mConfig; }; @@ -207,6 +221,7 @@ class Request void setLoraConfig(LoraConfig loraConfig); private: + friend class Serialization; class Impl; std::unique_ptr mImpl; }; @@ -298,15 +313,49 @@ class KvCacheConfig SizeType const kDefaultIterStatsMaxIterations = 1000; +/// @brief A configuration class for the parallel execution parameters +/// Currently only supports commType = CommunicationType::kMPI +class ParallelConfig +{ +public: + /// @brief Constructor + /// @param commType The communication type. See CommunicationType. + /// @param commMode The communication mode. See CommunicationMode. + /// @param deviceIds The IDs of the GPUs involved in the execution of the model + /// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the + /// model. The first participant is considered to be the leader. + ParallelConfig(CommunicationType commType = CommunicationType::kMPI, + CommunicationMode commMode = CommunicationMode::kLEADER, + std::optional> deviceIds = std::nullopt, + std::optional> participantIds = std::nullopt); + ~ParallelConfig(); + + [[nodiscard]] CommunicationType getCommunicationType() const; + [[nodiscard]] CommunicationMode getCommunicationMode() const; + [[nodiscard]] std::optional> getDeviceIds() const; + [[nodiscard]] std::optional> getParticipantIds() const; + + void setCommunicationType(CommunicationType type); + void setCommunicationMode(CommunicationMode mode); + void setDeviceIds(std::vector deviceIds); + void setParticipantIds(std::vector participantIds); + +private: + CommunicationType mCommType; + CommunicationMode mCommMode; + std::optional> mDeviceIds; + std::optional> mParticipantIds; +}; + /// @brief Configuration class for the model executor class ExecutorConfig { public: ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(), KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true, - bool enableTrtOverlap = false, std::optional> deviceIds = std::nullopt, - SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, - BatchingType batchingType = BatchingType::kINFLIGHT); + bool enableTrtOverlap = false, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, + BatchingType batchingType = BatchingType::kINFLIGHT, + std::optional parallelConfig = std::nullopt); [[nodiscard]] SizeType getMaxBeamWidth() const; [[nodiscard]] SchedulerConfig getSchedulerConfig() const; @@ -314,9 +363,9 @@ class ExecutorConfig [[nodiscard]] bool getEnableChunkedContext() const; [[nodiscard]] bool getNormalizeLogProbs() const; [[nodiscard]] bool getEnableTrtOverlap() const; - [[nodiscard]] std::optional> getDeviceIds() const; [[nodiscard]] SizeType getIterStatsMaxIterations() const; [[nodiscard]] BatchingType getBatchingType() const; + [[nodiscard]] std::optional getParallelConfig() const; void setMaxBeamWidth(SizeType maxBeamWidth); void setSchedulerConfig(SchedulerConfig schedulerConfig); @@ -324,9 +373,9 @@ class ExecutorConfig void setEnableChunkedContext(bool enableChunkedContext); void setNormalizeLogProbs(bool normalizeLogProbs); void setEnableTrtOverlap(bool enableTrtOverlap); - void setDeviceIds(std::optional> deviceIds); void setIterStatsMaxIterations(SizeType iterStatsMaxIterations); void setBatchingType(BatchingType batchingType); + void setParallelConfig(ParallelConfig parallelConfig); private: SizeType mMaxBeamWidth; @@ -335,24 +384,11 @@ class ExecutorConfig bool mEnableChunkedContext; bool mNormalizeLogProbs; bool mEnableTrtOverlap; - std::optional> mDeviceIds; SizeType mIterStatsMaxIterations; BatchingType mBatchingType; + std::optional mParallelConfig; }; -/// TODO: -/// @brief A class to identify processes involved in the execution of a model -/// Currently only supports MPI communication -class Communicator -{ -public: - Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector const& commIds, - std::optional orchestratorId){}; - ~Communicator() = default; -}; - -class Model; - /// @brief The executor is responsible for receiving new requests and sending responses, and running the inference class Executor { @@ -364,14 +400,12 @@ class Executor /// @param modelType The type of model /// @param executorConfig The configuration for the executor /// @param comm An optional inter-process communicator configuration - Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig, - std::optional comm = std::nullopt); + Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig); Executor(std::vector const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType, - ExecutorConfig executorConfig, std::optional comm = std::nullopt); + ExecutorConfig executorConfig); - Executor( - std::shared_ptr model, ExecutorConfig executorConfig, std::optional comm = std::nullopt); + Executor(std::shared_ptr model, ExecutorConfig executorConfig); ~Executor(); diff --git a/cpp/include/tensorrt_llm/executor/tensor.h b/cpp/include/tensorrt_llm/executor/tensor.h index 8bf2851d6..9ce936685 100644 --- a/cpp/include/tensorrt_llm/executor/tensor.h +++ b/cpp/include/tensorrt_llm/executor/tensor.h @@ -180,11 +180,11 @@ class Tensor ~Tensor() = default; - Tensor(const Tensor& other) noexcept = default; + Tensor(Tensor const& other) noexcept = default; Tensor(Tensor&& other) noexcept = default; - Tensor& operator=(const Tensor& other) noexcept = default; + Tensor& operator=(Tensor const& other) noexcept = default; Tensor& operator=(Tensor&& other) noexcept = default; @@ -267,6 +267,7 @@ class Tensor friend std::shared_ptr const& detail::toITensor(Tensor const& tensor); friend Tensor detail::ofITensor(std::shared_ptr tensor); + friend class Serialization; }; } // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 34a7c6dc9..36872629b 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -155,21 +155,16 @@ enum class SchedulerPolicy kGUARANTEED_NO_EVICT = 1, }; -enum class CommunicatorType +enum class CommunicationType { kMPI = 0 }; -enum class CommMode +enum class CommunicationMode { - kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and - // therefore only the leader can enqueue requests and get responses - kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor - // and therefore only the leader can enqueue requests and get responses The orchestrator doesn't - // participate in the computations - kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API - // So they all need to send the same requests - // Responses will be the same for all participants + kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be + // broadcasted to the workers. All participants can get response via awaitResponses. The leader is the + // first participant in the provided participant IDS, or 0 if participant ID is not provided }; } // namespace tensorrt_llm::executor diff --git a/cpp/include/tensorrt_llm/runtime/decodingMode.h b/cpp/include/tensorrt_llm/runtime/decodingMode.h index 84e79a851..c3d75f542 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingMode.h +++ b/cpp/include/tensorrt_llm/runtime/decodingMode.h @@ -81,6 +81,11 @@ class DecodingMode using UnderlyingType = uint8_t; + bool operator==(DecodingMode const& other) const + { + return mState == other.mState; + } + private: constexpr DecodingMode(UnderlyingType state) : mState(state) diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 0a41bf92b..1e369d13b 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -17,10 +17,13 @@ #pragma once #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingMode.h" #include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/gptModelConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h" +#include "tensorrt_llm/runtime/worldConfig.h" #include #include @@ -59,7 +62,7 @@ class IGptDecoder DecodingInput const& decodingInput, BufferManager const& manager) = 0; - virtual const SamplingConfig& getSamplingConfig() = 0; + virtual SamplingConfig const& getSamplingConfig() = 0; static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds, ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, @@ -71,6 +74,11 @@ class IGptDecoder SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold, curandState_t* curandState, BufferManager::CudaStreamPtr const& stream); + static void updateKVCacheBasedOnAcceptedTokens(ITensor const& acceptedOffsets, ITensor const& packedAcceptedIds, + ITensor const& pointerArray, ITensor const& pastKeyValueLengths, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig, BufferManager::CudaStreamPtr stream, SizeType rewindDraftTokenCount, + SizeType maxAttentionWindow, SizeType maxBlocksPerSeq, nvinfer1::DataType dtype); + static std::unique_ptr create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength, BufferManager::CudaStreamPtr const& stream); @@ -97,7 +105,7 @@ class GptDecoder : public virtual IGptDecoder void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, BufferManager const& manager) override; - const SamplingConfig& getSamplingConfig() override + SamplingConfig const& getSamplingConfig() override { return mSamplingConfig; } diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h index b176bd58a..3ce80e103 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatch.h @@ -153,6 +153,18 @@ class GptDecoderBatch : public IGptDecoderBatch return mFinishedSum; } + //! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu + [[nodiscard]] TensorPtr getNextDraftTokens() const override + { + return mNextDraftTokens; + } + + //! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu + [[nodiscard]] TensorPtr getNextDraftTokenLengths() const override + { + return mNextDraftTokenLengths; + } + private: //! @brief Gather final beam search results for request `batchIdx`. [[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const; @@ -204,6 +216,8 @@ class GptDecoderBatch : public IGptDecoderBatch TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned + TensorPtr mNextDraftTokens; + TensorPtr mNextDraftTokenLengths; SizeType mMaxSequenceLength{}; SizeType mMaxAttentionWindow{}; SizeType mSinkTokenLength{}; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h index 73552dcbc..dd151fea1 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatch.h @@ -46,15 +46,10 @@ class Request , endId{endId} , computeCumLogProbs(false) , computeLogProbs(false) + , generatedTokensPerStep(1) { } - // the number of tokens generated per step - SizeType generatedTokensPerStep() const - { - return draftTokens ? draftTokens->getSize() + 1 : 1; - } - // mandatory parameters ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu SizeType inputLen; // the input length without draft tokens @@ -71,6 +66,7 @@ class Request bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request + SizeType generatedTokensPerStep; }; class Input @@ -184,6 +180,12 @@ class IGptDecoderBatch : public virtual IStatefulGptDecoder std::vector const& samplingConfigs) = 0; + //! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu + virtual TensorPtr getNextDraftTokens() const = 0; + + //! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu + virtual TensorPtr getNextDraftTokenLengths() const = 0; + protected: IGptDecoderBatch() = default; }; diff --git a/cpp/include/tensorrt_llm/runtime/ipcUtils.h b/cpp/include/tensorrt_llm/runtime/ipcUtils.h index 7e3935964..017861dae 100644 --- a/cpp/include/tensorrt_llm/runtime/ipcUtils.h +++ b/cpp/include/tensorrt_llm/runtime/ipcUtils.h @@ -36,7 +36,7 @@ class IpcMemory IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize); ~IpcMemory(); - [[nodiscard]] const std::vector& getCommPtrsTensor() const + [[nodiscard]] std::vector const& getCommPtrsTensor() const { return mCommPtrs; } diff --git a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h index cdc4c2e90..6a735460f 100644 --- a/cpp/include/tensorrt_llm/runtime/promptTuningParams.h +++ b/cpp/include/tensorrt_llm/runtime/promptTuningParams.h @@ -67,7 +67,7 @@ class PromptTuningParams : public GenericPromptTuningParams // Fill the tasks tensor for the batch using the provided tasksHost // Function assumes that the first numContextRequests requests in the batch are context requests void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests, - const std::vector& reqBeamWidths, const std::vector& reqPromptLengths, + std::vector const& reqBeamWidths, std::vector const& reqPromptLengths, BufferManager const& manager, bool packedInput); }; diff --git a/cpp/include/tensorrt_llm/runtime/samplingConfig.h b/cpp/include/tensorrt_llm/runtime/samplingConfig.h index 2b3e96c25..180e94a19 100644 --- a/cpp/include/tensorrt_llm/runtime/samplingConfig.h +++ b/cpp/include/tensorrt_llm/runtime/samplingConfig.h @@ -43,7 +43,7 @@ class SamplingConfig auto const hasValues = accessor(0).has_value(); for (size_t ci = 0; ci < configs.size(); ++ci) { - const auto& configValue = accessor(ci); + auto const& configValue = accessor(ci); TLLM_CHECK(hasValues == configValue.has_value()); if (hasValues) { diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 8bf09659f..a622763da 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -188,7 +188,6 @@ endif() set(TRTLLM_LINK_LIBS ${CUBLAS_LIB} ${CUBLASLT_LIB} - ${CUDNN_LIB} ${CMAKE_DL_LIBS} ${MPI_C_LIBRARIES} ${NCCL_LIB} diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 535bbe5e6..e9a6edcb6 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0ecc134ad10a54b2953c772e72db2f71e84130d5736087b033e9e5b78594db6d -size 2113376 +oid sha256:c56ee13bb109917ab10df168ca15e6057436df1cd8b64a4268c6e7aae78a5ad8 +size 2126310 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index f87299781..8caf13772 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9aa3f3d7f8313c099df8e9bd4c9707922a4f1c4025c4c99986acf6df781738c7 -size 2128450 +oid sha256:339532215fa4c16e68ca28ee23d0a0e09c9caefa7bd19b563d2f7b83cad6822e +size 2142070 diff --git a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt index 42bdc3127..35e4cafb8 100644 --- a/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/aarch64-linux-gnu/version.txt @@ -1,3 +1,3 @@ -add62ff328028bbcded1af694fe758c5 libtensorrt_llm_batch_manager_static.a -9e8846e200e2aaaeace862741a90c3ab libtensorrt_llm_batch_manager_static.pre_cxx11.a -230623fa285048a2de5c54c2cc0f364fb9f2c559 commit +c9c505e2cb6e95b7cfc124c04ab1fcb3 libtensorrt_llm_batch_manager_static.a +2f5cec5a5b42e0031bc2edc688c1e74b libtensorrt_llm_batch_manager_static.pre_cxx11.a +741fb083cc42933439ae54557b177b6d7064da4f commit diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a index 0391ed88a..a2cced8e4 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7b25de974b6ca5f0dcb279f16f38199167d1efc35c01770d3234bec2dfb5dc86 -size 2097848 +oid sha256:a4060f2d60472850344e5b5799f9ad88390f4ad9c056e3843f3bdbcc046ca68b +size 2106440 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a index bcf6e8380..0290f5c4e 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/libtensorrt_llm_batch_manager_static.pre_cxx11.a @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5f06cee5ae2bcf393196265cd9a3ef832690cd4c5c53934bbfb169d50ab33c41 -size 2055004 +oid sha256:829f1ed5af0b0d2577e57fd13979706fe0b3636bd6338aac3c34a615f64afedc +size 2064310 diff --git a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt index b66afa5aa..06938d29a 100644 --- a/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/batch_manager/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ -bb62a31b8e17dae284d784ba43d5bc02 libtensorrt_llm_batch_manager_static.a -19327f59c7f5b6235e15b322d5f5a0f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a +2db5c985786dad3dd16c22ec54af0803 libtensorrt_llm_batch_manager_static.a +96940249ff7b3ff09754b89ad25fcf9f libtensorrt_llm_batch_manager_static.pre_cxx11.a diff --git a/cpp/tensorrt_llm/common/allocator.h b/cpp/tensorrt_llm/common/allocator.h index a65b03962..49d4422e9 100644 --- a/cpp/tensorrt_llm/common/allocator.h +++ b/cpp/tensorrt_llm/common/allocator.h @@ -42,11 +42,11 @@ class IAllocator virtual ~IAllocator() = default; // no copying - IAllocator(const IAllocator&) = delete; - IAllocator& operator=(const IAllocator&) = delete; + IAllocator(IAllocator const&) = delete; + IAllocator& operator=(IAllocator const&) = delete; template - [[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true) + [[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, bool const setZero = true) { TLLM_LOG_TRACE(__PRETTY_FUNCTION__); // TODO martinma: why do we need this size extension? diff --git a/cpp/tensorrt_llm/common/assert.h b/cpp/tensorrt_llm/common/assert.h index 9a6c84e25..7f51dbf1b 100644 --- a/cpp/tensorrt_llm/common/assert.h +++ b/cpp/tensorrt_llm/common/assert.h @@ -23,7 +23,7 @@ namespace tensorrt_llm::common { -[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") +[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") { throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); } @@ -38,8 +38,10 @@ class DebugConfig #if defined(_WIN32) #define TLLM_LIKELY(x) (__assume((x) == 1), (x)) +#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x)) #else #define TLLM_LIKELY(x) __builtin_expect((x), 1) +#define TLLM_UNLIKELY(x) __builtin_expect((x), 0) #endif #define TLLM_CHECK(val) \ @@ -61,20 +63,22 @@ class DebugConfig #define TLLM_CHECK_DEBUG(val) \ do \ { \ - if (DebugConfig::isCheckDebugEnabled()) \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ { \ TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ } \ } while (0) -#define TLLM_CHECK_DEBUG_WITH_INFO(val, info) \ +#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \ do \ { \ - if (DebugConfig::isCheckDebugEnabled()) \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ { \ - TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \ + TLLM_LIKELY(static_cast(val)) \ + ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError( \ + __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ } \ } while (0) diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp index 961bd1375..c674bb6d8 100644 --- a/cpp/tensorrt_llm/common/cublasMMWrapper.cpp +++ b/cpp/tensorrt_llm/common/cublasMMWrapper.cpp @@ -42,7 +42,7 @@ CublasMMWrapper::~CublasMMWrapper() mMutex = nullptr; } -CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper) +CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) : mCublasHandle(wrapper.mCublasHandle) , mCublasLtHandle(wrapper.mCublasLtHandle) , mStream(wrapper.mStream) @@ -50,8 +50,8 @@ CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper) { } -void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, - const int k, const int lda, const int ldb, const int ldc) +void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc) { // -------------------------------------- // Create descriptors for the original matrices @@ -79,15 +79,15 @@ void CublasMMWrapper::destroyDescriptors() mCDesc = NULL; } -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc) +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) { Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); } -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, - const std::optional& heuristic) +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& heuristic) { if (heuristic) { @@ -102,8 +102,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c } } -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta) +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) { bool usingCublasLt = mAType == CUDA_R_16F; @@ -111,9 +111,9 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c /* usingCublasLt */ usingCublasLt); } -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta, - const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt) +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) { half h_alpha = (half) (f_alpha); half h_beta = (half) (f_beta); @@ -126,8 +126,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c int batch_count = 1; // fp32 use cublas as default // fp16 use cublasLt as default - const void* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; if (usingCublasLt) @@ -154,10 +154,10 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c } } -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, - const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, - const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha, - const float f_beta) +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, + const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, + float const f_beta) { half h_alpha = (half) f_alpha; half h_beta = (half) f_beta; @@ -165,26 +165,26 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati std::lock_guard lock(*mMutex); int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, - const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, - const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, - cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType) +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, + void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, + cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) { half h_alpha = (half) f_alpha; half h_beta = (half) f_beta; std::lock_guard lock(*mMutex); bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - const void* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - const void* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, @@ -267,8 +267,8 @@ void CublasMMWrapper::setStream(cudaStream_t stream) mStream = stream; } -bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, - const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo) +bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) { TLLM_CHECK_WITH_INFO( descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); @@ -291,12 +291,12 @@ bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr } std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, - cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc) + cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) { TLLM_CHECK_WITH_INFO( descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); - const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); + auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); sync_check_cuda_error(); diff --git a/cpp/tensorrt_llm/common/cublasMMWrapper.h b/cpp/tensorrt_llm/common/cublasMMWrapper.h index 91e2b85a9..85178ad1f 100644 --- a/cpp/tensorrt_llm/common/cublasMMWrapper.h +++ b/cpp/tensorrt_llm/common/cublasMMWrapper.h @@ -65,39 +65,39 @@ class CublasMMWrapper ~CublasMMWrapper(); - CublasMMWrapper(const CublasMMWrapper& wrapper); + CublasMMWrapper(CublasMMWrapper const& wrapper); /********************** GEMMs **********************/ - void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, - const int lda, const void* B, const int ldb, void* C, const int ldc); + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc); - void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, - const int lda, const void* B, const int ldb, void* C, const int ldc, - const std::optional& algo); + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& algo); - void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, - const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta); + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); - void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, - const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta, - const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt); + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB, - void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f, - const float f_beta = 0.0f); + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, + void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, + float const f_beta = 0.0f); - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B, - cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType, - const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType); + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, + cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, + int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); /********************** Tactic selection helpers **********************/ - bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo); + bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, - const int m, const int n, const int k, const int lda, const int ldb, const int ldc); + int const m, int const n, int const k, int const lda, int const ldb, int const ldc); std::vector getTactics(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, @@ -126,8 +126,8 @@ class CublasMMWrapper CublasDataType getCublasDataType(cudaDataType_t data_type); - void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, - const int lda, const int ldb, const int ldc); + void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc); void destroyDescriptors(); cublasHandle_t getCublasHandle() diff --git a/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp b/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp index efe1022b3..4c816d44c 100644 --- a/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp +++ b/cpp/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -43,7 +43,7 @@ CUDADriverWrapper::CUDADriverWrapper() handle = dllOpen(CUDA_LIB_NAME); TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - auto load_sym = [](void* handle, const char* name) + auto load_sym = [](void* handle, char const* name) { void* ret = dllGetSym(handle, name); return ret; @@ -69,7 +69,7 @@ CUDADriverWrapper::~CUDADriverWrapper() dllClose(handle); } -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, const char** pStr) const +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const { return (*_cuGetErrorName)(error, pStr); } @@ -94,7 +94,7 @@ CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const return (*_cuLinkDestroy)(state); } -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, const void* image) const +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const { return (*_cuModuleLoadData)(module, image); } @@ -105,24 +105,24 @@ CUresult CUDADriverWrapper::cuLinkCreate( return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); } -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const { return (*_cuModuleGetFunction)(hfunc, hmod, name); } -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const { return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); } -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, CUjit_option* options, void** optionValues) const { return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); } CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const { return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); } diff --git a/cpp/tensorrt_llm/common/cudaDriverWrapper.h b/cpp/tensorrt_llm/common/cudaDriverWrapper.h index 758bfa1f5..d5eb5f2d7 100644 --- a/cpp/tensorrt_llm/common/cudaDriverWrapper.h +++ b/cpp/tensorrt_llm/common/cudaDriverWrapper.h @@ -37,7 +37,7 @@ class CUDADriverWrapper ~CUDADriverWrapper(); - CUresult cuGetErrorName(CUresult error, const char** pStr) const; + CUresult cuGetErrorName(CUresult error, char const** pStr) const; CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; @@ -47,19 +47,19 @@ class CUDADriverWrapper CUresult cuLinkDestroy(CUlinkState state) const; - CUresult cuModuleLoadData(CUmodule* module, const void* image) const; + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; CUresult cuLinkCreate( unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const; + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const; + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions, + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, CUjit_option* options, void** optionValues) const; - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const; CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, @@ -72,18 +72,18 @@ class CUDADriverWrapper private: void* handle; - CUresult (*_cuGetErrorName)(CUresult, const char**); + CUresult (*_cuGetErrorName)(CUresult, char const**); CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); CUresult (*_cuModuleUnload)(CUmodule); CUresult (*_cuLinkDestroy)(CUlinkState); CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, const void*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void**); CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, @@ -91,11 +91,11 @@ class CUDADriverWrapper CUstream hStream, void** kernelParams, void** extra); }; -inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line) +inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line) { if (stat != CUDA_SUCCESS) { - const char* msg = nullptr; + char const* msg = nullptr; wrap.cuGetErrorName(stat, &msg); fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line); } diff --git a/cpp/tensorrt_llm/common/cudaFp8Utils.cu b/cpp/tensorrt_llm/common/cudaFp8Utils.cu index 2d8329386..b404742db 100644 --- a/cpp/tensorrt_llm/common/cudaFp8Utils.cu +++ b/cpp/tensorrt_llm/common/cudaFp8Utils.cu @@ -121,16 +121,16 @@ void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaSt } template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( - float* dst, const float* src, const int64_t numel, cudaStream_t stream); + float* dst, float const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize( - float* dst, const __nv_fp8_e4m3* src, const int64_t numel, cudaStream_t stream); + float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( - half* dst, const half* src, const int64_t numel, cudaStream_t stream); + half* dst, half const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* dst, const __nv_bfloat16* src, const int64_t numel, cudaStream_t stream); + __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); template void invokeFakeQuantize( - half* dst, const float* src, const int64_t numel, cudaStream_t stream); + half* dst, float const* src, const int64_t numel, cudaStream_t stream); __device__ float atomicMaxExtd(float* address, float val) { @@ -146,7 +146,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); // The address in 64 bits. - uint64_t address_u64 = reinterpret_cast(address); + uint64_t address_u64 = reinterpret_cast(address); // Pack the input value into 32 bits. union @@ -155,7 +155,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val) uint16_t u[2]; } old, tmp = {}; - const int loc = (address_u64 & 0x2) >> 1; + int const loc = (address_u64 & 0x2) >> 1; tmp.v[loc] = val; // 4B aligned pointer. @@ -223,7 +223,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons auto val = fabs(static_cast(weights[i])); max = max > val ? max : val; } - const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) if constexpr (std::is_same_v) { @@ -231,7 +231,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons } else { - const auto address_u64 = reinterpret_cast(quant_ptr + col); + auto const address_u64 = reinterpret_cast(quant_ptr + col); if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) atomicMaxExtd(quant_ptr + col, scale); else @@ -244,7 +244,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons } else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) { - const auto nrows = size / n; + auto const nrows = size / n; for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) { float max = 0.f; @@ -256,7 +256,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons max = blockReduceMax(max); if (threadIdx.x == 0) { - const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); quant_ptr[row] = scale; } } @@ -272,7 +272,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons max = blockReduceMax(max); if (threadIdx.x == 0) { - const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); atomicMaxExtd(quant_ptr, scale); } } @@ -326,19 +326,19 @@ __global__ void dynamicQuantizeMatrixPerToken( extern __shared__ __align__(sizeof(float)) char _shmem[]; T_IN* shmem = reinterpret_cast(_shmem); constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); - const auto nrows = numel / lda; + auto const nrows = numel / lda; for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) { float max = 0.f; for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) { - const auto in = input[row * lda + i]; + auto const in = input[row * lda + i]; shmem[i] = in; auto val = fabs(static_cast(in)); max = max > val ? max : val; } max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem - const auto s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) { // true means we are quantizing @@ -359,7 +359,7 @@ void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T { dim3 grid(numel / lda); bool use_shmem = true; - const auto shmem_size = lda * sizeof(T_IN); + auto const shmem_size = lda * sizeof(T_IN); if (shmem_size >= (48 << 10)) { cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, diff --git a/cpp/tensorrt_llm/common/cudaFp8Utils.h b/cpp/tensorrt_llm/common/cudaFp8Utils.h index ad9aa9457..aa93b55a5 100644 --- a/cpp/tensorrt_llm/common/cudaFp8Utils.h +++ b/cpp/tensorrt_llm/common/cudaFp8Utils.h @@ -181,37 +181,37 @@ struct PackType<__nv_fp8_e4m3, 8> }; #endif -__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, const __nv_fp8x4_e4m3* in) +__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in) { - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = __nv_bfloat162((float) reinterpret_cast(&tmp_val.x)[0], - (float) reinterpret_cast(&tmp_val.y)[0]); - *out2 = __nv_bfloat162((float) reinterpret_cast(&tmp_val.z)[0], - (float) reinterpret_cast(&tmp_val.w)[0]); + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); } -__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(const __nv_fp8x2_e4m3* in) +__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in) { - const char2 tmp_val = reinterpret_cast(in)[0]; - __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast(&tmp_val.x)[0], - (float) reinterpret_cast(&tmp_val.y)[0]); + const char2 tmp_val = reinterpret_cast(in)[0]; + __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); return out; } -__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, const __nv_fp8x4_e4m3* in) +__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in) { - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = half2((float) reinterpret_cast(&tmp_val.x)[0], - (float) reinterpret_cast(&tmp_val.y)[0]); - *out2 = half2((float) reinterpret_cast(&tmp_val.z)[0], - (float) reinterpret_cast(&tmp_val.w)[0]); + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); } -__inline__ __device__ half2 fp8x2_e4m3_to_half2(const __nv_fp8x2_e4m3* in) +__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in) { - const char2 tmp_val = reinterpret_cast(in)[0]; - half2 out = half2((float) reinterpret_cast(&tmp_val.x)[0], - (float) reinterpret_cast(&tmp_val.y)[0]); + const char2 tmp_val = reinterpret_cast(in)[0]; + half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); return out; } diff --git a/cpp/tensorrt_llm/common/cudaTypeUtils.cuh b/cpp/tensorrt_llm/common/cudaTypeUtils.cuh index 8a8a5bcea..2db8937af 100644 --- a/cpp/tensorrt_llm/common/cudaTypeUtils.cuh +++ b/cpp/tensorrt_llm/common/cudaTypeUtils.cuh @@ -32,14 +32,14 @@ namespace common { template -inline __device__ T ldg(const T* val) +inline __device__ T ldg(T const* val) { return __ldg(val); } #if ENABLE_BF16 template <> -inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) +inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 return val[0]; @@ -49,7 +49,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) } template <> -inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) +inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 return val[0]; diff --git a/cpp/tensorrt_llm/common/cudaUtils.h b/cpp/tensorrt_llm/common/cudaUtils.h index 90bca9637..c068c50eb 100644 --- a/cpp/tensorrt_llm/common/cudaUtils.h +++ b/cpp/tensorrt_llm/common/cudaUtils.h @@ -81,12 +81,12 @@ enum class OperationType }; /* **************************** debug tools ********************************* */ -static const char* _cudaGetErrorEnum(cudaError_t error) +static char const* _cudaGetErrorEnum(cudaError_t error) { return cudaGetErrorString(error); } -static const char* _cudaGetErrorEnum(cublasStatus_t error) +static char const* _cudaGetErrorEnum(cublasStatus_t error) { switch (error) { @@ -114,7 +114,7 @@ static const char* _cudaGetErrorEnum(cublasStatus_t error) } template -void check(T result, char const* const func, const char* const file, int const line) +void check(T result, char const* const func, char const* const file, int const line) { if (result) { @@ -133,7 +133,7 @@ inline bool isCudaLaunchBlocking() if (firstCall) { - const char* env = std::getenv("CUDA_LAUNCH_BLOCKING"); + char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); result = env != nullptr && std::string(env) == "1"; firstCall = false; } @@ -141,12 +141,12 @@ inline bool isCudaLaunchBlocking() return result; } -inline void syncAndCheck(const char* const file, int const line) +inline void syncAndCheck(char const* const file, int const line) { #ifndef NDEBUG - const bool checkError = true; + bool const checkError = true; #else - const bool checkError = isCudaLaunchBlocking(); + bool const checkError = isCudaLaunchBlocking(); #endif if (checkError) @@ -279,7 +279,7 @@ inline int getDeviceCount() /// Get the memory info /// \return The free and total amount of memory in bytes -inline std::tuple getDeviceMemoryInfo(const bool useUvm) +inline std::tuple getDeviceMemoryInfo(bool const useUvm) { if (useUvm) { @@ -351,7 +351,7 @@ auto constexpr ceilDiv(T numerator, U denominator) } template -void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string name = "") +void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "") { if (buf == nullptr) { @@ -390,9 +390,9 @@ void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string } template -void printToStream(const T* result, const int size, FILE* strm) +void printToStream(T const* result, int const size, FILE* strm) { - const bool split_rows = (strm == stdout); + bool const split_rows = (strm == stdout); if (result == nullptr) { TLLM_LOG_WARNING("It is an nullptr, skip! \n"); @@ -414,13 +414,13 @@ void printToStream(const T* result, const int size, FILE* strm) } template -void printToScreen(const T* result, const int size) +void printToScreen(T const* result, int const size) { printToStream(result, size, stdout); } template -void print2dToStream(const T* result, const int r, const int c, const int stride, FILE* strm) +void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm) { if (result == nullptr) { @@ -429,20 +429,20 @@ void print2dToStream(const T* result, const int r, const int c, const int stride } for (int ri = 0; ri < r; ++ri) { - const T* ptr = result + ri * stride; + T const* ptr = result + ri * stride; printToStream(ptr, c, strm); } fprintf(strm, "\n"); } template -void print2dToScreen(const T* result, const int r, const int c, const int stride) +void print2dToScreen(T const* result, int const r, int const c, int const stride) { print2dToStream(result, r, c, stride, stdout); } template -void print2dToFile(std::string fname, const T* result, const int r, const int c, const int stride) +void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride) { FILE* fp = fopen(fname.c_str(), "wt"); if (fp != nullptr) @@ -493,7 +493,7 @@ inline void print_element_(int64_t ill) } template -inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_ptr) +inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr) { T* tmp; if (is_device_ptr) @@ -538,14 +538,14 @@ inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_p } } -template void printMatrix(const float* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(const half* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr); #ifdef ENABLE_BF16 -template void printMatrix(const __nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr); #endif -template void printMatrix(const uint32_t* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(const uint64_t* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(const int* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr); } // namespace tensorrt_llm::common diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 6aed80086..26af333c1 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -25,7 +25,7 @@ namespace tensorrt_llm::common // XQA kernels (optimized kernels for generation phase). bool forceXQAKernels() { - const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA"); + char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA"); static bool forceXQA = false; if (force_xqa_env_var != nullptr) { @@ -45,7 +45,7 @@ bool getEnvMmhaMultiblockDebug() if (!init) { init = true; - const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); + char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); if (enable_mmha_debug_var) { if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') @@ -64,7 +64,7 @@ int getEnvMmhaBlocksPerSequence() if (!init) { init = true; - const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); + char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); if (mmhaBlocksPerSequenceEnv) { mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); diff --git a/cpp/tensorrt_llm/common/logger.cpp b/cpp/tensorrt_llm/common/logger.cpp index c0fb2a3f8..dcbeef9d3 100644 --- a/cpp/tensorrt_llm/common/logger.cpp +++ b/cpp/tensorrt_llm/common/logger.cpp @@ -65,5 +65,4 @@ Logger* Logger::getLogger() thread_local Logger instance; return &instance; } - } // namespace tensorrt_llm::common diff --git a/cpp/tensorrt_llm/common/logger.h b/cpp/tensorrt_llm/common/logger.h index 25f67ac2f..a5eda656e 100644 --- a/cpp/tensorrt_llm/common/logger.h +++ b/cpp/tensorrt_llm/common/logger.h @@ -54,26 +54,26 @@ class Logger #if defined(_MSC_VER) template - void log(Level level, char const* format, const Args&... args); + void log(Level level, char const* format, Args const&... args); template - void log(Level level, int rank, char const* format, const Args&... args); + void log(Level level, int rank, char const* format, Args const&... args); #else template - void log(Level level, char const* format, const Args&... args) __attribute__((format(printf, 3, 0))); + void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); template - void log(Level level, int rank, char const* format, const Args&... args) __attribute__((format(printf, 4, 0))); + void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); #endif template - void log(Level level, std::string const& format, const Args&... args) + void log(Level level, std::string const& format, Args const&... args) { return log(level, format.c_str(), args...); } template - void log(const Level level, const int rank, const std::string& format, const Args&... args) + void log(const Level level, int const rank, std::string const& format, Args const&... args) { return log(level, rank, format.c_str(), args...); } @@ -122,7 +122,7 @@ class Logger return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); } - static inline std::string getPrefix(const Level level, const int rank) + static inline std::string getPrefix(const Level level, int const rank) { return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); } @@ -148,7 +148,7 @@ void Logger::log(Logger::Level level, char const* format, Args const&... args) } template -void Logger::log(const Logger::Level level, const int rank, char const* format, const Args&... args) +void Logger::log(const Logger::Level level, int const rank, char const* format, Args const&... args) { if (level_ <= level) { diff --git a/cpp/tensorrt_llm/common/memoryUtils.cu b/cpp/tensorrt_llm/common/memoryUtils.cu index ee39dfb7d..d13217b20 100644 --- a/cpp/tensorrt_llm/common/memoryUtils.cu +++ b/cpp/tensorrt_llm/common/memoryUtils.cu @@ -112,63 +112,63 @@ template void deviceFill(int* devptr, size_t size, int value, cudaStream_t strea template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); template -void cudaD2Hcpy(T* tgt, const T* src, const size_t size) +void cudaD2Hcpy(T* tgt, T const* src, const size_t size) { check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); } -template void cudaD2Hcpy(float* tgt, const float* src, size_t size); -template void cudaD2Hcpy(half* tgt, const half* src, size_t size); +template void cudaD2Hcpy(float* tgt, float const* src, size_t size); +template void cudaD2Hcpy(half* tgt, half const* src, size_t size); #ifdef ENABLE_BF16 -template void cudaD2Hcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size); +template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); #endif -template void cudaD2Hcpy(int* tgt, const int* src, size_t size); -template void cudaD2Hcpy(bool* tgt, const bool* src, size_t size); +template void cudaD2Hcpy(int* tgt, int const* src, size_t size); +template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); #ifdef ENABLE_FP8 -template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size); +template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); #endif -template void cudaD2Hcpy(unsigned long long* tgt, const unsigned long long* src, size_t size); -template void cudaD2Hcpy(unsigned int* tgt, const unsigned int* src, size_t size); -template void cudaD2Hcpy(int8_t* tgt, const int8_t* src, size_t size); +template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); template -void cudaH2Dcpy(T* tgt, const T* src, const size_t size) +void cudaH2Dcpy(T* tgt, T const* src, const size_t size) { check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); } -template void cudaH2Dcpy(float* tgt, const float* src, size_t size); -template void cudaH2Dcpy(half* tgt, const half* src, size_t size); +template void cudaH2Dcpy(float* tgt, float const* src, size_t size); +template void cudaH2Dcpy(half* tgt, half const* src, size_t size); #ifdef ENABLE_BF16 -template void cudaH2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size); +template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); #endif -template void cudaH2Dcpy(int* tgt, const int* src, size_t size); -template void cudaH2Dcpy(bool* tgt, const bool* src, size_t size); +template void cudaH2Dcpy(int* tgt, int const* src, size_t size); +template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); #ifdef ENABLE_FP8 -template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size); +template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); #endif -template void cudaH2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size); -template void cudaH2Dcpy(unsigned int* tgt, const unsigned int* src, size_t size); -template void cudaH2Dcpy(int8_t* tgt, const int8_t* src, size_t size); +template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); template -void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream) +void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) { check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); } -template void cudaD2Dcpy(float* tgt, const float* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(half* tgt, const half* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); #ifdef ENABLE_BF16 -template void cudaD2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); #endif -template void cudaD2Dcpy(int* tgt, const int* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); #ifdef ENABLE_FP8 -template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); #endif -template void cudaD2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); template __global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) @@ -204,7 +204,7 @@ template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const si #endif template -void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream) +void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) { if (stream != NULL) { @@ -216,19 +216,19 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream) } } -template void cudaAutoCpy(float* tgt, const float* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(half* tgt, const half* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); #ifdef ENABLE_BF16 -template void cudaAutoCpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); #endif -template void cudaAutoCpy(int* tgt, const int* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint8_t* tgt, const uint8_t* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint32_t* tgt, const uint32_t* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long* tgt, const unsigned long* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(char* tgt, const char* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); @@ -242,7 +242,7 @@ template void cudaAutoCpy( unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); template -__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const int seq_offset) +__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; curandState_t local_state; @@ -254,7 +254,7 @@ __global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const i } template <> -__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, const int seq_offset) +__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; curandState_t local_state; @@ -266,7 +266,7 @@ __global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, } template <> -__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, const int seq_offset) +__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; curandState_t local_state; @@ -278,7 +278,7 @@ __global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size } template <> -__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, const int seq_offset) +__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; curandState_t local_state; @@ -462,30 +462,30 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); } -template void invokeCudaD2DcpyConvert(int8_t* tgt, const float* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, const int8_t* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, const int* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, const int* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, const float* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, const float* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, const half* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(uint32_t* tgt, const int* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, const uint32_t* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, const float* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, const half* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const float* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const int* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); #endif // ENABLE_BF16 template __global__ void cudaD2DScaleCpyConvert( - T_OUT* dst, const T_IN* src, const float* scale, bool invert_scale, const size_t size) + T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) { - const float scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; + float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) { dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); @@ -494,7 +494,7 @@ __global__ void cudaD2DScaleCpyConvert( template void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream) + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) { cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); } @@ -524,7 +524,7 @@ void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaSt } template -void saveToBinary(const T* ptr, const size_t size, std::string filename) +void saveToBinary(T const* ptr, const size_t size, std::string filename) { std::vector h_ptr(size); @@ -541,14 +541,14 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename) out.write((char*) float_ptr.data(), size * sizeof(float)); } -template void saveToBinary(const float* ptr, const size_t size, std::string filename); -template void saveToBinary(const half* ptr, const size_t size, std::string filename); +template void saveToBinary(float const* ptr, const size_t size, std::string filename); +template void saveToBinary(half const* ptr, const size_t size, std::string filename); #ifdef ENABLE_BF16 -template void saveToBinary(const __nv_bfloat16* ptr, const size_t size, std::string filename); +template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); #endif // ENABLE_BF16 template <> -void saveToBinary(const int* ptr, const size_t size, std::string filename) +void saveToBinary(int const* ptr, const size_t size, std::string filename) { std::vector h_ptr(size); cudaD2Hcpy(h_ptr.data(), ptr, size); @@ -831,7 +831,7 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt) } template -__global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_within_range) +__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { @@ -844,7 +844,7 @@ __global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_ } template -bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) +bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) { cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); @@ -858,12 +858,12 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_ } template bool invokeCheckRange( - const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); + int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); /* * Determine the total workspace size based on a vector containing multiple variable sizes. */ -size_t calcAlignedSize(const std::vector& sizes, const size_t ALIGN_BYTES) +size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) { const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); // Check ALIGN_BYTES is a power of 2 @@ -885,7 +885,7 @@ size_t calcAlignedSize(const std::vector& sizes, const size_t ALIGN_BYTE * of each variable. */ void calcAlignedPointers( - std::vector& outPtrs, const void* p, const std::vector& sizes, size_t ALIGN_BYTES) + std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) { const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); // Check ALIGN_BYTES is a power of 2 diff --git a/cpp/tensorrt_llm/common/memoryUtils.h b/cpp/tensorrt_llm/common/memoryUtils.h index c70bf0c06..6d2c18c90 100644 --- a/cpp/tensorrt_llm/common/memoryUtils.h +++ b/cpp/tensorrt_llm/common/memoryUtils.h @@ -40,16 +40,16 @@ template void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); template -void cudaD2Hcpy(T* tgt, const T* src, const size_t size); +void cudaD2Hcpy(T* tgt, T const* src, const size_t size); template -void cudaH2Dcpy(T* tgt, const T* src, const size_t size); +void cudaH2Dcpy(T* tgt, T const* src, const size_t size); template -void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL); +void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL); template -void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL); +void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL); template void cudaRandomUniform(T* buffer, const size_t size); @@ -234,9 +234,9 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud template void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0); + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0); -inline bool checkIfFileExist(const std::string& file_path) +inline bool checkIfFileExist(std::string const& file_path) { std::ifstream in(file_path, std::ios::in | std::ios::binary); if (in.is_open()) @@ -248,7 +248,7 @@ inline bool checkIfFileExist(const std::string& file_path) } template -void saveToBinary(const T* ptr, const size_t size, std::string filename); +void saveToBinary(T const* ptr, const size_t size, std::string filename); template void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream); @@ -256,10 +256,10 @@ void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream); size_t cuda_datatype_size(TRTLLMCudaDataType dt); template -bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream); +bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream); -size_t calcAlignedSize(const std::vector& sizes, size_t ALIGN_BYTES = 256); +size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = 256); void calcAlignedPointers( - std::vector& outPtrs, const void* p, const std::vector& sizes, size_t ALIGN_BYTES = 256); + std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES = 256); } // namespace common } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/common/mpiUtils.cpp b/cpp/tensorrt_llm/common/mpiUtils.cpp index aeaf5e671..9c8cb2856 100644 --- a/cpp/tensorrt_llm/common/mpiUtils.cpp +++ b/cpp/tensorrt_llm/common/mpiUtils.cpp @@ -50,6 +50,7 @@ MPI_Datatype getMpiDtype(MpiType dtype) {MpiType::kUINT64, MPI_UINT64_T}, {MpiType::kFP8, MPI_UINT8_T}, {MpiType::kBF16, MPI_UINT16_T}, + {MpiType::kCHAR, MPI_CHAR}, }; return dtype_map.at(dtype); } @@ -126,23 +127,6 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm)); } -void MpiComm::bcast(std::vector& packed, int root) const -{ - int64_t nWords1; - auto const rank = getRank(); - if (rank == root) - { - nWords1 = static_cast(packed.size()); - } - auto const mpiInt64 = MpiTypeConverter::value; - bcast(&nWords1, 1, mpiInt64, root); - if (rank != root) - { - packed.resize(nWords1); - } - bcast(packed.data(), packed.size(), mpiInt64, root); -} - void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const { MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm)); @@ -162,12 +146,12 @@ MpiComm MpiComm::split(int color, int key) const return MpiComm{splitComm, true}; } -void MpiComm::allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const +void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const { MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); } -void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const +void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const { MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); } diff --git a/cpp/tensorrt_llm/common/quantization.h b/cpp/tensorrt_llm/common/quantization.h index 61a7cea41..93fd29600 100644 --- a/cpp/tensorrt_llm/common/quantization.h +++ b/cpp/tensorrt_llm/common/quantization.h @@ -39,7 +39,7 @@ class QuantMode constexpr QuantMode(QuantMode const&) noexcept = default; - constexpr QuantMode& operator=(const QuantMode& other) noexcept = default; + constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; static constexpr QuantMode none() noexcept { @@ -276,32 +276,32 @@ class QuantMode return quantMode; } - constexpr QuantMode operator+(const QuantMode& other) const noexcept + constexpr QuantMode operator+(QuantMode const& other) const noexcept { return QuantMode(mValue | other.mValue); } - constexpr QuantMode& operator+=(const QuantMode& other) noexcept + constexpr QuantMode& operator+=(QuantMode const& other) noexcept { return *this = *this + other; } - constexpr QuantMode operator-(const QuantMode& other) const noexcept + constexpr QuantMode operator-(QuantMode const& other) const noexcept { return QuantMode(mValue & ~other.mValue); } - constexpr QuantMode& operator-=(const QuantMode& other) noexcept + constexpr QuantMode& operator-=(QuantMode const& other) noexcept { return *this = *this - other; } - constexpr bool operator==(const QuantMode& other) const noexcept + constexpr bool operator==(QuantMode const& other) const noexcept { return mValue == other.mValue; } - constexpr bool operator!=(const QuantMode& other) const noexcept + constexpr bool operator!=(QuantMode const& other) const noexcept { return !(*this == other); } diff --git a/cpp/tensorrt_llm/common/reduceKernelUtils.cuh b/cpp/tensorrt_llm/common/reduceKernelUtils.cuh index 1e924fb1c..16f048a7e 100644 --- a/cpp/tensorrt_llm/common/reduceKernelUtils.cuh +++ b/cpp/tensorrt_llm/common/reduceKernelUtils.cuh @@ -63,11 +63,11 @@ struct BytesToType<16> }; template -__device__ inline void copy(const void* local, void* data) +__device__ inline void copy(void const* local, void* data) { using T = typename BytesToType::type; - const T* in = static_cast(local); + T const* in = static_cast(local); T* out = static_cast(data); *out = *in; } @@ -257,8 +257,8 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); - const int tid = cta.thread_rank(); - const int blockz = blockDim.x; + int const tid = cta.thread_rank(); + int const blockz = blockDim.x; for (int i = 0; i < NUM; i++) { #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) @@ -325,7 +325,7 @@ struct TopK __device__ __forceinline__ void init() { - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; for (int i = 0; i < MAX_K; i++) @@ -337,7 +337,7 @@ struct TopK }; template -__device__ __forceinline__ TopK reduce_topk_op(const TopK& a, const TopK& b) +__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) { TopK res = a; for (int i = 0; i < MAX_K; ++i) @@ -368,19 +368,19 @@ struct TopK_2 }; template -__device__ __forceinline__ TopK_2 reduce_topk_op_2(const TopK_2& a, const TopK_2& b) +__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) { return a.u > b.u ? a : b; } template -__device__ __forceinline__ T clamp_inf_for_half(const float input) +__device__ __forceinline__ T clamp_inf_for_half(float const input) { return input; } template <> -__device__ __forceinline__ half clamp_inf_for_half(const float input) +__device__ __forceinline__ half clamp_inf_for_half(float const input) { // clamp inf values to enable fp16 training return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); diff --git a/cpp/tensorrt_llm/common/tensor.cpp b/cpp/tensorrt_llm/common/tensor.cpp index 1cfce46cc..f6be58826 100644 --- a/cpp/tensorrt_llm/common/tensor.cpp +++ b/cpp/tensorrt_llm/common/tensor.cpp @@ -152,7 +152,7 @@ Tensor Tensor::slice(std::vector shape, size_t offset) const return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset)); } -TensorMap::TensorMap(const std::unordered_map& tensor_map) +TensorMap::TensorMap(std::unordered_map const& tensor_map) { for (auto& kv : tensor_map) { @@ -167,7 +167,7 @@ TensorMap::TensorMap(const std::unordered_map& tensor_map) } } -TensorMap::TensorMap(const std::vector& tensor_map) +TensorMap::TensorMap(std::vector const& tensor_map) { for (size_t i = 0; i < tensor_map.size(); i++) { diff --git a/cpp/tensorrt_llm/common/tensor.h b/cpp/tensorrt_llm/common/tensor.h index 748269c8c..f303e2cef 100644 --- a/cpp/tensorrt_llm/common/tensor.h +++ b/cpp/tensorrt_llm/common/tensor.h @@ -191,7 +191,7 @@ struct TensorDataType }; template <> -struct TensorDataType +struct TensorDataType { static constexpr DataType value = TYPE_INT32_PTR; }; @@ -419,8 +419,8 @@ class TensorMap public: TensorMap() = default; - TensorMap(const std::unordered_map& tensor_map); - TensorMap(const std::vector& tensor_map); + TensorMap(std::unordered_map const& tensor_map); + TensorMap(std::vector const& tensor_map); TensorMap(std::initializer_list> tensor_map); ~TensorMap(); @@ -429,7 +429,7 @@ class TensorMap return tensor_map_.size(); } - inline bool contains(const std::string& key) const + inline bool contains(std::string const& key) const { TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str()); return tensor_map_.find(key) != tensor_map_.end(); @@ -437,7 +437,7 @@ class TensorMap std::vector keys() const; - inline void insert(const std::string& key, const Tensor& value) + inline void insert(std::string const& key, Tensor const& value) { TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str())); TLLM_CHECK_WITH_INFO( @@ -445,7 +445,7 @@ class TensorMap tensor_map_.insert({key, value}); } - inline void insertIfValid(const std::string& key, const Tensor& value) + inline void insertIfValid(std::string const& key, Tensor const& value) { if (value.isValid()) { @@ -462,7 +462,7 @@ class TensorMap Tensor at(int tmp) = delete; Tensor at(size_t tmp) = delete; - inline Tensor& at(const std::string& key) + inline Tensor& at(std::string const& key) { TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_CHECK_WITH_INFO(contains(key), @@ -471,7 +471,7 @@ class TensorMap return tensor_map_.at(key); } - inline Tensor at(const std::string& key) const + inline Tensor at(std::string const& key) const { TLLM_CHECK_WITH_INFO(contains(key), fmtstr( @@ -479,7 +479,7 @@ class TensorMap return tensor_map_.at(key); } - inline std::optional atOpt(const std::string& key) const + inline std::optional atOpt(std::string const& key) const { if (contains(key)) return tensor_map_.at(key); @@ -487,7 +487,7 @@ class TensorMap return std::nullopt; } - inline Tensor& at(const std::string& key, Tensor& default_tensor) + inline Tensor& at(std::string const& key, Tensor& default_tensor) { TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (contains(key)) @@ -497,7 +497,7 @@ class TensorMap return default_tensor; } - inline Tensor at(const std::string& key, Tensor& default_tensor) const + inline Tensor at(std::string const& key, Tensor& default_tensor) const { TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (contains(key)) @@ -507,7 +507,7 @@ class TensorMap return default_tensor; } - inline Tensor& at(const std::string& key, Tensor&& default_tensor) + inline Tensor& at(std::string const& key, Tensor&& default_tensor) { TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); if (contains(key)) @@ -517,7 +517,7 @@ class TensorMap return default_tensor; } - inline Tensor at(const std::string& key, Tensor&& default_tensor) const + inline Tensor at(std::string const& key, Tensor&& default_tensor) const { if (contains(key)) { @@ -527,7 +527,7 @@ class TensorMap } template - inline T getVal(const std::string& key) const + inline T getVal(std::string const& key) const { TLLM_CHECK_WITH_INFO(contains(key), fmtstr( @@ -536,7 +536,7 @@ class TensorMap } template - inline std::optional getValOpt(const std::string& key) const + inline std::optional getValOpt(std::string const& key) const { if (contains(key)) { @@ -549,7 +549,7 @@ class TensorMap } template - inline T getVal(const std::string& key, T default_value) const + inline T getVal(std::string const& key, T default_value) const { if (contains(key)) { @@ -559,7 +559,7 @@ class TensorMap } template - inline T getValWithOffset(const std::string& key, size_t index) const + inline T getValWithOffset(std::string const& key, size_t index) const { TLLM_CHECK_WITH_INFO(contains(key), fmtstr( @@ -568,7 +568,7 @@ class TensorMap } template - inline T getValWithOffset(const std::string& key, size_t index, T default_value) const + inline T getValWithOffset(std::string const& key, size_t index, T default_value) const { if (contains(key)) { @@ -578,7 +578,7 @@ class TensorMap } template - inline T* getPtr(const std::string& key) const + inline T* getPtr(std::string const& key) const { TLLM_CHECK_WITH_INFO(contains(key), fmtstr( @@ -587,7 +587,7 @@ class TensorMap } template - inline T* getPtr(const std::string& key, T* default_ptr) const + inline T* getPtr(std::string const& key, T* default_ptr) const { if (contains(key)) { @@ -597,7 +597,7 @@ class TensorMap } template - inline T* getPtrWithOffset(const std::string& key, size_t index) const + inline T* getPtrWithOffset(std::string const& key, size_t index) const { TLLM_CHECK_WITH_INFO(contains(key), fmtstr( @@ -606,7 +606,7 @@ class TensorMap } template - inline T* getPtrWithOffset(const std::string& key, size_t index, T* default_ptr) const + inline T* getPtrWithOffset(std::string const& key, size_t index, T* default_ptr) const { if (contains(key)) { diff --git a/cpp/tensorrt_llm/common/tllmException.cpp b/cpp/tensorrt_llm/common/tllmException.cpp index f87c9e159..64938f8b0 100644 --- a/cpp/tensorrt_llm/common/tllmException.cpp +++ b/cpp/tensorrt_llm/common/tllmException.cpp @@ -34,7 +34,7 @@ int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; #if !defined(_MSC_VER) -TllmException::TllmException(char const* file, std::size_t line, const std::string& msg) +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) : std::runtime_error{""} { mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); @@ -43,7 +43,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); } #else -TllmException::TllmException(char const* file, std::size_t line, const std::string& msg) +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) : mNbFrames{} , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} { diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h index 2ed13dde1..f3c622b88 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -65,7 +65,7 @@ __forceinline__ __device__ float copysignf_pos(float a, float b) __forceinline__ __device__ float tanh_opt(float x) { #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - const float exp_val = -1.f * fabs(2 * x); + float const exp_val = -1.f * fabs(2 * x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #else return fast_tanh(x); @@ -76,7 +76,7 @@ __forceinline__ __device__ float tanh_opt(float x) template <> struct GELU_taylor { - static const bool kIsHeavy = true; + static bool const kIsHeavy = true; CUTLASS_DEVICE float operator()(float const& z) const diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h index 1781fc3ac..d3d4d0a45 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -157,8 +157,8 @@ class EpilogueVisitorPerRowPerCol MatrixCoord extent_real_; ElementwiseFunctor elementwise_; - const bool per_token_quant_; - const bool per_channel_quant_; + bool const per_token_quant_; + bool const per_channel_quant_; AlphaScaleElementType* ptr_alpha_row_; AlphaScaleElementType* ptr_alpha_col_; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h index c33d90e93..bfd3666b9 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -65,7 +65,7 @@ namespace device ///////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk, +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, int64_t* splitk_buffer_offsets) { // in_tensor: [problem_idx, k_partition, hidden_size] @@ -73,9 +73,9 @@ __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const // so, we need to use splitk_buffer_offsets. // out_tensor: problem_idx * [hidden_size] - const int problem_idx = blockIdx.y; + int const problem_idx = blockIdx.y; GemmCoord problem = problem_sizes[problem_idx]; - const int hidden_size = problem.m() * problem.n(); + int const hidden_size = problem.m() * problem.n(); const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; T_OUT* out_tensor_ = out_tensor[problem_idx]; @@ -143,7 +143,7 @@ class BaseSplitkGrouped private: /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) { int32_t tiles = 0; for (int32_t i = 0; i < problem_count; ++i) @@ -182,7 +182,7 @@ class BaseSplitkGrouped /// Reorder `data` according to `indices` template - static void reorder_array(T* data, const std::vector& indices) + static void reorder_array(T* data, std::vector const& indices) { // For now, simply create a copy of the data and then copy over to the original. std::vector copy(indices.size()); @@ -314,7 +314,7 @@ class BaseSplitkGrouped /// Computes the number of threadblocks to launch for the grouped kernel static int sufficient( - const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) { // Determine the number of blocks that would be launched to fill up a single // wave on the GPU with each SM having maximum occupancy. diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index afa28606a..053f73103 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -142,7 +142,7 @@ struct GemmFpAIntB Arguments() {} CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, const int group_size, + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, typename Epilogue::OutputTileIterator::TensorRef ref_C, @@ -206,7 +206,7 @@ struct GemmFpAIntB } CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, const int gemm_k_size, + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, void* workspace = nullptr) : problem_size(args.problem_size) , group_size(args.group_size) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h index 5677f31d5..06da3848f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -174,7 +174,7 @@ struct MoeFCGemm /// Ctor CUTLASS_HOST_DEVICE Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, - const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) : problem_count(problem_count) diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h index cd9270d14..796dc2fe7 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -119,7 +119,7 @@ struct BaseMoeProblemVisitor /// Get the grid shape CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) { return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), @@ -177,12 +177,12 @@ struct BaseMoeProblemVisitor } CUTLASS_HOST_DEVICE - static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) { return ProblemSizeHelper::tile_count(grid); } - static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) { int32_t total_tiles = 0; for (int32_t i = 0; i < problem_count; ++i) @@ -328,12 +328,12 @@ struct MoeProblemVisitor CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, - const int warp_tileB_k_offset) + int const warp_tileB_k_offset) { warp_mma(D, A, B, C); } @@ -68,7 +68,7 @@ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& template CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, const int warp_tileB_k_offset) + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) { warp_mma(D, A, B, C, warp_tileB_k_offset); } diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h index 76564f14b..a10b95866 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -572,8 +572,8 @@ class DqMmaMultistagewarp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { this->warp_tile_iterator_B_.set_kgroup_index( diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h index 5ec515c28..9b0140f6e 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -219,7 +219,7 @@ class DqMmaMultistagewarp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { this->warp_tile_iterator_B_.set_kgroup_index( diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h index e8f5a92c3..aa692681a 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -184,7 +184,7 @@ class DqMmaPipelined : public DqMmaBase=80. int thread_idx, ///< ID within the threadblock @@ -353,8 +353,8 @@ class DqMmaPipelined : public DqMmaBasewarp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); ++this->warp_tile_iterator_A_; - const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h index 8e9451694..c7f51d6fe 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -218,7 +218,7 @@ class MmaTensorOpComputeBWithF16 /// Performs a warp-level matrix multiply-accumulate operation CUTLASS_DEVICE void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - const int warp_tileB_k_offset) const + int const warp_tileB_k_offset) const { using MmaOperandA = typename ArchMmaOperator::FragmentA; diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h index bdac36fd9..2b0d05396 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -136,11 +136,11 @@ class MmaTensorOpDequantizer= 800) && defined(ENABLE_BF16)) using _MmaOperandB = typename ArchMmaOperator::FragmentB; @@ -174,7 +174,7 @@ class MmaTensorOpDequantizer(&scale_frag); + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); CUTLASS_PRAGMA_UNROLL for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) @@ -222,7 +222,7 @@ class MmaTensorOpDequantizer= 800) && defined(ENABLE_BF16)) using _MmaOperandB = typename ArchMmaOperator::FragmentB; @@ -231,8 +231,8 @@ class MmaTensorOpDequantizer(&scale_frag); - const __nv_bfloat16* zero_ptr = reinterpret_cast(&zero_frag); + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); CUTLASS_PRAGMA_UNROLL @@ -335,11 +335,11 @@ class MmaTensorOpDequantizer; @@ -406,7 +406,7 @@ class MmaTensorOpDequantizer; @@ -505,11 +505,11 @@ class MmaTensorOpDequantizer::value / 8; const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; @@ -182,11 +182,11 @@ class FineGrainedScaleZeroIterator -__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, - const int** parent_ids_buf, const int* batch_slots, int batch_size, int beam_width, int max_seq_len, - const int* no_repeat_ngram_size_buf, int vocab_size_padded, const int* sequence_lengths) +__global__ void ban_repeat_ngram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf, + int const** parent_ids_buf, int const* batch_slots, int batch_size, int beam_width, int max_seq_len, + int const* no_repeat_ngram_size_buf, int vocab_size_padded, int const* sequence_lengths) { /** * Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching @@ -46,13 +46,13 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi * in-bound positions only. For leftside out-of-boundary tokens, access by global memory. */ - const int output_idx = blockIdx.x * blockDim.x + threadIdx.x; - const int local_batch_idx = blockIdx.y / beam_width; + int const output_idx = blockIdx.x * blockDim.x + threadIdx.x; + int const local_batch_idx = blockIdx.y / beam_width; auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx; - const int beam_idx = blockIdx.y % beam_width; - const bool beam_search = beam_width > 1; - const int no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot]; - const int step = sequence_lengths[batch_slot]; + int const beam_idx = blockIdx.y % beam_width; + bool const beam_search = beam_width > 1; + int const no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot]; + int const step = sequence_lengths[batch_slot]; // case 1: ngram_size == 0 --> this means no ngram limit // case 2: generated length must be greater than ngram_size to do ngram check @@ -133,9 +133,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi } template -void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, - const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width, - int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream) +void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf, + int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width, + int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream) { // each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation // getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while diff --git a/cpp/tensorrt_llm/kernels/banRepeatNgram.h b/cpp/tensorrt_llm/kernels/banRepeatNgram.h index bf7215466..b9d15c4ee 100644 --- a/cpp/tensorrt_llm/kernels/banRepeatNgram.h +++ b/cpp/tensorrt_llm/kernels/banRepeatNgram.h @@ -26,9 +26,9 @@ namespace kernels { template -void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, - const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width, - int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream); +void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf, + int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width, + int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu index c21b9a400..cf4dfa05a 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu +++ b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.cu @@ -49,8 +49,8 @@ __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float template __launch_bounds__(THREADBLOCK_SIZE) __global__ - void beam_topK_kernel(const T* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const bool* finished, - const int* sequence_lengths, const int vocab_size, T diversity_rate, float length_penalty) + void beam_topK_kernel(T const* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, bool const* finished, + int const* sequence_lengths, int const vocab_size, T diversity_rate, float length_penalty) { typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -59,7 +59,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ int block_id = blockIdx.x; // batch beam index. TopK partial; - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; #pragma unroll @@ -101,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ { int thread_id = threadIdx.x; int block_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; TopK partial; if (thread_id == 0) @@ -136,7 +136,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ int tid = threadIdx.x; int bid = blockIdx.x; TopK partial; - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; #pragma unroll @@ -167,32 +167,32 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ } template -__global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, - T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size, - const float length_penalty, const int* end_ids) +__global__ void topk_stage_1_opt3(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, + T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size, + float const length_penalty, int const* end_ids) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - const int tid = threadIdx.x; - const int bid = blockIdx.x; + int const tid = threadIdx.x; + int const bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) - const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; + int const row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) + int const block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam + int const tmp_log_buf_index = row_id * vocab_size; + int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; TopK_2 partial; - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finished[row_id] == true) { if (tid < k) { - const int index = tmp_topk_buf_index + tid; + int const index = tmp_topk_buf_index + tid; if (block_lane == 0 && tid == 0) { - const int end_id = end_ids[row_id / k]; + int const end_id = end_ids[row_id / k]; topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; } @@ -226,7 +226,7 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob if (tid == 0) { - const int index = tmp_topk_buf_index + ite; + int const index = tmp_topk_buf_index + ite; topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; tmp_log_probs[total.p] = -MAX_T_VAL; @@ -236,15 +236,15 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob } template -__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, - BeamHypotheses beam_hyps, const int* end_ids, const int vocab_size, const int k) +__global__ void topk_stage_2_opt3(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, + BeamHypotheses beam_hyps, int const* end_ids, int const vocab_size, int const k) { - const int size = k * k * BLOCKS_PER_BEAM_; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; + int const size = k * k * BLOCKS_PER_BEAM_; + int const tid = threadIdx.x; + int const batch_id = blockIdx.x; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; + float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -263,7 +263,7 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk __syncthreads(); if (beam_hyps.num_beams != nullptr) { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { // initialize the buffer @@ -304,9 +304,9 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk } else { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); - const int num_beam = beam_hyps.num_beams[global_batch_idx]; + int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); + int const num_beam = beam_hyps.num_beams[global_batch_idx]; int beam_idx = num_beam; // If there are beam_width finished sentences, check that the score of // selected candidatet is higher than min_normed_score or not. If @@ -345,20 +345,20 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk } } } - const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) + int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) * beam_hyps.max_seq_len; beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; for (int j = beam_hyps.step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * k + int const src_idx = j * beam_hyps.batch_size * k + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; prev_id = beam_hyps.parent_ids_src[src_idx]; } - const int tgt_beam_idx = global_batch_idx * k + beam_idx; + int const tgt_beam_idx = global_batch_idx * k + beam_idx; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.min_normed_scores[global_batch_idx] @@ -389,21 +389,21 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk } template -__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, - T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size, - const float length_penalty) +__global__ void topk_stage_1_opt2_general(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, + T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size, + float const length_penalty) { - const bool IS_FP16 = std::is_same::value; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs - const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam - const int tmp_log_buf_index = row_id * vocab_size; - const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; + int const tid = threadIdx.x; + int const bid = blockIdx.x; + int const row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs + int const block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam + int const tmp_log_buf_index = row_id * vocab_size; + int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; TopK_2 partial; for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) @@ -426,7 +426,7 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_ if (tid == 0) { - const int index = tmp_topk_buf_index + ite; + int const index = tmp_topk_buf_index + ite; topk_tmp_id_buf[index] = total.p; topk_tmp_val_buf[index] = total.u; tmp_log_probs[total.p] = -MAX_T_VAL; @@ -436,15 +436,15 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_ } template -__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, - BeamHypotheses beam_hyps, const int* end_ids, const int k, const int vocab_size) +__global__ void topk_stage_2_opt2_general(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, + BeamHypotheses beam_hyps, int const* end_ids, int const k, int const vocab_size) { - const int size = k * k * BLOCKS_PER_BEAM; - const int tid = threadIdx.x; - const int batch_id = blockIdx.x; - const bool IS_FP16 = std::is_same::value; + int const size = k * k * BLOCKS_PER_BEAM; + int const tid = threadIdx.x; + int const batch_id = blockIdx.x; + bool const IS_FP16 = std::is_same::value; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; + float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -463,7 +463,7 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, __syncthreads(); if (beam_hyps.num_beams != nullptr) { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; @@ -503,9 +503,9 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, } else { - const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; - const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); - const int num_beam = beam_hyps.num_beams[global_batch_idx]; + int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); + int const num_beam = beam_hyps.num_beams[global_batch_idx]; int beam_idx = num_beam; // If there are beam_width finished sentences, check that the score of // selected candidatet is higher than min_normed_score or not. If @@ -544,20 +544,20 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, } } } - const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) + int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) * beam_hyps.max_seq_len; beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; for (int j = beam_hyps.step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * k + int const src_idx = j * beam_hyps.batch_size * k + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; prev_id = beam_hyps.parent_ids_src[src_idx]; } - const int tgt_beam_idx = global_batch_idx * k + beam_idx; + int const tgt_beam_idx = global_batch_idx * k + beam_idx; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.min_normed_scores[global_batch_idx] @@ -613,18 +613,18 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, - const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width, - const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids, + bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width, + int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids, cudaStream_t stream) { // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a // token. - const int vocab_size = vocab_size_padded_; + int const vocab_size = vocab_size_padded_; // Beam size should be less than or equal to vocab size. assert(beam_width <= vocab_size); // Beam search needs the sequence lengths of beams to apply length penalty. assert(length_penalty == 0.0f || sequence_lengths != nullptr); - const int max_block_per_beam = 8; + int const max_block_per_beam = 8; int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float @@ -685,13 +685,13 @@ void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, #undef CASE_K_DIV template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids, - BeamHypotheses* beam_hyps, const bool* finished, const int* sequence_lengths, const int batch_size, - const int beam_width, const int vocab_size_padded_, const float diversity_rate, const float length_penalty, - const int* end_ids, cudaStream_t stream); + BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size, + int const beam_width, int const vocab_size_padded_, float const diversity_rate, float const length_penalty, + int const* end_ids, cudaStream_t stream); template -__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, - const int* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model) +__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, + int const* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model) { if (blockIdx.x == 0) { @@ -711,7 +711,7 @@ __global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, } template -void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, const int* sequence_length, +void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream) { @@ -739,30 +739,30 @@ void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const } } -template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, const float* output, - const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, +template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, float const* output, + int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, const half* output, - const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, +template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, half const* output, + int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, const half2* output, - const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, +template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, half2 const* output, + int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length, - const __nv_bfloat16* output, const int* sequence_length, const size_t batch_size, const size_t beam_width, + __nv_bfloat16 const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); #endif -__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, - const float* cum_log_probs, const int batch_size, const int beam_width) +__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, + float const* cum_log_probs, int const batch_size, int const beam_width) { - const int bid = blockIdx.x; - const int tgt_start_idx = beam_hyps.num_beams[bid]; - const int max_seq_len{beam_hyps.max_seq_len}; - const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]}; + int const bid = blockIdx.x; + int const tgt_start_idx = beam_hyps.num_beams[bid]; + int const max_seq_len{beam_hyps.max_seq_len}; + float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]}; if (beam_hyps.is_done[bid]) { return; @@ -771,10 +771,10 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta { if (threadIdx.x == 0) { - const int src_beam_idx = bid * beam_width + beam_idx; - const int tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx; + int const src_beam_idx = bid * beam_width + beam_idx; + int const tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx; - const int last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1; + int const last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1; beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx] = beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx]; @@ -810,8 +810,8 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta } } -void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs, - const int batch_size, const int beam_width, cudaStream_t stream) +void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs, + int const batch_size, int const beam_width, cudaStream_t stream) { insertUnfinishedPath<<>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width); } diff --git a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h index 02ecfe681..8627808de 100644 --- a/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h +++ b/cpp/tensorrt_llm/kernels/beamSearchTopkKernels.h @@ -35,57 +35,64 @@ namespace kernels // After we collect `beam_width` beams, we will sort them by their norm_scores. struct BeamHypotheses { - // TODO: simplify the pointers - // Pointers initialized in function prepareOutputs in gptDecoder.cpp - bool* is_done{nullptr}; // [batchSize], whether the batch is finished - const int* input_lengths{nullptr}; // [batchSize] - float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr() - float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used? - float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch - float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty) - int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch - int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], - int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src - - // Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu - const int* end_ids{nullptr}; // get from SoftmaxParams - const int* output_ids_src{nullptr}; // for gatherTree - const int* parent_ids_src{nullptr}; // for gatherTree - const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading - const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading - float* log_probs_src{nullptr}; // get from outputs.output_log_probs - int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams - // For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate - int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing - int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing - - // Other scalar values and buffers - int batch_size{0}; - int beam_width{0}; - int ite{0}; - int local_batch_size{0}; - int max_seq_len{0}; - int step{0}; // useless in online version of beam search - int vocab_size{0}; - float* diversity_rates{nullptr}; - float* length_penalties{nullptr}; - int* early_stoppings{nullptr}; - bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs + // BS: batch_size + // BM: beam_width + // mSL: max_seq_length + // %%: parameter name when we call [generation.py] dynamic_decoder.forward + + // Pointers initialized in these two functions: + // [gptDecoder.cpp] GptDecoder::forward or [dynamicDecodeOp.cpp] FtDynamicDecode::forward + bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done + float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs + float* log_probs{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs + float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores + float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores + int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams + int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_is_done + int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt + int const* input_lengths{nullptr}; // [BS*BM] %% context_length + + // Pointers initialized in [onlineBeamSearchLayer.cu] invokeSoftMax: + int const* end_ids{nullptr}; // [BS*BM] %% self.end_ids + FinishedState* finished; // [BS*BM] %% self.finished + float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs + float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled + int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer + int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp] + int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp] + + float* diversity_rates{nullptr}; // [BS] from SamplingConfig + float* length_penalties{nullptr}; // [BS] from SamplingConfig + int* early_stoppings{nullptr}; // [BS] from SamplingConfig + + // Pointers for function gatherTree + int const* output_ids_src{nullptr}; // + int const* parent_ids_src{nullptr}; // + + // Scalar values + bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs, always be true now + int batch_size{0}; // + int beam_width{0}; // + int ite{0}; // index of local_batch, always be 0 if pp_size==1 + int local_batch_size{0}; // + int max_seq_len{0}; // + int step{0}; // only used in [beamSearchTopkKernels.cu], always be 0 in [onlineSoftmaxBeamsearchKernels*.cu.h] + int vocab_size{0}; // vocab_size_padded }; template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, - const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width, - const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids, + bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width, + int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids, cudaStream_t stream); template -void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, const T* encoder_output, - const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, +void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, T const* encoder_output, + int const* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); -void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs, - const int batch_size, const int beam_width, cudaStream_t stream); +void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs, + int const batch_size, int const beam_width, cudaStream_t stream); void invokeCopyBatchMajorToGeneralPtr( void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream); diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index ac0940fef..32312e0b6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -58,13 +58,13 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) else if (dtype == DATA_TYPE_INT32) { int32_t inorm = static_cast(norm); - alpha = reinterpret_cast(inorm); + alpha = reinterpret_cast(inorm); } else if (dtype == DATA_TYPE_BF16) { // TODO HACK!! BF16 Outputs are computed in FP32 for FP8. // This is because cublas does not allow current FP32 output. - alpha = reinterpret_cast(norm); + alpha = reinterpret_cast(norm); } else { @@ -77,7 +77,7 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype) class FusedMHARunnerV2::mhaImpl { public: - mhaImpl(const Data_type data_type, const int numHeads, const int headSize, const float qScaling, int sm_) + mhaImpl(const Data_type data_type, int const numHeads, int const headSize, float const qScaling, int sm_) : mDataType(data_type) , mNumHeads(numHeads) , mHeadSize(headSize) @@ -105,17 +105,17 @@ class FusedMHARunnerV2::mhaImpl // Shared setup function. template - void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size, - const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank) + void setup_params(Params& params, int const b, int const s_q, int const s_kv, int const sliding_window_size, + int const total_seqlen, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank) { - const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling)); + float const inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling)); // Note that we apply scales and bias in the order of // (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi - const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f; - const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale; - const float scale_softmax = 1.f; // Seems to be only required for int8 - const float scale_bmm2 = 1.f; + float const scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f; + float const scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale; + float const scale_softmax = 1.f; // Seems to be only required for int8 + float const scale_bmm2 = 1.f; Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType; // Use exp2f optimization for warp-specialized ws kernels on Hopper. @@ -153,8 +153,8 @@ class FusedMHARunnerV2::mhaImpl } // Support packed QKV. - void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi, - const bool scale_alibi, const int tp_size, const int tp_rank) + void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi, + bool const scale_alibi, int const tp_size, int const tp_rank) { // Determine launch parameters. @@ -165,10 +165,10 @@ class FusedMHARunnerV2::mhaImpl TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0."); mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1); - const bool isSm70 = (sm == kSM_70); - const bool isSm90 = (sm == kSM_90); - const bool isSm8x = (sm == kSM_86 || sm == kSM_89); - const bool isSm80 = (sm == kSM_80); + bool const isSm70 = (sm == kSM_70); + bool const isSm90 = (sm == kSM_90); + bool const isSm8x = (sm == kSM_86 || sm == kSM_89); + bool const isSm80 = (sm == kSM_80); if (isSm70) { mLaunchParams.flash_attention = true; @@ -238,9 +238,9 @@ class FusedMHARunnerV2::mhaImpl } // Support paged_kv_cache and chunked_attention. - void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, - const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi, - const bool scale_alibi, const int tp_size, const int tp_rank) + void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence, + int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi, + bool const scale_alibi, int const tp_size, int const tp_rank) { // Determine launch parameters. @@ -253,9 +253,9 @@ class FusedMHARunnerV2::mhaImpl mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1); // Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256 - const bool isSm90 = (sm == kSM_90); - const bool isSm8x = (sm == kSM_86 || sm == kSM_89); - const bool isSm80 = (sm == kSM_80); + bool const isSm90 = (sm == kSM_90); + bool const isSm8x = (sm == kSM_86 || sm == kSM_89); + bool const isSm80 = (sm == kSM_80); // always use flash attention kernels. mLaunchParams.flash_attention = true; @@ -383,7 +383,7 @@ class FusedMHARunnerV2::mhaImpl // QKV [TOTAL, 3, h, d] // NOTE: we may need to use actual seqlen to set oob_value - const char* qkv_ptr = reinterpret_cast(mParams.qkv_ptr); + char const* qkv_ptr = reinterpret_cast(mParams.qkv_ptr); tensor_size_qkv[3] = mTotalSeqLen; // Q: STEP_Q @@ -467,7 +467,7 @@ class FusedMHARunnerV2::mhaImpl : (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B)); // Q ptr. - const char* q_ptr = reinterpret_cast(mPagedKVParams.q_ptr); + char const* q_ptr = reinterpret_cast(mPagedKVParams.q_ptr); // Q: STEP_Q. q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN, @@ -518,7 +518,7 @@ class FusedMHARunnerV2::mhaImpl paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream); } - void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads) + void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads) { // BF16 FMHA only accumulates on FP32 mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc; @@ -541,11 +541,11 @@ class FusedMHARunnerV2::mhaImpl return MHARunner::fmha_supported(mHeadSize, sm); } - void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) + void run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) { mParams.qkv_ptr = qkvPtr; mParams.o_ptr = outputPtr; - mParams.cu_seqlens = reinterpret_cast(cuSeqlenPtr); + mParams.cu_seqlens = reinterpret_cast(cuSeqlenPtr); if (sm == kSM_90 && mLaunchParams.use_tma) { @@ -556,8 +556,8 @@ class FusedMHARunnerV2::mhaImpl xmmaKernel->run(mParams, mLaunchParams, stream); } - void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost, - const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr, + void run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost, + const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr, cudaStream_t stream) { KVBlockArrayForContextFMHA pagedKVCacheForContextMHA; @@ -568,10 +568,10 @@ class FusedMHARunnerV2::mhaImpl mPagedKVParams.tma_desc_paged_kv = reinterpret_cast(pagedKVTmaDesc); mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA; mPagedKVParams.o_ptr = outputPtr; - mPagedKVParams.cu_q_seqlens = reinterpret_cast(cuQSeqlenPtr); - mPagedKVParams.cu_seqlens = reinterpret_cast(cuKVSeqlenPtr); + mPagedKVParams.cu_q_seqlens = reinterpret_cast(cuQSeqlenPtr); + mPagedKVParams.cu_seqlens = reinterpret_cast(cuKVSeqlenPtr); // paged kv block device ptrs on host (used by tma descriptors). - mLaunchParams.paged_kv_block_ptrs = reinterpret_cast(pagedKVBlockPtrsOnHost); + mLaunchParams.paged_kv_block_ptrs = reinterpret_cast(pagedKVBlockPtrsOnHost); if (sm == kSM_90 && mLaunchParams.use_tma) { @@ -587,7 +587,7 @@ class FusedMHARunnerV2::mhaImpl return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s); } - int getSFromMaxSeqLen(const int max_seq_len) + int getSFromMaxSeqLen(int const max_seq_len) { int S = 1024; @@ -625,35 +625,35 @@ class FusedMHARunnerV2::mhaImpl Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams; Launch_params mLaunchParams; int sm; - const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel; - const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel; + FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel; + FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* pagedKVXmmaKernel; bool use_flash_attention = false; const Data_type mDataType; - const int mNumHeads; - const int mHeadSize; - const float mQScaling; + int const mNumHeads; + int const mHeadSize; + float const mQScaling; int mTotalSeqLen; }; //////////////////////////////////////////////////////////////////////////////////////////////////// FusedMHARunnerV2::FusedMHARunnerV2( - const Data_type data_type, const int numHeads, const int headSize, const float qScaling) + const Data_type data_type, int const numHeads, int const headSize, float const qScaling) : pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion())) { } FusedMHARunnerV2::~FusedMHARunnerV2() = default; -void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, - const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank) +void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, + bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank) { pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); } -void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, - const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi, - const bool scale_alibi, const int tp_size, const int tp_rank) +void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence, + int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi, + bool const scale_alibi, int const tp_size, int const tp_rank) { pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); @@ -665,18 +665,18 @@ bool FusedMHARunnerV2::fmha_supported() } void FusedMHARunnerV2::setup_flags( - const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads) + bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads) { pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads); } -void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) +void FusedMHARunnerV2::run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) { pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream); } -void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost, - const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr, +void FusedMHARunnerV2::run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost, + const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr, cudaStream_t stream) { pimpl->run_paged_kv( @@ -689,7 +689,7 @@ bool FusedMHARunnerV2::isValid(int s) const } // static function to check if fmha is supported when building plugins -bool MHARunner::fmha_supported(const int headSize, const int sm) +bool MHARunner::fmha_supported(int const headSize, int const sm) { if (sm == kSM_70) { diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h index a0b594825..e0912dbe3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.h @@ -41,33 +41,33 @@ namespace kernels class MHARunner { public: - MHARunner(const Data_type dataType, const int numHeads, const int headSize, const float qScaling); + MHARunner(const Data_type dataType, int const numHeads, int const headSize, float const qScaling); MHARunner() = default; virtual ~MHARunner() = default; - virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, - const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) + virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, + bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0) = 0; - virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, - const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, - const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) + virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence, + int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, + bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0) = 0; - static bool fmha_supported(const int headSize, const int sm); + static bool fmha_supported(int const headSize, int const sm); virtual bool fmha_supported() = 0; - virtual void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, - const int num_kv_heads /* MQA or GQA */) + virtual void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, + int const num_kv_heads /* MQA or GQA */) = 0; - virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0; + virtual void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) = 0; - virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host, - const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output, + virtual void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host, + const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output, cudaStream_t stream) = 0; @@ -86,28 +86,28 @@ class MHARunner class FusedMHARunnerV2 : public MHARunner { public: - FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling); + FusedMHARunnerV2(const Data_type dataType, int const numHeads, int const headSize, float const qScaling); ~FusedMHARunnerV2(); // for pimpl - void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, - const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, - const int tp_rank = 0) override; + void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, + bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, + int const tp_rank = 0) override; - void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, - const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, - const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, - const int tp_rank = 0) override; + void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence, + int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, + bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, + int const tp_rank = 0) override; bool fmha_supported() override; - void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override; - void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host, - const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output, + void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) override; + void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host, + const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output, cudaStream_t stream) override; - void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, - const int num_kv_heads /* MQA or GQA */) override; + void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, + int const num_kv_heads /* MQA or GQA */) override; bool isValid(int s) const override; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 6b118ac46..93b3977c7 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -84,9 +84,9 @@ struct AlibiParams struct Fused_multihead_attention_params_v2 { // The QKV matrices. - const void* qkv_ptr; + void const* qkv_ptr; // The mask to implement drop-out. - const void* packed_mask_ptr; + void const* packed_mask_ptr; // The O matrix (output). void* o_ptr; @@ -106,7 +106,7 @@ struct Fused_multihead_attention_params_v2 bool enable_i2f_trick; // array of length b+1 holding prefix sum of actual sequence lengths - const int* cu_seqlens; + int const* cu_seqlens; // use C/32 Format. bool interleaved = false; @@ -177,13 +177,13 @@ struct Fused_multihead_attention_params_v2 struct Fused_multihead_attention_paged_kv_params_v2 { // The Q matrices. - const void* q_ptr; + void const* q_ptr; // Paged KV Cache buffer. KVBlockArrayForContextFMHA paged_kv_cache; // The O matrix (output). void* o_ptr; // The packed mask for random mask. - const void* packed_mask_ptr; + void const* packed_mask_ptr; // The stride between rows of the Q matrices. int64_t q_stride_in_bytes; @@ -211,9 +211,9 @@ struct Fused_multihead_attention_paged_kv_params_v2 AlibiParams alibi_params; // array of length b+1 holding prefix sum of actual kv sequence lengths. - const int* cu_seqlens; + int const* cu_seqlens; // Chunked attention (only handles one tile of Q). - const int* cu_q_seqlens; + int const* cu_q_seqlens; // q with shape [B, S, H, D] in const cache. cudaTmaDesc tma_desc_q; @@ -301,7 +301,7 @@ struct Launch_params // number of paged kv blocks for context sequence. int blocks_per_context_sequence = 0; // device ptrs on the host for paged kv cache. - const int64_t* paged_kv_block_ptrs = nullptr; + int64_t const* paged_kv_block_ptrs = nullptr; // if flash attention is used (only FP16) bool flash_attention = false; // if warp_specialized kernels are used (only SM90 HGMMA + TMA) diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h index b9cd9d659..abe91a813 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_v2.h @@ -63,13 +63,13 @@ class TFusedMultiHeadAttentionXMMAKernel return (uint64_t) s << 32 | d; } - virtual uint64_t hashID(const KernelMeta& kernelMeta) const + virtual uint64_t hashID(KernelMeta const& kernelMeta) const { return hashID(kernelMeta.mS, kernelMeta.mD); } TFusedMultiHeadAttentionXMMAKernel( - const TKernelMeta* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) + TKernelMeta const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) : mDataType(type) , mKernelMeta(pMetaStart) , mKernelMetaCount(nMetaCount) @@ -86,7 +86,7 @@ class TFusedMultiHeadAttentionXMMAKernel for (unsigned int i = 0; i < mKernelMetaCount; ++i) { - const auto& kernelMeta = mKernelMeta[i]; + auto const& kernelMeta = mKernelMeta[i]; if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType) { CUmodule hmod{0}; @@ -125,9 +125,9 @@ class TFusedMultiHeadAttentionXMMAKernel virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const { - const auto findIter = mFunctions.find(hashID(params.s, params.d)); + auto const findIter = mFunctions.find(hashID(params.s, params.d)); - const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; + auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; const CUfunction func = findIter->second.mDeviceFunction; void* kernelParams[] = {¶ms, nullptr}; @@ -142,10 +142,10 @@ class TFusedMultiHeadAttentionXMMAKernel tensorrt_llm::common::CUDADriverWrapper mDriver; Data_type mDataType; - const TKernelMeta* mKernelMeta; + TKernelMeta const* mKernelMeta; unsigned int mKernelMetaCount; unsigned int mSM; - std::unordered_map mModules; + std::unordered_map mModules; struct FusedMultiHeadAttentionKernelInfo { @@ -161,14 +161,14 @@ template class TFusedMHAKernelFactory { public: - const TFusedMHAKernelList* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList, + TFusedMHAKernelList const* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList, unsigned int nbKernels, Data_type type, unsigned int sm) { static std::mutex s_mutex; std::lock_guard lg(s_mutex); - const auto id = hashID(type, sm); - const auto findIter = mKernels.find(id); + auto const id = hashID(type, sm); + auto const findIter = mKernels.find(id); if (findIter == mKernels.end()) { TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm}; @@ -214,7 +214,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 Fused_multihead_attention_params_v2> { public: - FusedMultiHeadAttentionXMMAKernelV2(const FusedMultiHeadAttentionKernelMetaInfoV2* pMetaStart, + FusedMultiHeadAttentionXMMAKernelV2(FusedMultiHeadAttentionKernelMetaInfoV2 const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) : TFusedMultiHeadAttentionXMMAKernel(pMetaStart, nMetaCount, type, sm) @@ -231,7 +231,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull); } - virtual uint64_t hashID(const KernelMeta& kernelMeta) const + virtual uint64_t hashID(KernelMeta const& kernelMeta) const { return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep, @@ -278,7 +278,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 } } - const auto findIter + auto const findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll, launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi, static_cast(launch_params.attention_mask_type), launch_params.granular_tiling)); @@ -290,7 +290,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 launch_params.flash_attention, !launch_params.useKernelWithoutAlibi, static_cast(launch_params.attention_mask_type), launch_params.granular_tiling); - const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; + auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; const CUfunction func = findIter->second.mDeviceFunction; void* kernelParams[] = {¶ms, nullptr}; @@ -369,7 +369,7 @@ class FusedMultiHeadAttentionXMMAKernelV2 using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory; -inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type type, unsigned int sm) +inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type type, unsigned int sm) { return FusedMHAKernelFactoryV2::Get().getXMMAKernels( sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm); @@ -384,7 +384,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2 Fused_multihead_attention_paged_kv_params_v2> { public: - FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart, + FusedMultiHeadAttentionPagedKVXMMAKernelV2(FusedMultiHeadAttentionPagedKVKernelMetaInfoV2 const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) : TFusedMultiHeadAttentionXMMAKernel(pMetaStart, nMetaCount, type, sm) @@ -402,7 +402,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2 | (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull); } - virtual uint64_t hashID(const KernelMeta& kernelMeta) const + virtual uint64_t hashID(KernelMeta const& kernelMeta) const { return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep, kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization, @@ -413,7 +413,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2 Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const { - const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, + auto const findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi, static_cast(launch_params.attention_mask_type), launch_params.granular_tiling)); @@ -426,7 +426,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2 !launch_params.useKernelWithoutAlibi, static_cast(launch_params.attention_mask_type), launch_params.granular_tiling); - const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; + auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; const CUfunction func = findIter->second.mDeviceFunction; void* kernelParams[] = {¶ms, nullptr}; @@ -488,7 +488,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2 using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory; -inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm) +inline FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm) { return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2, sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm); diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/tmaDescriptor.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/tmaDescriptor.h index 1a3835e81..1aafc6cf9 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/tmaDescriptor.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/tmaDescriptor.h @@ -186,7 +186,7 @@ class Multiple_tma_descriptor // set the desctriptor. int set_tma_desctriptor( // ptr to gmem - const void* gmem_ptr, + void const* gmem_ptr, // format is really data_type in TMA terminology. cudaTmaDescFormat format, // interleave mode. @@ -221,7 +221,7 @@ class Multiple_tma_descriptor // set the desctriptor. int set_tma_desctriptor( // ptr to gmem - const void* gmem_ptr, + void const* gmem_ptr, // format is really data_type in TMA terminology. cudaTmaDescFormat format, // interleave mode. diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu b/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu index 87c4c012e..e0be1c745 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu @@ -108,10 +108,10 @@ inline __device__ int4 add128b(T& a, T& b) } __inline__ __device__ void multi_gpu_barrier( - uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx) + uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx) { // At the end of the function, we now that has least block 0 from all others GPUs have reached that point. - volatile uint32_t* my_signals = signals[rank]; + uint32_t volatile* my_signals = signals[rank]; if (tidx < world_size) { // The 1st block notifies the other ranks. @@ -139,8 +139,8 @@ __global__ void multiGpuBarrierKernel(AllReduceParams params) template static __global__ void oneShotAllReduceKernel(AllReduceParams params) { - const int bidx = blockIdx.x; - const int tidx = threadIdx.x; + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; // The number of elements packed into one for comms static constexpr int NUM_ELTS = 16 / sizeof(T); @@ -151,7 +151,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); // The source pointers. Distributed round-robin for the different warps. - const T* src_d[RANKS_PER_NODE]; + T const* src_d[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { @@ -172,7 +172,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&src_d[ii][iter_offset]); + vals[ii].packed = *reinterpret_cast(&src_d[ii][iter_offset]); } // Sum the values from the different ranks. @@ -194,9 +194,9 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) { // The block index. - const int bidx = blockIdx.x; + int const bidx = blockIdx.x; // The thread index with the block. - const int tidx = threadIdx.x; + int const tidx = threadIdx.x; // The number of elements packed into one for comms static constexpr int NUM_ELTS = 16 / sizeof(T); @@ -233,7 +233,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params) #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&src_d[ii][local_offset]); + vals[ii].packed = *reinterpret_cast(&src_d[ii][local_offset]); } // Sum the values from the different ranks. @@ -396,14 +396,14 @@ void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream) multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param); } -AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value) +AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value) { void* const* buffer_ptrs = reinterpret_cast(buffer); AllReduceParams params; // Even plugins use ping buffers, odd plugins use pong. // That way, we don't need to wait for other GPUs to be done // before copying input tensor to workspace. - const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize; + auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize; for (int i = 0; i < tpSize; ++i) { diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index ad8c34eae..988f9d8c4 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -57,7 +57,7 @@ struct AllReduceParams void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; void* local_output_buffer_ptr; - static AllReduceParams deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value); + static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value); }; template diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp index 41d4a0205..a38b80647 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -70,7 +70,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) } bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, - const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) + int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only) { // All tile sizes have a k_tile of 64. @@ -89,7 +89,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, return false; } - const int k_elements_per_split = k / split_k_factor; + int const k_elements_per_split = k / split_k_factor; if ((k_elements_per_split % k_tile) != 0) { return false; @@ -97,9 +97,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } // Check that the workspace has sufficient space for this split-k factor - const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; - const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; - const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; if (required_ws_bytes > workspace_bytes) { @@ -110,7 +110,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, } std::vector get_candidate_tiles( - const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) + int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only) { enum class CutlassGemmType : char { @@ -170,7 +170,7 @@ std::vector get_candidate_tiles( } std::vector get_candidate_tiles_sm90( - const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) + int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only) { enum class CutlassGemmType : char { @@ -226,8 +226,8 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile) return valid_tiles.count(tile) == 1; } -std::vector get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only, - const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma) +std::vector get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only, + bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma) { if (sm == 90 && enable_hopper_gmma) { @@ -235,14 +235,14 @@ std::vector get_candidate_configs(int sm, const bool is_weigh = get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only); std::vector candidate_configs; - for (const auto& tile_config : tiles) + for (auto const& tile_config : tiles) { CutlassGemmConfig config( tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); candidate_configs.push_back(config); - const bool has_m_mcast = supports_mcast_along_m(tile_config); - const bool has_n_mcast = supports_mcast_along_n(tile_config); + bool const has_m_mcast = supports_mcast_along_m(tile_config); + bool const has_n_mcast = supports_mcast_along_n(tile_config); if (has_m_mcast) { CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, @@ -270,9 +270,9 @@ std::vector get_candidate_configs(int sm, const bool is_weigh = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only); std::vector candidate_configs; - const int min_stages = int8_configs_only ? 3 : 2; - const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); - for (const auto& tile_config : tiles) + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { for (int stages = min_stages; stages <= max_stages; ++stages) { @@ -292,9 +292,9 @@ std::vector get_candidate_configs(int sm, const bool is_weigh return candidate_configs; } -CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector& candidate_configs, - const std::vector& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, - const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only) +CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, + std::vector const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, + int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only) { if (occupancies.size() != candidate_configs.size()) @@ -311,7 +311,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector= multi_processor_count * 256 ? 1 : split_k_limit; + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; for (int ii = 0; ii < candidate_configs.size(); ++ii) { CutlassGemmConfig candidate_config = candidate_configs[ii]; @@ -330,21 +330,21 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector num_waves_total) && (current_score < config_score + score_slack))) { diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h index 5e2a7ba35..e910e96cb 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h @@ -27,13 +27,13 @@ namespace cutlass_kernels { std::vector get_candidate_configs(int sm, - const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false, - const int max_split_k = 1, const bool enable_hopper_gmma = false); + bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false, + int const max_split_k = 1, bool const enable_hopper_gmma = false); tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( - const std::vector& candidate_configs, - const std::vector& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, - const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only); + std::vector const& candidate_configs, + std::vector const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, + int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only); } // namespace cutlass_kernels } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp index 3c9000f9d..2489e7d6f 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -158,8 +158,8 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // For int4, each group of 32 rows is permuted using the map below: // 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, const int64_t arch_version) +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, const int64_t arch_version) { // We only want to run this step for weight only quant. @@ -170,19 +170,19 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8 const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int K = 16 / BITS_PER_ELT; - const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; - const int ELTS_PER_REG = 32 / BITS_PER_ELT; + int const BITS_PER_ELT = get_bits_in_quant_type(quant_type); + int const K = 16 / BITS_PER_ELT; + int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; + int const ELTS_PER_REG = 32 / BITS_PER_ELT; - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); int MMA_SHAPE_N = 8; int B_ROWS_PER_MMA = 8 * K; - const int elts_in_int32 = 32 / BITS_PER_ELT; + int const elts_in_int32 = 32 / BITS_PER_ELT; - const int num_vec_cols = num_cols / elts_in_int32; + int const num_vec_cols = num_cols / elts_in_int32; TLLM_CHECK_WITH_INFO( arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); @@ -205,11 +205,11 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8 for (int write_col = 0; write_col < num_vec_cols; ++write_col) { - const int write_row = base_row + tile_row; - const int tile_read_row + int const write_row = base_row + tile_row; + int const tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); - const int read_row = base_row + tile_read_row; - const int read_col = write_col; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; @@ -227,9 +227,9 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8 // issue for relatively large models. template void subbyte_transpose_impl( - int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector& shape) + int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) { - const int bits_per_elt = get_bits_in_quant_type(quant_type); + int const bits_per_elt = get_bits_in_quant_type(quant_type); TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; @@ -240,7 +240,7 @@ void subbyte_transpose_impl( const size_t col_bytes_trans = num_rows * bits_per_elt / 8; const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; - const uint8_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); @@ -260,8 +260,8 @@ void subbyte_transpose_impl( "num_col_bytes = %ld.", VECTOR_WIDTH, col_bytes_trans, col_bytes)); - const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; - const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; for (size_t expert = 0; expert < num_experts; ++expert) { @@ -271,16 +271,16 @@ void subbyte_transpose_impl( for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { - const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); - const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start + ii; + int const row = row_tile_start + ii; for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte + jj; + int const col = col_tile_start_byte + jj; const size_t logical_src_offset = matrix_offset + row * col_bytes + col; @@ -313,11 +313,11 @@ void subbyte_transpose_impl( // is square in the number of elements (not necessarily the number of bytes). for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { - const int ii_byte = ii / ELTS_PER_BYTE; - const int ii_bit_offset = ii % ELTS_PER_BYTE; + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; - const int jj_byte = jj / ELTS_PER_BYTE; - const int jj_bit_offset = jj % ELTS_PER_BYTE; + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); @@ -338,15 +338,15 @@ void subbyte_transpose_impl( const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; - const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); - const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); for (int ii = 0; ii < M_TILE_L1; ++ii) { - const int row = row_tile_start_trans + ii; + int const row = row_tile_start_trans + ii; for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { - const int col = col_tile_start_byte_trans + jj; + int const col = col_tile_start_byte_trans + jj; const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; @@ -364,8 +364,8 @@ void subbyte_transpose_impl( } } -void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type) +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { if (quant_type == QuantType::INT8_WEIGHT_ONLY) @@ -409,7 +409,7 @@ void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) { - const int num_bytes = num_elts / 2; + int const num_bytes = num_elts / 2; // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little // instructions as possible in the CUDA code. @@ -451,9 +451,9 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { - const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; - const int src_shift = 4 * src_idx; - const int dest_shift = 4 * dest_idx; + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; const uint32_t src_bits = (current_register >> src_shift) & 0xF; transformed_register |= (src_bits << dest_shift); @@ -478,8 +478,8 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size } } -void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, LayoutDetails details) +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, LayoutDetails details) { // We only want to run this step for weight only quant. @@ -490,23 +490,23 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); - const int elts_in_int32 = 32 / BITS_PER_ELT; + int const BITS_PER_ELT = get_bits_in_quant_type(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; - const int rows_per_tile = details.rows_per_column_tile; + int const rows_per_tile = details.rows_per_column_tile; TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32), fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows)); - const uint32_t* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile), fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows)); - const int num_vec_rows = num_rows / elts_in_int32; - const int vec_rows_per_tile = rows_per_tile / elts_in_int32; - const int interleave = details.columns_interleaved; + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; for (int expert = 0; expert < num_experts; ++expert) { @@ -532,8 +532,8 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const } } -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, - const std::vector& shape, QuantType quant_type, bool force_interleave) +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave) { int arch = getSMVersion(); if (force_interleave && arch == 90) @@ -546,7 +546,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); size_t num_elts = 1; - for (const auto& dim : shape) + for (auto const& dim : shape) { num_elts *= dim; } @@ -620,7 +620,7 @@ Outputs template void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, bool force_interleave) { @@ -633,8 +633,8 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const int bits_in_type = get_bits_in_quant_type(quant_type); - const int bytes_per_out_col = num_cols * bits_in_type / 8; + int const bits_in_type = get_bits_in_quant_type(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; std::vector weight_buf; if (unprocessed_quantized_weight == nullptr) @@ -643,15 +643,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ unprocessed_quantized_weight = weight_buf.data(); } - const int input_mat_size = num_rows * num_cols; - const int quantized_mat_size = num_rows * bytes_per_out_col; - const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); std::vector per_col_max(num_cols); for (int expert = 0; expert < num_experts; ++expert) { - const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; + WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; // First we find the per column max for this expert weight. @@ -662,7 +662,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ for (int ii = 0; ii < num_rows; ++ii) { - const WeightType* current_weight_row = current_weight + ii * num_cols; + WeightType const* current_weight_row = current_weight + ii * num_cols; for (int jj = 0; jj < num_cols; ++jj) { per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); @@ -681,15 +681,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ for (int ii = 0; ii < num_rows; ++ii) { int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; - const WeightType* current_weight_row = current_weight + ii * num_cols; + WeightType const* current_weight_row = current_weight + ii * num_cols; for (int jj = 0; jj < bytes_per_out_col; ++jj) { if (quant_type == QuantType::INT8_WEIGHT_ONLY) { - const float col_scale = per_col_max[jj]; - const float weight_elt = float(current_weight_row[jj]); - const float scaled_weight = round(weight_elt / col_scale); + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = round(weight_elt / col_scale); const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); current_quantized_weight_row[jj] = clipped_weight; } @@ -700,12 +700,12 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ int8_t packed_int4s = 0; for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { - const int input_idx = 2 * jj + packed_idx; + int const input_idx = 2 * jj + packed_idx; if (input_idx < num_cols) { - const float col_scale = per_col_max[input_idx]; - const float weight_elt = float(current_weight_row[input_idx]); - const float scaled_weight = round(weight_elt / col_scale); + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = round(weight_elt / col_scale); int int_weight = int(scaled_weight); const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); @@ -729,47 +729,47 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_ } template void symmetric_quantize( - int8_t*, int8_t*, half*, const float*, const std::vector&, QuantType, bool); + int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); template void symmetric_quantize( - int8_t*, int8_t*, half*, const half*, const std::vector&, QuantType, bool); + int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType, bool); + int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType, bool); + int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); #endif template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, - const std::vector& shape, QuantType quant_type, bool force_interleave) +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave) { symmetric_quantize( processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); } template void symmetric_quantize( - int8_t*, float*, const float*, const std::vector&, QuantType, bool); + int8_t*, float*, float const*, std::vector const&, QuantType, bool); template void symmetric_quantize( - int8_t*, half*, const float*, const std::vector&, QuantType, bool); + int8_t*, half*, float const*, std::vector const&, QuantType, bool); -template void symmetric_quantize(int8_t*, half*, const half*, const std::vector&, QuantType, bool); +template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); #ifdef ENABLE_BF16 template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( - int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector&, QuantType, bool); + int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); template void symmetric_quantize<__nv_bfloat16, half>( - int8_t*, __nv_bfloat16*, const half*, const std::vector&, QuantType, bool); + int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); template void symmetric_quantize( - int8_t*, half*, const __nv_bfloat16*, const std::vector&, QuantType, bool); + int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); template void symmetric_quantize<__nv_bfloat16, float>( - int8_t*, __nv_bfloat16*, const float*, const std::vector&, QuantType, bool); + int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); #endif } // namespace cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h index 1c3678c12..c97bb29ca 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h @@ -38,26 +38,26 @@ int get_bits_in_quant_type(QuantType quant_type); // Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] // 3-D shapes are [num_experts, num_rows, num_cols] -void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type, const int64_t arch_version); +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, const int64_t arch_version); -void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, - const std::vector& shape, QuantType quant_type); +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); -void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, - const std::vector& shape, QuantType quant_type, bool force_interleave = false); +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave = false); template -void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, - const std::vector& shape, QuantType quant_type, bool force_interleave); +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave); // This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight // to implement a simple reference implementation. template void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, - ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector& shape, QuantType quant_type, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, bool force_interleave); } // namespace cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h index 84f9c14c7..a64b908dd 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -58,27 +58,27 @@ class CutlassFpAIntBGemmRunnerInterface virtual ~CutlassFpAIntBGemmRunnerInterface() {} - virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, + virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; - virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, + virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; - virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; - virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size, + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; virtual std::vector getConfigs() const = 0; @@ -96,20 +96,20 @@ class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterfac CutlassFpAIntBGemmRunner(); ~CutlassFpAIntBGemmRunner(); - void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, + void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; - void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k, + void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; - void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; - void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, - const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size, + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; @@ -120,15 +120,15 @@ class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterfac // stream); // Returns desired workspace size in bytes. - size_t getWorkspaceSize(const int m, const int n, const int k) override; + size_t getWorkspaceSize(int const m, int const n, int const k) override; std::vector getConfigs() const override; private: template - void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, + void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); private: diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h index c61709826..cbdd62865 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -52,8 +52,8 @@ namespace cutlass_kernels template -void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, - const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, +void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales, + T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -127,7 +127,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; - const int ldb = cutlass::platform::is_same::value + int const ldb = cutlass::platform::is_same::value ? n : k * GemmKernel::kInterleave; @@ -171,7 +171,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* } } - const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast(const_cast(A)), k}, {reinterpret_cast(const_cast(B)), ldb}, @@ -230,8 +230,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* // quanitzation is only supported on Ampere+ GPUs. template -void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, +void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points, + T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -261,8 +261,8 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_ template -void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, - const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, +void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points, + T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -300,9 +300,9 @@ constexpr bool is_fp8() template -void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, +void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr) { @@ -412,9 +412,9 @@ template template void CutlassFpAIntBGemmRunner::dispatch_to_arch(const ActivationType* A, const WeightType* B, - const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, - const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, + OutputType>::dispatch_to_arch(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -453,16 +453,16 @@ void CutlassFpAIntBGemmRunner void CutlassFpAIntBGemmRunner::gemm( - const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases, - const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) { - dispatch_to_arch((const ActivationType*) A, (const WeightType*) B, - (const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases, + dispatch_to_arch((ActivationType const*) A, (WeightType const*) B, + (ScaleZeroType const*) weight_scales, (ScaleZeroType const*) weight_zero_points, (BiasType const*) biases, alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); } else @@ -475,8 +475,8 @@ void CutlassFpAIntBGemmRunner void CutlassFpAIntBGemmRunner::gemm( - const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases, - void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -487,15 +487,15 @@ void CutlassFpAIntBGemmRunner void CutlassFpAIntBGemmRunner::gemm( - const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k, + void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { - dispatch_to_arch((const ActivationType*) A, (const WeightType*) B, - (const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig, + dispatch_to_arch((ActivationType const*) A, (WeightType const*) B, + (ScaleZeroType const*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); } else @@ -507,7 +507,7 @@ void CutlassFpAIntBGemmRunner void CutlassFpAIntBGemmRunner::gemm( - const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, + void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -529,12 +529,12 @@ template size_t CutlassFpAIntBGemmRunner::getWorkspaceSize( - const int m, const int n, const int k) + int const m, int const n, int const k) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); - const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h index 3d4bb2936..36b9ae8de 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -44,9 +44,9 @@ namespace cutlass_kernels template -void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, +void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -114,9 +114,9 @@ constexpr bool are_tile_shapes_supported() template -void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, +void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -153,9 +153,9 @@ void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* template -void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, +void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -190,9 +190,9 @@ void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, con template -void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, - const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, - int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, +void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h index 9dc593927..9405287cc 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h @@ -28,9 +28,9 @@ namespace cutlass_kernels template -void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B, - const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, - const float alpha, OutputType* C, int m, int n, int k, const int group_size, +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl index cb7e49e96..0affcd648 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl @@ -59,9 +59,9 @@ namespace cutlass_kernels template -void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B, - const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, - const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -233,7 +233,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const Weigh StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1)); // Use the output as the bias to avoid making a tma descriptor with a nullptr. - auto output_as_bias_type = reinterpret_cast(C); + auto output_as_bias_type = reinterpret_cast(C); typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1}, {reinterpret_cast(B), stride_B, reinterpret_cast(A), diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h index f06ba4b4d..f3561dc50 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm.h @@ -47,13 +47,13 @@ class CutlassInt8GemmRunnerInterface virtual ~CutlassInt8GemmRunnerInterface() {} - virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + virtual void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream) = 0; // Returns desired workspace size in bytes. - virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; virtual std::vector getConfigs() const = 0; @@ -70,18 +70,18 @@ class CutlassInt8GemmRunner : public virtual CutlassInt8GemmRunnerInterface CutlassInt8GemmRunner(); ~CutlassInt8GemmRunner(); - void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow, + void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream) override; // Returns desired workspace size in bytes. - size_t getWorkspaceSize(const int m, const int n, const int k) override; + size_t getWorkspaceSize(int const m, int const n, int const k) override; std::vector getConfigs() const override; private: - void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, + void dispatchToArch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr); int mSm; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h index 6abcfa6a0..8bb5b12db 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h @@ -60,8 +60,8 @@ namespace cutlass_kernels { template -void genericInt8GemmKernelLauncher(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, +void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -174,8 +174,8 @@ void genericInt8GemmKernelLauncher(const int8_t* A, const int8_t* B, tk::QuantMo template struct dispatchStages { - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + static void dispatch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -188,8 +188,8 @@ struct dispatchStages template struct dispatchStages { - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + static void dispatch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -202,8 +202,8 @@ template struct dispatchStages 2)>::type> { - static void dispatch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, + static void dispatch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -214,8 +214,8 @@ struct dispatchStages -void dispatchGemmConfig(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, +void dispatchGemmConfig(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace, size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr) { @@ -255,8 +255,8 @@ void dispatchGemmConfig(const int8_t* A, const int8_t* B, tk::QuantMode quantOpt } template -void dispatchGemmToCutlass(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, T* C, int m, int n, int k, char* workspace, size_t workspaceBytes, +void dispatchGemmToCutlass(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, T* C, int m, int n, int k, char* workspace, size_t workspaceBytes, tkc::CutlassGemmConfig gemmConfig, cudaStream_t stream, int* occupancy = nullptr) { @@ -320,8 +320,8 @@ CutlassInt8GemmRunner::~CutlassInt8GemmRunner() } template -void CutlassInt8GemmRunner::dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, - const float* alphaCol, const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, +void CutlassInt8GemmRunner::dispatchToArch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, + float const* alphaCol, float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -353,8 +353,8 @@ void CutlassInt8GemmRunner::dispatchToArch(const int8_t* A, const int8_t* B, } template -void CutlassInt8GemmRunner::gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, - const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, +void CutlassInt8GemmRunner::gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, + float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -373,12 +373,12 @@ std::vector CutlassInt8GemmRunner::getConfigs() const } template -size_t CutlassInt8GemmRunner::getWorkspaceSize(const int m, const int n, const int k) +size_t CutlassInt8GemmRunner::getWorkspaceSize(int const m, int const n, int const k) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); // These are the min tile sizes for each config, which would launch the maximum number of blocks - const int maxGridM = cutlass::ceil_div(m, MIN_M_TILE); - const int maxGridN = cutlass::ceil_div(m, MIN_N_TILE); + int const maxGridM = cutlass::ceil_div(m, MIN_M_TILE); + int const maxGridN = cutlass::ceil_div(m, MIN_N_TILE); // We need 4 bytes per block in the worst case. We launch SPLIT_K_LIMIT in z dim. return static_cast(maxGridM * maxGridN * SPLIT_K_LIMIT * 4); } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h index 3f4def7d7..723143bc9 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -52,23 +52,23 @@ class MoeGemmRunner best_config_ = std::move(best_config); } - void moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + void moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream); - void moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, int64_t* total_rows_before_expert, + void moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); std::vector getConfigs(); private: template - void dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + void dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr); template - void runGemm(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + void runGemm(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h index f628ab6ea..57683c43d 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -47,9 +47,9 @@ namespace tensorrt_llm // ============================= Variable batched Gemm things =========================== template -void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, +void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { #ifdef ENABLE_BF16 @@ -120,15 +120,15 @@ void genericMoeGemmKernelLauncher(const T* A, const WeightType* B, const T* weig } int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - const int threadblock_count = multi_processor_count * occupancy; + int const threadblock_count = multi_processor_count * occupancy; typename EpilogueOp::Params epilogue_op( ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - const int group_size = gemm_k; + int const group_size = gemm_k; typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, - reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), reinterpret_cast(biases), + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), reinterpret_cast(biases), reinterpret_cast(C), total_rows_before_expert, gemm_n, gemm_k); GemmGrouped gemm; @@ -151,7 +151,7 @@ template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -165,7 +165,7 @@ template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -181,7 +181,7 @@ template 2)>::type> { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, + static void dispatch(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -194,7 +194,7 @@ struct dispatch_stages -void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, +void dispatchGemmConfig(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -224,7 +224,7 @@ void dispatchGemmConfig(const T* A, const WeightType* B, const T* weight_scales, // This overload is only enabled when T == WeightType. template ::value && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -277,7 +277,7 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s // compile time template ::value && !std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -328,7 +328,7 @@ void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_s // This overload will handle simt gemms. It is disabled via SFINAE for tensorop. template ::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, T const* weight_scales, T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int sm_version, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) @@ -370,8 +370,8 @@ MoeGemmRunner::MoeGemmRunner() template template -void MoeGemmRunner::dispatchToArch(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, +void MoeGemmRunner::dispatchToArch(T const* A, WeightType const* B, T const* weight_scales, + T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy) { if (sm_ >= 70 && sm_ < 75) @@ -407,8 +407,8 @@ void MoeGemmRunner::dispatchToArch(const T* A, const template template -void MoeGemmRunner::runGemm(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, +void MoeGemmRunner::runGemm(T const* A, WeightType const* B, T const* weight_scales, + T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { auto chosen_conf = this->best_config_; @@ -437,8 +437,8 @@ void MoeGemmRunner::runGemm(const T* A, const Weight } template -void MoeGemmRunner::moeGemmBiasAct(const T* A, const WeightType* B, const T* weight_scales, - const T* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, +void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, T const* weight_scales, + T const* biases, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { switch (activation_type) @@ -465,7 +465,7 @@ void MoeGemmRunner::moeGemmBiasAct(const T* A, const WeightType* } template -void MoeGemmRunner::moeGemm(const T* A, const WeightType* B, const T* weight_scales, T* C, +void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, T const* weight_scales, T* C, int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, cudaStream_t stream) { diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu index ecc455b26..73637028b 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.cu @@ -29,8 +29,8 @@ namespace mmha // Forward declaration of the kernel launcher to avoid including decoderMaskedMultiheadAttentionLaunch.h template -void mmha_launch_kernel(const T_PARAMS& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream); +void mmha_launch_kernel(const T_PARAMS& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream); } // namespace mmha @@ -56,11 +56,11 @@ namespace break; template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream) { - const bool has_implicit_rel_attn_bias = params.max_distance > 0 && params.relative_attention_bias != nullptr; - const int head_size = params.hidden_size_per_head; + bool const has_implicit_rel_attn_bias = params.max_distance > 0 && params.relative_attention_bias != nullptr; + int const head_size = params.hidden_size_per_head; TLLM_CHECK_WITH_INFO(!has_implicit_rel_attn_bias || head_size == 32 || head_size == 64 || head_size == 128, "MMHA kernels haven't instantiate implicit_relative_attention_bias paths for head size %d.", head_size); switch (params.hidden_size_per_head) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h index be8752afe..240aa59f0 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h @@ -81,19 +81,19 @@ struct Multihead_attention_params_base T* out = nullptr; // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; + T const *q = nullptr, *q_bias = nullptr; // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; + T const *k = nullptr, *k_bias = nullptr; // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; + T const *v = nullptr, *v_bias = nullptr; // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; + int const* cache_indir = nullptr; // scales - const float* query_weight_output_scale = nullptr; - const float* attention_qk_scale = nullptr; - const float* attention_output_weight_input_scale_inv = nullptr; + float const* query_weight_output_scale = nullptr; + float const* attention_qk_scale = nullptr; + float const* attention_output_weight_input_scale_inv = nullptr; // Stride to handle the case when KQV is a single buffer int stride = 0; @@ -134,22 +134,22 @@ struct Multihead_attention_params_base float inv_sqrt_dh = 0.0f; // If relative position embedding is used - const T* relative_attention_bias = nullptr; + T const* relative_attention_bias = nullptr; int relative_attention_bias_stride = 0; int max_distance = 0; // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; + T const* linear_bias_slopes = nullptr; - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; + T const* ia3_key_weights = nullptr; + T const* ia3_value_weights = nullptr; + int const* ia3_tasks = nullptr; - const float* qkv_scale_quant_orig = nullptr; - const float* attention_out_scale_orig_quant = nullptr; + float const* qkv_scale_quant_orig = nullptr; + float const* attention_out_scale_orig_quant = nullptr; - const float* kv_scale_orig_quant = nullptr; - const float* kv_scale_quant_orig = nullptr; + float const* kv_scale_orig_quant = nullptr; + float const* kv_scale_quant_orig = nullptr; bool int8_kv_cache = false; bool fp8_kv_cache = false; @@ -176,7 +176,7 @@ struct Multihead_attention_params_base // threadblock counter to identify the complete of partial attention computations int* block_counter = nullptr; - const int* memory_length_per_sample = nullptr; + int const* memory_length_per_sample = nullptr; }; template @@ -194,10 +194,10 @@ struct Multihead_attention_params : public Multihead_attention_params_ bool* finished = nullptr; // required in case of masked attention with different length - const int* length_per_sample = nullptr; + int const* length_per_sample = nullptr; // input lengths to identify the paddings (i.e. input seq < padding < new generated seq). - const int* input_lengths = nullptr; + int const* input_lengths = nullptr; }; template using Masked_multihead_attention_params = Multihead_attention_params; @@ -214,10 +214,10 @@ struct Multihead_attention_params : public Multihead_attention_params_b bool* finished = nullptr; // required in case of masked attention with different length - const int* length_per_sample = nullptr; + int const* length_per_sample = nullptr; // input lengths to identify the paddings (i.e. input seq < padding < new generated seq). - const int* input_lengths = nullptr; + int const* input_lengths = nullptr; }; template using Cross_multihead_attention_params = Multihead_attention_params; @@ -248,9 +248,9 @@ DECLARE_MMHA_NORMAL_AND_PAGED(__nv_bfloat16); template inline int estimate_min_multi_block_count(int max_timesteps, int max_dynamic_shmem_per_block) { - const auto qk_elts = static_cast((max_timesteps + 1 + 4 - 1) / 4); + auto const qk_elts = static_cast((max_timesteps + 1 + 4 - 1) / 4); int size_per_elts = 16; - const auto qk_sz = qk_elts * 16; + auto const qk_sz = qk_elts * 16; size_t logits_sz = 0; #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS if (sizeof(T) != 4) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h index a75c595a1..e549b0718 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionLaunch.h @@ -41,15 +41,15 @@ namespace mmha { template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, int threads_per_block) +inline size_t smem_size_in_bytes(Multihead_attention_params const& params, int threads_per_block) { using Tk = typename kernel_type_t::Type; // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = DO_CROSS_ATTENTION + int const max_timesteps = DO_CROSS_ATTENTION ? params.cyclic_attention_window_size : min((DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep), params.cyclic_attention_window_size); - const auto qk_elts = static_cast(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign - const auto qk_sz = qk_elts * 16; + auto const qk_elts = static_cast(divUp(max_timesteps + 1, 4)); // explicit cast because of the sign + auto const qk_sz = qk_elts * 16; // The extra memory needed if we are not using floats for the final logits. size_t logits_sz = 0; @@ -92,7 +92,7 @@ inline size_t smem_size_in_bytes(const Multihead_attention_params -inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params& params, +inline void multi_block_grid_setup(dim3& grid, Multihead_attention_params const& params, int blocks_per_sm, int block_size, int tlength) { if (!params.multi_block_mode) @@ -103,17 +103,17 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params< int balanced_seq_len_tile = mmha::divUp(params.multi_processor_count * blocks_per_sm, params.batch_size * params.num_heads); - const int threads_per_value = mmha::threads_per_value(mmha::dh_max(Dh)); + int const threads_per_value = mmha::threads_per_value(mmha::dh_max(Dh)); // Make sure that each block at least processes one loop of kv (unroll size is default at 8). - const int seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8; + int const seq_len_per_kv_loop = mmha::divUp(block_size, threads_per_value) * 8; int max_seq_len_tile = params.max_seq_len_tile; - const bool multi_block_debug_flag = getEnvMmhaMultiblockDebug(); + bool const multi_block_debug_flag = getEnvMmhaMultiblockDebug(); // User defined number of blocks. if (multi_block_debug_flag) { - const int env_seq_len_tile = getEnvMmhaBlocksPerSequence(); + int const env_seq_len_tile = getEnvMmhaBlocksPerSequence(); balanced_seq_len_tile = env_seq_len_tile > 0 ? env_seq_len_tile : balanced_seq_len_tile; } else @@ -211,12 +211,12 @@ inline void multi_block_grid_setup(dim3& grid, const Multihead_attention_params< template -void mmha_launch_kernel_ex(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, - const KCacheBuffer& k_cache_buffer, const cudaStream_t& stream, int tlength) +void mmha_launch_kernel_ex(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, + KCacheBuffer const& k_cache_buffer, cudaStream_t const& stream, int tlength) { dim3 grid{static_cast(params.num_heads), static_cast(params.batch_size), 1}; - const int kernel_total_blocks = params.batch_size * params.num_heads; + int const kernel_total_blocks = params.batch_size * params.num_heads; // Don't tune the block size if batchxhead is large enough. // The max number of warps we can launch per SM is 32 limited by registers. if (kernel_total_blocks >= params.multi_processor_count * 4) @@ -296,8 +296,8 @@ void mmha_launch_kernel_ex(const KernelParamsType& params, const KVCacheBuffer& template -void mmha_launch_kernel_dispatch_pos_shift(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength) +void mmha_launch_kernel_dispatch_pos_shift(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream, int tlength) { if (params.position_shift_enabled && !KernelParamsType::DO_CROSS_ATTENTION) { @@ -315,8 +315,8 @@ void mmha_launch_kernel_dispatch_pos_shift(const KernelParamsType& params, const template -void mmha_launch_kernel_dispatch_8bits_kv_cache(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream, int tlength) +void mmha_launch_kernel_dispatch_8bits_kv_cache(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream, int tlength) { if (params.int8_kv_cache) { @@ -339,8 +339,8 @@ void mmha_launch_kernel_dispatch_8bits_kv_cache(const KernelParamsType& params, template -void mmha_launch_kernel_dispatch(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) +void mmha_launch_kernel_dispatch(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream) { int const tlength = params.timestep; if (params.multi_block_mode) @@ -356,8 +356,8 @@ void mmha_launch_kernel_dispatch(const KernelParamsType& params, const KVCacheBu } template -void mmha_launch_kernel(const KernelParamsType& params, const KVCacheBuffer& kv_cache_buffer, - const KVLinearBuffer& shift_k_cache, const cudaStream_t& stream) +void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_cache_buffer, + KVLinearBuffer const& shift_k_cache, cudaStream_t const& stream) { assert((params.rotary_embedding_dim != 0) == (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 9e86f6f7a..0b9f29c59 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -652,14 +652,14 @@ struct V_vec_accum_fp32_ //////////////////////////////////////////////////////////////////////////////////////////////////// template -__inline__ __device__ constexpr Tout vec_conversion(const Tin& x) +__inline__ __device__ constexpr Tout vec_conversion(Tin const& x) { static_assert(std::is_same::value, "Type mismatch"); return x; } template <> -__inline__ __device__ Float8_ vec_conversion(const uint4& a) +__inline__ __device__ Float8_ vec_conversion(uint4 const& a) { Float8_ fc; fc.x = half2_to_float2(a.x); @@ -671,7 +671,7 @@ __inline__ __device__ Float8_ vec_conversion(const uint4& a) #ifdef ENABLE_BF16 template <> -__inline__ __device__ Float8_ vec_conversion(const bf16_8_t& a) +__inline__ __device__ Float8_ vec_conversion(bf16_8_t const& a) { Float8_ fc; fc.x = bf1622float2(a.x); @@ -685,39 +685,39 @@ __inline__ __device__ Float8_ vec_conversion(const bf16_8_t& #ifdef ENABLE_FP8 // fp8_t template <> -__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +__inline__ __device__ float vec_conversion(__nv_fp8_e4m3 const& a) { return float(a); } template <> -__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(float const& a) { return __nv_fp8_e4m3(a); } // fp8_2_t template <> -__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +__inline__ __device__ float2 vec_conversion(fp8_2_t const& a) { return float2(a); } template <> -__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +__inline__ __device__ fp8_2_t vec_conversion(float2 const& a) { return fp8_2_t(a); } // fp8_4_t template <> -__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +__inline__ __device__ float4 vec_conversion(fp8_4_t const& a) { return float4(a); } template <> -__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +__inline__ __device__ fp8_4_t vec_conversion(float4 const& a) { return fp8_4_t(a); } @@ -752,7 +752,7 @@ inline __device__ float qk_dot_(const Q_vec (&q)[N], const K_vec (&k)[N]) } template -inline __device__ float qk_scale_dot_(const Q_vec (&q)[N], const K_vec (&k)[N], const float k_scale) +inline __device__ float qk_scale_dot_(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale) { #ifdef MMHA_USE_FP32_ACCUM_FOR_FMA using K_vec_accum = typename K_vec_accum_fp32_::Type; @@ -791,7 +791,7 @@ struct Qk_dot } template - static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], const float k_scale) + static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale) { #ifdef MMHA_USE_HMMA static_assert("HMMA doesn't support k scales"); @@ -800,7 +800,7 @@ struct Qk_dot } template - static inline __device__ bool is_leader(const int tidx) + static inline __device__ bool is_leader(int const tidx) { return (tidx % THREADS_PER_KEY) == 0; } @@ -809,14 +809,14 @@ struct Qk_dot //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void hmma_fp32(float4& c, const K_vec& a, K_vec b) +inline __device__ void hmma_fp32(float4& c, K_vec const& a, K_vec b) { // Not supported. assert(false); } template <> -inline __device__ void hmma_fp32(float4& c, const uint32_t& a, uint32_t b) +inline __device__ void hmma_fp32(float4& c, uint32_t const& a, uint32_t b) { asm volatile( "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" @@ -829,14 +829,14 @@ inline __device__ void hmma_fp32(float4& c, const uint32_t& a, uint32_t b) } template <> -inline __device__ void hmma_fp32(float4& c, const uint2& a, uint2 b) +inline __device__ void hmma_fp32(float4& c, uint2 const& a, uint2 b) { hmma_fp32(c, a.x, b.x); hmma_fp32(c, a.y, b.y); } template <> -inline __device__ void hmma_fp32(float4& c, const uint4& a, uint4 b) +inline __device__ void hmma_fp32(float4& c, uint4 const& a, uint4 b) { hmma_fp32(c, a.x, b.x); hmma_fp32(c, a.y, b.y); @@ -918,7 +918,7 @@ struct Qk_dot } template - static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], const float k_scale) + static inline __device__ float scale_dot(const Q_vec (&q)[N], const K_vec (&k)[N], float const k_scale) { #ifdef MMHA_USE_HMMA static_assert("HMMA doesn't support k scales"); @@ -927,7 +927,7 @@ struct Qk_dot } template - static inline __device__ bool is_leader(const int tidx) + static inline __device__ bool is_leader(int const tidx) { // Use HMMA.FP32, leader threads are in the diagonal roughly (0, 4, 9, 13, 18, 22, 27, 31). #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA) @@ -943,7 +943,7 @@ struct Qk_dot leader = int(lane / THREADS_PER_KEY) * int(THREADS_PER_KEY / 8); } #else - const bool leader = 0; + bool const leader = 0; #endif // defined MMHA_USE_HMMA return (tidx % THREADS_PER_KEY) == leader; } @@ -953,7 +953,7 @@ struct Qk_dot template inline __device__ void Logit_value_fma( - V_vec_accum& out, const Tk* logits_smem, const V_vec_m& v_vec, const float v_scale, const bool is_mask) + V_vec_accum& out, Tk const* logits_smem, V_vec_m const& v_vec, float const v_scale, bool const is_mask) { #if defined(MMHA_USE_FP32_ACCUM_FOR_LOGITS) float logit = is_mask ? 0.f : reinterpret_cast(logits_smem)[0]; @@ -1331,12 +1331,12 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Note that the maximum sequence length supported by the model might be greater than this. // Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers. // By default, you can assume that they are the same. - const auto cyclic_kv_cache_len = static_cast(params.cyclic_attention_window_size); + auto const cyclic_kv_cache_len = static_cast(params.cyclic_attention_window_size); // The number of sink tokens in kv cache to support streamingllm - const auto sink_token_len = static_cast(params.sink_token_length); + auto const sink_token_len = static_cast(params.sink_token_length); // The current timestep (including paddings). // It is only used to calculate the smem stride. - const auto timestep = static_cast(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep); + auto const timestep = static_cast(DO_MULTI_BLOCK ? params.timesteps_per_block : params.timestep); #ifdef ENABLE_MULTI_BLOCK_OPTION constexpr bool MULTI_BLOCK_FLAG = DO_MULTI_BLOCK; @@ -1357,7 +1357,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS if (sizeof(Tk) != 4) { - const auto max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len); + auto const max_timesteps = DO_CROSS_ATTENTION ? cyclic_kv_cache_len : min(timestep, cyclic_kv_cache_len); logits_smem_ += divUp(max_timesteps + 1, 4u) * 16; } Tk* logits_smem = reinterpret_cast(logits_smem_); @@ -1431,27 +1431,27 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske static_assert(THREADS_PER_BLOCK >= QK_VECS_PER_Dh_MAX); // The batch/beam idx - const auto batch_beam_idx = blockIdx.y; + auto const batch_beam_idx = blockIdx.y; if (params.finished != nullptr && params.finished[batch_beam_idx]) { return; } // The head. - const unsigned hi{blockIdx.x}; + unsigned const hi{blockIdx.x}; // The head index of keys and values adjusted for MQA/GQA. - const int qhead_per_kv{params.num_heads / params.num_kv_heads}; - const unsigned hi_kv{hi / qhead_per_kv}; + int const qhead_per_kv{params.num_heads / params.num_kv_heads}; + unsigned const hi_kv{hi / qhead_per_kv}; // The number of heads. - const auto num_heads = static_cast(params.num_heads); + auto const num_heads = static_cast(params.num_heads); // The number of heads for keys and values adjusted for MQA/GQA. - const auto num_heads_kv = static_cast(params.num_kv_heads); + auto const num_heads_kv = static_cast(params.num_kv_heads); // The thread in the block. - const unsigned tidx{threadIdx.x}; + unsigned const tidx{threadIdx.x}; // The column tile along L dimension on K^T -- noted as T_c in flash-attention paper - const unsigned c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0}; + unsigned const c_tile{MULTI_BLOCK_FLAG ? blockIdx.z : 0}; // Indicate if we need to compute the K/V cache element (add KV bias, IA3, RoPE, etc.) and update the cache. // For Self-Attention, it's always required. @@ -1477,40 +1477,40 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The actual sequence length excluding the paddings. // minus 1 because it includes the current timestep while tlength denotes the kv cache length. - const int tlength = DO_CROSS_ATTENTION + int const tlength = DO_CROSS_ATTENTION ? params.memory_length_per_sample[batch_beam_idx] - 1 : (params.length_per_sample ? (params.length_per_sample[batch_beam_idx] - 1) : static_cast(timestep)); // We will use cyclic kv cache when it exceeds the limit. // The length position for storing new key and value. - const int cyclic_tlength = kvCacheBuffer.getKVTokenIdx(tlength); + int const cyclic_tlength = kvCacheBuffer.getKVTokenIdx(tlength); // When enable cyclic kv cache and one more block mode, we need to shift the index to the actual index in the // sequence. Otherwise, if the token is not the sink token, we need to add the bubblen length to the index. - const bool enable_use_seq_idx_kv = kvCacheBuffer.mEnableOneMoreBlock && tlength > cyclic_kv_cache_len; - const int shift_for_cyclic_kv = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : kvCacheBuffer.mBubbleLen; - const int shift_for_cyclic_k = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : pastKCache.mBubbleLen; + bool const enable_use_seq_idx_kv = kvCacheBuffer.mEnableOneMoreBlock && tlength > cyclic_kv_cache_len; + int const shift_for_cyclic_kv = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : kvCacheBuffer.mBubbleLen; + int const shift_for_cyclic_k = (enable_use_seq_idx_kv) ? tlength - cyclic_kv_cache_len : pastKCache.mBubbleLen; // The actual kv cache length. // tlength is the past length actually. - const int kv_loop_length = min(tlength, cyclic_kv_cache_len); + int const kv_loop_length = min(tlength, cyclic_kv_cache_len); // The context length for beam searching optimization (all points to beam 0). // TODO: with cyclic kv cache, we set it 0 for now (will optimize in the future) // as context kv cache might be overwritten by the new kv cache - const int beam0_context_length + int const beam0_context_length = HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx]; // The position of the current timestep, and it is used to apply the position embedding - const int current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length; + int const current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length; // The offset in the Q and K buffer also accounts for the batch. - const auto qk_vec_idx = tidx * QK_VEC_SIZE; - const auto is_valid_qk_vec = qk_vec_idx < Dh; + auto const qk_vec_idx = tidx * QK_VEC_SIZE; + auto const is_valid_qk_vec = qk_vec_idx < Dh; - const bool load_qkv_quant = params.qkv_scale_quant_orig != nullptr; - const bool write_attention_quant = params.attention_out_scale_orig_quant != nullptr; + bool const load_qkv_quant = params.qkv_scale_quant_orig != nullptr; + bool const write_attention_quant = params.attention_out_scale_orig_quant != nullptr; // Quant/Dequant scales for 8bits kv cache. using T_scale = typename kv_cache_scale_type_t::Type; T_scale kv_scale_orig_quant, k_scale_quant_orig; - const float k_scale_quant_orig_f = (ENABLE_8BITS_K_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); - const float kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); + float const k_scale_quant_orig_f = (ENABLE_8BITS_K_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); + float const kv_scale_quant_orig_f = (ENABLE_8BITS_KV_CACHE ? params.kv_scale_quant_orig[0] : 1.0f); convert_from_float(&k_scale_quant_orig, k_scale_quant_orig_f); convert_from_float(&kv_scale_orig_quant, (ENABLE_8BITS_KV_CACHE ? params.kv_scale_orig_quant[0] : 1.0f)); @@ -1535,29 +1535,29 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The stride between tokens. We may be able to always use params.stride. uint32_t q_stride = params.stride ? static_cast(params.stride) : (num_heads * Dh); // The offset. - const auto q_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi, qk_vec_idx, q_stride, Dh); + auto const q_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi, qk_vec_idx, q_stride, Dh); if (load_qkv_quant) { using Packed_Int8_t = typename packed_type::value>::type; using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_quant_orig[0]; - const auto q_quant - = *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); + auto const q_scaling = params.qkv_scale_quant_orig[0]; + auto const q_quant + = *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); convert_from_float(&q, mul(q_scaling, float_from_int8(q_quant))); } else { - q = vec_conversion(*reinterpret_cast(¶ms.q[q_offset])); + q = vec_conversion(*reinterpret_cast(¶ms.q[q_offset])); } if constexpr (DO_CROSS_ATTENTION) { - const auto k_idx = QK_VEC_SIZE * tidx; - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx); + auto const k_idx = QK_VEC_SIZE * tidx; + int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi, Dh, k_idx); Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength)); - k = vec_conversion(*reinterpret_cast(&k_cache[inBlockIdx])); + k = vec_conversion(*reinterpret_cast(&k_cache[inBlockIdx])); } else { @@ -1565,36 +1565,36 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The stride between tokens. We may be able to always use params.stride. uint32_t k_stride = params.stride ? static_cast(params.stride) : (num_heads_kv * Dh); // The offset. - const auto k_offset + auto const k_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi_kv, qk_vec_idx, k_stride, Dh); if (load_qkv_quant) { using Packed_Int8_t = typename packed_type::value>::type; using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_quant_orig[1]; - const auto k_quant - = *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); + auto const k_scaling = params.qkv_scale_quant_orig[1]; + auto const k_quant + = *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); convert_from_float(&k, mul(k_scaling, float_from_int8(k_quant))); } else { - k = vec_conversion(*reinterpret_cast(¶ms.k[k_offset])); + k = vec_conversion(*reinterpret_cast(¶ms.k[k_offset])); } } if (params.q_bias != nullptr) { - const auto q_bias_offset = tensorrt_llm::common::flat_index2(hi, qk_vec_idx, Dh); + auto const q_bias_offset = tensorrt_llm::common::flat_index2(hi, qk_vec_idx, Dh); q_bias - = vec_conversion(*reinterpret_cast(¶ms.q_bias[q_bias_offset])); + = vec_conversion(*reinterpret_cast(¶ms.q_bias[q_bias_offset])); } if (HANDLE_KV && params.k_bias != nullptr) { - const auto k_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, qk_vec_idx, Dh); + auto const k_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, qk_vec_idx, Dh); k_bias - = vec_conversion(*reinterpret_cast(¶ms.k_bias[k_bias_offset])); + = vec_conversion(*reinterpret_cast(¶ms.k_bias[k_bias_offset])); } } @@ -1606,20 +1606,20 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } // The width of the beam. - const auto beam_width = static_cast(params.beam_width); + auto const beam_width = static_cast(params.beam_width); // The batch idx. - const int batch_idx = batch_beam_idx / beam_width; + int const batch_idx = batch_beam_idx / beam_width; // Do we apply IA3? - const bool do_ia3 = HANDLE_KV && params.ia3_tasks != nullptr; + bool const do_ia3 = HANDLE_KV && params.ia3_tasks != nullptr; // Compute the IA3 task. One per batch index. - const auto ia3_ti_hi = do_ia3 + auto const ia3_ti_hi = do_ia3 ? tensorrt_llm::common::flat_index2(static_cast(params.ia3_tasks[batch_idx]), hi, num_heads) : 0; if (do_ia3 && is_valid_qk_vec) { k = mul(k, - vec_conversion(*reinterpret_cast( + vec_conversion(*reinterpret_cast( ¶ms.ia3_key_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, qk_vec_idx, Dh)]))); } k_wo_pos = k; @@ -1650,15 +1650,15 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } case PositionEmbeddingType::kROPE_GPT_NEOX: { - const bool do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + bool const do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; T* q_smem_ = reinterpret_cast(smem_); T* k_smem_ = q_smem_ + params.rotary_embedding_dim; - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = qk_vec_idx / half_rotary_dim; - const int intra_half_idx = qk_vec_idx % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + int const half_rotary_dim = params.rotary_embedding_dim / 2; + int const half_idx = qk_vec_idx / half_rotary_dim; + int const intra_half_idx = qk_vec_idx % half_rotary_dim; + int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts assert(half_rotary_dim % QK_VEC_SIZE == 0); @@ -1673,7 +1673,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske __syncthreads(); - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + int const transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; if (do_rotary) { @@ -1770,8 +1770,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } // Pre-compute the pointer for the relative attention bias. - const T* relative_attention_bias_ptr = nullptr; - const T* relative_attention_bias_ptr_fixed = nullptr; // record the base for offset + T const* relative_attention_bias_ptr = nullptr; + T const* relative_attention_bias_ptr_fixed = nullptr; // record the base for offset if (has_relative_attention_bias) { // "hi" is unsigned, subtracting int from unsigned int causes underflow. Cast to int @@ -1819,7 +1819,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The positions of the cache buffer (for this B * H) and the vector within that chunk associated with this // thread. - const auto k_idx = chunk_index(tidx); + auto const k_idx = chunk_index(tidx); // The number of vectors per thread. constexpr unsigned K_VECS_PER_THREAD{Dh_MAX / K_ELTS_PER_CHUNK}; @@ -1830,7 +1830,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #pragma unroll for (unsigned ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = vec_conversion(*reinterpret_cast( + q_vec[ii] = vec_conversion(*reinterpret_cast( &q_smem[tensorrt_llm::common::flat_index2(ii, k_idx.y, K_ELTS_PER_CHUNK)])); } @@ -1843,11 +1843,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The number of unrolled keys per ieration. constexpr unsigned UNROLLED_K_PER_ITER = K_PER_ITER * K_LOOP_UNROLL; - const auto timesteps_per_block = static_cast(params.timesteps_per_block); + auto const timesteps_per_block = static_cast(params.timesteps_per_block); // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible. - const int context_length + int const context_length = DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length); // Clarifications: // - in self attn, input_length is input text length, tlength is current timestep @@ -1860,31 +1860,31 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // - for cross attn, no-beam/beam search: cache length is fixed, not differ context/generation cache --> // context_length = tlength Suggestion: we could have a flag HANDLE_GEN_CACHE - const auto context_ti_end = MULTI_BLOCK_FLAG + auto const context_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP : divUp(static_cast(context_length), UNROLLED_K_PER_WARP) * UNROLLED_K_PER_WARP; // The generation ti_end. - const auto generation_ti_end = MULTI_BLOCK_FLAG + auto const generation_ti_end = MULTI_BLOCK_FLAG ? divUp(timesteps_per_block, K_PER_WARP) * K_PER_WARP : divUp(static_cast(kv_loop_length), K_PER_WARP) * K_PER_WARP; // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. // Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers. // By default, you can assume that they are the same. - const auto bi_seq_len_offset = static_cast(batch_beam_idx) * params.max_attention_window_size; + auto const bi_seq_len_offset = static_cast(batch_beam_idx) * params.max_attention_window_size; // Beam indices are based on the max_attention_window_size while each layer may have different // cyclic_attention_window_size So we need to rebuild the beam_indices if max_attention_window_size is not equal to // cyclic_attention_window_size. - const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + int const* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - const auto c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG + auto const c_tile_times_timesteps_per_block = c_tile * timesteps_per_block; // 0 if !MULTI_BLOCK_FLAG //////////////////////////////////////////////////////////////////////////////////////////////// // Key cache loops for dot(Q, K). // Is it the leader? - const bool is_leader = Qk_dot::is_leader(tidx); + bool const is_leader = Qk_dot::is_leader(tidx); // The slope for ALiBi. float linear_bias_slope = 0.f; @@ -1899,7 +1899,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Explicit batching of LDGs (by K_LOOP_UNROLL) as it doesn't depend on indirection tables. for (int ti = k_idx.x; ti < context_ti_end; ti += UNROLLED_K_PER_ITER) { - const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; + int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; // The keys loaded from the key cache. K_vec_m k_vec_cache[K_LOOP_UNROLL][K_VECS_PER_THREAD]; @@ -1926,21 +1926,21 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske valid_time_now = pastKCache.getKVTokenIdx(valid_time_now); } } - const int seqIdx = batch_idx * beam_width; + int const seqIdx = batch_idx * beam_width; // Base pointer to k cache block for beam's batch TKcache* k_cache_batch = reinterpret_cast(pastKCache.getKBlockPtr(seqIdx, valid_time_now)); int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj); - k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast(&k_cache_batch[inBlockIdx]); + k_vec_cache[k_loop][k_vec_i] = *reinterpret_cast(&k_cache_batch[inBlockIdx]); } } #pragma unroll for (int k_loop = 0; k_loop < K_LOOP_UNROLL; ++k_loop) { - const int local_time_now = time_now + k_loop * K_PER_ITER; - const int local_ti = ti + k_loop * K_PER_ITER; + int const local_time_now = time_now + k_loop * K_PER_ITER; + int const local_ti = ti + k_loop * K_PER_ITER; // Perform the dot product and normalize qk. // @@ -1953,7 +1953,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } // Is it active? - const bool is_active = local_time_now < context_length; + bool const is_active = local_time_now < context_length; if constexpr (IMPLICIT_REL_ATTN_BIAS) { @@ -2044,14 +2044,14 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske && (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length)) { // The input length; - const int input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length; + int const input_length_ = MULTI_BLOCK_FLAG ? beam0_context_length % timesteps_per_block : beam0_context_length; // The beginning of the generation. - const int generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP; + int const generation_start_ti = k_idx.x + input_length_ / K_PER_WARP * K_PER_WARP; // Iterate over the output tokens. for (int ti = generation_start_ti; ti < generation_ti_end; ti += K_PER_ITER) { - const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; + int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; // The keys loaded from the key cache. K_vec_m k_vec[K_VECS_PER_THREAD]; @@ -2059,7 +2059,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #pragma unroll for (int k_vec_i = 0; k_vec_i < K_VECS_PER_THREAD; ++k_vec_i) { - const int jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE); + int const jj = min(k_idx.y + k_vec_i * K_ELTS_PER_CHUNK, Dh - K_VEC_SIZE); int valid_time_now = min(time_now, kv_loop_length - 1); int beam_offset = beam_indices[valid_time_now]; if (POS_SHIFT && valid_time_now >= sink_token_len) @@ -2073,16 +2073,16 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske valid_time_now = pastKCache.getKVTokenIdx(valid_time_now); } } - const int seqIdx = batch_idx * beam_width + beam_offset; + int const seqIdx = batch_idx * beam_width + beam_offset; // Base pointer to k cache block for beam's batch, before offsetting with indirection buffer TKcache* k_cache_batch = reinterpret_cast(pastKCache.getKBlockPtr(seqIdx, valid_time_now)); int inBlockIdx = pastKCache.getKVLocalIdx(valid_time_now, hi_kv, Dh, jj); - k_vec[k_vec_i] = (*reinterpret_cast(&k_cache_batch[inBlockIdx])); + k_vec[k_vec_i] = (*reinterpret_cast(&k_cache_batch[inBlockIdx])); } // Is it active? - const bool is_active = time_now >= context_length && time_now < kv_loop_length; + bool const is_active = time_now >= context_length && time_now < kv_loop_length; if constexpr (IMPLICIT_REL_ATTN_BIAS) { @@ -2192,8 +2192,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #endif // defined MMHA_USE_HMMA // Decompose the thread index into warp and lane. - const auto warp = tidx / WARP_SIZE; - const auto lane = tidx % WARP_SIZE; + auto const warp = tidx / WARP_SIZE; + auto const lane = tidx % WARP_SIZE; // The warp leader writes the max to shared memory. if (lane == 0) @@ -2217,8 +2217,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { // Trigger the stores to global memory. Qk_vec_k k_vec = *reinterpret_cast(&k_smem[qk_vec_idx]); - const auto k_idx = QK_VEC_SIZE * tidx; - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx); + auto const k_idx = QK_VEC_SIZE * tidx; + int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, k_idx); // The base pointer for the value in the cache buffer. Tcache* k_cache = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, cyclic_tlength)); @@ -2247,11 +2247,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float sum = 0.f; // Each thread will handle one float (either qk_smem/logit). - const int logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; + int const logit_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; for (int ti = tidx; ti <= logit_loop_end; ti += THREADS_PER_BLOCK) { - const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; + int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; // For single-block mode, we don't need the mask since it has been skipped. if (!MULTI_BLOCK_FLAG) @@ -2289,10 +2289,10 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske #endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V float inv_sum = __fdividef(logit_scale, sum + 1.e-6f); - const int normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; + int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; for (int ti = tidx; ti <= normlization_loop_end; ti += THREADS_PER_BLOCK) { - const int time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; + int const time_now = MULTI_BLOCK_FLAG ? ti + c_tile_times_timesteps_per_block : ti; if (!MULTI_BLOCK_FLAG) { @@ -2315,11 +2315,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Put Values part below so we leverage __syncthreads // from the previous step - const auto v_idx = chunk_index(tidx); + auto const v_idx = chunk_index(tidx); // The value computed by this thread. - const auto vo = v_idx.x; + auto const vo = v_idx.x; // The hidden dimensions computed by this particular thread. - const auto vi = v_idx.y; + auto const vi = v_idx.y; // The number of values processed per iteration of the loop. constexpr unsigned V_PER_ITER{THREADS_PER_BLOCK / THREADS_PER_VALUE}; @@ -2337,8 +2337,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Trigger the loads from the V bias buffer. if (params.v_bias != nullptr) { - const auto v_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, vi, Dh); - v_bias = *reinterpret_cast(¶ms.v_bias[v_bias_offset]); + auto const v_bias_offset = tensorrt_llm::common::flat_index2(hi_kv, vi, Dh); + v_bias = *reinterpret_cast(¶ms.v_bias[v_bias_offset]); } if (DO_CROSS_ATTENTION) @@ -2370,7 +2370,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Handle both context and generation value cache without beam searching. // Explicit batching of LDGs (by V_LOOP_UNROLL) as it doesn't depend on indirection tables. // Take all previous cache as context when we have no beam searching in order to batch as many LDGs as possible. - const int context_length + int const context_length = DO_CROSS_ATTENTION ? kv_loop_length : (HAS_BEAMS ? beam0_context_length : kv_loop_length); int context_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : context_length; int generation_v_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; @@ -2396,11 +2396,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } int rowIdx = batch_idx * beam_width; - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi); + int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi); // The base pointer for the value in the cache buffer. Tcache* v_cache_batch = reinterpret_cast(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx)); - v_vec_cache[v_loop] = *reinterpret_cast(&v_cache_batch[inBlockIdx]); + v_vec_cache[v_loop] = *reinterpret_cast(&v_cache_batch[inBlockIdx]); } #pragma unroll @@ -2411,7 +2411,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske int local_time_idx = ti + v_loop * V_PER_ITER; int time_idx = local_time_idx + (MULTI_BLOCK_FLAG ? c_tile_times_timesteps_per_block : 0); - const bool is_mask + bool const is_mask = (MULTI_BLOCK_FLAG && local_time_idx >= timesteps_per_block) || (time_idx >= context_length); // Load the logits from shared memory. @@ -2424,7 +2424,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Handle generation value cache with beam searching. if (HAS_BEAMS && !DO_CROSS_ATTENTION) { - const auto generation_start_ti + auto const generation_start_ti = MULTI_BLOCK_FLAG ? vo : (vo + (beam0_context_length / V_PER_ITER) * V_PER_ITER); // Only the last few blocks need to handle the generation value cache. if (!MULTI_BLOCK_FLAG || (c_tile + 1) * timesteps_per_block > beam0_context_length) @@ -2452,10 +2452,10 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske } } - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi); + int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(time_idx, hi_kv, Dh, vi); // The base pointer for the value in the cache buffer. Tcache* v_cache_batch = reinterpret_cast(kvCacheBuffer.getVBlockPtr(rowIdx, time_idx)); - V_vec_m v_vec = reinterpret_cast(&v_cache_batch[inBlockIdx])[0]; + V_vec_m v_vec = reinterpret_cast(&v_cache_batch[inBlockIdx])[0]; // Load the logits from shared memory. // Note that fma will convert 8bit vec to the accumulation data type (float by default). @@ -2470,19 +2470,19 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske __syncthreads(); // Get the c_tile_id that handles the current timestep. - const int ctile_idx = tlength / timesteps_per_block; + int const ctile_idx = tlength / timesteps_per_block; // One group of threads computes the product(s) for the current timestep. if (vo == kv_loop_length % V_PER_ITER && is_valid_vi && (!MULTI_BLOCK_FLAG || (c_tile == ctile_idx))) { - const int inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, vi); + int const inBlockIdx = kvCacheBuffer.getKVLocalIdx(cyclic_tlength, hi_kv, Dh, vi); // The base pointer for the value in the cache buffer. Tcache* v_cache_base = reinterpret_cast(kvCacheBuffer.getVBlockPtr(batch_beam_idx, cyclic_tlength)); V_vec_k v; if (DO_CROSS_ATTENTION) { - v = vec_conversion(*reinterpret_cast(&v_cache_base[inBlockIdx])); + v = vec_conversion(*reinterpret_cast(&v_cache_base[inBlockIdx])); } else { @@ -2490,21 +2490,21 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The stride between tokens. We may be able to always use params.stride. uint32_t v_stride = params.stride ? static_cast(params.stride) : (num_heads_kv * Dh); // The offset. - const auto v_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi_kv, vi, v_stride, Dh); + auto const v_offset = tensorrt_llm::common::flat_index_strided3(batch_beam_idx, hi_kv, vi, v_stride, Dh); if (load_qkv_quant) { using Packed_Int8_t = typename packed_type::value>::type; using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_quant_orig[2]; - const auto v_quant - = *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + auto const v_scaling = params.qkv_scale_quant_orig[2]; + auto const v_quant + = *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); convert_from_float(&v, mul(v_scaling, float_from_int8(v_quant))); } else { - v = *reinterpret_cast(¶ms.v[v_offset]); + v = *reinterpret_cast(¶ms.v[v_offset]); } } @@ -2516,7 +2516,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske if (do_ia3) { v = mul(v, - *reinterpret_cast( + *reinterpret_cast( ¶ms.ia3_value_weights[tensorrt_llm::common::flat_index2(ia3_ti_hi, vi, Dh)])); } } @@ -2584,17 +2584,17 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // The bottom warps update their values. if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); } __syncthreads(); } - const auto bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads); - const auto bhi_seq_len_tile = bhi * params.seq_len_tile; + auto const bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads); + auto const bhi_seq_len_tile = bhi * params.seq_len_tile; // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - const auto bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh); + auto const bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh); #ifdef MMHA_USE_FP32_ACCUM_FOR_OUT if (write_attention_quant) { @@ -2693,7 +2693,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske if (tidx < gridDim.z) { thread_partial_max = params.partial_max[bhi_seq_len_tile + tidx]; - const auto thread_partial_sum = params.partial_sum[bhi_seq_len_tile + tidx]; + auto const thread_partial_sum = params.partial_sum[bhi_seq_len_tile + tidx]; final_sum += __expf(thread_partial_max - final_max) * thread_partial_sum; } @@ -2708,7 +2708,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Shared memory to store partial outputs for each oi. -> size: gridDim.z * Dh * 4 Bytes. Reuse qk_smem. T* out_oi_smem = reinterpret_cast(smem_); - const auto o_idx = chunk_index(tidx); + auto const o_idx = chunk_index(tidx); // Init partial out for accumulation. V_vec_k zero_k; @@ -2716,10 +2716,10 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske V_vec_k thread_accumulated_out = zero_k; // The hidden dimensions computed by this particular thread. (refer to vi) - const auto oi = o_idx.y; + auto const oi = o_idx.y; // The partial output region this thread takes care of - const auto oo = o_idx.x; + auto const oo = o_idx.x; // Each thread may handle more than one partial output. for (int tile_idx = o_idx.x; tile_idx < gridDim.z; tile_idx += V_PER_ITER) @@ -2730,7 +2730,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske float thread_partial_max_for_out = params.partial_max[bhi_seq_len_tile + tile_idx]; // Load the partial outputs. V_vec_k thread_partial_out - = *reinterpret_cast(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]); + = *reinterpret_cast(¶ms.partial_out[thread_partial_out_offset + bhi * Dh + oi]); // Apply the correction factor. Tk factor_compute; convert_from_float(&factor_compute, __expf(thread_partial_max_for_out - final_max)); @@ -2757,7 +2757,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske if (oo < midpoint && (Dh == Dh_MAX || oi < Dh)) { thread_accumulated_out - = add(thread_accumulated_out, *reinterpret_cast(&out_oi_smem[oo * Dh + oi])); + = add(thread_accumulated_out, *reinterpret_cast(&out_oi_smem[oo * Dh + oi])); } __syncthreads(); } @@ -2768,7 +2768,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske if (oo == 0 && (Dh == Dh_MAX || oi < Dh)) { - const auto inv_sum = __fdividef(1.f, final_sum + 1.e-6f); + auto const inv_sum = __fdividef(1.f, final_sum + 1.e-6f); Tk inv_sum_compute; convert_from_float(&inv_sum_compute, inv_sum); diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.cpp index d74c64ffc..55c3a7c3c 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.cpp @@ -27,15 +27,15 @@ namespace kernels { template <> -void DecoderXQAImpl::run(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) +void DecoderXQAImpl::run(XQAParams const& xqa_params, KVLinearBuffer& kv_linear_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) { runWithKVLinearBuffer(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream); } template <> -void DecoderXQAImpl::run(const XQAParams& xqa_params, KVBlockArray& kv_block_array, int2& rotary_kernel_launch_cache, - const cudaStream_t& stream) +void DecoderXQAImpl::run(XQAParams const& xqa_params, KVBlockArray& kv_block_array, int2& rotary_kernel_launch_cache, + cudaStream_t const& stream) { runWithKVBlockArray(xqa_params, kv_block_array, rotary_kernel_launch_cache, stream); } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h index a9d6eebb8..3fddebea3 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImpl.h @@ -41,15 +41,15 @@ class DecoderXQAImpl { public: // Whether it is beneficial to use this XQA codepath. - virtual bool shouldUse(const XQAParams& xqaParams) = 0; + virtual bool shouldUse(XQAParams const& xqaParams) = 0; // Prepares for the kernel running. Must be called before calling run. - virtual void prepare(const XQAParams& xqa_params) = 0; + virtual void prepare(XQAParams const& xqa_params) = 0; // Run XQA kernel with KVCacheBuffer. // // Sub-classes should implement runWithKVLinearBuffer and runWithKVBlockArray. template - void run(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, int2& rotary_kernel_launch_cache, - const cudaStream_t& stream); + void run(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, int2& rotary_kernel_launch_cache, + cudaStream_t const& stream); enum class ImplType { @@ -65,11 +65,11 @@ class DecoderXQAImpl { } - virtual void runWithKVLinearBuffer(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) + virtual void runWithKVLinearBuffer(XQAParams const& xqa_params, KVLinearBuffer& kv_linear_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) = 0; - virtual void runWithKVBlockArray(const XQAParams& xqa_params, KVBlockArray& kv_block_array, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) + virtual void runWithKVBlockArray(XQAParams const& xqa_params, KVBlockArray& kv_block_array, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) = 0; DecoderXQARunner* mRunner; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp index 912482073..12b77109f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp @@ -45,7 +45,7 @@ struct XQAKernelLoadHashKey struct XQAKernelLoadHasher { - size_t operator()(const XQAKernelLoadHashKey& s) const + size_t operator()(XQAKernelLoadHashKey const& s) const { size_t key = s.data_type; key <<= 16; @@ -76,7 +76,7 @@ struct XQAKernelRuntimeHashKey struct XQAKernelRuntimeHasher { - size_t operator()(const XQAKernelRuntimeHashKey& s) const + size_t operator()(XQAKernelRuntimeHashKey const& s) const { size_t key = s.kv_data_type; key <<= 16; @@ -120,24 +120,24 @@ struct XQALaunchParam { uint32_t num_k_heads; void* output; - const void* qkv; + void const* qkv; KVCache kvCacheParams; std::optional beamSearchParams; uint32_t batch_size; - const float* kv_scale_quant_orig = nullptr; + float const* kv_scale_quant_orig = nullptr; void* scratch = nullptr; }; // Setup launch params. template -void buildXQALaunchParams(XQALaunchParam& launchParams, const XQAParams& params, KVCacheBuffer kv_cache_buffer) +void buildXQALaunchParams(XQALaunchParam& launchParams, XQAParams const& params, KVCacheBuffer kv_cache_buffer) { TLLM_CHECK_WITH_INFO( params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16, "Only fp16 or bf16 supported now."); memset(&launchParams, 0, sizeof(XQALaunchParam)); launchParams.num_k_heads = params.num_kv_heads; launchParams.output = static_cast(params.output); - launchParams.qkv = static_cast(params.qkv); + launchParams.qkv = static_cast(params.qkv); launchParams.batch_size = params.batch_size; launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig; launchParams.scratch = params.workspaces; @@ -179,7 +179,7 @@ class XQAKernelList } for (unsigned int i = 0; i < mKernelMetaCount; ++i) { - const auto& kernelMeta = mKernelMeta[i]; + auto const& kernelMeta = mKernelMeta[i]; if (kernelMeta.mSM != mSM || kernelMeta.mDataType != mDataType) continue; @@ -220,7 +220,7 @@ class XQAKernelList } } - bool supportConfig(const XQAParams& xqaParams) const + bool supportConfig(XQAParams const& xqaParams) const { unsigned int head_size = xqaParams.head_size; int num_q_heads = xqaParams.num_q_heads; @@ -237,11 +237,11 @@ class XQAKernelList = {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize, xqaParams.paged_kv_cache ? static_cast(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache, xqaParams.multi_query_tokens}; - const auto findIter = mFunctions.find(hash_key); + auto const findIter = mFunctions.find(hash_key); return findIter != mFunctions.end(); } - bool mayHavePerfGain(const XQAParams& xqaParams, int multiprocessor_count) const + bool mayHavePerfGain(XQAParams const& xqaParams, int multiprocessor_count) const { // NOTE: only XQA supports multi_query_tokens (Medusa mode). if (mForceXQA || xqaParams.multi_query_tokens) @@ -261,8 +261,8 @@ class XQAKernelList } template - void run(const XQAParams& xqaParams, KVCacheBuffer& kv_cache_buffer, int2& rotary_kernel_launch_cache, - int multiprocessor_count, const cudaStream_t& stream) const + void run(XQAParams const& xqaParams, KVCacheBuffer& kv_cache_buffer, int2& rotary_kernel_launch_cache, + int multiprocessor_count, cudaStream_t const& stream) const { unsigned int head_size = xqaParams.head_size; int num_q_heads = xqaParams.num_q_heads; @@ -280,7 +280,7 @@ class XQAKernelList void const* xqa_q_input_ptr = xqaParams.output; invokeApplyBiasRopeUpdateKVCache(static_cast(const_cast(xqaParams.qkv)), static_cast(const_cast(xqaParams.output)), kv_cache_buffer, - static_cast(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr, nullptr, + static_cast(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr, nullptr, xqaParams.batch_size, xqaParams.generation_input_length, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length, xqaParams.batch_size * beam_width * xqaParams.generation_input_length, xqaParams.num_q_heads, xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim, @@ -302,13 +302,13 @@ class XQAKernelList kernel_num_q_heads_over_kv, kernel_m_tilesize, xqaParams.paged_kv_cache ? static_cast(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache, xqaParams.multi_query_tokens}; - const auto findIter = mFunctions.find(hash_key); + auto const findIter = mFunctions.find(hash_key); TLLM_CHECK_WITH_INFO(findIter != mFunctions.end(), "XQAKernelFunc not found."); - const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; + auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; const CUfunction func = findIter->second.mDeviceFunction; - const unsigned int shared_mem_bytes = findIter->second.mSharedMemBytes; + unsigned int const shared_mem_bytes = findIter->second.mSharedMemBytes; XQALaunchParam launchParams; buildXQALaunchParams(launchParams, xqaParams, kv_cache_buffer); @@ -373,7 +373,7 @@ class XQAKernelList sync_check_cuda_error(); } - static int computeMultiBlockCount(const XQAParams& xqaParams, int batch_size, int multiprocessor_count) + static int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size, int multiprocessor_count) { int multi_block_count = 1; int num_kv_heads = xqaParams.num_kv_heads; @@ -401,10 +401,10 @@ class XQAKernelList tensorrt_llm::common::CUDADriverWrapper mDriver; Data_type mDataType; - const TKernelMeta* mKernelMeta; + TKernelMeta const* mKernelMeta; unsigned int mKernelMetaCount; unsigned int mSM; - std::unordered_map mModules; + std::unordered_map mModules; bool mForceXQA = false; @@ -421,14 +421,14 @@ class XQAKernelList class XQAKernelLoader { public: - const XQAKernelList* getXQAKernels(Data_type type, unsigned int sm) + XQAKernelList const* getXQAKernels(Data_type type, unsigned int sm) { static std::mutex s_mutex; std::lock_guard lg(s_mutex); XQAKernelLoadHashKey hash_key{type, sm}; - const auto findIter = mKernels.find(hash_key); + auto const findIter = mKernels.find(hash_key); if (findIter == mKernels.end()) { XQAKernelList* newKernel = new XQAKernelList{type, sm}; @@ -458,7 +458,7 @@ class XQAKernelLoader std::unordered_map, XQAKernelLoadHasher> mKernels; }; -inline const XQAKernelList* getXQAKernels(Data_type type, unsigned int sm) +inline XQAKernelList const* getXQAKernels(Data_type type, unsigned int sm) { return XQAKernelLoader::Get().getXQAKernels(type, sm); } @@ -468,10 +468,10 @@ inline const XQAKernelList* getXQAKernels(Data_type type, unsigned int sm) xqa_params, kv_cache_buffer, rotary_kernel_launch_cache, multi_processor_count, stream); template -void DecoderXQAImplPrecompiled::runDispatchBuffer(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) +void DecoderXQAImplPrecompiled::runDispatchBuffer(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) { - const XQAKernelList* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion()); + XQAKernelList const* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion()); int multi_processor_count = mRunner->mMultiProcessorCount; if (mRunner->mDataType == DATA_TYPE_FP16) { @@ -485,26 +485,26 @@ void DecoderXQAImplPrecompiled::runDispatchBuffer(const XQAParams& xqa_params, K #undef XQA_KERNEL_RUN -bool DecoderXQAImplPrecompiled::shouldUse(const XQAParams& xqaParams) +bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams) { - const XQAKernelList* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion()); + XQAKernelList const* xqa_kernel = getXQAKernels(mRunner->mDataType, tensorrt_llm::common::getSMVersion()); return xqa_kernel->supportConfig(xqaParams) && xqa_kernel->mayHavePerfGain(xqaParams, mRunner->mMultiProcessorCount); } -void DecoderXQAImplPrecompiled::prepare(const XQAParams&) +void DecoderXQAImplPrecompiled::prepare(XQAParams const&) { // Intentionally do nothing. } -void DecoderXQAImplPrecompiled::runWithKVLinearBuffer(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) +void DecoderXQAImplPrecompiled::runWithKVLinearBuffer(XQAParams const& xqa_params, KVLinearBuffer& kv_linear_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) { runDispatchBuffer(xqa_params, kv_linear_buffer, rotary_kernel_launch_cache, stream); } -void DecoderXQAImplPrecompiled::runWithKVBlockArray(const XQAParams& xqa_params, KVBlockArray& kv_block_array, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) +void DecoderXQAImplPrecompiled::runWithKVBlockArray(XQAParams const& xqa_params, KVBlockArray& kv_block_array, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) { runDispatchBuffer(xqa_params, kv_block_array, rotary_kernel_launch_cache, stream); } diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h index 04380ea6a..d4beddfd2 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.h @@ -29,19 +29,19 @@ class DecoderXQAImplPrecompiled : public DecoderXQAImpl { } - bool shouldUse(const XQAParams& xqaParams) override; - void prepare(const XQAParams& xqa_params) override; + bool shouldUse(XQAParams const& xqaParams) override; + void prepare(XQAParams const& xqa_params) override; protected: - void runWithKVLinearBuffer(const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) override; - void runWithKVBlockArray(const XQAParams& xqa_params, KVBlockArray& kv_block_array, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream) override; + void runWithKVLinearBuffer(XQAParams const& xqa_params, KVLinearBuffer& kv_linear_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) override; + void runWithKVBlockArray(XQAParams const& xqa_params, KVBlockArray& kv_block_array, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream) override; private: template - void runDispatchBuffer(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, - int2& rotary_kernel_launch_cache, const cudaStream_t& stream); + void runDispatchBuffer(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, + int2& rotary_kernel_launch_cache, cudaStream_t const& stream); }; } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp index 734afa635..274264c8f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp @@ -76,7 +76,7 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size) if (mMultiBlockMode) { int workspaces[4]; - const int max_num_request = max_batch_beam_size; + int const max_num_request = max_batch_beam_size; uint32_t const nbSeq = mNumKVHeads * max_num_request; uint32_t const nbSubSeq = kMaxNbCtaPerKVHeadFactor * nbSeq; int group_size = mNumHeads / mNumKVHeads; @@ -90,26 +90,26 @@ size_t DecoderXQARunner::getWorkspaceSize(int max_batch_beam_size) return workspace_size; } -bool DecoderXQARunner::shouldUseImpl(const XQAParams& xqaParams) +bool DecoderXQARunner::shouldUseImpl(XQAParams const& xqaParams) { return mImpl->shouldUse(xqaParams); } -void DecoderXQARunner::prepareForRun(const XQAParams& xqa_params) +void DecoderXQARunner::prepareForRun(XQAParams const& xqa_params) { return mImpl->prepare(xqa_params); } template -void DecoderXQARunner::run(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream) +void DecoderXQARunner::run(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, cudaStream_t const& stream) { return mImpl->run(xqa_params, kv_cache_buffer, mLaunchGridBlockCache, stream); } template void DecoderXQARunner::run( - const XQAParams& xqa_params, KVLinearBuffer& kv_linear_buffer, const cudaStream_t& stream); + XQAParams const& xqa_params, KVLinearBuffer& kv_linear_buffer, cudaStream_t const& stream); template void DecoderXQARunner::run( - const XQAParams& xqa_params, KVBlockArray& kv_block_array, const cudaStream_t& stream); + XQAParams const& xqa_params, KVBlockArray& kv_block_array, cudaStream_t const& stream); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h index 808f23384..42ea7ae31 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.h @@ -85,7 +85,7 @@ class DecoderXQARunner * enqueueGeneration. */ template - bool shouldUse(const XQAParams& xqaParams, bool forConfigurePlugin) + bool shouldUse(XQAParams const& xqaParams, bool forConfigurePlugin) { if (!(xqaParams.data_type == DATA_TYPE_FP16 || xqaParams.data_type == DATA_TYPE_BF16)) { @@ -137,9 +137,9 @@ class DecoderXQARunner // OPTIMIZE: For the standard generation-phase MHA, there are still extra limitations. // NOTE: Medusa mode = Multi_query_tokens > 1. - const int nbQHeads = xqaParams.num_q_heads; - const int nbKVHeads = xqaParams.num_kv_heads; - const int nbQHeadsPerKV = nbQHeads / nbKVHeads; + int const nbQHeads = xqaParams.num_q_heads; + int const nbKVHeads = xqaParams.num_kv_heads; + int const nbQHeadsPerKV = nbQHeads / nbKVHeads; if (!xqaParams.multi_query_tokens) { if (nbQHeadsPerKV != 8 && nbQHeadsPerKV != 1) @@ -160,7 +160,7 @@ class DecoderXQARunner size_t getWorkspaceSize(int max_batch_beam_size); - void prepare(const XQAParams& xqa_params) + void prepare(XQAParams const& xqa_params) { if (!mPrepareCalled) { @@ -170,7 +170,7 @@ class DecoderXQARunner } template - void dispatch(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream) + void dispatch(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, cudaStream_t const& stream) { if (!mPrepareCalled) { @@ -181,11 +181,11 @@ class DecoderXQARunner } private: - bool shouldUseImpl(const XQAParams& xqa_params); - void prepareForRun(const XQAParams& xqa_params); + bool shouldUseImpl(XQAParams const& xqa_params); + void prepareForRun(XQAParams const& xqa_params); template - void run(const XQAParams& xqa_params, KVCacheBuffer& kv_cache_buffer, const cudaStream_t& stream); + void run(XQAParams const& xqa_params, KVCacheBuffer& kv_cache_buffer, cudaStream_t const& stream); static constexpr int kMaxBeamWidth = 4; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h index 10b22ac57..ada4f02ba 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h @@ -30,12 +30,12 @@ struct XQAParams XQADataType data_type = DATA_TYPE_FP16; XQADataType kv_cache_data_type = DATA_TYPE_FP16; void* output = nullptr; - const void* qkv = nullptr; - const int32_t* cache_indir = nullptr; - const float* kv_scale_orig_quant = nullptr; - const float* kv_scale_quant_orig = nullptr; - const int32_t* host_past_key_value_lengths = nullptr; - const int32_t* host_context_lengths = nullptr; + void const* qkv = nullptr; + int32_t const* cache_indir = nullptr; + float const* kv_scale_orig_quant = nullptr; + float const* kv_scale_quant_orig = nullptr; + int32_t const* host_past_key_value_lengths = nullptr; + int32_t const* host_context_lengths = nullptr; void* workspaces = nullptr; uint32_t batch_size = 0; int32_t beam_width = 0; @@ -43,12 +43,12 @@ struct XQAParams int32_t cyclic_attention_window_size = 0; int32_t sink_token_length = 0; int timestep = 0; - const void* qkv_bias; - const int32_t* sequence_lengths; // - const int32_t* context_lengths; // maybe not used now - const void* alibi_slopes; // maybe not used now - const int32_t* medusa_packed_mask; - const int* medusa_position_offsets; // rotary embedding. + void const* qkv_bias; + int32_t const* sequence_lengths; // + int32_t const* context_lengths; // maybe not used now + void const* alibi_slopes; // maybe not used now + int32_t const* medusa_packed_mask; + int const* medusa_position_offsets; // rotary embedding. // almost copy from GPTAttentionPluginCommon. // maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here. diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h index bbd95d058..58d24363f 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h @@ -2541,27 +2541,27 @@ inline __device__ void zero(T& dst) //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float update_rotary_base( - const int kv_seq_len, const int max_positions, const int embed_dim, const float base, const float scale) + int const kv_seq_len, int const max_positions, int const embed_dim, float const base, float const scale) { - const float b = (scale * kv_seq_len / max_positions) - (scale - 1); - const float p = static_cast(embed_dim) / (embed_dim - 2); + float const b = (scale * kv_seq_len / max_positions) - (scale - 1); + float const p = static_cast(embed_dim) / (embed_dim - 2); return base * __powf(b, p); } //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ float2 update_dynamic_scaling_rotary(float base, float scale, const int kv_seq_len, - const int max_positions, const int embed_dim, const bool dynamic_scaling) +inline __device__ float2 update_dynamic_scaling_rotary(float base, float scale, int const kv_seq_len, + int const max_positions, int const embed_dim, bool const dynamic_scaling) { - const float b = kv_seq_len * __fdividef(scale, max_positions) - (scale - 1); - const float p = __fdividef(embed_dim, embed_dim - 2); - const float updated_base = dynamic_scaling ? base * __powf(b, p) : base; - const float updated_scale = dynamic_scaling ? 1.0f : scale; + float const b = kv_seq_len * __fdividef(scale, max_positions) - (scale - 1); + float const p = __fdividef(embed_dim, embed_dim - 2); + float const updated_base = dynamic_scaling ? base * __powf(b, p) : base; + float const updated_scale = dynamic_scaling ? 1.0f : scale; return {updated_base, updated_scale}; } inline __device__ void update_rotary_base_n_scale(float& base, float& scale, RotaryScalingType const scale_type, - const int rot_embed_dim, const int max_positions, const int seq_len) + int const rot_embed_dim, int const max_positions, int const seq_len) { // only update the base and/or scale if needed based on scale_type if (scale_type == RotaryScalingType::kDYNAMIC) @@ -2579,9 +2579,9 @@ inline __device__ void update_rotary_base_n_scale(float& base, float& scale, Rot } inline __device__ float2 rotary_embedding_coefficient( - const int zid, const int rot_embed_dim, const float base, const float scale, const float t_step) + int const zid, int const rot_embed_dim, float const base, float const scale, float const t_step) { - const float inv_freq = float(t_step * scale) / powf(base, zid / (float) rot_embed_dim); + float const inv_freq = float(t_step * scale) / powf(base, zid / (float) rot_embed_dim); return {cosf(inv_freq), sinf(inv_freq)}; } @@ -2627,7 +2627,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); } @@ -2638,7 +2638,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } @@ -2652,9 +2652,9 @@ inline __device__ void apply_rotary_embedding( } Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q_.y = rotary_embedding_transform(q_.y, coef1); } @@ -2668,10 +2668,10 @@ inline __device__ void apply_rotary_embedding( Float4_& q_ = *reinterpret_cast(&q); Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q_.x = rotary_embedding_transform(q_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q_.y = rotary_embedding_transform(q_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1); } @@ -2686,16 +2686,16 @@ inline __device__ void apply_rotary_embedding( Float8_& q_ = *reinterpret_cast(&q); Float8_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); q_.x = rotary_embedding_transform(q_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); q_.y = rotary_embedding_transform(q_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); + auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); q_.z = rotary_embedding_transform(q_.z, coef2); k_.z = rotary_embedding_transform(k_.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); + auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); q_.w = rotary_embedding_transform(q_.w, coef3); k_.w = rotary_embedding_transform(k_.w, coef3); } @@ -2707,7 +2707,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); } @@ -2718,7 +2718,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } @@ -2741,9 +2741,9 @@ inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_d { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); } @@ -2754,10 +2754,10 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); } @@ -2768,13 +2768,13 @@ inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_d { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); + auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); + auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); q.w = rotary_embedding_transform(q.w, coef3); } @@ -2785,16 +2785,16 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); + auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); + auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } @@ -2807,7 +2807,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); } @@ -2818,7 +2818,7 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); + auto const coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, base, scale, t_step); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } @@ -2830,9 +2830,9 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); } @@ -2843,10 +2843,10 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); } @@ -2858,13 +2858,13 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); + auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); + auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); q.w = rotary_embedding_transform(q.w, coef3); } @@ -2875,16 +2875,16 @@ inline __device__ void apply_rotary_embedding( { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); + auto const coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, base, scale, t_step); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); + auto const coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, base, scale, t_step); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); + auto const coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, base, scale, t_step); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); + auto const coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, base, scale, t_step); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } @@ -3061,21 +3061,21 @@ inline __device__ void apply_rotary_embedding_gptneox(Vec_type& q, Vec_type& k, } } - const int half_rotary_dim = rotary_embedding_dim / 2; + int const half_rotary_dim = rotary_embedding_dim / 2; #pragma unroll for (int elt_id = 0; elt_id < VEC_SIZE; elt_id++) { // Pack two elements for calculation (only one if each the thread only gets one element) // Assume the head size (or rotary embedding) is multiple of 8. - const int rotary_emd_pos0_id + int const rotary_emd_pos0_id = (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 0 - int(!first_half) * half_rotary_dim) * 2; - const int rotary_emd_pos1_id + int const rotary_emd_pos1_id = (tidx * VEC_SIZE * PACKED_ELT_SIZE + elt_id * PACKED_ELT_SIZE + 1 - int(!first_half) * half_rotary_dim) * 2; - const bool valid_rotary_pos = rotary_emd_pos1_id < rotary_embedding_dim; + bool const valid_rotary_pos = rotary_emd_pos1_id < rotary_embedding_dim; Packed_type q_ = reinterpret_cast(&q)[elt_id]; Packed_type q_pair_ = reinterpret_cast(&q_pair)[elt_id]; @@ -3356,7 +3356,7 @@ inline __device__ void convert_from_fp8(uint32_t* v, const fp8_2_t u) inline __device__ void convert_from_fp8(uint2* v, const fp8_4_t u) { uint32_t* v_ptr = reinterpret_cast(v); - const fp8_2_t* u_ptr = reinterpret_cast(&u); + fp8_2_t const* u_ptr = reinterpret_cast(&u); convert_from_fp8(v_ptr + 0, u_ptr[0]); convert_from_fp8(v_ptr + 1, u_ptr[1]); @@ -3367,7 +3367,7 @@ inline __device__ void convert_from_fp8(uint2* v, const fp8_4_t u) inline __device__ void convert_from_fp8(uint4* v, const fp8_8_t u) { uint32_t* v_ptr = reinterpret_cast(v); - const fp8_2_t* u_ptr = reinterpret_cast(&u); + fp8_2_t const* u_ptr = reinterpret_cast(&u); convert_from_fp8(v_ptr + 0, u_ptr[0]); convert_from_fp8(v_ptr + 1, u_ptr[1]); @@ -3403,7 +3403,7 @@ inline __device__ void convert_from_fp8(bf16_4_t* v, const fp8_4_t u) { __nv_bfloat162* v2 = reinterpret_cast<__nv_bfloat162*>(v); - const fp8_2_t* u2 = reinterpret_cast(&u); + fp8_2_t const* u2 = reinterpret_cast(&u); convert_from_fp8(v2, u2[0]); convert_from_fp8(v2 + 1, u2[1]); } @@ -3548,14 +3548,14 @@ inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const half u) inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const uint16_t u) { - v[0] = __nv_fp8_e4m3(reinterpret_cast(u)); + v[0] = __nv_fp8_e4m3(reinterpret_cast(u)); } //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ void convert_to_fp8(fp8_2_t* v, const uint32_t u) { - v[0] = fp8_2_t(reinterpret_cast(u)); + v[0] = fp8_2_t(reinterpret_cast(u)); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -3594,7 +3594,7 @@ inline __device__ void convert_to_fp8(fp8_8_t* v, const uint4 u) //////////////////////////////////////////////////////////////////////////////////////////////////// -inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, const float u) +inline __device__ void convert_to_fp8(__nv_fp8_e4m3* v, float const u) { v[0] = __nv_fp8_e4m3(u); } @@ -3694,13 +3694,13 @@ inline __device__ int64_t cast_to_int8(Float8_ val) //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const T* pointer, int idx, T_scale scale) +inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, T const* pointer, int idx, T_scale scale) { assert(false); // Not used. } template -inline __device__ void store_8bits_kv_cache_vec(T* pointer, const Vec_k& vec, int idx, T_scale scale) +inline __device__ void store_8bits_kv_cache_vec(T* pointer, Vec_k const& vec, int idx, T_scale scale) { assert(false); // Not used. } @@ -3708,11 +3708,11 @@ inline __device__ void store_8bits_kv_cache_vec(T* pointer, const Vec_k& vec, in //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const int8_t* pointer, int idx, float scale) +inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, int8_t const* pointer, int idx, float scale) { using Packed_8bits_t = typename packed_type::value>::type; using Packed_Float_t = typename packed_type::value>::type; - const auto quant = *reinterpret_cast(&pointer[idx]); + auto const quant = *reinterpret_cast(&pointer[idx]); convert_from_float(vec, mul(scale, float_from_int8(quant))); } @@ -3721,15 +3721,15 @@ inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const int8_t* pointer #ifdef ENABLE_FP8 template -inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const __nv_fp8_e4m3* pointer, int idx) +inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, __nv_fp8_e4m3 const* pointer, int idx) { using Packed_8bits_t = typename packed_type<__nv_fp8_e4m3, num_elems::value>::type; - const auto quant = *reinterpret_cast(&pointer[idx]); + auto const quant = *reinterpret_cast(&pointer[idx]); convert_from_fp8(vec, quant); } template -inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const __nv_fp8_e4m3* pointer, int idx, T_scale scale) +inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, __nv_fp8_e4m3 const* pointer, int idx, T_scale scale) { load_8bits_kv_cache_vec(vec, pointer, idx); vec[0] = mul(scale, vec[0]); @@ -3739,7 +3739,7 @@ inline __device__ void load_8bits_kv_cache_vec(Vec_k* vec, const __nv_fp8_e4m3* //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void store_8bits_kv_cache_vec(int8_t* pointer, const Vec_k& vec, int idx, float scale) +inline __device__ void store_8bits_kv_cache_vec(int8_t* pointer, Vec_k const& vec, int idx, float scale) { using Packed_8bits_t = typename packed_type::value>::type; using Packed_Float_t = typename packed_type::value>::type; @@ -3752,7 +3752,7 @@ inline __device__ void store_8bits_kv_cache_vec(int8_t* pointer, const Vec_k& ve #ifdef ENABLE_FP8 template -inline __device__ void store_8bits_kv_cache_vec(__nv_fp8_e4m3* pointer, const Vec_k& vec, int idx, T_scale scale) +inline __device__ void store_8bits_kv_cache_vec(__nv_fp8_e4m3* pointer, Vec_k const& vec, int idx, T_scale scale) { using Packed_8bits_t = typename packed_type<__nv_fp8_e4m3, num_elems::value>::type; Packed_8bits_t out_quant; @@ -3765,7 +3765,7 @@ inline __device__ void store_8bits_kv_cache_vec(__nv_fp8_e4m3* pointer, const Ve //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void convert_from_8bit_kv_cache(Vec_out* vec_o, const Vec_in& vec_i, T_scale scale) +inline __device__ void convert_from_8bit_kv_cache(Vec_out* vec_o, Vec_in const& vec_i, T_scale scale) { if constexpr (std::is_same::value) { @@ -3786,7 +3786,7 @@ inline __device__ void convert_from_8bit_kv_cache(Vec_out* vec_o, const Vec_in& } template -inline __device__ void convert_from_8bit_kv_cache(Vec_out* vec_o, const Vec_in& vec_i) +inline __device__ void convert_from_8bit_kv_cache(Vec_out* vec_o, Vec_in const& vec_i) { if constexpr (std::is_same::value) { @@ -4015,10 +4015,10 @@ __device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int } template -__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); +__device__ __inline__ void write_smem_transpose(Vec_T const& vec, T* smem, int transpose_idx, int smem_pitch); template <> -__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(float const& vec, float* smem, int transpose_idx, int smem_pitch) { return; } @@ -4026,7 +4026,7 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i #ifdef ENABLE_BF16 template <> __device__ __inline__ void write_smem_transpose( - const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) + bf16_4_t const& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) { smem[transpose_idx] = vec.x.x; smem[transpose_idx + 1] = vec.y.x; @@ -4036,7 +4036,7 @@ __device__ __inline__ void write_smem_transpose( template <> __device__ __inline__ void write_smem_transpose( - const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) + bf16_8_t const& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) { smem[transpose_idx] = vec.x.x; smem[transpose_idx + 1] = vec.y.x; @@ -4060,7 +4060,7 @@ __device__ __inline__ void vec_from_smem_transpose(float4& vec, __nv_fp8_e4m3* s #endif // ENABLE_FP8 template <> -__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(uint4 const& vec, uint16_t* smem, int transpose_idx, int smem_pitch) { union { @@ -4089,7 +4089,7 @@ __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem } template <> -__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(uint2 const& vec, uint16_t* smem, int transpose_idx, int smem_pitch) { union { @@ -4114,7 +4114,7 @@ __device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem } template <> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(uint32_t const& vec, uint16_t* smem, int transpose_idx, int smem_pitch) { union { @@ -4129,7 +4129,7 @@ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* s } template <> -__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(float4 const& vec, float* smem, int transpose_idx, int smem_pitch) { smem[transpose_idx] = vec.x; smem[transpose_idx + 1] = vec.z; @@ -4138,7 +4138,7 @@ __device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, } template <> -__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(uint32_t const& vec, half* smem, int transpose_idx, int smem_pitch) { union { @@ -4152,15 +4152,15 @@ __device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, } template <> -__device__ __inline__ void write_smem_transpose(const half2& vec, half* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(half2 const& vec, half* smem, int transpose_idx, int smem_pitch) { - return write_smem_transpose(*reinterpret_cast(&vec), smem, transpose_idx, smem_pitch); + return write_smem_transpose(*reinterpret_cast(&vec), smem, transpose_idx, smem_pitch); } #ifdef ENABLE_BF16 template <> __device__ __inline__ void write_smem_transpose( - const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) + __nv_bfloat162 const& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) { smem[transpose_idx] = vec.x; smem[smem_pitch + transpose_idx] = vec.y; @@ -4168,7 +4168,7 @@ __device__ __inline__ void write_smem_transpose( #endif template <> -__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) +__device__ __inline__ void write_smem_transpose(float2 const& vec, float* smem, int transpose_idx, int smem_pitch) { smem[transpose_idx] = vec.x; smem[smem_pitch + transpose_idx] = vec.y; @@ -4177,7 +4177,7 @@ __device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, #ifdef ENABLE_FP8 template <> __device__ __inline__ void write_smem_transpose( - const float4& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch) + float4 const& vec, __nv_fp8_e4m3* smem, int transpose_idx, int smem_pitch) { printf("[ERROR] still no have implementation for vec_from_smem_transpose under __nv_fp8_e4m3 \n"); } diff --git a/cpp/tensorrt_llm/kernels/decodingCommon.cu b/cpp/tensorrt_llm/kernels/decodingCommon.cu index dba6e5212..6a504e568 100644 --- a/cpp/tensorrt_llm/kernels/decodingCommon.cu +++ b/cpp/tensorrt_llm/kernels/decodingCommon.cu @@ -26,7 +26,7 @@ namespace tensorrt_llm namespace kernels { -__global__ void curandInitialize(curandState_t* state, const int* batchSlots, const int size, const uint64_t randomSeed) +__global__ void curandInitialize(curandState_t* state, int const* batchSlots, int const size, const uint64_t randomSeed) { int const idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < size) @@ -37,7 +37,7 @@ __global__ void curandInitialize(curandState_t* state, const int* batchSlots, co } void invokeCurandInitialize( - curandState_t* state, const int* batchSlots, const size_t batchSize, const uint64_t randomSeed, cudaStream_t stream) + curandState_t* state, int const* batchSlots, const size_t batchSize, const uint64_t randomSeed, cudaStream_t stream) { dim3 block(256); dim3 grid((int) (ceil(batchSize * 1.0 / 256))); @@ -45,7 +45,7 @@ void invokeCurandInitialize( } __global__ void curandBatchInitialize( - curandState_t* states, const int* batchSlots, const int size, const uint64_t* randomSeeds) + curandState_t* states, int const* batchSlots, int const size, uint64_t const* randomSeeds) { int const idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < size) @@ -55,8 +55,8 @@ __global__ void curandBatchInitialize( } } -void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, const size_t batchSize, - const uint64_t* randomSeeds, cudaStream_t stream) +void invokeCurandBatchInitialize(curandState_t* states, int const* batchSlots, const size_t batchSize, + uint64_t const* randomSeeds, cudaStream_t stream) { dim3 block(256); dim3 grid((int) (ceil(batchSize * 1.0 / 256))); diff --git a/cpp/tensorrt_llm/kernels/decodingCommon.h b/cpp/tensorrt_llm/kernels/decodingCommon.h index 43bd05215..494990994 100644 --- a/cpp/tensorrt_llm/kernels/decodingCommon.h +++ b/cpp/tensorrt_llm/kernels/decodingCommon.h @@ -151,7 +151,7 @@ static_assert(FinishedState::finishedMaxLength().isFinishedMaxLength()); //! \param randomSeed seed to initialize states //! \param stream stream void invokeCurandInitialize( - curandState_t* state, const int* batchSlots, const size_t batchSize, uint64_t randomSeed, cudaStream_t stream); + curandState_t* state, int const* batchSlots, const size_t batchSize, uint64_t randomSeed, cudaStream_t stream); //! \brief Initialize batchSize curand states with given seed per request. //! @@ -160,8 +160,8 @@ void invokeCurandInitialize( //! \param batchSize number of states to initialize //! \param randomSeeds input buffer [maxBatchSize] with seeds //! \param stream stream -void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, const size_t batchSize, - const uint64_t* randomSeeds, cudaStream_t stream); +void invokeCurandBatchInitialize(curandState_t* states, int const* batchSlots, const size_t batchSize, + uint64_t const* randomSeeds, cudaStream_t stream); //! \brief Applies mask, adds bias to logits and computes softmax values. //! Sets -MAX_FLT value for tokens in range [vocabSize; vocabSizePadded) to prevent them from being chosen. diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.cu b/cpp/tensorrt_llm/kernels/decodingKernels.cu index dae3dba5d..8f640c73c 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.cu +++ b/cpp/tensorrt_llm/kernels/decodingKernels.cu @@ -17,10 +17,19 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaTypeUtils.cuh" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/decodingKernels.h" +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif using namespace tensorrt_llm::common; +using namespace tensorrt_llm::runtime; namespace tensorrt_llm { @@ -33,12 +42,12 @@ __global__ void gatherTree(gatherTreeParam param) for (int batchbeamIdx = blockIdx.x * blockDim.x + threadIdx.x; batchbeamIdx < param.batchSize * param.beamWidth; batchbeamIdx += gridDim.x * blockDim.x) { - const int batch = batchbeamIdx / param.beamWidth; - const int beam = batchbeamIdx % param.beamWidth; - const int inputLen = param.inputLengths == nullptr ? 0 : param.inputLengths[batchbeamIdx]; + int const batch = batchbeamIdx / param.beamWidth; + int const beam = batchbeamIdx % param.beamWidth; + int const inputLen = param.inputLengths == nullptr ? 0 : param.inputLengths[batchbeamIdx]; - const int* parentIds = param.parentIds; - const int* stepIds = param.stepIds; + int const* parentIds = param.parentIds; + int const* stepIds = param.stepIds; // TODO optimize the reduce_max operation for large beamWidth int maxLen = -1; @@ -58,22 +67,22 @@ __global__ void gatherTree(gatherTreeParam param) maxLen = tmpLen; } } - const int maxSeqLenB = min(param.maxSeqLen, maxLen); + int const maxSeqLenB = min(param.maxSeqLen, maxLen); if (maxSeqLenB <= 0) { continue; } - const int initialTgtIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1; - const int initialParentIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1; + int const initialTgtIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1; + int const initialParentIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + maxSeqLenB - 1; param.outputIds[initialTgtIx] = __ldg(stepIds + initialParentIx); int parent = parentIds == nullptr ? 0 : __ldg(parentIds + initialParentIx) % param.beamWidth; bool foundBad = false; for (int level = maxSeqLenB - 2; level >= 0; --level) { - const int levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + level; - const int levelParentIx = batch * param.beamWidth * param.maxSeqLen + parent * param.maxSeqLen + level; + int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + level; + int const levelParentIx = batch * param.beamWidth * param.maxSeqLen + parent * param.maxSeqLen + level; if (parent < 0 || parent > param.beamWidth) { param.outputIds[levelBeamIx] = param.endTokens[batch]; @@ -104,7 +113,7 @@ __global__ void gatherTree(gatherTreeParam param) int startStep = 1; for (int time = startStep; time < maxSeqLenB; ++time) { - const int levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + time; + int const levelBeamIx = batch * param.beamWidth * param.maxSeqLen + beam * param.maxSeqLen + time; if (finished) { param.outputIds[levelBeamIx] = param.endTokens[batch]; @@ -135,7 +144,7 @@ struct RankNorm float norm; }; -inline __device__ RankNorm swap(const RankNorm& rankNorm, int mask, int dir) +inline __device__ RankNorm swap(RankNorm const& rankNorm, int mask, int dir) { // Exchange the rank and norm inside the warp. RankNorm other; @@ -160,8 +169,8 @@ inline __device__ uint32_t bfe(uint32_t a, uint32_t start, uint32_t len = 1) __global__ void finalized(gatherTreeParam param) { - const int beamIdx = static_cast(threadIdx.x); - const int beamWidth{param.beamWidth}; + int const beamIdx = static_cast(threadIdx.x); + int const beamWidth{param.beamWidth}; extern __shared__ char array[]; int* sRank = (int*) (array); @@ -172,8 +181,8 @@ __global__ void finalized(gatherTreeParam param) if (beamIdx < beamWidth) { - const int idx = blockIdx.x * param.beamWidth + beamIdx; - const int numGeneratedToken{param.sequenceLengths[idx] - param.inputLengths[idx]}; + int const idx = blockIdx.x * param.beamWidth + beamIdx; + int const numGeneratedToken{param.sequenceLengths[idx] - param.inputLengths[idx]}; sNormedScores[beamIdx] = applyLengthPenalty(param.cumLogProbs[idx], numGeneratedToken, param.lengthPenalty); sLength[beamIdx] = param.sequenceLengths[idx]; sScores[beamIdx] = param.cumLogProbs[idx]; @@ -284,8 +293,8 @@ void invokeGatherTree(gatherTreeParam param) } __global__ void finalize(int* outputIds, int* sequenceLengths, float* cumLogProbs, float* outputLogProbs, - const int* topKOutputIds, const int* topKSequenceLengths, const float* scores, const float* topKCumLogProbs, - const float* topKLogProbs, const int* numBeams, const int* inputLengths, const int beamWidth, const int maxSeqLen) + int const* topKOutputIds, int const* topKSequenceLengths, float const* scores, float const* topKCumLogProbs, + float const* topKLogProbs, int const* numBeams, int const* inputLengths, int const beamWidth, int const maxSeqLen) { // outputIds: [bs, beamWidth, maxSeqLen] // sequenceLengths: [bs, beamWidth] @@ -306,7 +315,7 @@ __global__ void finalize(int* outputIds, int* sequenceLengths, float* cumLogProb int* sRank = (int*) (array); // [beamWidth] float* sScores = (float*) (sRank + beamWidth); // [2 * beamWidth] int* sSequenceLengths = (int*) (sScores + beamWidth * 2); // [beamWidth] - const int numBeam = numBeams[blockIdx.x]; + int const numBeam = numBeams[blockIdx.x]; if (threadIdx.x < numBeam) { sScores[threadIdx.x] = scores[blockIdx.x * beamWidth * 2 + threadIdx.x]; @@ -315,7 +324,7 @@ __global__ void finalize(int* outputIds, int* sequenceLengths, float* cumLogProb if (numBeam < 32) { - const int beamIdx = threadIdx.x; + int const beamIdx = threadIdx.x; RankNorm rankNorm; rankNorm.rank = beamIdx; rankNorm.norm = beamIdx < numBeam ? sScores[beamIdx] : -FLT_MAX; @@ -422,9 +431,9 @@ __global__ void finalize(int* outputIds, int* sequenceLengths, float* cumLogProb } void invokeFinalize(int* outputIds, int* sequenceLengths, float* cumLogProbs, float* outputLogProbs, - const int* topKOutputIds, const int* topKSequenceLengths, const float* scores, const float* topKCumLogProbs, - const float* topKLogProbs, const int* numBeams, const int* inputLengths, const int beamWidth, const int maxSeqLen, - const int batchSize, cudaStream_t stream) + int const* topKOutputIds, int const* topKSequenceLengths, float const* scores, float const* topKCumLogProbs, + float const* topKLogProbs, int const* numBeams, int const* inputLengths, int const beamWidth, int const maxSeqLen, + int const batchSize, cudaStream_t stream) { TLLM_LOG_DEBUG("%s %s start", __FILE__, __PRETTY_FUNCTION__); dim3 block(beamWidth * 2); @@ -435,7 +444,7 @@ void invokeFinalize(int* outputIds, int* sequenceLengths, float* cumLogProbs, fl topKLogProbs, numBeams, inputLengths, beamWidth, maxSeqLen); } -__global__ void initializeOutput(int* outputIds, const int* endIds, const int maxSeqLen) +__global__ void initializeOutput(int* outputIds, int const* endIds, int const maxSeqLen) { for (int i = threadIdx.x; i < maxSeqLen; i += blockDim.x) { @@ -443,26 +452,26 @@ __global__ void initializeOutput(int* outputIds, const int* endIds, const int ma } } -void invokeInitializeOutput(int* outputIds, const int* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream) +void invokeInitializeOutput(int* outputIds, int const* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream) { initializeOutput<<>>(outputIds, endIds, maxSeqLen); } -__global__ void copyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequenceLengths, const int* batchSlots, +__global__ void copyNextStepIds(int* nextStepIds, int** outputIdsPtr, int const* sequenceLengths, int const* batchSlots, int batchSize, int beamWidth, int maxSeqLen) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batchSize * beamWidth; index += blockDim.x * gridDim.x) { - const int batchIdx{index / beamWidth}; + int const batchIdx{index / beamWidth}; auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; - const int beamIdx{index % beamWidth}; + int const beamIdx{index % beamWidth}; auto const batchBeamIdx = batchSlot * beamWidth + beamIdx; nextStepIds[batchBeamIdx] = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + sequenceLengths[batchBeamIdx] - 1]; } } -void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequenceLengths, const int* batchSlots, +void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, int const* sequenceLengths, int const* batchSlots, int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) { dim3 block(min(256, batchSize * beamWidth)); @@ -471,8 +480,8 @@ void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequ nextStepIds, outputIdsPtr, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen); } -__global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths, - const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen) +__global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, int const* sequenceLengths, + int const* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen) { int index = blockIdx.x * blockDim.x + threadIdx.x; @@ -494,8 +503,8 @@ __global__ void transposeLogProbs(float* outputLogProbs, float* outputLogProbsTi } } -void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, const int* sequenceLengths, - const int* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) +void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled, int const* sequenceLengths, + int const* batchSlots, int batchSize, int maxBatchSize, int beamWidth, int maxSeqLen, cudaStream_t stream) { dim3 block(256); dim3 grid(divUp(batchSize * beamWidth * maxSeqLen, block.x)); @@ -707,5 +716,139 @@ template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream); +__device__ __forceinline__ int4 reduceMaxInt4(int4 const& a, int4 const& b) +{ + return a.x >= b.x ? a : b; +} + +template +__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, + SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, + TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize, + SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, + SizeType maxTokensPerStep) +{ + auto const batchIdx = static_cast(blockIdx.x); + auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx]; + auto& inputLength = sequenceLengths[batchSlot]; + auto const endId = endIds[batchSlot]; + auto const maxNumDraftTokens = maxNumHeads + 1; + + int4 partialMax{-1, -1, 0, 0}; + // Go over different paths and construct implicit sequences + for (auto pathIdx = static_cast(threadIdx.x); pathIdx < maxTokensPerStep; + pathIdx += static_cast(blockDim.x)) + { + auto acceptedLength = maxNumDraftTokens; + auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxTokensPerStep, maxNumDraftTokens); + bool hasEnd = false; + // Go along the path + for (SizeType ti = 0; ti < maxNumDraftTokens; ++ti) + { + auto const tokenId = paths[pathOffset + ti]; + // Break if path terminates + if (tokenId == -1) + { + acceptedLength = ti; + break; + } + auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId; + auto const draftTokenIdx = batchSlot * maxDraftSeqLen + inputLength + tokenId; + auto const draftToken = outputIds[draftTokenIdx]; + auto const targetToken = targetIds[targetTokenIdx]; + + // Check if draft tokens are the same as target tokens + bool const accepted = draftToken == targetToken; + hasEnd = targetToken == endId; + if (!accepted || hasEnd) + { + acceptedLength = hasEnd ? ti : ti + 1; + break; + } + } + // Get longest path of the thread + if (partialMax.x < acceptedLength) + { + partialMax.x = acceptedLength; + partialMax.y = pathIdx; + partialMax.z = hasEnd; + } + } + + // Get the longest path of the block (request) + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; + int4 total = BlockReduce(tempStorage).Reduce(partialMax, reduceMaxInt4); + + __shared__ int4 totalShared; + if (threadIdx.x == 0) + { + totalShared = total; + } + + __syncthreads(); + + auto const acceptedLength = totalShared.x; + auto const bestPathIdx = totalShared.y; + auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens); + for (auto ti = static_cast(threadIdx.x); ti < acceptedLength; ti += static_cast(blockDim.x)) + { + auto const tokenId = paths[pathOffset + ti]; + auto const targetSrcTokenIdx = batchSlot * maxTargetSeqLen + tokenId; + auto const draftDstTokenIdx = batchSlot * maxDraftSeqLen + inputLength + ti; + auto const targetToken = targetIds[targetSrcTokenIdx]; + // Copy accepted tokens to the sequence with draft tokens (outputIds === outputIds) + outputIds[draftDstTokenIdx] = targetToken; + } + + __syncthreads(); + + // Leading thread reconstructs winning path and sets new data + if (threadIdx.x == 0) + { + auto const hasEnd = totalShared.z; + // Set end condition + if (hasEnd) + { + finishedFinal[batchSlot].setFinishedEOS(); + } + // Make correction to the sequence length + inputLength += acceptedLength; + } + + // Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel + for (auto hi = static_cast(threadIdx.x); hi < maxNumHeads; hi += static_cast(blockDim.x)) + { + logitsPtrs[batchIdx * maxNumHeads + hi] + = medusaLogits + flat_index4(hi, batchIdx, acceptedLength, 0, maxBatchSize, maxTokensPerStep, vocabSize); + } +} + +template +void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, SizeType* sequenceLengths, + FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, + T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, + SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, SizeType maxTokensPerStep, + cudaStream_t stream) +{ + constexpr SizeType BLOCK_SIZE = 256; + dim3 block(BLOCK_SIZE); + dim3 grid(batchSize); + acceptDraftTokensByIdsWithPaths<<>>(outputIds, targetIds, sequenceLengths, + finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs, batchSize, vocabSize, maxBatchSize, + maxDraftSeqLen, maxTargetSeqLen, maxNumHeads, maxTokensPerStep); +} + +template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, + SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, + TokenIdType const* endIds, float const* medusaLogits, float const** logitsPtrs, SizeType batchSize, + SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, + SizeType maxTokensPerStep, cudaStream_t stream); +template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, + SizeType* sequenceLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths, + TokenIdType const* endIds, half const* medusaLogits, half const** logitsPtrs, SizeType batchSize, + SizeType vocabSize, SizeType maxBatchSize, int32_t maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads, + SizeType maxTokensPerStep, cudaStream_t stream); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/decodingKernels.h b/cpp/tensorrt_llm/kernels/decodingKernels.h index e2abc4630..ba2a891ca 100644 --- a/cpp/tensorrt_llm/kernels/decodingKernels.h +++ b/cpp/tensorrt_llm/kernels/decodingKernels.h @@ -18,6 +18,7 @@ #include "gptKernels.h" #include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/runtime/common.h" #include #include #include @@ -61,7 +62,7 @@ void invokeFinalize(int32_t* outputIds, int32_t* sequenceLengths, float* cumLogP int32_t maxSeqLen, int32_t batchSize, cudaStream_t stream); void invokeInitializeOutput( - int32_t* outputIds, const int32_t* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream); + int32_t* outputIds, int32_t const* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream); void invokeCopyNextStepIds(int32_t* nextStepIds, int32_t** outputIdsPtr, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream); @@ -133,10 +134,41 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti int32_t const* batchSlots, int32_t batch_size, int32_t max_batch_size, int32_t beam_width, int32_t max_seq_len, cudaStream_t stream); -void invokeAcceptTokens(int32_t const* draft_tokens, int32_t const* target_tokens, int32_t const* context_lengths, - int32_t const* nums_draft_tokens, int32_t* sequence_lengths, bool const* finished, bool* finished_final, - int32_t* finished_sum, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, int32_t max_draft_tokens, - cudaStream_t stream); - +//! \brief verifies draft medusa tokens given target tokens. Modifies outputIds tensor accordingly filling it with +//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according +//! to the next after the last accepted token. +//! +//! \param outputIds input/output buffer [maxBatchSize, maxDraftSeqLen], +//! input tokens followed by draft tokens to be verified. +//! After accepting tokens, gets overwritten such that input tokens are followed by the accepted tokens +//! and one additional token -- next after the last accepted. +//! \param targetIds input buffer [maxBatchSize, maxTargetSeqLen], tokens predicted from the target medusa head +//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens +//! Incrememnted according to the accepted length +//! \param finishedFinal input buffer [maxBatchSize], finished states per request +//! \param batchSlots input buffer [batchSize], address map from local index +//! to global index [0, batchSize] -> [0, maxBatchSize] +//! \param paths input buffer [maxBatchSize, maxTokensPerStep, maxNumHeads+1], +//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path. +//! \param endIds input buffer [maxBatchSize], EOS ids per request +//! \param medusaLogits input buffer [maxNumHeads, maxBatchSize, maxTokensPerStep, vocabSize], pointer +//! to the logits from medusa heads +//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the +//! respective rows of the medusaLogits for the next after the accepted token +//! \param batchSize current batch size +//! \param maxBatchSize maximum batch size +//! \param vocabSize vocab size +//! \param maxDraftSeqLen maximum sequence length of the sequence containing draft tokens +//! \param maxTargetSeqLen maximum sequence length predicted from target head +//! \param maxNumHeads maximum number of medusa heads +//! \param maxTokensPerStep maximum number of tokens per step configured in the system +//! \param stream stream +template +void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* targetIds, + runtime::SizeType* sequenceLengths, FinishedState* finishedFinal, runtime::SizeType const* batchSlots, + runtime::SizeType const* paths, runtime::TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, + runtime::SizeType batchSize, runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, + runtime::SizeType maxDraftSeqLen, runtime::SizeType maxTargetSeqLen, runtime::SizeType maxNumHeads, + runtime::SizeType maxTokensPerStep, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/gptKernels.cu b/cpp/tensorrt_llm/kernels/gptKernels.cu index 78f62822b..02fd6abd2 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.cu +++ b/cpp/tensorrt_llm/kernels/gptKernels.cu @@ -58,7 +58,7 @@ struct BlockPrefixCallbackOp template __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqOffsets( - int* seqOffsets, const int* seqLengths, int batchSize) + int* seqOffsets, int const* seqLengths, int batchSize) { // The implementation of the parallel scan in the thread block (see CUB for details). using BlockScan = cub::BlockScan; @@ -108,7 +108,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqOffsets( // // That kernel uses a grid of batchSize blocks. -__global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets, int maxSeqLength) +__global__ void computePaddingOffsets(int* paddingOffsets, int const* seqOffsets, int maxSeqLength) { // The index of the sequence in the batch. int batchIdx = blockIdx.x; @@ -133,7 +133,7 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets // This kernel computes the attention mask. We must compute this on-the-fly in the future. template -__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const int* seqOffsets, int maxSeqLength, +__global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, int const* seqOffsets, int maxSeqLength, int attentionWindowSize, AttentionMaskType attentionMaskType) { // The index of the sequence in the batch. @@ -221,10 +221,10 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const } template -void invokeBuildDecoderInfo(const BuildDecoderInfoParams& params, cudaStream_t stream) +void invokeBuildDecoderInfo(BuildDecoderInfoParams const& params, cudaStream_t stream) { // Compute the sequence offsets. - const int THREADS_PER_BLOCK = 256; + int const THREADS_PER_BLOCK = 256; computeSeqOffsets <<<1, THREADS_PER_BLOCK, 0, stream>>>(params.seqQOffsets, params.seqQLengths, params.batchSize); if (params.seqKVLengths) @@ -240,7 +240,7 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams& params, cudaStream_ // Compute the attention mask, if needed. if (params.attentionMask != nullptr) { - const int MIN_BLOCKS = 512; + int const MIN_BLOCKS = 512; int blocksPerSeq = 16; while (blocksPerSeq * params.batchSize < MIN_BLOCKS) { @@ -252,16 +252,16 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams& params, cudaStream_ } } -template void invokeBuildDecoderInfo(const BuildDecoderInfoParams&, cudaStream_t); -template void invokeBuildDecoderInfo(const BuildDecoderInfoParams&, cudaStream_t); +template void invokeBuildDecoderInfo(BuildDecoderInfoParams const&, cudaStream_t); +template void invokeBuildDecoderInfo(BuildDecoderInfoParams const&, cudaStream_t); #ifdef ENABLE_BF16 -template void invokeBuildDecoderInfo(const BuildDecoderInfoParams<__nv_bfloat16>&, cudaStream_t); +template void invokeBuildDecoderInfo(BuildDecoderInfoParams<__nv_bfloat16> const&, cudaStream_t); #endif #ifdef ENABLE_FP8 -template void invokeBuildDecoderInfo(const BuildDecoderInfoParams<__nv_fp8_e4m3>&, cudaStream_t); +template void invokeBuildDecoderInfo(BuildDecoderInfoParams<__nv_fp8_e4m3> const&, cudaStream_t); #endif -__global__ void updatePaddingCountKernel(int* paddingPerSeq, const int* seqLengths, int maxSeqLength, int batchSize) +__global__ void updatePaddingCountKernel(int* paddingPerSeq, int const* seqLengths, int maxSeqLength, int batchSize) { for (int ii = threadIdx.x; ii < batchSize; ii += blockDim.x) diff --git a/cpp/tensorrt_llm/kernels/gptKernels.h b/cpp/tensorrt_llm/kernels/gptKernels.h index a8c410993..e2a9a197d 100644 --- a/cpp/tensorrt_llm/kernels/gptKernels.h +++ b/cpp/tensorrt_llm/kernels/gptKernels.h @@ -72,9 +72,9 @@ struct BuildDecoderInfoParams AttentionMaskDataType* attentionMask; // The Q length of each sequence in the batch. Shape: [batchSize]. - const int* seqQLengths; + int const* seqQLengths; // The KV length of each sequence in the batch. Shape: [batchSize]. - const int* seqKVLengths; + int const* seqKVLengths; // The number of sequences in the batch. int batchSize; @@ -92,7 +92,7 @@ struct BuildDecoderInfoParams }; template -void invokeBuildDecoderInfo(const BuildDecoderInfoParams& params, cudaStream_t stream); +void invokeBuildDecoderInfo(BuildDecoderInfoParams const& params, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/groupGemm.cu b/cpp/tensorrt_llm/kernels/groupGemm.cu index a37424386..86656f779 100644 --- a/cpp/tensorrt_llm/kernels/groupGemm.cu +++ b/cpp/tensorrt_llm/kernels/groupGemm.cu @@ -72,8 +72,8 @@ void groupedGemm_(std::vector problem_sizes, std::vect using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; - const int kAlignmentA = 8; - const int kAlignmentB = 8; + int const kAlignmentA = 8; + int const kAlignmentB = 8; int problem_count = problem_sizes.size(); diff --git a/cpp/tensorrt_llm/kernels/kvCacheUtils.h b/cpp/tensorrt_llm/kernels/kvCacheUtils.h index 5c54bc6e5..8acde13a8 100644 --- a/cpp/tensorrt_llm/kernels/kvCacheUtils.h +++ b/cpp/tensorrt_llm/kernels/kvCacheUtils.h @@ -77,7 +77,7 @@ struct KVBlockArray , mSinkTokens(sinkTokenLen) , data(nullptr) { - const float tokensPerBlockSeqLog2 = log2(mTokensPerBlock); + float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); TLLM_CHECK_WITH_INFO( ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2"); // NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr). @@ -183,7 +183,7 @@ struct KVBlockArrayForContextFMHA , mTokensPerBlock(tokensPerBlock) , data(nullptr) { - const float tokensPerBlockSeqLog2 = log2(mTokensPerBlock); + float const tokensPerBlockSeqLog2 = log2(mTokensPerBlock); TLLM_CHECK_WITH_INFO( ceil(tokensPerBlockSeqLog2) == floor(tokensPerBlockSeqLog2), "tokensPerBlock must be power of 2"); // NOTE: pointer offset arithmetic offset is performed on int32_t (see this.getRowPtr). diff --git a/cpp/tensorrt_llm/kernels/layernormKernels.cu b/cpp/tensorrt_llm/kernels/layernormKernels.cu index 8145f5073..f10858709 100644 --- a/cpp/tensorrt_llm/kernels/layernormKernels.cu +++ b/cpp/tensorrt_llm/kernels/layernormKernels.cu @@ -26,7 +26,7 @@ namespace kernels { template -__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, const T* gamma, const T* beta, int i) +__inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_variance, T const* gamma, T const* beta, int i) { Tf ret = (val - s_mean) * s_variance * cuda_cast(gamma[i]); if (beta != nullptr) @@ -58,8 +58,8 @@ __inline__ __device__ Tf compute_layernorm(Tf val, float s_mean, float s_varianc * normed_output_quant. */ template -__global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, - int tokens, int hidden_dim, const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, +__global__ void generalLayerNorm(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, + int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, bool use_shmem) { constexpr auto num_elems_T = num_elems::value; @@ -72,15 +72,15 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, __shared__ float s_mean; __shared__ float s_variance; - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; + int const tidx = threadIdx.x; + int const bidx = blockIdx.x; float mean = 0.0f; float variance = 0.0f; float local_sum = 0.0f; float local_var_sum = 0.0f; - const int n_elems = hidden_dim / num_elems_T; + int const n_elems = hidden_dim / num_elems_T; for (int i = tidx; i < n_elems; i += blockDim.x) { const T val = input[bidx * n_elems + i]; @@ -138,15 +138,15 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, __syncthreads(); } - const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; - const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; + bool const with_per_token_scaling = scale_orig_quant_per_token != nullptr; + bool const with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const float_packed_t scale_orig_quant = cuda_cast(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f); T_scalar amax = 1e-6f; for (int i = tidx; i < n_elems; i += blockDim.x) { - const int index = bidx * n_elems + i; + int const index = bidx * n_elems + i; const float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); const T val = cuda_cast(compute_layernorm(val_f, s_mean, s_variance, gamma, beta, i)); @@ -172,10 +172,10 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, if (with_per_token_scaling) { float abs_max_f = blockAllReduceMax(cuda_cast(amax)); - const float dynamic_per_token_scale = 127.f / abs_max_f; + float const dynamic_per_token_scale = 127.f / abs_max_f; for (int i = tidx; i < n_elems; i += blockDim.x) { - const int index = bidx * n_elems + i; + int const index = bidx * n_elems + i; float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); if (!use_shmem) { @@ -193,8 +193,8 @@ __global__ void generalLayerNorm(const T* input, const T* gamma, const T* beta, } template -void dispatch_layernorm_type_square_method(const T* input, const T* gamma, const T* beta, T* normed_output, - const float eps, int tokens, int hidden_dim, const float* scale_orig_quant_per_tensor, +void dispatch_layernorm_type_square_method(T const* input, T const* gamma, T const* beta, T* normed_output, + float const eps, int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream) { @@ -208,8 +208,8 @@ void dispatch_layernorm_type_square_method(const T* input, const T* gamma, const } template -void dispatch_layernorm_type(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, - int tokens, int hidden_dim, const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, +void dispatch_layernorm_type(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, + int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream, bool use_diff_of_squares) { @@ -228,8 +228,8 @@ void dispatch_layernorm_type(const T* input, const T* gamma, const T* beta, T* n } template -void invokeGeneralLayerNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, const int tokens, - const int hidden_dim, cudaStream_t stream, bool use_diff_of_squares, const float* scale, float* dynamic_scale, +void invokeGeneralLayerNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens, + int const hidden_dim, cudaStream_t stream, bool use_diff_of_squares, float const* scale, float* dynamic_scale, int8_t* normed_output_quant) { dim3 grid(tokens); @@ -239,7 +239,7 @@ void invokeGeneralLayerNorm(T* out, const T* input, const T* gamma, const T* bet constexpr size_t vec_size = 2; const size_t shmem_size = hidden_dim * sizeof(T); - const bool use_vec_type = (hidden_dim % vec_size == 0) + bool const use_vec_type = (hidden_dim % vec_size == 0) && (std::is_same::value #ifdef ENABLE_BF16 || std::is_same::value @@ -249,8 +249,8 @@ void invokeGeneralLayerNorm(T* out, const T* input, const T* gamma, const T* bet if (use_vec_type) { using Tp = typename packed_as::type; - dispatch_layernorm_type(reinterpret_cast(input), reinterpret_cast(gamma), - reinterpret_cast(beta), reinterpret_cast(out), eps, tokens, hidden_dim, scale, + dispatch_layernorm_type(reinterpret_cast(input), reinterpret_cast(gamma), + reinterpret_cast(beta), reinterpret_cast(out), eps, tokens, hidden_dim, scale, dynamic_scale, normed_output_quant, grid, block, shmem_size, stream, use_diff_of_squares); } else diff --git a/cpp/tensorrt_llm/kernels/layernormKernels.h b/cpp/tensorrt_llm/kernels/layernormKernels.h index fd6715d19..20ba9d1cd 100644 --- a/cpp/tensorrt_llm/kernels/layernormKernels.h +++ b/cpp/tensorrt_llm/kernels/layernormKernels.h @@ -27,8 +27,8 @@ namespace kernels { template -void invokeGeneralLayerNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, const int tokens, - const int hidden_dim, cudaStream_t stream = 0, bool use_diff_of_squares = true, const float* scale = nullptr, +void invokeGeneralLayerNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens, + int const hidden_dim, cudaStream_t stream = 0, bool use_diff_of_squares = true, float const* scale = nullptr, float* dynamic_scale = nullptr, int8_t* out_quant = nullptr); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/lookupKernels.cu b/cpp/tensorrt_llm/kernels/lookupKernels.cu index ccf1dde3a..8e3e36e1d 100644 --- a/cpp/tensorrt_llm/kernels/lookupKernels.cu +++ b/cpp/tensorrt_llm/kernels/lookupKernels.cu @@ -39,14 +39,14 @@ input IDs and get the correct results. * If the input ids is out of range it writes zero, otherwise it writes the correct embedding result. */ template -__global__ void lookup_kernel(T* output, const Idx* input, const T* weight, const Idx batch_size, const Idx offset, - const Idx size, const int n_embed) +__global__ void lookup_kernel(T* output, Idx const* input, T const* weight, const Idx batch_size, const Idx offset, + const Idx size, int const n_embed) { for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * n_embed; index += blockDim.x * gridDim.x) { - const int word_index = input[index / n_embed] - offset; - const int col_index = index % n_embed; + int const word_index = input[index / n_embed] - offset; + int const col_index = index % n_embed; T embedding; if (word_index < 0 || word_index >= size) { @@ -61,8 +61,8 @@ __global__ void lookup_kernel(T* output, const Idx* input, const T* weight, cons } template -void invokeLookUp(T* out, const Idx* input, const T* weight, const Idx batch_size, const Idx offset, const Idx size, - const int n_embed, cudaStream_t stream) +void invokeLookUp(T* out, Idx const* input, T const* weight, const Idx batch_size, const Idx offset, const Idx size, + int const n_embed, cudaStream_t stream) { dim3 grid(min(batch_size, 65536)); dim3 block(min(n_embed, 512)); diff --git a/cpp/tensorrt_llm/kernels/lookupKernels.h b/cpp/tensorrt_llm/kernels/lookupKernels.h index 9fe246f8d..e8c4b5fbb 100644 --- a/cpp/tensorrt_llm/kernels/lookupKernels.h +++ b/cpp/tensorrt_llm/kernels/lookupKernels.h @@ -26,8 +26,8 @@ namespace tensorrt_llm namespace kernels { template -void invokeLookUp(T* out, const Idx* input, const T* weight, const Idx batch_size, const Idx offset, const Idx size, - const int n_embed, cudaStream_t stream = 0); +void invokeLookUp(T* out, Idx const* input, T const* weight, const Idx batch_size, const Idx offset, const Idx size, + int const n_embed, cudaStream_t stream = 0); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu index c42fd5370..84bdf2f91 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu @@ -61,7 +61,7 @@ static constexpr int WARP_SIZE = 32; // in the softmax kernel when we extend this module to support expert-choice routing. template __launch_bounds__(TPB) __global__ - void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) + void moeSoftmax(float const* input, bool const* finished, float* output, int const num_cols) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmpStorage; @@ -69,7 +69,7 @@ __launch_bounds__(TPB) __global__ __shared__ float normalizing_factor; __shared__ float float_max; - const int thread_row_offset = blockIdx.x * num_cols; + int const thread_row_offset = blockIdx.x * num_cols; cub::Sum sum; float threadData(-FLT_MAX); @@ -82,11 +82,11 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; + int const idx = thread_row_offset + ii; threadData = max(input[idx], threadData); } - const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + float const maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); if (threadIdx.x == 0) { float_max = maxElem; @@ -97,11 +97,11 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; + int const idx = thread_row_offset + ii; threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + auto const Z = BlockReduce(tmpStorage).Reduce(threadData, sum); if (threadIdx.x == 0) { @@ -111,15 +111,15 @@ __launch_bounds__(TPB) __global__ for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { - const int idx = thread_row_offset + ii; - const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + int const idx = thread_row_offset + ii; + float const val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; output[idx] = val; } } template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +__launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax, bool const* finished, float* output, + int* indices, int* source_rows, int const num_experts, int const k, int const start_expert, int const end_expert) { using cub_kvp = cub::KeyValuePair; @@ -129,11 +129,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax cub_kvp thread_kvp; cub::ArgMax arg_max; - const int num_rows = gridDim.x; - const int block_row = blockIdx.x; + int const num_rows = gridDim.x; + int const block_row = blockIdx.x; - const bool row_is_active = finished ? !finished[block_row] : true; - const int thread_read_offset = blockIdx.x * num_experts; + bool const row_is_active = finished ? !finished[block_row] : true; + int const thread_read_offset = blockIdx.x * num_experts; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -142,13 +142,13 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; + int const idx = thread_read_offset + expert; inp_kvp.key = expert; inp_kvp.value = inputs_after_softmax[idx]; for (int prior_k = 0; prior_k < k_idx; ++prior_k) { - const int prior_winning_expert = indices[k * block_row + prior_k]; + int const prior_winning_expert = indices[k * block_row + prior_k]; if (prior_winning_expert == expert) { @@ -163,11 +163,11 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax if (threadIdx.x == 0) { // Ignore experts the node isn't responsible for with expert parallelism - const int expert = result_kvp.key; - const bool node_uses_expert = expert >= start_expert && expert < end_expert; - const bool should_process_row = row_is_active && node_uses_expert; + int const expert = result_kvp.key; + bool const node_uses_expert = expert >= start_expert && expert < end_expert; + bool const should_process_row = row_is_active && node_uses_expert; - const int idx = k * block_row + k_idx; + int const idx = k * block_row + k_idx; output[idx] = result_kvp.value; indices[idx] = should_process_row ? (expert - start_expert) : num_experts; assert(indices[idx] >= 0); @@ -193,8 +193,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, - int* source_rows, const int k, const int start_expert, const int end_expert) + void topkGatingSoftmax(float const* input, bool const* finished, float* output, int const num_rows, int* indices, + int* source_rows, int const k, int const start_expert, int const end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); @@ -226,31 +226,31 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. // This, each block processes a chunk of rows. We start by computing the start row for each block. - const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + int const cta_base_row = blockIdx.x * ROWS_PER_CTA; // Now, using the base row per thread block, we compute the base row per warp. - const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + int const warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; // The threads in a warp are split into sub-groups that will work on a row. // We compute row offset for each thread sub-group - const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; - const int thread_row = warp_base_row + thread_row_in_warp; + int const thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + int const thread_row = warp_base_row + thread_row_in_warp; // Threads with indices out of bounds should early exit here. if (thread_row >= num_rows) { return; } - const bool row_is_active = finished ? !finished[thread_row] : true; + bool const row_is_active = finished ? !finished[thread_row] : true; // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the // row it will read. - const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + float const* thread_row_ptr = input + thread_row * ELTS_PER_ROW; // Now, we compute the group each thread belong to in order to determine the first column to start loads. - const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; - const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; - const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + int const thread_group_idx = threadIdx.x % THREADS_PER_ROW; + int const first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + float const* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, // this can support all powers of 2 up to 16. @@ -259,7 +259,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // Finally, we pull in the data from global mem cutlass::Array row_chunk; AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); - const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); + AccessType const* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); #pragma unroll for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { @@ -304,7 +304,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the // argmax after computing the softmax. - const float reciprocal_row_sum = 1.f / row_sum; + float const reciprocal_row_sum = 1.f / row_sum; #pragma unroll for (int ii = 0; ii < VPT; ++ii) @@ -361,12 +361,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ if (thread_group_idx == 0) { // Add a guard to ignore experts not included by this node - const bool node_uses_expert = expert >= start_expert && expert < end_expert; - const bool should_process_row = row_is_active && node_uses_expert; + bool const node_uses_expert = expert >= start_expert && expert < end_expert; + bool const should_process_row = row_is_active && node_uses_expert; // The lead thread from each sub-group will write out the final results to global memory. (This will be a // single) thread per row of the input/output matrices. - const int idx = k * thread_row + k_idx; + int const idx = k * thread_row + k_idx; output[idx] = max_val; indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; source_rows[idx] = k_idx * num_rows + thread_row; @@ -375,13 +375,13 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // Finally, we clear the value in the thread with the current max if there is another iteration to run. if (k_idx + 1 < k) { - const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; - const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + int const ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + int const thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; // Only the thread in the group which produced the max will reset the "winning" value to -inf. if (thread_group_idx == thread_to_clear_in_group) { - const int offset_for_expert = expert % ELTS_PER_LDG; + int const offset_for_expert = expert % ELTS_PER_LDG; // Safe to set to any negative value since row_chunk values must be between 0 and 1. row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; } @@ -405,8 +405,8 @@ struct TopkConstants } // namespace detail template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, - int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +void topkGatingSoftmaxLauncherHelper(float const* input, bool const* finished, float* output, int* indices, + int* source_row, int const num_rows, int const k, int const start_expert, int const end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -414,17 +414,17 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; - const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + int const num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + int const num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE, WARPS_PER_TB); topkGatingSoftmax<<>>( input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } -void topkGatingSoftmaxKernelLauncher(const float* input, const bool* finished, float* output, - float* softmax_temp_output, int* indices, int* source_row, const int num_rows, const int num_experts, const int k, - const int start_expert, const int end_expert, cudaStream_t stream) +void topkGatingSoftmaxKernelLauncher(float const* input, bool const* finished, float* output, + float* softmax_temp_output, int* indices, int* source_row, int const num_rows, int const num_experts, int const k, + int const start_expert, int const end_expert, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; @@ -502,19 +502,19 @@ CubKeyValueSorter::CubKeyValueSorter() { } -CubKeyValueSorter::CubKeyValueSorter(const int num_experts) +CubKeyValueSorter::CubKeyValueSorter(int const num_experts) : num_experts_(num_experts) , num_bits_((int) log2(num_experts) + 1) { } -void CubKeyValueSorter::updateNumExperts(const int num_experts) +void CubKeyValueSorter::updateNumExperts(int const num_experts) { num_experts_ = num_experts; num_bits_ = (int) log2(num_experts) + 1; } -size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, const int num_experts) +size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, int const num_experts) { size_t num_bits = (int) log2(num_experts) + 1; size_t required_storage = 0; @@ -524,8 +524,8 @@ size_t CubKeyValueSorter::getWorkspaceSize(const size_t num_key_value_pairs, con return required_storage; } -void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, - const int* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream) +void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, int const* keys_in, int* keys_out, + int const* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream) { size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); size_t actual_ws_size = workspace_size; @@ -538,7 +538,7 @@ void CubKeyValueSorter::run(void* workspace, const size_t workspace_size, const // ============================== Infer GEMM sizes ================================= // TODO Could linear search be better for small # experts -__device__ inline int findTotalEltsLeqTarget(const int* sorted_indices, const int arr_length, const int target) +__device__ inline int findTotalEltsLeqTarget(int const* sorted_indices, int const arr_length, int const target) { int64_t low = 0, high = arr_length - 1, target_location = -1; while (low <= high) @@ -563,11 +563,11 @@ __device__ inline int findTotalEltsLeqTarget(const int* sorted_indices, const in // // "total_rows_before_expert" contains the index one past the last occurrence of the corresponding expert. // e.g. Index 0 is the start offset of expert 1, the final entry is the total number of active rows -__global__ void computeTotalRowsBeforeExpertKernel(const int* sorted_experts, const int sorted_experts_len, +__global__ void computeTotalRowsBeforeExpertKernel(int const* sorted_experts, int const sorted_experts_len, const int64_t num_experts, int64_t* total_rows_before_expert) { // First, compute the global tid. We only need 1 thread per expert. - const int expert = blockIdx.x * blockDim.x + threadIdx.x; + int const expert = blockIdx.x * blockDim.x + threadIdx.x; if (expert >= num_experts) { return; @@ -591,17 +591,17 @@ __global__ void computeTotalRowsBeforeExpertKernel(const int* sorted_experts, co // of the expanded index. template -__global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - const int num_rows, const int64_t* num_dest_rows, const int cols) +__global__ void expandInputRowsKernel(T const* unpermuted_input, T* permuted_output, + int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, + int const num_rows, int64_t const* num_dest_rows, int const cols) { // Reverse permutation map. // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 // thread block will be responsible for all k summations. - const int expanded_dest_row = blockIdx.x; - const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + int const expanded_dest_row = blockIdx.x; + int const expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; if (threadIdx.x == 0) { expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_dest_row; @@ -610,9 +610,9 @@ __global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_out if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { // Duplicate and permute rows - const int source_row = expanded_source_row % num_rows; + int const source_row = expanded_source_row % num_rows; - const T* source_row_ptr = unpermuted_input + source_row * cols; + T const* source_row_ptr = unpermuted_input + source_row * cols; T* dest_row_ptr = permuted_output + expanded_dest_row * cols; for (int tid = threadIdx.x; tid < cols; tid += blockDim.x) @@ -623,12 +623,12 @@ __global__ void expandInputRowsKernel(const T* unpermuted_input, T* permuted_out } template -void expandInputRowsKernelLauncher(const T* unpermuted_input, T* permuted_output, - const int* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, - const int num_rows, const int64_t* num_valid_tokens_ptr, const int cols, const int k, cudaStream_t stream) +void expandInputRowsKernelLauncher(T const* unpermuted_input, T* permuted_output, + int const* expanded_dest_row_to_expanded_source_row, int* expanded_source_row_to_expanded_dest_row, + int const num_rows, int64_t const* num_valid_tokens_ptr, int const cols, int const k, cudaStream_t stream) { - const int blocks = num_rows * k; - const int threads = std::min(cols, 1024); + int const blocks = num_rows * k; + int const threads = std::min(cols, 1024); auto func = (num_valid_tokens_ptr != nullptr) ? expandInputRowsKernel : expandInputRowsKernel; func<<>>(unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, cols); @@ -644,16 +644,16 @@ enum class ScaleMode : int // Final kernel to unpermute and scale // This kernel unpermutes the original data, does the k-way reduction and performs the final skip connection. template -__global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, - const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, const int cols, const int k, const int64_t* num_valid_ptr) +__global__ void finalizeMoeRoutingKernel(T const* expanded_permuted_rows, T* reduced_unpermuted_output, T const* skip_1, + T const* skip_2, T const* bias, float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int const cols, int const k, int64_t const* num_valid_ptr) { - const int original_row = blockIdx.x; - const int num_rows = gridDim.x; - const auto offset = original_row * cols; + int const original_row = blockIdx.x; + int const num_rows = gridDim.x; + auto const offset = original_row * cols; T* reduced_row_ptr = reduced_unpermuted_output + offset; - const T* skip_1_row_ptr{}; - const T* skip_2_row_ptr{}; + T const* skip_1_row_ptr{}; + T const* skip_2_row_ptr{}; if (RESIDUAL_NUM >= 1) { @@ -671,11 +671,11 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red float row_rescale{0.f}; for (int k_idx = 0; k_idx < k; ++k_idx) { - const int expanded_original_row = original_row + k_idx * num_rows; - const int expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; + int const expanded_original_row = original_row + k_idx * num_rows; + int const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; const int64_t k_offset = original_row * k + k_idx; - const float row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; + float const row_scale = (SCALE_MODE == ScaleMode::NO_SCALE) ? 1.f : scales[k_offset]; if constexpr (SCALE_MODE == ScaleMode::RENORM_SCALE) { row_rescale = row_rescale + row_scale; @@ -687,11 +687,11 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red continue; } - const T* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; + T const* expanded_permuted_rows_row_ptr = expanded_permuted_rows + expanded_permuted_row * cols; - const int expert_idx = expert_for_source_row[k_offset]; + int const expert_idx = expert_for_source_row[k_offset]; - const T* bias_ptr = bias + expert_idx * cols; + T const* bias_ptr = bias + expert_idx * cols; const T bias_value = HAS_BIAS ? bias_ptr[tid] : T(0.f); thread_output = static_cast(thread_output) @@ -717,20 +717,20 @@ __global__ void finalizeMoeRoutingKernel(const T* expanded_permuted_rows, T* red } template -void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, T* reduced_unpermuted_output, - const T* skip_1, const T* skip_2, const T* bias, const float* scales, - const int* expanded_source_row_to_expanded_dest_row, const int* expert_for_source_row, const int num_rows, - const int cols, const int k, const int64_t* num_valid_ptr, MOEParallelismConfig parallelism_config, +void finalizeMoeRoutingKernelLauncherSelectBias(T const* expanded_permuted_rows, T* reduced_unpermuted_output, + T const* skip_1, T const* skip_2, T const* bias, float const* scales, + int const* expanded_source_row_to_expanded_dest_row, int const* expert_for_source_row, int const num_rows, + int const cols, int const k, int64_t const* num_valid_ptr, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) { - const int blocks = num_rows; - const int threads = std::min(cols, 1024); + int const blocks = num_rows; + int const threads = std::min(cols, 1024); // Only add bias on rank 0 for tensor parallelism - const bool is_rank_0 = parallelism_config.tp_rank == 0; - const bool has_bias = bias != nullptr && is_rank_0; + bool const is_rank_0 = parallelism_config.tp_rank == 0; + bool const has_bias = bias != nullptr && is_rank_0; - const bool check_finished = num_valid_ptr != nullptr; + bool const check_finished = num_valid_ptr != nullptr; ScaleMode renorm_scales = ScaleMode::DEFAULT; if (normalization_mode == MOEExpertScaleNormalizationMode::RENORMALIZE) @@ -762,13 +762,13 @@ void finalizeMoeRoutingKernelLauncherSelectBias(const T* expanded_permuted_rows, } template -void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduced_unpermuted_output, const T* skip_1, - const T* skip_2, const T* bias, const float* scales, const int* expanded_source_row_to_expanded_dest_row, - const int* expert_for_source_row, const int num_rows, const int cols, const int k, const int64_t* num_valid_ptr, +void finalizeMoeRoutingKernelLauncher(T const* expanded_permuted_rows, T* reduced_unpermuted_output, T const* skip_1, + T const* skip_2, T const* bias, float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int const num_rows, int const cols, int const k, int64_t const* num_valid_ptr, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) { // If we are not rank 0 we should not add any residuals because the allreduce would sum multiple copies - const bool is_rank_0 = parallelism_config.tp_rank == 0; + bool const is_rank_0 = parallelism_config.tp_rank == 0; if (skip_1 == nullptr || !is_rank_0) { assert(skip_2 == nullptr); @@ -794,10 +794,10 @@ void finalizeMoeRoutingKernelLauncher(const T* expanded_permuted_rows, T* reduce template __global__ void doGatedActivationKernel( - T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, size_t inter_size) + T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, size_t inter_size) { - const int tid = threadIdx.x; - const int token = blockIdx.x; + int const tid = threadIdx.x; + int const token = blockIdx.x; if (num_valid_tokens_ptr && token >= *num_valid_tokens_ptr) { return; @@ -817,11 +817,11 @@ __global__ void doGatedActivationKernel( } template -void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid_tokens_ptr, int inter_size, +void doGatedActivation(T* output, T const* gemm_result, int64_t const* num_valid_tokens_ptr, int inter_size, int num_tokens, ActivationType activation_type, cudaStream_t stream) { - const int blocks = num_tokens; - const int threads = std::min(inter_size, 1024); + int const blocks = num_tokens; + int const threads = std::min(inter_size, 1024); // TODO Instead of T use a vectored type if performance would benefit // TODO For some reason Volta fails on GELU_taylor here with Warp Illegal Instruction. @@ -832,8 +832,8 @@ void doGatedActivation(T* output, const T* gemm_result, const int64_t* num_valid } template -std::vector CutlassMoeFCRunner::getWorkspaceBufferSizes(const int num_rows, - const int hidden_size, const int inter_size, const int num_experts, const int num_experts_per_node, const int k, +std::vector CutlassMoeFCRunner::getWorkspaceBufferSizes(int const num_rows, + int const hidden_size, int const inter_size, int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type) const { const size_t num_moe_inputs = k * num_rows; @@ -842,7 +842,7 @@ std::vector CutlassMoeFCRunner::getWorkspaceBuffe const size_t glu_inter_elems = isGatedActivation(activation_type) ? (interbuf_elems * 2) : 0; int num_softmax_outs = 0; - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + bool const is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { num_softmax_outs = num_rows * num_experts; @@ -873,11 +873,11 @@ std::vector CutlassMoeFCRunner::getWorkspaceBuffe } template -size_t CutlassMoeFCRunner::getWorkspaceSize(const int num_rows, const int hidden_size, - const int inter_size, const int num_experts, const int k, ActivationType activation_type, +size_t CutlassMoeFCRunner::getWorkspaceSize(int const num_rows, int const hidden_size, + int const inter_size, int const num_experts, int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config) const { - const int ep_size = parallelism_config.ep_size; + int const ep_size = parallelism_config.ep_size; TLLM_CHECK_WITH_INFO(num_experts % ep_size == 0, "Number of experts must be a multiple of tp size"); auto workspace = getWorkspaceBufferSizes( num_rows, hidden_size, inter_size, num_experts, num_experts / ep_size, k, activation_type); @@ -885,8 +885,8 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(const int num } template -void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, - const int inter_size, const int num_experts, const int num_experts_per_node, const int k, +void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, int const num_rows, int const hidden_size, + int const inter_size, int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type) { auto workspace = getWorkspaceBufferSizes( @@ -906,7 +906,7 @@ void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, co total_rows_before_expert_ = (int64_t*) ws_sliced[4]; softmax_out_ = nullptr; - const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + bool const is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { softmax_out_ = (float*) ws_sliced[5]; @@ -920,25 +920,25 @@ void CutlassMoeFCRunner::configureWsPtrs(char* ws_ptr, co } template -void CutlassMoeFCRunner::runMoe(const void* input_activations_void, const float* gating_output, - const void* fc1_expert_weights_void, const void* fc1_scales_void, const void* fc1_expert_biases_void, - ActivationType fc1_activation_type, const void* fc2_expert_weights_void, const void* fc2_scales_void, - const void* fc2_expert_biases_void, const int num_rows, const int hidden_size, const int inter_size, - const int num_experts, const int k, char* workspace_ptr, void* final_output_void, void* fc2_result_void, - const bool* finished, const int active_rows, void* expert_scales_void, +void CutlassMoeFCRunner::runMoe(void const* input_activations_void, float const* gating_output, + void const* fc1_expert_weights_void, void const* fc1_scales_void, void const* fc1_expert_biases_void, + ActivationType fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_scales_void, + void const* fc2_expert_biases_void, int const num_rows, int const hidden_size, int const inter_size, + int const num_experts, int const k, char* workspace_ptr, void* final_output_void, void* fc2_result_void, + bool const* finished, int const active_rows, void* expert_scales_void, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) { static constexpr bool scales_required = std::is_same::value || std::is_same::value; - auto* input_activations = static_cast(input_activations_void); - auto* fc1_expert_weights = static_cast(fc1_expert_weights_void); - auto* fc1_scales = static_cast(fc1_scales_void); - auto* fc1_expert_biases = static_cast(fc1_expert_biases_void); - auto* fc2_expert_weights = static_cast(fc2_expert_weights_void); - auto* fc2_scales = static_cast(fc2_scales_void); - auto* fc2_expert_biases = static_cast(fc2_expert_biases_void); + auto* input_activations = static_cast(input_activations_void); + auto* fc1_expert_weights = static_cast(fc1_expert_weights_void); + auto* fc1_scales = static_cast(fc1_scales_void); + auto* fc1_expert_biases = static_cast(fc1_expert_biases_void); + auto* fc2_expert_weights = static_cast(fc2_expert_weights_void); + auto* fc2_scales = static_cast(fc2_scales_void); + auto* fc2_expert_biases = static_cast(fc2_expert_biases_void); auto* final_output = static_cast(final_output_void); auto* fc2_result = static_cast(fc2_result_void); auto* expert_scales = static_cast(expert_scales_void); @@ -965,9 +965,9 @@ void CutlassMoeFCRunner::runMoe(const void* input_activat TLLM_CHECK_WITH_INFO(fc2_scales == nullptr, "Scales are ignored for fp32/fp16/bf16 but received scale for FC2"); } - const int num_experts_per_node = num_experts / parallelism_config.ep_size; - const int start_expert = num_experts_per_node * parallelism_config.ep_rank; - const int end_expert = start_expert + num_experts_per_node; + int const num_experts_per_node = num_experts / parallelism_config.ep_size; + int const start_expert = num_experts_per_node * parallelism_config.ep_rank; + int const end_expert = start_expert + num_experts_per_node; configureWsPtrs( workspace_ptr, num_rows, hidden_size, inter_size, num_experts, num_experts_per_node, k, fc1_activation_type); @@ -977,21 +977,21 @@ void CutlassMoeFCRunner::runMoe(const void* input_activat sync_check_cuda_error(); sorter_.updateNumExperts(num_experts); - const int sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows, num_experts)); + int const sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows, num_experts)); sorter_.run((void*) sorter_ws_, sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, source_rows_, permuted_rows_, k * num_rows, stream); sync_check_cuda_error(); // Upper bound on number of expanded rows - const int expanded_active_expert_rows = k * active_rows; + int const expanded_active_expert_rows = k * active_rows; computeTotalRowsBeforeExpert( permuted_experts_, expanded_active_expert_rows, num_experts_per_node, total_rows_before_expert_, stream); sync_check_cuda_error(); - const bool needs_num_valid = finished || parallelism_config.ep_size > 1; - const int64_t* num_valid_tokens_ptr + bool const needs_num_valid = finished || parallelism_config.ep_size > 1; + int64_t const* num_valid_tokens_ptr = needs_num_valid ? total_rows_before_expert_ + num_experts_per_node - 1 : nullptr; expandInputRowsKernelLauncher(input_activations, permuted_data_, permuted_rows_, expanded_source_row_to_expanded_dest_row, num_rows, num_valid_tokens_ptr, hidden_size, k, stream); @@ -1035,11 +1035,11 @@ void CutlassMoeFCRunner::runMoe(const void* input_activat } template -void CutlassMoeFCRunner::computeTotalRowsBeforeExpert(const int* sorted_indices, - const int total_indices, const int num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) +void CutlassMoeFCRunner::computeTotalRowsBeforeExpert(int const* sorted_indices, + int const total_indices, int const num_experts, int64_t* total_rows_before_expert, cudaStream_t stream) { - const int threads = std::min(1024, num_experts); - const int blocks = (num_experts + threads - 1) / threads; + int const threads = std::min(1024, num_experts); + int const blocks = (num_experts + threads - 1) / threads; computeTotalRowsBeforeExpertKernel<<>>( sorted_indices, total_indices, num_experts, total_rows_before_expert); diff --git a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h index ae5ab0cc1..79cc31a0a 100644 --- a/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h @@ -26,7 +26,7 @@ namespace tensorrt_llm::kernels { -static inline size_t pad_to_multiple_of_16(const size_t& input) +static inline size_t pad_to_multiple_of_16(size_t const& input) { static constexpr int ALIGNMENT = 16; return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); @@ -49,21 +49,21 @@ static inline size_t pad_to_multiple_of_16(const size_t& input) k - k value in topk */ template -void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, - int* indices, int* source_row, const int num_rows, const int num_experts, const int k, cudaStream_t stream); +void topk_gating_softmax_kernelLauncher(T const* input, bool const* finished, T* output, T* softmax_temp_out, + int* indices, int* source_row, int const num_rows, int const num_experts, int const k, cudaStream_t stream); class CubKeyValueSorter { public: CubKeyValueSorter(); - CubKeyValueSorter(const int num_experts); + CubKeyValueSorter(int const num_experts); - void updateNumExperts(const int num_experts); + void updateNumExperts(int const num_experts); - static size_t getWorkspaceSize(const size_t num_key_value_pairs, const int num_experts); + static size_t getWorkspaceSize(const size_t num_key_value_pairs, int const num_experts); - void run(void* workspace, const size_t workspace_size, const int* keys_in, int* keys_out, const int* values_in, + void run(void* workspace, const size_t workspace_size, int const* keys_in, int* keys_out, int const* values_in, int* values_out, const size_t num_key_value_pairs, cudaStream_t stream); private: @@ -120,28 +120,28 @@ struct MOEParallelismConfig return {1, 0, ep_size, ep_rank}; } - const int tp_size = 1; - const int tp_rank = 0; - const int ep_size = 1; - const int ep_rank = 0; + int const tp_size = 1; + int const tp_rank = 0; + int const ep_size = 1; + int const ep_rank = 0; }; class CutlassMoeFCRunnerInterface { public: virtual ~CutlassMoeFCRunnerInterface() = default; - virtual size_t getWorkspaceSize(const int num_rows, const int hidden_size, const int fc1_output_size, - const int num_experts, const int k, ActivationType activation_type, + virtual size_t getWorkspaceSize(int const num_rows, int const hidden_size, int const fc1_output_size, + int const num_experts, int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config) const = 0; virtual void setTactic(std::optional gemm_config) = 0; virtual std::vector getTactics() = 0; - virtual void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights, - const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type, - const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows, - const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr, - void* final_output, void* fc2_result, const bool* finished, const int active_rows, void* expert_scales, + virtual void runMoe(void const* input_activations, float const* gating_output, void const* fc1_expert_weights, + void const* fc1_scales, void const* fc1_expert_biases, ActivationType fc1_activation_type, + void const* fc2_expert_weights, void const* fc2_scales, void const* fc2_expert_biases, int const num_rows, + int const hidden_size, int const inter_size, int const num_experts, int const k, char* workspace_ptr, + void* final_output, void* fc2_result, bool const* finished, int const active_rows, void* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) @@ -160,8 +160,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface CutlassMoeFCRunner() = default; ~CutlassMoeFCRunner() override = default; - size_t getWorkspaceSize(const int num_rows, const int hidden_size, const int fc1_output_size, const int num_experts, - const int k, ActivationType activation_type, MOEParallelismConfig parallelism_config) const override; + size_t getWorkspaceSize(int const num_rows, int const hidden_size, int const fc1_output_size, int const num_experts, + int const k, ActivationType activation_type, MOEParallelismConfig parallelism_config) const override; void setTactic(std::optional gemm_config) override { @@ -173,22 +173,22 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface return moe_gemm_runner_.getConfigs(); } - void runMoe(const void* input_activations, const float* gating_output, const void* fc1_expert_weights, - const void* fc1_scales, const void* fc1_expert_biases, ActivationType fc1_activation_type, - const void* fc2_expert_weights, const void* fc2_scales, const void* fc2_expert_biases, const int num_rows, - const int hidden_size, const int inter_size, const int num_experts, const int k, char* workspace_ptr, - void* final_output, void* fc2_result, const bool* finished, const int active_rows, void* expert_scales, + void runMoe(void const* input_activations, float const* gating_output, void const* fc1_expert_weights, + void const* fc1_scales, void const* fc1_expert_biases, ActivationType fc1_activation_type, + void const* fc2_expert_weights, void const* fc2_scales, void const* fc2_expert_biases, int const num_rows, + int const hidden_size, int const inter_size, int const num_experts, int const k, char* workspace_ptr, + void* final_output, void* fc2_result, bool const* finished, int const active_rows, void* expert_scales, int* expanded_source_row_to_expanded_dest_row, int* expert_for_source_row, MOEParallelismConfig parallelism_config, MOEExpertScaleNormalizationMode normalization_mode, cudaStream_t stream) override; private: - void computeTotalRowsBeforeExpert(const int* sorted_indices, const int total_indices, const int num_experts, + void computeTotalRowsBeforeExpert(int const* sorted_indices, int const total_indices, int const num_experts, int64_t* total_rows_before_expert, cudaStream_t stream); - std::vector getWorkspaceBufferSizes(const int num_rows, const int hidden_size, const int inter_size, - const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type) const; - void configureWsPtrs(char* ws_ptr, const int num_rows, const int hidden_size, const int inter_size, - const int num_experts, const int num_experts_per_node, const int k, ActivationType activation_type); + std::vector getWorkspaceBufferSizes(int const num_rows, int const hidden_size, int const inter_size, + int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type) const; + void configureWsPtrs(char* ws_ptr, int const num_rows, int const hidden_size, int const inter_size, + int const num_experts, int const num_experts_per_node, int const k, ActivationType activation_type); private: CubKeyValueSorter sorter_; @@ -215,8 +215,8 @@ class CutlassMoeFCRunner -void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs, - void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); +void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_storage, int const temp_storage_size, + BeamHypotheses& beam_hyps, cudaStream_t stream); #define CASE_K(MAX_K) \ - topK_softMax_kernelLauncher( \ - log_probs, bias, finished, cum_log_probs, temp_storage, temp_storage_size, beam_hyps, stream); \ + topK_softMax_kernelLauncher(log_probs, bias, temp_storage, temp_storage_size, beam_hyps, stream); \ break; template -void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs, - void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream) +void invokeTopkSoftMax(T const* log_probs, T const* bias, void* temp_storage, int const temp_storage_size, + BeamHypotheses& beam_hyps, cudaStream_t stream) { int log_beam_width(0); int recursor(beam_hyps.beam_width - 1); @@ -44,9 +43,8 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* f switch (log_beam_width) { - // 0 < beam_width <= 4 - case 0: // 1, 2 - case 1: // 3, 4 + case 0: + case 1: // 0 < beam_width <= 4 CASE_K(4) case 2: // 4 < beam_width <= 8 CASE_K(8) @@ -66,13 +64,11 @@ void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* f #undef CASE_K -template void invokeTopkSoftMax(const float* log_probs, const float* bias, const FinishedState* finished, - float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, - cudaStream_t stream); +template void invokeTopkSoftMax(float const* log_probs, float const* bias, void* tmp_storage, + int const temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); -template void invokeTopkSoftMax(const half* log_probs, const half* bias, const FinishedState* finished, - float* cum_log_probs, void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, - cudaStream_t stream); +template void invokeTopkSoftMax(half const* log_probs, half const* bias, void* tmp_storage, + int const temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels.h b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels.h index c42b61677..5b96f695d 100644 --- a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels.h +++ b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels.h @@ -24,8 +24,8 @@ namespace kernels { template -void invokeTopkSoftMax(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs, - void* tmp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); +void invokeTopkSoftMax(T const* log_probs, T const* bias, void* tmp_storage, int const temp_storage_size, + BeamHypotheses& beam_hyps, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h index 97f451f92..aee2e59e5 100644 --- a/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h +++ b/cpp/tensorrt_llm/kernels/onlineSoftmaxBeamsearchKernels/onlineSoftmaxBeamsearchKernelsTemplate.h @@ -37,7 +37,7 @@ namespace kernels { #define DO_SPLIT_SMALL_TOP_K_SOFTMAX -static const int SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; +static int const SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256; #define TOPK_FP16_STORAGE 0 @@ -52,12 +52,13 @@ __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float return log_prob / static_cast(powf(length, length_penalty)); } +/* +// Useless kernels, remove them? template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(int* topk_id, T* topk_val, int* id_buf) { - const int thread_id = threadIdx.x; - const int block_id = blockIdx.x; + int const thread_id = threadIdx.x; + int const block_id = blockIdx.x; TopK partial; if (thread_id == 0) @@ -71,7 +72,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ int index = block_id * MAX_K * MAX_K; for (int i = 0; i < MAX_K * MAX_K; i++) { - partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); + partial.insert(topk_val[index + i], topk_id[index + i]); } index = block_id * MAX_K; @@ -83,11 +84,11 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ } template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* __restrict topk_tmp_id_buf, - const T* __restrict topk_tmp_val_buf, int* __restrict id_buf, T* __restrict val_buf) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel( + int const* __restrict topk_id, T const* __restrict topk_val, int* __restrict id_buf, T* __restrict val_buf) { - const int thread_id = threadIdx.x; - const int block_id = blockIdx.x; + int const thread_id = threadIdx.x; + int const block_id = blockIdx.x; TopK partial; if (thread_id == 0) @@ -101,7 +102,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* int index = block_id * MAX_K * MAX_K; for (int i = 0; i < MAX_K * MAX_K; i++) { - partial.insert(topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); + partial.insert(topk_val[index + i], topk_id[index + i]); } index = block_id * MAX_K; @@ -112,24 +113,25 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topK_kernel(const int* } } } - +*/ template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __restrict topk_tmp_id_buf, - const T* __restrict topk_tmp_val_buf, float* __restrict cum_log_probs, const FinishedState* finished, - BeamHypotheses beam_hyps, const int candidate_size) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel( + int const* __restrict topk_id, T const* __restrict topk_val, BeamHypotheses beam_hyps, int const candidate_size) { - const int thread_id = threadIdx.x; - const int vector_id = blockIdx.x; - const int K{beam_hyps.beam_width}; - const int vocab_size{beam_hyps.vocab_size}; - const int global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id}; - const T MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - - const float length_penalty{beam_hyps.length_penalties[global_batch_idx]}; - const int early_stopping{beam_hyps.early_stoppings[global_batch_idx]}; - const T diversity_rate{beam_hyps.diversity_rates[global_batch_idx]}; - float* output_log_probs{beam_hyps.log_probs_src}; - const int* sequence_lengths{beam_hyps.sequence_lengths_src}; + int const thread_id = threadIdx.x; + int const vector_id = blockIdx.x; + int const global_batch_idx{beam_hyps.ite * beam_hyps.local_batch_size + vector_id}; + int const K{beam_hyps.beam_width}; + int const vocab_size{beam_hyps.vocab_size}; + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + + float const diversity_rate{beam_hyps.diversity_rates[global_batch_idx]}; + float const length_penalty{beam_hyps.length_penalties[global_batch_idx]}; + int const early_stopping{beam_hyps.early_stoppings[global_batch_idx]}; + int const* input_lengths{beam_hyps.input_lengths}; + int const* sequence_lengths{beam_hyps.sequence_lengths_src}; + + float* __restrict cum_log_probs_src{beam_hyps.cum_log_probs_src}; // copy since it will be modified using cub_kvp = cub::KeyValuePair; using BlockReduce = cub::BlockReduce; @@ -142,9 +144,9 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* __shared__ int selected_beams; __shared__ int thread_requiring_update; - // reposition topk_tmp_id_buf, topk_tmp_val_buf to data for the current vector - topk_tmp_id_buf += vector_id * candidate_size; - topk_tmp_val_buf += vector_id * candidate_size; + // reposition topk_id, topk_val to data for the current vector + topk_id += vector_id * candidate_size; + topk_val += vector_id * candidate_size; if (thread_id == 0) { @@ -152,20 +154,21 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* } if (thread_id < K) { - old_cum_log_probs[thread_id] = cum_log_probs[vector_id * K + thread_id]; + old_cum_log_probs[thread_id] = cum_log_probs_src[vector_id * K + thread_id]; } __syncthreads(); if (beam_hyps.num_beams != nullptr) { - // initialize worst_score if this batch has no finished beam + // Beam search is enabled if (beam_hyps.num_beams[global_batch_idx] == 0 && thread_id == 0) { + // Initialize worst_score if this batch has no finished beam beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; } - // return if this batch has enough finished beams else if (beam_hyps.num_beams[global_batch_idx] == K) { + // Return if this batch has enough finished beams return; } } @@ -174,14 +177,13 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* cub::ArgMax arg_max; cub_kvp partial_topk{candidate_size - 1, -MAX_T_VAL}; - for (int elem_id = thread_id; elem_id < candidate_size; elem_id += THREADBLOCK_SIZE) + for (int id = thread_id; id < candidate_size; id += THREADBLOCK_SIZE) { - int i = beam_hyps.num_beams == nullptr ? elem_id % K : elem_id / 2 / K; - T elem = topk_tmp_val_buf[elem_id]; // use token score to do TopK - elem += diversity_rate * (T) i; - cub_kvp new_elem{elem_id, elem}; + int i = beam_hyps.num_beams == nullptr ? id % K : id / 2 / K; + T elem = topk_val[id] + static_cast(diversity_rate * i); // use token score for TopK + cub_kvp new_elem{id, elem}; partial_topk = arg_max(partial_topk, new_elem); - buf_s[elem_id] = elem; + buf_s[id] = elem; } __syncthreads(); @@ -212,38 +214,36 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* if (thread_id == 0) { - cum_log_probs += vector_id * K; - + // Adjust beams or select completed beams sequentially + // Reference (might be changed along HF in the future): + // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272 for (int i = 0; i < 2 * K; ++i) { - const int current_key = cta_topk[i].key; - const T current_value = cta_topk[i].value; + int const current_key = cta_topk[i].key; + T const current_value = cta_topk[i].value; + bool const is_end_token = topk_id[current_key] % vocab_size == beam_hyps.end_ids[vector_id]; - // Consider to add beam only if this token belongs to top K range and it is end_token - // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272 - if (i < K && beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[current_key] % vocab_size == beam_hyps.end_ids[vector_id]) + if (i < K && beam_hyps.num_beams != nullptr && is_end_token) { - const int seq_len = sequence_lengths[vector_id * K + i] - beam_hyps.input_lengths[global_batch_idx]; - const int pad_if_not_finish = finished[vector_id * K + i].isFinished() ? 0 : 1; - const float normed_score - = apply_length_penalty(current_value, seq_len + pad_if_not_finish, length_penalty); - + // Consider to add beam only if this token is end_token and belongs to top K range + int const seq_len = sequence_lengths[vector_id * K + i] - input_lengths[global_batch_idx]; + int const pad = static_cast(!beam_hyps.finished[vector_id * K + i].isFinished()); + float const normed_score = apply_length_penalty(current_value, seq_len + pad, length_penalty); int beam_idx = beam_hyps.num_beams[global_batch_idx]; - // There are already K beams if (beam_idx == K) { - // The current score is worse than the worst one in beams + // There are already K beams if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { + // Current score is worse than the worst one in candidate beams // Stop considering new beams selected_beams = K; break; } - // The current score is better than the worst one in beams else { - // Find the beam index which score == min_normed_score and erase it. + // Current score is better than the worst one in candidate beams + // Find the beam index which score == min_normed_score and erase it for (int j = 0; j < K; j++) { if (beam_hyps.normed_scores[global_batch_idx * (K * 2) + j] @@ -264,111 +264,106 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(const int* } } } - const int tgt_id_offset + int const tgt_id_offset = ((vector_id + beam_hyps.ite * beam_hyps.local_batch_size) * (K * 2) + beam_idx) * (beam_hyps.max_seq_len); - int prev_id = (topk_tmp_id_buf[current_key] / vocab_size) % K; - const int current_step{sequence_lengths[vector_id * K + prev_id]}; + int prev_id = (topk_id[current_key] / vocab_size) % K; + int const current_step{sequence_lengths[vector_id * K + prev_id]}; beam_hyps.output_ids_tgt[tgt_id_offset + current_step] = beam_hyps.end_ids[vector_id]; - if (beam_hyps.log_probs != nullptr) { - beam_hyps.log_probs[tgt_id_offset + current_step] = (float) topk_tmp_val_buf[current_key] - - old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K]; + beam_hyps.log_probs[tgt_id_offset + current_step] + = (float) topk_val[current_key] - old_cum_log_probs[(topk_id[current_key] / vocab_size) % K]; } for (int j = current_step - 1; j >= 0; j--) { - const int src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K + int const src_idx = j * beam_hyps.batch_size * K + beam_hyps.ite * beam_hyps.local_batch_size * K + vector_id * K + prev_id; beam_hyps.output_ids_tgt[tgt_id_offset + j] - = beam_hyps.output_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; - + = beam_hyps.output_ids_tgt_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { beam_hyps.log_probs[tgt_id_offset + j] = beam_hyps.log_probs_src[src_idx]; } - - prev_id = beam_hyps.parent_ids_src_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; + prev_id = beam_hyps.parent_ids_tgt_ptr[vector_id][prev_id * beam_hyps.max_seq_len + j]; } - const int tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx; - + int const tgt_beam_idx = global_batch_idx * (K * 2) + beam_idx; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = current_step; beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.min_normed_scores[global_batch_idx] = min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); - beam_hyps.num_beams[global_batch_idx]++; - beam_hyps.cum_log_probs[tgt_beam_idx] = (float) topk_tmp_val_buf[current_key]; + beam_hyps.cum_log_probs[tgt_beam_idx] = (float) topk_val[current_key]; } - // This token is end_token but belongs to range K ~ 2K, just ignoe it - // TODO: eliminate this branch by rewriting condition of the else_if - else if (i >= K && beam_hyps.num_beams != nullptr - && topk_tmp_id_buf[current_key] % vocab_size == beam_hyps.end_ids[vector_id]) + else if (i < K || beam_hyps.num_beams != nullptr && !is_end_token) { - ; - } - // Beam search is disabled or this token is not end_token, we add it to the end of the unfinished sentence - else if (beam_hyps.num_beams != nullptr || beam_hyps.num_beams == nullptr && i < K) - { - const int current_step{sequence_lengths[vector_id * K + selected_beams]}; + // Condition of this branch + // 1. beam_hyps.num_beams == nullptr && i < K, i.e., beam search is disable + // 2. beam_hyps.num_beams != nullptr && i < K && is_end_token == false, i.e., add token at the end + // 3. beam_hyps.num_beams != nullptr && i >= K && is_end_token == false, i.e., add token at the end + int const current_step = sequence_lengths[vector_id * K + selected_beams]; beam_hyps.output_ids_tgt_ptr[vector_id][selected_beams * beam_hyps.max_seq_len + current_step] - = topk_tmp_id_buf[current_key]; - if (output_log_probs != nullptr) + = topk_id[current_key]; + if (beam_hyps.log_probs_src != nullptr) { - output_log_probs[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams] - = (float) topk_tmp_val_buf[current_key] - - old_cum_log_probs[(topk_tmp_id_buf[current_key] / vocab_size) % K]; + beam_hyps.log_probs_src[current_step * beam_hyps.batch_size * K + vector_id * K + selected_beams] + = (float) topk_val[current_key] - old_cum_log_probs[(topk_id[current_key] / vocab_size) % K]; } - cum_log_probs[selected_beams] = (float) topk_tmp_val_buf[current_key]; + cum_log_probs_src[vector_id * K + selected_beams] = (float) topk_val[current_key]; selected_beams++; } - __syncthreads(); + else + { + ; + // Condition of this branch, which we do nothing for it + // 1. beam_hyps.num_beams == nullptr && i >= K, i.e., beam search is disable + // 2. beam_hyps.num_beams != nullptr && i >= K && is_end_token == true, i.e., ignore the worse beams + } + if (selected_beams >= K) { break; } } } - // update beam_hyps.is_done for each batch - if (threadIdx.x == 0 && beam_hyps.num_beams != nullptr) + + // update beam_hyps.is_done + if (thread_id == 0 && beam_hyps.num_beams != nullptr) { - // no enough beams - if (beam_hyps.num_beams[blockIdx.x] < K) + if (beam_hyps.num_beams[vector_id] < K) { - beam_hyps.is_done[blockIdx.x] = false; + // no enough beams + beam_hyps.is_done[vector_id] = false; return; } + int seq_len = 0; float highest_attainable_score = 0.0f; switch (early_stopping) { case 1: - // enough beams with early stopping - beam_hyps.is_done[blockIdx.x] = true; + // enough beams with early_stopping + beam_hyps.is_done[vector_id] = true; return; case 0: - // enough beams without early stopping - highest_attainable_score = static_cast(apply_length_penalty(cum_log_probs[0], - sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty)); - beam_hyps.is_done[blockIdx.x] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; + // enough beams with non_early_stopping + seq_len = sequence_lengths[vector_id * K] - input_lengths[global_batch_idx]; + highest_attainable_score = apply_length_penalty(cum_log_probs_src[vector_id * K], seq_len, length_penalty); + beam_hyps.is_done[vector_id] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; return; default: - // early_stopping == "never" in HF, i.e., compute the best possible score depending on `length_penalty` + // early_stopping == "never" in HF, i.e., compute the best possible score depending on length_penalty // https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L990 if (length_penalty > 0.0f) { - highest_attainable_score = static_cast(apply_length_penalty(cum_log_probs[0], - beam_hyps.max_seq_len - beam_hyps.input_lengths[global_batch_idx], length_penalty)); - beam_hyps.is_done[blockIdx.x] - = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; + seq_len = beam_hyps.max_seq_len - input_lengths[global_batch_idx]; } else { - highest_attainable_score = static_cast(apply_length_penalty(cum_log_probs[0], - sequence_lengths[vector_id * K] - beam_hyps.input_lengths[global_batch_idx], length_penalty)); - beam_hyps.is_done[blockIdx.x] - = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; + seq_len = sequence_lengths[vector_id * K] - input_lengths[global_batch_idx]; } + highest_attainable_score = apply_length_penalty(cum_log_probs_src[vector_id * K], seq_len, length_penalty); + beam_hyps.is_done[vector_id] = beam_hyps.min_normed_scores[global_batch_idx] >= highest_attainable_score; return; } } @@ -399,7 +394,7 @@ struct TopKMD }; template -__device__ __forceinline__ TopKMD reduce_topk_md_op(const TopKMD& a, const TopKMD& b) +__device__ __forceinline__ TopKMD reduce_topk_md_op(TopKMD const& a, TopKMD const& b) { TopKMD res; res.md = reduce_md_op(a.md, b.md); @@ -408,14 +403,13 @@ __device__ __forceinline__ TopKMD reduce_topk_md_op(const TopKMD -__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(const T* __restrict log_probs, - const T* __restrict bias, const float* __restrict cum_log_probs, const FinishedState* __restrict finished, - int* __restrict topk_tmp_id_buf, T* __restrict topk_tmp_val_buf, int vocab_size, int K, - const int* __restrict end_ids) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_kernel(T const* __restrict log_probs, + T const* __restrict bias, float const* __restrict cum_log_probs, FinishedState const* __restrict finished, + int* __restrict topk_id, T* __restrict topk_val, int vocab_size, int K, int const* __restrict end_ids) { - const int thread_id = threadIdx.x; - const int vector_id = blockIdx.x; - const T MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + int const thread_id = threadIdx.x; + int const vector_id = blockIdx.x; + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -434,23 +428,22 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker if (finished[vector_id].isFinished()) { - for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) + for (int id = thread_id; id < vocab_size; id += THREADBLOCK_SIZE) { - float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; + float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); - // if (elem_id > THREADBLOCK_SIZE * MAX_K && elem_id == E) break; + partial.topk.insert(elem, id); } } else { - for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) + for (int id = thread_id; id < vocab_size; id += THREADBLOCK_SIZE) { - float elem = log_probs[elem_id] + bias[elem_id]; + float elem = log_probs[id] + bias[id]; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); + partial.topk.insert(elem, id); } } @@ -458,8 +451,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker if (thread_id == 0) { - topk_tmp_id_buf += vector_id * K; - topk_tmp_val_buf += vector_id * K; + topk_id += vector_id * K; + topk_val += vector_id * K; cum_log_probs += vector_id; // float d_total_inverse = __fdividef(1.0F, total.md.d); @@ -470,8 +463,8 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker float val = total.topk.u[i] - total.md.m - d_total_log; if (i < K) { - topk_tmp_id_buf[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id - topk_tmp_val_buf[i] = val + cum_log_probs[0]; + topk_id[i] = total.topk.p[i] + vector_id * vocab_size; // trtllm needs absolute id + topk_val[i] = val + cum_log_probs[0]; } } } @@ -479,34 +472,32 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_ker template __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_base( - const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished, - float* __restrict tmp_buffer, int vocab_size, int K, const int* __restrict end_ids) + T const* __restrict log_probs, T const* __restrict bias, FinishedState const* __restrict finished, + float* __restrict tmp_buffer, int vocab_size, int K, int const* __restrict end_ids) { - const int thread_id = threadIdx.x; - const int vector_id = blockIdx.x; - const T MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; + int const thread_id = threadIdx.x; + int const vector_id = blockIdx.x; + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; // one threadblock has multiple sections per vocab_size - const int v_local = (vocab_size + gridDim.y - 1) / gridDim.y; - const int section_start = v_local * blockIdx.y; - const int section_end = std::min(section_start + v_local, vocab_size); + int const v_local = (vocab_size + gridDim.y - 1) / gridDim.y; + int const section_start = v_local * blockIdx.y; + int const section_end = std::min(section_start + v_local, vocab_size); #if TOPK_FP16_STORAGE == 1 typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + TopKMD<__half, MAX_K2> partial; #else typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + TopKMD partial; #endif + __shared__ typename BlockReduce::TempStorage temp_storage; __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; - // reposition log_probs to data for the current vector + // reposition log_probs to the data for the current vector log_probs += vector_id * vocab_size; -#if TOPK_FP16_STORAGE == 1 - TopKMD<__half, MAX_K2> partial; -#else - TopKMD partial; -#endif for (int i = 0; i < MAX_K2; ++i) { partial.topk.p[i] = -1; @@ -518,24 +509,24 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ if (finished[vector_id].isFinished()) { #pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) { - float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; + float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); + partial.topk.insert(elem, id); } } else { #pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) { - T b = bias == nullptr ? (T) 0.0f : bias[elem_id]; - T elem = log_probs[elem_id] + b; + T b = bias == nullptr ? (T) 0.0f : bias[id]; + T elem = log_probs[id] + b; MD new_elem{elem, 1.0F}; partial.md = reduce_md_op(partial.md, new_elem); - partial.topk.insert(elem, elem_id); + partial.topk.insert(elem, id); } } @@ -556,26 +547,26 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ buf_s[2 * MAX_K2 + 1] = total.md.m; } __syncthreads(); - for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) + + for (int id = thread_id; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) { - tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] - = buf_s[elem_id]; + tmp_buffer[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + id] = buf_s[id]; } } template __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_stage1_kernel_fast( - const T* __restrict log_probs, const T* __restrict bias, const FinishedState* __restrict finished, - float* __restrict t, int vocab_size, int K, const int* __restrict end_ids, const int v_local) + T const* __restrict log_probs, T const* __restrict bias, FinishedState const* __restrict finished, + float* __restrict t, int vocab_size, int K, int const* __restrict end_ids, int const v_local) { - const int thread_id = threadIdx.x; - const int vector_id = blockIdx.x; - const T MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; + int const thread_id = threadIdx.x; + int const vector_id = blockIdx.x; + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; // one threadblock has multiple sections per vocab_size - const int section_start = v_local * blockIdx.y; - const int section_end = std::min(section_start + v_local, vocab_size); - const int valid_smem_length = section_end - section_start; + int const section_start = v_local * blockIdx.y; + int const section_end = std::min(section_start + v_local, vocab_size); + int const valid_smem_length = section_end - section_start; #if TOPK_FP16_STORAGE == 1 using cub_kvp = cub::KeyValuePair; @@ -589,6 +580,8 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ extern __shared__ char buf_smem_logprobs_[]; T* buf_smem_logprobs = reinterpret_cast(buf_smem_logprobs_); + __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; + __shared__ int thread_requiring_update; __shared__ union { @@ -596,10 +589,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ typename BlockReduceTopK::TempStorage topk_smem; } temp_storage; - __shared__ float buf_s[PACKED_TOP_KMD_SIZE]; - __shared__ int thread_requiring_update; - - // reposition log_probs to data for the current vector + // reposition log_probs to the data for the current vector log_probs += vector_id * vocab_size; cub::ArgMax arg_max; @@ -608,14 +598,14 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ if (finished[vector_id].isFinished()) { #pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) { - float elem = (elem_id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; - buf_smem_logprobs[elem_id - section_start] = elem; + float elem = (id == end_ids[vector_id / K]) ? MAX_T_VAL : -MAX_T_VAL; + buf_smem_logprobs[id - section_start] = elem; MD new_elem{elem, 1.0F}; partial_md = reduce_md_op(partial_md, new_elem); - const int smem_index = elem_id - section_start; + int const smem_index = id - section_start; cub_kvp new_elem_topk{smem_index, elem}; partial_topk = arg_max(partial_topk, new_elem_topk); buf_smem_logprobs[smem_index] = elem; @@ -624,14 +614,14 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ else { #pragma unroll 1 - for (int elem_id = section_start + thread_id; elem_id < section_end; elem_id += THREADBLOCK_SIZE) + for (int id = section_start + thread_id; id < section_end; id += THREADBLOCK_SIZE) { - T b = bias == nullptr ? (T) 0.0f : bias[elem_id]; - T elem = log_probs[elem_id] + b; + T b = bias == nullptr ? (T) 0.0f : bias[id]; + T elem = log_probs[id] + b; MD new_elem_md{elem, 1.0F}; partial_md = reduce_md_op(partial_md, new_elem_md); - const int smem_index = elem_id - section_start; + int const smem_index = id - section_start; cub_kvp new_elem_topk{smem_index, elem}; partial_topk = arg_max(partial_topk, new_elem_topk); buf_smem_logprobs[smem_index] = elem; @@ -655,7 +645,7 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ // Only one thread needs to update the old partial before the next block reduce. // No need to do this in the last iteration. - if (thread_id == thread_requiring_update && i < (2 * K - 1)) + if (thread_id == thread_requiring_update && i < 2 * K - 1) { partial_topk.key = vocab_size - 1; partial_topk.value = -MAX_T_VAL; @@ -676,21 +666,22 @@ __launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beam_online_softmax_topk_ buf_s[2 * MAX_K2 + 1] = total_md.m; } __syncthreads(); - for (int elem_id = thread_id; elem_id < PACKED_TOP_KMD_SIZE; elem_id += THREADBLOCK_SIZE) + + for (int id = thread_id; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE) { - t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + elem_id] = buf_s[elem_id]; + t[blockIdx.x * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE + id] = buf_s[id]; } } template __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_stage2_kernel( - const float* __restrict temp_storage, const float* __restrict cum_log_probs, int* __restrict ids, - T* __restrict vals, int K, int parts_per_beam, const int vocab_size) + float const* __restrict temp_storage, float const* __restrict cum_log_probs, int* __restrict ids, + T* __restrict vals, int K, int parts_per_beam, int const vocab_size) { - const int vector_id = blockIdx.x; - const int thread_id = threadIdx.x; - const T MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - const int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; + int const vector_id = blockIdx.x; + int const thread_id = threadIdx.x; + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + int const PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2; using cub_kvp = cub::KeyValuePair; using BlockReduceTopK = cub::BlockReduce; @@ -788,12 +779,12 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_online_softmax_topk_sta } template -void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, const float* cum_log_probs, int* ids, - T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, const int vocab_size) +void beam_online_softmax_topk_stage2_kernelLauncher(float const* temp_storage, float const* cum_log_probs, int* ids, + T* vals, int batch_size, int beam_width, int parts_per_beam, cudaStream_t stream, int const vocab_size) { // TODO: rewrite beam_online_softmax_topk_stage2_kernel to remove dependence // of constant block size in oreder to reduce compilation time - const int smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float); + int const smem_stage2_size = parts_per_beam * (2 * MAX_K2 + 2) * sizeof(float); if (parts_per_beam <= 32) { @@ -820,27 +811,28 @@ void beam_online_softmax_topk_stage2_kernelLauncher(const float* temp_storage, c } template -void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const FinishedState* finished, float* cum_log_probs, - void* temp_storage, const int temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream) +void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_storage, int const temp_storage_size, + BeamHypotheses& beam_hyps, cudaStream_t stream) { - const int batch_size{beam_hyps.local_batch_size}; - const int beam_width{beam_hyps.beam_width}; - const int vocab_size{beam_hyps.vocab_size}; - const int* end_ids{beam_hyps.end_ids}; + int const batch_size{beam_hyps.local_batch_size}; + int const beam_width{beam_hyps.beam_width}; + int const vocab_size{beam_hyps.vocab_size}; + int const* end_ids{beam_hyps.end_ids}; + float* cum_log_probs{beam_hyps.cum_log_probs_src}; + FinishedState const* finished{beam_hyps.finished}; - const int items_per_thread = 1; - const int block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64; - // const int block_sz = SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE; + int const items_per_thread = 1; + int const block_sz = (MAX_K < 16) ? ((MAX_K < 8) ? SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE : 128) : 64; assert(temp_storage_size % 2 == 0); assert(temp_storage_size >= 2 * batch_size * beam_width * beam_width * 2); - // Beam search needs the sequence lengths of beams to apply length penalty. + // Input and current sequence lengths are needed for computation of length penalty assert(beam_hyps.length_penalties == nullptr || beam_hyps.sequence_lengths_src != nullptr); - const int topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; - int* topk_tmp_id_buf = reinterpret_cast(temp_storage); - T* topk_tmp_val_buf = reinterpret_cast(topk_tmp_id_buf + topk_buf_offset); - float* tmp_buffer = reinterpret_cast(topk_tmp_val_buf + topk_buf_offset); + int const topk_buf_offset = ceil(batch_size * beam_width * beam_width * 2 / 4.) * 4; + int* topk_id = reinterpret_cast(temp_storage); + T* topk_val = reinterpret_cast(topk_id + topk_buf_offset); + float* tmp_buffer = reinterpret_cast(topk_val + topk_buf_offset); #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX // First, we query the occupancy assuming we need no smem. The goal of this heuristic is to simply run @@ -859,14 +851,14 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish TLLM_CUDA_CHECK(cudaFuncGetAttributes( &attr, beam_online_softmax_topk_stage1_kernel_fast)); - const int constant_smem = attr.sharedSizeBytes; - const int max_dyn_smem_per_block = max_smem_per_block - constant_smem; + int const constant_smem = attr.sharedSizeBytes; + int const max_dyn_smem_per_block = max_smem_per_block - constant_smem; constexpr int max_parts = 128; TLLM_CHECK_WITH_INFO(vocab_size * sizeof(T) <= max_dyn_smem_per_block * max_parts, "Vocab size too large for split-k top-k beam search fast path."); - const int driver_smem_per_block = max_smem_per_sm - max_smem_per_block; - const int extra_smem = driver_smem_per_block + constant_smem; + int const driver_smem_per_block = max_smem_per_sm - max_smem_per_block; + int const extra_smem = driver_smem_per_block + constant_smem; int smem_per_block = max_smem_per_sm / max_active_blocks; int dyn_smem_size = smem_per_block - extra_smem; @@ -898,7 +890,7 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish dyn_smem_size = sizeof(T) * (vocab_size + voc_parts - 1) / voc_parts; dim3 grid(batch_size * beam_width, voc_parts); // dynamically allocate shared memory - const int voc_size_chunk = dyn_smem_size / sizeof(T); + int const voc_size_chunk = dyn_smem_size / sizeof(T); if (dyn_smem_size >= (48 << 10)) { @@ -913,7 +905,7 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish } else { - // use original stage 1 base kernel + // use stage 1 base kernel int voc_parts = 4; if (batch_size * beam_width < 256) { @@ -934,35 +926,32 @@ void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, const Finish #endif #ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX - beam_online_softmax_topk_stage2_kernelLauncher(tmp_buffer, cum_log_probs, topk_tmp_id_buf, - topk_tmp_val_buf, batch_size, beam_width, voc_parts, stream, vocab_size); + beam_online_softmax_topk_stage2_kernelLauncher( + tmp_buffer, cum_log_probs, topk_id, topk_val, batch_size, beam_width, voc_parts, stream, vocab_size); sync_check_cuda_error(); #else beam_online_softmax_topk_kernel - <<>>(log_probs, bias, cum_log_probs, finished, topk_tmp_id_buf, - topk_tmp_val_buf, vocab_size, beam_width, end_ids); + <<>>( + log_probs, bias, cum_log_probs, finished, topk_id, topk_val, vocab_size, beam_width, end_ids); #endif - // We need 2*MAX_K candidates because at most k candidates are finished, and - // we will not put them into next iteration - - const int candidates = beam_width * beam_width * 2; - const int smem_size_batch_topk = sizeof(T) * candidates; + // Keep 2*MAX_K candidates in case of k candidates finishes in one iteration + int const candidates = beam_width * beam_width * 2; + int const smem_size_batch_topk = sizeof(T) * candidates; if (smem_size_batch_topk >= (48 << 10)) { TLLM_CUDA_CHECK(cudaFuncSetAttribute( batch_topk_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_batch_topk)); } - batch_topk_kernel<<>>( - topk_tmp_id_buf, topk_tmp_val_buf, cum_log_probs, finished, beam_hyps, candidates); + batch_topk_kernel + <<>>(topk_id, topk_val, beam_hyps, candidates); sync_check_cuda_error(); } #define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \ - template void topK_softMax_kernelLauncher(const T* log_probs, const T* bias, \ - const FinishedState* finished, float* cum_log_probs, void* temp_storage, const int temp_storage_size, \ - BeamHypotheses& beam_hyps, cudaStream_t stream); + template void topK_softMax_kernelLauncher(T const* log_probs, T const* bias, void* temp_storage, \ + int const temp_storage_size, BeamHypotheses& beam_hyps, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.cu b/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.cu index ac78619c2..a8b5c1de7 100644 --- a/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.cu +++ b/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.cu @@ -29,8 +29,8 @@ static constexpr int kUpdateKVCacheKernelShmSize = 16384; template __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array kvCacheBuffers, - const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices, - const int32_t* pastKeyValueLengths, int rewindDraftTokenCommonCount, const int* rewindDraftTokenSeparateAdjustments, + int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices, + int32_t const* pastKeyValueLengths, int rewindDraftTokenCommonCount, int const* rewindDraftTokenSeparateAdjustments, int eltCountPerHead) { int seqIdx = blockIdx.x; @@ -123,9 +123,9 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array -void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers, - const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices, - const int32_t* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, +void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers, + int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices, + int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream) { // make sure launch buffer is enough @@ -149,7 +149,7 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers, kvCacheBufferArray[i] = kvCacheBuffers[i]; } void (*pKernelFunc)( - std::array, const int*, const IndexType*, const int32_t*, int, const int*, int) + std::array, int const*, IndexType const*, int32_t const*, int, int const*, int) = nullptr; switch (alignedBytes) { @@ -204,8 +204,8 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers, * @param stream : CUDA stream to use. */ template -void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int layerCount, int seqCount, +void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream) { @@ -222,8 +222,8 @@ void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const } } -void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream) { @@ -240,8 +240,8 @@ void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffse rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream); } -void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream) @@ -259,8 +259,8 @@ void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffset rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream); } -void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream) { @@ -269,8 +269,8 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDra rewindDraftTokenCount, nullptr, maxKVCacheLen, stream); } -void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream) { @@ -279,8 +279,8 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraf rewindDraftTokenCount, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream); } -void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream) { @@ -289,8 +289,8 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedD rewindDraftTokenCounts, maxKVCacheLen, stream); } -void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream) { diff --git a/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h b/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h index 26ad4888d..c2c6a3e88 100644 --- a/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h +++ b/cpp/tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h @@ -42,8 +42,8 @@ using IndexType = int; * @param maxKVCacheLen : Maximum length of each KV cache * @param stream : CUDA stream to use. */ -void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream); @@ -64,8 +64,8 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDra * @param tokensPerBlock : Tokens per block of Block KV cache * @param stream : CUDA stream to use. */ -void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream); @@ -85,8 +85,8 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraf * @param maxKVCacheLen : Maximum length of each KV cache * @param stream : CUDA stream to use. */ -void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream); @@ -108,8 +108,8 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedD * @param tokensPerBlock : Tokens per block of Block KV cache * @param stream : CUDA stream to use. */ -void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream); @@ -132,8 +132,8 @@ void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDr * @param maxKVCacheLen : Maximum length of each KV cache * @param stream : CUDA stream to use. */ -void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, +void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream); @@ -158,8 +158,8 @@ void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffse * @param tokensPerBlock : Tokens per block of Block KV cache * @param stream : CUDA stream to use. */ -void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets, - const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray, +void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets, + IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream); diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.cu b/cpp/tensorrt_llm/kernels/penaltyKernels.cu index db860271a..bb58fff92 100644 --- a/cpp/tensorrt_llm/kernels/penaltyKernels.cu +++ b/cpp/tensorrt_llm/kernels/penaltyKernels.cu @@ -23,6 +23,7 @@ #include "tensorrt_llm/kernels/penaltyKernels.h" using namespace tensorrt_llm::common; +using namespace tensorrt_llm::runtime; namespace tensorrt_llm { @@ -31,36 +32,47 @@ namespace kernels template __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases, - int32_t* penaltyWorkspace, int32_t const* penaltyWorkspacePrev, float const* temperatures, + TokenIdType* penaltyWorkspace, TokenIdType const* penaltyWorkspacePrev, float const* temperatures, float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties, - const bool accumulateVocab, int32_t const maxSeqLen, int32_t const vocabSize, int32_t const vocabSizePadded, - int32_t const** outputIdsPtr, int32_t const** parentIdsPtr, int32_t const* inputLengths, - int32_t const* sequenceLengths, int32_t const* minLengths, int32_t const* endIds, int32_t const* batchSlots) + bool accumulateVocab, SizeType maxSeqLen, SizeType vocabSize, SizeType vocabSizePadded, + TokenIdType const** outputIdsPtr, SizeType const** parentIdsPtr, SizeType const* inputLengths, + SizeType const* sequenceLengths, SizeType const* minLengths, TokenIdType const* endIds, SizeType const* batchSlots, + SizeType const* tokensPerStep) { - int32_t const beamWidth = gridDim.y; - int32_t const batchIdx = blockIdx.x; - int32_t const beamIdx = blockIdx.y; - int32_t const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx]; - int32_t const batchBeamIdx = batchIdx * beamWidth + beamIdx; - int32_t const batchSlotBeamIdx = batchSlot * beamWidth + beamIdx; - int32_t const inputLen = inputLengths == nullptr ? 0 : inputLengths[batchSlotBeamIdx]; - int32_t const currentStep = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlotBeamIdx]; + auto const beamWidth = static_cast(gridDim.y); + auto const maxTokensPerStep = static_cast(gridDim.z); + auto const batchIdx = static_cast(blockIdx.x); + auto const beamIdx = static_cast(blockIdx.y); + auto const stepIdx = static_cast(blockIdx.z); + auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx]; + auto const batchBeamStepIdx = (batchIdx * beamWidth + beamIdx) * maxTokensPerStep + stepIdx; + auto const batchSlotBeamIdx = batchSlot * beamWidth + beamIdx; + auto const inputLen = inputLengths == nullptr ? SizeType{0} : inputLengths[batchSlotBeamIdx]; + auto const currentStep = sequenceLengths == nullptr ? SizeType{0} : sequenceLengths[batchSlotBeamIdx]; T const* biasBase = biases + batchSlot * vocabSizePadded; + + if (tokensPerStep != nullptr && stepIdx >= tokensPerStep[batchSlot]) + { + return; + } + // Initialize or update the number of occurrences of tokens if (accumulateVocab) { - penaltyWorkspace += batchBeamIdx * vocabSize; + penaltyWorkspace += batchBeamStepIdx * vocabSize; if (currentStep <= inputLen) { // Context phase - for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x) + for (auto index = static_cast(threadIdx.x); index < vocabSize; + index += static_cast(blockDim.x)) { penaltyWorkspace[index] = 0; } __syncthreads(); - for (int32_t step = threadIdx.x; step < inputLen; step += blockDim.x) + for (auto step = static_cast(threadIdx.x); step < inputLen; + step += static_cast(blockDim.x)) { // All beams in the context phase are identical - int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; + auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step]; if (penaltyIndex < vocabSize) { atomicAdd(&penaltyWorkspace[penaltyIndex], 1); @@ -71,9 +83,10 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, { // Generation phase if (beamWidth > 1) { - int32_t parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2]; - penaltyWorkspacePrev += (batchIdx * beamWidth + parentBeam) * vocabSize; - for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x) + auto parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2]; + penaltyWorkspacePrev += ((batchIdx * beamWidth + parentBeam) * maxTokensPerStep + stepIdx) * vocabSize; + for (auto index = static_cast(threadIdx.x); index < vocabSize; + index += static_cast(blockDim.x)) { penaltyWorkspace[index] = penaltyWorkspacePrev[index]; } @@ -81,7 +94,7 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, } if (threadIdx.x == 0) { - int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; + auto penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1]; if (penaltyIndex < vocabSize) { penaltyWorkspace[penaltyIndex] += 1; @@ -90,9 +103,10 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, } __syncthreads(); } + // Apply bias and penalties - auto const inLogitsPtr = inputLogits[batchIdx] + beamIdx * vocabSizePadded; - auto outLogitsPtr = outputLogits + batchBeamIdx * vocabSizePadded; + auto const inLogitsPtr = inputLogits[batchIdx] + (beamIdx * maxTokensPerStep + stepIdx) * vocabSizePadded; + auto outLogitsPtr = outputLogits + batchBeamStepIdx * vocabSizePadded; const T MASK_VAL = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; float invTemperature, repetitionPenalty, presencePenalty, frequencyPenalty; if (temperatures != nullptr) @@ -111,22 +125,23 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, { frequencyPenalty = frequencyPenalties[batchSlot]; } - for (int32_t index = threadIdx.x; index < vocabSizePadded; index += blockDim.x) + for (auto index = static_cast(threadIdx.x); index < vocabSizePadded; + index += static_cast(blockDim.x)) { if (index < vocabSize) { - float logit = (float) inLogitsPtr[index]; + auto logit = static_cast(inLogitsPtr[index]); // Bias if (biases != nullptr) { - logit += (float) biasBase[index]; + logit += static_cast(biasBase[index]); } // Temperature if (temperatures != nullptr) { logit *= invTemperature; } - int32_t numOccurences = penaltyWorkspace[index]; + SizeType numOccurences = penaltyWorkspace[index]; if (numOccurences > 0) { // Repetition @@ -164,21 +179,21 @@ __global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, } template -void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams& params) +void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); dim3 block(256); - dim3 grid(params.batchSize, params.beamWidth); + dim3 grid(params.batchSize, params.beamWidth, params.maxTokensPerStep); batchApplyPenalty<<>>(params.inputLogits, params.outputLogits, params.biases, params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties, params.presencePenalties, params.frequencyPenalties, params.accumulateVocab, params.maxSeqLen, params.vocabSize, params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, - params.minLengths, params.endIds, params.batchSlots); + params.minLengths, params.endIds, params.batchSlots, params.tokensPerStep); } -template void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams& params); +template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); -template void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams& params); +template void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/penaltyKernels.h b/cpp/tensorrt_llm/kernels/penaltyKernels.h index 00ebe44aa..7b3da736e 100644 --- a/cpp/tensorrt_llm/kernels/penaltyKernels.h +++ b/cpp/tensorrt_llm/kernels/penaltyKernels.h @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/penaltyTypes.h" +#include "tensorrt_llm/runtime/common.h" namespace tensorrt_llm { @@ -30,31 +31,33 @@ struct InvokeBatchApplyPenaltyParams { T const* const* inputLogits; T* outputLogits; - const T* biases; - int* penaltyWorkspace; - const int* penaltyWorkspacePrev; - const float* temperatures; - const float* repetitionPenalties; - const float* presencePenalties; - const float* frequencyPenalties; - const bool accumulateVocab; - const size_t batchSize; - const int beamWidth; - const int maxSeqLen; - const size_t vocabSize; - const size_t vocabSizePadded; - const int** outputIdsPtr; - const int** parentIdsPtr; - const int* inputLengths; - const int* sequenceLengths; - const int* minLengths; - const int* endIds; - const int* batchSlots; + T const* biases; + runtime::TokenIdType* penaltyWorkspace; + runtime::TokenIdType const* penaltyWorkspacePrev; + float const* temperatures; + float const* repetitionPenalties; + float const* presencePenalties; + float const* frequencyPenalties; + bool const accumulateVocab; + size_t const batchSize; + runtime::SizeType const beamWidth; + runtime::SizeType const maxSeqLen; + size_t const vocabSize; + size_t const vocabSizePadded; + runtime::TokenIdType const** outputIdsPtr; + runtime::SizeType const** parentIdsPtr; + runtime::SizeType const* inputLengths; + runtime::SizeType const* sequenceLengths; + runtime::SizeType const* minLengths; + runtime::TokenIdType const* endIds; + runtime::SizeType const* batchSlots; + runtime::SizeType const maxTokensPerStep; + runtime::SizeType const* tokensPerStep; cudaStream_t stream; }; template -void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams& params); +void invokeBatchApplyPenalty(InvokeBatchApplyPenaltyParams const& params); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu index 5d1ce20cb..5f85fe565 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.cu @@ -25,7 +25,7 @@ struct Vec2Type<__nv_bfloat16> template __global__ void apply_per_channel_scale( - T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols) + T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols) { static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); T_in scale[kElems], act_vec[kElems]; @@ -35,11 +35,11 @@ __global__ void apply_per_channel_scale( return; act += row_offset * kProcessRows * cols; smoothed_act += row_offset * kProcessRows * cols; - *reinterpret_cast(scale) = reinterpret_cast(per_channel_scale)[col_offset]; + *reinterpret_cast(scale) = reinterpret_cast(per_channel_scale)[col_offset]; #pragma unroll for (int i = 0; i < kProcessRows; ++i) { - *reinterpret_cast(act_vec) = reinterpret_cast(act + i * cols)[col_offset]; + *reinterpret_cast(act_vec) = reinterpret_cast(act + i * cols)[col_offset]; if constexpr ((std::is_same_v #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) || std::is_same_v @@ -80,7 +80,7 @@ __global__ void apply_per_channel_scale( template void apply_per_channel_scale_kernel_launcher_( - T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0) + T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0) { static constexpr int kElems = sizeof(AccessType) / sizeof(T_in); dim3 block(128); @@ -91,7 +91,7 @@ void apply_per_channel_scale_kernel_launcher_( template void apply_per_channel_scale_kernel_launcher( - T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream) + T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream) { int elems = rows * cols; if (elems < 2048 * 2048) diff --git a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h index 57ca57413..5ac02cff6 100644 --- a/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h +++ b/cpp/tensorrt_llm/kernels/preQuantScaleKernel.h @@ -35,7 +35,7 @@ namespace kernels template void apply_per_channel_scale_kernel_launcher( - T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0); + T_out* smoothed_act, T_in const* act, T_in const* per_channel_scale, int rows, int cols, cudaStream_t stream = 0); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/quantization.cu b/cpp/tensorrt_llm/kernels/quantization.cu index 5214fb02c..5ed12fd1a 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cu +++ b/cpp/tensorrt_llm/kernels/quantization.cu @@ -27,11 +27,11 @@ namespace tensorrt_llm namespace kernels { -__global__ void quantizedKernel(char4* dst, const float4* src, const int64_t sizeDiv4, const float* scalePtr) +__global__ void quantizedKernel(char4* dst, float4 const* src, const int64_t sizeDiv4, float const* scalePtr) { for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x) { - const float scale = __ldg(scalePtr); + float const scale = __ldg(scalePtr); char4 tmp; const float4 floatTmp = __ldg(src + idx); tmp.x = cuda_cast(floatTmp.x * scale); @@ -42,18 +42,18 @@ __global__ void quantizedKernel(char4* dst, const float4* src, const int64_t siz } } -__global__ void quantizedKernel(char4* dst, const half2* src, const int64_t sizeDiv4, const float* scalePtr) +__global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t sizeDiv4, float const* scalePtr) { for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x) { - const float scale = __ldg(scalePtr); + float const scale = __ldg(scalePtr); char4 tmp; int srcId = idx << 1; - const uint2 h2 = __ldg(reinterpret_cast(src + srcId)); + const uint2 h2 = __ldg(reinterpret_cast(src + srcId)); - const half2 half2Tmp = reinterpret_cast(h2.x); - const half2 half2Tmp2 = reinterpret_cast(h2.y); + const half2 half2Tmp = reinterpret_cast(h2.x); + const half2 half2Tmp2 = reinterpret_cast(h2.y); tmp.x = cuda_cast(cuda_cast(half2Tmp.x) * scale); tmp.y = cuda_cast(cuda_cast(half2Tmp.y) * scale); @@ -65,7 +65,7 @@ __global__ void quantizedKernel(char4* dst, const half2* src, const int64_t size template void invokeQuantization( - int8_t* dst, const T* src, const int64_t size, const float* scalePtr, cudaStream_t stream, int maxGridSize) + int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize) { TLLM_CHECK_WITH_INFO(size % 4 == 0, "[ERROR][invokeQuantization] size should be a multiple of 4.\n"); @@ -75,25 +75,25 @@ void invokeQuantization( dim3 block(64); if (std::is_same_v) { - quantizedKernel<<>>((char4*) dst, (const float4*) src, size / 4, scalePtr); + quantizedKernel<<>>((char4*) dst, (float4 const*) src, size / 4, scalePtr); } else if (std::is_same_v) { - quantizedKernel<<>>((char4*) dst, (const half2*) src, size / 4, scalePtr); + quantizedKernel<<>>((char4*) dst, (half2 const*) src, size / 4, scalePtr); } } template void invokeQuantization( - int8_t* dst, const float* src, const int64_t size, const float* scalePtr, cudaStream_t stream, int maxGridSize); + int8_t* dst, float const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize); template void invokeQuantization( - int8_t* dst, const half* src, const int64_t size, const float* scalePtr, cudaStream_t stream, int maxGridSize); + int8_t* dst, half const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize); template __global__ void perTokenQuantization( - int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr) + int8_t* dst, T const* src, const int64_t numRows, const int64_t numCols, float* scalePtr) { - const T* srcRow = src + blockIdx.x * numCols; + T const* srcRow = src + blockIdx.x * numCols; int8_t* dstRow = dst + blockIdx.x * numCols; T localMax = 1e-6f; @@ -101,14 +101,14 @@ __global__ void perTokenQuantization( { localMax = cuda_max(localMax, cuda_abs(srcRow[i])); } - const float rowMax = blockAllReduceMax(cuda_cast(localMax)); + float const rowMax = blockAllReduceMax(cuda_cast(localMax)); if (threadIdx.x == 0) { scalePtr[blockIdx.x] = rowMax / 127.f; } - const float scaleOrigQuant = 127.f / rowMax; + float const scaleOrigQuant = 127.f / rowMax; for (int i = threadIdx.x; i < numCols; i += blockDim.x) { dstRow[i] = cuda_cast(cuda_cast(srcRow[i]) * scaleOrigQuant); @@ -117,7 +117,7 @@ __global__ void perTokenQuantization( template void invokePerTokenQuantization( - int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr, cudaStream_t stream) + int8_t* dst, T const* src, const int64_t numRows, const int64_t numCols, float* scalePtr, cudaStream_t stream) { // each block is responsible for a single row const dim3 block(512); diff --git a/cpp/tensorrt_llm/kernels/quantization.h b/cpp/tensorrt_llm/kernels/quantization.h index afc5f48c4..e80c5b210 100644 --- a/cpp/tensorrt_llm/kernels/quantization.h +++ b/cpp/tensorrt_llm/kernels/quantization.h @@ -25,11 +25,11 @@ namespace kernels template void invokeQuantization( - int8_t* dst, const T* src, const int64_t size, const float* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0); + int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0); template void invokePerTokenQuantization( - int8_t* dst, const T* src, const int64_t numRows, const int64_t numCols, float* scalePtr, cudaStream_t stream = 0); + int8_t* dst, T const* src, const int64_t numRows, const int64_t numCols, float* scalePtr, cudaStream_t stream = 0); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/rmsnormKernels.cu b/cpp/tensorrt_llm/kernels/rmsnormKernels.cu index 326733605..c9d42396d 100644 --- a/cpp/tensorrt_llm/kernels/rmsnormKernels.cu +++ b/cpp/tensorrt_llm/kernels/rmsnormKernels.cu @@ -26,7 +26,7 @@ namespace kernels { template -__inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, const T* gamma, const T* beta, int i) +__inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, T const* gamma, T const* beta, int i) { Tf ret = val * s_variance * cuda_cast(gamma[i]); if (beta != nullptr) @@ -50,8 +50,8 @@ __inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, const T* gamm * normed_output_quant. */ template -__global__ void generalRmsNorm(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, - int tokens, int hidden_dim, const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, +__global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, + int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, bool use_shmem) { constexpr auto num_elems_T = num_elems::value; @@ -64,13 +64,13 @@ __global__ void generalRmsNorm(const T* input, const T* gamma, const T* beta, T* __shared__ float s_variance; - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; + int const tidx = threadIdx.x; + int const bidx = blockIdx.x; float variance = 0.0f; float local_var_sum = 0.0f; - const int n_elems = hidden_dim / num_elems_T; + int const n_elems = hidden_dim / num_elems_T; for (int i = tidx; i < n_elems; i += blockDim.x) { const T val = input[bidx * n_elems + i]; @@ -95,15 +95,15 @@ __global__ void generalRmsNorm(const T* input, const T* gamma, const T* beta, T* } __syncthreads(); - const bool with_per_token_scaling = scale_orig_quant_per_token != nullptr; - const bool with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; + bool const with_per_token_scaling = scale_orig_quant_per_token != nullptr; + bool const with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr; const float_packed_t scale_orig_quant = cuda_cast(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f); T_scalar amax = 1e-6f; for (int i = tidx; i < n_elems; i += blockDim.x) { - const int index = bidx * n_elems + i; + int const index = bidx * n_elems + i; const float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); const T val = cuda_cast(compute_rmsnorm(val_f, s_variance, gamma, beta, i)); @@ -129,10 +129,10 @@ __global__ void generalRmsNorm(const T* input, const T* gamma, const T* beta, T* if (with_per_token_scaling) { float abs_max_f = blockAllReduceMax(cuda_cast(amax)); - const float dynamic_per_token_scale = 127.f / abs_max_f; + float const dynamic_per_token_scale = 127.f / abs_max_f; for (int i = tidx; i < n_elems; i += blockDim.x) { - const int index = bidx * n_elems + i; + int const index = bidx * n_elems + i; float_packed_t val_f = cuda_cast(use_shmem ? shmem[i] : input[index]); if (!use_shmem) { @@ -150,8 +150,8 @@ __global__ void generalRmsNorm(const T* input, const T* gamma, const T* beta, T* } template -void dispatch_rmsnorm_type_square_method(const T* input, const T* gamma, const T* beta, T* normed_output, - const float eps, int tokens, int hidden_dim, const float* scale_orig_quant_per_tensor, +void dispatch_rmsnorm_type_square_method(T const* input, T const* gamma, T const* beta, T* normed_output, + float const eps, int tokens, int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream) { @@ -165,8 +165,8 @@ void dispatch_rmsnorm_type_square_method(const T* input, const T* gamma, const T } template -void dispatch_rmsnorm_type(const T* input, const T* gamma, const T* beta, T* normed_output, const float eps, int tokens, - int hidden_dim, const float* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, +void dispatch_rmsnorm_type(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, int tokens, + int hidden_dim, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token, int8_t* normed_output_quant, const dim3 grid, const dim3 block, const size_t shmem_size, cudaStream_t stream) { dispatch_rmsnorm_type_square_method(input, gamma, beta, normed_output, eps, tokens, hidden_dim, @@ -174,8 +174,8 @@ void dispatch_rmsnorm_type(const T* input, const T* gamma, const T* beta, T* nor } template -void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, const int tokens, - const int hidden_dim, cudaStream_t stream, const float* scale, float* dynamic_scale, int8_t* normed_output_quant) +void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens, + int const hidden_dim, cudaStream_t stream, float const* scale, float* dynamic_scale, int8_t* normed_output_quant) { dim3 grid(tokens); dim3 block(min(hidden_dim, 1024)); @@ -184,7 +184,7 @@ void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, constexpr size_t vec_size = 2; const size_t shmem_size = hidden_dim * sizeof(T); - const bool use_vec_type = (hidden_dim % vec_size == 0) + bool const use_vec_type = (hidden_dim % vec_size == 0) && (std::is_same::value #ifdef ENABLE_BF16 || std::is_same::value @@ -194,8 +194,8 @@ void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, if (use_vec_type) { using Tp = typename packed_as::type; - dispatch_rmsnorm_type(reinterpret_cast(input), reinterpret_cast(gamma), - reinterpret_cast(beta), reinterpret_cast(out), eps, tokens, hidden_dim, scale, + dispatch_rmsnorm_type(reinterpret_cast(input), reinterpret_cast(gamma), + reinterpret_cast(beta), reinterpret_cast(out), eps, tokens, hidden_dim, scale, dynamic_scale, normed_output_quant, grid, block, shmem_size, stream); } else diff --git a/cpp/tensorrt_llm/kernels/rmsnormKernels.h b/cpp/tensorrt_llm/kernels/rmsnormKernels.h index 1529a2af9..14e0e0cbf 100644 --- a/cpp/tensorrt_llm/kernels/rmsnormKernels.h +++ b/cpp/tensorrt_llm/kernels/rmsnormKernels.h @@ -27,8 +27,8 @@ namespace kernels { template -void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, const int tokens, - const int hidden_dim, cudaStream_t stream = 0, const float* scale = nullptr, float* dynamic_scale = nullptr, +void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens, + int const hidden_dim, cudaStream_t stream = 0, float const* scale = nullptr, float* dynamic_scale = nullptr, int8_t* out_quant = nullptr); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu index 39b899024..473d9d17a 100644 --- a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.cu @@ -24,7 +24,7 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/memoryUtils.h" -#include "tensorrt_llm/kernels/samplingTopPKernels.h" +#include "tensorrt_llm/kernels/samplingAirTopPKernels.h" #include #include @@ -34,6 +34,10 @@ namespace tensorrt_llm { namespace kernels { + +using IdxT = int; +using AccT = float; + template struct alignas(128) Counter { @@ -92,7 +96,7 @@ constexpr __host__ __device__ IntType alignTo(IntType a, IntType b) return ceilDiv(a, b) * b; } -//! \brief Calcute the number of buckets based on the number of bits per pass. +//! \brief Calculate the number of buckets based on the number of bits per pass. //! \tparam BitsPerPass. If BitsPerPass==11, the number of buckets is 2048. If BitsPerPass==8, the number of buckets is //! 256. template @@ -101,7 +105,7 @@ __host__ __device__ int constexpr calcNumBuckets() return 1 << BitsPerPass; } -//! \brief Calcute the number of passes based on the number of bits per pass. +//! \brief Calculate the number of passes based on the number of bits per pass. //! \tparam BitsPerPass. If BitsPerPass==11, the number of passes is 3. If BitsPerPass==8, the number of passes is 4. template __host__ __device__ int constexpr calcNumPasses() @@ -421,7 +425,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i * Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description) */ template -__device__ void scan(volatile IdxT* histogram) +__device__ void scan(IdxT volatile* histogram) { int constexpr numBuckets = calcNumBuckets(); if constexpr (numBuckets >= BlockSize) @@ -638,7 +642,6 @@ __global__ void airTopPSampling(Counter* counters, AccT* histogra { finishedOutput[batchSlot] = finishState; } - ids[batchSlot][sequenceLengths[batchSlot]] = endIds[batchSlot]; } return; } @@ -890,21 +893,41 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt) } template -void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, +[[nodiscard]] std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize) +{ + int constexpr BitsPerPass = 11; + int constexpr numBuckets = calcNumBuckets(); + IdxT const bufLen = calcBufLen(vocabSize); + + size_t countersSize = sizeof(Counter) * batchSize; + size_t histogramsSize = sizeof(AccT) * numBuckets * batchSize; + size_t countHistogramsSize = sizeof(IdxT) * numBuckets * batchSize; + size_t buf1Size = sizeof(T) * bufLen * batchSize; + size_t idxBuf1Size = sizeof(IdxT) * bufLen * batchSize; + size_t buf2Size = sizeof(T) * bufLen * batchSize; + size_t idxBuf2Size = sizeof(IdxT) * bufLen * batchSize; + + std::vector sizes + = {countersSize, histogramsSize, countHistogramsSize, buf1Size, idxBuf1Size, buf2Size, idxBuf2Size}; + + return sizes; +} + +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getAirTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); + +template +void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots) { - using IdxT = int; - using AccT = float; static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or float"); static_assert(std::is_same_v, "AccT needs to be float"); IdxT const vocabSize = vocabSizePadded; int constexpr BitsPerPass = 11; - int constexpr numBuckets = calcNumBuckets(); - IdxT const bufLen = calcBufLen(vocabSize); int constexpr SAMPLING_BLOCK_SIZE = 512; int constexpr THREADS_PER_CTA_TOP_P_INIT = 1024; @@ -916,18 +939,11 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou IdxT* idxBuf1 = nullptr; T* buf2 = nullptr; IdxT* idxBuf2 = nullptr; - std::vector sizes = {sizeof(*counters) * batchSize, sizeof(*histograms) * numBuckets * batchSize, - sizeof(*countHistograms) * numBuckets * batchSize, sizeof(*buf1) * bufLen * batchSize, - sizeof(*idxBuf1) * bufLen * batchSize, sizeof(*buf2) * bufLen * batchSize, - sizeof(*idxBuf2) * bufLen * batchSize}; - size_t totalSize = calcAlignedSize(sizes); - if (workspace == nullptr) - { - workspaceSize = totalSize; - return; - } + + auto const workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSize); + std::vector alignedPointers; - calcAlignedPointers(alignedPointers, workspace, sizes); + calcAlignedPointers(alignedPointers, workspace, workspaceSizes); counters = static_cast(alignedPointers[0]); histograms = static_cast(alignedPointers[1]); countHistograms = static_cast(alignedPointers[2]); @@ -960,37 +976,36 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou } } -template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, +template void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); -template void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, +template void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); template -void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, - FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, - int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots) +void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots) { - invokeBatchAirTopPSampling(workspace, workspaceSize, outputIds, sequenceLength, finishedInput, finishedOutput, - cumLogProbs, outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, vocabSizePadded, endIds, topP, - nullptr, stream, blockNum, skipDecode, batchSlots); + invokeBatchAirTopPSampling(workspace, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, + outputLogProbs, logProbs, curandstate, batchSize, maxBatchSize, vocabSizePadded, endIds, topP, nullptr, stream, + blockNum, skipDecode, batchSlots); } -template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, +template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, float const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); -template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, +template void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, half const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, int blockNum, @@ -999,5 +1014,15 @@ template void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int* template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); template unsigned calcAirTopPBlockNum(int batchSize, int len, int smCnt); +template +size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded) +{ + auto const workspaceSizes = getAirTopPWorkspaceSizes(batchSize, vocabSizePadded); + return calcAlignedSize(workspaceSizes, 256); +} + +template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); +template size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h new file mode 100644 index 000000000..0d31a9094 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/samplingAirTopPKernels.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ + +//! \brief Given logProbs, performs top P sampling. +//! Note different from invokeTopPSampling() and invokeBatchTopPSampling() there two functions invokeAirTopPSampling +//! and invokeBatchAirTopPSampling is non-deterministic. +//! Fills sampled tokens to outputIds. Computes sequenceLength, finished state, cumLogProbs inplace. +//! Sampling per request can be controlled using skipDecode and topPs parameters. +//! Function sets workspaceSize and exits early if workspace is nullptr. +//! +//! \param workspace pointer to the workspace. Has to be pre-allocated by caller. Function does not take ownership of +//! the buffer. +//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request. +//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding +//! endId token. +//! \param finishedInput input buffer[batchSize].Exit early if true. +//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId == +//! endId). +//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored +//! if nullptr. +//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability induced by the top-k +//! sampling. We normalize the probability 'expLogit' of the selected token by the probability 's_sum' of a set of top-k +//! tokens, meaning the logProb is the probability of the selected token, conditioned on the event that it is selected, +//! i.e., log_prob = log P(i | i is in top-k) = log(expLogit / s_sum). Ignored if nullptr. +//! \param logProbs input buffer [batchSize x vocabSizePadded]. Log probabilities of each token in the vocab. +//! If cumLogProbs or outputLogProbs are specified, logProbs must contain **just** probabilities instead of log +//! probabilities. +//! \param curandstate input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize per +//! request. +//! \param batchSize batch size +//! \param maxBatchSize max batch size +//! \param vocabSizePadded size of padded vocab +//! \param endIds input buffer [batchSize]. EOS token ids per request +//! \param maxTopP maximum among all topPs P for topP sampling +//! \param topPs input buffer [batchSize]. P for topP sampling per request. Supported P is in range (0.0; 1.0]. +//! If nullptr maxTopP is used for all requests. +//! \param stream cuda stream +//! \param blockNum The appropriate block configuration calculated based on the number of multiprocessors, occupancy, +//! batchSize and vocabSizePadded +//! \param skipDecode input buffer [batchSize]. Flags whether to skip decoding per request +template +void invokeBatchAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, + int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, + bool const* skipDecode, int32_t const* batchSlots); + +//! \brief Specialization of invokeBatchAirTopPSampling with topPs=nullptr +template +void invokeAirTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, + curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, + float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, int32_t const* batchSlots); + +//! \brief Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. +//! \tparam T the data type of value +//! \tparam IdxT the data type of index +//! \tparam AccT the data type of variables related to accumulation +//! \tparam BitsPerPass the number of bits for each pass. Can be 8 or 11. Use 11 for default. +//! \tparam BlockSize the block size +//! \param batchSize +//! \param len the number of candidates for each case +//! \param smCnt number of multiprocessors on device +template +unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt); + +//! \brief Returns workspace size in bytes needed for sampling Air TopP computation +//! \param batchSize batch size +//! \param vocabSizePadded size of padded vocab +template +[[nodiscard]] size_t getAirTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); + +} // namespace kernels +} // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu index 1c701cdc5..9fc1b982b 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.cu @@ -25,6 +25,7 @@ #endif #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" @@ -36,45 +37,57 @@ namespace tensorrt_llm namespace kernels { -template -__global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* topKTmpIdBuf, T* topKTmpValBuf, - const FinishedState* finished, const int maxTopK, const int* topKs, const int vocabSize, const int* endIds, - const bool* skipDecode, const int* batchSlots) +template +__global__ void topKStage1(T const* __restrict logProbs, T const* const* __restrict logProbsPtrs, T* tmpLogProbs, + int32_t* topKTmpIdBuf, T* topKTmpValBuf, FinishedState const* finished, int32_t maxTopK, int32_t const* topKs, + int32_t vocabSize, int32_t const* endIds, bool const* skipDecode, int32_t const* batchSlots, + int32_t const* tokensPerStep, int32_t maxTokensPerStep) { typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; - int const tid = threadIdx.x; - int const bid = blockIdx.x; + auto const tid = static_cast(threadIdx.x); + auto const bid = static_cast(blockIdx.x); + auto const tokenIdx = static_cast(blockIdx.y); auto const batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs auto const batchSlot = batchSlots != nullptr ? batchSlots[batchId] : batchId; + if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot]) + { + return; + } + FinishedState const finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty(); if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding())) { return; } - const int blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam - const int k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index - const int logBufIndex = batchId * vocabSize; - const int tmpLogBufIndex = batchId * vocabSize; - const int tmpTopKBufIndex = batchId * BLOCKS_PER_BEAM_ * maxTopK + blockLane * k; + auto const logBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize; + auto logProbsSlot + = logProbsPtrs == nullptr ? logProbs + logBufIndex : logProbsPtrs[batchId * maxTokensPerStep + tokenIdx]; + + auto const blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam + auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index + + auto const tmpLogBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize; + auto const tmpTopKBufIndex = batchId * maxTokensPerStep * BLOCKS_PER_BEAM_ * maxTopK + + tokenIdx * BLOCKS_PER_BEAM_ * maxTopK + blockLane * k; TopK_2 partial; - const bool IS_FP16 = std::is_same::value; - const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + bool const IS_FP16 = std::is_same::value; + T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; if (finished != nullptr && finishState.isFinished()) { if (tid < k) { - const int index = tmpTopKBufIndex + tid; + auto const index = tmpTopKBufIndex + tid; if (blockLane == 0 && tid == 0) { - const int endId = endIds[batchSlot]; + auto const endId = endIds[batchSlot]; topKTmpIdBuf[index] = tmpLogBufIndex + endId; - topKTmpValBuf[index] = logProbs[logBufIndex + endId]; + topKTmpValBuf[index] = logProbsSlot[endId]; } else { @@ -85,20 +98,19 @@ __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* to return; } - for (int elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) + for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int localIndex = elemId + tmpLogBufIndex; - int globalIndex = elemId + logBufIndex; - tmpLogProbs[localIndex] = logProbs[globalIndex]; + auto localIndex = elemId + tmpLogBufIndex; + tmpLogProbs[localIndex] = logProbsSlot[elemId]; } - for (int ite = 0; ite < k; ite++) + for (int32_t ite = 0; ite < k; ite++) { partial.init(); #pragma unroll - for (int elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) + for (auto elemId = tid + blockLane * BLOCK_SIZE_; elemId < vocabSize; elemId += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { - int index = elemId + tmpLogBufIndex; + auto index = elemId + tmpLogBufIndex; partial.insert(tmpLogProbs[index], index); } @@ -106,7 +118,7 @@ __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* to if (tid == 0) { - const int index = tmpTopKBufIndex + ite; + auto const index = tmpTopKBufIndex + ite; topKTmpIdBuf[index] = total.p; topKTmpValBuf[index] = total.u; if (total.p >= 0) @@ -119,17 +131,19 @@ __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* to } template -__global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids, - int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs, - curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode, const int* batchSlots, - int maxBatchSize, const bool normalizeLogProbs, const bool logitHasProbs) +__global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, int const maxTopK, int const* topKs, float const topP, float const* topPs, + curandState_t* curandstate, int const* endIds, int const vocabSize, bool const* skipDecode, int const* batchSlots, + int maxBatchSize, bool const normalizeLogProbs, bool const logitHasProbs, int const* tokensPerStep, + int const maxTokensPerStep, bool returnAllTopK) { bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - int const tid = threadIdx.x; - auto const batchIdx = blockIdx.x; + auto const tid = static_cast(threadIdx.x); + auto const batchIdx = static_cast(blockIdx.x); + auto const tokenIdx = static_cast(blockIdx.y); auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx; FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty(); if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding())) @@ -137,17 +151,22 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm return; } - const int k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; - const float probThreshold = (topPs != nullptr) ? topPs[batchSlot] : topP; - const int size = k * BLOCKS_PER_BEAM_; - const int stride = maxTopK * BLOCKS_PER_BEAM_; + if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot]) + { + return; + } + + auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; + auto const probThreshold = (topPs != nullptr) ? topPs[batchSlot] : topP; + auto const size = k * BLOCKS_PER_BEAM_; + auto const stride = maxTopK * BLOCKS_PER_BEAM_; typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; __shared__ typename BlockReduce::TempStorage tempStorage; extern __shared__ char array[]; __shared__ float s_sum; - T* s_val = topKTmpValBuf + batchIdx * stride; - int* s_id = reinterpret_cast(array); + T* s_val = topKTmpValBuf + (batchIdx * maxTokensPerStep + tokenIdx) * stride; + auto* s_id = reinterpret_cast(array); if (tid == 0) { s_sum = 0.0f; @@ -163,7 +182,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm return; } - float* s_val2 = reinterpret_cast(s_id + k); + auto s_val2 = reinterpret_cast(s_id + k); float maxLogit; for (int ite = 0; ite < k; ite++) { @@ -199,43 +218,50 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm if (tid == 0) { - float randNum = (float) curand_uniform(curandstate + batchSlot) * probThreshold * s_sum; - for (int i = 0; i < k; i++) + auto randNum = static_cast(curand_uniform(curandstate + batchSlot) * probThreshold * s_sum); + for (int ki = 0; ki < k; ki++) { - float expLogit = s_val2[i]; + auto expLogit = s_val2[ki]; randNum = randNum - expLogit; - if (randNum <= 0.0f || i == k - 1) + if (randNum <= 0.0f || ki == k - 1 || returnAllTopK) { - int idx = s_id[i]; + auto idx = s_id[ki]; // If s_id is -1 here we force output token to the last from vocabulary to get vivid indicator of smth // going wrong for the debug - auto outputId = idx != -1 ? topKTmpIdBuf[batchIdx * stride + idx] % vocabSize : vocabSize - 1; + auto outputId = idx != -1 + ? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize + : vocabSize - 1; auto const curSeqLen = sequenceLengths[batchSlot]; - ids[batchSlot][curSeqLen] = outputId; - if (cumLogProbs != nullptr || outputLogProbs != nullptr) + auto outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx; + ids[batchSlot][outIdx] = outputId; + // cum log prob is not supported with returnAllTopK + if (!returnAllTopK) { - float logProb = logf(expLogit); - if (cumLogProbs != nullptr) - { - cumLogProbs[batchSlot] += logProb; - } - if (outputLogProbs != nullptr) + if (cumLogProbs != nullptr || outputLogProbs != nullptr) { - // 'outputLogProbs' is the probability induced by the top-k sampling: - // NOT normalized (same way as OpenAI does): - // log_prob = log P(i | i is in top-k) = log(expLogit) - // normalized: - // log_prob = log P(i | i is in top-k) = log(expLogit / sum) - outputLogProbs[curSeqLen * maxBatchSize + batchSlot] - = normalizeLogProbs ? logProb - logf(s_sum) : logProb; + auto logProb = logf(expLogit); + if (cumLogProbs != nullptr) + { + cumLogProbs[batchSlot] += logProb; + } + if (outputLogProbs != nullptr) + { + // 'outputLogProbs' is the probability induced by the top-k sampling: + // NOT normalized (same way as OpenAI does): + // log_prob = log P(i | i is in top-k) = log(expLogit) + // normalized: + // log_prob = log P(i | i is in top-k) = log(expLogit / sum) + outputLogProbs[curSeqLen * maxBatchSize + batchSlot] + = normalizeLogProbs ? logProb - logf(s_sum) : logProb; + } } + break; } - break; } } - if (sequenceLengths != nullptr && finishedOutput != nullptr) + if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr) { - const int seqLen = sequenceLengths[batchSlot]; + int const seqLen = sequenceLengths[batchSlot]; if (ids[batchSlot][seqLen] == endIds[batchSlot]) { finishedOutput[batchSlot].setFinishedEOS(); @@ -252,58 +278,59 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm } #define CASE_K(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_, normalizeLogProbs) \ - topKStage1 \ - <<>>(logProbs, tempLogProbs, topKTmpIdBuf, \ - topKTmpValBuf, finishedInput, maxTopK, topKs, vocabSize, endIds, skipDecode, batchSlots); \ - topKStage2Sampling \ - <<>>(topKTmpIdBuf, \ - topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \ - topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \ - normalizeLogProbs, logitsHasProbs); \ - break; + do \ + { \ + { \ + dim3 grid(batchSize* BLOCKS_PER_BEAM_, maxTokensPerStep); \ + dim3 block(BLOCK_SIZE_1_); \ + topKStage1<<>>(logProbs, logProbsPtrs, \ + tempLogProbs, topKTmpIdBuf, topKTmpValBuf, finishedInput, maxTopK, topKs, vocabSize, endIds, \ + skipDecode, batchSlots, tokensPerStep, maxTokensPerStep); \ + } \ + { \ + dim3 grid(batchSize, maxTokensPerStep); \ + dim3 block(BLOCK_SIZE_2_); \ + topKStage2Sampling \ + <<>>(topKTmpIdBuf, topKTmpValBuf, \ + ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, topKs, \ + topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, maxBatchSize, \ + normalizeLogProbs, logitsHasProbs, tokensPerStep, maxTokensPerStep, returnAllTopK); \ + } \ + } while (0) template -void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, - const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, - const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, - int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs) +void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, float const topP, + float const* topPs, int const vocabSize, int const* endIds, int const* batchSlots, cudaStream_t stream, + int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, + bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // Not allow an ambiguous inputs topP and topPs. assert(topP == 1.0f || topPs == nullptr); - const int vocabSize = vocabSizePadded; - const int maxBlockPerBeam = 8; - int tempLogProbsBufSize = batchSize * vocabSize; // type float - int topKTmpIdsBufSize = batchSize * maxTopK * maxBlockPerBeam; // type int - int topKTmpValBuf_size = batchSize * maxTopK * maxBlockPerBeam; // type float - - // prevent memory misaligned address - tempLogProbsBufSize = (int) (ceil(tempLogProbsBufSize / 4.)) * 4; - topKTmpIdsBufSize = (int) (ceil(topKTmpIdsBufSize / 4.)) * 4; - topKTmpValBuf_size = (int) (ceil(topKTmpValBuf_size / 4.)) * 4; - - if (workspace == nullptr) - { - workspaceSize - = sizeof(T) * tempLogProbsBufSize + sizeof(int) * topKTmpIdsBufSize + sizeof(T) * topKTmpValBuf_size; - return; - } + auto const workspaceSizes = getTopKWorkspaceSizes(batchSize, maxTokensPerStep, maxTopK, vocabSize); if (maxTopK == 0) { return; } - T* tempLogProbs = (T*) workspace; - int* topKTmpIdBuf = (int*) (tempLogProbs + tempLogProbsBufSize); - T* topKTmpValBuf = (T*) (topKTmpIdBuf + topKTmpIdsBufSize); + std::vector alignedPointers; + calcAlignedPointers(alignedPointers, workspace, workspaceSizes); + + auto tempLogProbs = static_cast(alignedPointers[0]); + auto topKTmpIdBuf = static_cast(alignedPointers[1]); + auto topKTmpValBuf = static_cast(alignedPointers[2]); - int logMaxTopK(0); - int recursor(maxTopK - 1); + int32_t logMaxTopK{0}; + int32_t recursor{maxTopK - 1}; while (recursor >>= 1) + { ++logMaxTopK; + } + switch (logMaxTopK) { case 0: @@ -311,16 +338,20 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo case 2: case 3: // 0 < maxTopK <= 16 CASE_K(16, 128, 128, 8, normalizeLogProbs); + break; case 4: // 16 < maxTopK <= 32 CASE_K(32, 256, 128, 8, normalizeLogProbs); + break; case 5: // 32 < maxTopK <= 64 CASE_K(64, 256, 256, 8, normalizeLogProbs); + break; case 6: case 7: case 8: case 9: // 64 < maxTopK <= 1024 CASE_K(1024, 256, 256, 8, normalizeLogProbs); - default: throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", maxTopK)); + break; + default: TLLM_CHECK_WITH_INFO(false, "TopK kernel supports 1 <= k <= 1024 but got k=%d", maxTopK); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -328,43 +359,47 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo #undef CASE_K -template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids, - int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, - const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, - const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, - const bool logitsHasProbs); +template void invokeBatchTopKSampling(void* workspace, float const* logProbs, float const* const* logProbsPtrs, + int** ids, int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, + float const topP, float const* topPs, int const vocabSizePadded, int const* endIds, int const* batchSlots, + cudaStream_t stream, int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, + bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); -template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, - int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, - const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, - const int batchSize, int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, - const bool logitsHasProbs); +template void invokeBatchTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, int const maxTopK, int const* topKs, float const topP, + float const* topPs, int const vocabSizePadded, int const* endIds, int const* batchSlots, cudaStream_t stream, + int const batchSize, int maxBatchSize, int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, + bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); template -void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, - const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, - const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode, - const bool normalizeLogProbs, const bool logitsHasProbs) +void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtrs, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, + int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, + int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, + bool logitsHasProbs, bool returnAllTopK) { - invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput, + invokeBatchTopKSampling(workspace, logProbs, logProbsPtrs, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots, - stream, batchSize, maxBatchSize, skipDecode, normalizeLogProbs, logitsHasProbs); + stream, batchSize, maxBatchSize, tokensPerStep, maxTokensPerStep, skipDecode, normalizeLogProbs, logitsHasProbs, + returnAllTopK); } -template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids, - int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, - const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, - const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); - -template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids, - int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, - const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, - const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); +template void invokeTopKSampling(void* workspace, float const* logProbs, float const* const* logProbsPtrs, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, + int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, + int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, + bool logitsHasProbs, bool returnAllTopK); + +template void invokeTopKSampling(void* workspace, half const* logProbs, half const* const* logProbsPtrs, int** ids, + int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, int const topK, float const topP, int const vocabSizePadded, + int const* endIds, int const* batchSlots, cudaStream_t stream, int const batchSize, int maxBatchSize, + int const* tokensPerStep, int const maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, + bool logitsHasProbs, bool returnAllTopK); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h index e53d6d2e8..facdb55ed 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopKKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopKKernels.h @@ -16,9 +16,13 @@ */ #pragma once +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include +#include + namespace tensorrt_llm { namespace kernels @@ -32,11 +36,12 @@ namespace kernels //! //! \param workspace pointer to the workspace. Has to be pre-allocated by caller. Function does not take ownership of the //! buffer. -//! \param workspaceSize size of the workspace in bytes -//! \param logProbs input buffer [batchSize x vocabSizePadded]. +//! \param logProbs input buffer [batchSize, maxTokensPerStep, vocabSizePadded]. //! Log probabilities of each token in the vocab. If logitsHasProbs is true, //! logProbs must contain **just** probabilities instead of log probabilities. -//! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request +//! \param logProbsPtr input buffer [batchSize][vocabSizePadded] array of pointers to logits. If nullptr, logProbs is used. +//! Only maxTokensPerStep == 1 is supported. +//! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains point32_ters to rows with output tokens per request //! \param sequenceLength input/output buffer [maxBatchSize]. Current sequence length of the request up to, but excluding endId token //! \param finishedInput input buffer [maxBatchSize]. If true, request exits early. //! \param finishedOutput output buffer [maxBatchSize]. Set flag if sequence has finished (if finished || outputId == endId). @@ -56,28 +61,62 @@ namespace kernels //! Supported P is in range (0.0, 1.0]. If nullptr, topP is used for all requests //! \param vocabSizePadded size of padded vocab //! \param endIds input buffer [maxBatchSize]. EOS token ids per request -//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool +//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool. +//! Linear indexing (batchIdx) is used if nullptr. //! \param stream cuda stream //! \param batchSize batch size //! \param maxBatchSize maximum batch size +//! \param tokensPerStep input buffer [maxBatchSize], optional. Number of tokens per step for each request. +//! It is assumed that all requests have maxTokensPerStep tokens per step if nullptr. +//! \param maxTokensPerStep maximum number of tokens per computed per step //! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request //! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK //! \param logitsHasProbs flag to highlight that logProbs contains probabilities +//! \param returnAllTopK flag to return all selectedTopK results // clang-format on template -void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths, - const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs, - const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, - int maxBatchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs); +void invokeBatchTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, int32_t** ids, + int32_t* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, const int32_t maxTopK, int32_t const* topKs, float const topP, + float const* topPs, const int32_t vocabSizePadded, int32_t const* endIds, int32_t const* batchSlots, + cudaStream_t stream, const int32_t batchSize, int maxBatchSize, int32_t const* tokensPerStep, + const int32_t maxTokensPerStep, bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, + bool returnAllTopK); //! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr template -void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** outputIds, int* sequenceLength, - const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds, - const int* batchSlots, cudaStream_t stream, const int batchSize, int maxBatchSize, const bool* skipDecode, - const bool normalizeLogProbs, const bool logitsHasProbs); +void invokeTopKSampling(void* workspace, T const* logProbs, T const* const* logProbsPtr, int32_t** outputIds, + int32_t* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, + float* outputLogProbs, curandState_t* curandstate, const int32_t topK, float const topP, + const int32_t vocabSizePadded, int32_t const* endIds, int32_t const* batchSlots, cudaStream_t stream, + const int32_t batchSize, int maxBatchSize, int32_t const* tokensPerStep, const int32_t maxTokensPerStep, + bool const* skipDecode, bool normalizeLogProbs, bool logitsHasProbs, bool returnAllTopK); + +template +[[nodiscard]] std::vector getTopKWorkspaceSizes( + int32_t batchSize, int32_t maxTokensPerStep, int32_t maxTopK, int32_t vocabSizePadded) +{ + int32_t constexpr maxBlockPerBeam = 8; + auto const tempLogProbsBufSize = sizeof(T) * batchSize * maxTokensPerStep * vocabSizePadded; // type T + auto const topKTmpIdsBufSize + = sizeof(int32_t) * batchSize * maxTokensPerStep * maxTopK * maxBlockPerBeam; // type int + auto const topKTmpValBufSize = sizeof(T) * batchSize * maxTokensPerStep * maxTopK * maxBlockPerBeam; // type T + + return {tempLogProbsBufSize, topKTmpIdsBufSize, topKTmpValBufSize}; +} + +//! \brief Returns workspace size in bytes needed for sampling TopK computation +//! \param batchSize batch size +//! \param maxTokensPerStep maximum number of tokens per computed per step +//! \param maxTopK maximum among all topKs K for topK sampling +//! \param vocabSizePadded size of padded vocab +template +[[nodiscard]] size_t getTopKWorkspaceSize( + int32_t batchSize, int32_t maxTokensPerStep, int32_t maxTopK, int32_t vocabSizePadded) +{ + auto const workspaceSizes = getTopKWorkspaceSizes(batchSize, maxTokensPerStep, maxTopK, vocabSizePadded); + return tensorrt_llm::common::calcAlignedSize(workspaceSizes, 256); +} } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu index b03e12e44..73b1e02bb 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.cu @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) @@ -24,6 +23,7 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/samplingTopPKernels.h" @@ -36,12 +36,12 @@ namespace kernels __global__ void topPInitialize( int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPOffsetBuf, int const batchSize, int const vocabSize) { - int tid = threadIdx.x; - int bid = blockIdx.x; + auto const tid = static_cast(threadIdx.x); + auto const bid = static_cast(blockIdx.x); if (bid == 0) { - for (int i = tid; i < batchSize + 1; i += blockDim.x) + for (auto i = tid; i < batchSize + 1; i += static_cast(blockDim.x)) { // Inclusive sum of offsets to vocab rows topPOffsetBuf[i] = i * vocabSize; @@ -49,13 +49,13 @@ __global__ void topPInitialize( } } - int index = tid + bid * blockDim.x; + auto index = tid + bid * static_cast(blockDim.x); while (index < batchSize * vocabSize) { // Set value at {bi, vi} position to vi topPIdValBuf[index] = index % vocabSize; - index += blockDim.x * gridDim.x; + index += static_cast(blockDim.x * gridDim.x); } } @@ -68,16 +68,16 @@ void invokeTopPInitialize(int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPO } template -__launch_bounds__(THREADBLOCK_SIZE) __global__ void topPBeamTopKKernel(const T* logProbs, // prob. - int* topKTmpIdBuf, T* topKTmpValBuf, const FinishedState* finishedInput, const int vocabSize, int* offsetBuf, - int* beginOffsetBuf, const float topP, const float* topPs, const bool* skipDecode, const int* batchSlots) +__launch_bounds__(THREADBLOCK_SIZE) __global__ void topPBeamTopKKernel(T const* logProbs, // prob. + int* topKTmpIdBuf, T* topKTmpValBuf, FinishedState const* finishedInput, int const vocabSize, int* offsetBuf, + int* beginOffsetBuf, float const topP, float const* topPs, bool const* skipDecode, int const* batchSlots) { /** * Kernel performs top 1 search and saves the token with largest probability if it exceeds probability threshold */ int constexpr MAX_K = 1; - int threadId = threadIdx.x; - int batchId = blockIdx.x; + auto const threadId = static_cast(threadIdx.x); + auto const batchId = static_cast(blockIdx.x); auto const batchSlot = batchSlots != nullptr ? batchSlots[batchId] : batchId; // Skip decoding kernel if configured @@ -211,8 +211,8 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i __shared__ float randNumS; - int const tid = threadIdx.x; - int const batchId = blockIdx.x; + auto const tid = static_cast(threadIdx.x); + auto const batchId = static_cast(blockIdx.x); auto const batchSlot = batchSlots != nullptr ? batchSlots[batchId] : batchId; // Skip kernel if this sampling method is not chosen const FinishedState finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty(); @@ -231,16 +231,15 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i finishedOutput[batchSlot] = finishState; } } - ids[batchSlot][sequenceLength[batchSlot]] = endIds[batchSlot]; return; } constexpr int WARP_SIZE = 32; constexpr int NUM_WARPS = blockSize / WARP_SIZE; - const int laneId = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - const float probThreshold = (topPs != nullptr) ? topPs[batchSlot] : topP; - const int currentStep = sequenceLength[batchSlot]; + int const laneId = threadIdx.x % WARP_SIZE; + int const warpId = threadIdx.x / WARP_SIZE; + float const probThreshold = (topPs != nullptr) ? topPs[batchSlot] : topP; + int const currentStep = sequenceLength[batchSlot]; // With P in (0.0; 1.0] we draw a random number P' in range (0.0; P] // We will sum all probs moving from the largest probability to the smallest and @@ -305,36 +304,53 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i } template -void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots) +std::vector getTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize) { - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + auto const sortedLogProbBufSize = sizeof(T) * batchSize * vocabSize; // type T + auto const sortedIdValsBufSize = sizeof(int32_t) * batchSize * vocabSize; // type int + + size_t cubTempStorageSize; + tensorrt_llm::common::check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + cubTempStorageSize, static_cast(nullptr), static_cast(nullptr), static_cast(nullptr), + static_cast(nullptr), static_cast(vocabSize * batchSize), batchSize, + static_cast(nullptr), static_cast(nullptr), + 0, // begin_bit + sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8 + 0)); // cudaStream_t - int const vocabSize = vocabSizePadded; + return {cubTempStorageSize, sortedLogProbBufSize, sortedIdValsBufSize}; +} - size_t sortedLogProbBufSize = batchSize * vocabSize * sizeof(T); // type T - size_t sortedIdValsBufSize = batchSize * vocabSize * sizeof(int); // type int - sortedLogProbBufSize = divUp(sortedLogProbBufSize, 256) * 256; - sortedIdValsBufSize = divUp(sortedIdValsBufSize, 256) * 256; +template std::vector getTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); +template std::vector getTopPWorkspaceSizes(int32_t batchSize, int32_t vocabSize); - void* cubTempStorage = workspace; - T* sortedLogProbs = (T*) ((char*) cubTempStorage + cubTempStorageSize); - int* sortedIdVals = (int*) ((char*) sortedLogProbs + sortedLogProbBufSize); +template +size_t getTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded) +{ + auto const workspaceSizes = getTopPWorkspaceSizes(batchSize, vocabSizePadded); + return tensorrt_llm::common::calcAlignedSize(workspaceSizes, 256); +} - if (workspace == nullptr) - { - check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, cubTempStorageSize, logProbs, - (T*) nullptr, idVals, (int*) nullptr, vocabSize * batchSize, batchSize, beginOffsetBuf, offsetBuf + 1, - 0, // begin_bit - sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8 - stream)); // cudaStream_t - cubTempStorageSize = divUp(cubTempStorageSize, 256) * 256; - workspaceSize = sortedLogProbBufSize + sortedIdValsBufSize + cubTempStorageSize; - return; - } +template size_t getTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); +template size_t getTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); + +template +void invokeBatchTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int32_t const* idVals, + int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSize, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + bool const* skipDecode, int const* batchSlots) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const workspaceSizes = getTopPWorkspaceSizes(batchSize, vocabSize); + + std::vector alignedPointers; + calcAlignedPointers(alignedPointers, workspace, workspaceSizes); + + auto cubTempStorage = static_cast(alignedPointers[0]); + auto sortedLogProbs = static_cast(alignedPointers[1]); + auto sortedIdVals = static_cast(alignedPointers[2]); int constexpr BLOCK_SIZE = 256; // Performs Top K=1 search. @@ -344,11 +360,13 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub sync_check_cuda_error(); // Sort tokens by probability in descending order - check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(cubTempStorage, cubTempStorageSize, logProbs, - sortedLogProbs, idVals, sortedIdVals, vocabSize * batchSize, batchSize, beginOffsetBuf, offsetBuf + 1, - 0, // begin_bit - sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8 - stream)); // cudaStream_t + auto cubWorkspaceSize = workspaceSizes[0]; + check_cuda_error( + cub::DeviceSegmentedRadixSort::SortPairsDescending(cubTempStorage, cubWorkspaceSize, logProbs, sortedLogProbs, + idVals, sortedIdVals, static_cast(vocabSize * batchSize), batchSize, beginOffsetBuf, offsetBuf + 1, + 0, // begin_bit + static_cast(sizeof(T) * 8), // end_bit = sizeof(KeyT) * 8 + stream)); // cudaStream_t int constexpr SAMPLING_BLOCK_SIZE = 256; dim3 grid(batchSize); @@ -361,43 +379,41 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, - int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, - float* cumLogProbs, float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf, - int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, - size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, - bool const* skipDecode, int const* batchSlots); +template void invokeBatchTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + float const* logProbs, int32_t const* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, + int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, + float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); -template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, - int** outputIds, int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, - float* cumLogProbs, float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf, - int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, - size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, - bool const* skipDecode, int const* batchSlots); +template void invokeBatchTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + half const* logProbs, int32_t const* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, + int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const maxTopP, + float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); template -void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots) +void invokeTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int32_t const* idVals, + int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const topP, cudaStream_t stream, bool const* skipDecode, + int const* batchSlots) { - invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedInput, - finishedOutput, cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate, - batchSize, maxBatchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots); + invokeBatchTopPSampling(workspace, outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, + outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate, batchSize, maxBatchSize, + vocabSizePadded, endIds, topP, nullptr, stream, skipDecode, batchSlots); } -template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, float const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); - -template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, half const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topP, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); +template void invokeTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + float const* logProbs, int32_t const* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, + int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, + cudaStream_t stream, bool const* skipDecode, int const* batchSlots); + +template void invokeTopPSampling(void* workspace, int** outputIds, int* sequenceLength, + FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, + half const* logProbs, int32_t const* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, + int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, float const topP, + cudaStream_t stream, bool const* skipDecode, int const* batchSlots); __global__ void computeToppDecay(float* runtimeTopP, float const* runtimeInitialTopP, int const** outputIds, float const* topPDecay, float const* topPMin, int32_t const* topPResetIds, int const* sequenceLengths, diff --git a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h index 9a54d359d..16ea02a07 100644 --- a/cpp/tensorrt_llm/kernels/samplingTopPKernels.h +++ b/cpp/tensorrt_llm/kernels/samplingTopPKernels.h @@ -15,6 +15,8 @@ */ #pragma once +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include @@ -38,12 +40,9 @@ void invokeTopPInitialize(int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPO //! \brief Given logProbs, performs top P sampling. Fills sampled tokens to outputIds. //! Computes sequenceLength, finished state, cumLogProbs inplace. //! Sampling per request can be controlled using skipDecode and topPs parameters. -//! Function sets workspaceSize and cubTempStorageSize and exits early if workspace is nullptr. //! //! \param workspace pointer to the workspace. Has to be pre-allocated by caller. Function does not take ownership of //! the buffer. -//! \param workspaceSize size of the workspace in bytes. -//! \param cubTempStorageSize workspace size for cub radix sort. //! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request. //! \param sequenceLength input/output buffer [maxBatchSize]. Current sequence length of the request up to, but excluding endId token. //! \param finishedInput input buffer [maxBatchSize]. Exit early if true. @@ -75,84 +74,19 @@ void invokeTopPInitialize(int* topPIdValBuf, int* topPOffsetBuf, int* beginTopPO //! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool // clang-format on template -void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); +void invokeBatchTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int32_t const* idVals, + int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, + bool const* skipDecode, int const* batchSlots); //! \brief Specialization of invokeBatchTopPSampling with topPs=nullptr template -void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds, - int* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, - float* outputLogProbs, T const* logProbs, int const* idVals, int* offsetBuf, int* beginOffsetBuf, - curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, int const* endIds, - float const topPp, cudaStream_t stream, bool const* skipDecode, int const* batchSlots); - -//! \brief Given logProbs, performs top P sampling. -//! Note different from invokeTopPSampling() and invokeBatchTopPSampling() there two functions invokeAirTopPSampling -//! and invokeBatchAirTopPSampling is non-deterministic. -//! Fills sampled tokens to outputIds. Computes sequenceLength, finished state, cumLogProbs inplace. -//! Sampling per request can be controlled using skipDecode and topPs parameters. -//! Function sets workspaceSize and exits early if workspace is nullptr. -//! -//! \param workspace pointer to the workspace. Has to be pre-allocated by caller. Function does not take ownership of -//! the buffer. -//! \param workspaceSize size of the workspace in bytes. -//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request. -//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding -//! endId token. -//! \param finishedInput input buffer[batchSize].Exit early if true. -//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId == -//! endId). -//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored -//! if nullptr. -//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability induced by the top-k -//! sampling. We normalize the probability 'expLogit' of the selected token by the probability 's_sum' of a set of top-k -//! tokens, meaning the logProb is the probability of the selected token, conditioned on the event that it is selected, -//! i.e., log_prob = log P(i | i is in top-k) = log(expLogit / s_sum). Ignored if nullptr. -//! \param logProbs input buffer [batchSize x vocabSizePadded]. Log probabilities of each token in the vocab. -//! If cumLogProbs or outputLogProbs are specified, logProbs must contain **just** probabilities instead of log -//! probabilities. -//! \param curandstate input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize per -//! request. -//! \param batchSize batch size -//! \param maxBatchSize max batch size -//! \param vocabSizePadded size of padded vocab -//! \param endIds input buffer [batchSize]. EOS token ids per request -//! \param maxTopP maximum among all topPs P for topP sampling -//! \param topPs input buffer [batchSize]. P for topP sampling per request. Supported P is in range (0.0; 1.0]. -//! If nullptr maxTopP is used for all requests. -//! \param stream cuda stream -//! \param blockNum The appropriate block configuration calculated based on the number of multiprocessors, occupancy, -//! batchSize and vocabSizePadded -//! \param skipDecode input buffer [batchSize]. Flags whether to skip decoding per request -template -void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, - FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, - int const* endIds, float const maxTopP, float const* topPs, cudaStream_t stream, int blockNum, - bool const* skipDecode, int32_t const* batchSlots); - -//! \brief Specialization of invokeBatchAirTopPSampling with topPs=nullptr -template -void invokeAirTopPSampling(void* workspace, size_t& workspaceSize, int** outputIds, int* sequenceLength, - FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, - T const* logProbs, curandState_t* curandstate, int const batchSize, int maxBatchSize, size_t const vocabSizePadded, - int const* endIds, float const topP, cudaStream_t stream, int blockNum, bool const* skipDecode, - int32_t const* batchSlots); - -//! \brief Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. -//! \tparam T the data type of value -//! \tparam IdxT the data type of index -//! \tparam AccT the data type of variables related to accumulation -//! \tparam BitsPerPass the number of bits for each pass. Can be 8 or 11. Use 11 for default. -//! \tparam BlockSize the block size -//! \param batchSize -//! \param len the number of candidates for each case -//! \param smCnt number of multiprocessors on device -template -unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt); +void invokeTopPSampling(void* workspace, int** outputIds, int* sequenceLength, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, T const* logProbs, int32_t const* idVals, + int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, int const batchSize, int maxBatchSize, + size_t const vocabSizePadded, int const* endIds, float const topPp, cudaStream_t stream, bool const* skipDecode, + int const* batchSlots); //! \brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf //! In short, the formula is @@ -172,5 +106,11 @@ void invokeComputeToppDecay(float* runtimeTopP, float const* runtimeInitialTopP, float const* topPDecay, float const* topPMin, int32_t const* topPResetIds, int const* sequenceLengths, int const* batchSlots, int const localBatchSize, cudaStream_t stream); +//! \brief Returns workspace size in bytes needed for sampling TopP computation +//! \param batchSize batch size +//! \param vocabSizePadded size of padded vocab +template +[[nodiscard]] size_t getTopPWorkspaceSize(int32_t batchSize, int32_t vocabSizePadded); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/selectiveScan.cu b/cpp/tensorrt_llm/kernels/selectiveScan.cu index 0aae423ec..8ecde3a4f 100644 --- a/cpp/tensorrt_llm/kernels/selectiveScan.cu +++ b/cpp/tensorrt_llm/kernels/selectiveScan.cu @@ -70,7 +70,7 @@ template (params.out_ptr); - weight_t* state = reinterpret_cast(params.x_ptr); + input_t* state = reinterpret_cast(params.x_ptr); input_t* x = reinterpret_cast(params.u_ptr); input_t* dt = reinterpret_cast(params.delta_ptr); weight_t* A = reinterpret_cast(params.A_ptr); @@ -99,12 +99,12 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa __shared__ weight_t sh_D[CHANNELS_PER_BLOCK]; __shared__ weight_t sh_dt_bias[CHANNELS_PER_BLOCK]; - const int channel = blockIdx.x * blockDim.x + threadIdx.x; - const int sample = blockIdx.y; // batch id + int const channel = blockIdx.x * blockDim.x + threadIdx.x; + int const sample = blockIdx.y; // batch id - const int seq_loops = (num_tokens + SEQ_UNROLL - 1) / SEQ_UNROLL; + int const seq_loops = (num_tokens + SEQ_UNROLL - 1) / SEQ_UNROLL; - const int input_matrix_row_id = sample * num_tokens; + int const input_matrix_row_id = sample * num_tokens; if (threadIdx.y == 1) { @@ -300,7 +300,7 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa // Write the new state back out to the cache for (int i = 0; i < DSTATE; i++) { - weight_t* my_state = &state[sample * num_channels * DSTATE]; + input_t* my_state = &state[sample * num_channels * DSTATE]; int offset = i * num_channels + channel; convertAndStore(&my_state[offset], state_reg[i]); } @@ -313,8 +313,8 @@ void invokeSelectiveScan(SSMParamsBase& params, cudaStream_t stream) int samples = params.batch; int channels = params.dim; - const int threads = 128; - const int blocks = (channels + threads - 1) / threads; + int const threads = 128; + int const blocks = (channels + threads - 1) / threads; dim3 block(threads, 2); dim3 grid(blocks, samples); TLLM_CHECK((channels % block.x) == 0); @@ -343,7 +343,7 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams { input_t* output = reinterpret_cast(params.out_ptr); - weight_t* state = reinterpret_cast(params.x_ptr); + input_t* state = reinterpret_cast(params.x_ptr); input_t* x = reinterpret_cast(params.u_ptr); input_t* dt = reinterpret_cast(params.delta_ptr); weight_t* A = reinterpret_cast(params.A_ptr); @@ -355,12 +355,12 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams bool dt_softplus = params.delta_softplus; int num_channels = params.dim; - const int channel = blockIdx.x * blockDim.x + threadIdx.x; + int const channel = blockIdx.x * blockDim.x + threadIdx.x; if (channel >= num_channels) return; - const int sample = blockIdx.y; + int const sample = blockIdx.y; - weight_t* my_state = &state[sample * num_channels * DSTATE]; + input_t* my_state = &state[sample * num_channels * DSTATE]; input_t* my_output = &output[sample * num_channels]; float rA[DSTATE]; @@ -424,8 +424,8 @@ void invokeSelectiveScanUpdate(SSMParamsBase& params, cudaStream_t stream) int samples = params.batch; int channels = params.dim; - const int threads = 128; - const int blocks = (channels + threads - 1) / threads; + int const threads = 128; + int const blocks = (channels + threads - 1) / threads; dim3 block(threads, 1); dim3 grid(blocks, samples); diff --git a/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu b/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu index 7ad2e2ada..b4c49453f 100644 --- a/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu +++ b/cpp/tensorrt_llm/kernels/splitkGroupGemm.cu @@ -81,8 +81,8 @@ void splitkGroupedGemm_(std::vector problem_sizes, std using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; - const int kAlignmentA = 8; - const int kAlignmentB = 8; + int const kAlignmentA = 8; + int const kAlignmentB = 8; int problem_count = problem_sizes.size(); diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu index 44c6062cf..edb49a8ba 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.cu @@ -37,28 +37,28 @@ __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int d } template -__global__ void addQKVBiasIA3Transpose(T* q_out, T* k_out, T* v_out, const T* __restrict q_in, - const T* __restrict bias_q, const T* __restrict k_in, const T* __restrict bias_k, const T* __restrict v_in, - const T* __restrict bias_v, const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, - const int batch_size, const int seq_len, const int head_num, const int size_per_head) +__global__ void addQKVBiasIA3Transpose(T* q_out, T* k_out, T* v_out, T const* __restrict q_in, + T const* __restrict bias_q, T const* __restrict k_in, T const* __restrict bias_k, T const* __restrict v_in, + T const* __restrict bias_v, int const* ia3_tasks, T const* ia3_key_weights, T const* ia3_value_weights, + int const batch_size, int const seq_len, int const head_num, int const size_per_head) { - const int n = head_num * size_per_head; - const int batch_id = blockIdx.x; - const int word_id = blockIdx.y; - const int row_id = batch_id * seq_len + word_id; + int const n = head_num * size_per_head; + int const batch_id = blockIdx.x; + int const word_id = blockIdx.y; + int const row_id = batch_id * seq_len + word_id; - const bool use_ia3 = ia3_tasks != nullptr; - const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0; - const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); - const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); + bool const use_ia3 = ia3_tasks != nullptr; + int const ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0; + bool const use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); + bool const use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) { - const int head_id = col_id / size_per_head; - const int size_id = col_id % size_per_head; - const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + int const head_id = col_id / size_per_head; + int const size_id = col_id % size_per_head; + int const target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + word_id * size_per_head + size_id; - const int src_id = row_id * n + col_id; + int const src_id = row_id * n + col_id; T q = ldg(&q_in[src_id]); q_out[target_id] = add(q, ldg(&bias_q[col_id])); @@ -80,28 +80,28 @@ __global__ void addQKVBiasIA3Transpose(T* q_out, T* k_out, T* v_out, const T* __ } template -__global__ void QKVIA3Transpose(T* q_out, T* k_out, T* v_out, const T* __restrict q_in, const T* __restrict k_in, - const T* __restrict v_in, const int* ia3_tasks, const T* __restrict ia3_key_weights, - const T* __restrict ia3_value_weights, const int batch_size, const int seq_len, const int head_num, - const int size_per_head) +__global__ void QKVIA3Transpose(T* q_out, T* k_out, T* v_out, T const* __restrict q_in, T const* __restrict k_in, + T const* __restrict v_in, int const* ia3_tasks, T const* __restrict ia3_key_weights, + T const* __restrict ia3_value_weights, int const batch_size, int const seq_len, int const head_num, + int const size_per_head) { - const int n = head_num * size_per_head; - const int batch_id = blockIdx.x; - const int word_id = blockIdx.y; - const int row_id = batch_id * seq_len + word_id; + int const n = head_num * size_per_head; + int const batch_id = blockIdx.x; + int const word_id = blockIdx.y; + int const row_id = batch_id * seq_len + word_id; - const bool use_ia3 = ia3_tasks != nullptr; - const int ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0; - const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); - const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); + bool const use_ia3 = ia3_tasks != nullptr; + int const ia3_task = use_ia3 ? ia3_tasks[batch_id] : 0; + bool const use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); + bool const use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); for (int col_id = threadIdx.x; col_id < n; col_id += blockDim.x) { - const int head_id = col_id / size_per_head; - const int size_id = col_id % size_per_head; - const int target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + int const head_id = col_id / size_per_head; + int const size_id = col_id % size_per_head; + int const target_id = batch_id * (head_num * seq_len * size_per_head) + head_id * seq_len * size_per_head + word_id * size_per_head + size_id; - const int src_id = row_id * n + col_id; + int const src_id = row_id * n + col_id; q_out[target_id] = ldg(&q_in[src_id]); @@ -122,11 +122,11 @@ __global__ void QKVIA3Transpose(T* q_out, T* k_out, T* v_out, const T* __restric } template -void invokeAddQKVBiasIA3Transpose(T* q_buf, T* k_buf, T* v_buf, T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, - const T* bias_V, const int batch_size, const int seq_len, const int head_num, const int size_per_head, - const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, cudaStream_t stream) +void invokeAddQKVBiasIA3Transpose(T* q_buf, T* k_buf, T* v_buf, T* Q, T const* bias_Q, T* K, T const* bias_K, T* V, + T const* bias_V, int const batch_size, int const seq_len, int const head_num, int const size_per_head, + int const* ia3_tasks, T const* ia3_key_weights, T const* ia3_value_weights, cudaStream_t stream) { - const int k = head_num * size_per_head; + int const k = head_num * size_per_head; dim3 grid(batch_size, seq_len); bool is_add_bias = bias_Q != nullptr; if (sizeof(T) == 4 || k % 2 != 0) @@ -178,9 +178,9 @@ INSTANTIATE_ADDQKVBIASIA3_TRANSPOSE(__nv_bfloat16); #undef INSTANTIATEADDQKVBIASTRANSPOSE template -__global__ void softmax_kernel(T* attn_score, const T_IN* qk, const T* attn_mask, const T* linear_bias_slopes, +__global__ void softmax_kernel(T* attn_score, const T_IN* qk, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, - const float qk_scale) + float const qk_scale) { // attn_score, [batch_size, num_heads, q_length, k_length] // qk, [batch_size, num_heads, q_length, k_length] @@ -192,7 +192,7 @@ __global__ void softmax_kernel(T* attn_score, const T_IN* qk, const T* attn_mask __shared__ float s_mean, s_max; - const float linear_bias_slope = linear_bias_slopes != nullptr ? (float) linear_bias_slopes[hi] : 0.0f; + float const linear_bias_slope = linear_bias_slopes != nullptr ? (float) linear_bias_slopes[hi] : 0.0f; // Loop along with Q dimension. for (int qi = blockIdx.x; qi < q_length; qi += gridDim.x) @@ -258,7 +258,7 @@ __global__ void softmax_kernel(T* attn_score, const T_IN* qk, const T* attn_mask } template -__global__ void softmax_kernel_h2(T* attn_score, const T* qk_buf, const T* attn_mask, const T* linear_bias_slopes, +__global__ void softmax_kernel_h2(T* attn_score, T const* qk_buf, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, const T qk_scale) { // attn_score, [batch_size, num_heads, q_length, k_length] @@ -360,7 +360,7 @@ __global__ void softmax_kernel_h2(T* attn_score, const T* qk_buf, const T* attn_ } template -__global__ void softmax_kernel_h2_v2(T* attn_score, const T* qk_buf, const T* attn_mask, const T* linear_bias_slopes, +__global__ void softmax_kernel_h2_v2(T* attn_score, T const* qk_buf, T const* attn_mask, T const* linear_bias_slopes, const int64_t batch_size, const int64_t head_num, const int64_t q_length, const int64_t k_length, const T scalar) { // attn_score, [batch_size, num_heads, q_length, k_length] @@ -728,8 +728,8 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<__nv_bfloat16, __nv_bfloat16>& param #undef LAUNCH_MASKED_SOFTMAX_ template -__global__ void transpose(const T* src, T* dst, const int batch_size, const int seq_len, const int head_num, - const int size_per_head, const float* scale, int int8_mode) +__global__ void transpose(T const* src, T* dst, int const batch_size, int const seq_len, int const head_num, + int const size_per_head, float const* scale, int int8_mode) { int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -756,20 +756,20 @@ __global__ void transpose(const T* src, T* dst, const int batch_size, const int } template <> -__global__ void transpose(const float* src, float* dst, const int batch_size, const int seq_len, const int head_num, - const int size_per_head, const float* scale, int int8_mode) +__global__ void transpose(float const* src, float* dst, int const batch_size, int const seq_len, int const head_num, + int const size_per_head, float const* scale, int int8_mode) { int batch_id = blockIdx.x / (head_num * seq_len); int seq_id = blockIdx.x % seq_len; int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len; - const int target_id = batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + int const target_id = batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head + head_id * size_per_head + threadIdx.x; - const int src_id = blockIdx.x * size_per_head + threadIdx.x; + int const src_id = blockIdx.x * size_per_head + threadIdx.x; if (int8_mode == 2) { - const float scale_val = *scale; + float const scale_val = *scale; reinterpret_cast(dst)[target_id] = cuda_cast(src[src_id] * scale_val); } else @@ -779,8 +779,8 @@ __global__ void transpose(const float* src, float* dst, const int batch_size, co } template -void invokeTransposeQKV(T* dst, T* src, const int batch_size, const int seq_len, const int head_num, - const int size_per_head, const float* scale, const int int8_mode, cudaStream_t stream) +void invokeTransposeQKV(T* dst, T* src, int const batch_size, int const seq_len, int const head_num, + int const size_per_head, float const* scale, int const int8_mode, cudaStream_t stream) { dim3 grid, block; if (sizeof(T) == 2) @@ -820,7 +820,7 @@ void invokeTransposeQKV(T* dst, T* src, const int batch_size, const int seq_len, } else { - const int seq_per_block = 1; + int const seq_per_block = 1; grid.x = batch_size * head_num * seq_len / seq_per_block; block.x = seq_per_block * size_per_head; transpose @@ -839,28 +839,28 @@ INSTANTIATE_TRANSPOSE_QKV(__nv_bfloat16); #undef INSTANTIATE_TRANSPOSE_QKV template -__global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q, const T* bias_Q, const T* K, const T* bias_K, const T* V, - const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, const int* ia3_tasks, const T* ia3_key_weights, - const T* ia3_value_weights, const int batch_size, const int seq_len, const int head_num, const int size_per_head, - const int* mask_offset) +__global__ void add_QKV_bias_rebuild_padding_ia3(T const* Q, T const* bias_Q, T const* K, T const* bias_K, T const* V, + T const* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, int const* ia3_tasks, T const* ia3_key_weights, + T const* ia3_value_weights, int const batch_size, int const seq_len, int const head_num, int const size_per_head, + int const* mask_offset) { - const int bid = blockIdx.x; + int const bid = blockIdx.x; - const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; - const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; - const int n = head_num * size_per_head; + int const tgt_batch_id = (bid + mask_offset[bid]) / seq_len; + int const tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + int const n = head_num * size_per_head; - const bool use_ia3 = ia3_tasks != nullptr; - const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0; - const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); - const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); + bool const use_ia3 = ia3_tasks != nullptr; + int const ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0; + bool const use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); + bool const use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - const int tgt_head_id = idx / size_per_head; - const int tgt_hidden_id = idx % size_per_head; + int const tgt_head_id = idx / size_per_head; + int const tgt_hidden_id = idx % size_per_head; - const int src_id = bid * n + idx; - const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + int const src_id = bid * n + idx; + int const tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + tgt_seq_id * size_per_head + tgt_hidden_id; q_buf_[tgt_id] = add(ldg(&Q[src_id]), ldg(&bias_Q[idx])); @@ -882,27 +882,27 @@ __global__ void add_QKV_bias_rebuild_padding_ia3(const T* Q, const T* bias_Q, co } template -__global__ void rebuild_padding_ia3(const T* Q, const T* K, const T* V, T* q_buf_, T* k_buf_, T* v_buf_, - const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, const int batch_size, const int seq_len, - const int head_num, const int size_per_head, const int* mask_offset) +__global__ void rebuild_padding_ia3(T const* Q, T const* K, T const* V, T* q_buf_, T* k_buf_, T* v_buf_, + int const* ia3_tasks, T const* ia3_key_weights, T const* ia3_value_weights, int const batch_size, int const seq_len, + int const head_num, int const size_per_head, int const* mask_offset) { - const int bid = blockIdx.x; + int const bid = blockIdx.x; - const int tgt_batch_id = (bid + mask_offset[bid]) / seq_len; - const int tgt_seq_id = (bid + mask_offset[bid]) % seq_len; - const int n = head_num * size_per_head; + int const tgt_batch_id = (bid + mask_offset[bid]) / seq_len; + int const tgt_seq_id = (bid + mask_offset[bid]) % seq_len; + int const n = head_num * size_per_head; - const bool use_ia3 = ia3_tasks != nullptr; - const int ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0; - const bool use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); - const bool use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); + bool const use_ia3 = ia3_tasks != nullptr; + int const ia3_task = use_ia3 ? ia3_tasks[tgt_batch_id] : 0; + bool const use_ia3_key = use_ia3 && (ia3_key_weights != nullptr); + bool const use_ia3_value = use_ia3 && (ia3_value_weights != nullptr); for (int idx = threadIdx.x; idx < n; idx += blockDim.x) { - const int tgt_head_id = idx / size_per_head; - const int tgt_hidden_id = idx % size_per_head; + int const tgt_head_id = idx / size_per_head; + int const tgt_hidden_id = idx % size_per_head; - const int src_id = bid * n + idx; - const int tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + int const src_id = bid * n + idx; + int const tgt_id = tgt_batch_id * head_num * seq_len * size_per_head + tgt_head_id * seq_len * size_per_head + tgt_seq_id * size_per_head + tgt_hidden_id; q_buf_[tgt_id] = ldg(&Q[src_id]); @@ -924,10 +924,10 @@ __global__ void rebuild_padding_ia3(const T* Q, const T* K, const T* V, T* q_buf } template -void invokeAddQKVBiasIA3RebuildPadding(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf, - T* k_buf, T* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, - const int valid_word_num, const int* mask_offset, const int* ia3_tasks, const T* ia3_key_weights, - const T* ia3_value_weights, cudaStream_t stream) +void invokeAddQKVBiasIA3RebuildPadding(T* Q, T const* bias_Q, T* K, T const* bias_K, T* V, T const* bias_V, T* q_buf, + T* k_buf, T* v_buf, int const batch_size, int const seq_len, int const head_num, int const size_per_head, + int const valid_word_num, int const* mask_offset, int const* ia3_tasks, T const* ia3_key_weights, + T const* ia3_value_weights, cudaStream_t stream) { #ifdef ENABLE_BF16 bool is_half2 = (std::is_same::value || std::is_same::value) && (size_per_head % 2 == 0); @@ -1006,20 +1006,20 @@ INSTANTIATE_ADDQKVBIASIA3_REBUILD_PADDING(__nv_bfloat16); #undef INSTANTIATEADDQKVBIASREBUILDPADDING template -__global__ void transpose_remove_padding(const T* src, T* dst, const int batch_size, const int seq_len, - const int head_num, const int size_per_head, const int* mask_offset, const float* scale, const int int8_mode) +__global__ void transpose_remove_padding(T const* src, T* dst, int const batch_size, int const seq_len, + int const head_num, int const size_per_head, int const* mask_offset, float const* scale, int const int8_mode) { // TODO: optimize this kernel? // do remove_sequence_length_padding - const int bid = blockIdx.x; // batch * seq_len or valid_word_num + int const bid = blockIdx.x; // batch * seq_len or valid_word_num - const int src_batch_id = (bid + mask_offset[bid]) / seq_len; - const int src_seq_id = (bid + mask_offset[bid]) % seq_len; + int const src_batch_id = (bid + mask_offset[bid]) / seq_len; + int const src_seq_id = (bid + mask_offset[bid]) % seq_len; - const int dst_seq_id = bid; + int const dst_seq_id = bid; - const int src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head; - const int dst_offset_base = dst_seq_id * head_num * size_per_head; + int const src_offset_base = src_batch_id * seq_len * head_num * size_per_head + src_seq_id * size_per_head; + int const dst_offset_base = dst_seq_id * head_num * size_per_head; using Int8_Packed_T = typename packed_as::value>::type; using Float_Packed_T = typename packed_as::value>::type; @@ -1028,8 +1028,8 @@ __global__ void transpose_remove_padding(const T* src, T* dst, const int batch_s for (int idx = threadIdx.x; idx < head_num * size_per_head; idx += blockDim.x) { - const int head_id = idx / size_per_head; - const int hidden_id = idx % size_per_head; + int const head_id = idx / size_per_head; + int const hidden_id = idx % size_per_head; const T src_elem = ldg(&src[src_offset_base + head_id * seq_len * size_per_head + hidden_id]); if (int8_mode == 2) { @@ -1104,9 +1104,9 @@ INSTANTIATE_TRANSPOSE_ATTENTION_OUT_REMOVE_PADDING(__nv_bfloat16); #undef INSTANTIATE_TRANSPOSE_ATTENTION_OUT_REMOVE_PADDING template -__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* __restrict qkv_bias, - const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int token_num, - const int head_num, const int kv_head_num, const int size_per_head, const float* scale, const int int8_mode) +__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, T const* __restrict qkv_bias, + int const* seq_lens, int const* padding_offset, int const batch_size, int const seq_len, int const token_num, + int const head_num, int const kv_head_num, int const size_per_head, float const* scale, int const int8_mode) { // QKV: [token_num, hidden + 2 * kv_head_num * size_per_head] // qkv_bias: [hidden + 2 * kv_head_num * size_per_head] @@ -1114,20 +1114,20 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, // k_buf, v_buf: [batch, kv_head_num, seq_len, size_per_head] // For cross attention where q/k/v buffer could be nullptr, writing to split buffer is suppressed when null T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; - const bool has_padding = padding_offset == nullptr; - const int hidden = head_num * size_per_head; // hidden dim Q - const int n = hidden + 2 * kv_head_num * size_per_head; + bool const has_padding = padding_offset == nullptr; + int const hidden = head_num * size_per_head; // hidden dim Q + int const n = hidden + 2 * kv_head_num * size_per_head; for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < token_num * n; index += gridDim.x * blockDim.x) { - const int bias_id = index % n; + int const bias_id = index % n; - const int token_idx = index / n; - const int token_padded_idx = token_idx + (has_padding ? 0 : padding_offset[token_idx]); - const int target_batch_id = token_padded_idx / seq_len; - const int actual_seq_len = seq_lens[target_batch_id]; - const int seq_id = token_padded_idx % seq_len; - const bool valid_seq = seq_id < actual_seq_len || !has_padding; + int const token_idx = index / n; + int const token_padded_idx = token_idx + (has_padding ? 0 : padding_offset[token_idx]); + int const target_batch_id = token_padded_idx / seq_len; + int const actual_seq_len = seq_lens[target_batch_id]; + int const seq_id = token_padded_idx % seq_len; + bool const valid_seq = seq_id < actual_seq_len || !has_padding; int qkv_id; int head_id; @@ -1172,7 +1172,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, { if (int8_mode == 2) { - val = cuda_cast(cuda_cast(reinterpret_cast(QKV)[index]) * scale[qkv_id]); + val = cuda_cast(cuda_cast(reinterpret_cast(QKV)[index]) * scale[qkv_id]); } else { @@ -1186,9 +1186,9 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, // Write to split QKV buffer if (head_num == kv_head_num || qkv_id == 0) // QKV or Q when MQA/GQA { - const int target_batch_stride = head_num * seq_len * size_per_head; - const int target_head_stride = seq_len * size_per_head; - const int target_seq_stride = size_per_head; + int const target_batch_stride = head_num * seq_len * size_per_head; + int const target_head_stride = seq_len * size_per_head; + int const target_seq_stride = size_per_head; if (qkv_ptr[qkv_id]) qkv_ptr[qkv_id][target_batch_id * target_batch_stride + head_id * target_head_stride + seq_id * target_seq_stride + size_id] @@ -1196,9 +1196,9 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, } else if (head_num != kv_head_num && qkv_id > 0) // KV when MQA/GQA { - const int target_batch_stride = kv_head_num * seq_len * size_per_head; - const int target_head_stride = seq_len * size_per_head; - const int target_seq_stride = size_per_head; + int const target_batch_stride = kv_head_num * seq_len * size_per_head; + int const target_head_stride = seq_len * size_per_head; + int const target_seq_stride = size_per_head; if (qkv_ptr[qkv_id]) qkv_ptr[qkv_id][target_batch_id * target_batch_stride + head_id * target_head_stride + seq_id * target_seq_stride + size_id] @@ -1237,10 +1237,10 @@ struct Vec_t<__nv_bfloat16> #endif template -__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* __restrict qkv_bias, - const int* seq_lens, const int* padding_offset, const int batch_size, const int seq_len, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base, - RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, +__global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, T* QKV, T const* __restrict qkv_bias, + int const* seq_lens, int const* padding_offset, int const batch_size, int const seq_len, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) { // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head], and @@ -1264,50 +1264,50 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, constexpr int vec_size = Vec_t::size; using Vec_t = typename Vec_t::Type; - const int token_idx = blockIdx.x; - const int token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx]; - const int tgt_token_idx = token_idx + token_padding_offset; - const bool has_padding = padding_offset == nullptr; + int const token_idx = blockIdx.x; + int const token_padding_offset = (padding_offset == nullptr || token_idx < 0) ? 0 : padding_offset[token_idx]; + int const tgt_token_idx = token_idx + token_padding_offset; + bool const has_padding = padding_offset == nullptr; - const int batch_idx = tgt_token_idx / seq_len; - const int seq_idx = tgt_token_idx % seq_len; - const int actual_seq_len = seq_lens[batch_idx]; - const bool valid_seq = seq_idx < actual_seq_len || !has_padding; + int const batch_idx = tgt_token_idx / seq_len; + int const seq_idx = tgt_token_idx % seq_len; + int const actual_seq_len = seq_lens[batch_idx]; + bool const valid_seq = seq_idx < actual_seq_len || !has_padding; - const int head_idx = blockIdx.y; - const int tidx = threadIdx.x; + int const head_idx = blockIdx.y; + int const tidx = threadIdx.x; - const int total_seq_len = seq_len; + int const total_seq_len = seq_len; - const bool is_seq_masked = !valid_seq; - const bool is_head_size_masked = tidx * vec_size >= size_per_head; - const bool is_masked = is_head_size_masked || is_seq_masked; + bool const is_seq_masked = !valid_seq; + bool const is_head_size_masked = tidx * vec_size >= size_per_head; + bool const is_masked = is_head_size_masked || is_seq_masked; - const int hidden_size = head_num * size_per_head; - const int hidden_idx = head_idx * size_per_head + tidx * vec_size; - const int qheads_per_kv_head = head_num / kv_head_num; - const int kv_head_idx = head_idx / qheads_per_kv_head; - const int hidden_idx_kv = kv_head_idx * size_per_head + tidx * vec_size; - const int n = (head_num + 2 * kv_head_num) * size_per_head; + int const hidden_size = head_num * size_per_head; + int const hidden_idx = head_idx * size_per_head + tidx * vec_size; + int const qheads_per_kv_head = head_num / kv_head_num; + int const kv_head_idx = head_idx / qheads_per_kv_head; + int const hidden_idx_kv = kv_head_idx * size_per_head + tidx * vec_size; + int const n = (head_num + 2 * kv_head_num) * size_per_head; - const int dst_kv_seq_idx = seq_idx; - const int src_k_offset = hidden_size; - const int src_v_offset = hidden_size + kv_head_num * size_per_head; + int const dst_kv_seq_idx = seq_idx; + int const src_k_offset = hidden_size; + int const src_v_offset = hidden_size + kv_head_num * size_per_head; // NOTE: q has seq len excluding prefix prompt // head_num == kv_head_num: // src QKV: [batch, time, 3, head_num, size_per_head] // head_num != kv_head_num: // src QKV: [batch, time, head_num * size_per_head + 2 * kv_head_num * size_per_head] - const int src_q_idx = token_idx * n + hidden_idx; - const int src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv; - const int src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv; + int const src_q_idx = token_idx * n + hidden_idx; + int const src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv; + int const src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv; // destination offset. - const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len + int const dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len + seq_idx * size_per_head + tidx * vec_size; - const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num + int const dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num + kv_head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head + tidx * vec_size; Vec_t q, k, v, zero; @@ -1327,15 +1327,15 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, // load q,k,v and add bias if (!is_masked) { - q = *reinterpret_cast(&QKV[src_q_idx]); - k = *reinterpret_cast(&QKV[src_k_idx]); - v = *reinterpret_cast(&QKV[src_v_idx]); + q = *reinterpret_cast(&QKV[src_q_idx]); + k = *reinterpret_cast(&QKV[src_k_idx]); + v = *reinterpret_cast(&QKV[src_v_idx]); if (ADD_BIAS) { - q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); - k_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_k_offset]); - v_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_v_offset]); + q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); + k_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_k_offset]); + v_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_v_offset]); q = mmha::add(q, q_bias); k = mmha::add(k, k_bias); @@ -1353,15 +1353,15 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, } case PositionEmbeddingType::kROPE_GPT_NEOX: { - const bool do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim; + bool const do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim; T* q_smem = reinterpret_cast(smem_); T* k_smem = q_smem + rotary_embedding_dim; - const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (tidx * vec_size) / half_rotary_dim; - const int intra_half_idx = (tidx * vec_size) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? + int const half_rotary_dim = rotary_embedding_dim / 2; + int const half_idx = (tidx * vec_size) / half_rotary_dim; + int const intra_half_idx = (tidx * vec_size) % half_rotary_dim; + int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? if (do_rotary) { @@ -1371,7 +1371,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, __syncthreads(); - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + int const transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; constexpr int tidx_factor = vec_size / 2; if (do_rotary) { @@ -1439,18 +1439,18 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf, rotary_embedding_max_positions, position_embedding_type); template -void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, const float* scale, - const int int8_mode, cudaStream_t stream) +void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T const* qkv_bias, int const* seq_lens, + int const* padding_offset, int const batch_size, int const seq_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale, + int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, float const* scale, + int const int8_mode, cudaStream_t stream) { // [bs, seq_len, 3, head, Dh] if (rotary_embedding_dim == 0) { - const int m = token_num; - const int n = head_num * size_per_head; + int const m = token_num; + int const n = head_num * size_per_head; dim3 block(384); dim3 grid((int) (ceil(1.0 * m * n / 384))); @@ -1500,45 +1500,45 @@ INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(__nv_bfloat16); #undef INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE template -__global__ void transpose_4d(T* dst, T* src, const int dim0, const int dim1, const int dim2, const int dim3, - const int dim0_leading_dim, const int ite) +__global__ void transpose_4d(T* dst, T* src, int const dim0, int const dim1, int const dim2, int const dim3, + int const dim0_leading_dim, int const ite) { // transpose from [dim0, dim1, dim2, dim3] to [dim2, X, dim1, dim3] // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * dim3; i += blockDim.x * gridDim.x) { int index = i; - const int d3 = index % dim3; + int const d3 = index % dim3; index = (index - d3) / dim3; - const int d2 = index % dim2; + int const d2 = index % dim2; index = (index - d2) / dim2; - const int d1 = index % dim1; + int const d1 = index % dim1; index = (index - d1) / dim1; - const int d0 = index % dim0; + int const d0 = index % dim0; index = (index - d0) / dim0; dst[d2 * dim0_leading_dim * dim1 * dim3 + (d0 + dim0 * ite) * dim1 * dim3 + d1 * dim3 + d3] = src[i]; } } template <> -__global__ void transpose_4d(half* dst, half* src, const int dim0, const int dim1, const int dim2, const int dim3, - const int dim0_leading_dim, const int ite) +__global__ void transpose_4d(half* dst, half* src, int const dim0, int const dim1, int const dim2, int const dim3, + int const dim0_leading_dim, int const ite) { half2* dst_ptr = (half2*) dst; half2* src_ptr = (half2*) src; - const int half_dim3 = dim3 / 2; + int const half_dim3 = dim3 / 2; // transpose from [dim0, dim1, dim2, half_dim3] to [dim2, dim0, dim1, half_dim3] // where the dimension of X is dim0_leading_dim, and offset is ite * dim0 for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < dim0 * dim1 * dim2 * half_dim3; i += blockDim.x * gridDim.x) { int index = i; - const int d3 = index % half_dim3; + int const d3 = index % half_dim3; index = (index - d3) / half_dim3; - const int d2 = index % dim2; + int const d2 = index % dim2; index = (index - d2) / dim2; - const int d1 = index % dim1; + int const d1 = index % dim1; index = (index - d1) / dim1; - const int d0 = index % dim0; + int const d0 = index % dim0; index = (index - d0) / dim0; dst_ptr[d2 * dim0_leading_dim * dim1 * half_dim3 + (d0 + dim0 * ite) * dim1 * half_dim3 + d1 * half_dim3 + d3] = src_ptr[i]; @@ -1546,8 +1546,8 @@ __global__ void transpose_4d(half* dst, half* src, const int dim0, const int dim } template -void invokeTranspose4d(T* dst, T* src, const int local_batch_size, const int seq_len, const int size_per_head, - const int local_hidden_units, const int local_head_num, const int batch_size, const int ite, cudaStream_t stream) +void invokeTranspose4d(T* dst, T* src, int const local_batch_size, int const seq_len, int const size_per_head, + int const local_hidden_units, int const local_head_num, int const batch_size, int const ite, cudaStream_t stream) { transpose_4d<<>>( dst, src, local_batch_size, local_head_num, seq_len, size_per_head, batch_size, ite); @@ -1562,9 +1562,9 @@ INSTANTIATE_TRANSPOSE_4D(half); #undef INSTANTIATE_TRANSPOSE_4D template -__global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCacheBuffer kvCacheBuffer, - const int headNum, const int sizePerHead, const int seqLen, const int attentionWindowSize, - const float* kvScaleOrigQuant, const int* sequence_lengths) +__global__ void transpose4dBatchMajorKVCache(T const* kSrc, T const* vSrc, KVCacheBuffer kvCacheBuffer, + int const headNum, int const sizePerHead, int const seqLen, int const attentionWindowSize, + float const* kvScaleOrigQuant, int const* sequence_lengths) { // We allow only fp32/fp16/bf16 as input types static_assert(sizeof(T) == 4 || sizeof(T) == 2, ""); @@ -1574,14 +1574,14 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac using T_dst = T_cache; using T_src = typename mmha::packed_type::type; - const int batchIdx = blockIdx.y; - const int headIdx = blockIdx.z; + int const batchIdx = blockIdx.y; + int const headIdx = blockIdx.z; // idx is over output dimension L * sizePerHead / x for values int idx = blockIdx.x * blockDim.x + threadIdx.x; // threadIdx.y 0 handles k, while threadIdx.y 1 handles v. - const bool handle_k = (threadIdx.y == 0); - const int sizePerHeadDivX = sizePerHead / X_ELEMS; + bool const handle_k = (threadIdx.y == 0); + int const sizePerHeadDivX = sizePerHead / X_ELEMS; if (idx >= sizePerHeadDivX * seqLen) { @@ -1592,9 +1592,9 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac int tokenIdx = idx / sizePerHeadDivX; // Apply cyclic kv cache if tokenIdx >= max_attention_window_size. // which means we will drop the tokens in the beginning if seqLen > max_attention_window_size. - const int tokenIdxLowerBound = max(sequence_lengths[batchIdx] - attentionWindowSize, 0); + int const tokenIdxLowerBound = max(sequence_lengths[batchIdx] - attentionWindowSize, 0); // Get channel index - const int channelIdx = idx % sizePerHeadDivX; + int const channelIdx = idx % sizePerHeadDivX; if (tokenIdx >= sequence_lengths[batchIdx] || tokenIdx < tokenIdxLowerBound) { return; @@ -1611,7 +1611,7 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac // 16 byte loads will handle "x" dimension const size_t srcOffset = (batchIdx * headNum + headIdx) * sizePerHead * seqLen; - auto valSrc = reinterpret_cast((handle_k ? kSrc : vSrc) + srcOffset); + auto valSrc = reinterpret_cast((handle_k ? kSrc : vSrc) + srcOffset); T_src val = valSrc[idx]; if (ENABLE_8BITS_CACHE) @@ -1635,9 +1635,9 @@ __global__ void transpose4dBatchMajorKVCache(const T* kSrc, const T* vSrc, KVCac } template -void invokeTranspose4dBatchMajor(const T* kSrc, const T* vSrc, KVCacheBuffer& kvTable, const int localBatchSize, - const int seqLen, const int attentionWindowSize, const int sizePerHead, const int localHeadNum, - const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream) +void invokeTranspose4dBatchMajor(T const* kSrc, T const* vSrc, KVCacheBuffer& kvTable, int const localBatchSize, + int const seqLen, int const attentionWindowSize, int const sizePerHead, int const localHeadNum, + const KvCacheDataType cache_type, float const* kvScaleOrigQuant, int const* sequence_lengths, cudaStream_t stream) { // Block handles both K and V tile. dim3 blockSz(128, 2); @@ -1685,19 +1685,19 @@ INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR(__nv_bfloat16); #undef INSTANTIATE_TRANSPOSE_4D_BATCH_MAJOR template -__global__ void addRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, const int batch_size, - const int head_num, const int seq_len, int max_seq_len, bool implicit, int num_buckets, int max_distance, +__global__ void addRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, int const batch_size, + int const head_num, int const seq_len, int max_seq_len, bool implicit, int num_buckets, int max_distance, bool bidirectional) { - const int seq_i = blockIdx.x; - const int batch_id = blockIdx.y / head_num; - const int head_id = blockIdx.y % head_num; - const int rel_attn_table_stride = num_buckets; // num_buckets could be modified below, save it beforehand + int const seq_i = blockIdx.x; + int const batch_id = blockIdx.y / head_num; + int const head_id = blockIdx.y % head_num; + int const rel_attn_table_stride = num_buckets; // num_buckets could be modified below, save it beforehand for (int seq_j = threadIdx.x; seq_j < seq_len; seq_j += blockDim.x) { - const int qk_index + int const qk_index = batch_id * head_num * seq_len * seq_len + head_id * seq_len * seq_len + seq_i * seq_len + seq_j; if (implicit) @@ -1730,15 +1730,15 @@ __global__ void addRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_ } else { - const int bias_index = head_id * max_seq_len * max_seq_len + seq_i * max_seq_len + seq_j; + int const bias_index = head_id * max_seq_len * max_seq_len + seq_i * max_seq_len + seq_j; qk_buf[qk_index] = (T) add((T) relative_attention_bias[bias_index], qk_buf[qk_index]); } } } template -void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, const int batch_size, - const int head_num, const int seq_len, const int max_seq_len, cudaStream_t stream, bool implicit, int num_buckets, +void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, int const batch_size, + int const head_num, int const seq_len, int const max_seq_len, cudaStream_t stream, bool implicit, int num_buckets, int max_distance, bool bidirectional) { // qk_buf: [batch_size, head_num, seq_len, seq_len] @@ -1764,11 +1764,11 @@ INSTANTIATE_ADD_RELATIVE_ATTENTION_BIAS_UNALIGNED(float, __nv_bfloat16); #undef INSTANTIATE_ADD_RELATIVE_ATTENTION_BIAS_UNALIGNED template -__global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, const int sizePerHead, - const int timestep, const int beam_width, const int maxKCacheLen, const int sinkTokenLen, - const float* kScaleQuantOrig, const int* sequence_lengths, const int* input_lengths, const int rotary_embedding_dim, +__global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, int const sizePerHead, + int const timestep, int const beam_width, int const maxKCacheLen, int const sinkTokenLen, + float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, - const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) + int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) { // We allow only fp32/fp16/bf16 as the data types to apply rotary static_assert(sizeof(T) == 4 || sizeof(T) == 2, ""); @@ -1785,35 +1785,35 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa using Vec_k = typename mmha::packed_type::type; using Vec_k_cache = typename mmha::packed_type::type; using T_dst = T; - const int sizePerHeadDivX = sizePerHead / vec_size; + int const sizePerHeadDivX = sizePerHead / vec_size; // The start token idx for the cyclic part in k cache - const int cyclic_k_cache_start_idx + int const cyclic_k_cache_start_idx = (timestep <= maxKCacheLen) ? sinkTokenLen : sinkTokenLen + timestep - maxKCacheLen; // The token idx int token_idx = (kvCacheBuffer.isSinkToken(blockIdx.x)) ? blockIdx.x : cyclic_k_cache_start_idx + blockIdx.x - sinkTokenLen; // The position idx - const int token_pos_idx = blockIdx.x; + int const token_pos_idx = blockIdx.x; // Head - const int head_idx = blockIdx.y; + int const head_idx = blockIdx.y; // The batch beam idx - const int batch_beam_idx = blockIdx.z; + int const batch_beam_idx = blockIdx.z; // The beam idx - const int beam_idx = batch_beam_idx % beam_width; + int const beam_idx = batch_beam_idx % beam_width; // Thread idx - const int tidx = threadIdx.x; + int const tidx = threadIdx.x; // The actual sequence length excluding the paddings. // minus 1 because it includes the current timestep while tlength denotes the past token length. - const int tlength = sequence_lengths[batch_beam_idx] - 1; + int const tlength = sequence_lengths[batch_beam_idx] - 1; // The context length - const int inlength = input_lengths[batch_beam_idx]; + int const inlength = input_lengths[batch_beam_idx]; // The k cache valid token length - const int cache_length = (tlength > maxKCacheLen) ? maxKCacheLen : tlength; + int const cache_length = (tlength > maxKCacheLen) ? maxKCacheLen : tlength; // Mask out the tokens exceed the real total length and tokens in the context phase with beam_idx>0 - const bool valid_seq = token_idx < tlength && !(token_idx < inlength && beam_idx > 0); - const bool is_head_size_masked = tidx * vec_size >= sizePerHead; + bool const valid_seq = token_idx < tlength && !(token_idx < inlength && beam_idx > 0); + bool const is_head_size_masked = tidx * vec_size >= sizePerHead; // Dequant scales for 8bits k cache float k_scale_quant_orig = (ENABLE_8BITS_CACHE ? kScaleQuantOrig[0] : 1.0f); @@ -1834,7 +1834,7 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa Vec_k_cache k_cache; T_cache* k_cache_batch = reinterpret_cast(kvCacheBuffer.getKBlockPtr(batch_beam_idx, token_kv_idx)); int inBlockIdx_r = kvCacheBuffer.getKVLocalIdx(token_kv_idx, head_idx, sizePerHead, tidx * vec_size); - k_cache = *reinterpret_cast(&k_cache_batch[inBlockIdx_r]); + k_cache = *reinterpret_cast(&k_cache_batch[inBlockIdx_r]); if constexpr (INT8_K_CACHE) { using Packed_Float_t = typename mmha::packed_type::type; @@ -1863,14 +1863,14 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa } case PositionEmbeddingType::kROPE_GPT_NEOX: { - const bool do_rotary = vec_size * tidx < rotary_embedding_dim; + bool const do_rotary = vec_size * tidx < rotary_embedding_dim; T* k_smem = reinterpret_cast(smem_); - const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (tidx * vec_size) / half_rotary_dim; - const int intra_half_idx = (tidx * vec_size) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? + int const half_rotary_dim = rotary_embedding_dim / 2; + int const half_idx = (tidx * vec_size) / half_rotary_dim; + int const intra_half_idx = (tidx * vec_size) % half_rotary_dim; + int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? if (do_rotary) { @@ -1879,7 +1879,7 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa __syncthreads(); - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + int const transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; constexpr int tidx_factor = vec_size / 2; if (do_rotary) { @@ -1908,15 +1908,15 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa template void invokeShiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, const KvCacheDataType cache_type, - const int sizePerHead, const int timestep, const int batch_beam, const int kv_head_num, const int beam_width, - const int maxKCacheLen, const int sinkTokenLen, const float* kScaleQuantOrig, const int* sequence_lengths, - const int* input_lengths, const int rotary_embedding_dim, float rotary_embedding_base, - RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, + int const sizePerHead, int const timestep, int const batch_beam, int const kv_head_num, int const beam_width, + int const maxKCacheLen, int const sinkTokenLen, float const* kScaleQuantOrig, int const* sequence_lengths, + int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream) { // Block handles K tile. - const int token_num_in_k = (timestep <= maxKCacheLen) ? timestep : maxKCacheLen; - const int vec_size = 16u / sizeof(T); + int const token_num_in_k = (timestep <= maxKCacheLen) ? timestep : maxKCacheLen; + int const vec_size = 16u / sizeof(T); dim3 block((sizePerHead / vec_size + 31) / 32 * 32); dim3 grid(token_num_in_k, kv_head_num, batch_beam); size_t smem_size diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h index 091bbaea9..b3467ee71 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels.h @@ -25,9 +25,9 @@ namespace kernels { template -void invokeAddQKVBiasIA3Transpose(T* q_buf, T* k_buf, T* v_buf, T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, - const T* bias_V, const int batch_size, const int seq_len, const int head_num, const int size_per_head, - const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, cudaStream_t stream); +void invokeAddQKVBiasIA3Transpose(T* q_buf, T* k_buf, T* v_buf, T* Q, T const* bias_Q, T* K, T const* bias_K, T* V, + T const* bias_V, int const batch_size, int const seq_len, int const head_num, int const size_per_head, + int const* ia3_tasks, T const* ia3_key_weights, T const* ia3_value_weights, cudaStream_t stream); template struct MaskedSoftmaxParam @@ -35,7 +35,7 @@ struct MaskedSoftmaxParam // Common parameters. T* attention_score = nullptr; // (batch_size, head_num, q_length, k_length) const T_IN* qk = nullptr; // (batch_size, head_num, q_length, k_length) - const T* attention_mask = nullptr; // (batch_size, q_length, k_length) + T const* attention_mask = nullptr; // (batch_size, q_length, k_length) int batch_size = 0; int q_length = 0; int k_length = 0; @@ -44,7 +44,7 @@ struct MaskedSoftmaxParam // Optional parameters that depend on the type of attention. // The slopes of the linear position bias of ALiBi. - const T* linear_bias_slopes = nullptr; // (head_num,), optional + T const* linear_bias_slopes = nullptr; // (head_num,), optional }; enum class KvCacheDataType @@ -58,77 +58,77 @@ template void invokeMaskedSoftmax(MaskedSoftmaxParam& param, cudaStream_t stream); template -void invokeTransposeQKV(T* dst, T* src, const int batch_size, const int seq_len, const int head_num, - const int size_per_head, const float* scale, const int int8_mode, cudaStream_t stream); +void invokeTransposeQKV(T* dst, T* src, int const batch_size, int const seq_len, int const head_num, + int const size_per_head, float const* scale, int const int8_mode, cudaStream_t stream); template -void invokeAddQKVBiasIA3RebuildPadding(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf, - T* k_buf, T* v_buf, const int batch_size, const int seq_len, const int head_num, const int size_per_head, - const int valid_word_num, const int* mask_offset, const int* ia3_tasks, const T* ia3_key_weights, - const T* ia3_value_weights, cudaStream_t stream); +void invokeAddQKVBiasIA3RebuildPadding(T* Q, T const* bias_Q, T* K, T const* bias_K, T* V, T const* bias_V, T* q_buf, + T* k_buf, T* v_buf, int const batch_size, int const seq_len, int const head_num, int const size_per_head, + int const valid_word_num, int const* mask_offset, int const* ia3_tasks, T const* ia3_key_weights, + T const* ia3_value_weights, cudaStream_t stream); template -void invokeTransposeAttentionOutRemovePadding(T* src, T* dst, const int valid_word_num, const int batch_size, - const int seq_len, const int head_num, const int size_per_head, const int* mask_offset, const float* scale, - const int int8_mode, cudaStream_t stream); +void invokeTransposeAttentionOutRemovePadding(T* src, T* dst, int const valid_word_num, int const batch_size, + int const seq_len, int const head_num, int const size_per_head, int const* mask_offset, float const* scale, + int const int8_mode, cudaStream_t stream); template -void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, - PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream); +void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T const* qkv_bias, int const* seq_lens, + int const* padding_offset, int const batch_size, int const seq_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, + PositionEmbeddingType const position_embedding_type, float const* scale, int const int8_mode, cudaStream_t stream); template -void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const T* qkv_bias, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, cudaStream_t stream) +void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T const* qkv_bias, int const* seq_lens, + int const* padding_offset, int const batch_size, int const seq_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, cudaStream_t stream) { invokeAddFusedQKVBiasTranspose(q_buf, k_buf, v_buf, QKV, qkv_bias, seq_lens, padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, 0, false, (float*) nullptr, 0, stream); } template -void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, const int* seq_lens, - const int* padding_offset, const int batch_size, const int seq_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, - PositionEmbeddingType const position_embedding_type, const float* scale, const int int8_mode, cudaStream_t stream) +void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, int const* seq_lens, + int const* padding_offset, int const batch_size, int const seq_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, + PositionEmbeddingType const position_embedding_type, float const* scale, int const int8_mode, cudaStream_t stream) { - invokeAddFusedQKVBiasTranspose(q_buf, k_buf, v_buf, QKV, (const T*) nullptr, seq_lens, padding_offset, batch_size, + invokeAddFusedQKVBiasTranspose(q_buf, k_buf, v_buf, QKV, (T const*) nullptr, seq_lens, padding_offset, batch_size, seq_len, token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, scale, int8_mode, stream); } template -void invokeTranspose4dBatchMajor(const T* k_src, const T* v_src, KVCacheBuffer& kvTable, const int local_batch_size, - const int seq_len, const int max_attention_window_size, const int size_per_head, const int local_head_num, - const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const int* sequence_lengths, cudaStream_t stream); +void invokeTranspose4dBatchMajor(T const* k_src, T const* v_src, KVCacheBuffer& kvTable, int const local_batch_size, + int const seq_len, int const max_attention_window_size, int const size_per_head, int const local_head_num, + const KvCacheDataType cache_type, float const* kvScaleOrigQuant, int const* sequence_lengths, cudaStream_t stream); // NOTE: this kernel is in-place, QKV will be modified, if other kernels need that, may need copy or use before it. template -void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* kv_seq_lens, const int* padding_offset, const int batch_size, const int seq_len, - const int cyclic_kv_cache_len, const int sink_token_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, - const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, const int int8_mode, - const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const bool enable_paged_kv_fmha, - const int beam_width, int2& grid_block_cache, cudaStream_t stream); +void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens, + int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len, + int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale, + int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, + int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode, + const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha, + int const beam_width, int2& grid_block_cache, cudaStream_t stream); template -void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, const int batch_size, - const int head_num, const int seq_len, const int max_seq_len, cudaStream_t stream, bool implicit = false, +void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, int const batch_size, + int const head_num, int const seq_len, int const max_seq_len, cudaStream_t stream, bool implicit = false, int num_buckets = 0, int max_distance = 0, bool bidirectional = true); template void invokeShiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, const KvCacheDataType cache_type, - const int sizePerHead, const int timestep, const int batch_beam, const int kv_head_num, const int beam_width, - const int maxKCacheLen, const int sinkTokenLen, const float* kScaleQuantOrig, const int* sequence_lengths, - const int* input_lengths, const int rotary_embedding_dim, float rotary_embedding_base, - RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, + int const sizePerHead, int const timestep, int const batch_beam, int const kv_head_num, int const beam_width, + int const maxKCacheLen, int const sinkTokenLen, float const* kScaleQuantOrig, int const* sequence_lengths, + int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream); } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h index a9352e3bf..e74ce6f49 100644 --- a/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h +++ b/cpp/tensorrt_llm/kernels/unfusedAttentionKernels/unfusedAttentionKernels_2_template.h @@ -158,13 +158,13 @@ struct Rotary_vec_t<__nv_bfloat16, 256> template -__global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBuffer, const T* __restrict qkv_bias, - const int* seq_lens, const int* kv_seq_lens, const int* padding_offset, const float* kvScaleOrigQuant, - const int num_tokens, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, - const int sink_token_len, const int head_num, const int kv_head_num, const int qheads_per_kv_head, - const int size_per_head, const int rotary_embedding_dim, float rotary_embedding_base, - RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, - PositionEmbeddingType const position_embedding_type, const int* medusa_position_offsets, const int beam_width) +__global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBuffer, T const* __restrict qkv_bias, + int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, float const* kvScaleOrigQuant, + int const num_tokens, int const batch_size, int const seq_len, int const cyclic_kv_cache_len, + int const sink_token_len, int const head_num, int const kv_head_num, int const qheads_per_kv_head, + int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, + PositionEmbeddingType const position_embedding_type, int const* medusa_position_offsets, int const beam_width) { // This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head] // Extract the Q input when using paged KV FMHA. @@ -191,28 +191,28 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu constexpr int VEC_SIZE = Rotary_vec_t::size; using Vec_type = typename Rotary_vec_t::Type; using Packed_type = typename Rotary_vec_t::Packed_type; - const bool has_padding = padding_offset == nullptr; + bool const has_padding = padding_offset == nullptr; constexpr bool ENABLE_8BITS_CACHE = sizeof(T_cache) == 1; - const int sizePerHeadDivX = size_per_head / VEC_SIZE; + int const sizePerHeadDivX = size_per_head / VEC_SIZE; using T_dst = T_cache; - const int head_idx = blockIdx.y; + int const head_idx = blockIdx.y; // Block size is always 32 in the x dimension. int tidx = threadIdx.x; // The half head dimension for remapping. // 32 threads in one warp // (first rotary threads + first no rotary threads) = first 16 threads // (second rotary threads + second no rotary threads) = second 16 threads - const int half_within_bound_dim = size_per_head / 2; - const int half_rotary_embedding_dim = rotary_embedding_dim / 2; - const int half_rotary_embedding_threads = rotary_embedding_dim / (2 * VEC_SIZE); - const int half_non_rotary_embedding_threads = (size_per_head - rotary_embedding_dim) / (2 * VEC_SIZE); + int const half_within_bound_dim = size_per_head / 2; + int const half_rotary_embedding_dim = rotary_embedding_dim / 2; + int const half_rotary_embedding_threads = rotary_embedding_dim / (2 * VEC_SIZE); + int const half_non_rotary_embedding_threads = (size_per_head - rotary_embedding_dim) / (2 * VEC_SIZE); // Remap to the correct half head size when head size is not power of 2. // This is mianly designed for the gptneox_style_rotary_embedding (which rotates the half embedding.) // The first 16 threads will handle the first half head size. - const bool first_half = tidx < HALF_WARP_SIZE; - const int second_half = !first_half; + bool const first_half = tidx < HALF_WARP_SIZE; + int const second_half = !first_half; int rotary_local_tidx = (tidx - second_half * HALF_WARP_SIZE); @@ -227,16 +227,16 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu : (rotary_local_tidx + half_rotary_embedding_threads + second_half * half_non_rotary_embedding_threads)); - const int hidden_size = head_num * size_per_head; - const int hidden_idx = head_idx * size_per_head + tidx * VEC_SIZE; - const int kv_head_idx = head_idx / qheads_per_kv_head; - const int hidden_idx_kv = kv_head_idx * size_per_head + tidx * VEC_SIZE; - const int n = (head_num + 2 * kv_head_num) * size_per_head; - const int src_k_offset = hidden_size; - const int src_v_offset = hidden_size + kv_head_num * size_per_head; + int const hidden_size = head_num * size_per_head; + int const hidden_idx = head_idx * size_per_head + tidx * VEC_SIZE; + int const kv_head_idx = head_idx / qheads_per_kv_head; + int const hidden_idx_kv = kv_head_idx * size_per_head + tidx * VEC_SIZE; + int const n = (head_num + 2 * kv_head_num) * size_per_head; + int const src_k_offset = hidden_size; + int const src_v_offset = hidden_size + kv_head_num * size_per_head; // Dynamic scaling of rotary embedding. - const bool dynamic_scale = rotary_scale_type == RotaryScalingType::kDYNAMIC; + bool const dynamic_scale = rotary_scale_type == RotaryScalingType::kDYNAMIC; for (int token_idx = blockIdx.x * blockDim.y + threadIdx.y; token_idx < num_tokens; token_idx += gridDim.x * blockDim.y) @@ -244,20 +244,20 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu // The index of the token in the batch. It includes "virtual" padding (even if the input is not padded) // such that the sequence index and the position in the sequence can be obtained using the max. // sequence length as: - const int token_padding_offset = (has_padding || IS_GENERATE) ? 0 : padding_offset[token_idx]; - const int global_token_idx = token_idx + token_padding_offset; - const int batch_beam_idx = global_token_idx / seq_len; + int const token_padding_offset = (has_padding || IS_GENERATE) ? 0 : padding_offset[token_idx]; + int const global_token_idx = token_idx + token_padding_offset; + int const batch_beam_idx = global_token_idx / seq_len; // TODO: optimize this for generation by using anther dimension of grid. - const int seq_idx = global_token_idx % seq_len; - const int final_kv_seq_len = (!IS_GENERATE) ? kv_seq_lens[batch_beam_idx] : 0; - const int actual_seq_len = seq_lens[batch_beam_idx]; + int const seq_idx = global_token_idx % seq_len; + int const final_kv_seq_len = (!IS_GENERATE) ? kv_seq_lens[batch_beam_idx] : 0; + int const actual_seq_len = seq_lens[batch_beam_idx]; // Chunked attention: takes past_kv_sequence_length into consideration. - const int token_idx_in_seq + int const token_idx_in_seq = (!IS_GENERATE) ? (final_kv_seq_len - actual_seq_len) + seq_idx : (actual_seq_len - seq_len + seq_idx); - const bool valid_seq = IS_GENERATE || (token_idx_in_seq < actual_seq_len || !has_padding); + bool const valid_seq = IS_GENERATE || (token_idx_in_seq < actual_seq_len || !has_padding); // NOTE: only Medusa needs the position offsets. // In the generation phase, we assume all sequences should have the same input length. - const int rotary_position = medusa_position_offsets != nullptr && IS_GENERATE + int const rotary_position = medusa_position_offsets != nullptr && IS_GENERATE ? (medusa_position_offsets[seq_idx + batch_beam_idx * seq_len] + actual_seq_len - seq_len) : token_idx_in_seq; @@ -265,10 +265,10 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu // we have already updated the scale in host if it is linear scale. float2 updated_base_scale = mmha::update_dynamic_scaling_rotary(rotary_embedding_base, rotary_embedding_scale, actual_seq_len, rotary_embedding_max_positions, rotary_embedding_dim, dynamic_scale); - const float updated_base = updated_base_scale.x; - const float updated_scale = updated_base_scale.y; + float const updated_base = updated_base_scale.x; + float const updated_scale = updated_base_scale.y; - const bool is_masked = !valid_seq || tidx < 0; + bool const is_masked = !valid_seq || tidx < 0; // head_num == kv_head_num: // src QKV: [batch, time, 3, head_num, size_per_head] @@ -286,15 +286,15 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu // load q,k,v and add bias if (!is_masked) { - q = *reinterpret_cast(&QKV[src_q_idx]); - k = *reinterpret_cast(&QKV[src_k_idx]); - v = *reinterpret_cast(&QKV[src_v_idx]); + q = *reinterpret_cast(&QKV[src_q_idx]); + k = *reinterpret_cast(&QKV[src_k_idx]); + v = *reinterpret_cast(&QKV[src_v_idx]); if constexpr (ADD_BIAS) { - q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); - k_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_k_offset]); - v_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_v_offset]); + q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); + k_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_k_offset]); + v_bias = *reinterpret_cast(&qkv_bias[hidden_idx_kv + src_v_offset]); q = mmha::add(q, q_bias); k = mmha::add(k, k_bias); @@ -332,12 +332,12 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu } } - const int channelIdx{tidx}; - const int tokenIdxLowerBound = max(actual_seq_len - cyclic_kv_cache_len + sink_token_len, sink_token_len); - const bool valid_kv_cache_pos + int const channelIdx{tidx}; + int const tokenIdxLowerBound = max(actual_seq_len - cyclic_kv_cache_len + sink_token_len, sink_token_len); + bool const valid_kv_cache_pos = kvCacheBuffer.data != nullptr // In KV-cache-less mode. No need to store KV values && (token_idx_in_seq >= tokenIdxLowerBound || token_idx_in_seq < sink_token_len); - const int token_kv_idx = kvCacheBuffer.getKVTokenIdx(token_idx_in_seq); + int const token_kv_idx = kvCacheBuffer.getKVTokenIdx(token_idx_in_seq); if (!is_masked) { @@ -412,21 +412,21 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu position_embedding_type, medusa_position_offsets, beam_width); template -void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* kv_seq_lens, const int* padding_offset, const int batch_size, const int seq_len, - const int cyclic_kv_cache_len, const int sink_token_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, - const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, - const float* kvScaleOrigQuant, const int int8_mode, const bool enable_paged_kv_fmha, const int beam_width, +void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens, + int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len, + int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale, + int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, + int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, + float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width, int2& grid_block_cache, cudaStream_t stream) { - const bool add_bias = qkv_bias != nullptr; - const bool store_contiguous_qkv = !enable_paged_kv_fmha; + bool const add_bias = qkv_bias != nullptr; + bool const store_contiguous_qkv = !enable_paged_kv_fmha; // Update scale if scale_type == RotaryScalingType::kLINEAR. - const float updated_rotary_embedding_scale + float const updated_rotary_embedding_scale = rotary_scale_type == RotaryScalingType::kLINEAR ? 1.0f / rotary_embedding_scale : rotary_embedding_scale; if (add_bias) @@ -482,14 +482,14 @@ void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, const T* qkv_b } template -void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTable, const T* qkv_bias, - const int* seq_lens, const int* kv_seq_lens, const int* padding_offset, const int batch_size, const int seq_len, - const int cyclic_kv_cache_len, const int sink_token_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, - const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, - const float* kvScaleOrigQuant, const int int8_mode, const bool enable_paged_kv_fmha, const int beam_width, +void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, + int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len, + int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale, + int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, + int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, + float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width, int2& grid_block_cache, cudaStream_t stream) { TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO @@ -541,15 +541,15 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTab } template -void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, - const int* kv_seq_lens, const int* padding_offset, const int batch_size, const int seq_len, - const int cyclic_kv_cache_len, const int sink_token_len, const int token_num, const int head_num, - const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, - const RotaryScalingType rotary_scale_type, const float rotary_embedding_scale, - const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, - const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, const int int8_mode, - const KvCacheDataType cache_type, const float* kvScaleOrigQuant, const bool enable_paged_kv_fmha, - const int beam_width, int2& grid_block_cache, cudaStream_t stream) +void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens, + int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len, + int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num, + int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base, + const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale, + int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, + int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode, + const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha, + int const beam_width, int2& grid_block_cache, cudaStream_t stream) { // Block handles both K and V tile. constexpr int x = (sizeof(T) == 4) ? 4 : 8; diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h index ef5aad7e1..95b8e9a77 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/common.h @@ -66,25 +66,25 @@ struct WeightOnlyParams using ActType = void; using WeiType = uint8_t; - const uint8_t* qweight; - const ActType* scales; - const ActType* zeros; - const ActType* in; - const ActType* act_scale; - const ActType* bias; + uint8_t const* qweight; + ActType const* scales; + ActType const* zeros; + ActType const* in; + ActType const* act_scale; + ActType const* bias; ActType* out; - const int m; - const int n; - const int k; - const int group_size; + int const m; + int const n; + int const k; + int const group_size; WeightOnlyQuantType quant_type; WeightOnlyType weight_only_type; WeightOnlyActivationFunctionType act_func_type; WeightOnlyActivationType act_type; - WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in, - const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, - const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, + WeightOnlyParams(uint8_t const* _qweight, ActType const* _scales, ActType const* _zeros, ActType const* _in, + ActType const* _act_scale, ActType const* _bias, ActType* _out, int const _m, int const _n, int const _k, + int const _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type, const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type) : qweight(_qweight) , scales(_scales) diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h index d0f9eb34f..1aabeb8ee 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/enabled.h @@ -82,7 +82,7 @@ bool isEnabledForArch(int arch) inline bool isWeightOnlyBatchedGemvEnabled(WeightOnlyQuantType qtype) { - const int arch = tensorrt_llm::common::getSMVersion(); + int const arch = tensorrt_llm::common::getSMVersion(); if (qtype == WeightOnlyQuantType::Int4b) { return isEnabledForArch(arch); diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h index 91a36afbc..72eab1a4e 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernel.h @@ -239,14 +239,14 @@ struct WeightOnlyScaleLoader static constexpr int kGroupSize = WeightOnlyProperties::kGroupSize; private: - const ElemType* _scales; - const ElemType* _zeros; + ElemType const* _scales; + ElemType const* _zeros; int _stride; int _offset; public: __device__ __forceinline__ WeightOnlyScaleLoader( - const ElemType* scales, const ElemType* zeros, int initial_offset, int stride) + ElemType const* scales, ElemType const* zeros, int initial_offset, int stride) : _scales(scales) , _zeros(zeros) , _stride(stride) @@ -293,8 +293,8 @@ struct WeightOnlyScaleLoader template class ActOp, bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros, - const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) +__device__ void weight_only_batched_gemv(uint8_t const* qweight, ActType const* scales, ActType const* zeros, + ActType const* in, ActType const* act_scale, ActType const* bias, ActType* out, int const n, int const k) { static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0)); using ActType2 = typename ActTypeDetails::Vec2; @@ -309,11 +309,11 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* constexpr int Interleave = Details::kInterleave; constexpr int WarpSize = 32; constexpr int Num = Batch * NPerBlock; - const int tid = threadIdx.x; - const int bid = blockIdx.x; - const int n_start_id = bid * NPerBlock * Interleave; + int const tid = threadIdx.x; + int const bid = blockIdx.x; + int const n_start_id = bid * NPerBlock * Interleave; // Calculate the n-dimensional index of the data processed by the current thread in the interleave tile - const int interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; + int const interleave_n_id = (tid / Details::kThreadsNumPerTile) % Interleave; qweight += n_start_id * k / Details::kElemsPerByte; ScaleLoader scale_loader(scales, zeros, n_start_id + interleave_n_id, n); @@ -469,8 +469,8 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* template class ActOp, bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize> -__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros, - const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k) +__global__ void weight_only_batched_gemv_wrapper(uint8_t const* qweight, ActType const* scales, ActType const* zeros, + ActType const* in, ActType const* act_scale, ActType const* bias, ActType* out, int const n, int const k) { if constexpr (std::is_same_v) { @@ -490,7 +490,7 @@ template struct WeightOnlyBatchedGemvKernelLauncher { - static void run(const WeightOnlyParams& params, cudaStream_t stream) + static void run(WeightOnlyParams const& params, cudaStream_t stream) { if (params.act_type == WeightOnlyActivationType::FP16) { @@ -502,18 +502,18 @@ struct WeightOnlyBatchedGemvKernelLauncher { weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, + reinterpret_cast(params.scales), reinterpret_cast(params.zeros), + reinterpret_cast(params.in), reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, params.k); } else { weight_only_batched_gemv_wrapper<<>>(params.qweight, - reinterpret_cast(params.scales), reinterpret_cast(params.zeros), - reinterpret_cast(params.in), reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, + reinterpret_cast(params.scales), reinterpret_cast(params.zeros), + reinterpret_cast(params.in), reinterpret_cast(params.act_scale), + reinterpret_cast(params.bias), reinterpret_cast(params.out), params.n, params.k); } } @@ -528,22 +528,22 @@ struct WeightOnlyBatchedGemvKernelLauncher { weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast(params.scales), - reinterpret_cast(params.zeros), - reinterpret_cast(params.in), - reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), + reinterpret_cast<__nv_bfloat16 const*>(params.scales), + reinterpret_cast<__nv_bfloat16 const*>(params.zeros), + reinterpret_cast<__nv_bfloat16 const*>(params.in), + reinterpret_cast<__nv_bfloat16 const*>(params.act_scale), + reinterpret_cast<__nv_bfloat16 const*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), params.n, params.k); } else { weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock, Batch, BlockSize><<>>(params.qweight, - reinterpret_cast(params.scales), - reinterpret_cast(params.zeros), - reinterpret_cast(params.in), - reinterpret_cast(params.act_scale), - reinterpret_cast(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), + reinterpret_cast<__nv_bfloat16 const*>(params.scales), + reinterpret_cast<__nv_bfloat16 const*>(params.zeros), + reinterpret_cast<__nv_bfloat16 const*>(params.in), + reinterpret_cast<__nv_bfloat16 const*>(params.act_scale), + reinterpret_cast<__nv_bfloat16 const*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out), params.n, params.k); } } diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu index 398a54ee8..75ef3f158 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.cu @@ -25,12 +25,12 @@ template struct WeightOnlyBatchedGemvKernelLauncher { - static void run(const WeightOnlyParams& params, cudaStream_t stream); + static void run(WeightOnlyParams const& params, cudaStream_t stream); }; template class ActOp, int N_PER_BLOCK, int BATCH, int BLOCK_SIZE> -void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) +void select_zero_bias(WeightOnlyParams const& params, cudaStream_t stream) { if (params.zeros && params.bias) { @@ -55,7 +55,7 @@ void select_zero_bias(const WeightOnlyParams& params, cudaStream_t stream) } template -void select_activation(const WeightOnlyParams& params, cudaStream_t stream) +void select_activation(WeightOnlyParams const& params, cudaStream_t stream) { switch (params.act_func_type) { @@ -86,7 +86,7 @@ void select_activation(const WeightOnlyParams& params, cudaStream_t stream) } template -void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) +void select_quant_type(WeightOnlyParams const& params, cudaStream_t stream) { if (params.quant_type == WeightOnlyQuantType::Int4b) { @@ -103,7 +103,7 @@ void select_quant_type(const WeightOnlyParams& params, cudaStream_t stream) } template -void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t stream) +void select_groupwise_weight_only(WeightOnlyParams const& params, cudaStream_t stream) { if (params.weight_only_type == WeightOnlyType::GroupWise && params.group_size == 64) { @@ -119,7 +119,7 @@ void select_groupwise_weight_only(const WeightOnlyParams& params, cudaStream_t s } } -void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream) +void weight_only_batched_gemv_launcher(WeightOnlyParams const& params, cudaStream_t stream) { assert(params.act_func_type == WeightOnlyActivationFunctionType::Identity); assert(params.weight_only_type == WeightOnlyType::GroupWise diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h index 69877631a..65498c612 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelLauncher.h @@ -20,6 +20,6 @@ namespace tensorrt_llm { namespace kernels { -void weight_only_batched_gemv_launcher(const WeightOnlyParams& params, cudaStream_t stream); +void weight_only_batched_gemv_launcher(WeightOnlyParams const& params, cudaStream_t stream); } } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h index b5cd6764e..560a7cfc1 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h @@ -56,8 +56,8 @@ struct ConverterI4ToF16 static constexpr uint32_t hfma_bias_rep = 0xD480E408; static constexpr uint32_t hfma_scale_rep = 0x2C003C00; - const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); - const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + half2 const& hfma_bias = reinterpret_cast(hfma_bias_rep); + half2 const& hfma_scale = reinterpret_cast(hfma_scale_rep); #pragma unroll for (int ii = 0; ii < 4; ++ii) { @@ -322,8 +322,8 @@ __global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint static constexpr int CtaK = StepK * Threads; static_assert(CtaN % 2 == 0); - const int m_tile_id = blockIdx.x, n_tile_id = blockIdx.y, tid = threadIdx.x; - const int m_offset = m_tile_id * CtaM, n_offset = n_tile_id * CtaN; + int const m_tile_id = blockIdx.x, n_tile_id = blockIdx.y, tid = threadIdx.x; + int const m_offset = m_tile_id * CtaM, n_offset = n_tile_id * CtaN; act += m_offset * k; weight += n_offset * k / 2; diff --git a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h index e53814525..60d92fc5b 100644 --- a/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h +++ b/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/utility.h @@ -45,7 +45,7 @@ __inline__ __device__ float tanh_opt(float x) asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); return r; #else - const float exp_val = -1.f * fabs(2 * x); + float const exp_val = -1.f * fabs(2 * x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); #endif } @@ -53,9 +53,9 @@ __inline__ __device__ float tanh_opt(float x) template struct GeluActivation { - static __device__ __forceinline__ T apply(const T& val) + static __device__ __forceinline__ T apply(T const& val) { - const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + float const cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); return val * cdf; } }; @@ -63,7 +63,7 @@ struct GeluActivation template struct ReluActivation { - static __device__ __forceinline__ T apply(const T& val) + static __device__ __forceinline__ T apply(T const& val) { return val > static_cast(0.0f) ? val : static_cast(0.0f); } @@ -72,7 +72,7 @@ struct ReluActivation template struct IdentityActivation { - static __device__ __forceinline__ T apply(const T& val) + static __device__ __forceinline__ T apply(T const& val) { return val; } @@ -81,11 +81,11 @@ struct IdentityActivation template __device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) { - *reinterpret_cast(dst) = *(reinterpret_cast(src) + offset); + *reinterpret_cast(dst) = *(reinterpret_cast(src) + offset); } template -__device__ __forceinline__ void assign(T* dst, const AssignType& val) +__device__ __forceinline__ void assign(T* dst, AssignType const& val) { *reinterpret_cast(dst) = val; } @@ -93,7 +93,7 @@ __device__ __forceinline__ void assign(T* dst, const AssignType& val) template __device__ __forceinline__ void store(T0* src, T1* dst, size_t offset = 0) { - *(reinterpret_cast(dst) + offset) = *reinterpret_cast(src); + *(reinterpret_cast(dst) + offset) = *reinterpret_cast(src); } } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu index b18c3e902..2aeec3864 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.cu @@ -28,16 +28,16 @@ namespace tensorrt_llm namespace layers { -__global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, - const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, +__global__ void update_indir_cache_kernel(int* tgt_indir_cache, int const* src_indir_cache, int const** parent_ids, + FinishedState const* finished, int const* sequence_lengths, int const* input_lengths, int batch_dim, int local_batch_size, int beam_width, int max_attention_window, int sink_token_length, int max_seq_len) { int time_step = threadIdx.x + blockIdx.x * blockDim.x; int bb_id = threadIdx.y + blockIdx.y * blockDim.y; // should be just blockIdx.y? - const int current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1 - const int input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]}; - const int batch_id = bb_id / beam_width; - const int beam_id = bb_id % beam_width; + int const current_step{sequence_lengths[bb_id] - 1}; // the sequence_lengths is updated, need to minus 1 + int const input_length{input_lengths == nullptr ? 0 : input_lengths[bb_id]}; + int const batch_id = bb_id / beam_width; + int const beam_id = bb_id % beam_width; // Exit when the batch_beam or timestep is out of the bound. // Assume that KV Cache is shared and fixed for context part, // so we don't need to update the indices for context part. @@ -54,7 +54,7 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i } // for the parent_ids, we will still keep it for all past tokens (i.e. max_seq_len) - const int src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step]; + int const src_beam = parent_ids[batch_id][beam_id * max_seq_len + current_step]; // for the indir tables, we have the cyclic kv cache. const uint32_t tgt_offset @@ -65,8 +65,8 @@ __global__ void update_indir_cache_kernel(int* tgt_indir_cache, const int* src_i tgt_indir_cache[tgt_offset] = (time_step == current_step) ? beam_id : src_indir_cache[src_offset]; } -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int** parent_ids, - const FinishedState* finished, const int* sequence_lengths, const int* input_lengths, int batch_dim, +void update_indir_cache_kernelLauncher(int* tgt_indir_cache, int const* src_indir_cache, int const** parent_ids, + FinishedState const* finished, int const* sequence_lengths, int const* input_lengths, int batch_dim, int local_batch_size, int beam_width, int max_seq_len, int max_attention_window, int sink_token_length, cudaStream_t stream) { @@ -136,17 +136,17 @@ void BaseBeamSearchLayer::forward(BeamSearchOutputParams& outputs, ForwardPar TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__); Tensor& output_ids_ptr = outputs.output_ids_ptr; - const auto batch_size = static_cast(output_ids_ptr.shape[0]); - const auto beam_width = static_cast(output_ids_ptr.shape[1]); - const auto max_seq_len = static_cast(output_ids_ptr.shape[2]); + auto const batch_size = static_cast(output_ids_ptr.shape[0]); + auto const beam_width = static_cast(output_ids_ptr.shape[1]); + auto const max_seq_len = static_cast(output_ids_ptr.shape[2]); TLLM_CHECK_WITH_INFO(params.ite == 0, "Pipeline Parallelism is not supported yet !"); - const int ite = params.ite; - auto* const input_lengths = params.input_lengths ? params.input_lengths->template getPtr() : nullptr; + int const ite = params.ite; + auto* const input_lengths = params.input_lengths ? params.input_lengths->template getPtr() : nullptr; int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; Tensor const& logits = params.logits; - const auto local_batch_size = logits.shape[0]; + auto const local_batch_size = logits.shape[0]; invokeSoftMax(outputs, params); sync_check_cuda_error(); @@ -154,9 +154,9 @@ void BaseBeamSearchLayer::forward(BeamSearchOutputParams& outputs, ForwardPar if (beam_width > 1) { update_indir_cache_kernelLauncher(outputs.tgt_cache_indirection.template getPtr(), - params.src_cache_indirection.template getPtr(), - outputs.parent_ids_ptr.template getPtr(), - reinterpret_cast( + params.src_cache_indirection.template getPtr(), + outputs.parent_ids_ptr.template getPtr(), + reinterpret_cast( outputs.finished->template getPtr()), sequence_length, input_lengths, batch_size, local_batch_size, beam_width, max_seq_len, params.max_attention_window, params.sink_token_length, mStream); diff --git a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h index ba327f2a8..4b92a534d 100644 --- a/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h +++ b/cpp/tensorrt_llm/layers/baseBeamSearchLayer.h @@ -113,8 +113,8 @@ class BaseBeamSearchLayer : public BaseLayer void freeBuffer(); }; -void update_indir_cache_kernelLauncher(int* tgt_indir_cache, const int* src_indir_cache, const int* beam_ids, - const tensorrt_llm::kernels::FinishedState* finished, int batch_dim, int beam_width, int max_seq_len, int ite, +void update_indir_cache_kernelLauncher(int* tgt_indir_cache, int const* src_indir_cache, int const* beam_ids, + tensorrt_llm::kernels::FinishedState const* finished, int batch_dim, int beam_width, int max_seq_len, int ite, cudaStream_t stream); } // namespace layers diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp index 8a92038a1..9476a11c4 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.cpp @@ -351,8 +351,8 @@ void DynamicDecodeLayer::forward(OutputParams& outputs, ForwardParams const& std::vector batchSlotsVec(batchSize); std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0); - auto batchSlotsHost = params.batch_slots ? params.batch_slots->template getPtr() : batchSlotsVec.data(); - auto batchSlots = params.batch_slots ? params.batch_slots->template getPtr() : nullptr; + auto batchSlotsHost = params.batch_slots ? params.batch_slots->template getPtr() : batchSlotsVec.data(); + auto batchSlots = params.batch_slots ? params.batch_slots->template getPtr() : nullptr; mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen; prepareIdsPtrs(outputs, batchSlotsHost, batchSize, beamWidth, maxSeqLen); @@ -410,12 +410,12 @@ void DynamicDecodeLayer::layersForward(Tensor& logits, OutputParams& outputs, // Because we still not support batch beam search now, so we need to compute // one by one if there are different runtime arguments. const size_t dynamic_decode_batch_size = mHasDiffRuntimeArgs ? 1 : localBatchSize; - const int dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size; + int const dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size; for (uint32_t dynamic_ite = 0; dynamic_ite < dynamic_decode_total_iteration; ++dynamic_ite) { - const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth; - const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded; + int const dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth; + int const dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded; auto const logits_offset = logits.slice( {dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, dynamic_decode_vocab_size_units_offset); @@ -527,7 +527,7 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams if (params.input_lengths) { auto& input_lengths = params.input_lengths.value(); - inputLengths = input_lengths.template getPtr(); + inputLengths = input_lengths.template getPtr(); } auto* embeddingBias = params.embedding_bias ? params.embedding_bias->template getPtr() : nullptr; #define GET_PENALTIES(capital_name, penalty_name, type) \ @@ -545,14 +545,16 @@ void DynamicDecodeLayer::applyPenalties(OutputParams& outputs, ForwardParams #undef GET_PENALTIES + constexpr int32_t maxTokensPerStep = 1; + int32_t* tokensPerStep = nullptr; InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(logitsPtrsHostData), mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures, repetitionPenalties, presencePenalties, frequencyPenalties, (mUseRepetitionPenalty || mUsePresencePenalty || mUseFrequencyPenalty), batchSize, static_cast(beamWidth), static_cast(maxSeqLen), mVocabSize, mVocabSizePadded, - outputs.output_ids_ptr.template getPtr(), outputs.parent_ids_ptr.template getPtr(), - inputLengths, outputs.sequence_length->template getPtr(), minLengths, - params.end_ids.template getPtr(), batchSlots, mStream}; + outputs.output_ids_ptr.template getPtr(), outputs.parent_ids_ptr.template getPtr(), + inputLengths, outputs.sequence_length->template getPtr(), minLengths, + params.end_ids.template getPtr(), batchSlots, maxTokensPerStep, tokensPerStep, mStream}; invokeBatchApplyPenalty(penaltyParams); sync_check_cuda_error(); @@ -581,14 +583,14 @@ void DynamicDecodeLayer::banRepeatNGrams(Tensor& logits, OutputParams& output auto const max_step = params.step; if (params.no_repeat_ngram_size) { - const int* noRepeatNgramSizeBuf = params.no_repeat_ngram_size.value().template getPtr(); + int const* noRepeatNgramSizeBuf = params.no_repeat_ngram_size.value().template getPtr(); - invokeBanRepeatNgram(logits.template getPtr(), outputs.output_ids_ptr.template getPtr(), + invokeBanRepeatNgram(logits.template getPtr(), outputs.output_ids_ptr.template getPtr(), reinterpret_cast( params.finished.value_or(Tensor{}).template getPtr()), - outputs.parent_ids_ptr.template getPtr(), batchSlots, + outputs.parent_ids_ptr.template getPtr(), batchSlots, outputs.sequence_length->template getPtr(), batchSize, beamWidth, maxSeqLen, - params.no_repeat_ngram_size.value().template getPtr(), vocabSizePadded, max_step, stream); + params.no_repeat_ngram_size.value().template getPtr(), vocabSizePadded, max_step, stream); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } @@ -654,7 +656,7 @@ void DynamicDecodeLayer::checkMaxLengthStopCriteria(OutputParams& outputs, Fo invokeLengthCriterion( reinterpret_cast(outputs.finished->template getPtr()), outputs.finished_sum ? outputs.finished_sum->template getPtr() : nullptr, - params.sequence_limit_length->template getPtr(), + params.sequence_limit_length->template getPtr(), outputs.sequence_length->template getPtr(), batchSlots, batchSize, beamWidth, stream); sync_check_cuda_error(); } diff --git a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h index 0efbf4b42..af84465bf 100644 --- a/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h +++ b/cpp/tensorrt_llm/layers/dynamicDecodeLayer.h @@ -117,18 +117,18 @@ class DynamicDecodeLayer : public BaseLayer std::optional> logits_vec; // [batch_size], on gpu // optional parameters - std::optional finished; // [batch_size * beam_width], optional + std::optional finished; // [batch_size * beam_width] std::optional src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - the k/v cache // index for beam search, mandatory for beam search, on gpu std::optional sequence_limit_length; // [batch_size], on gpu std::optional embedding_bias; // [vocab_size_padded], on gpu std::optional input_lengths; // [batch_size, beam_width], on gpu - std::optional bad_words_ptr; // [2, bad_words_length] or [batch_size, 2, bad_words_length], on gpu - std::optional bad_words_lengths; // [batch_size], on gpu - std::optional stop_words_ptr; // [batch_size][2, stop_words_length], on gpu - std::optional stop_words_lengths; // [batch_size], on gpu - std::optional no_repeat_ngram_size; // [batch_size], optional - std::optional batch_slots; // [batch_size], optional, in pinned memory + std::optional bad_words_ptr; // [batch_size][2, bad_words_length], on gpu + std::optional bad_words_lengths; // [batch_size], on gpu + std::optional stop_words_ptr; // [batch_size][2, stop_words_length], on gpu + std::optional stop_words_lengths; // [batch_size], on gpu + std::optional no_repeat_ngram_size; // [batch_size] + std::optional batch_slots; // [batch_size] in pinned memory }; class OutputParams @@ -140,25 +140,24 @@ class DynamicDecodeLayer : public BaseLayer } // mandatory parameters - tc::Tensor output_ids; // [batch_size, beam_width. max_seq_len] + tc::Tensor output_ids; // [batch_size, beam_width, max_seq_len] tc::Tensor newTokens; // [batch_size, beam_width] // optional parameters - std::optional finished; // [batch_size * beam_width], optional - std::optional finished_sum; // [1], optional, in pinned host memory + std::optional finished; // [batch_size * beam_width] + std::optional finished_sum; // [1] in pinned host memory std::optional cum_log_probs; // [batch_size * beam_width], necessary in beam search std::optional parent_ids; // [max_seq_len, batch_size * beam_width], necessary in beam search - std::optional sequence_length; // [batch_size * beam_width], optional - std::optional - output_log_probs_tiled; // [request_output_length, batch_size, beam_width], must be float*, optional - std::optional - output_log_probs; // [batch_size, beam_width, request_output_length], must be float*, optional - std::optional - tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len], the k/v cache index for beam search - std::shared_ptr - beamHypotheses; // a special structure which maintains some pointers of beam search - - tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] - tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] + std::optional sequence_length; // [batch_size * beam_width] + std::optional output_log_probs_tiled; // [request_output_length, batch_size, beam_width] + // must be float* + std::optional output_log_probs; // [batch_size, beam_width, request_output_length] + // must be float* + std::optional tgt_cache_indirection; // [local_batch_size, beam_width, max_seq_len] + // the k/v cache index for beam search + std::shared_ptr beamHypotheses; // structure maintains some pointers of beam search + + tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] + tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len] }; void forward(OutputParams& outputs, ForwardParams const& params); diff --git a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu index d85e5d971..0cabf6e15 100644 --- a/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu +++ b/cpp/tensorrt_llm/layers/onlineBeamSearchLayer.cu @@ -27,41 +27,42 @@ namespace tensorrt_llm namespace layers { -static const int SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; -static const int MAX_K = 4; +static int const SMALL_TOP_K_SOFTMAX_MAX_VOC_PARTS = 128; +static int const MAX_K = 4; template -__global__ void update_kernel(FinishedState* finished, BeamHypotheses beam_hyps) +__global__ void update_kernel(BeamHypotheses beam_hyps) { - const int beam_width{beam_hyps.beam_width}; - const int ite{beam_hyps.ite}; - const int local_batch_size{beam_hyps.local_batch_size}; - const int max_seq_len{beam_hyps.max_seq_len}; - const int vocab_size{beam_hyps.vocab_size}; - const int end_id{beam_hyps.end_ids[blockIdx.x]}; + int const beam_width{beam_hyps.beam_width}; + int const ite{beam_hyps.ite}; + int const local_batch_size{beam_hyps.local_batch_size}; + int const max_seq_len{beam_hyps.max_seq_len}; + int const vocab_size{beam_hyps.vocab_size}; + int const end_id{beam_hyps.end_ids[blockIdx.x]}; int* num_beams{beam_hyps.num_beams}; int* sequence_lengths{beam_hyps.sequence_lengths_src}; int** output_ids_ptr{beam_hyps.output_ids_tgt_ptr}; int** parent_ids_ptr{beam_hyps.parent_ids_tgt_ptr}; + FinishedState* finished{beam_hyps.finished}; extern __shared__ char s_buf[]; // intermediate result int* s_sequence_lengths = reinterpret_cast(s_buf); for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) { - const auto batch_beam_idx = blockIdx.x * beam_width + beam_idx; + auto const batch_beam_idx = blockIdx.x * beam_width + beam_idx; s_sequence_lengths[beam_idx] = sequence_lengths[batch_beam_idx]; } __syncthreads(); for (int beam_idx = threadIdx.x; beam_idx < beam_width; beam_idx += blockDim.x) { - const auto batch_beam_idx = blockIdx.x * beam_width + beam_idx; - const int current_step{s_sequence_lengths[beam_idx]}; + auto const batch_beam_idx = blockIdx.x * beam_width + beam_idx; + int const current_step{s_sequence_lengths[beam_idx]}; // Increase the seq_len even if the request has finished. // On the following iteration we check if the sequence has finished before - const auto finish_state = finished[batch_beam_idx]; + auto const finish_state = finished[batch_beam_idx]; if (!finish_state.isFinished()) { s_sequence_lengths[beam_idx]++; @@ -88,11 +89,11 @@ __global__ void update_kernel(FinishedState* finished, BeamHypotheses beam_hyps) } } -void invokeUpdate(FinishedState* finished, BeamHypotheses& beam_hyps, cudaStream_t stream) +void invokeUpdate(BeamHypotheses& beam_hyps, cudaStream_t stream) { dim3 grid(beam_hyps.local_batch_size); dim3 block(min(beam_hyps.beam_width, 1024)); - update_kernel<<>>(finished, beam_hyps); + update_kernel<<>>(beam_hyps); } template @@ -117,21 +118,24 @@ template void OnlineBeamSearchLayer::invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) { TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__); - auto* finished - = reinterpret_cast(outputs.finished->template getPtr()); BeamHypotheses beam_hyps; if (outputs.beamHypotheses) { beam_hyps = *outputs.beamHypotheses; - // Some of beam_hyps members have been initialized before function invokeSoftMax - beam_hyps.end_ids = params.end_ids.template getPtr(); - beam_hyps.log_probs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; - beam_hyps.output_ids_src_ptr = outputs.output_ids_ptr.template getPtr(); + beam_hyps.end_ids = params.end_ids.template getPtr(); + beam_hyps.finished + = reinterpret_cast(outputs.finished->template getPtr()); + beam_hyps.cum_log_probs_src = outputs.cum_log_probs->template getPtr(); + beam_hyps.log_probs_src + = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; + beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr(); beam_hyps.output_ids_tgt_ptr = outputs.output_ids_ptr.template getPtr(); - beam_hyps.parent_ids_src_ptr = outputs.parent_ids_ptr.template getPtr(); beam_hyps.parent_ids_tgt_ptr = outputs.parent_ids_ptr.template getPtr(); - beam_hyps.sequence_lengths_src = outputs.sequence_length->template getPtr(); + + beam_hyps.diversity_rates = diversity_rates_buf_; + beam_hyps.length_penalties = length_penalties_buf_; + beam_hyps.early_stoppings = early_stoppings_buf_; beam_hyps.batch_size = static_cast(outputs.output_ids_ptr.shape[0]); beam_hyps.beam_width = static_cast(outputs.output_ids_ptr.shape[1]); @@ -139,17 +143,15 @@ void OnlineBeamSearchLayer::invokeSoftMax(BeamSearchOutputParams& outputs, So beam_hyps.local_batch_size = params.logits.shape[0]; beam_hyps.max_seq_len = static_cast(outputs.output_ids_ptr.shape[2]); beam_hyps.vocab_size = vocab_size_padded_; - beam_hyps.diversity_rates = diversity_rates_buf_; - beam_hyps.length_penalties = length_penalties_buf_; - beam_hyps.early_stoppings = early_stoppings_buf_; } - invokeTopkSoftMax(params.logits.template getPtr(), (const T*) (nullptr), finished, - outputs.cum_log_probs->template getPtr(), topk_softmax_workspace_, topk_softmax_workspace_size_, - beam_hyps, mStream); + T const* logits = params.logits.template getPtr(); + T const* bias = static_cast(nullptr); + + invokeTopkSoftMax(logits, bias, topk_softmax_workspace_, topk_softmax_workspace_size_, beam_hyps, mStream); sync_check_cuda_error(); - invokeUpdate(finished, beam_hyps, mStream); + invokeUpdate(beam_hyps, mStream); sync_check_cuda_error(); } diff --git a/cpp/tensorrt_llm/layers/samplingLayer.cpp b/cpp/tensorrt_llm/layers/samplingLayer.cpp index d7d55f945..db68ba20c 100644 --- a/cpp/tensorrt_llm/layers/samplingLayer.cpp +++ b/cpp/tensorrt_llm/layers/samplingLayer.cpp @@ -159,8 +159,8 @@ void SamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& inp auto const batchSize = inputs.logits.shape[0]; auto logits = inputs.logits.template getPtr(); - auto endIds = inputs.end_ids.template getPtr(); - auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; + auto endIds = inputs.end_ids.template getPtr(); + auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr() : nullptr; float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; @@ -170,7 +170,7 @@ void SamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& inp std::vector batchSlotsVec(batchSize); std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0); - auto batchSlotsHost = inputs.batch_slots ? inputs.batch_slots->template getPtr() : batchSlotsVec.data(); + auto batchSlotsHost = inputs.batch_slots ? inputs.batch_slots->template getPtr() : batchSlotsVec.data(); bool skipTopK = !mDecodingMode.isTopK(); if (!skipTopK) diff --git a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu index 029dff340..bcbab8366 100644 --- a/cpp/tensorrt_llm/layers/topKSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topKSamplingLayer.cu @@ -37,7 +37,7 @@ namespace layers template __global__ void setupTopKRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP, - float* topPs, int topPsSize, bool* skipDecode, const int* batchSlots) + float* topPs, int topPsSize, bool* skipDecode, int const* batchSlots) { int index = blockIdx.x * blockDim.x + threadIdx.x; for (int bi = index; bi < batchSize; bi += gridDim.x * blockDim.x) @@ -73,9 +73,7 @@ template void TopKSamplingLayer::allocateBuffer(size_t const batchSize) { TLLM_LOG_TRACE(__PRETTY_FUNCTION__); - invokeTopKSampling(nullptr, mSamplingWorkspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, mMaxBatchSize, - nullptr, mNormalizeLogProbs, false); + mSamplingWorkspaceSize = getTopKWorkspaceSize(batchSize, 1, TOP_K_MAX, mVocabSizePadded); std::array deviceBufferSizes; deviceBufferSizes[0] = sizeof(uint32_t) * batchSize; @@ -190,8 +188,8 @@ void TopKSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& auto const batchSize = inputs.logits.shape[0]; auto logits = inputs.logits.template getPtr(); - auto endIds = inputs.end_ids.template getPtr(); - auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; + auto endIds = inputs.end_ids.template getPtr(); + auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto curandStatesDevice = inputs.curand_states; auto samplingWorkspaceDevice = inputs.sampling_workspace; auto const probsComputed = inputs.probs_computed; @@ -210,11 +208,12 @@ void TopKSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr() : nullptr; int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr() : nullptr; - invokeBatchTopKSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, logits, + invokeBatchTopKSampling(samplingWorkspaceDevice, logits, static_cast(nullptr), outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, - outputLogProbs, curandStatesDevice, (int) mRuntimeMaxTopK, (int*) (mRuntimeTopKDevice), 1.0f, - mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mMaxBatchSize, mSkipDecodeDevice, - mNormalizeLogProbs, probsComputed); + outputLogProbs, curandStatesDevice, static_cast(mRuntimeMaxTopK), + reinterpret_cast(mRuntimeTopKDevice), 1.0f, mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, + mStream, batchSize, mMaxBatchSize, nullptr, 1, mSkipDecodeDevice, mNormalizeLogProbs, probsComputed, + /* return all Top-K*/ false); sync_check_cuda_error(); } diff --git a/cpp/tensorrt_llm/layers/topKSamplingLayer.h b/cpp/tensorrt_llm/layers/topKSamplingLayer.h index 4e27803cc..fdcdf02cc 100644 --- a/cpp/tensorrt_llm/layers/topKSamplingLayer.h +++ b/cpp/tensorrt_llm/layers/topKSamplingLayer.h @@ -45,7 +45,7 @@ class TopKSamplingLayer : public BaseSamplingLayer void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override; void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override; - const bool* getSkipDecodeHost() const + bool const* getSkipDecodeHost() const { return mSkipDecodeHost; } diff --git a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu index 135de4b4f..958b4bd08 100644 --- a/cpp/tensorrt_llm/layers/topPSamplingLayer.cu +++ b/cpp/tensorrt_llm/layers/topPSamplingLayer.cu @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/common/reduceKernelUtils.cuh" #include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/kernels/samplingAirTopPKernels.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/samplingTopPKernels.h" #include "tensorrt_llm/layers/topPSamplingLayer.h" @@ -35,7 +36,7 @@ namespace layers { static __global__ void setTopPRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP, - float* topPs, int topPsSize, bool* skipDecode, const int* batchSlots, float* initialTopPBuf) + float* topPs, int topPsSize, bool* skipDecode, int const* batchSlots, float* initialTopPBuf) { /** * @brief Setup the runtime arguments for topp, broadcasting top_p to top_ps @@ -69,30 +70,11 @@ void TopPSamplingLayer::allocateBuffer(size_t batchSize) TLLM_LOG_TRACE(__PRETTY_FUNCTION__); if (mIsDeterministic) { - invokeTopPSampling(nullptr, // workspace - mSamplingWorkspaceSize, mCubTempStorageSize, - nullptr, // output_ids - nullptr, // sequence_length - nullptr, // finished_input_buffer - nullptr, // finished_output_buffer - nullptr, // cum_log_probs - nullptr, // output_log_probs - nullptr, // log_probs - mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mMaxBatchSize, - mVocabSizePadded, nullptr, 0.f, mStream, nullptr, nullptr); + mSamplingWorkspaceSize = getTopPWorkspaceSize(batchSize, mVocabSizePadded); } else { - invokeAirTopPSampling(nullptr, mSamplingWorkspaceSize, - nullptr, // output_ids - nullptr, // sequence_length - nullptr, // finished_input_buffer - nullptr, // finished_output_buffer - nullptr, // cum_log_probs - nullptr, // output_log_probs - nullptr, // log_probs) - nullptr, batchSize, mMaxBatchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, - nullptr); + mSamplingWorkspaceSize = getAirTopPWorkspaceSize(batchSize, mVocabSizePadded); } std::array deviceBufferSizes; @@ -285,8 +267,8 @@ void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& // Probabilities must be already computed instead of logits auto probs = inputs.logits.template getPtr(); - auto endIds = inputs.end_ids.template getPtr(); - auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; + auto endIds = inputs.end_ids.template getPtr(); + auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr() : nullptr; auto curandStatesDevice = inputs.curand_states; auto samplingWorkspaceDevice = inputs.sampling_workspace; @@ -313,23 +295,22 @@ void TopPSamplingLayer::forward(DecodingOutputParams& outputs, ForwardParams& if (mIsDeterministic) { - invokeBatchTopPSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, mCubTempStorageSize, - outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, - outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice, - batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, - mSkipDecodeDevice, batchSlots); + invokeBatchTopPSampling(samplingWorkspaceDevice, outputs.output_ids_ptr.template getPtr(), + sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, probs, mTopPIdValsDevice, + mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, + endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice, batchSlots); sync_check_cuda_error(); invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice, - outputs.output_ids_ptr.template getPtr(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice, + outputs.output_ids_ptr.template getPtr(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice, sequenceLength, batchSlots, batchSize, mStream); sync_check_cuda_error(); } else { - invokeBatchAirTopPSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, - outputs.output_ids_ptr.template getPtr(), sequenceLength, finishedInput, finishedOutput, cumLogProbs, - outputLogProbs, probs, curandStatesDevice, batchSize, mMaxBatchSize, mVocabSizePadded, endIds, - mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots); + invokeBatchAirTopPSampling(samplingWorkspaceDevice, outputs.output_ids_ptr.template getPtr(), + sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, probs, curandStatesDevice, + batchSize, mMaxBatchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, + mAirTopPBlockNum, mSkipDecodeDevice, batchSlots); sync_check_cuda_error(); } } diff --git a/cpp/tensorrt_llm/layers/topPSamplingLayer.h b/cpp/tensorrt_llm/layers/topPSamplingLayer.h index e0bda8eab..dd485d218 100644 --- a/cpp/tensorrt_llm/layers/topPSamplingLayer.h +++ b/cpp/tensorrt_llm/layers/topPSamplingLayer.h @@ -45,7 +45,7 @@ class TopPSamplingLayer : public BaseSamplingLayer void setup(std::size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override; void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override; - const bool* getSkipDecodeHost() const + bool const* getSkipDecodeHost() const { return mSkipDecodeHost; } @@ -65,7 +65,6 @@ class TopPSamplingLayer : public BaseSamplingLayer int32_t* mBeginTopPOffsetDevice = nullptr; bool* mSkipDecodeDevice = nullptr; bool* mSkipDecodeHost = nullptr; - size_t mCubTempStorageSize; bool mIsDeterministic = true; int mAirTopPBlockNum; diff --git a/cpp/tensorrt_llm/plugins/CMakeLists.txt b/cpp/tensorrt_llm/plugins/CMakeLists.txt index 8991a9dc6..dce57145e 100755 --- a/cpp/tensorrt_llm/plugins/CMakeLists.txt +++ b/cpp/tensorrt_llm/plugins/CMakeLists.txt @@ -114,7 +114,6 @@ target_link_libraries( ${PLUGIN_SHARED_TARGET} ${CUBLAS_LIB} ${CUBLASLT_LIB} - ${CUDNN_LIB} nvinfer ${CUDA_DRV_LIB} ${CUDA_RT_LIB} diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp index 831b48433..69ac51a15 100644 --- a/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp @@ -93,7 +93,7 @@ bool pluginsInitialized = false; extern "C" { - bool initTrtLlmPlugins(void* logger, const char* libNamespace) + bool initTrtLlmPlugins(void* logger, char const* libNamespace) { if (pluginsInitialized) return true; diff --git a/cpp/tensorrt_llm/plugins/api/tllmPlugin.h b/cpp/tensorrt_llm/plugins/api/tllmPlugin.h index 5b536a467..6b75f4f24 100644 --- a/cpp/tensorrt_llm/plugins/api/tllmPlugin.h +++ b/cpp/tensorrt_llm/plugins/api/tllmPlugin.h @@ -47,7 +47,7 @@ class LoggerFinder : public nvinfer1::ILoggerFinder extern "C" { // This function is used for explicitly registering the TRT-LLM plugins and the default logger. - bool initTrtLlmPlugins(void* logger, const char* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace); + bool initTrtLlmPlugins(void* logger, char const* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace); // The functions below are used by TensorRT to when loading a shared plugin library with automatic registering. // see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#generating-plugin-library diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp index a476798a5..546fb3483 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.cpp @@ -27,8 +27,8 @@ namespace tc = tensorrt_llm::common; using tensorrt_llm::plugins::BertAttentionPluginCreator; using tensorrt_llm::plugins::BertAttentionPlugin; -static const char* BERT_ATTENTION_PLUGIN_VERSION{"1"}; -static const char* BERT_ATTENTION_PLUGIN_NAME{"BertAttention"}; +static char const* BERT_ATTENTION_PLUGIN_VERSION{"1"}; +static char const* BERT_ATTENTION_PLUGIN_NAME{"BertAttention"}; PluginFieldCollection BertAttentionPluginCreator::mFC{}; std::vector BertAttentionPluginCreator::mPluginAttributes; @@ -71,9 +71,9 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s } // Parameterized constructor -BertAttentionPlugin::BertAttentionPlugin(const void* data, size_t length) +BertAttentionPlugin::BertAttentionPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mNumHeads); read(d, mHeadSize); read(d, mQScaling); @@ -101,7 +101,7 @@ nvinfer1::IPluginV2DynamicExt* BertAttentionPlugin::clone() const noexcept } nvinfer1::DimsExprs BertAttentionPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { TLLM_CHECK(outputIndex == 0); auto ret = inputs[0]; @@ -110,7 +110,7 @@ nvinfer1::DimsExprs BertAttentionPlugin::getOutputDimensions( } bool BertAttentionPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { // inputs: [0] qkv, [1] input_lengths, [2] max_input_length (optional), [3] relative_attention_bias (optional) // outputs: [X] hidden_states @@ -142,20 +142,20 @@ bool BertAttentionPlugin::supportsFormatCombination( } } -void BertAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void BertAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t BertAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t BertAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { // if remove padding, inputs[0] "qkv_hidden_states" dim is [num_tokens, 3*hidden_dim] which doesn't have shape // info should get max_batch_size and max_input_length from inputs[1] "input_lengths" and input[2] // "max_input_length" - const int batch_size = mRemovePadding ? inputs[1].dims.d[0] : inputs[0].dims.d[0]; - const int input_seq_len = mRemovePadding ? inputs[2].dims.d[0] : inputs[0].dims.d[1]; - const int local_hidden_units_ = inputs[0].dims.d[mRemovePadding ? 1 : 2] / 3; + int const batch_size = mRemovePadding ? inputs[1].dims.d[0] : inputs[0].dims.d[0]; + int const input_seq_len = mRemovePadding ? inputs[2].dims.d[0] : inputs[0].dims.d[1]; + int const local_hidden_units_ = inputs[0].dims.d[mRemovePadding ? 1 : 2] / 3; auto const size = tensorrt_llm::runtime::BufferDataType(inputs[0].type).getSize(); @@ -170,7 +170,7 @@ size_t BertAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* i = mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_len * input_seq_len; const size_t padding_offset_size = sizeof(int) * batch_size * input_seq_len; - const int NUM_BUFFERS = 10; + int const NUM_BUFFERS = 10; size_t workspaces[NUM_BUFFERS]; workspaces[0] = CUBLAS_WORKSPACE_SIZE; workspaces[1] = attention_mask_size; @@ -187,8 +187,8 @@ size_t BertAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* i } template -int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { @@ -203,17 +203,17 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc // if remove padding, inputs[0] dim is [num_tokens] which doesn't have workspace info // should get max_batch_size from inputs[1] and max_input_length from plugin attribute - const int batch_size = mRemovePadding ? inputDesc[1].dims.d[0] : inputDesc[0].dims.d[0]; - const int input_seq_len = mRemovePadding ? inputDesc[2].dims.d[0] : inputDesc[0].dims.d[1]; - const int num_tokens = mRemovePadding ? inputDesc[0].dims.d[0] : batch_size * input_seq_len; - const int request_batch_size = batch_size; - const int request_seq_len = input_seq_len; - const int local_hidden_units_ = inputDesc[0].dims.d[mRemovePadding ? 1 : 2] / 3; - const float q_scaling = mQScaling; - - const T* attention_input = reinterpret_cast(inputs[0]); - const int* input_lengths = reinterpret_cast(inputs[1]); - const T* relative_attn_table = mRelativeAttention ? reinterpret_cast(inputs[3]) : nullptr; + int const batch_size = mRemovePadding ? inputDesc[1].dims.d[0] : inputDesc[0].dims.d[0]; + int const input_seq_len = mRemovePadding ? inputDesc[2].dims.d[0] : inputDesc[0].dims.d[1]; + int const num_tokens = mRemovePadding ? inputDesc[0].dims.d[0] : batch_size * input_seq_len; + int const request_batch_size = batch_size; + int const request_seq_len = input_seq_len; + int const local_hidden_units_ = inputDesc[0].dims.d[mRemovePadding ? 1 : 2] / 3; + float const q_scaling = mQScaling; + + T const* attention_input = reinterpret_cast(inputs[0]); + int const* input_lengths = reinterpret_cast(inputs[1]); + T const* relative_attn_table = mRelativeAttention ? reinterpret_cast(inputs[3]) : nullptr; T* context_buf_ = (T*) (outputs[0]); auto cublasHandle = mCublasWrapper->getCublasHandle(); @@ -276,16 +276,16 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc invokeBuildDecoderInfo(params, stream); sync_check_cuda_error(); - const auto gemm_data_type = tc::CudaDataType::value; - const int attention_seq_len_1 = request_seq_len; // q length - const int attention_seq_len_2 = request_seq_len; // kv length + auto const gemm_data_type = tc::CudaDataType::value; + int const attention_seq_len_1 = request_seq_len; // q length + int const attention_seq_len_2 = request_seq_len; // kv length // If the model has relative attentiona bias, q scaling should be applied in QK gemm stage and use 1 in // softamax stage (because to get softmax[scale(Q*K) + rel pos bias] here, q_scaling can't be applied during // softmax phase by qk_scale); otherwise, use 1 in gemm stage and apply scaling in softmax stage - const float qk_scale + float const qk_scale = 1.0f / (sqrtf(mHeadSize * 1.0f) * q_scaling); // q_scaling in denominator. by default q_scaling =1.0f - const float qk_scale_gemm = mRelativeAttention ? qk_scale : 1.0f; + float const qk_scale_gemm = mRelativeAttention ? qk_scale : 1.0f; const T qk_scale_softmax = static_cast(mRelativeAttention ? 1.0f : qk_scale); T* linear_bias_slopes = nullptr; @@ -402,22 +402,22 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc return 0; } -template int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +template int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); -template int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +template int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); #ifdef ENABLE_BF16 -template int BertAttentionPlugin::enqueueImpl<__nv_bfloat16>(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +template int BertAttentionPlugin::enqueueImpl<__nv_bfloat16>(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); #endif -int BertAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int BertAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (mType == DataType::kHALF) @@ -439,7 +439,7 @@ int BertAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType BertAttentionPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return inputTypes[0]; @@ -447,12 +447,12 @@ nvinfer1::DataType BertAttentionPlugin::getOutputDataType( // IPluginV2 Methods -const char* BertAttentionPlugin::getPluginType() const noexcept +char const* BertAttentionPlugin::getPluginType() const noexcept { return BERT_ATTENTION_PLUGIN_NAME; } -const char* BertAttentionPlugin::getPluginVersion() const noexcept +char const* BertAttentionPlugin::getPluginVersion() const noexcept { return BERT_ATTENTION_PLUGIN_VERSION; } @@ -540,24 +540,24 @@ BertAttentionPluginCreator::BertAttentionPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* BertAttentionPluginCreator::getPluginName() const noexcept +char const* BertAttentionPluginCreator::getPluginName() const noexcept { return BERT_ATTENTION_PLUGIN_NAME; } -const char* BertAttentionPluginCreator::getPluginVersion() const noexcept +char const* BertAttentionPluginCreator::getPluginVersion() const noexcept { return BERT_ATTENTION_PLUGIN_VERSION; } -const PluginFieldCollection* BertAttentionPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* BertAttentionPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* BertAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int num_heads, head_size; ContextFMHAType context_fmha_type; bool qk_half_accum; @@ -569,51 +569,51 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const Plug // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "num_heads")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - num_heads = static_cast(*(static_cast(fields[i].data))); + num_heads = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "head_size")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - head_size = static_cast(*(static_cast(fields[i].data))); + head_size = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "q_scaling")) { TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); - q_scaling = static_cast(*(static_cast(fields[i].data))); + q_scaling = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "enable_qk_half_accum")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - qk_half_accum = static_cast(*(static_cast(fields[i].data))); + qk_half_accum = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "context_fmha_type")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - context_fmha_type = static_cast(*(static_cast(fields[i].data))); + context_fmha_type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "do_relative_attention")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - do_relative_attention = static_cast(*(static_cast(fields[i].data))); + do_relative_attention = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "max_distance")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - max_distance = static_cast(*(static_cast(fields[i].data))); + max_distance = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "remove_padding")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - remove_padding = static_cast(*(static_cast(fields[i].data))); + remove_padding = static_cast(*(static_cast(fields[i].data))); } } try @@ -623,7 +623,7 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const Plug obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -631,7 +631,7 @@ IPluginV2* BertAttentionPluginCreator::createPlugin(const char* name, const Plug } IPluginV2* BertAttentionPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call BertAttentionPlugin::destroy() @@ -641,7 +641,7 @@ IPluginV2* BertAttentionPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h index 00929c748..eb7d42c98 100644 --- a/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/bertAttentionPlugin/bertAttentionPlugin.h @@ -38,34 +38,34 @@ class BertAttentionPlugin : public BasePlugin tensorrt_llm::kernels::ContextFMHAType context_fmha_type, nvinfer1::DataType type, bool do_relative_attention = false, int max_distance = 0, bool remove_padding = false); - BertAttentionPlugin(const void* data, size_t length); + BertAttentionPlugin(void const* data, size_t length); ~BertAttentionPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; template - int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -102,16 +102,16 @@ class BertAttentionPluginCreator : public BaseCreator public: BertAttentionPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.cpp b/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.cpp index 029c638fe..2aab6b367 100644 --- a/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.cpp @@ -22,12 +22,12 @@ namespace tensorrt_llm::plugins { -void caughtError(const std::exception& e) +void caughtError(std::exception const& e) { TLLM_LOG_EXCEPTION(e); } -void logError(const char* msg, const char* file, const char* fn, int line) +void logError(char const* msg, char const* file, char const* fn, int line) { TLLM_LOG_ERROR("Parameter check failed at: %s::%s::%d, condition: %s", file, fn, line, msg); } diff --git a/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.h b/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.h index 2280306f7..d8d8af1ef 100644 --- a/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.h +++ b/cpp/tensorrt_llm/plugins/common/checkMacrosPlugin.h @@ -22,8 +22,8 @@ namespace tensorrt_llm::plugins { -void logError(const char* msg, const char* file, const char* fn, int line); +void logError(char const* msg, char const* file, char const* fn, int line); -void caughtError(const std::exception& e); +void caughtError(std::exception const& e); } // namespace tensorrt_llm::plugins diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp index 2961b2649..bcb714021 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.cpp @@ -30,7 +30,7 @@ GemmPluginProfiler::GemmPluginPro mMNKProfileMap = std::make_shared(); // set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings - const auto skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); + auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); mSkip = (skipEnv != NULL && std::stoi(skipEnv)); if (mSkip) { @@ -42,13 +42,13 @@ GemmPluginProfiler::GemmPluginPro template void GemmPluginProfiler::serialize( - char*& buffer, const GemmIdType& gemmId) const + char*& buffer, GemmIdType const& gemmId) const { auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); // Save number of profiles for given GEMM ID write(buffer, static_cast(mProfileMap->size())); - for (const auto& pair : *mProfileMap) + for (auto const& pair : *mProfileMap) { // Save pair of M to the best GEMM config write(buffer, pair); @@ -57,7 +57,7 @@ void GemmPluginProfiler::serializ template void GemmPluginProfiler::deserialize( - const char*& data, GemmDims& dims, const GemmIdType& gemmId) + char const*& data, GemmDims& dims, GemmIdType const& gemmId) { // NOTE: this mutex is not needed since each thread owns its private map, but will put here for // consistency @@ -85,7 +85,7 @@ void GemmPluginProfiler::deserial template size_t GemmPluginProfiler::getSerializationSize( - const GemmIdType& gemmId) const + GemmIdType const& gemmId) const { reader_lock lock(mMNKProfileMap->mutex); return sizeof(int) + // size of the tactics map @@ -95,7 +95,7 @@ size_t GemmPluginProfiler::getSer template void GemmPluginProfiler::profileTactics( - const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId) + RunnerPtr const& runner, nvinfer1::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId) { writer_lock lock(mMNKProfileMap->mutex); @@ -107,7 +107,7 @@ void GemmPluginProfiler::profileT mRunner = runner; mType = type; - const int maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M); + int const maxM = std::min(nextPowerOfTwo(dims.maxM), MAX_PROFILE_M); computeTmpSize(maxM, dims.n, dims.k); if (!mMNKProfileMap->existsMProfileMap(gemmId)) @@ -137,7 +137,7 @@ void GemmPluginProfiler::profileT // Allocate tmp data to run GEMMs allocateTmpData(); - const int startMinMRounded = nextPowerOfTwo(dims.minM); + int const startMinMRounded = nextPowerOfTwo(dims.minM); for (int m = startMinMRounded; m < maxM; m *= 2) { profileTactics(m, dims.n, dims.k); @@ -150,7 +150,7 @@ void GemmPluginProfiler::profileT template std::optional GemmPluginProfiler::getBestConfig( - int m, const GemmIdType& gemmId) const + int m, GemmIdType const& gemmId) const { reader_lock lock(mMNKProfileMap->mutex); @@ -159,7 +159,7 @@ std::optional GemmPluginProfilergetMProfileMap(gemmId)->at(mRounded); } @@ -168,20 +168,20 @@ template ::allocateTmpData() { TLLM_CHECK_WITH_INFO(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0"); - const auto status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); + auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling."); } template void GemmPluginProfiler::freeTmpData() { - const auto status = cudaFree(mWorkspaceTmp); + auto const status = cudaFree(mWorkspaceTmp); TLLM_CHECK_WITH_INFO(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling."); } template std::optional GemmPluginProfiler::profileTacticsForProblem( - int m, int n, int k, const std::vector& tactics) + int m, int n, int k, std::vector const& tactics) { TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -192,7 +192,7 @@ std::optional GemmPluginProfiler::max(); try { @@ -204,7 +204,7 @@ std::optional GemmPluginProfiler GemmPluginProfiler float GemmPluginProfiler::profileTacticForProblem( - int m, int n, int k, const Config& tactic) + int m, int n, int k, Config const& tactic) { constexpr int warmup = 5; constexpr int runs = 10; diff --git a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h index 9389439e2..b0fbdde8a 100644 --- a/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h +++ b/cpp/tensorrt_llm/plugins/common/gemmPluginProfiler.h @@ -72,7 +72,7 @@ class GemmIdCore int k; nvinfer1::DataType dtype; - GemmIdCore(int n_, int k_, const nvinfer1::DataType& dtype_) + GemmIdCore(int n_, int k_, nvinfer1::DataType const& dtype_) : n(n_) , k(k_) , dtype(dtype_) @@ -86,12 +86,12 @@ class GemmIdCore { } - bool operator==(const GemmIdCore& id) const + bool operator==(GemmIdCore const& id) const { return isEqual(id); } - friend std::ostream& operator<<(std::ostream& out, const GemmIdCore& id) + friend std::ostream& operator<<(std::ostream& out, GemmIdCore const& id) { out << "(N;K)=(" << id.n << ";" << id.k << "),"; out << " type=" << static_cast(id.dtype); @@ -99,7 +99,7 @@ class GemmIdCore } protected: - bool isEqual(const GemmIdCore& id) const + bool isEqual(GemmIdCore const& id) const { return n == id.n && k == id.k && dtype == id.dtype; } @@ -108,7 +108,7 @@ class GemmIdCore // Hash of GemmId struct GemmIdCoreHash { - std::size_t operator()(const GemmIdCore& id) const + std::size_t operator()(GemmIdCore const& id) const { auto h1 = std::hash{}(id.n); auto h2 = std::hash{}(id.k); @@ -123,7 +123,7 @@ class GemmIdCublas : public GemmIdCore bool transA{}; bool transB{}; - GemmIdCublas(int n_, int k_, const nvinfer1::DataType& dtype_, bool transA_, bool transB_) + GemmIdCublas(int n_, int k_, nvinfer1::DataType const& dtype_, bool transA_, bool transB_) : GemmIdCore(n_, k_, dtype_) , transA(transA_) , transB(transB_) @@ -132,12 +132,12 @@ class GemmIdCublas : public GemmIdCore GemmIdCublas() {} - bool operator==(const GemmIdCublas& id) const + bool operator==(GemmIdCublas const& id) const { return isEqual(id) && transA == id.transA && transB == id.transB; } - friend std::ostream& operator<<(std::ostream& out, const GemmIdCublas& id) + friend std::ostream& operator<<(std::ostream& out, GemmIdCublas const& id) { out << "(N;K)=(" << id.n << ";" << id.k << "),"; out << " type=" << static_cast(id.dtype); @@ -150,7 +150,7 @@ class GemmIdCublas : public GemmIdCore // Hash of GemmIdCublas struct GemmIdCublasHash { - std::size_t operator()(const GemmIdCublas& id) const + std::size_t operator()(GemmIdCublas const& id) const { auto h1 = std::hash{}(id.n); auto h2 = std::hash{}(id.k); @@ -184,20 +184,20 @@ class GemmPluginProfiler // Map from GEMM Id to profile for particular GEMM std::unordered_map profileMap; - bool existsMProfileMap(const GemmIdType& id) + bool existsMProfileMap(GemmIdType const& id) { - const auto iter = profileMap.find(id); + auto const iter = profileMap.find(id); return iter != profileMap.end(); } - void createMProfileMap(const GemmIdType& id) + void createMProfileMap(GemmIdType const& id) { profileMap[id] = std::make_shared(); } - MProfileMapPtr getMProfileMap(const GemmIdType& id) + MProfileMapPtr getMProfileMap(GemmIdType const& id) { - const auto iter = profileMap.find(id); + auto const iter = profileMap.find(id); if (iter == profileMap.end()) { std::ostringstream msg; @@ -212,15 +212,15 @@ class GemmPluginProfiler GemmPluginProfiler(); - void serialize(char*& buffer, const GemmIdType& gemmId) const; + void serialize(char*& buffer, GemmIdType const& gemmId) const; - void deserialize(const char*& data, GemmDims& dims, const GemmIdType& gemmId); - size_t getSerializationSize(const GemmIdType& gemmId) const; + void deserialize(char const*& data, GemmDims& dims, GemmIdType const& gemmId); + size_t getSerializationSize(GemmIdType const& gemmId) const; void profileTactics( - const RunnerPtr& runner, const nvinfer1::DataType& type, const GemmDims& dims, const GemmIdType& gemmId); + RunnerPtr const& runner, nvinfer1::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId); - void setSelectionTactics(const MNKProfileMapPtr& map) + void setSelectionTactics(MNKProfileMapPtr const& map) { mMNKProfileMap = map; } @@ -235,14 +235,14 @@ class GemmPluginProfiler mSkip = mSkip || skip; } - std::optional getBestConfig(int m, const GemmIdType& gemmId) const; + std::optional getBestConfig(int m, GemmIdType const& gemmId) const; protected: - virtual void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) = 0; + virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; virtual void computeTmpSize(int maxM, int n, int k) = 0; - virtual bool checkTactic(int m, int n, int k, const Config& tactic) const + virtual bool checkTactic(int m, int n, int k, Config const& tactic) const { return true; } @@ -256,9 +256,9 @@ class GemmPluginProfiler void freeTmpData(); - std::optional profileTacticsForProblem(int m, int n, int k, const std::vector& tactics); + std::optional profileTacticsForProblem(int m, int n, int k, std::vector const& tactics); - float profileTacticForProblem(int m, int n, int k, const Config& tactic); + float profileTacticForProblem(int m, int n, int k, Config const& tactic); int nextPowerOfTwo(int v) const { diff --git a/cpp/tensorrt_llm/plugins/common/plugin.h b/cpp/tensorrt_llm/plugins/common/plugin.h index ca9c5a42d..33a24d5f1 100644 --- a/cpp/tensorrt_llm/plugins/common/plugin.h +++ b/cpp/tensorrt_llm/plugins/common/plugin.h @@ -46,7 +46,7 @@ namespace tensorrt_llm::plugins class BasePlugin : public nvinfer1::IPluginV2DynamicExt { public: - void setPluginNamespace(const char* libNamespace) noexcept override + void setPluginNamespace(char const* libNamespace) noexcept override { mNamespace = libNamespace; } @@ -63,7 +63,7 @@ class BasePlugin : public nvinfer1::IPluginV2DynamicExt class BaseCreator : public nvinfer1::IPluginCreator { public: - void setPluginNamespace(const char* libNamespace) noexcept override + void setPluginNamespace(char const* libNamespace) noexcept override { mNamespace = libNamespace; } @@ -79,7 +79,7 @@ class BaseCreator : public nvinfer1::IPluginCreator // Write values into buffer template -void write(char*& buffer, const T& val) +void write(char*& buffer, T const& val) { std::memcpy(buffer, &val, sizeof(T)); buffer += sizeof(T); @@ -87,7 +87,7 @@ void write(char*& buffer, const T& val) // Read values from buffer template -void read(const char*& buffer, T& val) +void read(char const*& buffer, T& val) { std::memcpy(&val, buffer, sizeof(T)); buffer += sizeof(T); diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp index bf145e34d..4a94bd0a7 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.cpp @@ -27,8 +27,8 @@ using tensorrt_llm::plugins::CublasGemmWrapperPtr; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* GEMM_PLUGIN_VERSION{"1"}; -static const char* GEMM_PLUGIN_NAME{"Gemm"}; +static char const* GEMM_PLUGIN_VERSION{"1"}; +static char const* GEMM_PLUGIN_NAME{"Gemm"}; PluginFieldCollection GemmPluginCreator::mFC{}; std::vector GemmPluginCreator::mPluginAttributes; @@ -45,9 +45,9 @@ void getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& ldc = N; } -void runGemm(const int M, const int N, const int K, const bool transA, const bool transB, const nvinfer1::DataType type, - const CublasGemmWrapperPtr& cublasWrapperPtr, const void* act, const void* weight, void* output, - const std::optional& heuristic, void* workspace, cudaStream_t stream) +void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, const nvinfer1::DataType type, + CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act, void const* weight, void* output, + std::optional const& heuristic, void* workspace, cudaStream_t stream) { cublasWrapperPtr->setStream(stream); cublasWrapperPtr->setWorkspace(workspace); @@ -63,7 +63,7 @@ void runGemm(const int M, const int N, const int K, const bool transA, const boo } void CublasLtGemmPluginProfiler::runTactic( - int m, int n, int k, const CublasLtGemmPluginProfiler::Config& tactic, char* workspace, const cudaStream_t& stream) + int m, int n, int k, CublasLtGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { size_t dataSize = sizeof(half); if (mType == DataType::kFLOAT) @@ -81,7 +81,7 @@ void CublasLtGemmPluginProfiler::runTactic( runGemm(m, n, k, mTransA, mTransB, mType, mRunner, actPtr, weightPtr, outputPtr, {tactic}, workspacePtr, stream); } -bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, const Config& tactic) const +bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const& tactic) const { cublasOperation_t transa, transb; int M = m, N = n, K = k; @@ -90,7 +90,7 @@ bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, const Config& mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc); - const auto checkResult = mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic.algo); + auto const checkResult = mRunner->checkTactic(transa, transb, m, n, k, lda, ldb, ldc, tactic.algo); mRunner->destroyDescriptors(); @@ -120,14 +120,14 @@ std::vector CublasLtGemmPluginProfiler::getT getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K); mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc); - const auto heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc); + auto const heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc); mRunner->destroyDescriptors(); return heruistics; } GemmPlugin::GemmPlugin( - int transA, int transB, nvinfer1::DataType type, bool useFp8, const GemmPlugin::PluginProfilerPtr& pluginProfiler) + int transA, int transB, nvinfer1::DataType type, bool useFp8, GemmPlugin::PluginProfilerPtr const& pluginProfiler) : mTransA(transA) , mTransB(transB) , mType(type) @@ -139,10 +139,10 @@ GemmPlugin::GemmPlugin( } // Parameterized constructor -GemmPlugin::GemmPlugin(const void* data, size_t length, const GemmPlugin::PluginProfilerPtr& pluginProfiler) +GemmPlugin::GemmPlugin(void const* data, size_t length, GemmPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mTransA); read(d, mTransB); read(d, mType); @@ -218,14 +218,14 @@ nvinfer1::IPluginV2DynamicExt* GemmPlugin::clone() const noexcept } nvinfer1::DimsExprs GemmPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { try { TLLM_CHECK(nbInputs == 2); TLLM_CHECK(outputIndex == 0); - const int nbDimsA = inputs[0].nbDims; - const int nbDimsB = inputs[1].nbDims; + int const nbDimsA = inputs[0].nbDims; + int const nbDimsB = inputs[1].nbDims; DimsExprs ret; ret.nbDims = nbDimsA + nbDimsB - 2; @@ -259,7 +259,7 @@ nvinfer1::DimsExprs GemmPlugin::getOutputDimensions( } return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -267,9 +267,9 @@ nvinfer1::DimsExprs GemmPlugin::getOutputDimensions( } bool GemmPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { - const auto& desc = inOut[pos]; + auto const& desc = inOut[pos]; if (desc.format != TensorFormat::kLINEAR) { return false; @@ -283,7 +283,7 @@ bool GemmPlugin::supportsFormatCombination( return desc.type == mType || desc.type == nvinfer1::DataType::kFLOAT; } -int32_t computeMDimension(bool transA, const int32_t nbDims, const int32_t* dims) +int32_t computeMDimension(bool transA, const int32_t nbDims, int32_t const* dims) { int32_t M = 1; if (transA) @@ -303,7 +303,7 @@ int32_t computeMDimension(bool transA, const int32_t nbDims, const int32_t* dims return M; } -int32_t computeNDimension(bool transB, const int32_t nbDims, const int32_t* dims) +int32_t computeNDimension(bool transB, const int32_t nbDims, int32_t const* dims) { int32_t N = 1; if (transB) @@ -323,16 +323,16 @@ int32_t computeNDimension(bool transB, const int32_t nbDims, const int32_t* dims return N; } -void GemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void GemmPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { - const int nbDimsA = in[0].max.nbDims; - const int nbDimsB = in[1].max.nbDims; + int const nbDimsA = in[0].max.nbDims; + int const nbDimsB = in[1].max.nbDims; - const auto minM = computeMDimension(mTransA, nbDimsA, in[0].min.d); - const auto maxM = computeMDimension(mTransA, nbDimsA, in[0].max.d); - const auto N = computeNDimension(mTransB, nbDimsB, in[1].max.d); - const auto K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; + auto const minM = computeMDimension(mTransA, nbDimsA, in[0].min.d); + auto const maxM = computeMDimension(mTransA, nbDimsA, in[0].max.d); + auto const N = computeNDimension(mTransB, nbDimsB, in[1].max.d); + auto const K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; if (!mDims.isInitialized()) { @@ -344,14 +344,14 @@ void GemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, in mOutputType = out[0].desc.type; } -size_t GemmPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t GemmPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return CUBLAS_WORKSPACE_SIZE; } -int GemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs // mat1 [M, K] (mTransA = False) @@ -361,11 +361,11 @@ int GemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf setGemmConfig(); - const int nbDimsA = inputDesc[0].dims.nbDims; - const int nbDimsB = inputDesc[1].dims.nbDims; - const auto M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d); - const auto N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d); - const int K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; + int const nbDimsA = inputDesc[0].dims.nbDims; + int const nbDimsB = inputDesc[1].dims.nbDims; + auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d); + auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d); + int const K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId); runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], bestTactic, workspace, @@ -375,7 +375,7 @@ int GemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf // IPluginV2Ext Methods nvinfer1::DataType GemmPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return inputTypes[0]; @@ -383,12 +383,12 @@ nvinfer1::DataType GemmPlugin::getOutputDataType( // IPluginV2 Methods -const char* GemmPlugin::getPluginType() const noexcept +char const* GemmPlugin::getPluginType() const noexcept { return GEMM_PLUGIN_NAME; } -const char* GemmPlugin::getPluginVersion() const noexcept +char const* GemmPlugin::getPluginVersion() const noexcept { return GEMM_PLUGIN_VERSION; } @@ -445,50 +445,50 @@ GemmPluginCreator::GemmPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* GemmPluginCreator::getPluginName() const noexcept +char const* GemmPluginCreator::getPluginName() const noexcept { return GEMM_PLUGIN_NAME; } -const char* GemmPluginCreator::getPluginVersion() const noexcept +char const* GemmPluginCreator::getPluginVersion() const noexcept { return GEMM_PLUGIN_VERSION; } -const PluginFieldCollection* GemmPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* GemmPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* GemmPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int transA, transB; nvinfer1::DataType type; int useFp8; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "transa")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - transA = static_cast(*(static_cast(fields[i].data))); + transA = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "transb")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - transB = static_cast(*(static_cast(fields[i].data))); + transB = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "use_fp8")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - useFp8 = static_cast(*(static_cast(fields[i].data))); + useFp8 = static_cast(*(static_cast(fields[i].data))); } } try @@ -501,14 +501,14 @@ IPluginV2* GemmPluginCreator::createPlugin(const char* name, const PluginFieldCo obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } return nullptr; } -IPluginV2* GemmPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept +IPluginV2* GemmPluginCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call GemmPlugin::destroy() @@ -522,7 +522,7 @@ IPluginV2* GemmPluginCreator::deserializePlugin(const char* name, const void* se obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h index c6a24d341..75b5b8bd3 100644 --- a/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h @@ -48,11 +48,11 @@ class CublasLtGemmPluginProfiler } protected: - void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) override; + void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; void computeTmpSize(int maxM, int n, int k) override; - bool checkTactic(int m, int n, int k, const Config& tactic) const override; + bool checkTactic(int m, int n, int k, Config const& tactic) const override; std::vector getTactics(int m, int n, int k) const override; @@ -71,32 +71,32 @@ class GemmPlugin : public BasePlugin GemmPlugin() = delete; - GemmPlugin(int transA, int transB, nvinfer1::DataType type, bool useFp8, const PluginProfilerPtr& profiler); + GemmPlugin(int transA, int transB, nvinfer1::DataType type, bool useFp8, PluginProfilerPtr const& profiler); - GemmPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler); + GemmPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); ~GemmPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -133,16 +133,16 @@ class GemmPluginCreator : public BaseCreator public: GemmPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: GemmPluginProfilerManager gemmPluginProfileManager; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp index 6e3249fec..dc9812450 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.cpp @@ -66,13 +66,13 @@ struct SATypeConverter template struct FusedQKVMaskedAttentionDispatchParams { - const T* qkv_buf; - const T* qkv_bias; - const T* relative_attention_bias; - const int* cache_indir; + T const* qkv_buf; + T const* qkv_bias; + T const* relative_attention_bias; + int const* cache_indir; T* context_buf; - const bool* finished; - const int* sequence_lengths; + bool const* finished; + int const* sequence_lengths; int max_batch_size; int inference_batch_size; int beam_width; @@ -89,16 +89,16 @@ struct FusedQKVMaskedAttentionDispatchParams int max_attention_window; int cyclic_attention_window_size; int sink_token_length; - const int* input_lengths; + int const* input_lengths; int step; float q_scaling; int relative_attention_bias_stride; - const T* linear_bias_slopes; - const int* ia3_tasks; - const T* ia3_key_weights; - const T* ia3_value_weights; - const float* qkv_scale_out; - const float* attention_out_scale; + T const* linear_bias_slopes; + int const* ia3_tasks; + T const* ia3_key_weights; + T const* ia3_value_weights; + float const* qkv_scale_out; + float const* attention_out_scale; bool mUnfuseQkvGemm; tc::QuantMode quant_option; bool multi_block_mode; @@ -108,14 +108,14 @@ struct FusedQKVMaskedAttentionDispatchParams float* partial_sum; float* partial_max; int* block_counter; - const float* kv_scale_orig_quant; - const float* kv_scale_quant_orig; + float const* kv_scale_orig_quant; + float const* kv_scale_quant_orig; tc::QuantMode kv_cache_quant_mode; int multi_processor_count; KVCacheBuffer kv_block_array; KVLinearBuffer shift_k_cache_buffer; bool cross_attention = false; - const int* memory_length_per_sample = nullptr; + int const* memory_length_per_sample = nullptr; int max_distance = 0; }; @@ -158,7 +158,7 @@ struct ConvertMMHAToXQAParamsHelper<__nv_bfloat16, KVBlockArray> template bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams, - const EnqueueGenerationParams& generationsParams, bool forConfigurePlugin) + EnqueueGenerationParams const& generationsParams, bool forConfigurePlugin) { bool retval = ConvertMMHAToXQAParamsHelper::supported; if (!retval) @@ -242,7 +242,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel template void fusedQKV_masked_attention_dispatch(Multihead_attention_params& params, - const FusedQKVMaskedAttentionDispatchParams& input_params, cudaStream_t stream) + FusedQKVMaskedAttentionDispatchParams const& input_params, cudaStream_t stream) { using DataType = typename SATypeConverter::Type; @@ -253,9 +253,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params(input_params.qkv_bias); - params.k_bias = reinterpret_cast(input_params.qkv_bias) + hidden_units; - params.v_bias = reinterpret_cast(input_params.qkv_bias) + hidden_units + hidden_units_kv; + params.q_bias = reinterpret_cast(input_params.qkv_bias); + params.k_bias = reinterpret_cast(input_params.qkv_bias) + hidden_units; + params.v_bias = reinterpret_cast(input_params.qkv_bias) + hidden_units + hidden_units_kv; } else { @@ -268,9 +268,9 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params(input_params.context_buf); // Set the input buffers. - params.q = reinterpret_cast(input_params.qkv_buf); - params.k = reinterpret_cast(input_params.qkv_buf) + hidden_units; - params.v = reinterpret_cast(input_params.qkv_buf) + hidden_units + hidden_units_kv; + params.q = reinterpret_cast(input_params.qkv_buf); + params.k = reinterpret_cast(input_params.qkv_buf) + hidden_units; + params.v = reinterpret_cast(input_params.qkv_buf) + hidden_units + hidden_units_kv; params.int8_kv_cache = input_params.kv_cache_quant_mode.hasInt8KvCache(); params.fp8_kv_cache = input_params.kv_cache_quant_mode.hasFp8KvCache(); @@ -305,20 +305,20 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params(input_params.relative_attention_bias); + params.relative_attention_bias = reinterpret_cast(input_params.relative_attention_bias); params.relative_attention_bias_stride = input_params.relative_attention_bias_stride; params.max_distance = input_params.max_distance; // The slope of linear position bias per head, e.g., ALiBi. if (input_params.linear_bias_slopes != nullptr) { - params.linear_bias_slopes = reinterpret_cast(input_params.linear_bias_slopes); + params.linear_bias_slopes = reinterpret_cast(input_params.linear_bias_slopes); } params.input_lengths = input_params.input_lengths; params.ia3_tasks = input_params.ia3_tasks; - params.ia3_key_weights = reinterpret_cast(input_params.ia3_key_weights); - params.ia3_value_weights = reinterpret_cast(input_params.ia3_value_weights); + params.ia3_key_weights = reinterpret_cast(input_params.ia3_key_weights); + params.ia3_value_weights = reinterpret_cast(input_params.ia3_value_weights); if (input_params.quant_option.hasStaticActivationScaling()) { @@ -464,7 +464,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, } } -const int GPTAttentionPluginCommon::getHeadSize(bool checkInit) const +int const GPTAttentionPluginCommon::getHeadSize(bool checkInit) const { if (checkInit) { @@ -474,9 +474,9 @@ const int GPTAttentionPluginCommon::getHeadSize(bool checkInit) const } // Parameterized constructor -GPTAttentionPluginCommon::GPTAttentionPluginCommon(const void* data, size_t length) +GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; unsigned int kvCacheQuantMode; read(d, mLayerIdx); @@ -529,15 +529,15 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t int32_t input_seq_length, int32_t max_attention_window, int32_t cross_qkv_length, int32_t max_num_tokens) const noexcept { - const int local_hidden_units_qo = mNumHeads * getHeadSize(); - const int local_hidden_units_kv = mNumKVHeads * getHeadSize(); - const bool chunked_context_support = mEnableContextFMHA && mPagedKVCache && mPagedContextFMHA; + int const local_hidden_units_qo = mNumHeads * getHeadSize(); + int const local_hidden_units_kv = mNumKVHeads * getHeadSize(); + bool const chunked_context_support = mEnableContextFMHA && mPagedKVCache && mPagedContextFMHA; auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize(); size_t context_workspace_size = 0; - const int batch_size = nbReq; + int const batch_size = nbReq; const size_t attention_mask_size = mEnableContextFMHA ? 0 : size * batch_size * input_seq_length * (isCrossAttention() ? cross_qkv_length : input_seq_length); @@ -564,7 +564,7 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t ? batch_size * 2 * TMA_DESC_SIZE_IN_BYTE * tc::divUp(max_attention_window, mTokensPerBlock) : 0; - const int NUM_BUFFERS = 12; + int const NUM_BUFFERS = 12; size_t workspaces[NUM_BUFFERS]; workspaces[0] = CUBLAS_WORKSPACE_SIZE; workspaces[1] = attention_mask_size; @@ -585,15 +585,15 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration( nvinfer1::DataType type, int32_t total_num_seq, int32_t max_attention_window) const noexcept { - const int local_hidden_units_qo = mNumHeads * getHeadSize(); - const int local_hidden_units_kv = mNumKVHeads * getHeadSize(); + int const local_hidden_units_qo = mNumHeads * getHeadSize(); + int const local_hidden_units_kv = mNumKVHeads * getHeadSize(); auto const size = tensorrt_llm::runtime::BufferDataType(type).getSize(); size_t context_workspace_size = 0; size_t generation_workspace_size = 0; - const int batch_beam = total_num_seq; + int const batch_beam = total_num_seq; int32_t const maxSeqLenTile = std::max(getMaxNumSeqLenTile(batch_beam), (int) tc::divUp(mMultiProcessorCount, mNumHeads)); @@ -605,7 +605,7 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForGeneration( ? 0 : size * batch_beam * mNumHeads * mHeadSize * max_attention_window; - const int NUM_BUFFERS = 5; + int const NUM_BUFFERS = 5; size_t workspaces[NUM_BUFFERS]; workspaces[0] = partial_out_size; workspaces[1] = partial_sum_size; @@ -638,20 +638,20 @@ int GPTAttentionPluginCommon::getMaxNumSeqLenTile(int batch_beam_size) const } template -int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams& params, cudaStream_t stream) +int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams const& params, cudaStream_t stream) { - const int num_heads = mNumHeads; - const int num_kv_heads = mNumKVHeads; - const int head_size = getHeadSize(); - const int local_hidden_units_qo = num_heads * head_size; - const int local_hidden_units_kv = num_kv_heads * head_size; + int const num_heads = mNumHeads; + int const num_kv_heads = mNumKVHeads; + int const head_size = getHeadSize(); + int const local_hidden_units_qo = num_heads * head_size; + int const local_hidden_units_kv = num_kv_heads * head_size; const PositionEmbeddingType position_embedding_type = mPositionEmbeddingType; - const float q_scaling = mQScaling; - const bool* finished = nullptr; - const bool has_ia3 = false; + float const q_scaling = mQScaling; + bool const* finished = nullptr; + bool const has_ia3 = false; KVCacheBuffer kv_cache_buffer; - const auto elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); + auto const elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); int64_t* host_kv_cache_block_ptrs = nullptr; if (mPagedKVCache) { @@ -670,16 +670,16 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(params.key_value_cache); } - const auto quant_option = tc::QuantMode::fromDescription(); - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; + auto const quant_option = tc::QuantMode::fromDescription(); + float const* qkv_scale_out = nullptr; + float const* attention_out_scale = nullptr; - const int* ia3_tasks = nullptr; - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; + int const* ia3_tasks = nullptr; + T const* ia3_key_weights = nullptr; + T const* ia3_value_weights = nullptr; - const bool multi_block_mode = false; - const int max_seq_len_tile = 0; + bool const multi_block_mode = false; + int const max_seq_len_tile = 0; T* partial_out = nullptr; float* partial_sum = nullptr; float* partial_max = nullptr; @@ -731,7 +731,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(params.workspace); @@ -805,15 +805,15 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams::value; - const int attention_seq_len_1 = params.input_seq_length; // q length - const int attention_seq_len_2 = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length; // kv length + int const attention_seq_len_1 = params.input_seq_length; // q length + int const attention_seq_len_2 = isCrossAttention() ? params.cross_qkv_length : params.input_seq_length; // kv length // If the model has relative attentiona bias, q scaling should be applied in QK gemm stage and use 1 in // softamax stage (because to get softmax[scale(Q*K) + rel pos bias] here, q_scaling can't be applied during // softmax phase by qk_scale); otherwise, use 1 in gemm stage and apply scaling in softmax stage - const float qk_scale + float const qk_scale = 1.0f / (sqrtf(getHeadSize() * 1.0f) * q_scaling); // q_scaling in denominator. by default q_scaling =1.0f - const float qk_scale_gemm = isRelativePosition() ? qk_scale : 1.0f; + float const qk_scale_gemm = isRelativePosition() ? qk_scale : 1.0f; const T qk_scale_softmax = static_cast(isRelativePosition() ? 1.0f : qk_scale); // in context phase, currently FMHA runner has two restrictions: @@ -822,7 +822,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParamssetup(params.batch_size, params.input_seq_length, attention_window_size, params.num_tokens, isALiBi(), isAliBiWithScale(), mTpSize, mTpRank); @@ -922,10 +922,10 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(qk_buf_float_) : static_cast(qk_buf_); if (mNumKVHeads == 1) // MQA @@ -974,12 +974,12 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams(qk_buf_float_ + qk_offset) : static_cast(qk_buf_ + qk_offset); mCublasWrapper->stridedBatchedGemm(CUBLAS_OP_T, CUBLAS_OP_N, @@ -1092,7 +1092,7 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams( - const EnqueueContextParams& params, cudaStream_t stream); + EnqueueContextParams const& params, cudaStream_t stream); template int GPTAttentionPluginCommon::enqueueContext( - const EnqueueContextParams& params, cudaStream_t stream); + EnqueueContextParams const& params, cudaStream_t stream); #ifdef ENABLE_BF16 template int GPTAttentionPluginCommon::enqueueContext<__nv_bfloat16, KVLinearBuffer>( - const EnqueueContextParams<__nv_bfloat16, KVLinearBuffer>& params, cudaStream_t stream); + EnqueueContextParams<__nv_bfloat16, KVLinearBuffer> const& params, cudaStream_t stream); #endif template int GPTAttentionPluginCommon::enqueueContext( - const EnqueueContextParams& params, cudaStream_t stream); + EnqueueContextParams const& params, cudaStream_t stream); template int GPTAttentionPluginCommon::enqueueContext( - const EnqueueContextParams& params, cudaStream_t stream); + EnqueueContextParams const& params, cudaStream_t stream); #ifdef ENABLE_BF16 template int GPTAttentionPluginCommon::enqueueContext<__nv_bfloat16, KVBlockArray>( - const EnqueueContextParams<__nv_bfloat16, KVBlockArray>& params, cudaStream_t stream); + EnqueueContextParams<__nv_bfloat16, KVBlockArray> const& params, cudaStream_t stream); #endif bool GPTAttentionPluginCommon::mForceMultiBlockWarned = false; template int GPTAttentionPluginCommon::enqueueGeneration( - const EnqueueGenerationParams& params, cudaStream_t stream) + EnqueueGenerationParams const& params, cudaStream_t stream) { - const int step = params.past_kv_length + 1; + int const step = params.past_kv_length + 1; - const int num_heads = mNumHeads; - const int num_kv_heads = mNumKVHeads; - const int head_size = getHeadSize(); - const int local_hidden_units_qo = num_heads * head_size; - const int local_hidden_units_kv = num_kv_heads * head_size; + int const num_heads = mNumHeads; + int const num_kv_heads = mNumKVHeads; + int const head_size = getHeadSize(); + int const local_hidden_units_qo = num_heads * head_size; + int const local_hidden_units_kv = num_kv_heads * head_size; const PositionEmbeddingType position_embedding_type = mPositionEmbeddingType; - const float q_scaling = mQScaling; - const T* relative_attention_bias = isRelativePosition() ? params.relative_attention_bias : nullptr; - const int relative_attention_bias_stride = isRelativePosition() ? params.relative_attention_bias_stride : 0; - const int max_distance = mMaxDistance; - const bool* finished = nullptr; - const bool has_ia3 = false; + float const q_scaling = mQScaling; + T const* relative_attention_bias = isRelativePosition() ? params.relative_attention_bias : nullptr; + int const relative_attention_bias_stride = isRelativePosition() ? params.relative_attention_bias_stride : 0; + int const max_distance = mMaxDistance; + bool const* finished = nullptr; + bool const has_ia3 = false; - const auto quant_option = tc::QuantMode::fromDescription(); - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; + auto const quant_option = tc::QuantMode::fromDescription(); + float const* qkv_scale_out = nullptr; + float const* attention_out_scale = nullptr; - const int* ia3_tasks = nullptr; - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; + int const* ia3_tasks = nullptr; + T const* ia3_key_weights = nullptr; + T const* ia3_value_weights = nullptr; int32_t const batch_beam = params.beam_width * params.num_requests; KVCacheBuffer kv_cache_buffer; - const auto elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); + auto const elem_size = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T); if (useKVCache()) { if (mPagedKVCache) @@ -1229,7 +1229,7 @@ int GPTAttentionPluginCommon::enqueueGeneration( } int timestep = params.past_kv_length; - const int max_timesteps = mCrossAttention ? params.cyclic_attention_window_size + int const max_timesteps = mCrossAttention ? params.cyclic_attention_window_size : std::min(timestep, params.cyclic_attention_window_size); int estimated_min_multi_block_count = estimate_min_multi_block_count(max_timesteps, mMaxSharedMemoryPerBlockOptin - 2048); @@ -1248,7 +1248,7 @@ int GPTAttentionPluginCommon::enqueueGeneration( // Runtime check to see the actual number of blocks per sequence we need. int32_t const max_num_seq_len_tiles = std::max(getMaxNumSeqLenTile(batch_beam), estimated_min_multi_block_count); int32_t const min_num_seq_len_tiles = std::max(1, estimated_min_multi_block_count); - const bool enable_multi_block + bool const enable_multi_block = (mMultiBlockMode && max_num_seq_len_tiles > 1) || estimated_min_multi_block_count > 1; const size_t partial_out_size = enable_multi_block ? sizeof(T) * batch_beam * mNumHeads * mHeadSize * max_num_seq_len_tiles : 0; @@ -1364,29 +1364,29 @@ int GPTAttentionPluginCommon::enqueueGeneration( } template int GPTAttentionPluginCommon::enqueueGeneration( - const EnqueueGenerationParams& params, cudaStream_t stream); + EnqueueGenerationParams const& params, cudaStream_t stream); template int GPTAttentionPluginCommon::enqueueGeneration( - const EnqueueGenerationParams& params, cudaStream_t stream); + EnqueueGenerationParams const& params, cudaStream_t stream); #ifdef ENABLE_BF16 template int GPTAttentionPluginCommon::enqueueGeneration<__nv_bfloat16, KVLinearBuffer>( - const EnqueueGenerationParams<__nv_bfloat16, KVLinearBuffer>& params, cudaStream_t stream); + EnqueueGenerationParams<__nv_bfloat16, KVLinearBuffer> const& params, cudaStream_t stream); #endif template int GPTAttentionPluginCommon::enqueueGeneration( - const EnqueueGenerationParams& params, cudaStream_t stream); + EnqueueGenerationParams const& params, cudaStream_t stream); template int GPTAttentionPluginCommon::enqueueGeneration( - const EnqueueGenerationParams& params, cudaStream_t stream); + EnqueueGenerationParams const& params, cudaStream_t stream); #ifdef ENABLE_BF16 template int GPTAttentionPluginCommon::enqueueGeneration<__nv_bfloat16, KVBlockArray>( - const EnqueueGenerationParams<__nv_bfloat16, KVBlockArray>& params, cudaStream_t stream); + EnqueueGenerationParams<__nv_bfloat16, KVBlockArray> const& params, cudaStream_t stream); #endif template -void GPTAttentionPluginCommon::prepareEnqueueGeneration(const EnqueueGenerationParams& params) +void GPTAttentionPluginCommon::prepareEnqueueGeneration(EnqueueGenerationParams const& params) { // self attn XQAParams xqaParams{}; @@ -1399,25 +1399,25 @@ void GPTAttentionPluginCommon::prepareEnqueueGeneration(const EnqueueGenerationP } template void GPTAttentionPluginCommon::prepareEnqueueGeneration( - const EnqueueGenerationParams& params); + EnqueueGenerationParams const& params); template void GPTAttentionPluginCommon::prepareEnqueueGeneration( - const EnqueueGenerationParams& params); + EnqueueGenerationParams const& params); #ifdef ENABLE_BF16 template void GPTAttentionPluginCommon::prepareEnqueueGeneration<__nv_bfloat16, KVLinearBuffer>( - const EnqueueGenerationParams<__nv_bfloat16, KVLinearBuffer>& params); + EnqueueGenerationParams<__nv_bfloat16, KVLinearBuffer> const& params); #endif template void GPTAttentionPluginCommon::prepareEnqueueGeneration( - const EnqueueGenerationParams& params); + EnqueueGenerationParams const& params); template void GPTAttentionPluginCommon::prepareEnqueueGeneration( - const EnqueueGenerationParams& params); + EnqueueGenerationParams const& params); #ifdef ENABLE_BF16 template void GPTAttentionPluginCommon::prepareEnqueueGeneration<__nv_bfloat16, KVBlockArray>( - const EnqueueGenerationParams<__nv_bfloat16, KVBlockArray>& params); + EnqueueGenerationParams<__nv_bfloat16, KVBlockArray> const& params); #endif int GPTAttentionPluginCommon::initialize() noexcept @@ -1595,7 +1595,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon() mFC.fields = mPluginAttributes.data(); } -const PluginFieldCollection* GPTAttentionPluginCreatorCommon::getFieldNames() noexcept +PluginFieldCollection const* GPTAttentionPluginCreatorCommon::getFieldNames() noexcept { return &mFC; } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h index 9c754fb40..65f1ab8bd 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommon.h @@ -49,13 +49,13 @@ class GPTAttentionPluginCommon : public BasePlugin bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true, bool is_medusa_enabled = false); - GPTAttentionPluginCommon(const void* data, size_t length); + GPTAttentionPluginCommon(void const* data, size_t length); ~GPTAttentionPluginCommon() override = default; template - int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); //! This is called on every trt Engine creation int initialize() noexcept override; @@ -74,7 +74,7 @@ class GPTAttentionPluginCommon : public BasePlugin static size_t getCommonSerializationSize() noexcept; void serializeCommon(void* buffer) const noexcept; - const int getHeadSize(bool checkInit = true) const; + int const getHeadSize(bool checkInit = true) const; protected: int getMaxNumSeqLenTile(int batch_beam_size = 1) const; @@ -112,7 +112,7 @@ class GPTAttentionPluginCommon : public BasePlugin int32_t max_blocks_per_sequence; void* workspace; // optional when relative position - const T* relative_attention_bias = nullptr; + T const* relative_attention_bias = nullptr; int relative_attention_bias_stride = 0; // optional when cross attention T const* cross_qkv = nullptr; @@ -122,7 +122,7 @@ class GPTAttentionPluginCommon : public BasePlugin }; template - int enqueueContext(const EnqueueContextParams& params, cudaStream_t stream); + int enqueueContext(EnqueueContextParams const& params, cudaStream_t stream); template struct EnqueueGenerationParams @@ -154,27 +154,27 @@ class GPTAttentionPluginCommon : public BasePlugin void* workspace; int32_t const* host_past_key_value_lengths; // optional when relative position - const T* relative_attention_bias = nullptr; + T const* relative_attention_bias = nullptr; int relative_attention_bias_stride = 0; // optional when cross attention int32_t const* encoder_input_lengths = nullptr; int32_t const* host_context_lengths = nullptr; // optional when medusa is used. - const bool* medusa_mask = nullptr; - const int32_t* medusa_packed_mask = nullptr; - const int32_t* medusa_position_offsets = nullptr; + bool const* medusa_mask = nullptr; + int32_t const* medusa_packed_mask = nullptr; + int32_t const* medusa_position_offsets = nullptr; }; template - int enqueueGeneration(const EnqueueGenerationParams& params, cudaStream_t stream); + int enqueueGeneration(EnqueueGenerationParams const& params, cudaStream_t stream); // Called in configurePlugin(). template - void prepareEnqueueGeneration(const EnqueueGenerationParams& params); + void prepareEnqueueGeneration(EnqueueGenerationParams const& params); template bool convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams, - const EnqueueGenerationParams& generationsParams, bool forConfigurePlugin); + EnqueueGenerationParams const& generationsParams, bool forConfigurePlugin); bool isRelativePosition() const { @@ -276,10 +276,10 @@ class GPTAttentionPluginCreatorCommon : public BaseCreator public: GPTAttentionPluginCreatorCommon(); - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; template - T* deserializePluginImpl(const char* name, const void* serialData, size_t serialLength) noexcept; + T* deserializePluginImpl(char const* name, void const* serialData, size_t serialLength) noexcept; protected: std::vector mPluginAttributes; diff --git a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h index 35dc7d15b..51462cee6 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionCommon/gptAttentionCommonImpl.h @@ -35,7 +35,7 @@ T* GPTAttentionPluginCommon::cloneImpl() const noexcept template T* GPTAttentionPluginCreatorCommon::deserializePluginImpl( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call GPTAttentionPluginCommon::destroy() @@ -45,7 +45,7 @@ T* GPTAttentionPluginCreatorCommon::deserializePluginImpl( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp index 9dd606d38..2083ff2b0 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp @@ -34,8 +34,8 @@ using namespace tensorrt_llm::common; using tensorrt_llm::plugins::GPTAttentionPluginCreator; using tensorrt_llm::plugins::GPTAttentionPlugin; -static const char* GPT_ATTENTION_PLUGIN_VERSION{"1"}; -static const char* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"}; +static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"}; +static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"}; GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_heads, int head_size, int unidirectional, float q_scaling, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type, @@ -58,13 +58,13 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_ initEntryIdx(); } -GPTAttentionPlugin::GPTAttentionPlugin(const void* data, size_t length) +GPTAttentionPlugin::GPTAttentionPlugin(void const* data, size_t length) : GPTAttentionPluginCommon(data, length) { initEntryIdx(); } -bool GPTAttentionPlugin::isEntryUsed(const IdxEntry& entry) const +bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const { switch (entry) { @@ -107,7 +107,7 @@ void GPTAttentionPlugin::initEntryIdx() } } -GPTAttentionPlugin::IndexType GPTAttentionPlugin::getIdx(const IdxEntry& entry) const +GPTAttentionPlugin::IndexType GPTAttentionPlugin::getIdx(IdxEntry const& entry) const { TLLM_CHECK_WITH_INFO( isEntryUsed(entry), common::fmtstr("getIdx() should not be used with entry %lu\n", static_cast(entry))); @@ -127,7 +127,7 @@ static int getPackedTensorHiddenDimIndex(bool removePadding) // NOTE: generation input length might be larger than one in the Medusa mode. int GPTAttentionPlugin::getGenerationInputSequenceLength( - const nvinfer1::PluginTensorDesc* inputDesc, int32_t localNbSeq, int32_t localNbTokens) const + nvinfer1::PluginTensorDesc const* inputDesc, int32_t localNbSeq, int32_t localNbTokens) const { if (mRemovePadding) { @@ -150,7 +150,7 @@ int GPTAttentionPlugin::getGenerationInputSequenceLength( // present_key_value_pool (optional if mPagedKVCache is false) [batch_size, 2, local_num_kv_heads, max_seq_len, // head_size] nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { TLLM_CHECK(outputIndex == 0 || (!mPagedKVCache && useKVCache() && outputIndex == 1)); if (outputIndex == 0) @@ -164,7 +164,7 @@ nvinfer1::DimsExprs GPTAttentionPlugin::getOutputDimensions( } bool GPTAttentionPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { if (pos == getIdx(IdxEntry::CONTEXT_LENGTHS) || pos == getIdx(IdxEntry::REQUEST_TYPES) || pos == getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW) || pos == getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH) @@ -221,23 +221,23 @@ bool GPTAttentionPlugin::supportsFormatCombination( } template -void GPTAttentionPlugin::configurePluginImpl(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { TLLM_CHECK(mHeadSize > 0); - const int beamWidth = useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1] : 1; + int const beamWidth = useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[1] : 1; // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. // the kv_cache capacity. int max_encoder_context_len = isCrossAttention() ? in[getIdx(IdxEntry::CROSS_QKV_LENGTH)].desc.dims.d[0] : 0; - const int max_attention_window_size = isCrossAttention() + int const max_attention_window_size = isCrossAttention() ? max_encoder_context_len : (useKVCache() ? in[getIdx(IdxEntry::CACHE_INDIR)].desc.dims.d[2] : 0); - const int cyclic_attention_window_size = max_attention_window_size; + int const cyclic_attention_window_size = max_attention_window_size; - const int num_requests = 256; - const int sink_token_length = 0; + int const num_requests = 256; + int const sink_token_length = 0; EnqueueGenerationParams enqueueParams{/*attention_input=*/nullptr, /*qkv_bias=*/nullptr, @@ -261,8 +261,8 @@ void GPTAttentionPlugin::configurePluginImpl(const nvinfer1::DynamicPluginTensor } template -void GPTAttentionPlugin::configurePluginDispatchKVCacheType(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void GPTAttentionPlugin::configurePluginDispatchKVCacheType(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { if (mPagedKVCache) { @@ -274,8 +274,8 @@ void GPTAttentionPlugin::configurePluginDispatchKVCacheType(const nvinfer1::Dyna } } -void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void GPTAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { if (mType == nvinfer1::DataType::kHALF) { @@ -293,20 +293,20 @@ void GPTAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc #endif } -size_t GPTAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t GPTAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { - const int max_context_length = mMaxContextLength; - const int cross_qkv_length = isCrossAttention() ? inputs[getIdx(IdxEntry::CROSS_QKV_LENGTH)].dims.d[0] : 0; - const int nbReq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; + int const max_context_length = mMaxContextLength; + int const cross_qkv_length = isCrossAttention() ? inputs[getIdx(IdxEntry::CROSS_QKV_LENGTH)].dims.d[0] : 0; + int const nbReq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; auto const type = inputs[getIdx(IdxEntry::QKV_TENSOR)].type; - const int max_kv_cache_length + int const max_kv_cache_length = isCrossAttention() ? cross_qkv_length : (useKVCache() ? inputs[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0); - const int max_num_tokens = inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0]; + int const max_num_tokens = inputs[getIdx(IdxEntry::QKV_TENSOR)].dims.d[0]; size_t const context_workspace_size = getWorkspaceSizeForContext( type, nbReq, max_context_length, max_kv_cache_length, cross_qkv_length, max_num_tokens); - const int total_num_seq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; + int const total_num_seq = inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; size_t const generation_workspace_size = getWorkspaceSizeForGeneration(type, total_num_seq, max_kv_cache_length); size_t attention_input_workspace_size = 0; @@ -333,8 +333,8 @@ static size_t getStride(nvinfer1::Dims const& dims, int n) } template -int GPTAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { int32_t const nbSeq = inputDesc[getIdx(IdxEntry::CONTEXT_LENGTHS)].dims.d[0]; @@ -394,8 +394,8 @@ int GPTAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, template int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t tokenIdxBeg, int32_t localNbTokens, - const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) + nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { // relative_attention_bias [head_num, max_seq_len, max_seq_len] (optional in relative position) // or [head_num, num_buckets] (optional in implicit relative attention) @@ -404,13 +404,13 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 // cross_qkv_length [int] max encoder input context length (optional in cross attention mode) // encoder_input_lengths [batch_size] raw sequence lengths (optional in cross attention mode) - const T* attention_input = static_cast(inputs[getIdx(IdxEntry::QKV_TENSOR)]) + T const* attention_input = static_cast(inputs[getIdx(IdxEntry::QKV_TENSOR)]) + inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[getPackedTensorHiddenDimIndex(mRemovePadding)] * size_t(tokenIdxBeg); - const T* qkv_bias = nullptr; + T const* qkv_bias = nullptr; if (mQKVBiasEnabled) { - qkv_bias = reinterpret_cast(inputs[getIdx(IdxEntry::QKV_BIAS_TENSOR)]); + qkv_bias = reinterpret_cast(inputs[getIdx(IdxEntry::QKV_BIAS_TENSOR)]); } auto const reqTypeInBatchPtr = static_cast(inputs[getIdx(IdxEntry::REQUEST_TYPES)]) + seqIdxBeg; @@ -448,9 +448,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 attention_input = attention_input_qkv + hidden_units * tokenIdxBeg; } - const int* context_q_lengths = reinterpret_cast(inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)]) + seqIdxBeg; - const int* sequence_kv_length = useKVCache() - ? static_cast(inputs[getIdx(IdxEntry::SEQUENCE_LENGTH)]) + seqIdxBeg + int const* context_q_lengths = reinterpret_cast(inputs[getIdx(IdxEntry::CONTEXT_LENGTHS)]) + seqIdxBeg; + int const* sequence_kv_length = useKVCache() + ? static_cast(inputs[getIdx(IdxEntry::SEQUENCE_LENGTH)]) + seqIdxBeg : context_q_lengths; // Note we still need context length during generation for MMHA optimization. int32_t const max_context_q_len = [&]() @@ -473,29 +473,29 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 // -- max_seq_len: max allowed len of decoder output, i.e. final results // -- max_encoder_context_len: len of encoder input (in cross attn). Also called encoder_input_seq_length - const int beamWidth = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1; + int const beamWidth = useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[1] : 1; // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. // the kv_cache capacity. - const int max_attention_window_size = isCrossAttention() + int const max_attention_window_size = isCrossAttention() ? max_encoder_context_len : (useKVCache() ? inputDesc[getIdx(IdxEntry::CACHE_INDIR)].dims.d[2] : 0); // The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens. // Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity. - const int cyclic_attention_window_size = isCrossAttention() + int const cyclic_attention_window_size = isCrossAttention() ? max_encoder_context_len - : reinterpret_cast(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[mLayerIdx]; - const int sink_token_length = reinterpret_cast(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0]; + : reinterpret_cast(inputs[getIdx(IdxEntry::HOST_MAX_ATTENTION_WINDOW)])[mLayerIdx]; + int const sink_token_length = reinterpret_cast(inputs[getIdx(IdxEntry::HOST_SINK_TOKEN_LENGTH)])[0]; - const float* kv_scale_orig_quant = nullptr; - const float* kv_scale_quant_orig = nullptr; + float const* kv_scale_orig_quant = nullptr; + float const* kv_scale_quant_orig = nullptr; if (useKVCache() && mKVCacheQuantMode.hasKvCacheQuant()) { assert(inputDesc[getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT); assert(inputDesc[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT); - kv_scale_orig_quant = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)]); - kv_scale_quant_orig = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)]); + kv_scale_orig_quant = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_QUANTIZATION_SCALE)]); + kv_scale_quant_orig = reinterpret_cast(inputs[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)]); } int max_blocks_per_sequence = 0; @@ -536,10 +536,10 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 } } - const T* alibi_slopes = isALiBi() ? static_cast(inputs[getIdx(IdxEntry::ALIBI_SLOPES)]) : nullptr; + T const* alibi_slopes = isALiBi() ? static_cast(inputs[getIdx(IdxEntry::ALIBI_SLOPES)]) : nullptr; - const int* medusa_packed_mask = nullptr; - const int* medusa_position_offsets = nullptr; + int const* medusa_packed_mask = nullptr; + int const* medusa_position_offsets = nullptr; int num_medusa_tokens = 0; if (mIsMedusaEnabled) { @@ -548,15 +548,15 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 num_medusa_tokens = inputDesc[getIdx(IdxEntry::MEDUSA_PACKED_MASK)].dims.d[1] - 1; if (num_medusa_tokens > 0) { - medusa_packed_mask = static_cast(inputs[getIdx(IdxEntry::MEDUSA_PACKED_MASK)]) + medusa_packed_mask = static_cast(inputs[getIdx(IdxEntry::MEDUSA_PACKED_MASK)]) + seqIdxBeg * getStride(inputDesc[getIdx(IdxEntry::MEDUSA_PACKED_MASK)].dims, 0); - medusa_position_offsets = static_cast(inputs[getIdx(IdxEntry::MEDUSA_POSITION_OFFSETS)]) + medusa_position_offsets = static_cast(inputs[getIdx(IdxEntry::MEDUSA_POSITION_OFFSETS)]) + seqIdxBeg * getStride(inputDesc[getIdx(IdxEntry::MEDUSA_POSITION_OFFSETS)].dims, 0); } } int32_t const* max_context_kv_len_list = useKVCache() - ? static_cast(inputs[getIdx(IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS)]) + seqIdxBeg + ? static_cast(inputs[getIdx(IdxEntry::HOST_PAST_KEY_VALUE_LENGTHS)]) + seqIdxBeg : nullptr; int32_t const max_context_kv_len = useKVCache() ? *std::max_element(max_context_kv_len_list, max_context_kv_len_list + localNbSeq) @@ -566,8 +566,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 { TLLM_CHECK(max_context_q_len <= mMaxContextLength); - const int batch_size = localNbSeq; - const int request_batch_size = batch_size; + int const batch_size = localNbSeq; + int const request_batch_size = batch_size; // num of total tokens (without paddings when remove paddings). int num_encoder_tokens = 0; if (isCrossAttention()) @@ -590,16 +590,16 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 if (isRelativePosition()) { enqueue_params.relative_attention_bias - = static_cast(inputs[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)]); + = static_cast(inputs[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)]); enqueue_params.relative_attention_bias_stride = inputDesc[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)].dims.d[1]; // max_seq_len or num_buckets } if (isCrossAttention()) { - enqueue_params.cross_qkv = static_cast(inputs[getIdx(IdxEntry::CROSS_QKV)]); + enqueue_params.cross_qkv = static_cast(inputs[getIdx(IdxEntry::CROSS_QKV)]); enqueue_params.cross_qkv_length = max_encoder_context_len; enqueue_params.encoder_input_lengths - = reinterpret_cast(inputs[getIdx(IdxEntry::ENCODER_INPUT_LENGTH)]) + seqIdxBeg; + = reinterpret_cast(inputs[getIdx(IdxEntry::ENCODER_INPUT_LENGTH)]) + seqIdxBeg; enqueue_params.num_encoder_tokens = num_encoder_tokens; } @@ -612,12 +612,12 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 TLLM_CHECK(batch_beam % beamWidth == 0); int32_t const num_requests = batch_beam / beamWidth; - const int* cache_indir - = beamWidth == 1 ? nullptr : reinterpret_cast(inputs[getIdx(IdxEntry::CACHE_INDIR)]); - const int* host_context_lengths - = mRemovePadding ? reinterpret_cast(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)]) : nullptr; + int const* cache_indir + = beamWidth == 1 ? nullptr : reinterpret_cast(inputs[getIdx(IdxEntry::CACHE_INDIR)]); + int const* host_context_lengths + = mRemovePadding ? reinterpret_cast(inputs[getIdx(IdxEntry::HOST_CONTEXT_LENGTH)]) : nullptr; - const int input_seq_length = getGenerationInputSequenceLength(inputDesc, localNbSeq, localNbTokens); + int const input_seq_length = getGenerationInputSequenceLength(inputDesc, localNbSeq, localNbTokens); auto qkvDims = inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims; TLLM_CHECK_WITH_INFO(input_seq_length == 1 || mIsMedusaEnabled, "Only Medusa mode supports input length > 1 in the generation phase, input_seq_length=%d, " @@ -634,14 +634,14 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 if (isRelativePosition()) { enqueue_params.relative_attention_bias - = static_cast(inputs[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)]); + = static_cast(inputs[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)]); enqueue_params.relative_attention_bias_stride = inputDesc[getIdx(IdxEntry::RELATIVE_ATTENTION_BIAS)].dims.d[1]; // max_seq_len or num_buckets } if (isCrossAttention()) { enqueue_params.encoder_input_lengths - = reinterpret_cast(inputs[getIdx(IdxEntry::ENCODER_INPUT_LENGTH)]) + seqIdxBeg; + = reinterpret_cast(inputs[getIdx(IdxEntry::ENCODER_INPUT_LENGTH)]) + seqIdxBeg; } if (mIsMedusaEnabled) { @@ -656,8 +656,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32 } template -int GPTAttentionPlugin::enqueueDispatchKVCacheType(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int GPTAttentionPlugin::enqueueDispatchKVCacheType(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { if (mPagedKVCache) @@ -671,8 +671,8 @@ int GPTAttentionPlugin::enqueueDispatchKVCacheType(const nvinfer1::PluginTensorD return 0; } -int GPTAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int GPTAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (mType == nvinfer1::DataType::kHALF) @@ -694,7 +694,7 @@ int GPTAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType GPTAttentionPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0 || (!mPagedKVCache && index == 1)); if (index == 0) @@ -709,12 +709,12 @@ nvinfer1::DataType GPTAttentionPlugin::getOutputDataType( // IPluginV2 Methods -const char* GPTAttentionPlugin::getPluginType() const noexcept +char const* GPTAttentionPlugin::getPluginType() const noexcept { return GPT_ATTENTION_PLUGIN_NAME; } -const char* GPTAttentionPlugin::getPluginVersion() const noexcept +char const* GPTAttentionPlugin::getPluginVersion() const noexcept { return GPT_ATTENTION_PLUGIN_VERSION; } @@ -745,22 +745,22 @@ GPTAttentionPluginCreator::GPTAttentionPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* GPTAttentionPluginCreator::getPluginName() const noexcept +char const* GPTAttentionPluginCreator::getPluginName() const noexcept { return GPT_ATTENTION_PLUGIN_NAME; } -const char* GPTAttentionPluginCreator::getPluginVersion() const noexcept +char const* GPTAttentionPluginCreator::getPluginVersion() const noexcept { return GPT_ATTENTION_PLUGIN_VERSION; } -const PluginFieldCollection* GPTAttentionPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* GPTAttentionPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { PluginFieldParser p{fc->nbFields, fc->fields}; @@ -799,7 +799,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -807,7 +807,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(const char* name, const Plugi } IPluginV2* GPTAttentionPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call GPTAttentionPlugin::destroy() @@ -817,7 +817,7 @@ IPluginV2* GPTAttentionPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h index 9b7a1dc62..90a5cb48f 100644 --- a/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +++ b/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h @@ -83,46 +83,46 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true, bool is_medusa_enabled = false); - GPTAttentionPlugin(const void* data, size_t length); + GPTAttentionPlugin(void const* data, size_t length); ~GPTAttentionPlugin() override = default; // IPluginV2DynamicExt Methods - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; template - int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); template - int enqueueDispatchKVCacheType(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, + int enqueueDispatchKVCacheType(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); template - void configurePluginImpl(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept; + void configurePluginImpl(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept; template - void configurePluginDispatchKVCacheType(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; + void configurePluginDispatchKVCacheType(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; //! This is called on every trt ExecutionContext creation by TRT @@ -141,8 +141,8 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon private: template int enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32_t tokenIdxBeg, int32_t localNbTokens, - const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); using IndexType = std::int32_t; @@ -176,13 +176,13 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon ENUM_SIZE, }; - bool isEntryUsed(const IdxEntry& entry) const; + bool isEntryUsed(IdxEntry const& entry) const; void initEntryIdx(); - IndexType getIdx(const IdxEntry& entry) const; + IndexType getIdx(IdxEntry const& entry) const; // Get generation input sequence length (might be larger than 1 in the Medusa mode). int getGenerationInputSequenceLength( - const nvinfer1::PluginTensorDesc* inputDesc, int32_t localNbSeq, int32_t localNbTokens) const; + nvinfer1::PluginTensorDesc const* inputDesc, int32_t localNbSeq, int32_t localNbTokens) const; }; class GPTAttentionPluginCreator : public GPTAttentionPluginCreatorCommon @@ -190,16 +190,16 @@ class GPTAttentionPluginCreator : public GPTAttentionPluginCreatorCommon public: GPTAttentionPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; }; } // namespace tensorrt_llm::plugins diff --git a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp index ad2f0c39f..b4f3e46dc 100644 --- a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.cpp @@ -21,17 +21,17 @@ using namespace nvinfer1; using tensorrt_llm::plugins::IdentityPluginCreator; using tensorrt_llm::plugins::IdentityPlugin; -static const char* IDENTITY_PLUGIN_VERSION{"1"}; -static const char* IDENTITY_PLUGIN_NAME{"Identity"}; +static char const* IDENTITY_PLUGIN_VERSION{"1"}; +static char const* IDENTITY_PLUGIN_NAME{"Identity"}; PluginFieldCollection IdentityPluginCreator::mFC{}; std::vector IdentityPluginCreator::mPluginAttributes; IdentityPlugin::IdentityPlugin() {} // Parameterized constructor -IdentityPlugin::IdentityPlugin(const void* data, size_t length) +IdentityPlugin::IdentityPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; TLLM_CHECK_WITH_INFO(d == a + length, "Expected length (%d) != real length (%d). This is often " "caused by using different TensorRT-LLM version to build " @@ -48,17 +48,17 @@ nvinfer1::IPluginV2DynamicExt* IdentityPlugin::clone() const noexcept } nvinfer1::DimsExprs IdentityPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { return inputs[outputIndex]; } bool IdentityPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { assert(0 <= pos && pos < 2); - const PluginTensorDesc& input = inOut[0]; - const PluginTensorDesc& output = inOut[1]; + PluginTensorDesc const& input = inOut[0]; + PluginTensorDesc const& output = inOut[1]; switch (pos) { case 0: return input.format == nvinfer1::TensorFormat::kLINEAR; @@ -67,19 +67,19 @@ bool IdentityPlugin::supportsFormatCombination( return false; } -void IdentityPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void IdentityPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t IdentityPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t IdentityPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int IdentityPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int IdentityPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { size_t count = 1; for (int i = 0; i < inputDesc[0].dims.nbDims; ++i) @@ -95,7 +95,7 @@ int IdentityPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const n // IPluginV2Ext Methods nvinfer1::DataType IdentityPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -103,12 +103,12 @@ nvinfer1::DataType IdentityPlugin::getOutputDataType( // IPluginV2 Methods -const char* IdentityPlugin::getPluginType() const noexcept +char const* IdentityPlugin::getPluginType() const noexcept { return IDENTITY_PLUGIN_NAME; } -const char* IdentityPlugin::getPluginVersion() const noexcept +char const* IdentityPlugin::getPluginVersion() const noexcept { return IDENTITY_PLUGIN_VERSION; } @@ -148,22 +148,22 @@ IdentityPluginCreator::IdentityPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* IdentityPluginCreator::getPluginName() const noexcept +char const* IdentityPluginCreator::getPluginName() const noexcept { return IDENTITY_PLUGIN_NAME; } -const char* IdentityPluginCreator::getPluginVersion() const noexcept +char const* IdentityPluginCreator::getPluginVersion() const noexcept { return IDENTITY_PLUGIN_VERSION; } -const PluginFieldCollection* IdentityPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* IdentityPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* IdentityPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* IdentityPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { try { @@ -171,7 +171,7 @@ IPluginV2* IdentityPluginCreator::createPlugin(const char* name, const PluginFie obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -179,7 +179,7 @@ IPluginV2* IdentityPluginCreator::createPlugin(const char* name, const PluginFie } IPluginV2* IdentityPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call IdentityPlugin::destroy() @@ -189,7 +189,7 @@ IPluginV2* IdentityPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.h b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.h index 9cbc20aac..9ab10601a 100644 --- a/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.h +++ b/cpp/tensorrt_llm/plugins/identityPlugin/identityPlugin.h @@ -30,30 +30,30 @@ class IdentityPlugin : public BasePlugin public: IdentityPlugin(); - IdentityPlugin(const void* data, size_t length); + IdentityPlugin(void const* data, size_t length); ~IdentityPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -70,16 +70,16 @@ class IdentityPluginCreator : public BaseCreator public: IdentityPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp index 831f3a9f0..9824795b8 100644 --- a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.cpp @@ -23,8 +23,8 @@ using namespace tensorrt_llm::common; using tensorrt_llm::plugins::LayernormQuantizationPluginCreator; using tensorrt_llm::plugins::LayernormQuantizationPlugin; -static const char* LAYERNORM_QUANTIZATION_PLUGIN_VERSION{"1"}; -static const char* LAYERNORM_QUANTIZATION_PLUGIN_NAME{"LayernormQuantization"}; +static char const* LAYERNORM_QUANTIZATION_PLUGIN_VERSION{"1"}; +static char const* LAYERNORM_QUANTIZATION_PLUGIN_NAME{"LayernormQuantization"}; PluginFieldCollection LayernormQuantizationPluginCreator::mFC{}; std::vector LayernormQuantizationPluginCreator::mPluginAttributes; @@ -38,9 +38,9 @@ LayernormQuantizationPlugin::LayernormQuantizationPlugin( } // Parameterized constructor -LayernormQuantizationPlugin::LayernormQuantizationPlugin(const void* data, size_t length) +LayernormQuantizationPlugin::LayernormQuantizationPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mEps); read(d, mUseDiffOfSquares); read(d, mDynActScaling); @@ -61,7 +61,7 @@ nvinfer1::IPluginV2DynamicExt* LayernormQuantizationPlugin::clone() const noexce } nvinfer1::DimsExprs LayernormQuantizationPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { if (outputIndex == 0) { @@ -82,7 +82,7 @@ nvinfer1::DimsExprs LayernormQuantizationPlugin::getOutputDimensions( ret.d[ret.nbDims - 1] = exprBuilder.constant(1); return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -90,9 +90,9 @@ nvinfer1::DimsExprs LayernormQuantizationPlugin::getOutputDimensions( } bool LayernormQuantizationPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { - const int totalPoses = 6 + static_cast(mDynActScaling); + int const totalPoses = 6 + static_cast(mDynActScaling); TLLM_CHECK(0 <= pos && pos < totalPoses); TLLM_CHECK(nbInputs == 4); if (pos < nbInputs) @@ -114,19 +114,19 @@ bool LayernormQuantizationPlugin::supportsFormatCombination( return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); } -void LayernormQuantizationPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void LayernormQuantizationPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t LayernormQuantizationPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t LayernormQuantizationPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int LayernormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -143,25 +143,25 @@ int LayernormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* input { m *= inputDesc[0].dims.d[i]; } - const int n = inputDesc[1].dims.d[0]; + int const n = inputDesc[1].dims.d[0]; - const float* scale = reinterpret_cast(inputs[3]); + float const* scale = reinterpret_cast(inputs[3]); int8_t* output = reinterpret_cast(outputs[0]); float* dynamic_scale = mDynActScaling ? reinterpret_cast(outputs[1]) : nullptr; if (mType == DataType::kHALF) { - const half* input = reinterpret_cast(inputs[0]); - const half* weight = reinterpret_cast(inputs[1]); - const half* bias = reinterpret_cast(inputs[2]); + half const* input = reinterpret_cast(inputs[0]); + half const* weight = reinterpret_cast(inputs[1]); + half const* bias = reinterpret_cast(inputs[2]); invokeGeneralLayerNorm( (half*) nullptr, input, weight, bias, mEps, m, n, stream, mUseDiffOfSquares, scale, dynamic_scale, output); } else if (mType == DataType::kFLOAT) { - const float* input = reinterpret_cast(inputs[0]); - const float* weight = reinterpret_cast(inputs[1]); - const float* bias = reinterpret_cast(inputs[2]); + float const* input = reinterpret_cast(inputs[0]); + float const* weight = reinterpret_cast(inputs[1]); + float const* bias = reinterpret_cast(inputs[2]); invokeGeneralLayerNorm( (float*) nullptr, input, weight, bias, mEps, m, n, stream, mUseDiffOfSquares, scale, dynamic_scale, output); } @@ -171,7 +171,7 @@ int LayernormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* input // IPluginV2Ext Methods nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0)); if (index == 0) @@ -185,12 +185,12 @@ nvinfer1::DataType LayernormQuantizationPlugin::getOutputDataType( // IPluginV2 Methods -const char* LayernormQuantizationPlugin::getPluginType() const noexcept +char const* LayernormQuantizationPlugin::getPluginType() const noexcept { return LAYERNORM_QUANTIZATION_PLUGIN_NAME; } -const char* LayernormQuantizationPlugin::getPluginVersion() const noexcept +char const* LayernormQuantizationPlugin::getPluginVersion() const noexcept { return LAYERNORM_QUANTIZATION_PLUGIN_VERSION; } @@ -242,24 +242,24 @@ LayernormQuantizationPluginCreator::LayernormQuantizationPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* LayernormQuantizationPluginCreator::getPluginName() const noexcept +char const* LayernormQuantizationPluginCreator::getPluginName() const noexcept { return LAYERNORM_QUANTIZATION_PLUGIN_NAME; } -const char* LayernormQuantizationPluginCreator::getPluginVersion() const noexcept +char const* LayernormQuantizationPluginCreator::getPluginVersion() const noexcept { return LAYERNORM_QUANTIZATION_PLUGIN_VERSION; } -const PluginFieldCollection* LayernormQuantizationPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* LayernormQuantizationPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* LayernormQuantizationPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* LayernormQuantizationPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; float eps; nvinfer1::DataType type; bool useDiffOfSquares; @@ -267,26 +267,26 @@ IPluginV2* LayernormQuantizationPluginCreator::createPlugin(const char* name, co // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "eps")) { TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); - eps = static_cast(*(static_cast(fields[i].data))); + eps = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "dyn_act_scaling")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - dynamicActivationScaling = static_cast(*(static_cast(fields[i].data))); + dynamicActivationScaling = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "use_diff_of_squares")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - useDiffOfSquares = static_cast(*(static_cast(fields[i].data))); + useDiffOfSquares = static_cast(*(static_cast(fields[i].data))); } } try @@ -295,7 +295,7 @@ IPluginV2* LayernormQuantizationPluginCreator::createPlugin(const char* name, co obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -303,7 +303,7 @@ IPluginV2* LayernormQuantizationPluginCreator::createPlugin(const char* name, co } IPluginV2* LayernormQuantizationPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call LayernormQuantizationPlugin::destroy() @@ -313,7 +313,7 @@ IPluginV2* LayernormQuantizationPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h index 22f33bd94..f8c362dc4 100644 --- a/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h +++ b/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h @@ -31,30 +31,30 @@ class LayernormQuantizationPlugin : public BasePlugin LayernormQuantizationPlugin( float eps, bool useDiffOfSquares, bool dynamicActivationScaling, nvinfer1::DataType type); - LayernormQuantizationPlugin(const void* data, size_t length); + LayernormQuantizationPlugin(void const* data, size_t length); ~LayernormQuantizationPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -76,16 +76,16 @@ class LayernormQuantizationPluginCreator : public BaseCreator public: LayernormQuantizationPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.cpp b/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.cpp index 5570f15bb..45653a05e 100644 --- a/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.cpp @@ -27,8 +27,8 @@ using namespace tensorrt_llm::common; using tensorrt_llm::plugins::LookupPluginCreator; using tensorrt_llm::plugins::LookupPlugin; -static const char* LOOKUP_PLUGIN_VERSION{"1"}; -static const char* LOOKUP_PLUGIN_NAME{"Lookup"}; +static char const* LOOKUP_PLUGIN_VERSION{"1"}; +static char const* LOOKUP_PLUGIN_NAME{"Lookup"}; PluginFieldCollection LookupPluginCreator::mFC{}; std::vector LookupPluginCreator::mPluginAttributes; @@ -39,9 +39,9 @@ LookupPlugin::LookupPlugin(nvinfer1::DataType type, int rank) } // Parameterized constructor -LookupPlugin::LookupPlugin(const void* data, size_t length) +LookupPlugin::LookupPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); read(d, mRank); TLLM_CHECK_WITH_INFO(d == a + length, @@ -61,15 +61,15 @@ nvinfer1::IPluginV2DynamicExt* LookupPlugin::clone() const noexcept } nvinfer1::DimsExprs LookupPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { try { TLLM_CHECK(nbInputs == 2); TLLM_CHECK(outputIndex == 0); DimsExprs ret; - const int nbDimsInput = inputs[0].nbDims; - const int nbDimsWeight = inputs[1].nbDims; + int const nbDimsInput = inputs[0].nbDims; + int const nbDimsWeight = inputs[1].nbDims; ret.nbDims = nbDimsInput + 1; for (int i = 0; i < nbDimsInput; ++i) @@ -80,7 +80,7 @@ nvinfer1::DimsExprs LookupPlugin::getOutputDimensions( return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -88,7 +88,7 @@ nvinfer1::DimsExprs LookupPlugin::getOutputDimensions( } bool LookupPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { bool res = false; switch (pos) @@ -103,19 +103,19 @@ bool LookupPlugin::supportsFormatCombination( return res; } -void LookupPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void LookupPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t LookupPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t LookupPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int LookupPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int LookupPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs // input [batchSize] @@ -129,27 +129,27 @@ int LookupPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvi batchSize *= inputDesc[0].dims.d[i]; } - const int localVocabSize = inputDesc[1].dims.d[0]; - const int hidden = inputDesc[1].dims.d[inputDesc[1].dims.nbDims - 1]; - const int* input = reinterpret_cast(inputs[0]); + int const localVocabSize = inputDesc[1].dims.d[0]; + int const hidden = inputDesc[1].dims.d[inputDesc[1].dims.nbDims - 1]; + int const* input = reinterpret_cast(inputs[0]); int offset = mRank * localVocabSize; if (mType == DataType::kHALF) { - const half* weight = reinterpret_cast(inputs[1]); + half const* weight = reinterpret_cast(inputs[1]); half* output = reinterpret_cast(outputs[0]); invokeLookUp(output, input, weight, batchSize, offset, localVocabSize, hidden, stream); } else if (mType == DataType::kFLOAT) { - const float* weight = reinterpret_cast(inputs[1]); + float const* weight = reinterpret_cast(inputs[1]); float* output = reinterpret_cast(outputs[0]); invokeLookUp(output, input, weight, batchSize, offset, localVocabSize, hidden, stream); } else if (mType == DataType::kBF16) { - const __nv_bfloat16* weight = reinterpret_cast(inputs[1]); + __nv_bfloat16 const* weight = reinterpret_cast<__nv_bfloat16 const*>(inputs[1]); __nv_bfloat16* output = reinterpret_cast<__nv_bfloat16*>(outputs[0]); invokeLookUp<__nv_bfloat16, int>(output, input, weight, batchSize, offset, localVocabSize, hidden, stream); } @@ -159,7 +159,7 @@ int LookupPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvi // IPluginV2Ext Methods nvinfer1::DataType LookupPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return inputTypes[1]; @@ -167,12 +167,12 @@ nvinfer1::DataType LookupPlugin::getOutputDataType( // IPluginV2 Methods -const char* LookupPlugin::getPluginType() const noexcept +char const* LookupPlugin::getPluginType() const noexcept { return LOOKUP_PLUGIN_NAME; } -const char* LookupPlugin::getPluginVersion() const noexcept +char const* LookupPlugin::getPluginVersion() const noexcept { return LOOKUP_PLUGIN_VERSION; } @@ -220,39 +220,39 @@ LookupPluginCreator::LookupPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* LookupPluginCreator::getPluginName() const noexcept +char const* LookupPluginCreator::getPluginName() const noexcept { return LOOKUP_PLUGIN_NAME; } -const char* LookupPluginCreator::getPluginVersion() const noexcept +char const* LookupPluginCreator::getPluginVersion() const noexcept { return LOOKUP_PLUGIN_VERSION; } -const PluginFieldCollection* LookupPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* LookupPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* LookupPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* LookupPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; nvinfer1::DataType type; int rank; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - rank = static_cast(*(static_cast(fields[i].data))); + rank = static_cast(*(static_cast(fields[i].data))); } } try @@ -261,7 +261,7 @@ IPluginV2* LookupPluginCreator::createPlugin(const char* name, const PluginField obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -269,7 +269,7 @@ IPluginV2* LookupPluginCreator::createPlugin(const char* name, const PluginField } IPluginV2* LookupPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call LookupPlugin::destroy() @@ -279,7 +279,7 @@ IPluginV2* LookupPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h b/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h index 035264715..a0722c8c9 100644 --- a/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h +++ b/cpp/tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h @@ -32,30 +32,30 @@ class LookupPlugin : public BasePlugin LookupPlugin(nvinfer1::DataType type, int rank); - LookupPlugin(const void* data, size_t length); + LookupPlugin(void const* data, size_t length); ~LookupPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -75,16 +75,16 @@ class LookupPluginCreator : public BaseCreator public: LookupPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp index d825b3a4d..9d81c0ec9 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.cpp @@ -35,8 +35,8 @@ using tensorrt_llm::plugins::CublasGemmWrapperPtr; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* LORA_PLUGIN_VERSION{"1"}; -static const char* LORA_PLUGIN_NAME{"Lora"}; +static char const* LORA_PLUGIN_VERSION{"1"}; +static char const* LORA_PLUGIN_NAME{"Lora"}; PluginFieldCollection LoraPluginCreator::mFC{}; std::vector LoraPluginCreator::mPluginAttributes; @@ -55,9 +55,9 @@ void _getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int } // TODO should reuse the function in gemmPlugin -void _runGemm(const int M, const int N, const int K, const bool transA, const bool transB, - const nvinfer1::DataType type, const CublasGemmWrapperPtr& cublasWrapperPtr, const void* act, const void* weight, - void* output, const std::optional& heuristic, void* workspace, cudaStream_t stream) +void _runGemm(int const M, int const N, int const K, bool const transA, bool const transB, + const nvinfer1::DataType type, CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act, void const* weight, + void* output, std::optional const& heuristic, void* workspace, cudaStream_t stream) { cublasWrapperPtr->setStream(stream); cublasWrapperPtr->setWorkspace(workspace); @@ -73,7 +73,7 @@ void _runGemm(const int M, const int N, const int K, const bool transA, const bo } LoraPlugin::LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, int transA, int transB, - int num_lora_modules, nvinfer1::DataType type, const LoraPlugin::PluginProfilerPtr& pluginProfiler, + int num_lora_modules, nvinfer1::DataType type, LoraPlugin::PluginProfilerPtr const& pluginProfiler, bool remove_input_padding, int max_context_length, int max_low_rank) : mInHiddenSize(in_hidden_size) , mTransA(transA) @@ -92,11 +92,11 @@ LoraPlugin::LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, in } // Parameterized constructor -LoraPlugin::LoraPlugin(const void* data, size_t length, const LoraPlugin::PluginProfilerPtr& pluginProfiler) +LoraPlugin::LoraPlugin(void const* data, size_t length, LoraPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mInHiddenSize); read(d, mTransA); read(d, mTransB); @@ -175,13 +175,13 @@ nvinfer1::IPluginV2DynamicExt* LoraPlugin::clone() const noexcept } nvinfer1::DimsExprs LoraPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); try { TLLM_CHECK(outputIndex < mNumLoraModules); - const int nbDimsA = inputs[getInputTensorIdx()].nbDims; + int const nbDimsA = inputs[getInputTensorIdx()].nbDims; DimsExprs ret; ret.nbDims = nbDimsA; @@ -210,7 +210,7 @@ nvinfer1::DimsExprs LoraPlugin::getOutputDimensions( ret.d[ret.nbDims - 1] = outHiddenSize; return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -218,7 +218,7 @@ nvinfer1::DimsExprs LoraPlugin::getOutputDimensions( } bool LoraPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); if (pos == getHostRequestTypesIdx()) @@ -243,7 +243,7 @@ bool LoraPlugin::supportsFormatCombination( } } -int32_t _computeMDimension(bool transA, const int32_t nbDims, const int32_t* dims) +int32_t _computeMDimension(bool transA, const int32_t nbDims, int32_t const* dims) { int32_t M = 1; if (transA) @@ -263,7 +263,7 @@ int32_t _computeMDimension(bool transA, const int32_t nbDims, const int32_t* dim return M; } -int32_t _computeNDimension(bool transB, const int32_t nbDims, const int32_t* dims) +int32_t _computeNDimension(bool transB, const int32_t nbDims, int32_t const* dims) { int32_t N = 1; if (transB) @@ -283,17 +283,17 @@ int32_t _computeNDimension(bool transB, const int32_t nbDims, const int32_t* dim return N; } -void LoraPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void LoraPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); - const int nbDimsA = in[0].max.nbDims; - const int nbDimsB = in[1].max.nbDims; + int const nbDimsA = in[0].max.nbDims; + int const nbDimsB = in[1].max.nbDims; - const auto minM = _computeMDimension(mTransA, nbDimsA, in[0].min.d); - const auto maxM = _computeMDimension(mTransA, nbDimsA, in[0].max.d); - const auto N = _computeNDimension(mTransB, nbDimsB, in[1].max.d); - const auto K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; + auto const minM = _computeMDimension(mTransA, nbDimsA, in[0].min.d); + auto const maxM = _computeMDimension(mTransA, nbDimsA, in[0].max.d); + auto const N = _computeNDimension(mTransB, nbDimsB, in[1].max.d); + auto const K = mTransA ? in[0].max.d[0] : in[0].max.d[nbDimsA - 1]; if (!mDims.isInitialized()) { @@ -328,11 +328,11 @@ int64_t getGemmWorkSpaceSize( getSplitkGroupedGemmWorkSpaceSize(nbReq, maxContextLength, maxLoraModuleNum, maxLowRank, splitKSlices)); } -size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t LoraPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); - const int nbReq = inputs[getLoraRanksIdx()].dims.d[0]; + int const nbReq = inputs[getLoraRanksIdx()].dims.d[0]; auto const type = inputs[getInputTensorIdx()].type; auto const typeSize = tensorrt_llm::runtime::BufferDataType(type).getSize(); @@ -341,8 +341,8 @@ size_t LoraPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, in + getGroupedGemmParamsWorkSpaceSize(nbReq * mNumLoraModules); } -void runCublasGemmEx(const int M, const int N, const int K, const bool transA, const bool transB, const void* act, - const void* weight, void* output, cublasHandle_t cublas_handle) +void runCublasGemmEx(int const M, int const N, int const K, bool const transA, bool const transB, void const* act, + void const* weight, void* output, cublasHandle_t cublas_handle) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); float a = 1.0f; @@ -359,8 +359,8 @@ void runCublasGemmEx(const int M, const int N, const int K, const bool transA, c TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); // inputs @@ -390,7 +390,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf void* groupGemmParamsWorkSpace = static_cast(lowRankWorkSpace) + getLowRankWorkSpaceSize(batch_size, mMaxContextLength, mNumLoraModules, mMaxLowRank, typeSize); - const int nbDimsA = inputDesc[0].dims.nbDims; + int const nbDimsA = inputDesc[0].dims.nbDims; for (int loraModuleIdx = 0; loraModuleIdx < mNumLoraModules; loraModuleIdx++) { size_t size = 1; @@ -434,22 +434,22 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf : (mRemoveInputPadding ? host_context_lengths[batchIdx] : inputDesc[0].dims.d[1]); } - const auto lora_rank = lora_ranks[0]; + auto const lora_rank = lora_ranks[0]; auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId); - const auto N = lora_rank; + auto const N = lora_rank; if (N > 0) { TLLM_CHECK_WITH_INFO(N <= mMaxLowRank, fmtstr("Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank)); - const auto K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size - const auto N2 = outputDesc[loraModuleIdx].dims.d[nbDimsA - 1]; + auto const K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size + auto const N2 = outputDesc[loraModuleIdx].dims.d[nbDimsA - 1]; // [M, K] -> [M, N] -> [M, N2] void* lora_in_weight = reinterpret_cast(lora_weights_ptr[0]); void* lora_out_weight = reinterpret_cast(lora_weights_ptr[1]); - const void* input = inputs[0]; + void const* input = inputs[0]; void* output = outputs[loraModuleIdx]; _runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, input, lora_in_weight, lowRankWorkSpace, @@ -497,8 +497,8 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf int handled_token_num = 0; while (batchIdx < batch_size) { - const auto lora_rank = lora_ranks[batchIdx]; - const auto N = lora_rank; + auto const lora_rank = lora_ranks[batchIdx]; + auto const N = lora_rank; int count = 0; size_t M = 0; while (batchIdx + count < batch_size && lora_rank == lora_ranks[batchIdx + count] @@ -517,7 +517,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf TLLM_CHECK_WITH_INFO(N <= mMaxLowRank, fmtstr( "Invalid low_rank (%d). low_rank must be smaller than mMaxLowRank (%d)", N, mMaxLowRank)); - const auto K + auto const K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1]; // input hidden size cutlass::gemm::GemmCoord problem(M, N, K); @@ -535,7 +535,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf + handled_token_num * mMaxLowRank) * typeSize)); - const auto N2 = outputDesc[loraModuleIdx].dims.d[nbDimsA - 1]; + auto const N2 = outputDesc[loraModuleIdx].dims.d[nbDimsA - 1]; cutlass::gemm::GemmCoord problem_2(M, N2, N); problem_sizes_2.push_back(problem_2); ptrA_2.push_back(static_cast(static_cast(lowRankWorkSpace) @@ -571,7 +571,7 @@ int LoraPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf // IPluginV2Ext Methods nvinfer1::DataType LoraPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); TLLM_CHECK(index < mNumLoraModules); @@ -580,13 +580,13 @@ nvinfer1::DataType LoraPlugin::getOutputDataType( // IPluginV2 Methods -const char* LoraPlugin::getPluginType() const noexcept +char const* LoraPlugin::getPluginType() const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return LORA_PLUGIN_NAME; } -const char* LoraPlugin::getPluginVersion() const noexcept +char const* LoraPlugin::getPluginVersion() const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return LORA_PLUGIN_VERSION; @@ -657,29 +657,29 @@ LoraPluginCreator::LoraPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* LoraPluginCreator::getPluginName() const noexcept +char const* LoraPluginCreator::getPluginName() const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return LORA_PLUGIN_NAME; } -const char* LoraPluginCreator::getPluginVersion() const noexcept +char const* LoraPluginCreator::getPluginVersion() const noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return LORA_PLUGIN_VERSION; } -const PluginFieldCollection* LoraPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* LoraPluginCreator::getFieldNames() noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); return &mFC; } -IPluginV2* LoraPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* LoraPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; nvinfer1::DataType type; int num_lora_modules; int in_hidden_size, transA, transB; @@ -689,59 +689,59 @@ IPluginV2* LoraPluginCreator::createPlugin(const char* name, const PluginFieldCo // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "in_hidden_size")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - in_hidden_size = static_cast(*(static_cast(fields[i].data))); + in_hidden_size = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "transa")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - transA = static_cast(*(static_cast(fields[i].data))); + transA = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "transb")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - transB = static_cast(*(static_cast(fields[i].data))); + transB = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "remove_input_padding")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - remove_input_padding = static_cast(*(static_cast(fields[i].data))); + remove_input_padding = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "max_context_length")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - max_context_length = static_cast(*(static_cast(fields[i].data))); + max_context_length = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "max_low_rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - max_low_rank = static_cast(*(static_cast(fields[i].data))); + max_low_rank = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "num_lora_modules")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - num_lora_modules = static_cast(*(static_cast(fields[i].data))); + num_lora_modules = static_cast(*(static_cast(fields[i].data))); } } std::vector out_hidden_sizes; out_hidden_sizes.resize(num_lora_modules); for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; for (int j = 0; j < num_lora_modules; j++) { if (!strcmp(attrName, fmtstr("out_hidden_size_%d", j).c_str())) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - out_hidden_sizes.at(j) = static_cast(*(static_cast(fields[i].data))); + out_hidden_sizes.at(j) = static_cast(*(static_cast(fields[i].data))); } } } @@ -756,14 +756,14 @@ IPluginV2* LoraPluginCreator::createPlugin(const char* name, const PluginFieldCo obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } return nullptr; } -IPluginV2* LoraPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept +IPluginV2* LoraPluginCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept { TLLM_LOG_DEBUG("%s", __PRETTY_FUNCTION__); // This object will be deleted when the network is destroyed, which will @@ -778,7 +778,7 @@ IPluginV2* LoraPluginCreator::deserializePlugin(const char* name, const void* se obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h index 0f80047a3..a20c4a21a 100644 --- a/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h +++ b/cpp/tensorrt_llm/plugins/loraPlugin/loraPlugin.h @@ -39,33 +39,33 @@ class LoraPlugin : public BasePlugin LoraPlugin() = delete; LoraPlugin(int in_hidden_size, std::vector out_hidden_sizes, int transA, int transB, int num_lora_modules, - nvinfer1::DataType type, const PluginProfilerPtr& profiler, bool remove_input_padding, int max_context_length, + nvinfer1::DataType type, PluginProfilerPtr const& profiler, bool remove_input_padding, int max_context_length, int max_low_rank); - LoraPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler); + LoraPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); ~LoraPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -124,7 +124,7 @@ class LoraPlugin : public BasePlugin int mMaxContextLength; int mMaxLowRank; int mNumLoraModules; - const int mSplitKSlices = 16; + int const mSplitKSlices = 16; // @fixme: seems this is shared across multiple clones. // If we deep copy the wrapper inside clone(), then we may avoid the mutex inside the wrapper? @@ -141,16 +141,16 @@ class LoraPluginCreator : public BaseCreator public: LoraPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: GemmPluginProfilerManager gemmPluginProfileManager; diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index e24d2ebd1..576ebce1f 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -31,8 +31,8 @@ using tensorrt_llm::plugins::MixtureOfExpertsPlugin; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* MIXTURE_OF_EXPERTS_PLUGIN_VERSION{"1"}; -static const char* MIXTURE_OF_EXPERTS_PLUGIN_NAME{"MixtureOfExperts"}; +static char const* MIXTURE_OF_EXPERTS_PLUGIN_VERSION{"1"}; +static char const* MIXTURE_OF_EXPERTS_PLUGIN_NAME{"MixtureOfExperts"}; nvinfer1::PluginFieldCollection MixtureOfExpertsPluginCreator::mFC{}; std::vector MixtureOfExpertsPluginCreator::mPluginAttributes; @@ -60,7 +60,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(int number_of_experts, int top_k, init(); } -tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(const MixtureOfExpertsPlugin& other) +tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(MixtureOfExpertsPlugin const& other) : mMOERunner() , mNumExperts(other.mNumExperts) , mK(other.mK) @@ -94,11 +94,11 @@ size_t MixtureOfExpertsPlugin::getSerializationSize() const noexcept } MixtureOfExpertsPlugin::MixtureOfExpertsPlugin( - const void* data, size_t length, MixtureOfExpertsPluginProfilerPtr plugin_profiler_ptr) + void const* data, size_t length, MixtureOfExpertsPluginProfilerPtr plugin_profiler_ptr) : mPluginProfiler(plugin_profiler_ptr) { - const char* d = reinterpret_cast(data); - const char* a = d; + char const* d = reinterpret_cast(data); + char const* a = d; read(d, mNumExperts); read(d, mK); read(d, mExpertHiddenSize); @@ -208,14 +208,14 @@ nvinfer1::IPluginV2DynamicExt* MixtureOfExpertsPlugin::clone() const noexcept } nvinfer1::DimsExprs MixtureOfExpertsPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { assert(outputIndex == getOutputTensorIndex()); return inputs[getInputTensorIndex()]; } bool MixtureOfExpertsPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { TLLM_CHECK(0 <= pos && pos < getNbInputs() + getNbOutputs()); TLLM_CHECK(nbInputs == getNbInputs()); @@ -246,23 +246,23 @@ bool MixtureOfExpertsPlugin::supportsFormatCombination( return false; } -void MixtureOfExpertsPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void MixtureOfExpertsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { auto in_tensor = in[getInputTensorIndex()]; - const auto minM + auto const minM = std::accumulate(in_tensor.min.d, in_tensor.min.d + in_tensor.min.nbDims - 1, 1, std::multiplies()); - const auto maxM + auto const maxM = std::accumulate(in_tensor.max.d, in_tensor.max.d + in_tensor.max.nbDims - 1, 1, std::multiplies()); auto weights_1 = in[getExpertWeights1Index()]; auto weights_2 = in[getExpertWeights2Index()]; int inner_dim_idx = getGemmShapeInnerDimIndex(); - const int maxK = weights_1.max.d[inner_dim_idx]; - const int maxN = weights_2.max.d[inner_dim_idx]; - const int minK = weights_1.min.d[inner_dim_idx]; - const int minN = weights_2.min.d[inner_dim_idx]; + int const maxK = weights_1.max.d[inner_dim_idx]; + int const maxN = weights_2.max.d[inner_dim_idx]; + int const minK = weights_1.min.d[inner_dim_idx]; + int const minN = weights_2.min.d[inner_dim_idx]; TLLM_CHECK_WITH_INFO(minN == maxN, "Variable out channels is not allowed"); TLLM_CHECK_WITH_INFO(minK == maxK, "Variable in channels is not allowed"); @@ -320,7 +320,7 @@ auto MixtureOfExpertsPlugin::setupWorkspace(void* base_ptr, int num_tokens) cons return info; } -int MixtureOfExpertsPlugin::getNumTokens(const nvinfer1::PluginTensorDesc* input_tensors) const +int MixtureOfExpertsPlugin::getNumTokens(nvinfer1::PluginTensorDesc const* input_tensors) const { int ndim = input_tensors[getInputTensorIndex()].dims.nbDims; TLLM_CHECK_WITH_INFO( @@ -333,10 +333,10 @@ int MixtureOfExpertsPlugin::getNumTokens(const nvinfer1::PluginTensorDesc* input return num_tokens; } -size_t MixtureOfExpertsPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t MixtureOfExpertsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { - const int num_tokens = getNumTokens(inputs); + int const num_tokens = getNumTokens(inputs); return setupWorkspace(nullptr, num_tokens).size; } @@ -354,12 +354,12 @@ MOEParallelismConfig MixtureOfExpertsPlugin::getParallelismConfig() const return {}; } -int MixtureOfExpertsPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace_ptr, +int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace_ptr, cudaStream_t stream) noexcept { - const int num_tokens = getNumTokens(inputDesc); - const int num_not_finished = num_tokens; // TODO Take this as an input + int const num_tokens = getNumTokens(inputDesc); + int const num_not_finished = num_tokens; // TODO Take this as an input auto parallelism_config = getParallelismConfig(); auto workspace = setupWorkspace(workspace_ptr, num_tokens); @@ -389,7 +389,7 @@ int MixtureOfExpertsPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, TLLM_CHECK(w2_desc.dims.d[outer_dim_idx] * packed_elements == mExpertHiddenSize); mMOERunner->setTactic(mPluginProfiler->getBestConfig(num_tokens, mGemmId)); - mMOERunner->runMoe(inputs[getInputTensorIndex()], static_cast(inputs[getRoutingTensorIndex()]), + mMOERunner->runMoe(inputs[getInputTensorIndex()], static_cast(inputs[getRoutingTensorIndex()]), inputs[getExpertWeights1Index()], hasExpertQuantScales() ? inputs[getExpertQuantScale1Index()] : nullptr, hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType, inputs[getExpertWeights2Index()], hasExpertQuantScales() ? inputs[getExpertQuantScale2Index()] : nullptr, @@ -397,7 +397,7 @@ int MixtureOfExpertsPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, mNumExperts, mK, static_cast(workspace.workspace), // Outputs outputs[getOutputTensorIndex()], workspace.fc2_output, - hasFinishedTensor() ? static_cast(inputs[getFinishedTensorIndex()]) : nullptr, num_not_finished, + hasFinishedTensor() ? static_cast(inputs[getFinishedTensorIndex()]) : nullptr, num_not_finished, workspace.scale_probs, static_cast(workspace.src_to_dest_map), static_cast(workspace.selected_experts), parallelism_config, mNormalizationMode, stream); @@ -406,7 +406,7 @@ int MixtureOfExpertsPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType MixtureOfExpertsPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == getOutputTensorIndex()); assert(inputTypes[getInputTensorIndex()] == mType); @@ -414,12 +414,12 @@ nvinfer1::DataType MixtureOfExpertsPlugin::getOutputDataType( } // IPluginV2 Methods -const char* MixtureOfExpertsPlugin::getPluginType() const noexcept +char const* MixtureOfExpertsPlugin::getPluginType() const noexcept { return MIXTURE_OF_EXPERTS_PLUGIN_NAME; } -const char* MixtureOfExpertsPlugin::getPluginVersion() const noexcept +char const* MixtureOfExpertsPlugin::getPluginVersion() const noexcept { return MIXTURE_OF_EXPERTS_PLUGIN_VERSION; } @@ -438,29 +438,29 @@ void MixtureOfExpertsPlugin::destroy() noexcept delete this; } -void MixtureOfExpertsPlugin::setPluginNamespace(const char* libNamespace) noexcept +void MixtureOfExpertsPlugin::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } -const char* MixtureOfExpertsPlugin::getPluginNamespace() const noexcept +char const* MixtureOfExpertsPlugin::getPluginNamespace() const noexcept { return mNamespace.c_str(); } /////////////// -const char* MixtureOfExpertsPluginCreator::getPluginName() const noexcept +char const* MixtureOfExpertsPluginCreator::getPluginName() const noexcept { return MIXTURE_OF_EXPERTS_PLUGIN_NAME; } -const char* MixtureOfExpertsPluginCreator::getPluginVersion() const noexcept +char const* MixtureOfExpertsPluginCreator::getPluginVersion() const noexcept { return MIXTURE_OF_EXPERTS_PLUGIN_VERSION; } -const nvinfer1::PluginFieldCollection* MixtureOfExpertsPluginCreator::getFieldNames() noexcept +nvinfer1::PluginFieldCollection const* MixtureOfExpertsPluginCreator::getFieldNames() noexcept { return &mFC; } @@ -495,9 +495,9 @@ MixtureOfExpertsPluginCreator::MixtureOfExpertsPluginCreator() } IPluginV2* MixtureOfExpertsPluginCreator::createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept + char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept { - const nvinfer1::PluginField* fields = fc->fields; + nvinfer1::PluginField const* fields = fc->fields; int mNumExperts{}; int mK{}; int mExpertHiddenSize{}; @@ -514,7 +514,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin( int mNormalizationMode{}; // Read configurations from each fields - using MapPair = std::pair>; + using MapPair = std::pair>; const std::array input_map{ MapPair{"number_of_experts", std::ref(mNumExperts)}, MapPair{"top_k", std::ref(mK)}, @@ -533,13 +533,13 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin( }; for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - for (const auto& item : input_map) + char const* attrName = fields[i].name; + for (auto const& item : input_map) { if (!strcmp(item.first, attrName)) { TLLM_CHECK(fields[i].type == nvinfer1::PluginFieldType::kINT32); - item.second.get() = static_cast(*(static_cast(fields[i].data))); + item.second.get() = static_cast(*(static_cast(fields[i].data))); } } } @@ -556,7 +556,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -564,7 +564,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin( } IPluginV2* MixtureOfExpertsPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call MixtureOfExpertsPlugin::destroy() @@ -577,26 +577,26 @@ IPluginV2* MixtureOfExpertsPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } return nullptr; } -void MixtureOfExpertsPluginCreator::setPluginNamespace(const char* libNamespace) noexcept +void MixtureOfExpertsPluginCreator::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } -const char* MixtureOfExpertsPluginCreator::getPluginNamespace() const noexcept +char const* MixtureOfExpertsPluginCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } std::vector MixtureOfExpertsGemmProfiler::getProfilerWorkspaces(int maxM) { - const auto& plugin = *mRunner; + auto const& plugin = *mRunner; size_t num_tokens = maxM; @@ -634,13 +634,13 @@ void MixtureOfExpertsGemmProfiler::computeTmpSize(int maxM, int n, int k) this->setTmpWorkspaceSizeInBytes(bytes); } -void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, const MixtureOfExpertsGemmProfiler::Config& tactic, +void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExpertsGemmProfiler::Config const& tactic, char* workspace_ptr_char, cudaStream_t const& stream) { assert(mRunner); auto& plugin = *mRunner; auto parallelism_config = plugin.getParallelismConfig(); - const int num_tokens = m; + int const num_tokens = m; int8_t* workspace_ptr = reinterpret_cast(workspace_ptr_char); auto workspaces = getProfilerWorkspaces(m); @@ -659,17 +659,17 @@ void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, const MixtureO // Routing goes first as we need to manually initialise it in initTmpData, everything else can be uninit // If we didn't init routing all the values could go to one expert, causing the profile to be unreliable (e.g. for // expert parallelism) - const float* routing = static_cast(getNext()); - - const void* input = getNext(); - const void* weights_1 = getNext(); - const void* scale_1 = getNext(); - const void* bias_1 = getNext(); - const void* weights_2 = getNext(); - const void* scale_2 = getNext(); - const void* bias_2 = getNext(); + float const* routing = static_cast(getNext()); + + void const* input = getNext(); + void const* weights_1 = getNext(); + void const* scale_1 = getNext(); + void const* bias_1 = getNext(); + void const* weights_2 = getNext(); + void const* scale_2 = getNext(); + void const* bias_2 = getNext(); void* output = getNext(); - const bool* finished = nullptr; // No finished, we want to benchmark all tokens + bool const* finished = nullptr; // No finished, we want to benchmark all tokens auto workspace = plugin.setupWorkspace(getNext(), num_tokens); diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h index 26cdde393..4da5236c3 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h @@ -43,14 +43,14 @@ struct GemmIDMoe tensorrt_llm::common::QuantMode quant_mode; tensorrt_llm::kernels::MOEParallelismMode parallelism_mode{}; - bool operator==(const GemmIDMoe& id) const + bool operator==(GemmIDMoe const& id) const { return id.num_experts == num_experts && id.moe_k == moe_k && id.hidden == hidden && id.inter == inter && id.actfn == actfn && id.dtype == dtype && id.wdtype == wdtype && id.quant_mode == quant_mode && id.parallelism_mode == parallelism_mode; } - friend std::ostream& operator<<(std::ostream& out, const GemmIDMoe& id) + friend std::ostream& operator<<(std::ostream& out, GemmIDMoe const& id) { out << "experts, k, hidden, inter, actfn, dtype, weight type, parallelism mode=" << id.num_experts << "," << id.moe_k << "," << id.hidden << "," << id.inter << "," << static_cast(id.actfn) << "," @@ -63,7 +63,7 @@ struct GemmIDMoe // Hash of GemmIDMoe struct GemmIDMoeHash { - std::size_t operator()(const GemmIDMoe& id) const + std::size_t operator()(GemmIDMoe const& id) const { size_t hash = std::hash{}(id.num_experts); hash ^= std::hash{}(id.moe_k); @@ -90,8 +90,8 @@ class MixtureOfExpertsPlugin : public nvinfer1::IPluginV2DynamicExt tensorrt_llm::common::QuantMode quant_mode, bool use_finished, bool use_bias, int tp_size, int tp_rank, MOEParallelismMode parallelism_mode, MOEExpertScaleNormalizationMode normalization_mode, MixtureOfExpertsPluginProfilerPtr plugin_profiler_ptr); - MixtureOfExpertsPlugin(const void* data, size_t length, MixtureOfExpertsPluginProfilerPtr plugin_profiler_ptr); - MixtureOfExpertsPlugin(const MixtureOfExpertsPlugin&); + MixtureOfExpertsPlugin(void const* data, size_t length, MixtureOfExpertsPluginProfilerPtr plugin_profiler_ptr); + MixtureOfExpertsPlugin(MixtureOfExpertsPlugin const&); void init(); @@ -99,24 +99,24 @@ class MixtureOfExpertsPlugin : public nvinfer1::IPluginV2DynamicExt // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override { @@ -128,8 +128,8 @@ class MixtureOfExpertsPlugin : public nvinfer1::IPluginV2DynamicExt size_t getSerializationSize() const noexcept override; void serialize(void* buffer) const noexcept override; void destroy() noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; private: friend class MixtureOfExpertsGemmProfiler; @@ -169,7 +169,7 @@ class MixtureOfExpertsPlugin : public nvinfer1::IPluginV2DynamicExt size_t size{}; }; - int getNumTokens(const nvinfer1::PluginTensorDesc* input_tensor) const; + int getNumTokens(nvinfer1::PluginTensorDesc const* input_tensor) const; WorkspaceInfo setupWorkspace(void* base_ptr, int num_tokens) const; kernels::MOEParallelismConfig getParallelismConfig() const; @@ -288,7 +288,7 @@ class MixtureOfExpertsGemmProfiler protected: using Config = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) override; + void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; void computeTmpSize(int maxM, int n, int k) override; std::vector getTactics(int m, int n, int k) const override; void initTmpData(int maxM, int n, int k, char* workspace, size_t size, cudaStream_t stream) override; @@ -301,20 +301,20 @@ class MixtureOfExpertsPluginCreator : public nvinfer1::IPluginCreator public: MixtureOfExpertsPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; + char const* getPluginNamespace() const noexcept override; private: GemmPluginProfilerManager moePluginProfiler; diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.cpp index dccc36925..55a9fd16c 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.cpp @@ -22,8 +22,8 @@ using namespace nvinfer1; using tensorrt_llm::plugins::AllgatherPluginCreator; using tensorrt_llm::plugins::AllgatherPlugin; -static const char* ALLGATHER_PLUGIN_VERSION{"1"}; -static const char* ALLGATHER_PLUGIN_NAME{"AllGather"}; +static char const* ALLGATHER_PLUGIN_VERSION{"1"}; +static char const* ALLGATHER_PLUGIN_NAME{"AllGather"}; PluginFieldCollection AllgatherPluginCreator::mFC{}; std::vector AllgatherPluginCreator::mPluginAttributes; @@ -34,9 +34,9 @@ AllgatherPlugin::AllgatherPlugin(std::set group, nvinfer1::DataType type) } // Parameterized constructor -AllgatherPlugin::AllgatherPlugin(const void* data, size_t length) +AllgatherPlugin::AllgatherPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); mGroup.clear(); int groupItem = 0; @@ -61,7 +61,7 @@ nvinfer1::IPluginV2DynamicExt* AllgatherPlugin::clone() const noexcept } nvinfer1::DimsExprs AllgatherPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { auto ret = inputs[0]; auto groupSize = exprBuilder.constant(mGroup.size()); @@ -70,25 +70,25 @@ nvinfer1::DimsExprs AllgatherPlugin::getOutputDimensions( } bool AllgatherPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } -void AllgatherPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void AllgatherPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t AllgatherPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t AllgatherPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int AllgatherPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int AllgatherPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) { @@ -108,7 +108,7 @@ int AllgatherPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const // IPluginV2Ext Methods nvinfer1::DataType AllgatherPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -116,12 +116,12 @@ nvinfer1::DataType AllgatherPlugin::getOutputDataType( // IPluginV2 Methods -const char* AllgatherPlugin::getPluginType() const noexcept +char const* AllgatherPlugin::getPluginType() const noexcept { return ALLGATHER_PLUGIN_NAME; } -const char* AllgatherPlugin::getPluginVersion() const noexcept +char const* AllgatherPlugin::getPluginVersion() const noexcept { return ALLGATHER_PLUGIN_VERSION; } @@ -183,34 +183,34 @@ AllgatherPluginCreator::AllgatherPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* AllgatherPluginCreator::getPluginName() const noexcept +char const* AllgatherPluginCreator::getPluginName() const noexcept { return ALLGATHER_PLUGIN_NAME; } -const char* AllgatherPluginCreator::getPluginVersion() const noexcept +char const* AllgatherPluginCreator::getPluginVersion() const noexcept { return ALLGATHER_PLUGIN_VERSION; } -const PluginFieldCollection* AllgatherPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* AllgatherPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* AllgatherPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* AllgatherPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; std::set group; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "group")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - const auto* r = static_cast(fields[i].data); + auto const* r = static_cast(fields[i].data); for (int j = 0; j < fields[i].length; ++j) { group.insert(*r); @@ -220,7 +220,7 @@ IPluginV2* AllgatherPluginCreator::createPlugin(const char* name, const PluginFi else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } @@ -230,7 +230,7 @@ IPluginV2* AllgatherPluginCreator::createPlugin(const char* name, const PluginFi obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -238,7 +238,7 @@ IPluginV2* AllgatherPluginCreator::createPlugin(const char* name, const PluginFi } IPluginV2* AllgatherPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call AllgatherPlugin::destroy() @@ -248,7 +248,7 @@ IPluginV2* AllgatherPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h index 6cbb93992..ac8e723f7 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h @@ -30,30 +30,30 @@ class AllgatherPlugin : public BasePlugin public: AllgatherPlugin(std::set group, nvinfer1::DataType type); - AllgatherPlugin(const void* data, size_t length); + AllgatherPlugin(void const* data, size_t length); ~AllgatherPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -72,16 +72,16 @@ class AllgatherPluginCreator : public BaseCreator public: AllgatherPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp index f32579196..9d6764c22 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp @@ -26,8 +26,8 @@ using tensorrt_llm::plugins::AllreducePluginCreator; using tensorrt_llm::plugins::AllreducePlugin; using tensorrt_llm::kernels::AllReduceStrategyType; -static const char* ALLREDUCE_PLUGIN_VERSION{"1"}; -static const char* ALLREDUCE_PLUGIN_NAME{"AllReduce"}; +static char const* ALLREDUCE_PLUGIN_VERSION{"1"}; +static char const* ALLREDUCE_PLUGIN_NAME{"AllReduce"}; PluginFieldCollection AllreducePluginCreator::mFC{}; std::vector AllreducePluginCreator::mPluginAttributes; @@ -41,9 +41,9 @@ AllreducePlugin::AllreducePlugin( } // Parameterized constructor -AllreducePlugin::AllreducePlugin(const void* data, size_t length) +AllreducePlugin::AllreducePlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); read(d, mStrategy); read(d, mCounter); @@ -70,13 +70,13 @@ nvinfer1::IPluginV2DynamicExt* AllreducePlugin::clone() const noexcept } nvinfer1::DimsExprs AllreducePlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { return inputs[0]; } bool AllreducePlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { if (mStrategy == AllReduceStrategyType::RING) { @@ -97,20 +97,20 @@ bool AllreducePlugin::supportsFormatCombination( } } -void AllreducePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void AllreducePlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t AllreducePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t AllreducePlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } AllReduceStrategyType AllreducePlugin::selectImplementation(size_t messageSize, int worldSize) noexcept { - const auto maxWorkspaceSize = utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(worldSize); + auto const maxWorkspaceSize = utils::customAllReduceUtils::getMaxRequiredWorkspaceSize(worldSize); if (messageSize > maxWorkspaceSize) { @@ -138,8 +138,8 @@ AllReduceStrategyType AllreducePlugin::selectImplementation(size_t messageSize, return AllReduceStrategyType::TWOSHOT; } -int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) { @@ -191,7 +191,7 @@ int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const myRank = myRank % nRanks; auto params = tensorrt_llm::kernels::AllReduceParams::deserialize( - reinterpret_cast(inputs[1]), nRanks, myRank, mCounter); + reinterpret_cast(inputs[1]), nRanks, myRank, mCounter); cudaMemcpyAsync( params.peer_comm_buffer_ptrs[myRank], inputs[0], size * sizePerElem, cudaMemcpyDeviceToDevice, stream); @@ -204,7 +204,7 @@ int AllreducePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const // IPluginV2Ext Methods nvinfer1::DataType AllreducePlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -212,12 +212,12 @@ nvinfer1::DataType AllreducePlugin::getOutputDataType( // IPluginV2 Methods -const char* AllreducePlugin::getPluginType() const noexcept +char const* AllreducePlugin::getPluginType() const noexcept { return ALLREDUCE_PLUGIN_NAME; } -const char* AllreducePlugin::getPluginVersion() const noexcept +char const* AllreducePlugin::getPluginVersion() const noexcept { return ALLREDUCE_PLUGIN_VERSION; } @@ -304,24 +304,24 @@ AllreducePluginCreator::AllreducePluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* AllreducePluginCreator::getPluginName() const noexcept +char const* AllreducePluginCreator::getPluginName() const noexcept { return ALLREDUCE_PLUGIN_NAME; } -const char* AllreducePluginCreator::getPluginVersion() const noexcept +char const* AllreducePluginCreator::getPluginVersion() const noexcept { return ALLREDUCE_PLUGIN_VERSION; } -const PluginFieldCollection* AllreducePluginCreator::getFieldNames() noexcept +PluginFieldCollection const* AllreducePluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* AllreducePluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; std::set group; nvinfer1::DataType type; AllReduceStrategyType strategy; @@ -329,11 +329,11 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "group")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - const auto* r = static_cast(fields[i].data); + auto const* r = static_cast(fields[i].data); for (int j = 0; j < fields[i].length; ++j) { group.insert(*r); @@ -343,17 +343,17 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "strategy")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - strategy = static_cast(*static_cast(fields[i].data)); + strategy = static_cast(*static_cast(fields[i].data)); } else if (!strcmp(attrName, "counter")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - counter = *static_cast(fields[i].data); + counter = *static_cast(fields[i].data); } } @@ -363,7 +363,7 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -371,7 +371,7 @@ IPluginV2* AllreducePluginCreator::createPlugin(const char* name, const PluginFi } IPluginV2* AllreducePluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call AllreducePlugin::destroy() @@ -381,7 +381,7 @@ IPluginV2* AllreducePluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h index b02e63c14..90c51b015 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.h @@ -34,30 +34,30 @@ class AllreducePlugin : public BasePlugin AllreducePlugin( std::set group, nvinfer1::DataType type, kernels::AllReduceStrategyType strategy, int32_t counter); - AllreducePlugin(const void* data, size_t length); + AllreducePlugin(void const* data, size_t length); ~AllreducePlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -81,16 +81,16 @@ class AllreducePluginCreator : public BaseCreator public: AllreducePluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp index 7a092e35e..54ed7be5a 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.cpp @@ -24,8 +24,8 @@ using namespace nvinfer1; using tensorrt_llm::plugins::RecvPluginCreator; using tensorrt_llm::plugins::RecvPlugin; -static const char* RECV_PLUGIN_VERSION{"1"}; -static const char* RECV_PLUGIN_NAME{"Recv"}; +static char const* RECV_PLUGIN_VERSION{"1"}; +static char const* RECV_PLUGIN_NAME{"Recv"}; PluginFieldCollection RecvPluginCreator::mFC{}; std::vector RecvPluginCreator::mPluginAttributes; @@ -36,9 +36,9 @@ RecvPlugin::RecvPlugin(int srcRank, nvinfer1::DataType type) } // Parameterized constructor -RecvPlugin::RecvPlugin(const void* data, size_t length) +RecvPlugin::RecvPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); read(d, mSrcRank); TLLM_CHECK_WITH_INFO(d == a + length, @@ -57,30 +57,30 @@ nvinfer1::IPluginV2DynamicExt* RecvPlugin::clone() const noexcept } nvinfer1::DimsExprs RecvPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { return inputs[0]; } bool RecvPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } -void RecvPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void RecvPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t RecvPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t RecvPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int RecvPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int RecvPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) { @@ -98,7 +98,7 @@ int RecvPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf // IPluginV2Ext Methods nvinfer1::DataType RecvPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -106,12 +106,12 @@ nvinfer1::DataType RecvPlugin::getOutputDataType( // IPluginV2 Methods -const char* RecvPlugin::getPluginType() const noexcept +char const* RecvPlugin::getPluginType() const noexcept { return RECV_PLUGIN_NAME; } -const char* RecvPlugin::getPluginVersion() const noexcept +char const* RecvPlugin::getPluginVersion() const noexcept { return RECV_PLUGIN_VERSION; } @@ -174,39 +174,39 @@ RecvPluginCreator::RecvPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* RecvPluginCreator::getPluginName() const noexcept +char const* RecvPluginCreator::getPluginName() const noexcept { return RECV_PLUGIN_NAME; } -const char* RecvPluginCreator::getPluginVersion() const noexcept +char const* RecvPluginCreator::getPluginVersion() const noexcept { return RECV_PLUGIN_VERSION; } -const PluginFieldCollection* RecvPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* RecvPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* RecvPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* RecvPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int srcRank; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "src_rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - srcRank = static_cast(*(static_cast(fields[i].data))); + srcRank = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } @@ -216,14 +216,14 @@ IPluginV2* RecvPluginCreator::createPlugin(const char* name, const PluginFieldCo obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } return nullptr; } -IPluginV2* RecvPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept +IPluginV2* RecvPluginCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call RecvPlugin::destroy() @@ -233,7 +233,7 @@ IPluginV2* RecvPluginCreator::deserializePlugin(const char* name, const void* se obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.h index 4eefc61c8..5c8eedfb5 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/recvPlugin.h @@ -29,30 +29,30 @@ class RecvPlugin : public BasePlugin public: RecvPlugin(int srcRank, nvinfer1::DataType type); - RecvPlugin(const void* data, size_t length); + RecvPlugin(void const* data, size_t length); ~RecvPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -71,16 +71,16 @@ class RecvPluginCreator : public BaseCreator public: RecvPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.cpp index 1910c526c..d3537f6bb 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.cpp @@ -23,8 +23,8 @@ using namespace nvinfer1; using tensorrt_llm::plugins::ReduceScatterPluginCreator; using tensorrt_llm::plugins::ReduceScatterPlugin; -static const char* REDUCE_SCATTER_PLUGIN_VERSION{"1"}; -static const char* REDUCE_SCATTER_PLUGIN_NAME{"ReduceScatter"}; +static char const* REDUCE_SCATTER_PLUGIN_VERSION{"1"}; +static char const* REDUCE_SCATTER_PLUGIN_NAME{"ReduceScatter"}; PluginFieldCollection ReduceScatterPluginCreator::mFC{}; std::vector ReduceScatterPluginCreator::mPluginAttributes; @@ -35,9 +35,9 @@ ReduceScatterPlugin::ReduceScatterPlugin(std::set group, nvinfer1::DataType } // Parameterized constructor -ReduceScatterPlugin::ReduceScatterPlugin(const void* data, size_t length) +ReduceScatterPlugin::ReduceScatterPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); mGroup.clear(); int groupItem = 0; @@ -62,7 +62,7 @@ nvinfer1::IPluginV2DynamicExt* ReduceScatterPlugin::clone() const noexcept } nvinfer1::DimsExprs ReduceScatterPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { auto output = inputs[0]; output.d[0] @@ -71,24 +71,24 @@ nvinfer1::DimsExprs ReduceScatterPlugin::getOutputDimensions( } bool ReduceScatterPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } -void ReduceScatterPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void ReduceScatterPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t ReduceScatterPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t ReduceScatterPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int ReduceScatterPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int ReduceScatterPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) @@ -109,7 +109,7 @@ int ReduceScatterPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType ReduceScatterPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -117,12 +117,12 @@ nvinfer1::DataType ReduceScatterPlugin::getOutputDataType( // IPluginV2 Methods -const char* ReduceScatterPlugin::getPluginType() const noexcept +char const* ReduceScatterPlugin::getPluginType() const noexcept { return REDUCE_SCATTER_PLUGIN_NAME; } -const char* ReduceScatterPlugin::getPluginVersion() const noexcept +char const* ReduceScatterPlugin::getPluginVersion() const noexcept { return REDUCE_SCATTER_PLUGIN_VERSION; } @@ -184,34 +184,34 @@ ReduceScatterPluginCreator::ReduceScatterPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* ReduceScatterPluginCreator::getPluginName() const noexcept +char const* ReduceScatterPluginCreator::getPluginName() const noexcept { return REDUCE_SCATTER_PLUGIN_NAME; } -const char* ReduceScatterPluginCreator::getPluginVersion() const noexcept +char const* ReduceScatterPluginCreator::getPluginVersion() const noexcept { return REDUCE_SCATTER_PLUGIN_VERSION; } -const PluginFieldCollection* ReduceScatterPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* ReduceScatterPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* ReduceScatterPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* ReduceScatterPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; std::set group; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "group")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - const auto* r = static_cast(fields[i].data); + auto const* r = static_cast(fields[i].data); for (int j = 0; j < fields[i].length; ++j) { group.insert(*r); @@ -221,7 +221,7 @@ IPluginV2* ReduceScatterPluginCreator::createPlugin(const char* name, const Plug else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } @@ -231,7 +231,7 @@ IPluginV2* ReduceScatterPluginCreator::createPlugin(const char* name, const Plug obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -239,7 +239,7 @@ IPluginV2* ReduceScatterPluginCreator::createPlugin(const char* name, const Plug } IPluginV2* ReduceScatterPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call ReduceScatterPlugin::destroy() @@ -249,7 +249,7 @@ IPluginV2* ReduceScatterPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h index 3fac75030..10f28d2e9 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h @@ -29,30 +29,30 @@ class ReduceScatterPlugin : public BasePlugin public: ReduceScatterPlugin(std::set group, nvinfer1::DataType type); - ReduceScatterPlugin(const void* data, size_t length); + ReduceScatterPlugin(void const* data, size_t length); ~ReduceScatterPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -71,16 +71,16 @@ class ReduceScatterPluginCreator : public BaseCreator public: ReduceScatterPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp index 92f1f1bc0..162a2b1a2 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.cpp @@ -25,8 +25,8 @@ using namespace nvinfer1; using tensorrt_llm::plugins::SendPluginCreator; using tensorrt_llm::plugins::SendPlugin; -static const char* SEND_PLUGIN_VERSION{"1"}; -static const char* SEND_PLUGIN_NAME{"Send"}; +static char const* SEND_PLUGIN_VERSION{"1"}; +static char const* SEND_PLUGIN_NAME{"Send"}; PluginFieldCollection SendPluginCreator::mFC{}; std::vector SendPluginCreator::mPluginAttributes; @@ -37,9 +37,9 @@ SendPlugin::SendPlugin(int tgtRank, nvinfer1::DataType type) } // Parameterized constructor -SendPlugin::SendPlugin(const void* data, size_t length) +SendPlugin::SendPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mType); read(d, mTgtRank); TLLM_CHECK_WITH_INFO(d == a + length, @@ -58,30 +58,30 @@ nvinfer1::IPluginV2DynamicExt* SendPlugin::clone() const noexcept } nvinfer1::DimsExprs SendPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { return inputs[0]; } bool SendPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR); } -void SendPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void SendPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t SendPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t SendPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int SendPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept +int SendPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (isBuilding()) { @@ -99,7 +99,7 @@ int SendPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinf // IPluginV2Ext Methods nvinfer1::DataType SendPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -107,12 +107,12 @@ nvinfer1::DataType SendPlugin::getOutputDataType( // IPluginV2 Methods -const char* SendPlugin::getPluginType() const noexcept +char const* SendPlugin::getPluginType() const noexcept { return SEND_PLUGIN_NAME; } -const char* SendPlugin::getPluginVersion() const noexcept +char const* SendPlugin::getPluginVersion() const noexcept { return SEND_PLUGIN_VERSION; } @@ -177,39 +177,39 @@ SendPluginCreator::SendPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* SendPluginCreator::getPluginName() const noexcept +char const* SendPluginCreator::getPluginName() const noexcept { return SEND_PLUGIN_NAME; } -const char* SendPluginCreator::getPluginVersion() const noexcept +char const* SendPluginCreator::getPluginVersion() const noexcept { return SEND_PLUGIN_VERSION; } -const PluginFieldCollection* SendPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* SendPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* SendPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* SendPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int tgtRank; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "tgt_rank")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - tgtRank = static_cast(*(static_cast(fields[i].data))); + tgtRank = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } @@ -219,14 +219,14 @@ IPluginV2* SendPluginCreator::createPlugin(const char* name, const PluginFieldCo obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } return nullptr; } -IPluginV2* SendPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept +IPluginV2* SendPluginCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call SendPlugin::destroy() @@ -236,7 +236,7 @@ IPluginV2* SendPluginCreator::deserializePlugin(const char* name, const void* se obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.h b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.h index 2c2e8f375..0d36b0ebf 100644 --- a/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.h +++ b/cpp/tensorrt_llm/plugins/ncclPlugin/sendPlugin.h @@ -28,30 +28,30 @@ class SendPlugin : public BasePlugin public: SendPlugin(int tgtRank, nvinfer1::DataType type); - SendPlugin(const void* data, size_t length); + SendPlugin(void const* data, size_t length); ~SendPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -70,16 +70,16 @@ class SendPluginCreator : public BaseCreator public: SendPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp index a8d519fad..44df2e3a2 100644 --- a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.cpp @@ -22,17 +22,17 @@ using namespace tensorrt_llm::kernels; using tensorrt_llm::plugins::QuantizePerTokenPluginCreator; using tensorrt_llm::plugins::QuantizePerTokenPlugin; -static const char* QUANTIZE_PER_TOKEN_PLUGIN_VERSION{"1"}; -static const char* QUANTIZE_PER_TOKEN_PLUGIN_NAME{"QuantizePerToken"}; +static char const* QUANTIZE_PER_TOKEN_PLUGIN_VERSION{"1"}; +static char const* QUANTIZE_PER_TOKEN_PLUGIN_NAME{"QuantizePerToken"}; PluginFieldCollection QuantizePerTokenPluginCreator::mFC{}; std::vector QuantizePerTokenPluginCreator::mPluginAttributes; QuantizePerTokenPlugin::QuantizePerTokenPlugin() {} // Parameterized constructor -QuantizePerTokenPlugin::QuantizePerTokenPlugin(const void* data, size_t length) +QuantizePerTokenPlugin::QuantizePerTokenPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; TLLM_CHECK_WITH_INFO(d == a + length, "Expected length (%d) != real length (%d). This is often " "caused by using different TensorRT-LLM version to build " @@ -49,7 +49,7 @@ nvinfer1::IPluginV2DynamicExt* QuantizePerTokenPlugin::clone() const noexcept } nvinfer1::DimsExprs QuantizePerTokenPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { try { @@ -71,7 +71,7 @@ nvinfer1::DimsExprs QuantizePerTokenPlugin::getOutputDimensions( // [M(*), 1] dynamic per token scales return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -79,7 +79,7 @@ nvinfer1::DimsExprs QuantizePerTokenPlugin::getOutputDimensions( } bool QuantizePerTokenPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { switch (pos) { @@ -100,19 +100,19 @@ bool QuantizePerTokenPlugin::supportsFormatCombination( } } -void QuantizePerTokenPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void QuantizePerTokenPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t QuantizePerTokenPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t QuantizePerTokenPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int QuantizePerTokenPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int QuantizePerTokenPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -131,12 +131,12 @@ int QuantizePerTokenPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, if (inputDesc[0].type == DataType::kFLOAT) { invokePerTokenQuantization(reinterpret_cast(outputs[0]), - reinterpret_cast(inputs[0]), m, k, reinterpret_cast(outputs[1]), stream); + reinterpret_cast(inputs[0]), m, k, reinterpret_cast(outputs[1]), stream); } else { invokePerTokenQuantization(reinterpret_cast(outputs[0]), - reinterpret_cast(inputs[0]), m, k, reinterpret_cast(outputs[1]), stream); + reinterpret_cast(inputs[0]), m, k, reinterpret_cast(outputs[1]), stream); } return 0; @@ -144,7 +144,7 @@ int QuantizePerTokenPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType QuantizePerTokenPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(nbInputs == 1); TLLM_CHECK(index < 2); @@ -153,12 +153,12 @@ nvinfer1::DataType QuantizePerTokenPlugin::getOutputDataType( // IPluginV2 Methods -const char* QuantizePerTokenPlugin::getPluginType() const noexcept +char const* QuantizePerTokenPlugin::getPluginType() const noexcept { return QUANTIZE_PER_TOKEN_PLUGIN_NAME; } -const char* QuantizePerTokenPlugin::getPluginVersion() const noexcept +char const* QuantizePerTokenPlugin::getPluginVersion() const noexcept { return QUANTIZE_PER_TOKEN_PLUGIN_VERSION; } @@ -202,22 +202,22 @@ QuantizePerTokenPluginCreator::QuantizePerTokenPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* QuantizePerTokenPluginCreator::getPluginName() const noexcept +char const* QuantizePerTokenPluginCreator::getPluginName() const noexcept { return QUANTIZE_PER_TOKEN_PLUGIN_NAME; } -const char* QuantizePerTokenPluginCreator::getPluginVersion() const noexcept +char const* QuantizePerTokenPluginCreator::getPluginVersion() const noexcept { return QUANTIZE_PER_TOKEN_PLUGIN_VERSION; } -const PluginFieldCollection* QuantizePerTokenPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* QuantizePerTokenPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* QuantizePerTokenPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* QuantizePerTokenPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { try { @@ -225,7 +225,7 @@ IPluginV2* QuantizePerTokenPluginCreator::createPlugin(const char* name, const P obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -233,7 +233,7 @@ IPluginV2* QuantizePerTokenPluginCreator::createPlugin(const char* name, const P } IPluginV2* QuantizePerTokenPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call QuantizePerTokenPlugin::destroy() @@ -243,7 +243,7 @@ IPluginV2* QuantizePerTokenPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h index c10f0bc77..0bfc25fc5 100644 --- a/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h +++ b/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h @@ -32,30 +32,30 @@ class QuantizePerTokenPlugin : public BasePlugin public: QuantizePerTokenPlugin(); - QuantizePerTokenPlugin(const void* data, size_t length); + QuantizePerTokenPlugin(void const* data, size_t length); ~QuantizePerTokenPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -72,16 +72,16 @@ class QuantizePerTokenPluginCreator : public BaseCreator public: QuantizePerTokenPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp index b318371c1..8ebcc5022 100644 --- a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.cpp @@ -22,17 +22,17 @@ using namespace tensorrt_llm::kernels; using tensorrt_llm::plugins::QuantizeTensorPluginCreator; using tensorrt_llm::plugins::QuantizeTensorPlugin; -static const char* QUANTIZE_TENSOR_PLUGIN_VERSION{"1"}; -static const char* QUANTIZE_TENSOR_PLUGIN_NAME{"QuantizeTensor"}; +static char const* QUANTIZE_TENSOR_PLUGIN_VERSION{"1"}; +static char const* QUANTIZE_TENSOR_PLUGIN_NAME{"QuantizeTensor"}; PluginFieldCollection QuantizeTensorPluginCreator::mFC{}; std::vector QuantizeTensorPluginCreator::mPluginAttributes; QuantizeTensorPlugin::QuantizeTensorPlugin() {} // Parameterized constructor -QuantizeTensorPlugin::QuantizeTensorPlugin(const void* data, size_t length) +QuantizeTensorPlugin::QuantizeTensorPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; TLLM_CHECK_WITH_INFO(d == a + length, "Expected length (%d) != real length (%d). This is often " "caused by using different TensorRT-LLM version to build " @@ -47,7 +47,7 @@ nvinfer1::IPluginV2DynamicExt* QuantizeTensorPlugin::clone() const noexcept } nvinfer1::DimsExprs QuantizeTensorPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { try { @@ -56,7 +56,7 @@ nvinfer1::DimsExprs QuantizeTensorPlugin::getOutputDimensions( // Quantized input return inputs[0]; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -64,7 +64,7 @@ nvinfer1::DimsExprs QuantizeTensorPlugin::getOutputDimensions( } bool QuantizeTensorPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { switch (pos) { @@ -85,19 +85,19 @@ bool QuantizeTensorPlugin::supportsFormatCombination( } } -void QuantizeTensorPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void QuantizeTensorPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t QuantizeTensorPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t QuantizeTensorPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int QuantizeTensorPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int QuantizeTensorPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -114,13 +114,13 @@ int QuantizeTensorPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, if (inputDesc[0].type == DataType::kFLOAT) { - invokeQuantization(reinterpret_cast(outputs[0]), reinterpret_cast(inputs[0]), - numElts, reinterpret_cast(inputs[1]), stream, mProp.maxGridSize[0]); + invokeQuantization(reinterpret_cast(outputs[0]), reinterpret_cast(inputs[0]), + numElts, reinterpret_cast(inputs[1]), stream, mProp.maxGridSize[0]); } else { - invokeQuantization(reinterpret_cast(outputs[0]), reinterpret_cast(inputs[0]), - numElts, reinterpret_cast(inputs[1]), stream, mProp.maxGridSize[0]); + invokeQuantization(reinterpret_cast(outputs[0]), reinterpret_cast(inputs[0]), + numElts, reinterpret_cast(inputs[1]), stream, mProp.maxGridSize[0]); } return 0; @@ -128,7 +128,7 @@ int QuantizeTensorPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType QuantizeTensorPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(nbInputs == 2); TLLM_CHECK(index == 0); @@ -137,12 +137,12 @@ nvinfer1::DataType QuantizeTensorPlugin::getOutputDataType( // IPluginV2 Methods -const char* QuantizeTensorPlugin::getPluginType() const noexcept +char const* QuantizeTensorPlugin::getPluginType() const noexcept { return QUANTIZE_TENSOR_PLUGIN_NAME; } -const char* QuantizeTensorPlugin::getPluginVersion() const noexcept +char const* QuantizeTensorPlugin::getPluginVersion() const noexcept { return QUANTIZE_TENSOR_PLUGIN_VERSION; } @@ -189,22 +189,22 @@ QuantizeTensorPluginCreator::QuantizeTensorPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* QuantizeTensorPluginCreator::getPluginName() const noexcept +char const* QuantizeTensorPluginCreator::getPluginName() const noexcept { return QUANTIZE_TENSOR_PLUGIN_NAME; } -const char* QuantizeTensorPluginCreator::getPluginVersion() const noexcept +char const* QuantizeTensorPluginCreator::getPluginVersion() const noexcept { return QUANTIZE_TENSOR_PLUGIN_VERSION; } -const PluginFieldCollection* QuantizeTensorPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* QuantizeTensorPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* QuantizeTensorPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* QuantizeTensorPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { try { @@ -212,7 +212,7 @@ IPluginV2* QuantizeTensorPluginCreator::createPlugin(const char* name, const Plu obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -220,7 +220,7 @@ IPluginV2* QuantizeTensorPluginCreator::createPlugin(const char* name, const Plu } IPluginV2* QuantizeTensorPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call QuantizeTensorPlugin::destroy() @@ -230,7 +230,7 @@ IPluginV2* QuantizeTensorPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h index ec1d33785..6f1ce864e 100644 --- a/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h +++ b/cpp/tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h @@ -32,30 +32,30 @@ class QuantizeTensorPlugin : public BasePlugin public: QuantizeTensorPlugin(); - QuantizeTensorPlugin(const void* data, size_t length); + QuantizeTensorPlugin(void const* data, size_t length); ~QuantizeTensorPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -73,16 +73,16 @@ class QuantizeTensorPluginCreator : public BaseCreator public: QuantizeTensorPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp index 0d95cd864..f9df16562 100644 --- a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.cpp @@ -23,8 +23,8 @@ using namespace tensorrt_llm::common; using tensorrt_llm::plugins::RmsnormQuantizationPluginCreator; using tensorrt_llm::plugins::RmsnormQuantizationPlugin; -static const char* RMSNORM_QUANTIZATION_PLUGIN_VERSION{"1"}; -static const char* RMSNORM_QUANTIZATION_PLUGIN_NAME{"RmsnormQuantization"}; +static char const* RMSNORM_QUANTIZATION_PLUGIN_VERSION{"1"}; +static char const* RMSNORM_QUANTIZATION_PLUGIN_NAME{"RmsnormQuantization"}; PluginFieldCollection RmsnormQuantizationPluginCreator::mFC{}; std::vector RmsnormQuantizationPluginCreator::mPluginAttributes; @@ -36,9 +36,9 @@ RmsnormQuantizationPlugin::RmsnormQuantizationPlugin(float eps, bool dynamicActi } // Parameterized constructor -RmsnormQuantizationPlugin::RmsnormQuantizationPlugin(const void* data, size_t length) +RmsnormQuantizationPlugin::RmsnormQuantizationPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mEps); read(d, mDynActScaling); read(d, mType); @@ -58,7 +58,7 @@ nvinfer1::IPluginV2DynamicExt* RmsnormQuantizationPlugin::clone() const noexcept } nvinfer1::DimsExprs RmsnormQuantizationPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { if (outputIndex == 0) { @@ -79,7 +79,7 @@ nvinfer1::DimsExprs RmsnormQuantizationPlugin::getOutputDimensions( ret.d[ret.nbDims - 1] = exprBuilder.constant(1); return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -87,9 +87,9 @@ nvinfer1::DimsExprs RmsnormQuantizationPlugin::getOutputDimensions( } bool RmsnormQuantizationPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { - const int totalPoses = 6 + static_cast(mDynActScaling); + int const totalPoses = 6 + static_cast(mDynActScaling); TLLM_CHECK(0 <= pos && pos < totalPoses); TLLM_CHECK(nbInputs == 4); if (pos < nbInputs) @@ -111,19 +111,19 @@ bool RmsnormQuantizationPlugin::supportsFormatCombination( return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR); } -void RmsnormQuantizationPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void RmsnormQuantizationPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t RmsnormQuantizationPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t RmsnormQuantizationPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } -int RmsnormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int RmsnormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -140,24 +140,24 @@ int RmsnormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDe { m *= inputDesc[0].dims.d[i]; } - const int n = inputDesc[1].dims.d[0]; + int const n = inputDesc[1].dims.d[0]; - const float* scale = reinterpret_cast(inputs[3]); + float const* scale = reinterpret_cast(inputs[3]); int8_t* output = reinterpret_cast(outputs[0]); float* dynamic_scale = mDynActScaling ? reinterpret_cast(outputs[1]) : nullptr; if (mType == DataType::kHALF) { - const half* input = reinterpret_cast(inputs[0]); - const half* weight = reinterpret_cast(inputs[1]); - const half* bias = reinterpret_cast(inputs[2]); + half const* input = reinterpret_cast(inputs[0]); + half const* weight = reinterpret_cast(inputs[1]); + half const* bias = reinterpret_cast(inputs[2]); invokeGeneralRmsNorm((half*) nullptr, input, weight, bias, mEps, m, n, stream, scale, dynamic_scale, output); } else if (mType == DataType::kFLOAT) { - const float* input = reinterpret_cast(inputs[0]); - const float* weight = reinterpret_cast(inputs[1]); - const float* bias = reinterpret_cast(inputs[2]); + float const* input = reinterpret_cast(inputs[0]); + float const* weight = reinterpret_cast(inputs[1]); + float const* bias = reinterpret_cast(inputs[2]); invokeGeneralRmsNorm((float*) nullptr, input, weight, bias, mEps, m, n, stream, scale, dynamic_scale, output); } @@ -166,7 +166,7 @@ int RmsnormQuantizationPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDe // IPluginV2Ext Methods nvinfer1::DataType RmsnormQuantizationPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert((mDynActScaling && index < 2) || (!mDynActScaling && index == 0)); if (index == 0) @@ -180,12 +180,12 @@ nvinfer1::DataType RmsnormQuantizationPlugin::getOutputDataType( // IPluginV2 Methods -const char* RmsnormQuantizationPlugin::getPluginType() const noexcept +char const* RmsnormQuantizationPlugin::getPluginType() const noexcept { return RMSNORM_QUANTIZATION_PLUGIN_NAME; } -const char* RmsnormQuantizationPlugin::getPluginVersion() const noexcept +char const* RmsnormQuantizationPlugin::getPluginVersion() const noexcept { return RMSNORM_QUANTIZATION_PLUGIN_VERSION; } @@ -235,45 +235,45 @@ RmsnormQuantizationPluginCreator::RmsnormQuantizationPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* RmsnormQuantizationPluginCreator::getPluginName() const noexcept +char const* RmsnormQuantizationPluginCreator::getPluginName() const noexcept { return RMSNORM_QUANTIZATION_PLUGIN_NAME; } -const char* RmsnormQuantizationPluginCreator::getPluginVersion() const noexcept +char const* RmsnormQuantizationPluginCreator::getPluginVersion() const noexcept { return RMSNORM_QUANTIZATION_PLUGIN_VERSION; } -const PluginFieldCollection* RmsnormQuantizationPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* RmsnormQuantizationPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* RmsnormQuantizationPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* RmsnormQuantizationPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; float eps; nvinfer1::DataType type; bool dynamicActivationScaling; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "eps")) { TLLM_CHECK(fields[i].type == PluginFieldType::kFLOAT32); - eps = static_cast(*(static_cast(fields[i].data))); + eps = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "dyn_act_scaling")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - dynamicActivationScaling = static_cast(*(static_cast(fields[i].data))); + dynamicActivationScaling = static_cast(*(static_cast(fields[i].data))); } } try @@ -282,7 +282,7 @@ IPluginV2* RmsnormQuantizationPluginCreator::createPlugin(const char* name, cons obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -290,7 +290,7 @@ IPluginV2* RmsnormQuantizationPluginCreator::createPlugin(const char* name, cons } IPluginV2* RmsnormQuantizationPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call RmsnormQuantizationPlugin::destroy() @@ -300,7 +300,7 @@ IPluginV2* RmsnormQuantizationPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h index aa3d4fc71..c284b18a5 100644 --- a/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h +++ b/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin/rmsnormQuantizationPlugin.h @@ -30,30 +30,30 @@ class RmsnormQuantizationPlugin : public BasePlugin public: RmsnormQuantizationPlugin(float eps, bool dynamicActivationScaling, nvinfer1::DataType type); - RmsnormQuantizationPlugin(const void* data, size_t length); + RmsnormQuantizationPlugin(void const* data, size_t length); ~RmsnormQuantizationPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -74,16 +74,16 @@ class RmsnormQuantizationPluginCreator : public BaseCreator public: RmsnormQuantizationPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp index c7e512fd9..9db457af4 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.cpp @@ -24,8 +24,8 @@ using namespace tensorrt_llm::common; using tensorrt_llm::plugins::SelectiveScanPluginCreator; using tensorrt_llm::plugins::SelectiveScanPlugin; -static const char* SELECTIVE_SCAN_PLUGIN_VERSION{"1"}; -static const char* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"}; +static char const* SELECTIVE_SCAN_PLUGIN_VERSION{"1"}; +static char const* SELECTIVE_SCAN_PLUGIN_NAME{"SelectiveScan"}; PluginFieldCollection SelectiveScanPluginCreator::mFC{}; std::vector SelectiveScanPluginCreator::mPluginAttributes; @@ -45,9 +45,9 @@ SelectiveScanPlugin::SelectiveScanPlugin( } // Parameterized constructor -SelectiveScanPlugin::SelectiveScanPlugin(const void* data, size_t length) +SelectiveScanPlugin::SelectiveScanPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; read(d, mDim); read(d, mDState); read(d, mIsVariableB); @@ -72,7 +72,7 @@ nvinfer1::IPluginV2DynamicExt* SelectiveScanPlugin::clone() const noexcept // output_tensor: [batch_size, seq_len, dim] // state: [batch_size, dstate, dim] nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { if (outputIndex == 0) { @@ -82,7 +82,7 @@ nvinfer1::DimsExprs SelectiveScanPlugin::getOutputDimensions( } bool SelectiveScanPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { if (pos == getHostRequestTypesIdx()) { @@ -98,20 +98,20 @@ bool SelectiveScanPlugin::supportsFormatCombination( } } -void SelectiveScanPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void SelectiveScanPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t SelectiveScanPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t SelectiveScanPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return 0; } void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch, const size_t dim, const size_t seqLen, - const size_t dstate, const bool isVariableB, const bool isVariableC, void* statePtr, const void* x, - const void* delta, const void* deltaBias, const void* A, const void* B, const void* C, const void* D, const void* z, + const size_t dstate, bool const isVariableB, bool const isVariableC, void* statePtr, void const* x, + void const* delta, void const* deltaBias, void const* A, void const* B, void const* C, void const* D, void const* z, void* out, bool deltaSoftplus) { // Reset the parameters @@ -141,8 +141,8 @@ void SelectiveScanPlugin::setSSMParams(SSMParamsBase& params, const size_t batch } template -int SelectiveScanPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { // inputs @@ -183,8 +183,8 @@ int SelectiveScanPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc return 0; } -int SelectiveScanPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int SelectiveScanPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (mType == DataType::kHALF) @@ -206,7 +206,7 @@ int SelectiveScanPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType SelectiveScanPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { if (index == 0) { @@ -220,12 +220,12 @@ nvinfer1::DataType SelectiveScanPlugin::getOutputDataType( // IPluginV2 Methods -const char* SelectiveScanPlugin::getPluginType() const noexcept +char const* SelectiveScanPlugin::getPluginType() const noexcept { return SELECTIVE_SCAN_PLUGIN_NAME; } -const char* SelectiveScanPlugin::getPluginVersion() const noexcept +char const* SelectiveScanPlugin::getPluginVersion() const noexcept { return SELECTIVE_SCAN_PLUGIN_VERSION; } @@ -281,60 +281,60 @@ SelectiveScanPluginCreator::SelectiveScanPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* SelectiveScanPluginCreator::getPluginName() const noexcept +char const* SelectiveScanPluginCreator::getPluginName() const noexcept { return SELECTIVE_SCAN_PLUGIN_NAME; } -const char* SelectiveScanPluginCreator::getPluginVersion() const noexcept +char const* SelectiveScanPluginCreator::getPluginVersion() const noexcept { return SELECTIVE_SCAN_PLUGIN_VERSION; } -const PluginFieldCollection* SelectiveScanPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* SelectiveScanPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* SelectiveScanPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* SelectiveScanPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int dim, dstate; bool isVariableB, isVariableC, deltaSoftplus; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "dim")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - dim = static_cast(*(static_cast(fields[i].data))); + dim = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "dstate")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - dstate = static_cast(*(static_cast(fields[i].data))); + dstate = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "is_variable_B")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - isVariableB = static_cast(*(static_cast(fields[i].data))); + isVariableB = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "is_variable_C")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - isVariableC = static_cast(*(static_cast(fields[i].data))); + isVariableC = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "delta_softplus")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT8); - deltaSoftplus = static_cast(*(static_cast(fields[i].data))); + deltaSoftplus = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } try @@ -343,7 +343,7 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(const char* name, const Plug obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -351,7 +351,7 @@ IPluginV2* SelectiveScanPluginCreator::createPlugin(const char* name, const Plug } IPluginV2* SelectiveScanPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call SelectiveScanPlugin::destroy() @@ -361,7 +361,7 @@ IPluginV2* SelectiveScanPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h index 679f630d3..ad9062fe3 100644 --- a/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h +++ b/cpp/tensorrt_llm/plugins/selectiveScanPlugin/selectiveScanPlugin.h @@ -49,33 +49,33 @@ class SelectiveScanPlugin : public BasePlugin SelectiveScanPlugin( int dim, int dstate, bool isVariableB, bool isVariableC, bool deltaSoftplus, nvinfer1::DataType type); - SelectiveScanPlugin(const void* data, size_t length); + SelectiveScanPlugin(void const* data, size_t length); ~SelectiveScanPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; template - int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -144,11 +144,11 @@ class SelectiveScanPlugin : public BasePlugin void setSSMParams(tensorrt_llm::kernels::SSMParamsBase& params, // sizes - const size_t batch, const size_t dim, const size_t seqLen, const size_t dstate, const bool isVariableB, - const bool isVariableC, + const size_t batch, const size_t dim, const size_t seqLen, const size_t dstate, bool const isVariableB, + bool const isVariableC, // device pointers - void* statePtr, const void* x, const void* delta, const void* deltaBias, const void* A, const void* B, - const void* C, const void* D, const void* z, void* out, bool deltaSoftplus); + void* statePtr, void const* x, void const* delta, void const* deltaBias, void const* A, void const* B, + void const* C, void const* D, void const* z, void* out, bool deltaSoftplus); private: int mDim; @@ -164,16 +164,16 @@ class SelectiveScanPluginCreator : public BaseCreator public: SelectiveScanPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp index 7e0263db4..9402cb875 100644 --- a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.cpp @@ -26,13 +26,13 @@ using tensorrt_llm::plugins::SmoothQuantGemmPluginProfiler; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* SQ_GEMM_PLUGIN_VERSION{"1"}; -static const char* SQ_GEMM_PLUGIN_NAME{"SmoothQuantGemm"}; +static char const* SQ_GEMM_PLUGIN_VERSION{"1"}; +static char const* SQ_GEMM_PLUGIN_NAME{"SmoothQuantGemm"}; PluginFieldCollection SmoothQuantGemmPluginCreator::mFC{}; std::vector SmoothQuantGemmPluginCreator::mPluginAttributes; -void SmoothQuantGemmPluginProfiler::runTactic(int m, int n, int k, const SmoothQuantGemmPluginProfiler::Config& tactic, - char* workspace, const cudaStream_t& stream) +void SmoothQuantGemmPluginProfiler::runTactic(int m, int n, int k, SmoothQuantGemmPluginProfiler::Config const& tactic, + char* workspace, cudaStream_t const& stream) { int8_t* aTmp = reinterpret_cast(workspace); int8_t* bTmp = nextWorkspacePtr(aTmp, m * k * sizeof(int8_t)); @@ -44,7 +44,7 @@ void SmoothQuantGemmPluginProfiler::runTactic(int m, int n, int k, const SmoothQ char* workspaceTmp = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(alphaColTmp), n * sizeof(float))); - const int wsSize = mRunner->getWorkspaceSize(m, n, k); + int const wsSize = mRunner->getWorkspaceSize(m, n, k); mRunner->gemm( aTmp, bTmp, mQuantMode, alphaColTmp, alphaRowTmp, cTmp, m, n, k, tactic, workspaceTmp, wsSize, stream); @@ -70,7 +70,7 @@ std::vector SmoothQuantGemmPluginProfiler } SmoothQuantGemmPlugin::SmoothQuantGemmPlugin( - QuantMode quantMode, nvinfer1::DataType type, const SmoothQuantGemmPlugin::PluginProfilerPtr& pluginProfiler) + QuantMode quantMode, nvinfer1::DataType type, SmoothQuantGemmPlugin::PluginProfilerPtr const& pluginProfiler) : mQuantMode(quantMode) , mPluginProfiler(pluginProfiler) { @@ -79,10 +79,10 @@ SmoothQuantGemmPlugin::SmoothQuantGemmPlugin( // Parameterized constructor SmoothQuantGemmPlugin::SmoothQuantGemmPlugin( - const void* data, size_t length, const SmoothQuantGemmPlugin::PluginProfilerPtr& pluginProfiler) + void const* data, size_t length, SmoothQuantGemmPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; bool perChannelScaling = false, perTokenScaling = false; nvinfer1::DataType type; unsigned int quantMode; @@ -137,13 +137,13 @@ nvinfer1::IPluginV2DynamicExt* SmoothQuantGemmPlugin::clone() const noexcept } nvinfer1::DimsExprs SmoothQuantGemmPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { try { TLLM_CHECK(nbInputs == 4); TLLM_CHECK(outputIndex == 0); - const int nbDimsA = inputs[0].nbDims; + int const nbDimsA = inputs[0].nbDims; TLLM_CHECK(nbDimsA >= 2); DimsExprs ret; ret.nbDims = nbDimsA; @@ -154,7 +154,7 @@ nvinfer1::DimsExprs SmoothQuantGemmPlugin::getOutputDimensions( ret.d[nbDimsA - 1] = inputs[1].d[0]; return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -162,7 +162,7 @@ nvinfer1::DimsExprs SmoothQuantGemmPlugin::getOutputDimensions( } bool SmoothQuantGemmPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { switch (pos) { @@ -188,16 +188,16 @@ bool SmoothQuantGemmPlugin::supportsFormatCombination( } } -void SmoothQuantGemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void SmoothQuantGemmPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { - const auto minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); - const auto maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); + auto const minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); + auto const maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); - const int maxK = in[0].max.d[in[0].max.nbDims - 1]; - const int maxN = in[1].max.d[0]; - const int minK = in[0].min.d[in[0].min.nbDims - 1]; - const int minN = in[1].min.d[0]; + int const maxK = in[0].max.d[in[0].max.nbDims - 1]; + int const maxN = in[1].max.d[0]; + int const minK = in[0].min.d[in[0].min.nbDims - 1]; + int const minN = in[1].min.d[0]; TLLM_CHECK_WITH_INFO(minN == maxN, "Variable out channels is not allowed"); TLLM_CHECK_WITH_INFO(minK == maxK, "Variable in channels is not allowed"); @@ -211,14 +211,14 @@ void SmoothQuantGemmPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorD m_workspaceMaxSize = m_sqGemmRunner->getWorkspaceSize(maxM, maxN, maxK); } -size_t SmoothQuantGemmPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t SmoothQuantGemmPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return m_workspaceMaxSize; } -int SmoothQuantGemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int SmoothQuantGemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -233,14 +233,14 @@ int SmoothQuantGemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, { m *= inputDesc[0].dims.d[ii]; } - const int n = inputDesc[1].dims.d[0]; - const int k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; - const int wsSize = m_sqGemmRunner->getWorkspaceSize(m, n, k); + int const n = inputDesc[1].dims.d[0]; + int const k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; + int const wsSize = m_sqGemmRunner->getWorkspaceSize(m, n, k); - const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid SQ GEMM tactic"); - m_sqGemmRunner->gemm(reinterpret_cast(inputs[0]), reinterpret_cast(inputs[1]), - mQuantMode, reinterpret_cast(inputs[3]), reinterpret_cast(inputs[2]), + m_sqGemmRunner->gemm(reinterpret_cast(inputs[0]), reinterpret_cast(inputs[1]), + mQuantMode, reinterpret_cast(inputs[3]), reinterpret_cast(inputs[2]), reinterpret_cast(outputs[0]), m, n, k, *bestTactic, reinterpret_cast(workspace), wsSize, stream); return 0; @@ -248,7 +248,7 @@ int SmoothQuantGemmPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, // IPluginV2Ext Methods nvinfer1::DataType SmoothQuantGemmPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return mType; @@ -256,12 +256,12 @@ nvinfer1::DataType SmoothQuantGemmPlugin::getOutputDataType( // IPluginV2 Methods -const char* SmoothQuantGemmPlugin::getPluginType() const noexcept +char const* SmoothQuantGemmPlugin::getPluginType() const noexcept { return SQ_GEMM_PLUGIN_NAME; } -const char* SmoothQuantGemmPlugin::getPluginVersion() const noexcept +char const* SmoothQuantGemmPlugin::getPluginVersion() const noexcept { return SQ_GEMM_PLUGIN_VERSION; } @@ -322,44 +322,44 @@ SmoothQuantGemmPluginCreator::SmoothQuantGemmPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* SmoothQuantGemmPluginCreator::getPluginName() const noexcept +char const* SmoothQuantGemmPluginCreator::getPluginName() const noexcept { return SQ_GEMM_PLUGIN_NAME; } -const char* SmoothQuantGemmPluginCreator::getPluginVersion() const noexcept +char const* SmoothQuantGemmPluginCreator::getPluginVersion() const noexcept { return SQ_GEMM_PLUGIN_VERSION; } -const PluginFieldCollection* SmoothQuantGemmPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* SmoothQuantGemmPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; bool perTokenScaling, perChannelScaling; nvinfer1::DataType type; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "has_per_channel_scaling")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - perChannelScaling = static_cast(*(static_cast(fields[i].data))); + perChannelScaling = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "has_per_token_scaling")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - perTokenScaling = static_cast(*(static_cast(fields[i].data))); + perTokenScaling = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } try @@ -372,7 +372,7 @@ IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(const char* name, const Pl obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -380,7 +380,7 @@ IPluginV2* SmoothQuantGemmPluginCreator::createPlugin(const char* name, const Pl } IPluginV2* SmoothQuantGemmPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call SmoothQuantGemmPlugin::destroy() @@ -392,7 +392,7 @@ IPluginV2* SmoothQuantGemmPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h index 170a0904d..702f1effb 100644 --- a/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h +++ b/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin/smoothQuantGemmPlugin.h @@ -38,13 +38,13 @@ class SmoothQuantGemmPluginProfiler : public GemmPluginProfiler gemmPluginProfileManager; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp index 769cf476d..eb33af803 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.cpp @@ -35,16 +35,16 @@ static constexpr int FP8_ALPHA = int(1) << 3; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* WOQ_GROUPWISE_MATMUL_PLUGIN_VERSION{"1"}; -static const char* WOQ_GROUPWISE_MATMUL_PLUGIN_NAME{"WeightOnlyGroupwiseQuantMatmul"}; +static char const* WOQ_GROUPWISE_MATMUL_PLUGIN_VERSION{"1"}; +static char const* WOQ_GROUPWISE_MATMUL_PLUGIN_NAME{"WeightOnlyGroupwiseQuantMatmul"}; PluginFieldCollection WeightOnlyGroupwiseQuantMatmulPluginCreator::mFC{}; std::vector WeightOnlyGroupwiseQuantMatmulPluginCreator::mPluginAttributes; void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k, - const WeightOnlyGroupwiseQuantGemmPluginProfiler::Config& tactic, char* workspace, const cudaStream_t& stream) + WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { // Quantized weights are packed in FP16 format (INT4*4 -> FP16) - const int originalN = n * FP16_INT4_RATIO; + int const originalN = n * FP16_INT4_RATIO; half* actPtr = reinterpret_cast(workspace); cutlass::uint4b_t* weightPtr = reinterpret_cast( nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half))); @@ -68,7 +68,7 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k, biasesPtr = nullptr; } - const int wsSize = mRunner->getWorkspaceSize(m, n, k); + int const wsSize = mRunner->getWorkspaceSize(m, n, k); mRunner->gemm(actPtr, weightPtr, inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); @@ -77,7 +77,7 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k, void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) { // Quantized weights are packed in FP16 format (INT4*4 -> FP16) - const int originalN = n * FP16_INT4_RATIO; + int const originalN = n * FP16_INT4_RATIO; std::vector workspaces = { maxM * k * sizeof(half), // A k * n * sizeof(float), // B @@ -98,7 +98,7 @@ std::vector WeightOnlyGroupw } WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(nvinfer1::DataType type, int quant_algo, - int group_size, const WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler) + int group_size, WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { init(type, quant_algo, group_size); @@ -106,10 +106,10 @@ WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(nvinf // Parameterized constructor WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin( - const void* data, size_t length, const WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler) + void const* data, size_t length, WeightOnlyGroupwiseQuantMatmulPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; nvinfer1::DataType type; int quant_algo = 0; int group_size = 0; @@ -241,7 +241,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configGemm() } nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { // inputs @@ -257,8 +257,8 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions( { TLLM_CHECK(nbInputs == mAlphaInputIdx + 1); TLLM_CHECK(outputIndex == 0); - const int nbDimsA = inputs[0].nbDims; - const int nbDimsB = inputs[mWeightInputIdx].nbDims; + int const nbDimsA = inputs[0].nbDims; + int const nbDimsB = inputs[mWeightInputIdx].nbDims; TLLM_CHECK(nbDimsA >= 2); TLLM_CHECK(nbDimsB == 2); DimsExprs ret; @@ -273,7 +273,7 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions( return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -281,7 +281,7 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions( } bool WeightOnlyGroupwiseQuantMatmulPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { if (pos < mAlphaInputIdx + 2) { @@ -307,19 +307,19 @@ bool WeightOnlyGroupwiseQuantMatmulPlugin::supportsFormatCombination( } } -void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { - const auto minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); - const auto maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); + auto const minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); + auto const maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); - const int maxK = in[0].max.d[in[0].max.nbDims - 1]; + int const maxK = in[0].max.d[in[0].max.nbDims - 1]; // Quantized weights are packed in FP16 format (INT4*4 -> FP16) - const int maxN = in[mWeightInputIdx].max.d[1] * FP16_INT4_RATIO; + int const maxN = in[mWeightInputIdx].max.d[1] * FP16_INT4_RATIO; - const auto K = maxK; - const auto N = maxN / FP16_INT4_RATIO; + auto const K = maxK; + auto const N = maxN / FP16_INT4_RATIO; if (!mDims.isInitialized()) { @@ -332,14 +332,14 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(const nvinfer1::Dynam m_workspaceMaxSize = smoothedActSize + m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(maxM, maxN, maxK); } -size_t WeightOnlyGroupwiseQuantMatmulPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t WeightOnlyGroupwiseQuantMatmulPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return m_workspaceMaxSize; } -int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -358,8 +358,8 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe { m *= inputDesc[0].dims.d[ii]; } - const int n = inputDesc[mWeightInputIdx].dims.d[1]; - const int k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; + int const n = inputDesc[mWeightInputIdx].dims.d[1]; + int const k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; int smVersion = getSMVersion(); bool use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled; @@ -370,9 +370,9 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe #endif bool use_pre_quant_scale = mQuantAlgo & PRE_QUANT_SCALE; - const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast(inputs[mZerosInputIdx]) : nullptr; - const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast(inputs[mBiasesInputIdx]) : nullptr; - const half* act_ptr = reinterpret_cast(inputs[0]); + half const* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast(inputs[mZerosInputIdx]) : nullptr; + half const* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast(inputs[mBiasesInputIdx]) : nullptr; + half const* act_ptr = reinterpret_cast(inputs[0]); float alpha = 1.0; if (mQuantAlgo & FP8_ALPHA) { @@ -382,20 +382,20 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe if (use_pre_quant_scale && !use_cuda_kernel) { // Apply pre-quant per channel scale on activations - act_ptr = reinterpret_cast(workspace); + act_ptr = reinterpret_cast(workspace); if (mType == nvinfer1::DataType::kHALF) { if (mQuantAlgo & FP8_ALPHA) { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( - reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); } else { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher( - reinterpret_cast(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast(workspace), reinterpret_cast(inputs[0]), + reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); } } #if defined(ENABLE_BF16) @@ -404,14 +404,14 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe if (mQuantAlgo & FP8_ALPHA) { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<__nv_bfloat16, __nv_fp8_e4m3>( - reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast<__nv_fp8_e4m3*>(workspace), reinterpret_cast<__nv_bfloat16 const*>(inputs[0]), + reinterpret_cast<__nv_bfloat16 const*>(inputs[mPreQuantScaleInputIdx]), m, k, stream); } else { tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<__nv_bfloat16, __nv_bfloat16>( - reinterpret_cast<__nv_bfloat16*>(workspace), reinterpret_cast(inputs[0]), - reinterpret_cast(inputs[mPreQuantScaleInputIdx]), m, k, stream); + reinterpret_cast<__nv_bfloat16*>(workspace), reinterpret_cast<__nv_bfloat16 const*>(inputs[0]), + reinterpret_cast<__nv_bfloat16 const*>(inputs[mPreQuantScaleInputIdx]), m, k, stream); } } #endif @@ -444,16 +444,16 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe // Use CUDA kernels for small batch size // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - const void* pre_quant_scale_ptr = nullptr; + void const* pre_quant_scale_ptr = nullptr; if (use_pre_quant_scale) pre_quant_scale_ptr = inputs[mPreQuantScaleInputIdx]; - void* cuda_kernel_act_ptr = const_cast(reinterpret_cast(inputs[0])); - void* cuda_kernel_act_scale_ptr = const_cast(reinterpret_cast(pre_quant_scale_ptr)); - void* cuda_kernel_weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - void* cuda_kernel_scales_ptr = const_cast(reinterpret_cast(inputs[mScalesInputIdx])); - void* cuda_kernel_zeros_ptr = const_cast(reinterpret_cast(zeros_ptr)); - void* cuda_kernel_bias_ptr = const_cast(reinterpret_cast(biases_ptr)); - void* cuda_kernel_out_ptr = const_cast(reinterpret_cast(outputs[0])); + void* cuda_kernel_act_ptr = const_cast(reinterpret_cast(inputs[0])); + void* cuda_kernel_act_scale_ptr = const_cast(reinterpret_cast(pre_quant_scale_ptr)); + void* cuda_kernel_weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); + void* cuda_kernel_scales_ptr = const_cast(reinterpret_cast(inputs[mScalesInputIdx])); + void* cuda_kernel_zeros_ptr = const_cast(reinterpret_cast(zeros_ptr)); + void* cuda_kernel_bias_ptr = const_cast(reinterpret_cast(biases_ptr)); + void* cuda_kernel_out_ptr = const_cast(reinterpret_cast(outputs[0])); tensorrt_llm::kernels::weight_only::Params params{cuda_kernel_act_ptr, cuda_kernel_act_scale_ptr, cuda_kernel_weight_ptr, cuda_kernel_scales_ptr, cuda_kernel_zeros_ptr, cuda_kernel_bias_ptr, @@ -464,11 +464,11 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe else { // Use cutlass kernels for large batch size - const int ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); + int const ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); - int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); + int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic"); m_weightOnlyGroupwiseGemmRunner->gemm(act_ptr, weight_ptr, inputs[mScalesInputIdx], zeros_ptr, biases_ptr, alpha, outputs[0], m, real_n, k, mGroupSize, *bestTactic, @@ -483,10 +483,10 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe // Use CUDA kernels for small batch size // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel // when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - const void* pre_quant_scale = nullptr; + void const* pre_quant_scale = nullptr; if (use_pre_quant_scale) pre_quant_scale = inputs[mPreQuantScaleInputIdx]; - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[mWeightInputIdx]), + tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[mWeightInputIdx]), inputs[mScalesInputIdx], zeros_ptr, act_ptr, pre_quant_scale, biases_ptr, outputs[0], m, real_n, k, mGroupSize, tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, tensorrt_llm::kernels::WeightOnlyType::GroupWise, @@ -496,11 +496,11 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe else { // Use cutlass kernels for large batch size - const int ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); + int const ws_bytes = m_weightOnlyGroupwiseGemmRunner->getWorkspaceSize(m, n, k); - int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); + int32_t* weight_ptr = const_cast(reinterpret_cast(inputs[mWeightInputIdx])); - const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only groupwise GEMM tactic(It is usually caused by the failure to execute all " "candidate " @@ -518,7 +518,7 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe // IPluginV2Ext Methods nvinfer1::DataType WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return mType; @@ -526,12 +526,12 @@ nvinfer1::DataType WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDataType( // IPluginV2 Methods -const char* WeightOnlyGroupwiseQuantMatmulPlugin::getPluginType() const noexcept +char const* WeightOnlyGroupwiseQuantMatmulPlugin::getPluginType() const noexcept { return WOQ_GROUPWISE_MATMUL_PLUGIN_NAME; } -const char* WeightOnlyGroupwiseQuantMatmulPlugin::getPluginVersion() const noexcept +char const* WeightOnlyGroupwiseQuantMatmulPlugin::getPluginVersion() const noexcept { return WOQ_GROUPWISE_MATMUL_PLUGIN_VERSION; } @@ -589,46 +589,46 @@ WeightOnlyGroupwiseQuantMatmulPluginCreator::WeightOnlyGroupwiseQuantMatmulPlugi mFC.fields = mPluginAttributes.data(); } -const char* WeightOnlyGroupwiseQuantMatmulPluginCreator::getPluginName() const noexcept +char const* WeightOnlyGroupwiseQuantMatmulPluginCreator::getPluginName() const noexcept { return WOQ_GROUPWISE_MATMUL_PLUGIN_NAME; } -const char* WeightOnlyGroupwiseQuantMatmulPluginCreator::getPluginVersion() const noexcept +char const* WeightOnlyGroupwiseQuantMatmulPluginCreator::getPluginVersion() const noexcept { return WOQ_GROUPWISE_MATMUL_PLUGIN_VERSION; } -const PluginFieldCollection* WeightOnlyGroupwiseQuantMatmulPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* WeightOnlyGroupwiseQuantMatmulPluginCreator::getFieldNames() noexcept { return &mFC; } IPluginV2* WeightOnlyGroupwiseQuantMatmulPluginCreator::createPlugin( - const char* name, const PluginFieldCollection* fc) noexcept + char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; nvinfer1::DataType type; int QuantAlgo; int GroupSize; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "quant_algo")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - QuantAlgo = static_cast(*(static_cast(fields[i].data))); + QuantAlgo = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "group_size")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - GroupSize = static_cast(*(static_cast(fields[i].data))); + GroupSize = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } try @@ -640,7 +640,7 @@ IPluginV2* WeightOnlyGroupwiseQuantMatmulPluginCreator::createPlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -648,7 +648,7 @@ IPluginV2* WeightOnlyGroupwiseQuantMatmulPluginCreator::createPlugin( } IPluginV2* WeightOnlyGroupwiseQuantMatmulPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call weightOnlyGroupwiseQuantMatmulPlugin::destroy() @@ -660,7 +660,7 @@ IPluginV2* WeightOnlyGroupwiseQuantMatmulPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h index c2df11dd4..5c664f5fb 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin/weightOnlyGroupwiseQuantMatmulPlugin.h @@ -61,7 +61,7 @@ class WeightOnlyGroupwiseQuantGemmPluginProfiler } protected: - void runTactic(int m, int n, int k, const Config& tactic, char* workspace, const cudaStream_t& stream) override; + void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) override; void computeTmpSize(int maxM, int n, int k) override; @@ -80,32 +80,32 @@ class WeightOnlyGroupwiseQuantMatmulPlugin : public BasePlugin WeightOnlyGroupwiseQuantMatmulPlugin() = delete; WeightOnlyGroupwiseQuantMatmulPlugin( - nvinfer1::DataType type, int quant_algo, int group_size, const PluginProfilerPtr& profiler); + nvinfer1::DataType type, int quant_algo, int group_size, PluginProfilerPtr const& profiler); - WeightOnlyGroupwiseQuantMatmulPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler); + WeightOnlyGroupwiseQuantMatmulPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); ~WeightOnlyGroupwiseQuantMatmulPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -153,16 +153,16 @@ class WeightOnlyGroupwiseQuantMatmulPluginCreator : public BaseCreator public: WeightOnlyGroupwiseQuantMatmulPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: GemmPluginProfilerManager gemmPluginProfileManager; diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp index e2586652e..e4f57f31e 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.cpp @@ -26,15 +26,15 @@ using tensorrt_llm::plugins::WeightOnlyQuantGemmPluginProfiler; using tensorrt_llm::plugins::read; using tensorrt_llm::plugins::write; -static const char* WOQ_MATMUL_PLUGIN_VERSION{"1"}; -static const char* WOQ_MATMUL_PLUGIN_NAME{"WeightOnlyQuantMatmul"}; +static char const* WOQ_MATMUL_PLUGIN_VERSION{"1"}; +static char const* WOQ_MATMUL_PLUGIN_NAME{"WeightOnlyQuantMatmul"}; PluginFieldCollection WeightOnlyQuantMatmulPluginCreator::mFC{}; std::vector WeightOnlyQuantMatmulPluginCreator::mPluginAttributes; void WeightOnlyQuantGemmPluginProfiler::runTactic(int m, int n, int k, - const WeightOnlyQuantGemmPluginProfiler::Config& tactic, char* workspace, const cudaStream_t& stream) + WeightOnlyQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { - const int originalN = n * getWeightTypeMultiplier(mWeightTypeId); + int const originalN = n * getWeightTypeMultiplier(mWeightTypeId); half* actPtr = reinterpret_cast(workspace); int8_t* weightPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half))); @@ -45,7 +45,7 @@ void WeightOnlyQuantGemmPluginProfiler::runTactic(int m, int n, int k, char* workspacePtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(outputPtr), m * originalN * sizeof(half))); - const int wsSize = mRunner->getWorkspaceSize(m, n, k); + int const wsSize = mRunner->getWorkspaceSize(m, n, k); if (mWeightTypeId == WeightTypeId::INT8) { @@ -60,7 +60,7 @@ void WeightOnlyQuantGemmPluginProfiler::runTactic(int m, int n, int k, void WeightOnlyQuantGemmPluginProfiler::computeTmpSize(int maxM, int n, int k) { - const int originalN = n * getWeightTypeMultiplier(mWeightTypeId); + int const originalN = n * getWeightTypeMultiplier(mWeightTypeId); std::vector workspaces = { maxM * k * sizeof(half), // A originalN * k * sizeof(int8_t), // B @@ -79,7 +79,7 @@ std::vector WeightOnlyQuantGemmPlugin } WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin(nvinfer1::DataType type, WeightTypeId weightTypeId, - const WeightOnlyQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler) + WeightOnlyQuantMatmulPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { init(type, weightTypeId); @@ -87,10 +87,10 @@ WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin(nvinfer1::DataType type // Parameterized constructor WeightOnlyQuantMatmulPlugin::WeightOnlyQuantMatmulPlugin( - const void* data, size_t length, const WeightOnlyQuantMatmulPlugin::PluginProfilerPtr& pluginProfiler) + void const* data, size_t length, WeightOnlyQuantMatmulPlugin::PluginProfilerPtr const& pluginProfiler) : mPluginProfiler(pluginProfiler) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; nvinfer1::DataType type; WeightTypeId weightTypeId; read(d, type); @@ -178,7 +178,7 @@ void WeightOnlyQuantMatmulPlugin::configGemm() } nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { // input [m1, m2, m3, ... , k] // weight [k, n] for int8, [k, n/2] for int4 @@ -187,8 +187,8 @@ nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions( { TLLM_CHECK(nbInputs == 3); TLLM_CHECK(outputIndex == 0); - const int nbDimsA = inputs[0].nbDims; - const int nbDimsB = inputs[1].nbDims; + int const nbDimsA = inputs[0].nbDims; + int const nbDimsB = inputs[1].nbDims; TLLM_CHECK(nbDimsA >= 2); TLLM_CHECK(nbDimsB == 2); DimsExprs ret; @@ -209,7 +209,7 @@ nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions( } return ret; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -217,7 +217,7 @@ nvinfer1::DimsExprs WeightOnlyQuantMatmulPlugin::getOutputDimensions( } bool WeightOnlyQuantMatmulPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { switch (pos) { @@ -242,17 +242,17 @@ bool WeightOnlyQuantMatmulPlugin::supportsFormatCombination( } } -void WeightOnlyQuantMatmulPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void WeightOnlyQuantMatmulPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { - const auto minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); - const auto maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); + auto const minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies()); + auto const maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies()); - const int maxK = in[0].max.d[in[0].max.nbDims - 1]; - const int maxN = in[1].max.d[1] * getWeightTypeMultiplier(mWeightTypeId); + int const maxK = in[0].max.d[in[0].max.nbDims - 1]; + int const maxN = in[1].max.d[1] * getWeightTypeMultiplier(mWeightTypeId); - const auto K = maxK; - const auto N = maxN / getWeightTypeMultiplier(mWeightTypeId); + auto const K = maxK; + auto const N = maxN / getWeightTypeMultiplier(mWeightTypeId); if (!mDims.isInitialized()) { @@ -264,14 +264,14 @@ void WeightOnlyQuantMatmulPlugin::configurePlugin(const nvinfer1::DynamicPluginT m_workspaceMaxSize = m_weightOnlyGemmRunner->getWorkspaceSize(maxM, maxN, maxK); } -size_t WeightOnlyQuantMatmulPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t WeightOnlyQuantMatmulPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { return m_workspaceMaxSize; } -int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int WeightOnlyQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { // inputs @@ -286,10 +286,10 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input { m *= inputDesc[0].dims.d[ii]; } - const int n = inputDesc[1].dims.d[1]; - const int k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; + int const n = inputDesc[1].dims.d[1]; + int const k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1]; - const bool use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled; + bool const use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled; #if defined(ENABLE_BF16) TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16, "No valid weightOnlyQuantMatmul configuration"); @@ -323,7 +323,7 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input // Use CUDA kernels for small batch size // The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass // kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights. - tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), inputs[2], nullptr, + tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast(inputs[1]), inputs[2], nullptr, inputs[0], nullptr, nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type, tensorrt_llm::kernels::WeightOnlyType::PerChannel, tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type}; @@ -331,9 +331,9 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input } else { - const int ws_size = m_weightOnlyGemmRunner->getWorkspaceSize(m, real_n, k); + int const ws_size = m_weightOnlyGemmRunner->getWorkspaceSize(m, real_n, k); - const auto& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); + auto const& bestTactic = mPluginProfiler->getBestConfig(m, mGemmId); TLLM_CHECK_WITH_INFO(bestTactic, "No valid weight only per-channel GEMM tactic(It is usually caused by the failure to execute all candidate " "configurations of the CUTLASS kernel, please pay attention to the warning information when building the " @@ -348,7 +348,7 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input // IPluginV2Ext Methods nvinfer1::DataType WeightOnlyQuantMatmulPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { TLLM_CHECK(index == 0); return mType; @@ -356,12 +356,12 @@ nvinfer1::DataType WeightOnlyQuantMatmulPlugin::getOutputDataType( // IPluginV2 Methods -const char* WeightOnlyQuantMatmulPlugin::getPluginType() const noexcept +char const* WeightOnlyQuantMatmulPlugin::getPluginType() const noexcept { return WOQ_MATMUL_PLUGIN_NAME; } -const char* WeightOnlyQuantMatmulPlugin::getPluginVersion() const noexcept +char const* WeightOnlyQuantMatmulPlugin::getPluginVersion() const noexcept { return WOQ_MATMUL_PLUGIN_VERSION; } @@ -416,39 +416,39 @@ WeightOnlyQuantMatmulPluginCreator::WeightOnlyQuantMatmulPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* WeightOnlyQuantMatmulPluginCreator::getPluginName() const noexcept +char const* WeightOnlyQuantMatmulPluginCreator::getPluginName() const noexcept { return WOQ_MATMUL_PLUGIN_NAME; } -const char* WeightOnlyQuantMatmulPluginCreator::getPluginVersion() const noexcept +char const* WeightOnlyQuantMatmulPluginCreator::getPluginVersion() const noexcept { return WOQ_MATMUL_PLUGIN_VERSION; } -const PluginFieldCollection* WeightOnlyQuantMatmulPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* WeightOnlyQuantMatmulPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* WeightOnlyQuantMatmulPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* WeightOnlyQuantMatmulPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; nvinfer1::DataType type; WeightTypeId weightTypeId; // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "weight_type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - weightTypeId = static_cast(*(static_cast(fields[i].data))); + weightTypeId = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { TLLM_CHECK(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } try @@ -460,7 +460,7 @@ IPluginV2* WeightOnlyQuantMatmulPluginCreator::createPlugin(const char* name, co obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } @@ -468,7 +468,7 @@ IPluginV2* WeightOnlyQuantMatmulPluginCreator::createPlugin(const char* name, co } IPluginV2* WeightOnlyQuantMatmulPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call WeightOnlyQuantMatmulPlugin::destroy() @@ -480,7 +480,7 @@ IPluginV2* WeightOnlyQuantMatmulPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { caughtError(e); } diff --git a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h index 964c1f0ed..aac4cd6e0 100644 --- a/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h +++ b/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin/weightOnlyQuantMatmulPlugin.h @@ -67,7 +67,7 @@ class WeightOnlyQuantGemmPluginProfiler : public GemmPluginProfiler; WeightOnlyQuantMatmulPlugin() = delete; - WeightOnlyQuantMatmulPlugin(nvinfer1::DataType type, WeightTypeId weightTypeId, const PluginProfilerPtr& profiler); + WeightOnlyQuantMatmulPlugin(nvinfer1::DataType type, WeightTypeId weightTypeId, PluginProfilerPtr const& profiler); - WeightOnlyQuantMatmulPlugin(const void* data, size_t length, const PluginProfilerPtr& profiler); + WeightOnlyQuantMatmulPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler); ~WeightOnlyQuantMatmulPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; @@ -145,16 +145,16 @@ class WeightOnlyQuantMatmulPluginCreator : public BaseCreator public: WeightOnlyQuantMatmulPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; private: GemmPluginProfilerManager gemmPluginProfileManager; diff --git a/cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp index b3e2286b2..58b72f6a5 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp @@ -37,7 +37,7 @@ namespace tensorrt_llm::pybind::batch_manager GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth, tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback const& getInferenceRequestsCb, - SendResponseCallback const& sendResponseCb, const tb::PollStopSignalCallback& pollStopSignalCb, + SendResponseCallback const& sendResponseCb, tb::PollStopSignalCallback const& pollStopSignalCb, tb::ReturnBatchManagerStatsCallback const& returnBatchManagerStatsCb, tb::TrtGptModelOptionalParams const& optionalParams, std::optional terminateReqId) { @@ -87,7 +87,7 @@ tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback co tb::SendResponseCallback callbackAdapter(SendResponseCallback const& callback) { - return [callback](uint64_t id, std::list const& cppTensors, bool isOk, const std::string& errMsg) + return [callback](uint64_t id, std::list const& cppTensors, bool isOk, std::string const& errMsg) { std::list pythonList{}; for (const auto& cppNamedTensor : cppTensors) @@ -103,7 +103,7 @@ void GptManager::initBindings(py::module_& m) py::class_(m, "GptManager") .def(py::init>(), + tb::ReturnBatchManagerStatsCallback, tb::TrtGptModelOptionalParams const&, std::optional>(), py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"), py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr, py::arg("return_batch_manager_stats_cb") = nullptr, diff --git a/cpp/tensorrt_llm/pybind/batch_manager/gptManager.h b/cpp/tensorrt_llm/pybind/batch_manager/gptManager.h index df5d233e2..ef8e38e19 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/gptManager.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/gptManager.h @@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::batch_manager { using GetInferenceRequestsCallback = std::function(int32_t)>; -using SendResponseCallback = std::function const&, bool, const std::string&)>; +using SendResponseCallback = std::function const&, bool, std::string const&)>; tensorrt_llm::batch_manager::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback const& callback); tensorrt_llm::batch_manager::SendResponseCallback callbackAdapter(SendResponseCallback const& callback); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp index 55317c498..96bd90fff 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp @@ -28,7 +28,7 @@ #ifdef _WIN32 // FIXME: THPStream_Wrap seems not to be present in libtorch_python.so on Windows -PyObject* THPStream_Wrap(const c10::Stream& stream) +PyObject* THPStream_Wrap(c10::Stream const& stream) { TLLM_THROW("Stream conversion in not yet supported on Windows."); return nullptr; @@ -147,7 +147,7 @@ void InferenceRequest::initBindings(py::module_& m) .def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming) .def_property_readonly("request_id", &InferenceRequest::getRequestId) .def(py::pickle( - [](const InferenceRequest& p) { // __getstate__ + [](InferenceRequest const& p) { // __getstate__ return py::bytearray(p.serialize()); }, [](py::bytearray const& t) { // __setstate__ diff --git a/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp b/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp index 262624e6b..9526d1699 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp @@ -29,7 +29,7 @@ namespace tb = tensorrt_llm::batch_manager; namespace tensorrt_llm::pybind::batch_manager { -NamedTensor::NamedTensor(const tb::NamedTensor& cppNamedTensor) +NamedTensor::NamedTensor(tb::NamedTensor const& cppNamedTensor) : Base(runtime::Torch::tensor(cppNamedTensor.tensor), cppNamedTensor.name) { } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h b/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h index afe257235..9a0bf661d 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h @@ -39,7 +39,7 @@ class NamedTensor : public tensorrt_llm::batch_manager::GenericNamedTensor #include #include +#include #include #include #include @@ -62,6 +64,17 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) tpr::GenerationInput::initBindings(m); tpr::GenerationOutput::initBindings(m); + auto kvCacheConfigGetstate = [](tbk::KvCacheConfig const& config) + { + return py::make_tuple(config.maxTokens, config.maxAttentionWindow, config.sinkTokenLength, + config.freeGpuMemoryFraction, config.enableBlockReuse, config.useUvm); + }; + auto kvCacheConfigSetstate = [](py::tuple t) + { + return tbk::KvCacheConfig(t[0].cast>(), t[1].cast>(), + t[2].cast>(), t[3].cast>(), t[4].cast(), + t[5].cast()); + }; py::class_(m, "KvCacheConfig") .def(py::init, std::optional, std::optional, std::optional, bool>(), @@ -72,7 +85,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("max_attention_window", &tbk::KvCacheConfig::maxAttentionWindow) .def_readwrite("sink_token_length", &tbk::KvCacheConfig::sinkTokenLength) .def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction) - .def_readwrite("enable_block_reuse", &tbk::KvCacheConfig::enableBlockReuse); + .def_readwrite("enable_block_reuse", &tbk::KvCacheConfig::enableBlockReuse) + .def(py::pickle(kvCacheConfigGetstate, kvCacheConfigSetstate)) + .def("__eq__", &tbk::KvCacheConfig::operator==); py::class_(m, "GptSessionConfig") .def(py::init(), py::arg("max_batch_size"), py::arg("max_beam_width"), @@ -251,11 +266,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_property_readonly("pipeline_parallelism", &tr::GptJsonConfig::getPipelineParallelism) .def_property_readonly("world_size", &tr::GptJsonConfig::getWorldSize) .def("engine_filename", - py::overload_cast( + py::overload_cast( &tr::GptJsonConfig::engineFilename, py::const_), py::arg("world_config"), py::arg("model")) .def("engine_filename", - py::overload_cast(&tr::GptJsonConfig::engineFilename, py::const_), + py::overload_cast(&tr::GptJsonConfig::engineFilename, py::const_), py::arg("world_config")); py::class_(m, "GptSession") @@ -333,6 +348,20 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("MAX_UTILIZATION", tbb::SchedulerPolicy::MAX_UTILIZATION) .value("GUARANTEED_NO_EVICT", tbb::SchedulerPolicy::GUARANTEED_NO_EVICT); + auto gptModelParamsGetstate = [&kvCacheConfigGetstate](tb::TrtGptModelOptionalParams const& params) + { + auto kvCacheState = kvCacheConfigGetstate(params.kvCacheConfig); + return py::make_tuple(kvCacheState, params.enableTrtOverlap, params.deviceIds, params.normalizeLogProbs, + params.enableChunkedContext, params.decodingMode); + }; + auto gptModelParamsSetstate = [&kvCacheConfigSetstate](py::tuple t) + { + auto kvCacheConfig = kvCacheConfigSetstate(t[0]); + return tb::TrtGptModelOptionalParams(kvCacheConfig, t[1].cast(), + t[2].cast>>(), t[3].cast(), t[4].cast(), + t[5].cast>()); + }; + py::class_(m, "TrtGptModelOptionalParams") .def(py::init(), py::arg_v("kv_cache_config", tbk::KvCacheConfig{}, "KvCacheConfig()"), @@ -342,7 +371,9 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .def_readwrite("device_ids", &tb::TrtGptModelOptionalParams::deviceIds) .def_readwrite("enable_chunked_context", &tb::TrtGptModelOptionalParams::enableChunkedContext) .def_readwrite("normalize_log_probs", &tb::TrtGptModelOptionalParams::normalizeLogProbs) - .def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode); + .def_readwrite("decoding_mode", &tb::TrtGptModelOptionalParams::decodingMode) + .def(py::pickle(gptModelParamsGetstate, gptModelParamsSetstate)) + .def("__eq__", &tb::TrtGptModelOptionalParams::operator==); tpb::GptManager::initBindings(m); diff --git a/cpp/tensorrt_llm/pybind/utils/pathCaster.h b/cpp/tensorrt_llm/pybind/utils/pathCaster.h index 3cba9e39e..571be82ad 100644 --- a/cpp/tensorrt_llm/pybind/utils/pathCaster.h +++ b/cpp/tensorrt_llm/pybind/utils/pathCaster.h @@ -35,18 +35,18 @@ struct PathCaster { private: - static PyObject* unicode_from_fs_native(const std::string& w) + static PyObject* unicode_from_fs_native(std::string const& w) { return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size())); } - static PyObject* unicode_from_fs_native(const std::wstring& w) + static PyObject* unicode_from_fs_native(std::wstring const& w) { return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size())); } public: - static handle cast(const T& path, return_value_policy, handle) + static handle cast(T const& path, return_value_policy, handle) { if (auto py_str = unicode_from_fs_native(path.native())) { diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index 3e8810daa..b3052f4d8 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/cudaAllocator.h" #include "tensorrt_llm/common/tensorConversion.h" #include "tensorrt_llm/kernels/decodingKernels.h" +#include "tensorrt_llm/kernels/parallelDecoding/kvCacheUpdateKernels.h" #include "tensorrt_llm/layers/dynamicDecodeLayer.h" #include @@ -218,22 +219,14 @@ typename tl::DynamicDecodeLayer::OutputParams prepareOutputs( } outputParams.beamHypotheses = std::make_shared(); - if (output.beamHypotheses.outputIdsTgt) - { - outputParams.beamHypotheses->output_ids_tgt = bufferCast(*output.beamHypotheses.outputIdsTgt); - } - if (output.beamHypotheses.sequenceLengthsTgt) + if (output.beamHypotheses.isDone) { - outputParams.beamHypotheses->sequence_lengths_tgt = bufferCast(*output.beamHypotheses.sequenceLengthsTgt); + outputParams.beamHypotheses->is_done = bufferCast(*output.beamHypotheses.isDone); } if (output.beamHypotheses.cumLogProbs) { outputParams.beamHypotheses->cum_log_probs = bufferCast(*output.beamHypotheses.cumLogProbs); } - if (output.beamHypotheses.normedScores) - { - outputParams.beamHypotheses->normed_scores = bufferCast(*output.beamHypotheses.normedScores); - } if (output.beamHypotheses.logProbs) { outputParams.beamHypotheses->log_probs = bufferCast(*output.beamHypotheses.logProbs); @@ -242,13 +235,21 @@ typename tl::DynamicDecodeLayer::OutputParams prepareOutputs( { outputParams.beamHypotheses->min_normed_scores = bufferCast(*output.beamHypotheses.minNormedScores); } + if (output.beamHypotheses.normedScores) + { + outputParams.beamHypotheses->normed_scores = bufferCast(*output.beamHypotheses.normedScores); + } if (output.beamHypotheses.numBeams) { outputParams.beamHypotheses->num_beams = bufferCast(*output.beamHypotheses.numBeams); } - if (output.beamHypotheses.isDone) + if (output.beamHypotheses.outputIdsTgt) { - outputParams.beamHypotheses->is_done = bufferCast(*output.beamHypotheses.isDone); + outputParams.beamHypotheses->output_ids_tgt = bufferCast(*output.beamHypotheses.outputIdsTgt); + } + if (output.beamHypotheses.sequenceLengthsTgt) + { + outputParams.beamHypotheses->sequence_lengths_tgt = bufferCast(*output.beamHypotheses.sequenceLengthsTgt); } if (inputLengths) { @@ -316,6 +317,8 @@ void GptDecoder::forwardAsync(DecodingOutput& output, DecodingInput const& in auto outputParams = prepareOutputs(output, input.lengths, mLogProbsTiled); mDynamicDecodeLayer->forward(outputParams, forwardParams); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } // this should be similar to gatherTree in cpp/tensorrt_llm/thop/gatherTreeOp.cpp @@ -370,7 +373,7 @@ void GptDecoder::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de // This is where transpose is done tensorrt_llm::kernels::invokeInsertUnfinishedPath(beamHypotheses, - reinterpret_cast( + reinterpret_cast( bufferCast(*decodingOutput.finished)), bufferCast(*decodingOutput.cumLogProbs), batchSize, beamWidth, stream.get()); sync_check_cuda_error(); @@ -432,7 +435,7 @@ void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor tensorrt_llm::kernels::invokeAcceptDraftTokensByIds(bufferCast(draftTokenIds), bufferCast(targetTokenIds), bufferCast(contextLengths), bufferCast(numDraftTokens), bufferCast(sequenceLengths), - reinterpret_cast( + reinterpret_cast( bufferCast(finishedVec)), reinterpret_cast( bufferCast(finishedFinal)), @@ -492,3 +495,29 @@ void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } + +void IGptDecoder::updateKVCacheBasedOnAcceptedTokens(ITensor const& acceptedOffsets, ITensor const& packedAcceptedIds, + ITensor const& pointerArray, ITensor const& pastKeyValueLengths, GptModelConfig const& modelConfig, + WorldConfig const& worldConfig, BufferManager::CudaStreamPtr stream, SizeType rewindDraftTokenCount, + SizeType maxAttentionWindow, SizeType maxBlocksPerSeq, nvinfer1::DataType dtype) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto const numLayers = modelConfig.getNbLayers(worldConfig.getPipelineParallelism()); + auto const numKvHeads = modelConfig.getNbKvHeads(); + auto const tokensPerBlock = modelConfig.getTokensPerBlock(); + auto const sizeInBytesPerKVHead = modelConfig.getSizePerHead() * BufferDataType(dtype).getSize(); + auto* const* pointerArrayPtr = reinterpret_cast(bufferCast(pointerArray)); + auto const seqCount = acceptedOffsets.getShape().d[0] - 1; + TLLM_CHECK_WITH_INFO(seqCount > 0, "Number of offsets must be larger than 0"); + + tensorrt_llm::kernels::parallel_decoding::updateKVBlockArrayDraftTokenLocation( + bufferCast(acceptedOffsets), bufferCast(packedAcceptedIds), + bufferCast(pastKeyValueLengths), pointerArrayPtr, numLayers, seqCount, numKvHeads, + sizeInBytesPerKVHead, rewindDraftTokenCount, nullptr, maxAttentionWindow, maxBlocksPerSeq, tokensPerBlock, + stream->get()); + + sync_check_cuda_error(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp index 651cc73a4..073c5e7c7 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatch.cpp @@ -128,7 +128,10 @@ GptDecoderBatch::GptDecoderBatch( dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNED, TRTDataType::value); dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + mNextDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + mNextDraftTokenLengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + + TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__); } void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, @@ -229,6 +232,9 @@ void GptDecoderBatch::setup(DecodingMode const& mode, SizeType maxBatchSize, Siz const_cast(*dInput.stopWordsLens).reshape(ITensor::makeShape({maxBatchSize})); auto const numOfDecoders = fusedDecoder ? 1 : maxBatchSize; + mNextDraftTokens->reshape(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep - 1})); + mNextDraftTokenLengths->reshape(ITensor::makeShape({maxBatchSize})); + mStreams.resize(maxBatchSize); mDecoders.resize(numOfDecoders); mDecodingInputs.resize(maxBatchSize); @@ -417,41 +423,53 @@ void GptDecoderBatch::newRequest( dOutput->beamHypotheses.init(manager, endId); } - auto generatedTokensPerStep = request.generatedTokensPerStep(); + auto generatedTokensPerStep = request.generatedTokensPerStep; if (generatedTokensPerStep > 1) { TLLM_CHECK(beamWidth == 1); - auto numDraftTokens = generatedTokensPerStep - 1; - TensorPtr draftTokensReqBatchSlice = std::move(ITensor::slice(mDraftTokenIds, batchIdx, 1)); - draftTokensReqBatchSlice->squeeze(0); - TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); - TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); - manager.copy(*draftTokensView, *draftTokensReqTokensSlice); mAcceptByLogits[batchIdx] = false; - if (request.draftLogits.has_value()) + auto const numDraftTokens = generatedTokensPerStep - 1; + // If draft tokens are given with context at decoder setup it is target model in speculative decoding + if (request.draftTokens) { - TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); - mAcceptByLogits[batchIdx] = true; - - TensorPtr draftLogitsReqBatchSlice = std::move(ITensor::slice(mDraftLogits, batchIdx, 1)); - draftLogitsReqBatchSlice->squeeze(0); - TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); - manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); - } - - auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, localBatchSize); - kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream); + if (request.draftLogits.has_value()) + { + TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); + mAcceptByLogits[batchIdx] = true; - auto const curandStatesView = ITensor::slice(mCurandStates, batchIdx, localBatchSize); - auto curandState = reinterpret_cast(bufferCast(*curandStatesView)); - if (samplingConfig.randomSeed.has_value()) - { - tk::invokeCurandInitialize( - curandState, nullptr, localBatchSize, samplingConfig.randomSeed.value()[0], stream->get()); + TensorPtr draftLogitsReqBatchSlice = std::move(ITensor::slice(mDraftLogits, batchIdx, 1)); + draftLogitsReqBatchSlice->squeeze(0); + TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); + manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); + } + TensorPtr draftTokensReqBatchSlice = std::move(ITensor::slice(mDraftTokenIds, batchIdx, 1)); + draftTokensReqBatchSlice->squeeze(0); + TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); + TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); + manager.copy(*draftTokensView, *draftTokensReqTokensSlice); + + auto const curandStatesView = ITensor::slice(mCurandStates, batchIdx, localBatchSize); + auto curandState = reinterpret_cast(bufferCast(*curandStatesView)); + if (samplingConfig.randomSeed.has_value()) + { + tk::invokeCurandInitialize( + curandState, nullptr, localBatchSize, samplingConfig.randomSeed.value()[0], stream->get()); + } + else + { + tk::invokeCurandInitialize(curandState, nullptr, localBatchSize, 0, stream->get()); + } + auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, localBatchSize); + kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream); } - else + else // Medusa { - tk::invokeCurandInitialize(curandState, nullptr, localBatchSize, 0, stream->get()); + auto nextDraftTokenLengthsView = ITensor::slice(mNextDraftTokenLengths, batchIdx, localBatchSize); + kernels::invokeFill(*nextDraftTokenLengthsView, numDraftTokens, *stream); + TensorPtr nextDraftTokens = ITensor::slice(mNextDraftTokens, batchIdx, localBatchSize); + manager.setZero(*nextDraftTokens); + // FIXME(nkorobov): skip decoding draft tokens at medusa for now + generatedTokensPerStep = 1; } } @@ -653,8 +671,8 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync( { // These params are only used for testing. Thus, can be per batch instead of per request auto const& samplingConfig = decoder.getSamplingConfig(); - const bool useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value(); - const float randomAcceptanceThreshold + bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value(); + float const randomAcceptanceThreshold = useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0]; TensorPtr batchSlotsAcceptLogitsStepSlice = std::move(ITensor::slice(mBatchSlotsAcceptLogits, si, 1)); diff --git a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp index 7879e6883..ba3816660 100644 --- a/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp +++ b/cpp/tensorrt_llm/runtime/gptJsonConfig.cpp @@ -58,12 +58,12 @@ std::optional parseJsonFieldOptional(Json const& json, std::string_vi { value = json.at(name).template get(); } - catch (const nlohmann::json::out_of_range& e) + catch (nlohmann::json::out_of_range const& e) { TLLM_LOG_WARNING(e.what()); TLLM_LOG_WARNING("Optional value for parameter %s will not be set.", std::string(name).c_str()); } - catch (const nlohmann::json::type_error& e) + catch (nlohmann::json::type_error const& e) { TLLM_LOG_WARNING(e.what()); TLLM_LOG_WARNING("Optional value for parameter %s will not be set.", std::string(name).c_str()); diff --git a/cpp/tensorrt_llm/runtime/gptSession.cpp b/cpp/tensorrt_llm/runtime/gptSession.cpp index 894dcd167..eac6d8687 100644 --- a/cpp/tensorrt_llm/runtime/gptSession.cpp +++ b/cpp/tensorrt_llm/runtime/gptSession.cpp @@ -231,7 +231,7 @@ void GptSession::createCustomAllReduceWorkspace( for (size_t memIdx = 0; memIdx < mIpcMemoryHandles.size(); memIdx++) { - const auto& memCommPtrs = mIpcMemoryHandles[memIdx]->getCommPtrsTensor(); + auto const& memCommPtrs = mIpcMemoryHandles[memIdx]->getCommPtrsTensor(); for (SizeType tpIdx = 0; tpIdx < mWorldConfig.getTensorParallelism(); tpIdx++) { commPtrsData[memIdx * mWorldConfig.getTensorParallelism() + tpIdx] = memCommPtrs[tpIdx]; @@ -354,7 +354,7 @@ void GptSession::kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, TLLM_CHECK(mKvCacheManager); auto contextLengthsHost = mBuffers.at(microBatchId)->contextLengthsHost; TLLM_CHECK(contextLengthsHost); - const auto* const contextLengthsPtr = bufferCast(*contextLengthsHost); + auto const* const contextLengthsPtr = bufferCast(*contextLengthsHost); auto const contextLengthsSize = static_cast(contextLengthsHost->getSize()); for (SizeType batchIdx = 0; batchIdx < contextLengthsSize; ++batchIdx) { diff --git a/cpp/tensorrt_llm/runtime/ipcUtils.cpp b/cpp/tensorrt_llm/runtime/ipcUtils.cpp index e8f55756f..fee181cc7 100644 --- a/cpp/tensorrt_llm/runtime/ipcUtils.cpp +++ b/cpp/tensorrt_llm/runtime/ipcUtils.cpp @@ -22,7 +22,7 @@ namespace tensorrt_llm::runtime void setPeerAccess(WorldConfig const& worldConfig, bool enable) { - const auto srcNode = worldConfig.getTensorParallelRank(); + auto const srcNode = worldConfig.getTensorParallelRank(); for (SizeType destNode = 0; destNode < worldConfig.getTensorParallelism(); destNode++) { @@ -42,7 +42,7 @@ void setPeerAccess(WorldConfig const& worldConfig, bool enable) { cudaDeviceDisablePeerAccess(destNode); } - const auto error = cudaGetLastError(); + auto const error = cudaGetLastError(); if (error != cudaErrorPeerAccessAlreadyEnabled && error != cudaErrorPeerAccessNotEnabled) { TLLM_CUDA_CHECK(error); @@ -66,8 +66,8 @@ void IpcMemory::allocateIpcMemory() cudaIpcMemHandle_t localHandle; TLLM_CUDA_CHECK(cudaIpcGetMemHandle(&localHandle, mBufferPtr)); - const auto tpRank = mWorldConfig.getTensorParallelRank(); - const auto ppRank = mWorldConfig.getPipelineParallelRank(); + auto const tpRank = mWorldConfig.getTensorParallelRank(); + auto const ppRank = mWorldConfig.getPipelineParallelRank(); auto const comm = COMM_SESSION.split(ppRank, tpRank); std::vector serialHandles(CUDA_IPC_HANDLE_SIZE * mWorldConfig.getTensorParallelism(), 0); comm.allgather(&localHandle.reserved, serialHandles.data(), CUDA_IPC_HANDLE_SIZE, mpi::MpiType::kBYTE); diff --git a/cpp/tensorrt_llm/runtime/loraUtils.cpp b/cpp/tensorrt_llm/runtime/loraUtils.cpp index 8a8f3fdba..7baf8652b 100644 --- a/cpp/tensorrt_llm/runtime/loraUtils.cpp +++ b/cpp/tensorrt_llm/runtime/loraUtils.cpp @@ -20,8 +20,8 @@ namespace tensorrt_llm::runtime::lora { -void loraValidateRequestTensorDims(const std::optional& optReqLoraWeights, - const std::optional& optReqLoraConfig) +void loraValidateRequestTensorDims(std::optional const& optReqLoraWeights, + std::optional const& optReqLoraConfig) { TLLM_CHECK_WITH_INFO(optReqLoraWeights.has_value() && optReqLoraConfig.has_value(), "Request for LoRA inference must have both lora_weights and lora_keys"); @@ -50,8 +50,8 @@ void loraValidateRequestTensorDims(const std::optional& optR keys->getShape().d[2] == expectedLoraConfigValues, "Expected dim2 of lora_keys to have a size of 3"); } -void loraValidateRequestTensors(const std::optional& optReqLoraWeights, - const std::optional& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, +void loraValidateRequestTensors(std::optional const& optReqLoraWeights, + std::optional const& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig) { SizeType constexpr expectedLoraConfigValues = 3; diff --git a/cpp/tensorrt_llm/runtime/loraUtils.h b/cpp/tensorrt_llm/runtime/loraUtils.h index 9c4a3d4b8..039b1d269 100644 --- a/cpp/tensorrt_llm/runtime/loraUtils.h +++ b/cpp/tensorrt_llm/runtime/loraUtils.h @@ -27,10 +27,10 @@ SizeType constexpr kLORA_CONFIG_ADAPTER_SIZE_OFF = 2; SizeType constexpr kLORA_NUM_WEIGHTS_POINTERS = 2; -void loraValidateRequestTensorDims(const std::optional& optReqLoraWeights, - const std::optional& optReqLoraConfig); +void loraValidateRequestTensorDims(std::optional const& optReqLoraWeights, + std::optional const& optReqLoraConfig); -void loraValidateRequestTensors(const std::optional& optReqLoraWeights, - const std::optional& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, +void loraValidateRequestTensors(std::optional const& optReqLoraWeights, + std::optional const& optReqLoraConfig, runtime::GptModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig); } // namespace tensorrt_llm::runtime::lora diff --git a/cpp/tensorrt_llm/runtime/ncclCommunicator.h b/cpp/tensorrt_llm/runtime/ncclCommunicator.h index 76dc38701..754f3f1d6 100644 --- a/cpp/tensorrt_llm/runtime/ncclCommunicator.h +++ b/cpp/tensorrt_llm/runtime/ncclCommunicator.h @@ -42,8 +42,8 @@ class NcclCommunicator ~NcclCommunicator(); // no copy - NcclCommunicator(const NcclCommunicator&) = delete; - NcclCommunicator& operator=(const NcclCommunicator&) = delete; + NcclCommunicator(NcclCommunicator const&) = delete; + NcclCommunicator& operator=(NcclCommunicator const&) = delete; void send(IBuffer const& buf, int peer, CudaStream const& stream) const { diff --git a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp index 7b4b8b870..12f4ac39b 100644 --- a/cpp/tensorrt_llm/runtime/promptTuningParams.cpp +++ b/cpp/tensorrt_llm/runtime/promptTuningParams.cpp @@ -20,8 +20,8 @@ namespace tensorrt_llm::runtime { void PromptTuningParams::fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, - const SizeType numContextRequests, const std::vector& reqBeamWidths, - const std::vector& reqPromptLengths, BufferManager const& manager, bool packedInput) + const SizeType numContextRequests, std::vector const& reqBeamWidths, + std::vector const& reqPromptLengths, BufferManager const& manager, bool packedInput) { auto const& tasksHostShape = tasksHost->getShape(); TLLM_CHECK_WITH_INFO(tasksHostShape.nbDims == 1, "tasksHost expected to have dimension [batchSize]"); diff --git a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp index b4b773541..17b230fbf 100644 --- a/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp +++ b/cpp/tensorrt_llm/runtime/runtimeBuffers.cpp @@ -940,8 +940,8 @@ void RuntimeBuffers::getRuntimeBuffers(TensorMap& inputBuffers, TensorMap& outpu TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -std::vector RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeType& batchSize, - const SizeType& maxInputLength, const SizeType* pInputLengths, bool useGptAttentionPlugin, bool usePackedInput) +std::vector RuntimeBuffers::getPositionIdsContextPhaseGlm(SizeType const& batchSize, + SizeType const& maxInputLength, SizeType const* pInputLengths, bool useGptAttentionPlugin, bool usePackedInput) { TLLM_CHECK(pInputLengths != nullptr); @@ -991,8 +991,8 @@ std::vector RuntimeBuffers::getPositionIdsContextPhaseGlm(const SizeTy return positionIdsVec; } -std::vector RuntimeBuffers::getPositionIdsGenerationPhaseGlm(const SizeType& batchSize, - const SizeType& beamSize, const SizeType& step, const SizeType* pInputLengths, bool useGptAttentionPlugin, +std::vector RuntimeBuffers::getPositionIdsGenerationPhaseGlm(SizeType const& batchSize, + SizeType const& beamSize, SizeType const& step, SizeType const* pInputLengths, bool useGptAttentionPlugin, bool usePackedInput) { TLLM_CHECK(pInputLengths != nullptr); diff --git a/cpp/tensorrt_llm/runtime/runtimeBuffers.h b/cpp/tensorrt_llm/runtime/runtimeBuffers.h index ea830002d..ae5fdf503 100644 --- a/cpp/tensorrt_llm/runtime/runtimeBuffers.h +++ b/cpp/tensorrt_llm/runtime/runtimeBuffers.h @@ -177,13 +177,13 @@ class RuntimeBuffers // Some tensors are properly tiled, some are just reshaped. void tile(BufferManager& manager, GptModelConfig const& modelConfig, WorldConfig const& worldConfig); - static std::vector getPositionIdsContextPhaseGlm(const SizeType& batchSize, - const SizeType& maxInputLength, const SizeType* pInputLengths, const bool useGptAttentionPlugin, - const bool usePackedInput); + static std::vector getPositionIdsContextPhaseGlm(SizeType const& batchSize, + SizeType const& maxInputLength, SizeType const* pInputLengths, bool const useGptAttentionPlugin, + bool const usePackedInput); - static std::vector getPositionIdsGenerationPhaseGlm(const SizeType& batchSize, const SizeType& beamSize, - const SizeType& step, const SizeType* pInputLengths, const bool useGptAttentionPlugin, - const bool usePackedInput); + static std::vector getPositionIdsGenerationPhaseGlm(SizeType const& batchSize, SizeType const& beamSize, + SizeType const& step, SizeType const* pInputLengths, bool const useGptAttentionPlugin, + bool const usePackedInput); }; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/runtimeKernels.cu b/cpp/tensorrt_llm/runtime/runtimeKernels.cu index 8bd79e050..44c817ede 100644 --- a/cpp/tensorrt_llm/runtime/runtimeKernels.cu +++ b/cpp/tensorrt_llm/runtime/runtimeKernels.cu @@ -60,6 +60,7 @@ void invokeFill(IBuffer& buffer, T const value, CudaStream const& stream) } // template instantiation +template void invokeFill(IBuffer&, std::int64_t, CudaStream const&); template void invokeFill(IBuffer&, std::int32_t, CudaStream const&); template void invokeFill(IBuffer&, std::int8_t, CudaStream const&); template void invokeFill(IBuffer&, std::uint8_t, CudaStream const&); @@ -111,7 +112,7 @@ template void invokeFillBatch(IBuffer&, IBuffer const&, std::size_ namespace { template -__global__ void copyBatch(const uint8_t* srcData, uint8_t* dstData, std::int32_t const* srcOffsets, +__global__ void copyBatch(uint8_t const* srcData, uint8_t* dstData, std::int32_t const* srcOffsets, std::int32_t const* dstOffsets, std::int32_t const* sizes, std::int32_t const dataTypeSize) { constexpr auto VEC_ELTS = static_cast(sizeof(VecT)); @@ -127,7 +128,7 @@ __global__ void copyBatch(const uint8_t* srcData, uint8_t* dstData, std::int32_t for (; srcIdx < srcEndIdx; srcIdx += stride, dstIdx += stride) { - *reinterpret_cast(&dstData[dstIdx]) = *reinterpret_cast(&srcData[srcIdx]); + *reinterpret_cast(&dstData[dstIdx]) = *reinterpret_cast(&srcData[srcIdx]); } } } // namespace @@ -135,7 +136,7 @@ __global__ void copyBatch(const uint8_t* srcData, uint8_t* dstData, std::int32_t void invokeCopyBatch(IBuffer const& srcBuffer, IBuffer& dstBuffer, IBuffer const& srcOffsets, IBuffer const& dstOffsets, IBuffer const& sizes, std::size_t maxStride, CudaStream const& stream) { - auto srcDataPtr = reinterpret_cast(srcBuffer.data()); + auto srcDataPtr = reinterpret_cast(srcBuffer.data()); auto dstDataPtr = reinterpret_cast(dstBuffer.data()); auto srcOffsetsPtr = bufferCast(srcOffsets); auto dstOffsetsPtr = bufferCast(dstOffsets); @@ -1118,7 +1119,7 @@ void gatherLastTokenLogits(ITensor& output, ITensor const& input, ITensor const& // block copies a `vocabSizePadded` length logits tensor from the "inputLogits (microBatchSize, beamWidth, // vocabSizePadded)" to the "outputGenerationLogits (batchSize, beamWidth, outputLen, vocabSizePadded)" template -__global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, const int outputLen, int firstBatchSlotIdx, +__global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, int const outputLen, int firstBatchSlotIdx, int microBatchSize, int beamWidth, int vocabSizePadded, int stepOffset) { // output: shape: [batchSize, beamWidth, outputLen, vocabSize] @@ -1137,13 +1138,13 @@ __global__ void mergeLogitsFragmentsKernel(T* output, T** fragmentsVector, const int mbeamIdx = blockIdx.x % beamWidth; // The output pointer - const unsigned int outputOffset + unsigned int const outputOffset = (absoluteBatchSlotIdx * beamWidth * outputLen + mbeamIdx * outputLen + curStep + stepOffset) * vocabSizePadded; T* outputPtr = &output[outputOffset]; - const unsigned int inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded; + unsigned int const inputOffset = (relativeBatchSlotIdx * beamWidth + mbeamIdx) * vocabSizePadded; // The input pointer. T const* inputPtr = &fragmentsVector[curStep][inputOffset]; diff --git a/cpp/tensorrt_llm/runtime/tllmBuffers.h b/cpp/tensorrt_llm/runtime/tllmBuffers.h index c7f402ca6..0d83ca189 100644 --- a/cpp/tensorrt_llm/runtime/tllmBuffers.h +++ b/cpp/tensorrt_llm/runtime/tllmBuffers.h @@ -378,7 +378,7 @@ void MemoryPool::allocateImpl(MemoryPool::PointerType* ptr, MemoryPo // Finds first free segment providing sufficient space auto it = std::find_if(mMemorySegments.begin(), mMemorySegments.end(), - [alignedRequest](const auto& ms) { return ms.tag == nullptr && ms.size >= alignedRequest; }); + [alignedRequest](auto const& ms) { return ms.tag == nullptr && ms.size >= alignedRequest; }); if (it == mMemorySegments.end()) { @@ -421,7 +421,7 @@ void MemoryPool::deallocateImpl(PointerType tag, SizeType n) { std::lock_guard lock(mLock); auto it = std::find_if(mMemorySegments.begin(), mMemorySegments.end(), - [&tag](const MemorySegment& segment) { return segment.tag == tag; }); + [&tag](MemorySegment const& segment) { return segment.tag == tag; }); TLLM_CHECK_WITH_INFO(it != mMemorySegments.end(), "MemoryPool free: Requested tag %p could not be found", tag); diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu index 7533a1a97..ef1e3a190 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.cu +++ b/cpp/tensorrt_llm/runtime/utils/debugUtils.cu @@ -22,7 +22,7 @@ namespace { -__global__ void checkTensorNanKernel(const float* data, std::size_t size, int* foundNan) +__global__ void checkTensorNanKernel(float const* data, std::size_t size, int* foundNan) { auto tidx = blockIdx.x * blockDim.x + threadIdx.x; @@ -47,18 +47,18 @@ namespace tc = tensorrt_llm::common; namespace tensorrt_llm::runtime::utils { -void invokeCheckTensorNanKernel(const float* data, std::size_t size, int* foundNan, cudaStream_t stream) +void invokeCheckTensorNanKernel(float const* data, std::size_t size, int* foundNan, cudaStream_t stream) { constexpr uint32_t kThreadsPerCta = 256; checkTensorNanKernel<<>>(data, size, foundNan); } -bool tensorHasNan(const IBuffer& tensor, BufferManager& manager) +bool tensorHasNan(IBuffer const& tensor, BufferManager& manager) { auto foundNan = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32); auto foundNanPtr = bufferCast(*foundNan); foundNanPtr[0] = 0; - const auto size = tensor.getSize(); + auto const size = tensor.getSize(); invokeCheckTensorNanKernel(bufferCast(tensor), size, foundNanPtr, manager.getStream().get()); manager.getStream().synchronize(); return static_cast(foundNanPtr[0]); diff --git a/cpp/tensorrt_llm/runtime/utils/debugUtils.h b/cpp/tensorrt_llm/runtime/utils/debugUtils.h index 92243f55c..9ebbfdee7 100644 --- a/cpp/tensorrt_llm/runtime/utils/debugUtils.h +++ b/cpp/tensorrt_llm/runtime/utils/debugUtils.h @@ -23,7 +23,7 @@ namespace tensorrt_llm::runtime namespace utils { -bool tensorHasNan(const IBuffer& tensor, BufferManager& manager); +bool tensorHasNan(IBuffer const& tensor, BufferManager& manager); } } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp index eab11d29b..6926ddc39 100644 --- a/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp +++ b/cpp/tensorrt_llm/runtime/utils/numpyUtils.cpp @@ -61,7 +61,7 @@ nvinfer1::DataType typeFromNumpyDesc(std::string type) void parseNpyIntro(FILE*& f_ptr, uint32_t& header_len, uint32_t& start_data) { - const char magic[] + char const magic[] = "\x93" "NUMPY"; char magic_test[sizeof(magic)] = "\0"; @@ -144,7 +144,7 @@ int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, nvinfer1::DataType& type, } //! \brief Create new tensor from numpy file. -[[nodiscard]] ITensor::UniquePtr loadNpy(BufferManager& manager, const std::string& npyFile, const MemoryType where) +[[nodiscard]] ITensor::UniquePtr loadNpy(BufferManager& manager, std::string const& npyFile, const MemoryType where) { FILE* f_ptr = fopen(npyFile.c_str(), "rb"); if (f_ptr == nullptr) @@ -180,7 +180,7 @@ int parseNpyHeader(FILE*& f_ptr, uint32_t header_len, nvinfer1::DataType& type, return tensor; } -void saveNpy(BufferManager& manager, ITensor const& tensor, const std::string& filename) +void saveNpy(BufferManager& manager, ITensor const& tensor, std::string const& filename) { // Save tensor to NPY 1.0 format (see https://numpy.org/neps/nep-0001-npy-format.html) auto const tensorSize = tensor.getSize(); @@ -209,7 +209,7 @@ void saveNpy(BufferManager& manager, ITensor const& tensor, const std::string& f return; } - const char magic[] + char const magic[] = "\x93" "NUMPY"; const uint8_t npy_major = 1; diff --git a/cpp/tensorrt_llm/runtime/utils/numpyUtils.h b/cpp/tensorrt_llm/runtime/utils/numpyUtils.h index 6d804ed6c..cba2a2de3 100644 --- a/cpp/tensorrt_llm/runtime/utils/numpyUtils.h +++ b/cpp/tensorrt_llm/runtime/utils/numpyUtils.h @@ -25,9 +25,9 @@ namespace tensorrt_llm::runtime::utils { //! \brief Create new tensor from numpy file. -[[nodiscard]] ITensor::UniquePtr loadNpy(BufferManager& manager, const std::string& npyFile, const MemoryType where); +[[nodiscard]] ITensor::UniquePtr loadNpy(BufferManager& manager, std::string const& npyFile, const MemoryType where); //! \brief Save tensor to numpy file. -void saveNpy(BufferManager& manager, ITensor const& tensor, const std::string& filename); +void saveNpy(BufferManager& manager, ITensor const& tensor, std::string const& filename); } // namespace tensorrt_llm::runtime::utils diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp index b709a8667..755fb5b38 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.cpp @@ -31,7 +31,7 @@ namespace torch_ext template FtDynamicDecode::FtDynamicDecode(const size_t max_batch_size, const size_t max_beam_width, const size_t vocab_size, - const size_t vocab_size_padded, const int tensor_para_size, const int pipeline_para_size) + const size_t vocab_size_padded, int const tensor_para_size, int const pipeline_para_size) : vocab_size_(vocab_size) , vocab_size_padded_(vocab_size_padded) , finished_sum_(tr::BufferManager::pinned( @@ -108,20 +108,17 @@ void FtDynamicDecode::setup(size_t batch_size, size_t beam_width, th::optiona th::optional top_p_decay_opt, th::optional top_p_min_opt, th::optional top_p_reset_ids_opt) { - // unused: length_penalty_opt, beam_search_diversity_rate_opt, early_stopping_opt - auto stream = at::cuda::getCurrentCUDAStream().stream(); dynamic_decode_layer_->setStream(stream); SetupParams setupParams; - - safeInsert(runtime_top_k_opt, setupParams.runtime_top_k); - safeInsert(runtime_top_p_opt, setupParams.runtime_top_p); safeInsert(temperature_opt, setupParams.temperature); safeInsert(repetition_penalty_opt, setupParams.repetition_penalty); safeInsert(presence_penalty_opt, setupParams.presence_penalty); safeInsert(frequency_penalty_opt, setupParams.frequency_penalty); safeInsert(min_length_opt, setupParams.min_length); + safeInsert(runtime_top_k_opt, setupParams.runtime_top_k); + safeInsert(runtime_top_p_opt, setupParams.runtime_top_p); safeInsert(random_seed_opt, setupParams.randomSeed); safeInsert(top_p_decay_opt, setupParams.top_p_decay); safeInsert(top_p_min_opt, setupParams.top_p_min); @@ -134,17 +131,15 @@ void FtDynamicDecode::setup(size_t batch_size, size_t beam_width, th::optiona } template -void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, hidden_size) - int step, int max_input_length, int max_attention_window, int sink_token_length, uint64_t ite, int local_batch_size, - th::Tensor end_id, th::optional embedding_bias_opt, th::optional input_lengths_opt, +void FtDynamicDecode::forward(th::Tensor& logits, int step, int max_input_length, int max_attention_window, + int sink_token_length, uint64_t ite, int local_batch_size, th::Tensor end_id, + th::optional embedding_bias_opt, th::optional input_lengths_opt, th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, th::optional stop_words_lens_opt, int32_t max_stop_words_len, th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, int32_t max_bad_words_len, th::optional no_repeat_ngram_size_opt, - th::optional src_cache_indirection_opt, - // Outputs - th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, - th::optional finished_input, th::optional finished_output, + th::optional src_cache_indirection_opt, th::Tensor& output_token_ids, th::Tensor& newTokens, + th::Tensor& should_stop, th::optional finished_input, th::optional finished_output, th::optional sequence_lengths_opt, th::optional cum_log_probs_opt, th::optional output_log_probs_opt, th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, th::optional tgt_cache_indirection_opt, @@ -153,32 +148,36 @@ void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, th::optional beam_hyps_log_probs_opt, th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, th::optional beam_hyps_is_done_opt, bool use_beam_hyps) - { - auto const& logits_converted = convert_tensor(logits); - auto const& end_ids_converted = convert_tensor(end_id); typename tensorrt_llm::layers::DynamicDecodeLayer::ForwardParams forwardParams{step, static_cast(ite), - max_input_length, max_attention_window, sink_token_length, local_batch_size, end_ids_converted}; + max_input_length, max_attention_window, sink_token_length, local_batch_size, convert_tensor(end_id)}; + + forwardParams.logits = convert_tensor(logits); - forwardParams.logits = logits_converted; - safeUpdate(src_cache_indirection_opt, forwardParams.src_cache_indirection); - safeUpdate(sequence_limit_length_opt, forwardParams.sequence_limit_length); safeUpdate(embedding_bias_opt, forwardParams.embedding_bias); safeUpdate(input_lengths_opt, forwardParams.input_lengths); - safeUpdate(bad_words_list_ptrs_opt, forwardParams.bad_words_ptr); - safeUpdate(bad_words_lens_opt, forwardParams.bad_words_lengths); - forwardParams.max_bad_words_len = max_bad_words_len; + safeUpdate(sequence_limit_length_opt, forwardParams.sequence_limit_length); safeUpdate(stop_words_list_ptrs_opt, forwardParams.stop_words_ptr); safeUpdate(stop_words_lens_opt, forwardParams.stop_words_lengths); forwardParams.max_stop_words_len = max_stop_words_len; + safeUpdate(bad_words_list_ptrs_opt, forwardParams.bad_words_ptr); + safeUpdate(bad_words_lens_opt, forwardParams.bad_words_lengths); + forwardParams.max_bad_words_len = max_bad_words_len; safeUpdate(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size); - safeUpdate(finished_input, forwardParams.finished); + safeUpdate(src_cache_indirection_opt, forwardParams.src_cache_indirection); auto const& output_ids_converted = convert_tensor(output_token_ids); typename tensorrt_llm::layers::DynamicDecodeLayer::OutputParams outputParams{output_ids_converted}; outputParams.newTokens = std::move(convert_tensor(newTokens)); - + safeUpdate(finished_input, forwardParams.finished); safeUpdate(finished_output, outputParams.finished); + safeUpdate(sequence_lengths_opt, outputParams.sequence_length); + safeUpdate(cum_log_probs_opt, outputParams.cum_log_probs); + safeUpdate(output_log_probs_opt, outputParams.output_log_probs); + safeUpdate(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled); + safeUpdate(parent_ids_opt, outputParams.parent_ids); + safeUpdate(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection); + std::int32_t* finished_sum_host = nullptr; if (forwardParams.sequence_limit_length && outputParams.finished.has_value()) { @@ -189,24 +188,19 @@ void FtDynamicDecode::forward(th::Tensor& logits, // (batch_size, beam_width, finished_sum_host[bi] = 0; } } - safeUpdate(sequence_lengths_opt, outputParams.sequence_length); - safeUpdate(parent_ids_opt, outputParams.parent_ids); - safeUpdate(cum_log_probs_opt, outputParams.cum_log_probs); - safeUpdate(output_log_probs_opt, outputParams.output_log_probs); - safeUpdate(output_log_probs_tiled_opt, outputParams.output_log_probs_tiled); - safeUpdate(tgt_cache_indirection_opt, outputParams.tgt_cache_indirection); if (use_beam_hyps) { outputParams.beamHypotheses = std::make_shared(); - safeUpdatePtr(beam_hyps_output_ids_tgt_opt, outputParams.beamHypotheses->output_ids_tgt); - safeUpdatePtr(beam_hyps_sequence_lengths_tgt_opt, outputParams.beamHypotheses->sequence_lengths_tgt); + safeUpdatePtr(beam_hyps_is_done_opt, outputParams.beamHypotheses->is_done); safeUpdatePtr(beam_hyps_cum_log_probs_opt, outputParams.beamHypotheses->cum_log_probs); - safeUpdatePtr(beam_hyps_normed_scores_opt, outputParams.beamHypotheses->normed_scores); safeUpdatePtr(beam_hyps_log_probs_opt, outputParams.beamHypotheses->log_probs); safeUpdatePtr(beam_hyps_min_normed_scores_opt, outputParams.beamHypotheses->min_normed_scores); + safeUpdatePtr(beam_hyps_normed_scores_opt, outputParams.beamHypotheses->normed_scores); safeUpdatePtr(beam_hyps_num_beams_opt, outputParams.beamHypotheses->num_beams); - safeUpdatePtr(beam_hyps_is_done_opt, outputParams.beamHypotheses->is_done); + safeUpdatePtr(beam_hyps_output_ids_tgt_opt, outputParams.beamHypotheses->output_ids_tgt); + safeUpdatePtr(beam_hyps_sequence_lengths_tgt_opt, outputParams.beamHypotheses->sequence_lengths_tgt); + // TODO: move the assignment below into onlineBeamSearchLayer.cu safeUpdatePtr(input_lengths_opt, outputParams.beamHypotheses->input_lengths); } @@ -269,7 +263,6 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional // TODO: Revise DynamicDecodeLayer and make the decode arguments consistent. CHECK_OPTIONAL_CPU_INPUT(runtime_top_k_opt, torch::kInt32); CHECK_OPTIONAL_CPU_INPUT(runtime_top_p_opt, torch::kFloat); - CHECK_OPTIONAL_CPU_INPUT(temperature_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(repetition_penalty_opt, torch::kFloat); CHECK_OPTIONAL_CPU_INPUT(presence_penalty_opt, torch::kFloat); @@ -283,53 +276,55 @@ void DynamicDecodeOp::setup(int64_t batch_size, int64_t beam_width, th::optional CHECK_OPTIONAL_INPUT(top_p_min_opt, torch::kFloat); CHECK_OPTIONAL_INPUT(top_p_reset_ids_opt, torch::kInt32); + // TODO: add a parameter "return_normed_score" to return normed_cum_log_probs / cum_log_probs + dynamic_decode_->setup(static_cast(batch_size), static_cast(beam_width), runtime_top_k_opt, runtime_top_p_opt, temperature_opt, repetition_penalty_opt, presence_penalty_opt, frequency_penalty_opt, min_length_opt, length_penalty_opt, early_stopping_opt, beam_search_diversity_rate_opt, random_seed_opt, top_p_decay_opt, top_p_min_opt, top_p_reset_ids_opt); } -th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max_input_length, - int64_t max_attention_window, int64_t sink_token_length, int64_t ite, int64_t local_batch_size, th::Tensor end_id, - th::optional embedding_bias_opt, - th::optional input_lengths_opt, // length of input contexts. - th::optional sequence_limit_length_opt, th::optional stop_words_list_ptrs_opt, - th::optional stop_words_lens_opt, int64_t max_stop_words_len, - th::optional bad_words_list_ptrs_opt, th::optional bad_words_lens_opt, - int64_t max_bad_words_len, th::optional no_repeat_ngram_size_opt, - th::optional src_cache_indirection_opt, - // output buffers. - th::Tensor output_token_ids, th::Tensor newTokens, th::optional finished_input, - th::optional finished_output, - th::optional seuqence_lengths_opt, // length of the current sequences. - th::optional cum_log_probs_opt, th::optional output_log_probs_opt, - th::optional output_log_probs_tiled_opt, th::optional parent_ids_opt, - th::optional tgt_cache_indirection_opt, th::optional beam_hyps_output_ids_tgt_opt, - th::optional beam_hyps_sequence_lengths_tgt_opt, th::optional beam_hyps_cum_log_probs_opt, - th::optional beam_hyps_normed_scores_opt, th::optional beam_hyps_log_probs_opt, - th::optional beam_hyps_min_normed_scores_opt, th::optional beam_hyps_num_beams_opt, - th::optional beam_hyps_is_done_opt, bool use_beam_hyps) +th::Tensor DynamicDecodeOp::forward( // BS: batch_size, BM: beam_width, mSL: max_seq_length + th::Tensor logits, // [BS, BM, vocab_size_padded], T + int64_t step, // + int64_t max_input_length, // + int64_t max_attention_window, // + int64_t sink_token_length, // + int64_t ite, // + int64_t local_batch_size, // + th::Tensor end_id, // [BS*BM], int + th::optional embedding_bias_opt, // [vocab_size_padded], T + th::optional input_lengths_opt, // [BS*BM], int, length of input contexts + th::optional sequence_limit_length_opt, // [BS, 1], int + th::optional stop_words_list_ptrs_opt, // [BS][2, stop_words_length], int64 + th::optional stop_words_lens_opt, // [BS], int + int64_t max_stop_words_len, // + th::optional bad_words_list_ptrs_opt, // [BS][2, bad_words_length], int64 + th::optional bad_words_lens_opt, // [BS], int + int64_t max_bad_words_len, // + th::optional no_repeat_ngram_size_opt, // [BS], int + th::optional src_cache_indirection_opt, // [local_BS, BM, mSL], int + th::Tensor output_token_ids, // [BS, BM, mSL], int ? [mSL, BS, BM] + th::Tensor newTokens, // [BS, BM, 1], int + th::optional finished_input, // [BS, BM], uint8 + th::optional finished_output, // [BS, BM], uint8 + th::optional seuqence_lengths_opt, // [BS*BM], int, length of the current sequences + th::optional cum_log_probs_opt, // [BS, BM], float + th::optional output_log_probs_opt, // [BS, BM, mSL], float ? [mSL, BS, BM] + th::optional output_log_probs_tiled_opt, // [mSL, BS, BM], float ? [BS, BM, mSL] + th::optional parent_ids_opt, // [BS, BM, mSL], int ? [mSL, BS, BM] + th::optional tgt_cache_indirection_opt, // [local_BS, BM, memory_length], int + th::optional beam_hyps_output_ids_tgt_opt, // [BS, BM*2, mSL], int + th::optional beam_hyps_sequence_lengths_tgt_opt, // [BS, BM*2], int + th::optional beam_hyps_cum_log_probs_opt, // [BS, BM*2], float + th::optional beam_hyps_normed_scores_opt, // [BS, BM*2], float + th::optional beam_hyps_log_probs_opt, // [BS, BM*2, mSL], float + th::optional beam_hyps_min_normed_scores_opt, // [BS], float + th::optional beam_hyps_num_beams_opt, // [BS], int + th::optional beam_hyps_is_done_opt, // [BS], bool + bool use_beam_hyps // +) { - // Input Arguments: - // logits: [batch_size, beam_width, vocab_size_padded], T - // end_id: [batch_size], int, optional - // embedding_bias: [vocab_size_padded], T, optional - // input_lengths: [batch_size * beam_width], int, optional - // sequence_limit_length: [batch_size], int, optional - // stop_words_list_ptrs: [batch_size][2, stop_words_length], int, optional - // stop_words_lens_ptrs: [batch_size], int, optional - // bad_words_list_ptrs: [batch_size][2, bad_words_length], int, optional - // bad_words_lens: [batch_size], int, optional - // src_cache_indirection: [local_batch_size, beam_width, memory_length], - // int, optional output_token_ids: [max_seq_length, batch_size, - // beam_width], int finished: [batch_size * beam_width], bool, optional - // sequence_lengths: [batch_size * beam_width], int, optional - // cum_log_probs: [batch_size * beam_width], float, optional - // output_log_probs: [gen_length, batch_size, beam_width], float, optional - // parent_ids: [gen_length, batch_size, beam_width], float, optional - // tgt_cache_indirection: [local_batch_size, beam_width, memory_length], - // float, optional - CHECK_INPUT(logits, scalar_type_); TLLM_CHECK_WITH_INFO(logits.dim() == 3, "logits is of shape (batch_size, beam_width, vocab_size_padded), but got dim=%d shape=%s", (int) logits.dim(), @@ -339,7 +334,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max static_cast(logits.size(2))); CHECK_INPUT(end_id, torch::kInt32); - + CHECK_OPTIONAL_INPUT(embedding_bias_opt, scalar_type_); CHECK_OPTIONAL_INPUT(input_lengths_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(sequence_limit_length_opt, torch::kInt32); CHECK_OPTIONAL_INPUT(stop_words_list_ptrs_opt, torch::kInt64); @@ -350,6 +345,7 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32); CHECK_INPUT(output_token_ids, torch::kInt32); + CHECK_INPUT(newTokens, torch::kInt32); CHECK_OPTIONAL_INPUT(finished_input, torch::kUInt8); CHECK_OPTIONAL_INPUT(finished_output, torch::kUInt8); CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32); diff --git a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h index be1e42264..38c0c3d94 100644 --- a/cpp/tensorrt_llm/thop/dynamicDecodeOp.h +++ b/cpp/tensorrt_llm/thop/dynamicDecodeOp.h @@ -67,7 +67,7 @@ class FtDynamicDecode : public IFtDynamicDecode using SetupParams = typename tensorrt_llm::layers::DynamicDecodeLayer::SetupParams; FtDynamicDecode(const size_t max_batch_size, const size_t max_beam_width, const size_t vocab_size, - const size_t vocab_size_padded, const int tensor_para_size, const int pipeline_para_size); + const size_t vocab_size_padded, int const tensor_para_size, int const pipeline_para_size); void setup(size_t batch_size, size_t beam_width, th::optional runtime_top_k_opt, th::optional runtime_top_p_opt, th::optional temperature_opt, diff --git a/cpp/tensorrt_llm/thop/fp8Op.cpp b/cpp/tensorrt_llm/thop/fp8Op.cpp index 231046a4d..e13bb78c8 100644 --- a/cpp/tensorrt_llm/thop/fp8Op.cpp +++ b/cpp/tensorrt_llm/thop/fp8Op.cpp @@ -61,7 +61,7 @@ std::vector e4m3_quantize_helper(Tensor input, QuantizeMode quantize_mod scale_shape.assign(input.dim(), 1); } - const auto is_cuda = input.is_cuda(); + auto const is_cuda = input.is_cuda(); input = input.cuda(); Tensor quantized_input @@ -75,19 +75,19 @@ std::vector e4m3_quantize_helper(Tensor input, QuantizeMode quantize_mod if (input.scalar_type() == at::ScalarType::Float) { - invokeComputeScalesAndQuantizeMatrix(quantized_input_ptr, get_ptr(scales), get_ptr(input), + invokeComputeScalesAndQuantizeMatrix(quantized_input_ptr, get_ptr(scales), get_ptr(input), input.numel(), input.size(-1), quantize_mode, stream); } else if (input.scalar_type() == at::ScalarType::Half) { - invokeComputeScalesAndQuantizeMatrix(quantized_input_ptr, get_ptr(scales), get_ptr(input), + invokeComputeScalesAndQuantizeMatrix(quantized_input_ptr, get_ptr(scales), get_ptr(input), input.numel(), input.size(-1), quantize_mode, stream); } #ifdef ENABLE_BF16 else if (input.scalar_type() == at::ScalarType::BFloat16) { invokeComputeScalesAndQuantizeMatrix(quantized_input_ptr, get_ptr<__nv_bfloat16>(scales), - get_ptr(input), input.numel(), input.size(-1), quantize_mode, stream); + get_ptr<__nv_bfloat16 const>(input), input.numel(), input.size(-1), quantize_mode, stream); } #endif else @@ -136,7 +136,7 @@ Tensor e4m3_dequantize_helper(Tensor input, Tensor scales, QuantizeMode quantize TORCH_CHECK(scales.size(i) == 1); } - const auto w_is_cuda = input.is_cuda(); + auto const w_is_cuda = input.is_cuda(); input = input.cuda(); scales = scales.cuda(); diff --git a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp index cc523411e..8b7c8ad01 100644 --- a/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp +++ b/cpp/tensorrt_llm/thop/weightOnlyQuantOp.cpp @@ -197,20 +197,20 @@ std::vector symmetric_quantize_helper( if (weight.scalar_type() == at::ScalarType::Float) { symmetric_quantize(processed_quantized_weight_ptr, unprocessed_quantized_weight_ptr, - get_ptr(scales), get_ptr(weight), {num_experts, num_rows, num_cols}, ft_quant_type, + get_ptr(scales), get_ptr(weight), {num_experts, num_rows, num_cols}, ft_quant_type, force_interleave); } else if (weight.scalar_type() == at::ScalarType::Half) { symmetric_quantize(processed_quantized_weight_ptr, unprocessed_quantized_weight_ptr, - get_ptr(scales), get_ptr(weight), {num_experts, num_rows, num_cols}, ft_quant_type, + get_ptr(scales), get_ptr(weight), {num_experts, num_rows, num_cols}, ft_quant_type, force_interleave); } #ifdef ENABLE_BF16 else if (weight.scalar_type() == at::ScalarType::BFloat16) { symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(processed_quantized_weight_ptr, - unprocessed_quantized_weight_ptr, get_ptr<__nv_bfloat16>(scales), get_ptr(weight), + unprocessed_quantized_weight_ptr, get_ptr<__nv_bfloat16>(scales), get_ptr<__nv_bfloat16 const>(weight), {num_experts, num_rows, num_cols}, ft_quant_type, force_interleave); } #endif diff --git a/cpp/tests/kernels/banRepeatNGramsKernelsTest.cpp b/cpp/tests/kernels/banRepeatNGramsKernelsTest.cpp index 5dd058063..2740936fb 100644 --- a/cpp/tests/kernels/banRepeatNGramsKernelsTest.cpp +++ b/cpp/tests/kernels/banRepeatNGramsKernelsTest.cpp @@ -49,30 +49,30 @@ class BanRepeatNgramKernelsTest : public testing::Test void TearDown() override {} - void initData(const std::vector>& outputIds, const std::vector& nGramSizes) + void initData(std::vector> const& outputIds, std::vector const& nGramSizes) { auto const ptrType = TRTDataType::value; SizeType const batchSize = outputIds.size(); auto const maxBatchSize = 2 * batchSize; - mLogits = mBufferManager->pinned(ITensor::makeShape({batchSize, mVocabSizePadded}), nvinfer1::DataType::kFLOAT); + mLogits = BufferManager::pinned(ITensor::makeShape({batchSize, mVocabSizePadded}), nvinfer1::DataType::kFLOAT); mSequenceLengths - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), nvinfer1::DataType::kINT32); - mFinished = mBufferManager->pinned( + = BufferManager::pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), nvinfer1::DataType::kINT32); + mFinished = BufferManager::pinned( ITensor::makeShape({maxBatchSize, mBeamWidth}), TRTDataType::value); - mOutputIds = mBufferManager->pinned( + mOutputIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); - mOutputIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), ptrType); + mOutputIdsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), ptrType); - mParentIds = mBufferManager->pinned( + mParentIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, mBeamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); - mParentIdsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), ptrType); + mParentIdsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize, mBeamWidth}), ptrType); - mNGramSizes = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mNGramSizes = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + mBatchSlots = BufferManager::pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < batchSize; ++bi) @@ -139,7 +139,7 @@ class BanRepeatNgramKernelsTest : public testing::Test } void verifyBanRepeatNGramResults( - const std::vector& nGramSizes, const std::vector& expectedLastId) + std::vector const& nGramSizes, std::vector const& expectedLastId) { auto const batchSize = expectedLastId.size(); @@ -161,8 +161,8 @@ class BanRepeatNgramKernelsTest : public testing::Test } } - void runBanRepeatNGramTest(const std::vector>& outputIds, - const std::vector& nGramSizes, const std::vector& expectedLastId) + void runBanRepeatNGramTest(std::vector> const& outputIds, + std::vector const& nGramSizes, std::vector const& expectedLastId) { auto const batchSize = expectedLastId.size(); int32_t maxStep = 0; diff --git a/cpp/tests/kernels/decodingKernelTest.cpp b/cpp/tests/kernels/decodingKernelTest.cpp index 4b46686f5..0f36d14c1 100644 --- a/cpp/tests/kernels/decodingKernelTest.cpp +++ b/cpp/tests/kernels/decodingKernelTest.cpp @@ -19,13 +19,18 @@ #include +#include "tensorrt_llm/common/memoryUtils.h" #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/decodingKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/runtimeKernels.h" #include #include +#include namespace tk = tensorrt_llm::kernels; +namespace tc = tensorrt_llm::common; +namespace trk = tensorrt_llm::runtime::kernels; using namespace tensorrt_llm::runtime; @@ -71,7 +76,7 @@ std::vector calculateGaussianKernel(float sigma, int size) } template -void applyGaussianFilter(T* result, const float* input, int n, float sigma) +void applyGaussianFilter(T* result, float const* input, int n, float sigma) { int size = static_cast(std::ceil(6.f * sigma)); size = (size % 2 == 0) ? size + 1 : size; @@ -98,22 +103,22 @@ void applyGaussianFilter(T* result, const float* input, int n, float sigma) } } -template void applyGaussianFilter(float* result, const float* input, int n, float sigma); -template void applyGaussianFilter(__half* result, const float* input, int n, float sigma); +template void applyGaussianFilter(float* result, float const* input, int n, float sigma); +template void applyGaussianFilter(__half* result, float const* input, int n, float sigma); template -void probsToLogits(const T* probs, T* logits, SizeType n) +void probsToLogits(T const* probs, T* logits, SizeType n) { constexpr float eps = 1e-6f; for (SizeType ni = 0; ni < n; ++ni) { - const auto prob = std::max(eps, static_cast(probs[ni])); + auto const prob = std::max(eps, static_cast(probs[ni])); logits[ni] = std::log(prob / (1.f - prob)); } } template -void softmax(const T* logits, T* probs, int n) +void softmax(T const* logits, T* probs, int n) { float epsilon = 1e-6f; @@ -139,8 +144,71 @@ void softmax(const T* logits, T* probs, int n) } } -template void probsToLogits(const float* probs, float* logits, SizeType n); -template void probsToLogits(const __half* probs, __half* logits, SizeType n); +template void probsToLogits(float const* probs, float* logits, SizeType n); +template void probsToLogits(__half const* probs, __half* logits, SizeType n); + +enum AcceptKernelMode +{ + BY_IDS, + BY_LOGITS, + BY_IDS_WITH_PATH +}; + +struct DecodingKernelTestParam +{ + SizeType mBatchSize{128}; + SizeType mMaxBatchSize{2 * mBatchSize}; + SizeType mBeamWidth{1}; + SizeType mMaxSeqLen{16}; + SizeType mVocabSize{32}; + SizeType mMaxDraftTokens{8}; + SizeType mMaxNumHeads{0}; + SizeType mMaxDraftSeqPerStep{1}; + AcceptKernelMode mAcceptMode{AcceptKernelMode::BY_IDS}; + + DecodingKernelTestParam& setBatchSize(SizeType bs) + { + mBatchSize = bs; + mMaxBatchSize = 2 * mBatchSize; + return *this; + } + + DecodingKernelTestParam& setVocabSize(SizeType vs) + { + mVocabSize = vs; + return *this; + } + + DecodingKernelTestParam& setMaxSeqLen(SizeType msl) + { + mMaxSeqLen = msl; + return *this; + } + + DecodingKernelTestParam& setMaxDraftTokens(SizeType dt) + { + mMaxDraftTokens = dt; + return *this; + } + + DecodingKernelTestParam& setMaxNumHeads(SizeType mnh) + { + mMaxNumHeads = mnh; + return *this; + } + + DecodingKernelTestParam& setMaxDraftSeqPerStep(SizeType tps) + { + mMaxDraftSeqPerStep = tps; + return *this; + } + + DecodingKernelTestParam& setAcceptMode(AcceptKernelMode const& mode) + { + mAcceptMode = mode; + return *this; + } +}; template class DecodingKernelsTest : public testing::Test @@ -152,210 +220,428 @@ class DecodingKernelsTest : public testing::Test { mStream = std::make_shared(); mBufferManager = std::make_shared(mStream); - - cudaMalloc(&mCurandStates, sizeof(curandState_t) * maxBatchSize); } - void TearDown() override - { - cudaFree(mCurandStates); - } + void TearDown() override {} - void initData(SizeType seed) + void createBuffers() { auto const dataType = TRTDataType::value; auto const ptrType = TRTDataType::value; - std::mt19937 generator(seed); - std::uniform_int_distribution contextLenDistr(0, maxSeqLen - maxDraftTokens); - std::uniform_int_distribution numDraftTokensDistr(1, maxDraftTokens); - std::uniform_int_distribution vocabDistr(1, vocabSize - 1); - std::uniform_real_distribution acceptTokenDistr(0.f, 1.f); - - mDraftTokens = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize, beamWidth, maxDraftTokens}), nvinfer1::DataType::kINT32); - mTargetTokens = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize, beamWidth, maxSeqLen}), nvinfer1::DataType::kINT32); - - mDraftLogits = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize * beamWidth, maxDraftTokens, vocabSize}), dataType); - mTargetLogits = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize * beamWidth, maxDraftTokens, vocabSize}), dataType); - mTargetLogitsPtrs = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), ptrType); - mRefTargetLogits = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize * beamWidth, maxDraftTokens, vocabSize}), dataType); - - mDraftProbs = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize * beamWidth, maxDraftTokens, vocabSize}), dataType); - mTargetProbs = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize * beamWidth, maxDraftTokens, vocabSize}), dataType); - - mNumsDraftTokens = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mSequenceLengths - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); - mContextLengths - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); - mFinishedSteps = mBufferManager->pinned(ITensor::makeShape({maxDraftTokens + 1, maxBatchSize, beamWidth}), + mDraftTokens + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqlen}), nvinfer1::DataType::kINT32); + mTargetTokens + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32); + mNumsDraftTokens = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep}), nvinfer1::DataType::kINT32); + mSequenceLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mContextLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mDraftContextLengths = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mFinishedSteps = BufferManager::pinned(ITensor::makeShape({mMaxDraftTokens + 1, mMaxBatchSize}), TRTDataType::value); - mFinishedFinal = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize, beamWidth}), TRTDataType::value); - mFinishedSum = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mFinishedFinal = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize}), TRTDataType::value); + mFinishedSum = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + mPaths = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxDraftSeqPerStep, mMaxDraftTokens}), nvinfer1::DataType::kINT32); + mEndIds = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + + mCurandStates = mBufferManager->gpu( + ITensor::makeShape({mMaxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8); + + mAcceptedLen.resize(mMaxBatchSize); + mOutputLen.resize(mMaxBatchSize); + mAcceptedFinished.resize(mMaxBatchSize, tk::FinishedState::empty()); + + // Buffers only for Logits comparison + if (mAcceptMode == AcceptKernelMode::BY_LOGITS) + { + mDraftLogits = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType); + mTargetLogits = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType); + mTargetLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), ptrType); + mRefTargetLogits = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType); + + mDraftProbs = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType); + mTargetProbs = BufferManager::pinned( + ITensor::makeShape({mMaxBatchSize, mMaxTotalDraftTokens, mVocabSize}), dataType); + } - mAcceptedLen.resize(maxBatchSize * beamWidth); - mOutputLen.resize(maxBatchSize * beamWidth); - for (SizeType bi = 0; bi < maxBatchSize * beamWidth; ++bi) + if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH) { - mAcceptedFinished.emplace_back(tk::FinishedState::empty()); + mMedusaLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mMaxNumHeads}), ptrType); } + } + + void initData(SizeType seed) + { + std::mt19937 generator(seed); + std::uniform_int_distribution contextLenDistr(0, std::max(mMaxSeqLen - mMaxTotalDraftTokens, 0)); + std::uniform_int_distribution draftContextLenDistr( + 0, std::max(mMaxDraftSeqlen - mMaxTotalDraftTokens, 0)); + std::uniform_int_distribution numTotalDraftTokensDistr(1, mMaxTotalDraftTokens); + std::uniform_int_distribution numDraftTokensDistr(0, mMaxDraftTokens); + std::uniform_int_distribution vocabDistr(1, mVocabSize - 1); + std::uniform_real_distribution acceptTokenDistr(0.f, 1.f); + + trk::invokeFill(*mPaths, int32_t{-1}, *mStream); + trk::invokeFill(*mFinishedFinal, tk::FinishedState::UnderlyingType{0}, *mStream); - auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); - auto contextLengthsPtr = bufferCast(*mContextLengths); - auto numsDraftTokensPtr = bufferCast(*mNumsDraftTokens); - auto draftTokensPtr = bufferCast(*mDraftTokens); - auto targetTokensPtr = bufferCast(*mTargetTokens); + auto sequenceLengthsPtr = BufferRange(*mSequenceLengths); + auto contextLengthsPtr = BufferRange(*mContextLengths); + auto draftContextLengthsPtr = BufferRange(*mDraftContextLengths); + auto numsDraftTokensPtr = BufferRange(*mNumsDraftTokens); + auto draftTokensPtr = BufferRange(*mDraftTokens); + auto targetTokensPtr = BufferRange(*mTargetTokens); auto finishedStepsPtr = reinterpret_cast(bufferCast(*mFinishedSteps)); - auto finishedFinalPtr - = reinterpret_cast(bufferCast(*mFinishedFinal)); - auto finishedSumPtr = bufferCast(*mFinishedSum); - - auto draftProbsPtr = bufferCast(*mDraftProbs); - auto targetProbsPtr = bufferCast(*mTargetProbs); + auto pathsPtr = BufferRange(*mPaths); + auto endIdsPtr = BufferRange(*mEndIds); - auto draftLogitsPtr = bufferCast(*mDraftLogits); - auto targetLogitsPtr = bufferCast(*mTargetLogits); - auto targetLogitsPtrsPtr = BufferRange(*mTargetLogitsPtrs); - auto refTargetLogitsPtr = bufferCast(*mRefTargetLogits); - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto batchSlotsPtr = BufferRange(*mBatchSlots); - tk::invokeCurandInitialize(mCurandStates, nullptr, maxBatchSize, seed, this->mStream->get()); + tk::invokeCurandInitialize(reinterpret_cast(bufferCast(*mCurandStates)), nullptr, + mMaxBatchSize, seed, this->mStream->get()); - // Init number of draft tokens - for (SizeType bi = 0; bi < maxBatchSize; ++bi) + auto generateAvoidingValues + = [&vocabDistr, &generator](std::uniform_int_distribution& distr, + std::unordered_set const& tokensToAvoid, SizeType maxTries = -1, SizeType defaultValue = -1) { - numsDraftTokensPtr[bi] = numDraftTokensDistr(generator); - } + // Avoid generating endId. + auto token = distr(generator); + SizeType tries = 0; + while (tokensToAvoid.count(token) != 0 && ((maxTries >= 0 && tries < maxTries) || maxTries < 0)) + { + token = distr(generator); + tries++; + } + if (tries == maxTries) + { + token = defaultValue; + } + return token; + }; - for (SizeType bi = 0; bi < batchSize; ++bi) + // Init batch slots + for (SizeType bi = 0; bi < mBatchSize; ++bi) { batchSlotsPtr[bi] = 2 * bi; } - for (SizeType bi = 0; bi < maxBatchSize * beamWidth; ++bi) + // Init end ids + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { - const SizeType batchIdx = bi / beamWidth; - // Randomly init context len + endIdsPtr[bi] = generateAvoidingValues(vocabDistr, {mPadId}); + TLLM_LOG_DEBUG("bi %d endIdsPtr[bi] %d", bi, endIdsPtr[bi]); + + // Randomly init context len for target and draft contextLengthsPtr[bi] = contextLenDistr(generator); + draftContextLengthsPtr[bi] = draftContextLenDistr(generator); + } - // Sequence len is at most numsDraftTokensPtr[bi] away from context len (it can be closer if e.g. endId is - // generated) - std::uniform_int_distribution realDraftTokensDistr(0, numsDraftTokensPtr[batchIdx]); - const auto realLen = realDraftTokensDistr(generator); - sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + realLen; + std::fill(draftTokensPtr.begin(), draftTokensPtr.begin() + mMaxBatchSize * mMaxDraftSeqlen, mPadId); + std::fill(targetTokensPtr.begin(), targetTokensPtr.begin() + mMaxBatchSize * mMaxSeqLen, mPadId); + std::fill(pathsPtr.begin(), pathsPtr.begin() + mMaxBatchSize * mMaxDraftSeqPerStep * mMaxDraftTokens, -1); - // Initialize finished states - for (int i = 0; i < realLen; ++i) + // Generate paths + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) + { + auto const numTotalDraftTokens = std::min(mMaxDraftTokens, numTotalDraftTokensDistr(generator)); + std::uniform_int_distribution pathIdDistr(0, numTotalDraftTokens); + for (SizeType pi = 0; pi < mMaxDraftSeqPerStep; ++pi) { - finishedStepsPtr[i * maxBatchSize * beamWidth + bi] = tk::FinishedState::empty(); + std::unordered_set pathIds; + auto const numDraftTokensAtStep = numDraftTokensDistr(generator); + numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + pi] = numDraftTokensAtStep; + + for (SizeType ti = 0; ti < numDraftTokensAtStep; ++ti) + { + auto const pathIdx = tc::flat_index3(bi, pi, ti, mMaxDraftSeqPerStep, mMaxDraftTokens); + // Single linear path for BY_IDS and BY_LOGITS modes + auto const pathId = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH + ? generateAvoidingValues(pathIdDistr, pathIds, mMaxDraftTokens * 5, -1) + : ti; + pathsPtr[pathIdx] = pathId; + pathIds.insert(pathId); + } + TLLM_LOG_DEBUG("bi %d pi %d numsDraftTokensPtr[bi] %d", bi, pi, numDraftTokensAtStep); } - for (int i = realLen; i <= numsDraftTokensPtr[batchIdx]; ++i) + } + + for (SizeType ti = 0; ti < mMaxDraftSeqPerStep; ++ti) + { + std::vector targetPredictedLen(mMaxBatchSize); + std::vector targetAcceptedLen(mMaxBatchSize); + + // Init number of draft tokens + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { - finishedStepsPtr[i * maxBatchSize * beamWidth + bi] = tk::FinishedState::finished(); + // It can be shorter than num of draft tokens due to the EOS generation + std::uniform_int_distribution realDraftTokensDistr( + 0, numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]); + targetPredictedLen[bi] = realDraftTokensDistr(generator); + // Accept ~ half of the tokens on avergae + std::poisson_distribution targetAcceptedDistr(targetPredictedLen[bi] / 2); + targetAcceptedLen[bi] = std::min(targetAcceptedDistr(generator), targetPredictedLen[bi]); + + TLLM_LOG_DEBUG( + "bi %d ti %d targetPredictedLen[bi] %d targetAcceptedLen[bi] %d draftContextLengthsPtr[bi] %d", bi, + ti, targetPredictedLen[bi], targetAcceptedLen[bi], draftContextLengthsPtr[bi]); } - // Init helper vector with max value - mAcceptedLen[bi] = sequenceLengthsPtr[bi]; - mOutputLen[bi] = sequenceLengthsPtr[bi]; - mAcceptedFinished[bi] = finishedStepsPtr[realLen * maxBatchSize * beamWidth + bi]; - } - // Fill token arrays - for (SizeType bi = 0; bi < maxBatchSize * beamWidth; ++bi) - { - // Draft: [d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dN] - // Target: [vocabSize - 1, vocabSize - 1, ... for contextLengthsPtr[bi] ... vocabSize - 1, - // t0, t1, t2, ... for numsDraftTokensPtr[bi] ... , tN, - // vocabSize - 1, vocabSize - 1, .. to maxSeqLen] - for (SizeType si = 0; si < contextLengthsPtr[bi]; ++si) + // Fill draft tokens + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { - targetTokensPtr[bi * maxSeqLen + si] = vocabSize - 1; + for (SizeType si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si) + { + auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens); + auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathsPtr[pathIdx]; + // Avoid generating endId. We'll insert in manually later if needed. + draftTokensPtr[draftTokenIdx] = generateAvoidingValues(vocabDistr, {mPadId, endIdsPtr[bi]}); + TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d draftToken %d", bi, ti, si, pathsPtr[pathIdx], + draftTokensPtr[draftTokenIdx]); + } } - for (SizeType si = contextLengthsPtr[bi]; si < sequenceLengthsPtr[bi]; ++si) + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { - const auto draftToken = vocabDistr(generator); - const auto draftTokenIdx = si - contextLengthsPtr[bi]; - const auto targetToken - = acceptTokenDistr(generator) < 1.f / (draftTokenIdx + 1e-6) ? draftToken : vocabDistr(generator); - draftTokensPtr[bi * maxDraftTokens + draftTokenIdx] = draftToken; - targetTokensPtr[bi * maxSeqLen + si] = targetToken; - if (draftToken != targetToken) + sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + targetPredictedLen[bi]; + + // Initialize finished states + for (int di = 0; di < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++di) { - mAcceptedLen[bi] = std::min(mAcceptedLen[bi], std::min(si, maxSeqLen)); - mOutputLen[bi] = std::min(mOutputLen[bi], std::min(si + 1, maxSeqLen)); - mAcceptedFinished[bi] = finishedStepsPtr[draftTokenIdx * maxBatchSize * beamWidth + bi]; + finishedStepsPtr[di * mMaxBatchSize + bi] + = (di < targetPredictedLen[bi]) ? tk::FinishedState::empty() : tk::FinishedState::finished(); } + + // Init helper vectors + mAcceptedLen[bi] = contextLengthsPtr[bi] + std::max(targetAcceptedLen[bi], 0); + mOutputLen[bi] = std::min(sequenceLengthsPtr[bi], std::min(mAcceptedLen[bi] + 1, mMaxSeqLen)); + mAcceptedFinished[bi] = finishedStepsPtr[std::max(targetAcceptedLen[bi], 0) * mMaxBatchSize + bi]; + + TLLM_LOG_DEBUG( + "bi %d ti %d contextLengthsPtr[bi] %d sequenceLengthsPtr[bi] %d mAcceptedLen[bi] %d mOutputLen[bi] " + "%d", + bi, ti, contextLengthsPtr[bi], sequenceLengthsPtr[bi], mAcceptedLen[bi], mOutputLen[bi]); } - for (SizeType si = sequenceLengthsPtr[bi]; si < maxSeqLen; ++si) + // Fill token arrays + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) { - targetTokensPtr[bi * maxSeqLen + si] = vocabSize - 1; + // Draft: [padId, padId, for draftContextLengthsPtr[bi] ... padId, + // d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dK, + // padId, padId, .. to mMaxDraftSeqlen] + // Target: [padId, padId, ... for contextLengthsPtr[bi] ... padId, + // d0, d1, d2, ... for targetAcceptedLen[bi], + // ti (!= di), ti+1 (!= di+1), ... for (targetPredictedLen[bi] - targetAcceptedLen[bi]), + // EOS, EOS, EOS, ... for (numsDraftTokensPtr[bi] - targetPredictedLen[bi]) + // padId, padId, .. to mMaxSeqLen] + for (SizeType si = 0; si < numsDraftTokensPtr[bi * mMaxDraftSeqPerStep + ti]; ++si) + { + auto const pathIdx = tc::flat_index3(bi, ti, si, mMaxDraftSeqPerStep, mMaxDraftTokens); + auto const pathId = pathsPtr[pathIdx]; + if (pathId == -1) + { + continue; + } + auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathId; + auto const targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; + auto targetToken = mPadId; + + if (0 <= si && si < targetAcceptedLen[bi]) + { + // Use draft token up to the accepted len + targetToken = draftTokensPtr[draftTokenIdx]; + } + else if (targetAcceptedLen[bi] <= si && si < targetPredictedLen[bi]) + { + // Do not use draft token token up to the generated len + targetToken = generateAvoidingValues( + vocabDistr, {mPadId, endIdsPtr[bi], draftTokensPtr[draftTokenIdx]}); + } + else if (targetPredictedLen[bi] <= si && si < numsDraftTokensPtr[bi]) + { + // Fill with EOS from generated len to the draft len + targetToken = endIdsPtr[bi]; + } + targetTokensPtr[targetTokenIdx] = targetToken; + TLLM_LOG_DEBUG("bi %d ti %d si %d pathId %d targetToken %d", bi, ti, si, pathId, targetToken); + } } + } + + if (mAcceptMode == AcceptKernelMode::BY_LOGITS) + { + initDataAndReferenceAcceptByLogits(); + } + + if (mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH) + { + initDataAndReferenceAcceptByIdsWithPaths(); + } + } - for (SizeType si = sequenceLengthsPtr[bi] - contextLengthsPtr[bi]; si < maxDraftTokens; ++si) + void initDataAndReferenceAcceptByIdsWithPaths() + { + auto const dataType = TRTDataType::value; + auto const ptrType = TRTDataType::value; + + auto pathsPtr = BufferRange(*mPaths); + auto endIdsPtr = BufferRange(*mEndIds); + auto contextLengthsPtr = BufferRange(*mContextLengths); + auto draftContextLengthsPtr = BufferRange(*mDraftContextLengths); + auto draftTokensPtr = BufferRange(*mDraftTokens); + auto targetTokensPtr = BufferRange(*mTargetTokens); + + trk::invokeFill(*mMedusaLogitsPtrs, int64_t{0}, *mStream); + + mAcceptedLen.resize(mMaxBatchSize); + mAcceptedPathIdx.resize(mMaxBatchSize); + mRefAcceptedTokens.resize(mMaxBatchSize); + mFinishedByIdsPaths.resize(mMaxBatchSize); + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) + { + SizeType maxAcceptedLen = -1; + SizeType maxAcceptedPath = -1; + bool maxFinished = false; + std::vector maxAcceptedTokens; + for (SizeType ti = 0; ti < mMaxDraftSeqPerStep; ++ti) { - draftTokensPtr[bi * maxDraftTokens + si] = 0; + std::vector acceptedTokens; + SizeType curAcceptedLen = mMaxDraftTokens; + SizeType curAcceptedPath = -1; + bool curFinished = false; + for (SizeType di = 0; di < mMaxDraftTokens; ++di) + { + auto const pathIdx = tc::flat_index3(bi, ti, di, mMaxDraftSeqPerStep, mMaxDraftTokens); + auto const pathId = pathsPtr[pathIdx]; + if (pathId == -1) + { + curAcceptedLen = di; + curAcceptedPath = ti; + curFinished = false; + break; + } + auto const draftTokenIdx = bi * mMaxDraftSeqlen + draftContextLengthsPtr[bi] + pathId; + auto const targetTokenIdx = bi * mMaxSeqLen + contextLengthsPtr[bi] + pathId; + auto const draftToken = draftTokensPtr[draftTokenIdx]; + auto const targetToken = targetTokensPtr[targetTokenIdx]; + bool const hasEnd = targetToken == endIdsPtr[bi]; + if (!hasEnd) + { + acceptedTokens.push_back(targetToken); + } + if (draftToken != targetToken || hasEnd) + { + auto const curLen = hasEnd ? di : di + 1; + curAcceptedLen = curLen; + curAcceptedPath = ti; + curFinished = hasEnd; + break; + } + } + if (curAcceptedLen > maxAcceptedLen) + { + maxAcceptedLen = curAcceptedLen; + maxAcceptedPath = curAcceptedPath; + maxAcceptedTokens = acceptedTokens; + maxFinished = curFinished; + } } + mAcceptedLen[bi] = maxAcceptedLen; + mAcceptedPathIdx[bi] = maxAcceptedPath; + mRefAcceptedTokens[bi] = maxAcceptedTokens; + mFinishedByIdsPaths[bi] = maxFinished; + TLLM_LOG_DEBUG("bi %d maxAcceptedLen %d maxAcceptedPath %d", bi, maxAcceptedLen, maxAcceptedPath); + std::ostringstream ss; + for (auto& tk : maxAcceptedTokens) + { + ss << tk << " "; + } + TLLM_LOG_DEBUG(ss.str().c_str()); + } + mDraftContextLengthsCopy = mBufferManager->copyFrom(*mDraftContextLengths, MemoryType::kCPU); + } + + void initDataAndReferenceAcceptByLogits() + { + auto contextLengthsPtr = BufferRange(*mContextLengths); + auto numsDraftTokensPtr = BufferRange(*mNumsDraftTokens); + auto draftTokensPtr = BufferRange(*mDraftTokens); + auto targetTokensPtr = BufferRange(*mTargetTokens); + + auto draftProbsPtr = BufferRange(*mDraftProbs); + auto targetProbsPtr = BufferRange(*mTargetProbs); + + auto draftLogitsPtr = BufferRange(*mDraftLogits); + auto targetLogitsPtr = BufferRange(*mTargetLogits); + auto targetLogitsPtrsPtr = BufferRange(*mTargetLogitsPtrs); + auto refTargetLogitsPtr = BufferRange(*mRefTargetLogits); + auto batchSlotsPtr = BufferRange(*mBatchSlots); + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) + { // Init draft and target logits and probabilities for (SizeType si = 0; si < numsDraftTokensPtr[bi]; ++si) { - std::vector peakDraftProb(vocabSize, 0.f); - std::vector peakTargetProb(vocabSize, 0.f); + std::vector peakDraftProb(mVocabSize, 0.f); + std::vector peakTargetProb(mVocabSize, 0.f); - auto const targetToken = targetTokensPtr[bi * maxSeqLen + contextLengthsPtr[bi] + si] % vocabSize; - auto const draftToken = draftTokensPtr[bi * maxDraftTokens + si] % vocabSize; + auto const targetToken = targetTokensPtr[bi * mMaxSeqLen + contextLengthsPtr[bi] + si] % mVocabSize; + auto const draftToken = draftTokensPtr[bi * mMaxDraftTokens + si] % mVocabSize; peakDraftProb[draftToken] = 1.f; peakTargetProb[targetToken] = 1.f; - auto const logitsOffset = bi * beamWidth * maxDraftTokens * vocabSize + si * beamWidth * vocabSize; + auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize; // Emulate some distribution around target token - applyGaussianFilter(draftProbsPtr + logitsOffset, peakDraftProb.data(), peakDraftProb.size(), 1.0f); - applyGaussianFilter(targetProbsPtr + logitsOffset, peakTargetProb.data(), peakTargetProb.size(), 1.0f); + applyGaussianFilter( + draftProbsPtr.begin() + logitsOffset, peakDraftProb.data(), peakDraftProb.size(), 1.0f); + applyGaussianFilter( + targetProbsPtr.begin() + logitsOffset, peakTargetProb.data(), peakTargetProb.size(), 1.0f); // Probabilities to logits - probsToLogits(draftProbsPtr + logitsOffset, draftLogitsPtr + logitsOffset, vocabSize); - probsToLogits(targetProbsPtr + logitsOffset, targetLogitsPtr + logitsOffset, vocabSize); + probsToLogits(draftProbsPtr.begin() + logitsOffset, draftLogitsPtr.begin() + logitsOffset, mVocabSize); + probsToLogits( + targetProbsPtr.begin() + logitsOffset, targetLogitsPtr.begin() + logitsOffset, mVocabSize); // Do softmax conversion back to emulate kernels accuracy - softmax(draftLogitsPtr + logitsOffset, draftProbsPtr + logitsOffset, vocabSize); - softmax(targetLogitsPtr + logitsOffset, targetProbsPtr + logitsOffset, vocabSize); + softmax(draftLogitsPtr.begin() + logitsOffset, draftProbsPtr.begin() + logitsOffset, mVocabSize); + softmax(targetLogitsPtr.begin() + logitsOffset, targetProbsPtr.begin() + logitsOffset, mVocabSize); } + } - for (SizeType si = 0; si < maxDraftTokens; ++si) + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) + { + for (SizeType si = 0; si < mMaxDraftTokens; ++si) { - auto const logitsOffset = bi * beamWidth * maxDraftTokens * vocabSize + si * beamWidth * vocabSize; + auto const logitsOffset = bi * mMaxDraftTokens * mVocabSize + si * mVocabSize; auto const outputLen = mOutputLen[bi] - contextLengthsPtr[bi]; auto const acceptedLen = mAcceptedLen[bi] - contextLengthsPtr[bi]; if (si < acceptedLen) { - std::memcpy( - refTargetLogitsPtr + logitsOffset, targetLogitsPtr + logitsOffset, vocabSize * sizeof(T)); + auto logitsStart = targetLogitsPtr.begin() + logitsOffset; + std::copy(logitsStart, logitsStart + mVocabSize, refTargetLogitsPtr.begin() + logitsOffset); } else if (si == acceptedLen) { // When token is not accepted, correct probabilities and compute updated logits float sumProb = 1e-6f; - for (SizeType vi = 0; vi < vocabSize; ++vi) + for (SizeType vi = 0; vi < mVocabSize; ++vi) { - const auto correctedProb = std::max( + auto const correctedProb = std::max( static_cast(targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]), 0.f); sumProb += correctedProb; } - for (SizeType vi = 0; vi < vocabSize; ++vi) + for (SizeType vi = 0; vi < mVocabSize; ++vi) { auto prob = std::max(static_cast( targetProbsPtr[logitsOffset + vi] - draftProbsPtr[logitsOffset + vi]), @@ -370,23 +656,65 @@ class DecodingKernelsTest : public testing::Test } } } - for (SizeType bi = 0; bi < batchSize; ++bi) + for (SizeType bi = 0; bi < mBatchSize; ++bi) { - targetLogitsPtrsPtr[bi] = targetLogitsPtr + batchSlotsPtr[bi] * maxDraftTokens * beamWidth * vocabSize; + targetLogitsPtrsPtr[bi] = targetLogitsPtr.begin() + batchSlotsPtr[bi] * mMaxDraftTokens * mVocabSize; } } - void verifyAcceptByIdsResults(SizeType seed) + void callAcceptByIds() + { + tk::invokeAcceptDraftTokensByIds(bufferCast(*mDraftTokens), bufferCast(*mTargetTokens), + bufferCast(*mContextLengths), bufferCast(*mNumsDraftTokens), + bufferCast(*mSequenceLengths), + reinterpret_cast(bufferCast(*mFinishedSteps)), + reinterpret_cast(bufferCast(*mFinishedFinal)), + bufferCast(*mFinishedSum), bufferCast(*mBatchSlots), mBatchSize, mMaxBatchSize, + mBeamWidth, mMaxSeqLen, mMaxDraftTokens, mStream->get()); + } + + void callAcceptByLogits() + { + tk::acceptDraftTokensByLogits(bufferCast(*mDraftLogits), + reinterpret_cast(bufferCast(*mTargetLogitsPtrs)), bufferCast(*mDraftProbs), + bufferCast(*mTargetProbs), bufferCast(*mNumsDraftTokens), + reinterpret_cast(bufferCast(*mFinishedSteps)), + reinterpret_cast(bufferCast(*mCurandStates)), bufferCast(*mBatchSlots), + mBatchSize, mMaxBatchSize, mBeamWidth, mVocabSize, mVocabSize, mMaxDraftTokens, false, 0.9f, + mStream->get()); + } + + void callAcceptByIdsWithPaths() { - mStream->synchronize(); + tk::acceptDraftTokensByIdsWithPaths(bufferCast(*mDraftTokens), bufferCast(*mTargetTokens), + bufferCast(*mDraftContextLengths), + reinterpret_cast(bufferCast(*mFinishedFinal)), + bufferCast(*mBatchSlots), bufferCast(*mPaths), bufferCast(*mEndIds), + static_cast(nullptr), reinterpret_cast(bufferCast(*mMedusaLogitsPtrs)), + mBatchSize, mVocabSize, mMaxBatchSize, mMaxDraftSeqlen, mMaxTotalDraftTokens, mMaxNumHeads, + mMaxDraftSeqPerStep, mStream->get()); + } + + void callTestedKernel() + { + switch (mAcceptMode) + { + case AcceptKernelMode::BY_IDS: callAcceptByIds(); break; + case AcceptKernelMode::BY_LOGITS: callAcceptByLogits(); break; + case AcceptKernelMode::BY_IDS_WITH_PATH: callAcceptByIdsWithPaths(); break; + default: TLLM_CHECK(false); // Should never be here + } + } + void verifyAcceptByIdsResults(SizeType seed) + { auto finishedFinalPtr = reinterpret_cast(bufferCast(*mFinishedFinal)); - auto sequenceLengthsPtr = bufferCast(*mSequenceLengths); - auto finishedSumPtr = bufferCast(*mFinishedSum); - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto sequenceLengthsPtr = BufferRange(*mSequenceLengths); + auto finishedSumPtr = BufferRange(*mFinishedSum); + auto batchSlotsPtr = BufferRange(*mBatchSlots); // Verify seqLen for accepted tokens - for (SizeType bi = 0; bi < batchSize; ++bi) + for (SizeType bi = 0; bi < mBatchSize; ++bi) { auto const batchSlot = batchSlotsPtr[bi]; EXPECT_EQ(mOutputLen[batchSlot], sequenceLengthsPtr[batchSlot]) << " bi " << bi << " seed " << seed; @@ -400,29 +728,26 @@ class DecodingKernelsTest : public testing::Test void verifyAcceptByLogitsResults(SizeType seed) { - mStream->synchronize(); - auto finishedStepsPtr = reinterpret_cast(bufferCast(*mFinishedSteps)); - auto contextLengthsPtr = bufferCast(*mContextLengths); - auto outLogitsPtr = bufferCast(*mTargetLogits); - auto refLogitsPtr = bufferCast(*mRefTargetLogits); - auto numsDraftTokensPtr = bufferCast(*mNumsDraftTokens); - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto contextLengthsPtr = BufferRange(*mContextLengths); + auto outLogitsPtr = BufferRange(*mTargetLogits); + auto refLogitsPtr = BufferRange(*mRefTargetLogits); + auto numsDraftTokensPtr = BufferRange(*mNumsDraftTokens); + auto batchSlotsPtr = BufferRange(*mBatchSlots); - for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi) + for (SizeType bi = 0; bi < mBatchSize; ++bi) { auto const batchSlot = batchSlotsPtr[bi]; for (SizeType si = 0; si < numsDraftTokensPtr[batchSlot]; ++si) { - auto const outFinishedState = finishedStepsPtr[si * maxBatchSize * beamWidth + batchSlot]; - auto const logitsOffset - = batchSlot * beamWidth * maxDraftTokens * vocabSize + si * beamWidth * vocabSize; + auto const outFinishedState = finishedStepsPtr[si * mMaxBatchSize + batchSlot]; + auto const logitsOffset = batchSlot * mMaxDraftTokens * mVocabSize + si * mVocabSize; if (si <= mAcceptedLen[batchSlot] - contextLengthsPtr[batchSlot]) { EXPECT_FALSE(outFinishedState.isSkipDecoding()) << " bi: " << bi << " si: " << si << " seed: " << seed; - for (SizeType vi = 0; vi < vocabSize; ++vi) + for (SizeType vi = 0; vi < mVocabSize; ++vi) { auto const outLogit = static_cast(outLogitsPtr[logitsOffset + vi]); auto const refLogit = static_cast(refLogitsPtr[logitsOffset + vi]); @@ -448,29 +773,103 @@ class DecodingKernelsTest : public testing::Test } } - void runAcceptByIdsTest(SizeType seed) + void verifyAcceptByIdsWithPathsResults(SizeType seed) { - initData(seed); - tk::invokeAcceptDraftTokensByIds(bufferCast(*mDraftTokens), bufferCast(*mTargetTokens), - bufferCast(*mContextLengths), bufferCast(*mNumsDraftTokens), - bufferCast(*mSequenceLengths), - reinterpret_cast(bufferCast(*mFinishedSteps)), - reinterpret_cast(bufferCast(*mFinishedFinal)), - bufferCast(*mFinishedSum), bufferCast(*mBatchSlots), batchSize, maxBatchSize, beamWidth, - maxSeqLen, maxDraftTokens, mStream->get()); - verifyAcceptByIdsResults(seed); + auto medusaLogitsPtrsPtr = BufferRange(*mMedusaLogitsPtrs); + auto batchSlotsPtr = BufferRange(*mBatchSlots); + auto draftContextLengths = BufferRange(*mDraftContextLengths); + auto draftContextLengthsInit = BufferRange(*mDraftContextLengthsCopy); + auto draftTokensPtr = BufferRange(*mDraftTokens); + auto finishedFinalPtr + = reinterpret_cast(bufferCast(*mFinishedFinal)); + + for (SizeType bi = 0; bi < mBatchSize; ++bi) + { + auto const batchSlot = batchSlotsPtr[bi]; + auto const bestPathIdx = mAcceptedPathIdx[batchSlot]; + auto const acceptedLen = mAcceptedLen[batchSlot]; + auto acceptedTokens = mRefAcceptedTokens[batchSlot]; + + for (int32_t hi = 0; hi < mMaxNumHeads; ++hi) + { + auto refOffset + = tc::flat_index4(hi, bi, acceptedLen, 0, mMaxBatchSize, mMaxDraftSeqPerStep, mVocabSize); + auto outOffset + = static_cast(medusaLogitsPtrsPtr[bi * mMaxNumHeads + hi] - static_cast(nullptr)); + EXPECT_EQ(outOffset, refOffset) << " bi: " << bi << " hi: " << hi << " seed: " << seed; + } + EXPECT_EQ(draftContextLengths[batchSlot], draftContextLengthsInit[batchSlot] + acceptedLen) + << " bi: " << bi << " seed: " << seed << " out: " << draftContextLengths[batchSlot] + << " ref: " << draftContextLengthsInit[batchSlot] + acceptedLen; + + for (SizeType ti = 0; ti < acceptedLen; ++ti) + { + ASSERT_EQ(mRefAcceptedTokens[batchSlot].size(), acceptedLen) + << " bi: " << bi << " ti: " << ti << " seed: " << seed; + EXPECT_EQ(draftTokensPtr[batchSlot * mMaxDraftSeqlen + draftContextLengthsInit[batchSlot] + ti], + mRefAcceptedTokens[batchSlot][ti]) + << " bi: " << bi << " ti: " << ti << " seed: " << seed; + } + EXPECT_EQ(finishedFinalPtr[batchSlot].isFinished(), mFinishedByIdsPaths[batchSlot]) + << " bi: " << bi << " seed: " << seed; + } } - void runAcceptByLogitsTest(SizeType seed) + void verifyResult(SizeType seed) { - initData(seed); - tk::acceptDraftTokensByLogits(bufferCast(*mDraftLogits), - reinterpret_cast(bufferCast(*mTargetLogitsPtrs)), bufferCast(*mDraftProbs), - bufferCast(*mTargetProbs), bufferCast(*mNumsDraftTokens), - reinterpret_cast(bufferCast(*mFinishedSteps)), - mCurandStates, bufferCast(*mBatchSlots), batchSize, maxBatchSize, beamWidth, vocabSize, vocabSize, - maxDraftTokens, false, 0.9f, mStream->get()); - verifyAcceptByLogitsResults(seed); + switch (mAcceptMode) + { + case AcceptKernelMode::BY_IDS: verifyAcceptByIdsResults(seed); break; + case AcceptKernelMode::BY_LOGITS: verifyAcceptByLogitsResults(seed); break; + case AcceptKernelMode::BY_IDS_WITH_PATH: verifyAcceptByIdsWithPathsResults(seed); break; + default: TLLM_CHECK(false); // Should never be here + } + } + + void runTest(DecodingKernelTestParam const& params) + { + mAcceptMode = params.mAcceptMode; + + mBatchSize = params.mBatchSize; + mMaxBatchSize = params.mMaxBatchSize; + mBeamWidth = params.mBeamWidth; + mVocabSize = params.mVocabSize; + mMaxDraftTokens = params.mMaxDraftTokens; + + mMaxNumHeads = params.mMaxNumHeads; + if (mMaxNumHeads > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH) + { + GTEST_SKIP() << "MaxNumHeads > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH"; + } + + mMaxDraftSeqPerStep = params.mMaxDraftSeqPerStep; + if (mMaxDraftSeqPerStep > 1 && mAcceptMode != AcceptKernelMode::BY_IDS_WITH_PATH) + { + GTEST_SKIP() << "MaxDraftSeqPerStep > 1 is only supported for AcceptKernelMode::BY_IDS_WITH_PATH"; + } + + mMaxTotalDraftTokens = mMaxDraftSeqPerStep * mMaxDraftTokens; + mPadId = mVocabSize - 1; + + mMaxDraftSeqlen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? params.mMaxSeqLen : mMaxDraftTokens; + mMaxSeqLen = mAcceptMode == AcceptKernelMode::BY_IDS_WITH_PATH ? mMaxTotalDraftTokens : params.mMaxSeqLen; + + createBuffers(); + + for (SizeType seed = 0; seed < mSeeds; ++seed) + { + TLLM_LOG_DEBUG("Seed %d", seed); + + initData(seed); + + mStream->synchronize(); + + callTestedKernel(); + + mStream->synchronize(); + + verifyResult(seed); + } } protected: @@ -491,23 +890,39 @@ class DecodingKernelsTest : public testing::Test TensorPtr mNumsDraftTokens; TensorPtr mSequenceLengths; TensorPtr mContextLengths; + TensorPtr mDraftContextLengthsCopy; + TensorPtr mDraftContextLengths; TensorPtr mFinishedSteps; TensorPtr mFinishedFinal; TensorPtr mFinishedSum; TensorPtr mBatchSlots; - std::vector mAcceptedLen; - std::vector mOutputLen; - std::vector mAcceptedFinished; + TensorPtr mPaths; + TensorPtr mEndIds; + TensorPtr mMedusaLogitsPtrs; - curandState_t* mCurandStates; + TensorPtr mCurandStates; - static constexpr SizeType batchSize{128}; - static constexpr SizeType maxBatchSize{2 * batchSize}; - static constexpr SizeType beamWidth{1}; - static constexpr SizeType maxSeqLen{16}; - static constexpr SizeType vocabSize{32}; - static constexpr SizeType maxDraftTokens{8}; + std::vector mAcceptedLen; + std::vector mOutputLen; + std::vector mAcceptedFinished; + std::vector mAcceptedPathIdx; + std::vector> mRefAcceptedTokens; + std::vector mFinishedByIdsPaths; + + SizeType mBatchSize; + SizeType mMaxBatchSize; + SizeType mBeamWidth; + SizeType mMaxSeqLen; + SizeType mVocabSize; + SizeType mMaxDraftTokens; + SizeType mMaxTotalDraftTokens; + SizeType mMaxDraftSeqlen; + SizeType mMaxNumHeads; + SizeType mMaxDraftSeqPerStep; + AcceptKernelMode mAcceptMode; + SizeType mPadId; + static constexpr SizeType mSeeds = 64; }; template class DecodingKernelsTest; @@ -517,22 +932,71 @@ typedef testing::Types FloatAndHalfTypes; TYPED_TEST_SUITE(DecodingKernelsTest, FloatAndHalfTypes); -TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsKernel) +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsKernelSmall) { - constexpr SizeType seeds = 64; - for (SizeType seed = 0; seed < seeds; ++seed) - { - this->runAcceptByIdsTest(seed); - } + this->runTest(DecodingKernelTestParam() + .setBatchSize(1) + .setMaxSeqLen(16) + .setVocabSize(32) + .setMaxDraftTokens(8) + .setMaxDraftSeqPerStep(1) + .setAcceptMode(AcceptKernelMode::BY_IDS)); } -TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByLogitsKernel) +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsKernelLarge) { - constexpr SizeType seeds = 64; - for (SizeType seed = 0; seed < seeds; ++seed) - { - this->runAcceptByLogitsTest(seed); - } + this->runTest(DecodingKernelTestParam() + .setBatchSize(128) + .setMaxSeqLen(128) + .setVocabSize(52000) + .setMaxDraftTokens(8) + .setMaxDraftSeqPerStep(1) + .setAcceptMode(AcceptKernelMode::BY_IDS)); +} + +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByLogitsKernelSmall) +{ + this->runTest(DecodingKernelTestParam() + .setBatchSize(1) + .setMaxSeqLen(16) + .setVocabSize(32) + .setMaxDraftTokens(8) + .setMaxDraftSeqPerStep(1) + .setAcceptMode(AcceptKernelMode::BY_LOGITS)); +} + +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByLogitsKernelLarge) +{ + this->runTest(DecodingKernelTestParam() + .setBatchSize(64) + .setMaxSeqLen(64) + .setVocabSize(4000) + .setMaxDraftTokens(8) + .setMaxDraftSeqPerStep(1) + .setAcceptMode(AcceptKernelMode::BY_LOGITS)); +} + +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsWithPathsKernelSmall) +{ + this->runTest(DecodingKernelTestParam() + .setBatchSize(1) + .setMaxSeqLen(128) + .setVocabSize(32) + .setMaxDraftTokens(5) + .setMaxDraftSeqPerStep(4) + .setMaxNumHeads(4) + .setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH)); } +TYPED_TEST(DecodingKernelsTest, acceptDraftTokensByIdsWithPathsKernelLarge) +{ + this->runTest(DecodingKernelTestParam() + .setBatchSize(128) + .setMaxSeqLen(1024) + .setVocabSize(4000) + .setMaxDraftTokens(8) + .setMaxDraftSeqPerStep(64) + .setMaxNumHeads(7) + .setAcceptMode(AcceptKernelMode::BY_IDS_WITH_PATH)); +} } // end of namespace diff --git a/cpp/tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/kernels/mixtureOfExpertsTest.cu index 254c0814a..612e99fe9 100644 --- a/cpp/tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/kernels/mixtureOfExpertsTest.cu @@ -164,7 +164,7 @@ protected: size_t workspace_size = mMoERunner.getWorkspaceSize( mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mActType, parallelism_config); - const auto stream = mStream->get(); + auto const stream = mStream->get(); mWorkspace = allocBuffer(workspace_size); check_cuda_error(cudaMemsetAsync(mWorkspace, 0xD5, workspace_size, stream)); @@ -273,8 +273,8 @@ protected: { if (parallelism_config.tp_size > 1) { - const int tp_size = parallelism_config.tp_size; - const int tp_rank = parallelism_config.tp_rank; + int const tp_size = parallelism_config.tp_size; + int const tp_rank = parallelism_config.tp_rank; const size_t matrix_size = mHiddenSize * mInterSize / tp_size; @@ -321,7 +321,7 @@ protected: // Clear the buffers to blank so we can assume zero if not written resetOutBuffers(); - const auto [weight1_ptr, weight2_ptr, bias1_ptr, bias2_ptr] = getWeights(parallelism_config); + auto const [weight1_ptr, weight2_ptr, bias1_ptr, bias2_ptr] = getWeights(parallelism_config); auto stream = mStream->get(); mMoERunner.setTactic(std::nullopt); @@ -333,18 +333,18 @@ protected: } template - std::vector getDataFromDevice(const T* in, size_t length) + std::vector getDataFromDevice(T const* in, size_t length) { std::vector data(length); - const auto stream = mStream->get(); + auto const stream = mStream->get(); check_cuda_error(cudaMemcpyAsync(data.data(), in, length * sizeof(T), cudaMemcpyDeviceToHost, stream)); check_cuda_error(cudaStreamSynchronize(mStream->get())); return data; } - auto maskSelectedExpertsForTP(const std::vector& vector, int tp_size, int tp_rank) + auto maskSelectedExpertsForTP(std::vector const& vector, int tp_size, int tp_rank) { std::vector result; int num_experts_per_node = mNumExperts / tp_size; @@ -415,8 +415,8 @@ protected: return calcMLPVal(input, expert_id, mUseBias); } - void comparePermuted(const std::vector& expected_experts, const std::vector& expected_permutation, - const std::vector& input_data) + void comparePermuted(std::vector const& expected_experts, std::vector const& expected_permutation, + std::vector const& input_data) { auto states = getDataFromDevice(mExpertOutput, mTotalTokens * mK * mHiddenSize); @@ -427,11 +427,11 @@ protected: { // Permutation has the position of the first copy of all token, // followed by the position of the second copy of all tokens etc. - const int permuted_position = expected_permutation[k_idx * mTotalTokens + token_id]; + int const permuted_position = expected_permutation[k_idx * mTotalTokens + token_id]; // Expected experts has all the selected experts for token one, // followed by all the selected experts for token two etc. - const int expert_id = expected_experts[token_id * mK + k_idx]; + int const expert_id = expected_experts[token_id * mK + k_idx]; // Compare the copied tokens with the projection applied for (int hidden_id = 0; hidden_id < mHiddenSize; hidden_id++) @@ -445,7 +445,7 @@ protected: } } - std::vector softmax(const std::vector& expected_probs) + std::vector softmax(std::vector const& expected_probs) { std::vector result; // All values we test are 0-1 so we can skip the normalization step @@ -467,7 +467,7 @@ protected: return result; } - void compareSoftmax(const std::vector& expected_experts, const std::vector& expected_probs, + void compareSoftmax(std::vector const& expected_experts, std::vector const& expected_probs, std::vector scale_probs = {}) { if (scale_probs.empty()) @@ -489,7 +489,7 @@ protected: } } - void renormScales(DataType* probs, const int* experts) + void renormScales(DataType* probs, int const* experts) { if (mNormMode == MOEExpertScaleNormalizationMode::NONE) return; @@ -505,8 +505,8 @@ protected: } } - void compareFinal(const std::vector& expected_experts, const std::vector& expected_probs, - const std::vector& input_data, std::vector final_results = {}) + void compareFinal(std::vector const& expected_experts, std::vector const& expected_probs, + std::vector const& input_data, std::vector final_results = {}) { if (final_results.empty()) final_results = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize); @@ -536,7 +536,7 @@ protected: void BasicPermuteTest(int k = 1); - std::vector calcPermuteMapExpertParallel(const std::vector& expected_experts); + std::vector calcPermuteMapExpertParallel(std::vector const& expected_experts); void ExpertParallelTest(int k = 1); void TensorParallelTest(int k = 1); @@ -546,7 +546,7 @@ BufferManager::CudaStreamPtr MixtureOfExpertsTest::mStream{}; std::unique_ptr MixtureOfExpertsTest::mBufferManager{}; int MixtureOfExpertsTest::mDeviceCount{}; -const int DEFAULT_HIDDEN_SIZE = 4; +int const DEFAULT_HIDDEN_SIZE = 4; void MixtureOfExpertsTest::BasicPermuteTest(int k) { @@ -658,7 +658,7 @@ TEST_F(MixtureOfExpertsTest, Finished) compareFinal(selected_expert, probs, hidden_states); } -std::vector MixtureOfExpertsTest::calcPermuteMapExpertParallel(const std::vector& expected_experts) +std::vector MixtureOfExpertsTest::calcPermuteMapExpertParallel(std::vector const& expected_experts) { std::vector map(expected_experts.size()); auto getInterleavedIndex = [this](int i) { return (i % mK) * mTotalTokens + i / mK; }; @@ -713,7 +713,7 @@ void MixtureOfExpertsTest::ExpertParallelTest(int k) auto selected_expert = getDataFromDevice(mSelectedExpert, num_tokens * k); // Experts should only be selected when we are on the right node // Note the index is [0,num_experts_per_node), so we offset the experts by the start for this node - const int start_expert = i * (mNumExperts / parallelism); + int const start_expert = i * (mNumExperts / parallelism); std::transform(selected_expert.begin(), selected_expert.end(), selected_expert.begin(), [&](int val) { return val == mNumExperts ? mNumExperts : val + start_expert; }); auto masked_expected_experts = maskSelectedExpertsForTP(expected_experts, parallelism, i); diff --git a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp index 39a8db935..f52c74fc6 100644 --- a/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp +++ b/cpp/tests/kernels/sampling/samplingAirTopPTest.cpp @@ -41,25 +41,13 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest using SamplingKernelTest::mBufferManager; private: - size_t getWorkspaceSize(const SamplingKernelTestParam& params) override + size_t getWorkspaceSize(SamplingKernelTestParam const& params) override { - auto const maxBatchSize = 2 * params.batchSize; - size_t sampling_workspace_size_; - tk::invokeAirTopPSampling(nullptr, sampling_workspace_size_, - nullptr, // output_ids - nullptr, // sequence_length - nullptr, // finished_input_buffer - nullptr, // finished_output_buffer - nullptr, // cum_log_probs - nullptr, // output_log_probs - nullptr, // log_probs) - this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, nullptr, this->mMaxTopP, - this->mStream->get(), 0, nullptr, nullptr); - return sampling_workspace_size_; + return tensorrt_llm::kernels::getAirTopPWorkspaceSize(params.batchSize, params.vocabSize); } - void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize, - tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override + void callTestedFunction( + SamplingKernelTestParam const& params, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { // Calculate the number of blocks based on the number of multiprocessors, batchSize and vocabSize. int dev; @@ -70,7 +58,7 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest int blockNum = tk::calcAirTopPBlockNum(params.batchSize, params.vocabSize, smCnt); // Perform batched TopP sampling - tk::invokeBatchAirTopPSampling(workspaceDevice->data(), workspaceSize, bufferCast(*this->mIdsPtrHost), + tk::invokeBatchAirTopPSampling(workspaceDevice->data(), bufferCast(*this->mIdsPtrHost), bufferCast(*this->mSeqLengthsDevice), reinterpret_cast( bufferCast(*this->mFinishedDevice)), @@ -81,9 +69,10 @@ class AirTopPSamplingKernelTest : public SamplingKernelTest // log-prob if cum_log_probs or output_log_probs are // provided. It's because the sampling layer already // preprocesses log_prob_buf when those are provided. - bufferCast(*this->mProbsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, - params.vocabSize, bufferCast(*this->mEndIdsDevice), this->mMaxTopP, - hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, this->mStream->get(), blockNum, + bufferCast(*this->mProbsDevice), + reinterpret_cast(bufferCast(*this->mCurandStatesDevice)), params.batchSize, + maxBatchSize, params.vocabSize, bufferCast(*this->mEndIdsDevice), this->mMaxTopP, + bufferCast(*this->mTopPsDevice), this->mStream->get(), blockNum, bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots)); } }; @@ -92,29 +81,27 @@ TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes); TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessSmallP) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f)); }; TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeP) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f)); }; TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessAncestral) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f)); }; TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeVocabSmallP) { - this->runTest( - SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setOutputLen(16)); + this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f)); }; TYPED_TEST(AirTopPSamplingKernelTest, CorrectnessLargeVocabLargeP) { - this->runTest( - SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setOutputLen(16)); + this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f)); }; class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest diff --git a/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp b/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp index 393ac3ba1..71204138c 100644 --- a/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp +++ b/cpp/tests/kernels/sampling/samplingPenaltyTest.cpp @@ -24,6 +24,7 @@ namespace namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; +namespace trk = tensorrt_llm::runtime::kernels; using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; @@ -36,6 +37,7 @@ struct TemperatureTestParam int32_t vocabSize; TensorPtr temperatures; int32_t temperaturesSize; + int32_t maxTokensPerStep{1}; TemperatureTestParam& setBatchSize(int32_t bs) { @@ -61,6 +63,12 @@ struct TemperatureTestParam return *this; } + TemperatureTestParam& setMaxTokensPerStep(int32_t ts) + { + maxTokensPerStep = ts; + return *this; + } + std::string toString() const { return tc::fmtstr("TemperatureTestParam[batch=%d, vocab=%d, temperatures=%s]", batchSize, vocabSize, @@ -74,9 +82,10 @@ size_t padVocabSize(size_t vocabSize, size_t pad = 8) } template -void initLogitsAndBias(T* logits, T* bias, const size_t batchSize, const size_t vocabSize, const size_t vocabSizePadded) +void initLogitsAndBias(T* logits, T* bias, size_t const batchSize, size_t const maxTokensPerStep, + size_t const vocabSize, size_t const vocabSizePadded) { - initRandom(logits, batchSize * vocabSizePadded, -5.0f, 5.0f); + initRandom(logits, batchSize * maxTokensPerStep * vocabSizePadded, -5.0f, 5.0f); if (bias != nullptr) { initRandom(bias, batchSize * vocabSizePadded, -5.0f, 5.0f); @@ -84,15 +93,19 @@ void initLogitsAndBias(T* logits, T* bias, const size_t batchSize, const size_t bool is_half = sizeof(T) == 2; for (size_t i = 0; i < batchSize; ++i) { - for (size_t j = 0; j < vocabSizePadded; ++j) + for (size_t t = 0; t < maxTokensPerStep; ++t) { - if (j >= vocabSize) + for (size_t j = vocabSize; j < vocabSizePadded; ++j) { - logits[i * vocabSizePadded + j] = static_cast(is_half ? -HALF_FLT_MAX : -FLT_MAX); - if (bias != nullptr && i == 0) - { - bias[i * vocabSizePadded + j] = (T) 0.0f; - } + logits[(i * maxTokensPerStep + t) * vocabSizePadded + j] + = static_cast(is_half ? -HALF_FLT_MAX : -FLT_MAX); + } + } + for (size_t j = vocabSize; j < vocabSizePadded; ++j) + { + if (bias != nullptr && i == 0) + { + bias[i * vocabSizePadded + j] = (T) 0.0f; } } } @@ -109,6 +122,7 @@ class TemperaturePenaltyTest : public SamplingKernelTest int32_t mBatchSize; int32_t mVocabSize; int32_t mVocabSizePadded; + int32_t mMaxTokensPerStep; using SamplingKernelTest::mBufferManager; using SamplingKernelTest::mStream; @@ -119,12 +133,13 @@ class TemperaturePenaltyTest : public SamplingKernelTest TensorPtr mLogitsRefHost; TensorPtr mLogitsPtrs; TensorPtr mPenaltyWorkspaceDevice; + TensorPtr mTokensPerStep; TensorPtr mBiasHost; TensorPtr mBiasDevice; TensorPtr mTemperaturesDevice; TensorPtr mBatchSlots; - void subsetup(const TemperatureTestParam& param) + void subsetup(TemperatureTestParam const& param) { auto const dataType = TRTDataType::value; auto const ptrType = TRTDataType::value; @@ -133,20 +148,30 @@ class TemperaturePenaltyTest : public SamplingKernelTest mMaxBatchSize = 2 * mBatchSize; mVocabSize = param.vocabSize; mVocabSizePadded = padVocabSize(mVocabSize); + mMaxTokensPerStep = param.maxTokensPerStep; + + mLogitsHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsRefHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mOutLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType); - mLogitsHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsRefHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mOutLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsPtrs = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), ptrType); + mPenaltyWorkspaceDevice = mBufferManager->gpu( + ITensor::makeShape({mMaxBatchSize, mMaxTokensPerStep, mVocabSizePadded}), nvinfer1::DataType::kINT32); - mPenaltyWorkspaceDevice - = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), nvinfer1::DataType::kINT32); + mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mBiasHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); + mBiasHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); mBiasDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mVocabSizePadded}), dataType); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + + trk::invokeFill(*mLogitsRefHost, T{0.0f}, *mStream); + trk::invokeFill(*mOutLogitsDevice, T{0.0f}, *mStream); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < mBatchSize; ++bi) @@ -154,15 +179,22 @@ class TemperaturePenaltyTest : public SamplingKernelTest batchSlotsPtr[bi] = 2 * bi; } - initLogitsAndBias( - bufferCast(*mLogitsHost), bufferCast(*mBiasHost), mBatchSize, mVocabSize, mVocabSizePadded); + initLogitsAndBias(bufferCast(*mLogitsHost), bufferCast(*mBiasHost), mBatchSize, mMaxTokensPerStep, + mVocabSize, mVocabSizePadded); mBufferManager->copy(*mLogitsHost, *mLogitsDevice); auto logitsPtrs = BufferRange(*mLogitsPtrs); for (SizeType bi = 0; bi < mBatchSize; ++bi) { - logitsPtrs[bi] = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mVocabSizePadded; + logitsPtrs[bi] + = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mMaxTokensPerStep * mVocabSizePadded; + } + + auto tokensPerStepPtr = bufferCast(*mTokensPerStep); + for (SizeType bi = 0; bi < mMaxBatchSize; ++bi) + { + tokensPerStepPtr[bi] = (std::rand() % mMaxTokensPerStep) + 1; } mBufferManager->copy(*mBiasHost, *mBiasDevice); @@ -179,31 +211,39 @@ class TemperaturePenaltyTest : public SamplingKernelTest bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; auto const batchSlotsPtr = bufferCast(*mBatchSlots); + auto const tokensPerStepPtr = bufferCast(*mTokensPerStep); for (int32_t bi = 0; bi < mBatchSize; ++bi) { - for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + auto const batchSlot = batchSlotsPtr[bi]; + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - auto const srcIdx = (mBatchSize - bi - 1) * mVocabSizePadded + vi; - auto const dstIdx = bi * mVocabSizePadded + vi; - outLogits[dstIdx] = inLogits[srcIdx]; + for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + { + auto const srcIdx = ((mBatchSize - bi - 1) * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + auto const dstIdx = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + outLogits[dstIdx] = inLogits[srcIdx]; + } } } for (size_t bi = 0; bi < mBatchSize; ++bi) { auto const batchSlot = batchSlotsPtr[bi]; - auto temperature = temperatures[batchSlot]; - ASSERT_GT(temperature, 0.0f) << "temperature should be positive but got " << temperature; - for (size_t j = 0; j < mVocabSizePadded; ++j) + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - size_t index = bi * mVocabSizePadded + j; - auto logit = static_cast(outLogits[index]); - if (j < mVocabSize && bias != nullptr) + auto temperature = temperatures[batchSlot]; + ASSERT_GT(temperature, 0.0f) << "temperature should be positive but got " << temperature; + for (size_t j = 0; j < mVocabSizePadded; ++j) { - logit += static_cast(bias[batchSlot * mVocabSizePadded + j]); + size_t index = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded + j; + auto logit = static_cast(outLogits[index]); + if (j < mVocabSize && bias != nullptr) + { + logit += static_cast(bias[batchSlot * mVocabSizePadded + j]); + } + outLogits[index] = j < mVocabSize ? static_cast(logit / temperature) : -MAX_T_VAL; } - outLogits[index] = j < mVocabSize ? static_cast(logit / temperature) : -MAX_T_VAL; } } } @@ -213,22 +253,21 @@ class TemperaturePenaltyTest : public SamplingKernelTest { subsetup(param); // Do test - InvokeBatchApplyPenaltyParams penalty_params{reinterpret_cast(bufferCast(*mLogitsPtrs)), + InvokeBatchApplyPenaltyParams penaltyParams{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), bufferCast(*mBiasDevice), bufferCast(*mPenaltyWorkspaceDevice), nullptr, bufferCast(*mTemperaturesDevice), nullptr, nullptr, nullptr, false, static_cast(mBatchSize), 1, 1, static_cast(mVocabSize), static_cast(mVocabSizePadded), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - bufferCast(*mBatchSlots), mStream->get()}; - tk::invokeBatchApplyPenalty(penalty_params); + bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; + tk::invokeBatchApplyPenalty(penaltyParams); auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU); mStream->synchronize(); computeReference(bufferCast(*mLogitsHost), bufferCast(*mLogitsRefHost), bufferCast(*mBiasHost), bufferCast(*param.temperatures), param.temperaturesSize); - bool passed = checkResult(param.toString(), bufferCast(*logitsOutHost), bufferCast(*mLogitsRefHost), - mBatchSize * mVocabSizePadded); + mBatchSize * mMaxTokensPerStep * mVocabSizePadded); EXPECT_TRUE(passed); } }; @@ -239,8 +278,7 @@ TYPED_TEST(TemperaturePenaltyTest, NoPenalty) { int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; - TensorPtr temperaturesHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*temperaturesHost)[i] = 1.0f; @@ -256,8 +294,7 @@ TYPED_TEST(TemperaturePenaltyTest, LessThanOne) { int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; - TensorPtr temperaturesHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*temperaturesHost)[i] = 0.53f; @@ -273,8 +310,7 @@ TYPED_TEST(TemperaturePenaltyTest, GreaterThaneOne) { int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; - TensorPtr temperaturesHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*temperaturesHost)[i] = 2.01f; @@ -290,8 +326,7 @@ TYPED_TEST(TemperaturePenaltyTest, Mixed) { int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; - TensorPtr temperaturesHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*temperaturesHost)[i] = 0.53f + 0.2f * i; @@ -307,8 +342,7 @@ TYPED_TEST(TemperaturePenaltyTest, LargeVocab) { int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; - TensorPtr temperaturesHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*temperaturesHost)[i] = 0.53f + 0.2f * i; @@ -320,6 +354,23 @@ TYPED_TEST(TemperaturePenaltyTest, LargeVocab) .setTemperatures(temperaturesHost)); } +TYPED_TEST(TemperaturePenaltyTest, LargeVocabTokensPerStep) +{ + int32_t batchSize = 6; + int32_t maxBatchSize = 2 * batchSize; + TensorPtr temperaturesHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + for (int32_t i = 0; i < maxBatchSize; ++i) + { + bufferCast(*temperaturesHost)[i] = 1.f; // 0.53f + 0.2f * i; + } + this->runTest(TemperatureTestParam() + .setBatchSize(batchSize) + .setVocabSize(8) + .setTemperaturesSize(maxBatchSize) + .setTemperatures(temperaturesHost) + .setMaxTokensPerStep(4)); +} + struct RepetitionPenaltyTestCase { int32_t batchSize; @@ -331,6 +382,7 @@ struct RepetitionPenaltyTestCase int32_t repetitionPenaltiesSize; int32_t presencePenaltiesSize; int32_t frequencyPenaltiesSize; + int32_t maxTokensPerStep{1}; RepetitionPenaltyTestCase& setBatchSize(int32_t bs) { @@ -386,6 +438,12 @@ struct RepetitionPenaltyTestCase return *this; } + RepetitionPenaltyTestCase& setMaxTokensPerStep(int32_t ts) + { + maxTokensPerStep = ts; + return *this; + } + std::string toString() const { return tc::fmtstr( @@ -409,6 +467,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest int32_t mVocabSizePadded; int32_t mMaxInputLength; int32_t mSequenceLength; + int32_t mMaxTokensPerStep; using SamplingKernelTest::mBufferManager; using SamplingKernelTest::mStream; @@ -420,6 +479,8 @@ class RepetitionPenaltyTest : public SamplingKernelTest TensorPtr mLogitsPtrs; TensorPtr mPenaltyWorkspaceDevice; + TensorPtr mTokensPerStep; + TensorPtr mOutputIdsHost; TensorPtr mOutputIdsDevice; @@ -448,31 +509,38 @@ class RepetitionPenaltyTest : public SamplingKernelTest mVocabSizePadded = padVocabSize(mVocabSize); mMaxInputLength = param.maxInputLength; mSequenceLength = 2 * mMaxInputLength; // input + output + mMaxTokensPerStep = param.maxTokensPerStep; + + mLogitsHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsRefHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mOutLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType); - mLogitsHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsRefHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mOutLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsPtrs = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), ptrType); + mPenaltyWorkspaceDevice = mBufferManager->gpu( + ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize}), nvinfer1::DataType::kINT32); - mPenaltyWorkspaceDevice - = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSize}), nvinfer1::DataType::kINT32); + mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mOutputIdsHost - = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize, mSequenceLength}), nvinfer1::DataType::kINT32); + = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize, mSequenceLength}), nvinfer1::DataType::kINT32); mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize, mSequenceLength}), nvinfer1::DataType::kINT32); - mSeqLengthHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mSeqLengthHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mSeqLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mContextLengthHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mContextLengthHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mIdsPtrHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), ptrType); - mIdsPtrDevice = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), ptrType); + mIdsPtrHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), ptrType); + mIdsPtrDevice = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), ptrType); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < mBatchSize; ++bi) @@ -480,8 +548,8 @@ class RepetitionPenaltyTest : public SamplingKernelTest batchSlotsPtr[bi] = 2 * bi; } - initLogitsAndBias( - bufferCast(*mLogitsHost), static_cast(nullptr), mBatchSize, mVocabSize, mVocabSizePadded); + initLogitsAndBias(bufferCast(*mLogitsHost), static_cast(nullptr), mBatchSize, mMaxTokensPerStep, + mVocabSize, mVocabSizePadded); initRandomInt(bufferCast(*mOutputIdsHost), mSequenceLength * mMaxBatchSize, 0, mVocabSize); initRandomInt(bufferCast(*mSeqLengthHost), mMaxBatchSize, 1, mSequenceLength); for (size_t i = 0; i < mMaxBatchSize; ++i) @@ -489,11 +557,16 @@ class RepetitionPenaltyTest : public SamplingKernelTest bufferCast(*mContextLengthHost)[i] = bufferCast(*mSeqLengthHost)[i]; } + trk::invokeFill(*mLogitsRefHost, T{0.0f}, *mStream); + trk::invokeFill(*mOutLogitsDevice, T{0.0f}, *mStream); + auto idsPtrHostPtr = BufferRange(*mIdsPtrHost); auto outputIdsDevicePtr = bufferCast(*mOutputIdsDevice); + auto tokensPerStepPtr = bufferCast(*mTokensPerStep); for (SizeType bi = 0; bi < mMaxBatchSize; bi++) { idsPtrHostPtr[bi] = outputIdsDevicePtr + bi * mSequenceLength; + tokensPerStepPtr[bi] = (std::rand() % mMaxTokensPerStep) + 1; } mBufferManager->copy(*mLogitsHost, *mLogitsDevice); @@ -505,7 +578,8 @@ class RepetitionPenaltyTest : public SamplingKernelTest auto logitsPtrs = BufferRange(*mLogitsPtrs); for (SizeType bi = 0; bi < mBatchSize; ++bi) { - logitsPtrs[bi] = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mVocabSizePadded; + logitsPtrs[bi] + = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mMaxTokensPerStep * mVocabSizePadded; } ASSERT_EQ(param.repetitionPenaltiesSize, mMaxBatchSize) << "Invalid test configuration."; @@ -529,40 +603,49 @@ class RepetitionPenaltyTest : public SamplingKernelTest int32_t const frequencyPenaltiesSize) { std::vector penalized(mVocabSize); - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto const batchSlotsPtr = bufferCast(*mBatchSlots); + auto const tokensPerStepPtr = bufferCast(*mTokensPerStep); for (int32_t bi = 0; bi < mBatchSize; ++bi) { - for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + auto const batchSlot = batchSlotsPtr[bi]; + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - auto const srcIdx = (mBatchSize - bi - 1) * mVocabSizePadded + vi; - auto const dstIdx = bi * mVocabSizePadded + vi; - outLogits[dstIdx] = inLogits[srcIdx]; + for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + { + auto const srcIdx = ((mBatchSize - bi - 1) * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + auto const dstIdx = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + outLogits[dstIdx] = inLogits[srcIdx]; + } } } for (int32_t bi = 0; bi < mBatchSize; ++bi) { auto const batchSlot = batchSlotsPtr[bi]; - float repetitionPenalty - = repetitionPenaltiesSize > 1 ? repetitionPenalties[batchSlot] : repetitionPenalties[0]; - float presencePenalty = presencePenaltiesSize > 1 ? presencePenalties[batchSlot] : presencePenalties[0]; - float frequencyPenalty = frequencyPenaltiesSize > 1 ? frequencyPenalties[batchSlot] : frequencyPenalties[0]; - - std::fill(penalized.begin(), penalized.end(), false); - size_t offset = bi * mVocabSizePadded; - auto const step = sequenceLengths[batchSlot]; - for (int32_t t = 0; t < step; ++t) + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - auto tokenId = outputIds[batchSlot * mSequenceLength + t]; - if (!penalized[tokenId]) + float repetitionPenalty + = repetitionPenaltiesSize > 1 ? repetitionPenalties[batchSlot] : repetitionPenalties[0]; + float presencePenalty = presencePenaltiesSize > 1 ? presencePenalties[batchSlot] : presencePenalties[0]; + float frequencyPenalty + = frequencyPenaltiesSize > 1 ? frequencyPenalties[batchSlot] : frequencyPenalties[0]; + + std::fill(penalized.begin(), penalized.end(), false); + size_t offset = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded; + auto const step = sequenceLengths[batchSlot]; + for (int32_t t = 0; t < step; ++t) { - auto logit = static_cast(outLogits[offset + tokenId]); - outLogits[offset + tokenId] = static_cast( - (logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty) - presencePenalty); - penalized[tokenId] = true; + auto tokenId = outputIds[batchSlot * mSequenceLength + t]; + if (!penalized[tokenId]) + { + auto logit = static_cast(outLogits[offset + tokenId]); + outLogits[offset + tokenId] = static_cast( + (logit < 0.0f ? logit * repetitionPenalty : logit / repetitionPenalty) - presencePenalty); + penalized[tokenId] = true; + } + outLogits[offset + tokenId] -= frequencyPenalty; } - outLogits[offset + tokenId] -= frequencyPenalty; } } } @@ -571,6 +654,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest void runTest(RepetitionPenaltyTestCase param) { subsetup(param); + InvokeBatchApplyPenaltyParams penalty_params{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, bufferCast(*mRepetitionPenaltiesDevice), bufferCast(*mPresencePenaltiesDevice), @@ -578,7 +662,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest static_cast(mVocabSize), static_cast(mVocabSizePadded), reinterpret_cast(bufferCast(*mIdsPtrDevice)), nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), nullptr, nullptr, - bufferCast(*mBatchSlots), mStream->get()}; + bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; tk::invokeBatchApplyPenalty(penalty_params); auto logitsOutHost = mBufferManager->copyFrom(*mOutLogitsDevice, MemoryType::kCPU); @@ -592,7 +676,7 @@ class RepetitionPenaltyTest : public SamplingKernelTest mStream->synchronize(); bool passed = checkResult(param.toString(), bufferCast(*logitsOutHost), bufferCast(*mLogitsRefHost), - mBatchSize * mVocabSizePadded); + mBatchSize * mMaxTokensPerStep * mVocabSizePadded); EXPECT_TRUE(passed); } }; @@ -604,11 +688,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchNoPenalty) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -632,11 +716,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionLessThanOne) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53f; @@ -660,11 +744,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionGreaterThaneOne) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 2.01f; @@ -688,11 +772,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchRepetitionMixed) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; @@ -716,11 +800,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceMixed) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -744,11 +828,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchPresenceHasDefaultValueZero2) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -772,11 +856,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyMixed) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -800,11 +884,11 @@ TYPED_TEST(RepetitionPenaltyTest, BatchFrequencyHasDefaultValueZero2) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -828,11 +912,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionPresence) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; @@ -856,11 +940,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeRepetitionFrequency) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; @@ -884,11 +968,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypePresenceFrequency) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 1.0f; @@ -912,11 +996,11 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull) int32_t batchSize = 6; int32_t maxBatchSize = 2 * batchSize; TensorPtr repetitionPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr presencePenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); TensorPtr frequencyPenaltyHost - = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); for (int32_t i = 0; i < maxBatchSize; ++i) { bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; @@ -935,11 +1019,41 @@ TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFull) .setFrequencyPenaltiesSize(maxBatchSize)); } +TYPED_TEST(RepetitionPenaltyTest, PenaltyTypeFullTokensPerStep) +{ + int32_t batchSize = 6; + int32_t maxBatchSize = 2 * batchSize; + TensorPtr repetitionPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr presencePenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + TensorPtr frequencyPenaltyHost + = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + for (int32_t i = 0; i < maxBatchSize; ++i) + { + bufferCast(*repetitionPenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*presencePenaltyHost)[i] = 0.53 + i * 0.2f; + bufferCast(*frequencyPenaltyHost)[i] = 0.53 + i * 0.2f; + } + this->runTest(RepetitionPenaltyTestCase() + .setBatchSize(batchSize) + .setVocabSize(4) + .setMaxInputLength(5) + .setRepetitionPenalties(repetitionPenaltyHost) + .setPresencePenalties(presencePenaltyHost) + .setFrequencyPenalties(frequencyPenaltyHost) + .setRepetitionPenaltiesSize(maxBatchSize) + .setPresencePenaltiesSize(maxBatchSize) + .setFrequencyPenaltiesSize(maxBatchSize) + .setMaxTokensPerStep(4)); +} + struct MinLengthPenaltyTestParams { int32_t batchSize; int32_t vocabSize; int32_t maxSeqLength; + int32_t maxTokensPerStep{1}; MinLengthPenaltyTestParams& setBatchSize(int32_t bs) { @@ -959,10 +1073,16 @@ struct MinLengthPenaltyTestParams return *this; } + MinLengthPenaltyTestParams& setMaxTokensPerStep(int32_t ts) + { + maxTokensPerStep = ts; + return *this; + } + std::string toString() const { - return tc::fmtstr( - "MinLengthPenaltyTestParams[batch=%d, vocab=%d, maxSeqLen=%d]", batchSize, vocabSize, maxSeqLength); + return tc::fmtstr("MinLengthPenaltyTestParams[batch=%d, vocab=%d, maxSeqLen=%d, maxTokensPerStep=%d]", + batchSize, vocabSize, maxSeqLength, maxTokensPerStep); } }; @@ -977,6 +1097,7 @@ class MinLengthPenaltyTest : public SamplingKernelTest int32_t mVocabSizePadded; int32_t mMaxInputLength; int32_t mSequenceLength; + int32_t mMaxTokensPerStep; using SamplingKernelTest::mBufferManager; using SamplingKernelTest::mStream; @@ -988,6 +1109,8 @@ class MinLengthPenaltyTest : public SamplingKernelTest TensorPtr mLogitsPtrs; TensorPtr mPenaltyWorkspaceDevice; + TensorPtr mTokensPerStep; + TensorPtr mContextLengthHost; TensorPtr mContextLengthDevice; @@ -1013,29 +1136,36 @@ class MinLengthPenaltyTest : public SamplingKernelTest mVocabSizePadded = padVocabSize(mVocabSize); mMaxInputLength = param.maxSeqLength; mSequenceLength = 2 * mMaxInputLength; // input + output + mMaxTokensPerStep = param.maxTokensPerStep; + + mLogitsHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsRefHost + = BufferManager::pinned(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mOutLogitsDevice + = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSizePadded}), dataType); + mLogitsPtrs = BufferManager::pinned(ITensor::makeShape({mBatchSize}), ptrType); - mLogitsHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsRefHost = mBufferManager->pinned(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mOutLogitsDevice = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSizePadded}), dataType); - mLogitsPtrs = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), ptrType); + mPenaltyWorkspaceDevice = mBufferManager->gpu( + ITensor::makeShape({mBatchSize, mMaxTokensPerStep, mVocabSize}), nvinfer1::DataType::kINT32); - mPenaltyWorkspaceDevice - = mBufferManager->gpu(ITensor::makeShape({mBatchSize, mVocabSize}), nvinfer1::DataType::kINT32); + mTokensPerStep = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mSeqLengthHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mSeqLengthHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mSeqLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mContextLengthHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mContextLengthHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mContextLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mMinLengthHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mMinLengthHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mMinLengthDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mEndIdsHost = mBufferManager->pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); + mEndIdsHost = BufferManager::pinned(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({mMaxBatchSize}), nvinfer1::DataType::kINT32); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); + mBatchSlots = BufferManager::pinned(ITensor::makeShape({mBatchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < mBatchSize; ++bi) @@ -1043,22 +1173,27 @@ class MinLengthPenaltyTest : public SamplingKernelTest batchSlotsPtr[bi] = 2 * bi; } - initLogitsAndBias( - bufferCast(*mLogitsHost), static_cast(nullptr), mBatchSize, mVocabSize, mVocabSizePadded); + initLogitsAndBias(bufferCast(*mLogitsHost), static_cast(nullptr), mBatchSize, mMaxTokensPerStep, + mVocabSize, mVocabSizePadded); initRandomInt(bufferCast(*mContextLengthHost), mMaxBatchSize, 0, mMaxInputLength); initRandomInt(bufferCast(*mMinLengthHost), mMaxBatchSize, 1, mMaxInputLength); initRandomInt(bufferCast(*mEndIdsHost), mMaxBatchSize, 0, mVocabSize); + trk::invokeFill(*mLogitsRefHost, T{0.0f}, *mStream); + trk::invokeFill(*mOutLogitsDevice, T{0.0f}, *mStream); + auto seqLengthHostPtr = bufferCast(*mSeqLengthHost); auto contextLengthHostPtr = bufferCast(*mContextLengthHost); auto minLengthHostPtr = bufferCast(*mMinLengthHost); + auto tokensPerStepPtr = bufferCast(*mTokensPerStep); for (SizeType bi = 0; bi < mMaxBatchSize; bi++) { // Current generated seq len is randomly either smaller than min length or larger - const auto generatedSeqLen = std::max(0, + auto const generatedSeqLen = std::max(0, std::min( static_cast(minLengthHostPtr[bi] + 2 * std::pow(-1, std::rand() % 2)), mMaxInputLength)); seqLengthHostPtr[bi] = contextLengthHostPtr[bi] + generatedSeqLen; + tokensPerStepPtr[bi] = (std::rand() % mMaxTokensPerStep) + 1; } mBufferManager->copy(*mLogitsHost, *mLogitsDevice); @@ -1070,7 +1205,8 @@ class MinLengthPenaltyTest : public SamplingKernelTest auto logitsPtrs = BufferRange(*mLogitsPtrs); for (SizeType bi = 0; bi < mBatchSize; ++bi) { - logitsPtrs[bi] = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mVocabSizePadded; + logitsPtrs[bi] + = bufferCast(*mLogitsDevice) + (mBatchSize - bi - 1) * mMaxTokensPerStep * mVocabSizePadded; } } @@ -1079,26 +1215,34 @@ class MinLengthPenaltyTest : public SamplingKernelTest { bool const IS_FP16 = std::is_same::value; T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; - auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto const batchSlotsPtr = bufferCast(*mBatchSlots); + auto const tokensPerStepPtr = bufferCast(*mTokensPerStep); for (int32_t bi = 0; bi < mBatchSize; ++bi) { - for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + auto const batchSlot = batchSlotsPtr[bi]; + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - auto const srcIdx = (mBatchSize - bi - 1) * mVocabSizePadded + vi; - auto const dstIdx = bi * mVocabSizePadded + vi; - outLogits[dstIdx] = inLogits[srcIdx]; + for (int32_t vi = 0; vi < mVocabSizePadded; ++vi) + { + auto const srcIdx = ((mBatchSize - bi - 1) * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + auto const dstIdx = (bi * mMaxTokensPerStep + ti) * mVocabSizePadded + vi; + outLogits[dstIdx] = inLogits[srcIdx]; + } } } for (int32_t bi = 0; bi < mBatchSize; ++bi) { auto const batchSlot = batchSlotsPtr[bi]; - auto const generatedSeqLen = sequenceLengths[batchSlot] - contextLengths[batchSlot]; - auto const endId = endIds[batchSlot]; - if (generatedSeqLen < minSeqLen[batchSlot]) + for (int32_t ti = 0; ti < tokensPerStepPtr[batchSlot]; ++ti) { - outLogits[bi * mVocabSizePadded + endId] = -MAX_T_VAL; + auto const generatedSeqLen = sequenceLengths[batchSlot] - contextLengths[batchSlot]; + auto const endId = endIds[batchSlot]; + if (generatedSeqLen < minSeqLen[batchSlot]) + { + outLogits[(bi * mMaxTokensPerStep + ti) * mVocabSizePadded + endId] = -MAX_T_VAL; + } } } } @@ -1107,13 +1251,14 @@ class MinLengthPenaltyTest : public SamplingKernelTest void runTest(MinLengthPenaltyTestParams param) { subsetup(param); + InvokeBatchApplyPenaltyParams penalty_params{reinterpret_cast(bufferCast(*mLogitsPtrs)), bufferCast(*mOutLogitsDevice), nullptr, bufferCast(*mPenaltyWorkspaceDevice), nullptr, nullptr, nullptr, nullptr, nullptr, false, static_cast(mBatchSize), 1, mSequenceLength, static_cast(mVocabSize), static_cast(mVocabSizePadded), nullptr, nullptr, bufferCast(*mContextLengthDevice), bufferCast(*mSeqLengthDevice), bufferCast(*mMinLengthDevice), bufferCast(*mEndIdsDevice), - bufferCast(*mBatchSlots), mStream->get()}; + bufferCast(*mBatchSlots), mMaxTokensPerStep, bufferCast(*mTokensPerStep), mStream->get()}; tk::invokeBatchApplyPenalty(penalty_params); mStream->synchronize(); @@ -1127,7 +1272,7 @@ class MinLengthPenaltyTest : public SamplingKernelTest mStream->synchronize(); bool passed = checkResult(param.toString(), bufferCast(*logitsOutHost), bufferCast(*mLogitsRefHost), - mBatchSize * mVocabSizePadded); + mBatchSize * mMaxTokensPerStep * mVocabSizePadded); EXPECT_TRUE(passed); } }; @@ -1144,4 +1289,10 @@ TYPED_TEST(MinLengthPenaltyTest, BatchMaxSeqLen64) this->runTest(MinLengthPenaltyTestParams().setBatchSize(16).setVocabSize(51200).setMaxSeqLength(64)); } +TYPED_TEST(MinLengthPenaltyTest, BatchMaxSeqLen64TokensPerStep) +{ + this->runTest( + MinLengthPenaltyTestParams().setBatchSize(16).setVocabSize(51200).setMaxSeqLength(64).setMaxTokensPerStep(4)); +} + } // namespace diff --git a/cpp/tests/kernels/sampling/samplingTest.cpp b/cpp/tests/kernels/sampling/samplingTest.cpp index 281194d84..448096d6b 100644 --- a/cpp/tests/kernels/sampling/samplingTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTest.cpp @@ -41,290 +41,345 @@ void SamplingKernelTest::TearDown() } template -void SamplingKernelTest::allocateBuffers( - int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t outputLen) +void SamplingKernelTest::allocateBuffers(SamplingKernelTestParam const& param) { + auto const batchSize = param.batchSize; + auto const maxBatchSize = 2 * batchSize; + auto const vocabSize = param.vocabSize; + auto const maxTokensPerStep = param.maxTokensPerStep; + auto const dataType = TRTDataType::value; auto const ptrType = TRTDataType::value; // Allocate GPU data - mSeqLengthsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mSeqLengthsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mSeqLengthsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mFinishedHost = mBufferManager->pinned( + mFinishedHost = BufferManager::pinned( ITensor::makeShape({maxBatchSize}), TRTDataType::value); mFinishedDevice = mBufferManager->gpu( ITensor::makeShape({maxBatchSize}), TRTDataType::value); - mOutputIdsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, maxSeqLen}), nvinfer1::DataType::kINT32); - mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, maxSeqLen}), nvinfer1::DataType::kINT32); + mOutputIdsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32); + mOutputIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, mMaxSeqLen}), nvinfer1::DataType::kINT32); - mProbsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), dataType); - mProbsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, vocabSize}), dataType); + mProbsHost = BufferManager::pinned(ITensor::makeShape({batchSize, maxTokensPerStep, vocabSize}), dataType); + mProbsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, maxTokensPerStep, vocabSize}), dataType); + mProbsPtrsDevice + = BufferManager::pinned(ITensor::makeShape({batchSize, maxTokensPerStep}), nvinfer1::DataType::kINT64); mCumLogProbsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); mOutputLogProbsDevice - = mBufferManager->gpu(ITensor::makeShape({maxSeqLen, maxBatchSize}), nvinfer1::DataType::kFLOAT); + = mBufferManager->gpu(ITensor::makeShape({mMaxSeqLen, maxBatchSize}), nvinfer1::DataType::kFLOAT); mZeroParentIdsDevice - = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, maxSeqLen}), nvinfer1::DataType::kINT32); + = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32); mTopPIdValsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize, vocabSize}), nvinfer1::DataType::kINT32); mBeginOffsetsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); mEndOffsetsDevice = mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); - mLogitsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), dataType); - mLogProbsHost = mBufferManager->pinned(ITensor::makeShape({batchSize, vocabSize}), dataType); - mIdsPtrHost = mBufferManager->pinned(ITensor::makeShape({2 * maxBatchSize}), ptrType); + mLogitsHost = BufferManager::pinned(ITensor::makeShape({batchSize, maxTokensPerStep, vocabSize}), dataType); + mLogProbsHost = BufferManager::pinned(ITensor::makeShape({batchSize, maxTokensPerStep, vocabSize}), dataType); + mIdsPtrHost = BufferManager::pinned(ITensor::makeShape({2 * maxBatchSize}), ptrType); - mEndIdsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mEndIdsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mEndIdsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mTopPsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + mTopPsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); mTopPsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); - mTopKsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mTopKsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); mTopKsDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mSkipDecodeHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kBOOL); + mSkipDecodeHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kBOOL); mSkipDecodeDevice = mBufferManager->gpu(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kBOOL); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + mTokensPerStep = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + + mBatchSlots = BufferManager::pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + + mExpectedCumLogProbsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); - mExpectedCumLogProbsHost = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kFLOAT); + mCurandStatesDevice + = mBufferManager->gpu(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8); } template -void SamplingKernelTest::setupBuffers(int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, int32_t maxSeqLen, - int32_t outputLen, int32_t topK, float topP, bool useSkipDecode, bool hasDiffRuntimeArgs, std::mt19937& gen, - std::uniform_int_distribution<>& endIdsDistr) +void SamplingKernelTest::setupBuffers(SamplingKernelTestParam const& param) { + auto const batchSize = param.batchSize; + auto const maxBatchSize = 2 * batchSize; + auto const vocabSize = param.vocabSize; + auto const maxTokensPerStep = param.maxTokensPerStep; + + auto const topK = param.topK; + auto const topP = param.topP; + + std::mt19937 gen(42); + auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto probsPtr = BufferRange(*mProbsPtrsDevice); + auto probsDevicePtr = bufferCast(*mProbsDevice); for (SizeType bi = 0; bi < batchSize; ++bi) { batchSlotsPtr[bi] = 2 * bi; + for (SizeType ti = 0; ti < maxTokensPerStep; ++ti) + { + probsPtr[bi * maxTokensPerStep + ti] = probsDevicePtr + bi * maxTokensPerStep * vocabSize + ti * vocabSize; + } } // Allocate and init curand states - cudaMalloc(&mCurandStatesDevice, sizeof(curandState_t) * maxBatchSize); - tk::invokeCurandInitialize(mCurandStatesDevice, batchSlotsPtr, batchSize, mSeed, mStream->get()); + tk::invokeCurandInitialize(reinterpret_cast(bufferCast(*mCurandStatesDevice)), + batchSlotsPtr, batchSize, mSeed, mStream->get()); - std::uniform_real_distribution<> skipDecodeDist(0, 1); // uniform distribution between 0 and 1 - std::uniform_real_distribution<> topPDist(0, 1); // uniform distribution between 0 and 1 - std::uniform_int_distribution<> topKDist(1, std::min(1024, vocabSize)); + std::uniform_int_distribution<> endIdsDistr( + 0, vocabSize - 1); // -1 because uniform_int_distribution generates closed interval + std::uniform_real_distribution<> skipDecodeDist(0, 1); + std::uniform_real_distribution<> topPDist(0, topP); + std::uniform_int_distribution<> topKDist(1, topK); + std::uniform_int_distribution<> tokensPerStepDist(1, maxTokensPerStep); + std::uniform_int_distribution<> seqLenDist(0, mMaxSeqLen - maxTokensPerStep); + std::uniform_real_distribution<> logProbDist(-3.f, 3.f); + std::uniform_real_distribution<> finishedDist(0, 1); // Init by zero. - trk::invokeFill(*mSeqLengthsDevice, int32_t{0}, *mStream); trk::invokeFill(*mFinishedDevice, uint8_t{0}, *mStream); - trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream); trk::invokeFill(*mZeroParentIdsDevice, int32_t{0}, *mStream); trk::invokeFill(*mOutputIdsDevice, int32_t{0}, *mStream); - std::fill_n(bufferCast(*mExpectedCumLogProbsHost), maxBatchSize, 0); // Init topK, topP and endIds for each request in batch auto skipDecodeHostPtr = bufferCast(*mSkipDecodeHost); auto topPsHostPtr = bufferCast(*mTopPsHost); auto topKsHostPtr = bufferCast(*mTopKsHost); auto endIdsHostPtr = bufferCast(*mEndIdsHost); + auto tokensPerStepPtr = bufferCast(*mTokensPerStep); + auto finishedHostPtr + = reinterpret_cast(bufferCast(*mFinishedHost)); for (SizeType bi = 0; bi < maxBatchSize; ++bi) { endIdsHostPtr[bi] = endIdsDistr(gen); - skipDecodeHostPtr[bi] = useSkipDecode ? skipDecodeDist(gen) > 0.8 : false; - topKsHostPtr[bi] = hasDiffRuntimeArgs ? topKDist(gen) : topK; - topPsHostPtr[bi] = hasDiffRuntimeArgs ? topPDist(gen) : topP; + skipDecodeHostPtr[bi] = skipDecodeDist(gen) > 0.8; + topPsHostPtr[bi] = topPDist(gen); + topKsHostPtr[bi] = topKDist(gen); + tokensPerStepPtr[bi] = tokensPerStepDist(gen); + finishedHostPtr[bi] = finishedDist(gen) > 0.8 ? tk::FinishedState::finished() : tk::FinishedState::empty(); } - mMaxTopK = *std::max_element(topKsHostPtr, topKsHostPtr + maxBatchSize); - mMaxTopP = *std::max_element(topPsHostPtr, topPsHostPtr + maxBatchSize); + mMaxTopK = topK; + mMaxTopP = topP; + + TLLM_CHECK(mMaxTopK * maxTokensPerStep <= mMaxSeqLen); // Setup pointers to output ids for each request in batch auto idsPtrHostPtr = BufferRange(*mIdsPtrHost); auto outputIdsDevicePtr = bufferCast(*mOutputIdsDevice); auto zeroParentIdsDevicePtr = bufferCast(*mZeroParentIdsDevice); + auto seqLensHostPtr = bufferCast(*mSeqLengthsHost); + auto logProbHostPtr = bufferCast(*mExpectedCumLogProbsHost); for (SizeType bi = 0; bi < maxBatchSize; bi++) { - idsPtrHostPtr[bi] = outputIdsDevicePtr + bi * maxSeqLen; + idsPtrHostPtr[bi] = outputIdsDevicePtr + bi * mMaxSeqLen; + idsPtrHostPtr[maxBatchSize + bi] = zeroParentIdsDevicePtr + bi * mMaxSeqLen; } + for (SizeType bi = 0; bi < maxBatchSize; bi++) { - idsPtrHostPtr[maxBatchSize + bi] = zeroParentIdsDevicePtr + bi * maxSeqLen; + seqLensHostPtr[bi] = seqLenDist(gen); + logProbHostPtr[bi] = logProbDist(gen); } mBufferManager->copy(*mEndIdsHost, *mEndIdsDevice); mBufferManager->copy(*mSkipDecodeHost, *mSkipDecodeDevice); mBufferManager->copy(*mTopPsHost, *mTopPsDevice); mBufferManager->copy(*mTopKsHost, *mTopKsDevice); -} + mBufferManager->copy(*mSeqLengthsHost, *mSeqLengthsDevice); + mBufferManager->copy(*mExpectedCumLogProbsHost, *mCumLogProbsDevice); + mBufferManager->copy(*mFinishedHost, *mFinishedDevice); -template -void SamplingKernelTest::verifyCurrentStep(int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, - int32_t maxSeqLen, int32_t step, bool greedySearch, bool useSkipDecode, bool hasDiffRuntimeArgs, - std::vector& refFinished, std::vector& refSeqLength, - std::vector const& finishedCurrentStep) -{ - auto const batchSlotsPtr = bufferCast(*mBatchSlots); - auto const outputIdsHostPtr = bufferCast(*mOutputIdsHost); - auto const seqLengthsHostPtr = bufferCast(*mSeqLengthsHost); - auto const finishedHostPtr - = reinterpret_cast(bufferCast(*mFinishedHost)); - auto const logProbsHostPtr = bufferCast(*mLogProbsHost); - auto const endIdsHostPtr = bufferCast(*mEndIdsHost); - auto const skipDecodeHostPtr = bufferCast(*mSkipDecodeHost); - auto const expectedCumLogProbsHostPtr = bufferCast(*mExpectedCumLogProbsHost); + // Init logits randomly + auto logitsHostPtr = bufferCast(*mLogitsHost); + initRandom(logitsHostPtr, batchSize * maxTokensPerStep * vocabSize, -3.0f, 3.0f); - for (SizeType bi = 0; bi < batchSize; ++bi) + // Only in greedy search we can guarantee the selected token and stop by condition + if (topK == 1) { - auto const batchSlot = batchSlotsPtr[bi]; - // Set reference finished state to true if we finished before or at current step - bool const generatedEOS = outputIdsHostPtr[batchSlot * maxSeqLen + step] == endIdsHostPtr[batchSlot]; - bool finishedThisStep = finishedCurrentStep[batchSlot].isFinished() || generatedEOS; - refFinished[batchSlot] = generatedEOS ? tk::FinishedState::finishedEOS() : refFinished[batchSlot]; - - if (!refFinished[batchSlot].isFinished()) - { - // Increase reference seq len excluding the EOS token - refSeqLength[batchSlot]++; - } - - // If decoding for this batch is skipped ignore cumLog calculation - if (!skipDecodeHostPtr[batchSlot]) + for (SizeType bi = 0; bi < batchSize; ++bi) { - // Check seq len correctness - EXPECT_EQ(seqLengthsHostPtr[batchSlot], refSeqLength[batchSlot]); - // Only in greedy search we can guarantee the selected token and stop by condition - if (greedySearch) + auto const batchSlot = batchSlotsPtr[bi]; + for (int32_t ti = 0; ti < maxTokensPerStep; ++ti) { - EXPECT_EQ(finishedHostPtr[batchSlot].isFinished(), refFinished[batchSlot].isFinished()); + // Set logit of the endId for the finished request to the value above others + // NOTE that we can guarantee finish only in greedy search + logitsHostPtr[(bi * maxTokensPerStep + ti) * vocabSize + endIdsHostPtr[batchSlot]] = 4.0f; } + } + } - // Check the range of the returned token ([0, vocabSize)) - auto const outputId = outputIdsHostPtr[batchSlot * maxSeqLen + step]; - EXPECT_TRUE((outputId >= 0) && (outputId < vocabSize)); - int idx = bi * vocabSize + outputId; + // Compute probabilities for each token + computeProb(bufferCast(*mProbsHost), logitsHostPtr, batchSize * maxTokensPerStep, vocabSize); + mBufferManager->copy(*mProbsHost, *mProbsDevice); +} - // Compute reference cumLogProb by summing all logProbs up to the stop token - expectedCumLogProbsHostPtr[batchSlot] - += step < refSeqLength[batchSlot] || finishedThisStep ? (float) logProbsHostPtr[idx] : 0.0f; - // If sequence has just finished at this step - if (finishedHostPtr[batchSlot].isFinished() && step < seqLengthsHostPtr[batchSlot]) - { - // Check that finished tokens is endId - EXPECT_EQ(outputIdsHostPtr[batchSlot * maxSeqLen + step], endIdsHostPtr[batchSlot]) - << "step: " << step << " b: " << bi << " hasDiffRuntimeArgs: " << hasDiffRuntimeArgs - << " useSkipDecode: " << useSkipDecode; - } - // TODO(nkorobov): check correctness with K>1 - } +template +std::vector SamplingKernelTest::computeTopKTopPVariants( + int32_t bi, int32_t batchSlot, int32_t ti, int32_t maxTokensPerStep, int32_t vocabSize) +{ + std::vector allowedTokens; + auto probsPtr = bufferCast(*mProbsHost) + (bi * maxTokensPerStep + ti) * vocabSize; + std::vector indices(vocabSize); + std::iota(indices.begin(), indices.end(), 0); + std::sort( + indices.begin(), indices.end(), [probsPtr](SizeType i1, SizeType i2) { return probsPtr[i1] > probsPtr[i2]; }); + + auto topK = bufferCast(*mTopKsHost)[batchSlot]; + auto topP = bufferCast(*mTopPsHost)[batchSlot]; + + allowedTokens.insert(allowedTokens.begin(), indices.begin(), indices.begin() + topK); + float totalProb = 0.f; + SizeType idx = 0; + while (totalProb < topP && idx < vocabSize) + { + allowedTokens.push_back(indices[idx]); + totalProb += static_cast(probsPtr[indices[idx++]]); } + return allowedTokens; } template -void SamplingKernelTest::runTest(const SamplingKernelTestParam& param, bool hasDiffRuntimeArgs, bool useSkipDecode) +void SamplingKernelTest::verifyResult(SamplingKernelTestParam const& param) { auto const batchSize = param.batchSize; - auto const maxBatchSize = 2 * batchSize; auto const vocabSize = param.vocabSize; - auto const outputLen = param.outputLen; - auto const maxSeqLen = outputLen; + auto const maxTokensPerStep = param.maxTokensPerStep; - auto const topK = param.topK; - auto const topP = param.topP; - - bool const greedySearch = topK == 1 && hasDiffRuntimeArgs == false && useSkipDecode == false; + auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, MemoryType::kCPU); + auto const seqLenHost = mBufferManager->copyFrom(*mSeqLengthsDevice, MemoryType::kCPU); + auto const finishedHost = mBufferManager->copyFrom(*mFinishedDevice, MemoryType::kCPU); + auto const cumLogProbsHost = mBufferManager->copyFrom(*mCumLogProbsDevice, MemoryType::kCPU); - std::mt19937 gen(42); - std::uniform_real_distribution<> finishedDist(0, 1); // uniform distribution between 0 and 1 - std::uniform_int_distribution<> endIdsDistr( - 0, vocabSize - 1); // -1 because uniform_int_distribution generates closed interval + // Synchronize to get valid data on Host + mStream->synchronize(); - // Allocate buffers - allocateBuffers(batchSize, maxBatchSize, vocabSize, maxSeqLen, outputLen); + // Compute reference. + computeLogProb(bufferCast(*mLogProbsHost), bufferCast(*mLogitsHost), batchSize * maxTokensPerStep, vocabSize); - // Setup buffers - setupBuffers(batchSize, maxBatchSize, vocabSize, maxSeqLen, outputLen, topK, topP, useSkipDecode, - hasDiffRuntimeArgs, gen, endIdsDistr); auto const batchSlotsPtr = bufferCast(*mBatchSlots); - // Allocate internal state holders for reference - std::vector refSeqLength(maxBatchSize); - std::vector refFinished(maxBatchSize, tk::FinishedState::empty()); + auto const outputIdsHostPtr = bufferCast(*outputIdsHost); + auto const seqLengthsHostPtr = bufferCast(*seqLenHost); + auto const finishedHostPtr + = reinterpret_cast(bufferCast(*finishedHost)); - // retrieve the workspace size of the sampling kernel. - auto const workspaceSize = getWorkspaceSize(param); - TensorPtr workspaceDevice - = mBufferManager->gpu(ITensor::makeShape({static_cast(workspaceSize)}), nvinfer1::DataType::kINT8); + auto const outputIdsOrigHostPtr = bufferCast(*mOutputIdsHost); + auto const seqLengthsOrigHostPtr = bufferCast(*mSeqLengthsHost); + auto const finishedOrigHostPtr + = reinterpret_cast(bufferCast(*mFinishedHost)); + + auto const logProbsHostPtr = bufferCast(*mLogProbsHost); + auto const endIdsHostPtr = bufferCast(*mEndIdsHost); + auto const skipDecodeHostPtr = bufferCast(*mSkipDecodeHost); + auto const tokensPerStepPtr = bufferCast(*mTokensPerStep); + auto const expectedCumLogProbsHostPtr = bufferCast(*mExpectedCumLogProbsHost); - for (size_t step = 0; step < outputLen; ++step) + for (SizeType bi = 0; bi < batchSize; ++bi) { - // Init logits randomly - auto logitsHostPtr = bufferCast(*mLogitsHost); - auto endIdsHostPtr = bufferCast(*mEndIdsHost); - initRandom(logitsHostPtr, batchSize * vocabSize, -3.0f, 3.0f); - - std::vector finishedCurrentStep(maxBatchSize, tk::FinishedState::empty()); - // Only in greedy search we can guarantee the selected token and stop by condition - if (greedySearch) + auto const batchSlot = batchSlotsPtr[bi]; + auto const tokensPerStep = tokensPerStepPtr[batchSlot]; + for (SizeType ti = 0; ti < tokensPerStep; ++ti) { - for (SizeType bi = 0; bi < batchSize; ++bi) + auto kResults = param.returnAllTopK ? bufferCast(*mTopKsHost)[batchSlot] : 1; + + for (SizeType ki = 0; ki < kResults; ++ki) { - auto const batchSlot = batchSlotsPtr[bi]; - // Randomly decide if the sequence finishes at current step - finishedCurrentStep[batchSlot] - = (refFinished[batchSlot].isFinished() == false && finishedDist(gen) < 0.1) - ? tk::FinishedState::finishedEOS() - : tk::FinishedState::empty(); - - if (finishedCurrentStep[batchSlot].isFinished()) + // Set reference finished state to true if we finished before or at current step + auto const idsIdx = param.returnAllTopK ? ti * mMaxTopK + ki : seqLengthsOrigHostPtr[batchSlot] + ti; + auto const outputId = outputIdsHostPtr[batchSlot * mMaxSeqLen + idsIdx]; + // Check the range of the returned token ([0, vocabSize)) + EXPECT_TRUE((outputId >= 0) && (outputId < vocabSize)); + bool const generatedEOS = outputId == endIdsHostPtr[batchSlot]; + + // If decoding for this batch is skipped ignore cumLog calculation + if (!skipDecodeHostPtr[batchSlot] && !finishedOrigHostPtr[batchSlot].isFinished() + && !finishedOrigHostPtr[batchSlot].isSkipDecoding()) { - // Set logit of the endId for the finished request to the value above others - // NOTE that we can guarantee finish only in greedy search - logitsHostPtr[bi * vocabSize + endIdsHostPtr[batchSlot]] = 4.0f; + if (maxTokensPerStep == 1 && !param.returnAllTopK) + { + if (generatedEOS) + { + EXPECT_EQ(seqLengthsHostPtr[batchSlot], seqLengthsOrigHostPtr[batchSlot]); + EXPECT_TRUE(finishedHostPtr[batchSlot].isFinished()); + } + else + { + EXPECT_EQ(seqLengthsHostPtr[batchSlot], seqLengthsOrigHostPtr[batchSlot] + tokensPerStep); + EXPECT_EQ( + finishedHostPtr[batchSlot].isFinished(), finishedOrigHostPtr[batchSlot].isFinished()); + } + } + + auto topKTopPVariants = computeTopKTopPVariants(bi, batchSlot, ti, maxTokensPerStep, vocabSize); + + bool found = false; + for (auto const& var : topKTopPVariants) + { + if (outputId == var) + { + found = true; + break; + } + } + EXPECT_TRUE(found) << "Incorrect output id token"; + + // Compute reference cumLogProb by summing all logProbs up to the stop token + expectedCumLogProbsHostPtr[batchSlot] + += static_cast(logProbsHostPtr[bi * vocabSize + outputId]); + } + else + { + // Check that tensors are not modified + auto const idsIdx = batchSlot * mMaxSeqLen + seqLengthsOrigHostPtr[batchSlot] + ti; + EXPECT_EQ(outputId, outputIdsOrigHostPtr[idsIdx]); + EXPECT_EQ(seqLengthsHostPtr[batchSlot], seqLengthsOrigHostPtr[batchSlot]); + EXPECT_EQ(finishedHostPtr[batchSlot].isFinished(), finishedOrigHostPtr[batchSlot].isFinished()); } } } - - // Compute probobilities for each token - computeProb(bufferCast(*mProbsHost), bufferCast(*mLogitsHost), batchSize, vocabSize); - mBufferManager->copy(*mProbsHost, *mProbsDevice); - - // Call tested function sampling - callTestedFunction(param, hasDiffRuntimeArgs, workspaceSize, workspaceDevice); - - mBufferManager->copy(*mOutputIdsDevice, *mOutputIdsHost); - mBufferManager->copy(*mSeqLengthsDevice, *mSeqLengthsHost); - mBufferManager->copy(*mFinishedDevice, *mFinishedHost); - - // Synchronize to get valid data on Host - mStream->synchronize(); - - // Compute reference. - computeLogProb(bufferCast(*mLogProbsHost), bufferCast(*mLogitsHost), batchSize, vocabSize); - - verifyCurrentStep(batchSize, maxBatchSize, vocabSize, maxSeqLen, step, greedySearch, useSkipDecode, - hasDiffRuntimeArgs, refFinished, refSeqLength, finishedCurrentStep); } - auto const cumLogProbsHost = mBufferManager->copyFrom(*mCumLogProbsDevice, MemoryType::kCPU); - - mStream->synchronize(); - for (int32_t bi = 0; bi < batchSize; ++bi) + // Cum log probs is not supported for multiple tokens per step or all top K return + if (maxTokensPerStep == 1 && !param.returnAllTopK) { - auto batchSlotsPtr = bufferCast(*mBatchSlots); - auto const batchSlot = batchSlotsPtr[bi]; - bool passed = checkResult(param.toString(), bufferCast(*cumLogProbsHost) + batchSlot, - bufferCast(*mExpectedCumLogProbsHost) + batchSlot, 1); - EXPECT_TRUE(passed); + for (int32_t bi = 0; bi < batchSize; ++bi) + { + auto batchSlotsPtr = bufferCast(*mBatchSlots); + auto const batchSlot = batchSlotsPtr[bi]; + bool passed = checkResult("cum log probs", bufferCast(*cumLogProbsHost) + batchSlot, + bufferCast(*mExpectedCumLogProbsHost) + batchSlot, 1); + EXPECT_TRUE(passed); + } } - - cudaFree(mCurandStatesDevice); } template -void SamplingKernelTest::runTest(const SamplingKernelTestParam& param) +void SamplingKernelTest::runTest(SamplingKernelTestParam const& param) { - runTest(param, false, false); // Single params, do not skip decoders - runTest(param, true, false); // Different params, do not skip decoders - runTest(param, false, true); // Single params, skip some decoders - runTest(param, true, true); // Different params, skip some decoders + // Allocate buffers + allocateBuffers(param); + + // Setup buffers + setupBuffers(param); + + // Retrieve the workspace size of the sampling kernel. + auto const workspaceSize = getWorkspaceSize(param); + TensorPtr workspaceDevice + = mBufferManager->gpu(ITensor::makeShape({static_cast(workspaceSize)}), nvinfer1::DataType::kINT8); + + // Call tested function sampling + callTestedFunction(param, workspaceDevice); + + // Verify results + verifyResult(param); } template class SamplingKernelTest; diff --git a/cpp/tests/kernels/sampling/samplingTest.h b/cpp/tests/kernels/sampling/samplingTest.h index cacee4e66..102c1caea 100644 --- a/cpp/tests/kernels/sampling/samplingTest.h +++ b/cpp/tests/kernels/sampling/samplingTest.h @@ -22,6 +22,7 @@ #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/kernels/penaltyKernels.h" +#include "tensorrt_llm/kernels/samplingAirTopPKernels.h" #include "tensorrt_llm/kernels/samplingTopKKernels.h" #include "tensorrt_llm/kernels/samplingTopPKernels.h" #include "tensorrt_llm/runtime/bufferManager.h" @@ -120,7 +121,7 @@ bool checkResult(std::string name, T* out, T* ref, size_t size) /////////////////////////////////// Tests ////////////////////////////////////////// template -void computeProb(T* probs, const T* logits, int batchSize, int vocabSize) +void computeProb(T* probs, T const* logits, int batchSize, int vocabSize) { // Compute the log probability from logits. // logits = batchSize x vocabSize. @@ -153,7 +154,7 @@ void computeProb(T* probs, const T* logits, int batchSize, int vocabSize) } template -void computeLogProb(T* logprobs, const T* logits, int batchSize, int vocabSize) +void computeLogProb(T* logprobs, T const* logits, int batchSize, int vocabSize) { // Compute the log probability from logits. // logits = batchSize x vocabSize. @@ -189,11 +190,13 @@ struct SamplingKernelTestParam { int32_t batchSize; int32_t vocabSize; - uint32_t topK; - float topP; - int32_t outputLen; - bool normalizeLogProbs = false; - bool logitsHasProbs = true; + uint32_t topK{1}; + float topP{0.f}; + bool normalizeLogProbs{false}; + bool logitsHasProbs{true}; + int32_t maxTokensPerStep{1}; + bool returnAllTopK{false}; + bool useLogitsPtrs{false}; SamplingKernelTestParam& setBatchSize(int32_t bs) { @@ -219,16 +222,29 @@ struct SamplingKernelTestParam return *this; } - SamplingKernelTestParam& setOutputLen(int32_t ol) + SamplingKernelTestParam& setMaxTokensPerStep(int32_t ts) { - outputLen = ol; + maxTokensPerStep = ts; + return *this; + } + + SamplingKernelTestParam& setReturnAllTopK() + { + returnAllTopK = true; + return *this; + } + + SamplingKernelTestParam& setUseLogitsPtrs() + { + useLogitsPtrs = true; return *this; } std::string toString() const { - return tensorrt_llm::common::fmtstr("SamplingKernelTestParam[batch=%d, vocab=%d, k=%u, p=%3.1f, output_len=%d]", - batchSize, vocabSize, topK, topP, outputLen); + return tensorrt_llm::common::fmtstr( + "SamplingKernelTestParam[batch=%d, vocab=%d, k=%u, p=%3.1f, tokens_per_step=%d]", batchSize, vocabSize, + topK, topP, maxTokensPerStep); } }; @@ -241,34 +257,28 @@ class SamplingKernelTest : public testing::Test void SetUp() override; void TearDown() override; - void runTest(const SamplingKernelTestParam& param); + void runTest(SamplingKernelTestParam const& param); protected: - virtual size_t getWorkspaceSize(const SamplingKernelTestParam& param) + virtual size_t getWorkspaceSize(SamplingKernelTestParam const& param) { throw std::logic_error("Not implemented"); }; - virtual void callTestedFunction( - const SamplingKernelTestParam& param, bool hasDiffRuntimeArgs, size_t workspaceSize, TensorPtr& workspaceDevice) + virtual void callTestedFunction(SamplingKernelTestParam const& param, TensorPtr& workspaceDevice) { throw std::logic_error("Not implemented"); } - void allocateBuffers( - int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t outputLen); +private: + void allocateBuffers(SamplingKernelTestParam const& param); - void setupBuffers(int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t outputLen, - int32_t topK, float topP, bool useSkipDecode, bool hasDiffRuntimeArgs, std::mt19937& gen, - std::uniform_int_distribution<>& endIdsDistr); + void setupBuffers(SamplingKernelTestParam const& param); - void verifyCurrentStep(int32_t batchSize, int32_t maxBatchSize, int32_t vocabSize, int32_t maxSeqLen, int32_t step, - bool greedySearch, bool useSkipDecode, bool hasDiffRuntimeArgs, - std::vector& refFinished, std::vector& refSeqLength, - const std::vector& finishedCurrentStep); + void verifyResult(SamplingKernelTestParam const& param); -private: - void runTest(const SamplingKernelTestParam& param, bool hasDiffRuntimeArgs, bool useSkipDecode); + std::vector computeTopKTopPVariants( + int32_t bi, int32_t batchSlot, int32_t ti, int32_t tokensPerStep, int32_t vocabSize); protected: std::shared_ptr mBufferManager; @@ -288,6 +298,7 @@ class SamplingKernelTest : public testing::Test TensorPtr mProbsHost; TensorPtr mProbsDevice; + TensorPtr mProbsPtrsDevice; TensorPtr mCumLogProbsDevice; TensorPtr mOutputLogProbsDevice; @@ -312,14 +323,17 @@ class SamplingKernelTest : public testing::Test TensorPtr mSkipDecodeHost; TensorPtr mSkipDecodeDevice; + TensorPtr mTokensPerStep; + TensorPtr mBatchSlots; TensorPtr mExpectedCumLogProbsHost; + TensorPtr mCurandStatesDevice; + int32_t mMaxTopK; + static constexpr int32_t mMaxSeqLen = 2048; float mMaxTopP; - - curandState_t* mCurandStatesDevice; }; } // namespace tensorrt_llm::tests::kernels::sampling diff --git a/cpp/tests/kernels/sampling/samplingTopKTest.cpp b/cpp/tests/kernels/sampling/samplingTopKTest.cpp index 54fac9fb9..4cfc41317 100644 --- a/cpp/tests/kernels/sampling/samplingTopKTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTopKTest.cpp @@ -17,6 +17,7 @@ #error "Define TOP_LEVEL_DIR" #endif +#include "tensorrt_llm/common/tllmException.h" #include "tests/kernels/sampling/samplingTest.h" #include @@ -35,44 +36,41 @@ class TopKSamplingKernelTest : public SamplingKernelTest { protected: - const int32_t endId = 0; + int32_t const endId = 0; using SamplingKernelTest::mSeed; using SamplingKernelTest::mStream; using SamplingKernelTest::mBufferManager; - size_t getWorkspaceSize(const SamplingKernelTestParam& params) override + size_t getWorkspaceSize(SamplingKernelTestParam const& params) override { - auto const maxBatchSize = 2 * params.batchSize; - size_t workspaceSize; - tk::invokeTopKSampling(nullptr, workspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, this->mMaxTopK, 1.0f, params.vocabSize, nullptr, nullptr, this->mStream->get(), params.batchSize, - maxBatchSize, nullptr, true, false); - return workspaceSize; + return tk::getTopKWorkspaceSize(params.batchSize, params.maxTokensPerStep, this->mMaxTopK, params.vocabSize); } - void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize, - tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override + void callTestedFunction( + SamplingKernelTestParam const& params, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { auto const maxBatchSize = 2 * params.batchSize; // Perform batched TopK sampling - tk::invokeBatchTopKSampling(workspaceDevice->data(), workspaceSize, + tk::invokeBatchTopKSampling(workspaceDevice->data(), // Note that the kernel needs vocab probs instead of // log-prob if cum_log_probs or output_log_probs are // provided. It's because the sampling layer already // preprocesses log_prob_buf when those are provided. - bufferCast(*this->mProbsDevice), bufferCast(*this->mIdsPtrHost), - bufferCast(*this->mSeqLengthsDevice), + params.useLogitsPtrs ? nullptr : bufferCast(*this->mProbsDevice), + params.useLogitsPtrs ? reinterpret_cast(bufferCast(*this->mProbsPtrsDevice)) + : nullptr, + bufferCast(*this->mIdsPtrHost), bufferCast(*this->mSeqLengthsDevice), reinterpret_cast( bufferCast(*this->mFinishedDevice)), reinterpret_cast( bufferCast(*this->mFinishedDevice)), bufferCast(*this->mCumLogProbsDevice), bufferCast(*this->mOutputLogProbsDevice), - this->mCurandStatesDevice, this->mMaxTopK, - hasDiffRuntimeArgs ? bufferCast(*this->mTopKsDevice) : nullptr, params.topP, - hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, params.vocabSize, - bufferCast(*this->mEndIdsDevice), bufferCast(*this->mBatchSlots), this->mStream->get(), - params.batchSize, maxBatchSize, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, - params.logitsHasProbs); + reinterpret_cast(bufferCast(*this->mCurandStatesDevice)), this->mMaxTopK, + bufferCast(*this->mTopKsDevice), params.topP, bufferCast(*this->mTopPsDevice), + params.vocabSize, bufferCast(*this->mEndIdsDevice), bufferCast(*this->mBatchSlots), + this->mStream->get(), params.batchSize, maxBatchSize, bufferCast(*this->mTokensPerStep), + params.maxTokensPerStep, bufferCast(*this->mSkipDecodeDevice), params.normalizeLogProbs, + params.logitsHasProbs, params.returnAllTopK); } }; @@ -80,44 +78,66 @@ TYPED_TEST_SUITE(TopKSamplingKernelTest, FloatAndHalfTypes); TYPED_TEST(TopKSamplingKernelTest, CorrectnessGreedy) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(1).setTopP(1.0f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(1).setTopP(1.0f)); }; TYPED_TEST(TopKSamplingKernelTest, CorrectnessGreedyLarge) { - this->runTest( - SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(1).setTopP(1.0f).setOutputLen(8)); + this->runTest(SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(1).setTopP(1.0f)); }; TYPED_TEST(TopKSamplingKernelTest, CorrectnessAncestral) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(4).setTopP(1.0f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(4).setTopP(1.0f)); }; TYPED_TEST(TopKSamplingKernelTest, CorrectnessLargeK63) { - this->runTest( - SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(63).setTopP(1.0f).setOutputLen(8)); + this->runTest(SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(63).setTopP(1.0f)); }; TYPED_TEST(TopKSamplingKernelTest, CorrectnessLargeK1024) { - this->runTest( - SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(1024).setTopP(1.0f).setOutputLen(8)); + this->runTest(SamplingKernelTestParam().setBatchSize(16).setVocabSize(51200).setTopK(1024).setTopP(1.0f)); }; TYPED_TEST(TopKSamplingKernelTest, CorrectnessTopKTopP) { - this->runTest( - SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(63).setTopP(0.3f).setOutputLen(8)); + this->runTest(SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(63).setTopP(0.3f)); }; TYPED_TEST(TopKSamplingKernelTest, NotSupportedLargerThanK1024) { EXPECT_THROW( - this->runTest( - SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(1025).setTopP(1.0f).setOutputLen(8)), - std::domain_error); + this->runTest(SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(1025).setTopP(1.0f)), + tensorrt_llm::common::TllmException); +}; + +TYPED_TEST(TopKSamplingKernelTest, CorrectnessTopKMaxTokensPerStep) +{ + this->runTest( + SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(63).setTopP(1.0f).setMaxTokensPerStep(4)); }; +TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllTopK) +{ + this->runTest(SamplingKernelTestParam() + .setBatchSize(16) + .setVocabSize(50) + .setTopK(10) + .setTopP(1.0f) + .setMaxTokensPerStep(4) + .setReturnAllTopK()); +}; + +TYPED_TEST(TopKSamplingKernelTest, CorrectnessLogitsPtrs) +{ + this->runTest(SamplingKernelTestParam() + .setBatchSize(16) + .setVocabSize(50) + .setTopK(10) + .setTopP(1.0f) + .setMaxTokensPerStep(4) + .setUseLogitsPtrs()); +}; } // end of namespace diff --git a/cpp/tests/kernels/sampling/samplingTopPTest.cpp b/cpp/tests/kernels/sampling/samplingTopPTest.cpp index 2e7b8c555..d59efbb32 100644 --- a/cpp/tests/kernels/sampling/samplingTopPTest.cpp +++ b/cpp/tests/kernels/sampling/samplingTopPTest.cpp @@ -41,45 +41,15 @@ class TopPSamplingKernelTest : public SamplingKernelTest using SamplingKernelTest::mBufferManager; private: - size_t getWorkspaceSize(const SamplingKernelTestParam& params) override + size_t getWorkspaceSize(SamplingKernelTestParam const& params) override { - auto const maxBatchSize = 2 * params.batchSize; - size_t workspaceSize; - size_t cubTempStorageSize; - tk::invokeBatchTopPSampling(nullptr, // workspace - workspaceSize, cubTempStorageSize, - nullptr, // output_ids - nullptr, // sequence_length - nullptr, // finished_buffer - nullptr, // finished_buffer - nullptr, // cum_log_probs - nullptr, // output_log_probs - nullptr, // log_probs - bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), - bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, - params.vocabSize, nullptr, this->mMaxTopP, bufferCast(*this->mTopPsDevice), this->mStream->get(), - nullptr, nullptr); - return workspaceSize; + return tensorrt_llm::kernels::getTopPWorkspaceSize(params.batchSize, params.vocabSize); } - void callTestedFunction(const SamplingKernelTestParam& params, bool hasDiffRuntimeArgs, size_t workspaceSize, - tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override + void callTestedFunction( + SamplingKernelTestParam const& params, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { auto const maxBatchSize = 2 * params.batchSize; - size_t cubTempStorageSize; - tk::invokeBatchTopPSampling(nullptr, // workspace - workspaceSize, cubTempStorageSize, - nullptr, // output_ids - nullptr, // sequence_length - nullptr, // finished_buffer - nullptr, // finished_buffer - nullptr, // cum_log_probs - nullptr, // output_log_probs - nullptr, // log_probs - bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), - bufferCast(*this->mBeginOffsetsDevice), this->mCurandStatesDevice, params.batchSize, maxBatchSize, - params.vocabSize, nullptr, this->mMaxTopP, bufferCast(*this->mTopPsDevice), this->mStream->get(), - nullptr, nullptr); // Perform batched TopP sampling tk::invokeTopPInitialize(bufferCast(*this->mTopPIdValsDevice), @@ -87,8 +57,8 @@ class TopPSamplingKernelTest : public SamplingKernelTest params.batchSize, params.vocabSize, this->mStream->get()); // Perform batched TopP sampling - tk::invokeBatchTopPSampling(workspaceDevice->data(), workspaceSize, cubTempStorageSize, - bufferCast(*this->mIdsPtrHost), bufferCast(*this->mSeqLengthsDevice), + tk::invokeBatchTopPSampling(workspaceDevice->data(), bufferCast(*this->mIdsPtrHost), + bufferCast(*this->mSeqLengthsDevice), reinterpret_cast( bufferCast(*this->mFinishedDevice)), reinterpret_cast( @@ -100,10 +70,10 @@ class TopPSamplingKernelTest : public SamplingKernelTest // preprocesses log_prob_buf when those are provided. bufferCast(*this->mProbsDevice), bufferCast(*this->mTopPIdValsDevice), bufferCast(*this->mEndOffsetsDevice), bufferCast(*this->mBeginOffsetsDevice), - this->mCurandStatesDevice, params.batchSize, maxBatchSize, params.vocabSize, - bufferCast(*this->mEndIdsDevice), this->mMaxTopP, - hasDiffRuntimeArgs ? bufferCast(*this->mTopPsDevice) : nullptr, this->mStream->get(), - bufferCast(*this->mSkipDecodeDevice), bufferCast(*this->mBatchSlots)); + reinterpret_cast(bufferCast(*this->mCurandStatesDevice)), params.batchSize, + maxBatchSize, params.vocabSize, bufferCast(*this->mEndIdsDevice), this->mMaxTopP, + bufferCast(*this->mTopPsDevice), this->mStream->get(), bufferCast(*this->mSkipDecodeDevice), + bufferCast(*this->mBatchSlots)); } }; @@ -111,29 +81,27 @@ TYPED_TEST_SUITE(TopPSamplingKernelTest, FloatAndHalfTypes); TYPED_TEST(TopPSamplingKernelTest, CorrectnessSmallP) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f)); }; TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeP) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f)); }; TYPED_TEST(TopPSamplingKernelTest, CorrectnessAncestral) { - this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setOutputLen(1)); + this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f)); }; TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabSmallP) { - this->runTest( - SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setOutputLen(16)); + this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f)); }; TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabLargeP) { - this->runTest( - SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setOutputLen(16)); + this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f)); }; class TopPSamplingKernelUtilsTest : public SamplingKernelTest @@ -145,25 +113,25 @@ TEST_F(TopPSamplingKernelUtilsTest, invokeTopPInitialize) const int32_t batchSize = 8; const int32_t vocabSize = 256; - const auto topPIdValsDevice + auto const topPIdValsDevice = this->mBufferManager->gpu(ITensor::makeShape({batchSize, vocabSize}), nvinfer1::DataType::kINT32); - const auto beginOffsetsDevice + auto const beginOffsetsDevice = this->mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); - const auto endOffsetsDevice + auto const endOffsetsDevice = this->mBufferManager->gpu(ITensor::makeShape({batchSize + 1}), nvinfer1::DataType::kINT32); tk::invokeTopPInitialize(bufferCast(*topPIdValsDevice), bufferCast(*endOffsetsDevice), bufferCast(*beginOffsetsDevice), batchSize, vocabSize, this->mStream->get()); - const auto topPIdValsHost = this->mBufferManager->copyFrom(*topPIdValsDevice, MemoryType::kCPU); - const auto endOffsetsHost = this->mBufferManager->copyFrom(*endOffsetsDevice, MemoryType::kCPU); - const auto beginOffsetsHost = this->mBufferManager->copyFrom(*beginOffsetsDevice, MemoryType::kCPU); + auto const topPIdValsHost = this->mBufferManager->copyFrom(*topPIdValsDevice, MemoryType::kCPU); + auto const endOffsetsHost = this->mBufferManager->copyFrom(*endOffsetsDevice, MemoryType::kCPU); + auto const beginOffsetsHost = this->mBufferManager->copyFrom(*beginOffsetsDevice, MemoryType::kCPU); this->mStream->synchronize(); - const auto topPIdValsHostPtr = bufferCast(*topPIdValsHost); - const auto endOffsetsHostPtr = bufferCast(*endOffsetsHost); - const auto beginOffsetsHostPtr = bufferCast(*beginOffsetsHost); + auto const topPIdValsHostPtr = bufferCast(*topPIdValsHost); + auto const endOffsetsHostPtr = bufferCast(*endOffsetsHost); + auto const beginOffsetsHostPtr = bufferCast(*beginOffsetsHost); for (int32_t bi = 0; bi < batchSize + 1; ++bi) { diff --git a/cpp/tests/kernels/sampling/samplingUtilsTest.cu b/cpp/tests/kernels/sampling/samplingUtilsTest.cu index d15d753ef..04f42207e 100644 --- a/cpp/tests/kernels/sampling/samplingUtilsTest.cu +++ b/cpp/tests/kernels/sampling/samplingUtilsTest.cu @@ -31,7 +31,7 @@ namespace static float constexpr HALF_FLT_MAX = 65504.F; -__global__ void generateRandomNumber(int32_t* vals, curandState_t* states, const int batch_size) +__global__ void generateRandomNumber(int32_t* vals, curandState_t* states, int const batch_size) { int idx = threadIdx.x; if (idx < batch_size) @@ -196,17 +196,16 @@ public: int32_t const vocabSize = 51000; int32_t const vocabSizePadded = tc::divUp(vocabSize, 256) * 256; - auto logitsHost - = this->mBufferManager->pinned(ITensor::makeShape({batchSize, beamWidth, vocabSizePadded}), dataType); - auto logitsHostPtrs = this->mBufferManager->pinned(ITensor::makeShape({batchSize}), ptrType); - auto refLogitsHost = this->mBufferManager->pinned( + auto logitsHost = BufferManager::pinned(ITensor::makeShape({batchSize, beamWidth, vocabSizePadded}), dataType); + auto logitsHostPtrs = BufferManager::pinned(ITensor::makeShape({batchSize}), ptrType); + auto refLogitsHost = BufferManager::pinned( ITensor::makeShape({batchSize, beamWidth, vocabSizePadded}), nvinfer1::DataType::kFLOAT); - auto biasHost = this->mBufferManager->pinned(ITensor::makeShape({vocabSize}), dataType); - auto endIdsHost = this->mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - auto finishedHost = this->mBufferManager->pinned( + auto biasHost = BufferManager::pinned(ITensor::makeShape({vocabSize}), dataType); + auto endIdsHost = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + auto finishedHost = BufferManager::pinned( ITensor::makeShape({beamWidth, maxBatchSize}), TRTDataType::value); - auto batchSlots = this->mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + auto batchSlots = BufferManager::pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*batchSlots); for (SizeType bi = 0; bi < batchSize; ++bi) diff --git a/cpp/tests/kernels/shiftKCacheKernelTest.cu b/cpp/tests/kernels/shiftKCacheKernelTest.cu index bba4abea8..d502e2970 100644 --- a/cpp/tests/kernels/shiftKCacheKernelTest.cu +++ b/cpp/tests/kernels/shiftKCacheKernelTest.cu @@ -44,11 +44,11 @@ struct SATypeConverter }; template -__global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, const int sizePerHead, - const int beam_width, const int* token_read_idxs, const int* token_write_idxs, const int* token_pos_idxs, - const int* token_seq_idxs, const int* sequence_lengths, const int* input_lengths, const int rotary_embedding_dim, +__global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, int const sizePerHead, + int const beam_width, int const* token_read_idxs, int const* token_write_idxs, int const* token_pos_idxs, + int const* token_seq_idxs, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, - const int rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) + int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type) { // We allow only fp32/fp16/bf16 as the data types to apply rotary static_assert(sizeof(T) == 4 || sizeof(T) == 2, ""); @@ -57,29 +57,29 @@ __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, // Each thread will handle 16 bytes. constexpr int vec_size = 16u / sizeof(T); using Vec_k = typename mmha::packed_type::type; - const int sizePerHeadDivX = sizePerHead / vec_size; + int const sizePerHeadDivX = sizePerHead / vec_size; // The position idx - const int token_idx = token_seq_idxs[blockIdx.x]; - const int token_read_idx = token_read_idxs[blockIdx.x]; - const int token_write_idx = token_write_idxs[blockIdx.x]; - const int token_pos_idx = token_pos_idxs[blockIdx.x]; + int const token_idx = token_seq_idxs[blockIdx.x]; + int const token_read_idx = token_read_idxs[blockIdx.x]; + int const token_write_idx = token_write_idxs[blockIdx.x]; + int const token_pos_idx = token_pos_idxs[blockIdx.x]; // Head - const int head_idx = blockIdx.y; + int const head_idx = blockIdx.y; // The batch beam idx - const int batch_beam_idx = blockIdx.z; + int const batch_beam_idx = blockIdx.z; // The beam idx - const int beam_idx = batch_beam_idx % beam_width; + int const beam_idx = batch_beam_idx % beam_width; // Thread idx - const int tidx = threadIdx.x; + int const tidx = threadIdx.x; // The actual sequence length excluding the paddings. - const int tlength = sequence_lengths[batch_beam_idx] - 1; + int const tlength = sequence_lengths[batch_beam_idx] - 1; // The context length - const int inlength = input_lengths[batch_beam_idx]; + int const inlength = input_lengths[batch_beam_idx]; // Mask out the tokens exceed the real total length and tokens in the context phase with beam_idx>0 - const bool valid_seq = token_idx < tlength && !(token_idx < inlength && beam_idx > 0); - const bool is_head_size_masked = tidx * vec_size >= sizePerHead; + bool const valid_seq = token_idx < tlength && !(token_idx < inlength && beam_idx > 0); + bool const is_head_size_masked = tidx * vec_size >= sizePerHead; if (!valid_seq || is_head_size_masked) { @@ -90,7 +90,7 @@ __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, Vec_k k; T* k_cache = reinterpret_cast(kCacheRead.getKBlockPtr(batch_beam_idx, token_read_idx)); int inBlockIdx_r = kCacheRead.getKVLocalIdx(token_read_idx, head_idx, sizePerHead, tidx * vec_size); - k = *reinterpret_cast(&k_cache[inBlockIdx_r]); + k = *reinterpret_cast(&k_cache[inBlockIdx_r]); // Apply position embedding switch (position_embedding_type) @@ -103,14 +103,14 @@ __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, } case PositionEmbeddingType::kROPE_GPT_NEOX: { - const bool do_rotary = vec_size * tidx < rotary_embedding_dim; + bool const do_rotary = vec_size * tidx < rotary_embedding_dim; T* k_smem = reinterpret_cast(smem_); - const int half_rotary_dim = rotary_embedding_dim / 2; - const int half_idx = (tidx * vec_size) / half_rotary_dim; - const int intra_half_idx = (tidx * vec_size) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? + int const half_rotary_dim = rotary_embedding_dim / 2; + int const half_idx = (tidx * vec_size) / half_rotary_dim; + int const intra_half_idx = (tidx * vec_size) % half_rotary_dim; + int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts? if (do_rotary) { @@ -119,7 +119,7 @@ __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, __syncthreads(); - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + int const transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; constexpr int tidx_factor = vec_size / 2; if (do_rotary) { @@ -146,15 +146,15 @@ __global__ void applyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, } template -void invokeApplyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, const int sizePerHead, const int batch_beam, - const int kv_head_num, const int beam_width, const int* token_read_idxs, const int* token_write_idxs, - const int* token_pos_idxs, const int* token_seq_idxs, const int token_num, const int* sequence_lengths, - const int* input_lengths, const int rotary_embedding_dim, float rotary_embedding_base, - RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, const int rotary_embedding_max_positions, +void invokeApplyRoPE(KVCacheBuffer kCacheRead, KVLinearBuffer kCacheWrite, int const sizePerHead, int const batch_beam, + int const kv_head_num, int const beam_width, int const* token_read_idxs, int const* token_write_idxs, + int const* token_pos_idxs, int const* token_seq_idxs, int const token_num, int const* sequence_lengths, + int const* input_lengths, int const rotary_embedding_dim, float rotary_embedding_base, + RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream) { // Block handles K tile. - const int vec_size = 16u / sizeof(T); + int const vec_size = 16u / sizeof(T); dim3 block((sizePerHead / vec_size + 31) / 32 * 32); dim3 grid(token_num, kv_head_num, batch_beam); size_t smem_size @@ -186,10 +186,10 @@ public: void TearDown() override {} void initData(int32_t batchSize, int32_t beamWidth, int32_t numHeads, int32_t maxAttentionWindow, int32_t headSize, - bool pagedKvCache, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, const std::vector& seqLengths, - const std::vector& inputLengths, const std::vector& tokenReadIdxs, - const std::vector& tokenWriteIdxs, const std::vector& tokenPosIdxs, - const std::vector& tokenSeqIdxs) + bool pagedKvCache, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, std::vector const& seqLengths, + std::vector const& inputLengths, std::vector const& tokenReadIdxs, + std::vector const& tokenWriteIdxs, std::vector const& tokenPosIdxs, + std::vector const& tokenSeqIdxs) { // allocate buffer mSeqLengthsHost = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); @@ -303,8 +303,8 @@ public: } float compareResults(KVLinearBuffer kCacheOut, KVLinearBuffer kCacheRef, int32_t batchBeam, int32_t beamWidth, - int32_t numHeads, int32_t headSize, int32_t validTokenNum, const int32_t* seqLengths, - const int32_t* inputLengths, const int32_t* tokenWriteIdxs, const int32_t* tokenSeqIdxs) + int32_t numHeads, int32_t headSize, int32_t validTokenNum, int32_t const* seqLengths, + int32_t const* inputLengths, int32_t const* tokenWriteIdxs, int32_t const* tokenSeqIdxs) { mBufferManager->copy(*mOutputDataDevice, *mOutputDataHost); mBufferManager->copy(*mRefOutputDataDevice, *mRefOutputDataHost); @@ -316,16 +316,16 @@ public: float tot_diff = 0.f; for (SizeType bi = 0; bi < batchBeam; ++bi) { - const int tlength = seqLengths[bi] - 1; - const int inlength = inputLengths[bi]; - const int beam_idx = bi % beamWidth; + int const tlength = seqLengths[bi] - 1; + int const inlength = inputLengths[bi]; + int const beam_idx = bi % beamWidth; for (SizeType hi = 0; hi < numHeads; ++hi) { for (SizeType ti = 0; ti < validTokenNum; ++ti) { - const int token_seq_idx = tokenSeqIdxs[ti]; - const int token_write_idx = tokenWriteIdxs[ti]; - const bool valid_seq = token_seq_idx < tlength && !(token_seq_idx < inlength && beam_idx > 0); + int const token_seq_idx = tokenSeqIdxs[ti]; + int const token_write_idx = tokenWriteIdxs[ti]; + bool const valid_seq = token_seq_idx < tlength && !(token_seq_idx < inlength && beam_idx > 0); if (!valid_seq) { continue; @@ -350,7 +350,7 @@ public: void runTest(int32_t batchSize, int32_t beamWidth, int32_t numHeads, int32_t headSize, int32_t maxAttentionWindow, int32_t sinkTokenLength, int32_t pastKCacheLength, int32_t validTokenNum, bool pagedKvCache, int32_t maxBlocksPerSeq, int32_t tokensPerBlock, int rotaryEmbeddingDim, float rotaryEmbeddingBase, - RotaryScalingType const rotaryScaleType, float rotaryEmbeddingScale, const int rotaryEmbeddingMaxPositions, + RotaryScalingType const rotaryScaleType, float rotaryEmbeddingScale, int const rotaryEmbeddingMaxPositions, PositionEmbeddingType const positionEmbeddingType) { // Synchronize @@ -358,7 +358,7 @@ public: // get kv cache const int32_t batchBeam = batchSize * beamWidth; - const auto elemSize = sizeof(T); + auto const elemSize = sizeof(T); KVLinearBuffer shiftKCacheBuffer = KVLinearBuffer(batchBeam, 1, maxAttentionWindow, numHeads * headSize * elemSize, maxAttentionWindow, sinkTokenLength, true); @@ -587,7 +587,7 @@ TYPED_TEST(ShiftKCacheKernelTest, CyclicShiftKCacheSink) std::vector tokenPosIdxs = {0, 1, 2, 3}; std::vector tokenSeqIdxs = {0, 1, 2, 3}; - const int cyclicLength = maxAttentionWindow - sinkTokenLength; + int const cyclicLength = maxAttentionWindow - sinkTokenLength; for (SizeType idx = pastKCacheLength - cyclicLength; idx < pastKCacheLength; ++idx) { tokenReadIdxs.push_back(sinkTokenLength + bubbleLength + (idx - sinkTokenLength) % cyclicLength); diff --git a/cpp/tests/kernels/stopCriteriaKernelsTest.cpp b/cpp/tests/kernels/stopCriteriaKernelsTest.cpp index 57a31357b..3a66fa444 100644 --- a/cpp/tests/kernels/stopCriteriaKernelsTest.cpp +++ b/cpp/tests/kernels/stopCriteriaKernelsTest.cpp @@ -57,31 +57,31 @@ class StopCriteriaKernelsTest : public testing::Test std::uniform_int_distribution seqLenDistr(0, mMaxSeqLen); mSequenceLengths - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); - mSequenceLengthLimits = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mFinished = mBufferManager->pinned( + = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32); + mSequenceLengthLimits = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mFinished = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth}), TRTDataType::value); - mFinishedSum = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mFinishedSum = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mOutputIds = mBufferManager->pinned( + mOutputIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mOutputIdsPtr - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); - mParentIds = mBufferManager->pinned( + mParentIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); mParentIdsPtr - = mBufferManager->pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); + = BufferManager::pinned(ITensor::makeShape({maxBatchSize, beamWidth}), nvinfer1::DataType::kINT64); - mRefOutputIds = mBufferManager->pinned( + mRefOutputIds = BufferManager::pinned( ITensor::makeShape({maxBatchSize, beamWidth, mMaxSeqLen}), nvinfer1::DataType::kINT32); - mStopWords = mBufferManager->pinned( - ITensor::makeShape({maxBatchSize, 2, maxStopWordsLen}), nvinfer1::DataType::kINT32); - mStopWordsPtr = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT64); - mStopWordsLen = mBufferManager->pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); + mStopWords + = BufferManager::pinned(ITensor::makeShape({maxBatchSize, 2, maxStopWordsLen}), nvinfer1::DataType::kINT32); + mStopWordsPtr = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT64); + mStopWordsLen = BufferManager::pinned(ITensor::makeShape({maxBatchSize}), nvinfer1::DataType::kINT32); - mBatchSlots = mBufferManager->pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); + mBatchSlots = BufferManager::pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32); auto batchSlotsPtr = bufferCast(*mBatchSlots); for (SizeType bi = 0; bi < batchSize; ++bi) @@ -229,7 +229,7 @@ class StopCriteriaKernelsTest : public testing::Test } } - bool isSubsequence(const SizeType* sequence, SizeType n, const std::vector& subsequence) + bool isSubsequence(SizeType const* sequence, SizeType n, std::vector const& subsequence) { auto it = std::search(sequence, sequence + n, subsequence.begin(), subsequence.end()); return it != sequence + n; @@ -278,10 +278,10 @@ class StopCriteriaKernelsTest : public testing::Test std::vector>> const& stopWords, SizeType batchSize, SizeType beamWidth) { SizeType maxStopWordsLen = 0; - for (const auto& batchStopWords : stopWords) + for (auto const& batchStopWords : stopWords) { SizeType stopWordsLen = 0; - for (const auto& words : batchStopWords) + for (auto const& words : batchStopWords) { stopWordsLen += words.size(); } @@ -311,7 +311,7 @@ class StopCriteriaKernelsTest : public testing::Test tk::invokeLengthCriterion( reinterpret_cast(bufferCast(*mFinished)), bufferCast(*mFinishedSum), - reinterpret_cast(bufferCast(*mSequenceLengthLimits)), + reinterpret_cast(bufferCast(*mSequenceLengthLimits)), bufferCast(*mSequenceLengths), bufferCast(*mBatchSlots), batchSize, beamWidth, mStream->get()); diff --git a/cpp/tests/layers/baseSamplingLayerTest.cpp b/cpp/tests/layers/baseSamplingLayerTest.cpp index e49a21c05..9b3fbbc6e 100644 --- a/cpp/tests/layers/baseSamplingLayerTest.cpp +++ b/cpp/tests/layers/baseSamplingLayerTest.cpp @@ -167,7 +167,7 @@ bool BaseSamplingLayerTest::checkResult(int32_t* outputIds, std::vector expts = expectedIds.at(i); - const auto outputId = outputIds[batchSlot * mMaxSeqLen + s]; + auto const outputId = outputIds[batchSlot * mMaxSeqLen + s]; if (expts.count(outputId) == 0) { if (failures < 10) @@ -214,7 +214,7 @@ void BaseSamplingLayerTest::runTest( mStream->synchronize(); } - const auto outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); + auto const outputIdsHost = mBufferManager->copyFrom(*mOutputIdsDevice, tensorrt_llm::runtime::MemoryType::kCPU); mStream->synchronize(); diff --git a/cpp/tests/layers/baseSamplingLayerTest.h b/cpp/tests/layers/baseSamplingLayerTest.h index a9adfce79..82c4807ba 100644 --- a/cpp/tests/layers/baseSamplingLayerTest.h +++ b/cpp/tests/layers/baseSamplingLayerTest.h @@ -43,7 +43,7 @@ namespace tensorrt_llm::tests::layers::sampling constexpr float EPSILON = 1e-20f; template -void computeProb(T* probs, const T* logits, int batchSize, int vocabSize) +void computeProb(T* probs, T const* logits, int batchSize, int vocabSize) { // Compute the log probability from logits. // logits = batchSize x vocabSize. diff --git a/cpp/tests/layers/dynamicDecodeLayerTest.cpp b/cpp/tests/layers/dynamicDecodeLayerTest.cpp index e06eb9601..f62570476 100644 --- a/cpp/tests/layers/dynamicDecodeLayerTest.cpp +++ b/cpp/tests/layers/dynamicDecodeLayerTest.cpp @@ -260,10 +260,10 @@ template SizeType DynamicDecodeLayerTest::getMaxWordsLen(std::vector>> const& inputWords) { SizeType maxWordsLen = 0; - for (const auto& batchWords : inputWords) + for (auto const& batchWords : inputWords) { SizeType wordsLen = 0; - for (const auto& words : batchWords) + for (auto const& words : batchWords) { wordsLen += words.size(); } @@ -379,7 +379,7 @@ typename DynamicDecodeLayer::OutputParams DynamicDecodeLayerTest::createOu template void DynamicDecodeLayerTest::batchCopy(int32_t step) { - const auto logitsHost = ITensor::wrap(mTestLogitsInit.data() + step * mVocabSizePadded, + auto const logitsHost = ITensor::wrap(mTestLogitsInit.data() + step * mVocabSizePadded, std::is_same_v ? nvinfer1::DataType::kFLOAT : nvinfer1::DataType::kHALF, ITensor::makeShape({1, mVocabSizePadded})); for (int32_t bi = 0; bi < mBatchSize; ++bi) @@ -408,7 +408,7 @@ bool DynamicDecodeLayerTest::checkResult(int32_t* outputIds, std::vector expts = expectedIds.at(i + step * stride); auto bid = batchSlot; - const auto outputId = outputIds[bid * leadingDim + s]; + auto const outputId = outputIds[bid * leadingDim + s]; if (expts.count(outputId) == 0) { if (failures < 10) diff --git a/cpp/tests/resources/scripts/generate_expected_medusa_output.py b/cpp/tests/resources/scripts/generate_expected_medusa_output.py index d059b6c9c..2bf964ab9 100755 --- a/cpp/tests/resources/scripts/generate_expected_medusa_output.py +++ b/cpp/tests/resources/scripts/generate_expected_medusa_output.py @@ -53,7 +53,8 @@ def generate_output(engine: str, output_name: str, max_output_len: int = 8): def generate_outputs(): print(f'Generating outputs for Medusa FP16') generate_output(engine='fp16-plugin-packed-paged', - output_name='output_tokens_fp16_plugin_packed_paged') + output_name='output_tokens_long_fp16_plugin_packed_paged', + max_output_len=128) if __name__ == '__main__': diff --git a/cpp/tests/resources/scripts/test_cpp.py b/cpp/tests/resources/scripts/test_cpp.py index 3705da9c1..b4530e19f 100755 --- a/cpp/tests/resources/scripts/test_cpp.py +++ b/cpp/tests/resources/scripts/test_cpp.py @@ -389,8 +389,7 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, benchmark_exe_dir = build_dir / "benchmarks" gpt_engine_dir = resources_dir / "models" / "rt_engine" / "gpt2" benchmark = [ - str(benchmark_exe_dir / "gptSessionBenchmark"), "--model", "gpt", - "--engine_dir", + str(benchmark_exe_dir / "gptSessionBenchmark"), "--engine_dir", str(gpt_engine_dir / "fp16-plugin" / "tp1-pp1-gpu"), "--batch_size", "8", "--input_output_len", "10,20", "--duration", "10" ] @@ -424,8 +423,7 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, run_command(prepare_dataset, cwd=root_dir, timeout=300) benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--model", "gpt", - "--engine_dir", + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", str(gpt_engine_dir / "fp16-plugin-packed-paged" / "tp1-pp1-gpu"), "--type", "IFB", "--dataset", str(data_dir / tokens_f) @@ -433,8 +431,7 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, run_command(benchmark, cwd=root_dir, timeout=600) benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--model", "gpt", - "--engine_dir", + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", str(gpt_engine_dir / "fp16-plugin-packed-paged" / "tp1-pp1-gpu"), "--type", "V1", "--dataset", str(data_dir / tokens_f) @@ -442,8 +439,7 @@ def run_benchmarks(python_exe: str, root_dir: _pl.Path, build_dir: _pl.Path, run_command(benchmark, cwd=root_dir, timeout=600) benchmark = [ - str(benchmark_exe_dir / "gptManagerBenchmark"), "--model", "gpt", - "--engine_dir", + str(benchmark_exe_dir / "gptManagerBenchmark"), "--engine_dir", str(gpt_engine_dir / "fp16-plugin-packed-paged" / "tp1-pp1-gpu"), "--type", "IFB", "--static_emulated_batch_size", "50", "--dataset", str(data_dir / "prepared_dummy_cnn.json") diff --git a/cpp/tests/runtime/gptDecoderBatchTest.cpp b/cpp/tests/runtime/gptDecoderBatchTest.cpp index 9f62f7b25..d7f50c66c 100644 --- a/cpp/tests/runtime/gptDecoderBatchTest.cpp +++ b/cpp/tests/runtime/gptDecoderBatchTest.cpp @@ -102,6 +102,7 @@ std::vector prepareRequests(SizeType batchSize, SizeType std::vector draftTokens(generatedTokensPerSteps[batchIdx] - 1); std::fill(draftTokens.begin(), draftTokens.begin() + acceptedTokensPerStep[batchIdx], 1023); requests.back().draftTokens = manager.copyFrom(draftTokens, MemoryType::kGPU); + requests.back().generatedTokensPerStep = generatedTokensPerSteps[batchIdx]; } requests.back().computeCumLogProbs = computeLogProbs; requests.back().computeLogProbs = computeLogProbs; @@ -523,7 +524,7 @@ struct BeamConfig using ParamType = std::tuple; -std::string generateTestName(const testing::TestParamInfo& info) +std::string generateTestName(testing::TestParamInfo const& info) { std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"}; BeamConfig const beamConfig = std::get<1>(info.param); @@ -631,7 +632,7 @@ INSTANTIATE_TEST_SUITE_P(DecoderTest, ParamDraftTest, DraftConfig{4, {1, 2, 3}, {0, 0, 1}} )), - [](const testing::TestParamInfo& info) + [](testing::TestParamInfo const& info) { std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"}; BeamConfig const beamConfig = std::get<1>(info.param); diff --git a/cpp/tests/runtime/gptDecoderTest.cpp b/cpp/tests/runtime/gptDecoderTest.cpp index 967ab2122..f4a9bfcef 100644 --- a/cpp/tests/runtime/gptDecoderTest.cpp +++ b/cpp/tests/runtime/gptDecoderTest.cpp @@ -199,7 +199,7 @@ TEST_P(ParamTest, Test) INSTANTIATE_TEST_SUITE_P(DecoderTest, ParamTest, testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF), testing::Values(1, 3)), - [](const testing::TestParamInfo& info) + [](testing::TestParamInfo const& info) { std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"}; auto const beamWidth = std::get<1>(info.param); diff --git a/cpp/tests/runtime/gptSessionTest.cpp b/cpp/tests/runtime/gptSessionTest.cpp index fa0446014..98f5d5dbb 100644 --- a/cpp/tests/runtime/gptSessionTest.cpp +++ b/cpp/tests/runtime/gptSessionTest.cpp @@ -199,7 +199,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model if (isChatGlmTest) { ASSERT_TRUE(fs::exists(DATA_PATH / modelName)); - const int batchSize = *batchSizes.begin(); + int const batchSize = *batchSizes.begin(); fileNameSuffix = std::string("-BS") + std::to_string(batchSize) + "-BM" + std::to_string(beamWidth) + std::string(".npy"); inputPath = DATA_PATH / modelName / (std::string("inputId") + fileNameSuffix); @@ -239,7 +239,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model auto const modelConfig = json.getModelConfig(); verifyModelConfig(modelConfig, modelSpec); - const int worldSize = modelSpec.mTPSize * modelSpec.mPPSize; + int const worldSize = modelSpec.mTPSize * modelSpec.mPPSize; auto const worldConfig = WorldConfig::mpi(worldSize, modelSpec.mTPSize, modelSpec.mPPSize); auto enginePath = modelPath / json.engineFilename(worldConfig); @@ -297,9 +297,9 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model std::srand(42); if (modelSpec.mRandomEndId) { - const auto endIdRow = std::rand() % nbGivenInputs; - const auto endIdBeam = std::rand() % beamWidth; - const auto endIdCol = givenInputLengths[endIdRow] + minLength + std::rand() % (maxNewTokens - minLength); + auto const endIdRow = std::rand() % nbGivenInputs; + auto const endIdBeam = std::rand() % beamWidth; + auto const endIdCol = givenInputLengths[endIdRow] + minLength + std::rand() % (maxNewTokens - minLength); auto const endIdIndex = tc::flat_index2((endIdRow * beamWidth + endIdBeam), endIdCol, maxSeqLength); endId = expectedOutputData[endIdIndex]; } @@ -357,7 +357,7 @@ void testGptSession(fs::path const& modelPath, ModelSpec const& modelSpec, Model std::vector inputLengthsHost(batchSize); for (SizeType i = 0; i < batchSize; ++i) { - const int inputIdx = i % nbGivenInputs; + int const inputIdx = i % nbGivenInputs; inputLengthsHost[i] = givenInputLengths[inputIdx]; } auto inputLengths = bufferManager.copyFrom(inputLengthsHost, ITensor::makeShape({batchSize}), MemoryType::kGPU); @@ -519,7 +519,7 @@ auto constexpr kBatchSizes = {1, 8}; using ParamType = std::tuple; -std::string generateTestName(const testing::TestParamInfo& info) +std::string generateTestName(testing::TestParamInfo const& info) { auto const modelSpec = std::get<1>(info.param); std::string name{modelSpec.mDataType == nvinfer1::DataType::kFLOAT ? "Float" : "Half"}; diff --git a/cpp/tests/runtime/runtimeKernelTest.cpp b/cpp/tests/runtime/runtimeKernelTest.cpp index ddc2c170d..559c01db9 100644 --- a/cpp/tests/runtime/runtimeKernelTest.cpp +++ b/cpp/tests/runtime/runtimeKernelTest.cpp @@ -938,7 +938,7 @@ void testCopyBatch(SizeType stride, BufferManager& manager, CudaStream& stream) { for (SizeType ci = 0; ci < stride; ++ci) { - const auto idx = row * stride + ci; + auto const idx = row * stride + ci; srcBufferHostPtr[idx] = idx; } } @@ -968,18 +968,18 @@ void testCopyBatch(SizeType stride, BufferManager& manager, CudaStream& stream) { if (idx < numIndices && ci < sizesPtr[idx]) { - const auto refIdx = srcOffsetsPtr[idx] + ci; - const auto ref = srcBufferHostPtr[refIdx]; + auto const refIdx = srcOffsetsPtr[idx] + ci; + auto const ref = srcBufferHostPtr[refIdx]; - const auto outIdx = dstOffsetsPtr[idx] + ci; - const auto out = dstBufferHostPtr[outIdx]; + auto const outIdx = dstOffsetsPtr[idx] + ci; + auto const out = dstBufferHostPtr[outIdx]; EXPECT_EQ(ref, out) << "Error at index row: " << idx << " column: " << ci << " for stride " << stride; } else { - const auto outIdx = idx * stride + ci; - const auto out = dstBufferHostPtr[outIdx]; + auto const outIdx = idx * stride + ci; + auto const out = dstBufferHostPtr[outIdx]; EXPECT_EQ(0, out) << "Error at index row: " << idx << " column: " << ci << " for stride " << stride; } diff --git a/cpp/tests/runtime/tllmBuffersTest.cpp b/cpp/tests/runtime/tllmBuffersTest.cpp index 5d8f49982..e9f4a7a5e 100644 --- a/cpp/tests/runtime/tllmBuffersTest.cpp +++ b/cpp/tests/runtime/tllmBuffersTest.cpp @@ -417,7 +417,7 @@ TEST_F(TllmBuffersTest, PinnedPoolAllocator) GTEST_SKIP(); using MemPool = MemoryPool; - auto expectedSize = [](const auto& tensor) + auto expectedSize = [](auto const& tensor) { auto s = tensor()->getSizeInBytes(); constexpr auto alignment = MemPool::kAlignment; diff --git a/cpp/tests/runtime/tllmRuntimeTest.cpp b/cpp/tests/runtime/tllmRuntimeTest.cpp index ed67e9bd7..863f87c75 100644 --- a/cpp/tests/runtime/tllmRuntimeTest.cpp +++ b/cpp/tests/runtime/tllmRuntimeTest.cpp @@ -51,7 +51,7 @@ std::unique_ptr buildMnistEngine(trt::ILogger& logger) { EXPECT_TRUE(fs::exists(MNIST_MODEL_PATH)); auto builder = makeUnique(trt::createInferBuilder(logger)); - const auto explicitBatch = 1U << static_cast(trt::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto const explicitBatch = 1U << static_cast(trt::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto network = makeUnique(builder->createNetworkV2(explicitBatch)); auto parser = makeUnique(nvonnxparser::createParser(*network, logger)); auto const parsingSuccess = parser->parseFromFile( diff --git a/cpp/tests/runtime/transposeKVKernelTest.cpp b/cpp/tests/runtime/transposeKVKernelTest.cpp index 949365cf5..89a864773 100644 --- a/cpp/tests/runtime/transposeKVKernelTest.cpp +++ b/cpp/tests/runtime/transposeKVKernelTest.cpp @@ -48,19 +48,19 @@ void randomInitVector(std::vector& vec, float range) template void randomInitVector(std::vector& vec, float scale); template void randomInitVector(std::vector& vec, float scale); -std::vector pointerArrayFromPageTable(const std::unordered_map& pageTable, void* memoryPool, +std::vector pointerArrayFromPageTable(std::unordered_map const& pageTable, void* memoryPool, int32_t batchSize, int32_t blocksPerSeq, int32_t blockSizeInBytes, int32_t blocksPerPool) { - const auto pointerArrayElts = pageTable.size(); + auto const pointerArrayElts = pageTable.size(); std::vector pointers(2 * pointerArrayElts); for (int i = 0; i < pointerArrayElts; ++i) { - const int pageIdx = pageTable.find(i)->second; + int const pageIdx = pageTable.find(i)->second; auto kPtr = reinterpret_cast(reinterpret_cast(memoryPool) + pageIdx * blockSizeInBytes); auto vPtr = reinterpret_cast( reinterpret_cast(memoryPool) + pageIdx * blockSizeInBytes + blocksPerPool * blockSizeInBytes); - const int batchIdx = i / batchSize; - const int seqIdx = i % blocksPerSeq; + int const batchIdx = i / batchSize; + int const seqIdx = i % blocksPerSeq; pointers[batchIdx * blocksPerSeq * 2 + 0 * blocksPerSeq + seqIdx] = kPtr; pointers[batchIdx * blocksPerSeq * 2 + 1 * blocksPerSeq + seqIdx] = vPtr; } @@ -76,8 +76,8 @@ T_DST castTo(T value) template <> int8_t castTo(float value) { - const auto clipped = std::min(127.f, std::max(value, -128.f)); - const auto rounded = std::round(clipped); + auto const clipped = std::min(127.f, std::max(value, -128.f)); + auto const rounded = std::round(clipped); return static_cast(rounded); } @@ -95,7 +95,7 @@ float castTo(__nv_fp8_e4m3 value) template void verifyKVTransposed(int batchSize, int headsNum, int dimsPerHead, int seqLen, int maxSeqLen, KVCacheBuffer& buffer, - const std::vector& refKCacheVec, const std::vector& vTransposedCacheVec, bool b8bitKVCache, + std::vector const& refKCacheVec, std::vector const& vTransposedCacheVec, bool b8bitKVCache, float kvScaleOrigQuant) { for (int bi = 0; bi < batchSize; ++bi) @@ -112,10 +112,10 @@ void verifyKVTransposed(int batchSize, int headsNum, int dimsPerHead, int seqLen for (int xi = 0; xi < X_ELEMS; ++xi) { - const int refKVIdx = bi * headsNum * seqLen * dimsPerHead + hi * seqLen * dimsPerHead + int const refKVIdx = bi * headsNum * seqLen * dimsPerHead + hi * seqLen * dimsPerHead + li * dimsPerHead + di * X_ELEMS + xi; - const int kVIdx = buffer.getKVLocalIdx(li, hi, dimsPerHead, di * X_ELEMS + xi); + int const kVIdx = buffer.getKVLocalIdx(li, hi, dimsPerHead, di * X_ELEMS + xi); T refK = refKCacheVec[refKVIdx]; T refV = vTransposedCacheVec[refKVIdx]; @@ -128,14 +128,14 @@ void verifyKVTransposed(int batchSize, int headsNum, int dimsPerHead, int seqLen const T_DST castedRefK = castTo(refK); const T_DST castedRefV = castTo(refV); - const auto outK = blockKPtr[kVIdx]; - const auto outV = blockVPtr[kVIdx]; + auto const outK = blockKPtr[kVIdx]; + auto const outV = blockVPtr[kVIdx]; // Since EXPECT_EQ does not support fp8, casting to float to compare - const float outK_float = castTo(outK); - const float outV_float = castTo(outV); - const float castedRefK_float = castTo(castedRefK); - const float castedRefV_float = castTo(castedRefV); + float const outK_float = castTo(outK); + float const outV_float = castTo(outV); + float const castedRefK_float = castTo(castedRefK); + float const castedRefV_float = castTo(castedRefV); EXPECT_EQ(outK_float, castedRefK_float); EXPECT_EQ(outV_float, castedRefV_float); } @@ -175,14 +175,14 @@ void testTransposeBatch4dPaged(bool multiQueryMode, bool int8KVCache, bool fp8KV maxAttentionWindow, sinkTokenLen, onlyKorV); // Allocate for pointer array - const auto pointerArrayElts = maxSeq * maxBlocksPerSeq; - const auto pointerArraySize = 2 * pointerArrayElts * sizeof(void*); + auto const pointerArrayElts = maxSeq * maxBlocksPerSeq; + auto const pointerArraySize = 2 * pointerArrayElts * sizeof(void*); cudaMalloc(&blockArray.data, pointerArraySize); cudaMemset(blockArray.data, 0, pointerArraySize); // Allocate for kv cache block pool - const auto blocksPerPool = maxBlocksPerSeq * maxSeq; - const auto kvPoolSize = 2 * blockSizeBytes * blocksPerPool; + auto const blocksPerPool = maxBlocksPerSeq * maxSeq; + auto const kvPoolSize = 2 * blockSizeBytes * blocksPerPool; void* kvMemoryPool; cudaMalloc(&kvMemoryPool, kvPoolSize); cudaMemset(kvMemoryPool, 0, kvPoolSize); @@ -206,7 +206,7 @@ void testTransposeBatch4dPaged(bool multiQueryMode, bool int8KVCache, bool fp8KV } // Init array of pointer from page table - const auto pointers = pointerArrayFromPageTable( + auto const pointers = pointerArrayFromPageTable( mapIndicesTable, kvMemoryPool, maxSeq, maxBlocksPerSeq, blockSizeBytes, blocksPerPool); cudaMemcpy(blockArray.data, pointers.data(), pointerArraySize, cudaMemcpyHostToDevice); @@ -286,8 +286,8 @@ void testTransposeBatch4dContiguous(bool multiQueryMode, bool int8KVCache, bool batchSize, 1, maxSeqLen, dimsPerHead * headsNum * sizeof(T_DST), maxAttentionWindow, sinkTokenLen, onlyKorV); // Allocate for kv cache pool - const auto kvPoolElts = 2 * batchSize * maxSeqLen * dimsPerHead * headsNum; - const auto kvPoolSize = kvPoolElts * sizeof(T_DST); + auto const kvPoolElts = 2 * batchSize * maxSeqLen * dimsPerHead * headsNum; + auto const kvPoolSize = kvPoolElts * sizeof(T_DST); cudaMalloc(&kvLinearBuffer.data, kvPoolSize); cudaMemset(kvLinearBuffer.data, 0, kvPoolSize); diff --git a/docker/Dockerfile.multi b/docker/Dockerfile.multi index abb7dd45c..ba1ed4328 100644 --- a/docker/Dockerfile.multi +++ b/docker/Dockerfile.multi @@ -1,6 +1,6 @@ # Multi-stage Dockerfile ARG BASE_IMAGE=nvcr.io/nvidia/pytorch -ARG BASE_TAG=23.12-py3 +ARG BASE_TAG=24.01-py3 ARG DEVEL_IMAGE=devel FROM ${BASE_IMAGE}:${BASE_TAG} as base diff --git a/docker/Makefile b/docker/Makefile index 5fca9d3cd..237e16b31 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -21,6 +21,7 @@ endif STAGE ?= # Set this to define a custom image name and tag IMAGE_WITH_TAG ?= $(IMAGE_NAME)$(if $(STAGE),/$(STAGE)):$(IMAGE_TAG) +PUSH_TO_STAGING ?= 1 DOCKER_BUILD_OPTS ?= --pull DOCKER_BUILD_ARGS ?= DOCKER_PROGRESS ?= auto @@ -50,6 +51,11 @@ define add_local_user .. endef +# Rewrite `/tensorrt-llm:` in image tag with `/tensorrt-llm-staging:` to avoid directly overwriting +define rewrite_tag +$(shell echo $(IMAGE_WITH_TAG) | sed "s/\/tensorrt-llm:/\/tensorrt-llm-staging:/g") +endef + %_build: @echo "Building docker image: $(IMAGE_WITH_TAG)" DOCKER_BUILDKIT=1 docker build $(DOCKER_BUILD_OPTS) $(DOCKER_BUILD_ARGS) \ @@ -75,8 +81,16 @@ endef $(call add_local_user,$(IMAGE_WITH_TAG)) %_push: %_build - @echo "Pushing docker image: $(IMAGE_WITH_TAG)" - docker push $(IMAGE_WITH_TAG)$(IMAGE_TAG_SUFFIX) + @if [ $(PUSH_TO_STAGING) = 0 ]; then \ + echo "Pushing docker image: $(IMAGE_WITH_TAG)"; \ + docker push $(IMAGE_WITH_TAG)$(IMAGE_TAG_SUFFIX); \ + fi + @if [ $(PUSH_TO_STAGING) = 1 ]; then \ + echo "Rewriting docker tag: $(IMAGE_WITH_TAG) to $(call rewrite_tag)"; \ + docker tag $(IMAGE_WITH_TAG)$(IMAGE_TAG_SUFFIX) $(call rewrite_tag)$(IMAGE_TAG_SUFFIX); \ + echo "Pushing docker image: $(call rewrite_tag)"; \ + docker push $(call rewrite_tag)$(IMAGE_TAG_SUFFIX); \ + fi %_pull: @echo "Pulling docker image: $(IMAGE_WITH_TAG)" @@ -148,6 +162,7 @@ old-cuda_%: NCCL_VERSION = 2.18.3-1+cuda12.1 old-cuda_%: CUBLAS_VERSION = 12.1.3.1-1 trtllm_%: STAGE = release +trtllm_%: PUSH_TO_STAGING := 0 trtllm_%: DEVEL_IMAGE = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"') trtllm_%: IMAGE_NAME = $(shell grep 'IMAGE_NAME = ' ../jenkins/BuildDockerImage.groovy | grep -o '".*"' | tr -d '"') trtllm_%: IMAGE_TAG = $(shell git rev-parse --abbrev-ref HEAD | tr '/' '_') diff --git a/docker/common/install_base.sh b/docker/common/install_base.sh index 5e2608f35..e1dce4934 100644 --- a/docker/common/install_base.sh +++ b/docker/common/install_base.sh @@ -24,7 +24,8 @@ init_ubuntu() { python3-dev \ python3-pip \ python-is-python3 \ - wget + wget \ + pigz if ! command -v mpirun &> /dev/null; then DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends openmpi-bin libopenmpi-dev fi @@ -92,7 +93,7 @@ install_devtoolset_centos() { YUM_CUDA=${CUDA_VERSION/./-} # Consistent with manylinux2014 centos-7 based version yum -y install vim wget git-lfs rh-git227 devtoolset-10 libffi-devel - yum -y install openmpi3 openmpi3-devel + yum -y install openmpi3 openmpi3-devel pigz echo "source scl_source enable rh-git227" >> "${ENV}" echo "source scl_source enable devtoolset-10" >> "${DEVTOOLSET_ENV_FILE}" echo "source ${DEVTOOLSET_ENV_FILE}" >> "${ENV}" diff --git a/docker/common/install_pytorch.sh b/docker/common/install_pytorch.sh index 3c9473555..93fc8d3ea 100644 --- a/docker/common/install_pytorch.sh +++ b/docker/common/install_pytorch.sh @@ -2,7 +2,10 @@ set -ex -TORCH_VERSION="2.1.0" +# Use latest stable version from https://pypi.org/project/torch/#history +# and closest to the version specified in +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html#rel-24-01 +TORCH_VERSION="2.1.2" SYSTEM_ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') prepare_environment() { diff --git a/docker/common/install_tensorrt.sh b/docker/common/install_tensorrt.sh index 7c2de9660..487106b08 100644 --- a/docker/common/install_tensorrt.sh +++ b/docker/common/install_tensorrt.sh @@ -2,11 +2,14 @@ set -ex -TRT_VER="9.2.0.5" +# Use https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-01.html#rel-24-01 +TRT_VER="9.3.0.1" CUDA_VER="12.3" -CUDNN_VER="8.9.4.25-1+cuda12.2" +CUDNN_VER="8.9.7.29-1+cuda12.2" +# v2.19.4 doesn't exist in https://developer.download.nvidia.cn/compute/cuda/repos/ NCCL_VER="2.19.3-1+cuda12.3" CUBLAS_VER="12.3.4.1-1" +NVRTC_VER="12.3.107-1" for i in "$@"; do case $i in @@ -44,10 +47,16 @@ install_ubuntu_requirements() { if [[ $(apt list --installed | grep libcublas) ]]; then apt-get remove --purge -y --allow-change-held-packages libcublas* fi + if [[ $(apt list --installed | grep cuda-nvrtc-dev) ]]; then + apt-get remove --purge -y --allow-change-held-packages cuda-nvrtc-dev* + fi CUBLAS_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') apt-get install -y --no-install-recommends libcudnn8=${CUDNN_VER} libcudnn8-dev=${CUDNN_VER} apt-get install -y --no-install-recommends libnccl2=${NCCL_VER} libnccl-dev=${NCCL_VER} apt-get install -y --no-install-recommends libcublas-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} libcublas-dev-${CUBLAS_CUDA_VERSION}=${CUBLAS_VER} + # NVRTC static library doesn't exist in NGC PyTorch container. + NVRTC_CUDA_VERSION=$(echo $CUDA_VER | sed 's/\./-/g') + apt-get install -y --no-install-recommends cuda-nvrtc-dev-${NVRTC_CUDA_VERSION}=${NVRTC_VER} apt-get clean rm -rf /var/lib/apt/lists/* } @@ -75,7 +84,7 @@ install_tensorrt() { if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi - RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.2.0/tensorrt-${TRT_VER}.${OS}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz; + RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/9.3.0/tensorrt-${TRT_VER}.${OS}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz; fi wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar tar -xf /tmp/TensorRT.tar -C /usr/local/ diff --git a/docs/source/performance.md b/docs/source/performance.md index 397617289..aea75d45e 100644 --- a/docs/source/performance.md +++ b/docs/source/performance.md @@ -391,7 +391,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - ./cpp/build/benchmarks/gptSessionBenchmark --model gptj --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` @@ -405,7 +405,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - ./cpp/build/benchmarks/gptSessionBenchmark --model gptj --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` @@ -461,7 +461,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` #### First Token Latency Benchmark @@ -474,7 +474,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` @@ -536,7 +536,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` @@ -551,7 +551,7 @@ do in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }') echo "BS: $batch_size, ISL/OSL: $in_out_dims" - mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims + mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims done ``` @@ -624,9 +624,9 @@ do --strongly_typed # Throughput benchmark - mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model falcon --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},${osl}" + mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},${osl}" # Time to first token benchmark - mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model falcon --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},1" + mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},1" # The Falcon-180b engine is quite large, remove after the benchmark to free up space # Remove this line if you'd like to save the engines. diff --git a/docs/source/precision.md b/docs/source/precision.md index c31b7b7a9..c11c10f00 100644 --- a/docs/source/precision.md +++ b/docs/source/precision.md @@ -145,6 +145,17 @@ This release of TensorRT-LLM contains the following examples: | Whisper | Y | Y | Y | . | . | Y | Y | . | . | +The list of supported multi-modal examples is: + +| BLIP2-OPT | Y | Y | Y | . | . | . | . | . | . | +| BLIP2-T5 | Y | Y | Y | . | . | . | . | . | . | +| LLaVA | Y | Y | Y | Y | Y | Y | Y | Y | Y | +| VILA | Y | Y | Y | Y | Y | Y | Y | Y | Y | +| Nougat | Y | Y | Y | . | . | . | . | . | . | + +Note: The vision component of multi-modal models uses FP16 by default. +The language component decides which quantization methods are supported by a given multi-modal model. + ## Technical Detail: The `QuantMode` Flags The quantization method is controlled by the diff --git a/examples/baichuan/README.md b/examples/baichuan/README.md index ba2242144..45f3b4d20 100644 --- a/examples/baichuan/README.md +++ b/examples/baichuan/README.md @@ -44,8 +44,8 @@ TensorRT-LLM Baichuan builds TensorRT engine(s) from HF checkpoint. If no checkp # 7B models should always enable `gpt_attention_plugin`` since RoPE is only # supported with GPTAttention plugin now. # Try gemm_plugin to prevent accuracy issue. -trtllm-build --checkpoint_dir ./trt_ckpt/baichuan_v1_13b/ \ - --output_dir ./trt_engines/baichuan_v1_13b/ \ +trtllm-build --checkpoint_dir ./tmp/baichuan_v1_13b/trt_ckpts/fp16/1-gpu/ \ + --output_dir ./tmp/baichuan_v1_13b/trt_engines/fp16/1-gpu/ \ --gemm_plugin float16 \ --max_batch_size=32 \ --max_input_len=1024 \ @@ -60,20 +60,20 @@ Here're some examples for checkpoint conversion that take `v1_13b` as example: python convert_checkpoint.py --model_version v1_13b \ --model_dir baichuan-inc/Baichuan-13B-Chat \ --dtype float16 \ - --output_dir ./tmp/baichuan_v1_13b/trt_engines/fp16/1-gpu/ + --output_dir ./tmp/baichuan_v1_13b/trt_ckpts/fp16/1-gpu/ # Convert the Baichuan V1 13B model using a single GPU and BF16. python convert_checkpoint.py --model_version v1_13b \ --model_dir baichuan-inc/Baichuan-13B-Chat \ --dtype bfloat16 \ - --output_dir ./tmp/baichuan_v1_13b/trt_engines/bf16/1-gpu/ + --output_dir ./tmp/baichuan_v1_13b/trt_ckpts/bf16/1-gpu/ # Convert the Baichuan V1 13B model using a single GPU and apply INT8 weight-only quantization. python convert_checkpoint.py --model_version v1_13b \ --model_dir baichuan-inc/Baichuan-13B-Chat \ --dtype float16 \ --use_weight_only \ - --output_dir ./tmp/baichuan_v1_13b/trt_engines/int8_weight_only/1-gpu/ + --output_dir ./tmp/baichuan_v1_13b/trt_ckpts/int8_weight_only/1-gpu/ # Convert the Baichuan V1 13B model using a single GPU and apply INT4 weight-only quantization. python convert_checkpoint.py --model_version v1_13b \ @@ -81,13 +81,13 @@ python convert_checkpoint.py --model_version v1_13b \ --dtype float16 \ --use_weight_only \ --weight_only_precision int4 \ - --output_dir ./tmp/baichuan_v1_13b/trt_engines/int4_weight_only/1-gpu/ + --output_dir ./tmp/baichuan_v1_13b/trt_ckpts/int4_weight_only/1-gpu/ # Convert Baichuan V1 13B using 2-way tensor parallelism. python convert_checkpoint.py --model_version v1_13b \ --model_dir baichuan-inc/Baichuan-13B-Chat \ --dtype float16 \ - --output_dir ./tmp/baichuan_v1_13b/trt_engines/fp16/1-gpu/ \ + --output_dir ./tmp/baichuan_v1_13b/trt_ckpts/fp16/1-gpu/ \ --world_size 2 \ --tp_size 2 ``` @@ -164,7 +164,7 @@ To run the GPTQ Baichuan example, the following steps are required: --group_size 64 \ --world_size 2 \ --tp_size 2 \ - --output_dir ./tmp/baichuan_v2_13b/trt_engines/int4_gptq_gs64/2-gpu/ + --output_dir ./tmp/baichuan_v2_13b/trt_ckpts/int4_gptq_gs64/2-gpu/ ``` The quantized model checkpoint is saved for future TensorRT-LLM engine build directly with the `trtllm-build` command mentioned above. diff --git a/examples/baichuan/convert_checkpoint.py b/examples/baichuan/convert_checkpoint.py index 04ac95f17..68193c017 100644 --- a/examples/baichuan/convert_checkpoint.py +++ b/examples/baichuan/convert_checkpoint.py @@ -1216,7 +1216,6 @@ def process_and_assign_weight(prefix, v, tp_dim=-1): 'quantization': { 'quant_algo': quant_algo, 'kv_cache_quant_algo': kv_cache_quant_algo, - 'sq_use_plugin': True, 'group_size': args.group_size, }, 'mapping': { @@ -1234,25 +1233,29 @@ def process_and_assign_weight(prefix, v, tp_dim=-1): with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) + hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, + trust_remote_code=True, + torch_dtype="auto") + + if args.smoothquant is not None or args.int8_kv_cache: + act_range = {} + baichuan_smoother = {} + act_range = capture_activation_range( + hf_model.cuda(), + AutoTokenizer.from_pretrained(args.model_dir, + use_fast=False, + trust_remote_code=True)) + if args.smoothquant is not None: + smooth_baichuan_model(hf_model, act_range, args.smoothquant, + baichuan_smoother) + def covert_and_save(rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size) - hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, - trust_remote_code=True, - torch_dtype="auto") + if args.smoothquant is not None or args.int8_kv_cache: - act_range = {} - baichuan_smoother = {} - act_range = capture_activation_range( - hf_model.cuda(), - AutoTokenizer.from_pretrained(args.model_dir, - use_fast=False, - trust_remote_code=True)) - if args.smoothquant is not None: - smooth_baichuan_model(hf_model, act_range, args.smoothquant, - baichuan_smoother) weights = convert_hf_baichuan_sq(hf_model, mapping, rank, args.dtype, args.per_channel, args.per_token, args.int8_kv_cache, @@ -1272,7 +1275,6 @@ def covert_and_save(rank): dtype=args.dtype, use_weight_only=args.use_weight_only, plugin_weight_only_quant_type=plugin_weight_only_quant_type) - del hf_model safetensors.torch.save_file( weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) @@ -1296,6 +1298,7 @@ def covert_and_save(rank): exceptions ) == 0, "Checkpoint conversion failed, please check error log." + del hf_model tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') diff --git a/examples/baichuan/requirements.txt b/examples/baichuan/requirements.txt index c329614f4..d849018f9 100644 --- a/examples/baichuan/requirements.txt +++ b/examples/baichuan/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.15.0 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/bloom/convert_checkpoint.py b/examples/bloom/convert_checkpoint.py index 83041a754..586d654fd 100644 --- a/examples/bloom/convert_checkpoint.py +++ b/examples/bloom/convert_checkpoint.py @@ -1090,8 +1090,6 @@ def main(): 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, } - if args.smoothquant: - config['quantization']['sq_use_plugin'] = True with (args.output_dir / 'config.json').open('w') as f: json.dump(config, f, indent=4) diff --git a/examples/bloom/requirements.txt b/examples/bloom/requirements.txt index ba54c3ef5..756b9635c 100644 --- a/examples/bloom/requirements.txt +++ b/examples/bloom/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/chatglm/convert_checkpoint.py b/examples/chatglm/convert_checkpoint.py index 0ae5baf2f..e1bdc6913 100644 --- a/examples/chatglm/convert_checkpoint.py +++ b/examples/chatglm/convert_checkpoint.py @@ -285,195 +285,6 @@ def get_tllm_linear_weight( return results -def convert_hf_chatglm(hf_model: AutoModel, - hf_config: AutoConfig, - chatglm_version: str, - mapping: Mapping, - dtype: str = 'float32', - use_parallel_embedding: bool = False, - sharding_dim: int = 0, - share_embedding_table: bool = False, - use_weight_only: bool = False, - plugin_weight_only_quant_type: torch.dtype = torch.int8): - weights = {} - tik = time.time() - - model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - num_attention_heads = hf_config.num_attention_heads - hidden_size = hf_config.hidden_size - vocab_size = hf_config.vocab_size - num_kv_heads = getattr(hf_config, 'num_kv_heads', num_attention_heads) - num_hidden_layers = hf_config.num_layers - - layers_range = mapping.pp_layers(num_hidden_layers) - for l in layers_range: - if chatglm_version in ['glm', 'chatglm']: - prefix = f'transformer.layers.{l}' - elif chatglm_version in ['chatglm2', 'chatglm3']: - prefix = f'transformer.encoder.layers.{l}' - tllm_prex = f'transformer.layers.{l-layers_range[0]}' - - if chatglm_version in ['glm', 'chatglm']: - qkv_weight, qkv_bias = get_weight_and_bias( - model_params, f'{prefix}.attention.query_key_value', dtype) - elif chatglm_version in ['chatglm2', 'chatglm3']: - qkv_weight, qkv_bias = get_weight_and_bias( - model_params, f'{prefix}.self_attention.query_key_value', dtype) - - qkv_w = split_qkv(qkv_weight, - mapping.tp_size, - mapping.tp_rank, - hidden_size, - num_attention_heads, - num_kv_heads=num_kv_heads) - if qkv_bias is None: - qkv_b = None - else: - qkv_b = split_qkv(qkv_bias, - mapping.tp_size, - mapping.tp_rank, - hidden_size, - num_attention_heads, - num_kv_heads=num_kv_heads) - weights.update( - get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', qkv_b, - use_weight_only, - plugin_weight_only_quant_type)) - - if chatglm_version in ['glm', 'chatglm']: - attn_dense_weight, attn_dense_bias = get_weight_and_bias( - model_params, f'{prefix}.attention.dense', dtype) - else: - attn_dense_weight, attn_dense_bias = get_weight_and_bias( - model_params, f'{prefix}.self_attention.dense', dtype) - - attn_dense_w = split(attn_dense_weight, - mapping.tp_size, - mapping.tp_rank, - dim=1) - weights.update( - get_tllm_linear_weight(attn_dense_w, f'{tllm_prex}.attention.dense', - attn_dense_bias, use_weight_only, - plugin_weight_only_quant_type)) - - mlp_fc_weight, mlp_fc_bias = get_weight_and_bias( - model_params, f'{prefix}.mlp.dense_h_to_4h', dtype) - if chatglm_version in ['glm', 'chatglm']: - mlp_fc_w = split(mlp_fc_weight, - mapping.tp_size, - mapping.tp_rank, - dim=0) - if mlp_fc_bias is None: - mlp_fc_b = None - else: - mlp_fc_b = split(mlp_fc_bias, - mapping.tp_size, - mapping.tp_rank, - dim=0) - - elif chatglm_version in ['chatglm2', 'chatglm3']: - mlp_fc_w = swap_and_split_mlp(mlp_fc_weight, mapping.tp_size, - mapping.tp_rank) - - if mlp_fc_bias is None: - mlp_fc_b = None - else: - mlp_fc_b = swap_and_split_mlp(mlp_fc_bias, mapping.tp_size, - mapping.tp_rank) - - weights.update( - get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', mlp_fc_b, - use_weight_only, - plugin_weight_only_quant_type)) - - mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( - model_params, f'{prefix}.mlp.dense_4h_to_h', dtype) - - mlp_proj_w = split(mlp_proj_weight, - mapping.tp_size, - mapping.tp_rank, - dim=1) - weights.update( - get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', - mlp_proj_bias, use_weight_only, - plugin_weight_only_quant_type)) - - input_ln_weight, input_ln_bias = get_weight_and_bias( - model_params, f'{prefix}.input_layernorm', dtype) - weights[f'{tllm_prex}.input_layernorm.weight'] = input_ln_weight - if input_ln_bias is not None: - weights[f'{tllm_prex}.input_layernorm.bias'] = input_ln_bias - - post_ln_weight, post_ln_bias = get_weight_and_bias( - model_params, f'{prefix}.post_attention_layernorm', dtype) - weights[f'{tllm_prex}.post_layernorm.weight'] = post_ln_weight - if post_ln_bias is not None: - weights[f'{tllm_prex}.post_layernorm.bias'] = post_ln_bias - - if mapping.is_first_pp_rank(): - if chatglm_version == 'glm': - embed_w = get_weight(model_params, 'word_embeddings', dtype) - pos_embed_w = get_weight(model_params, - 'transformer.position_embeddings', dtype) - weights['transformer.position_embedding.weight'] = pos_embed_w - block_embed_w = get_weight(model_params, - 'transformer.block_position_embeddings', - dtype) - weights['transformer.block_embedding.weight'] = block_embed_w - elif chatglm_version == 'chatglm': - embed_w = get_weight(model_params, 'transformer.word_embeddings', - dtype) - elif chatglm_version in ['chatglm2', 'chatglm3']: - embed_w = get_weight(model_params, - 'transformer.embedding.word_embeddings', dtype) - - if not use_parallel_embedding: - weights['transformer.vocab_embedding.weight'] = embed_w - else: - if sharding_dim == 0: - assert vocab_size % mapping.tp_size == 0 - else: - assert hidden_size % mapping.tp_size == 0 - weights['transformer.vocab_embedding.weight'] = split( - embed_w, mapping.tp_size, mapping.tp_rank, sharding_dim) - - if mapping.is_last_pp_rank(): - if chatglm_version == 'glm': - lm_head_weight = get_weight(model_params, 'word_embeddings', - dtype).clone() - elif chatglm_version == 'chatglm': - lm_head_weight = get_weight(model_params, - 'transformer.word_embeddings', - dtype).clone() - elif chatglm_version in ['chatglm2', 'chatglm3']: - lm_head_weight = get_weight(model_params, - 'transformer.output_layer', dtype) - assert not share_embedding_table - - if not share_embedding_table: - weights['lm_head.weight'] = split(lm_head_weight, - mapping.tp_size, - mapping.tp_rank, - dim=0) - - if chatglm_version in ['glm', 'chatglm']: - ln_f_w, ln_f_b = get_weight_and_bias(model_params, - 'transformer.final_layernorm', - dtype) - elif chatglm_version in ['chatglm2', 'chatglm3']: - ln_f_w, ln_f_b = get_weight_and_bias( - model_params, 'transformer.encoder.final_layernorm', dtype) - weights['transformer.ln_f.weight'] = ln_f_w - if ln_f_b is not None: - weights['transformer.ln_f.bias'] = ln_f_b - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - @torch.no_grad() def apply_smoothing( scales, @@ -835,20 +646,22 @@ def get_tllm_linear_sq_weight(vals, return results -def convert_hf_chatglm_sq(hf_model: AutoModel, - hf_config: AutoConfig, - chatglm_version: str, - mapping: Mapping, - dtype: str = 'float32', - use_parallel_embedding: bool = False, - sharding_dim: int = 0, - share_embedding_table: bool = False, - per_channel=False, - per_token=False, - int8_kv_cache=False, - act_range=None, - smoother=None): - assert mapping.world_size == 1 +def convert_hf_chatglm(hf_model: AutoModel, + hf_config: AutoConfig, + chatglm_version: str, + mapping: Mapping, + dtype: str = 'float32', + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, + use_weight_only: bool = False, + plugin_weight_only_quant_type: str = 'int8', + use_smooth_quant: bool = False, + per_channel=False, + per_token=False, + int8_kv_cache=False, + act_range=None, + smoother=None): weights = {} tik = time.time() @@ -868,6 +681,7 @@ def convert_hf_chatglm_sq(hf_model: AutoModel, prefix = f'transformer.encoder.layers.{l}' tllm_prex = f'transformer.layers.{l-layers_range[0]}' + # Attention QKV if chatglm_version in ['glm', 'chatglm']: qkv_weight, qkv_bias = get_weight_and_bias( model_params, f'{prefix}.attention.query_key_value', dtype) @@ -878,37 +692,63 @@ def convert_hf_chatglm_sq(hf_model: AutoModel, qkv_act_range = act_range.get( f'{prefix}.self_attention.query_key_value') - qkv_vals_int8 = generate_int8(qkv_weight.t().numpy(), - qkv_act_range, - is_qkv=True, - multi_query_mode=True) - weights.update( - get_tllm_linear_sq_weight( - vals=qkv_vals_int8, - prefix=f'{tllm_prex}.attention.qkv.', - shape=[1, qkv_weight.size(0)], - is_qkv=True, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int', - smoother_value=None, - smoother_shape=None)) - - if qkv_bias is not None: - qkv_b = split_qkv(qkv_bias, + if use_smooth_quant: + qkv_vals_int8 = generate_int8(qkv_weight.t().numpy(), + qkv_act_range, + is_qkv=True, + multi_query_mode=True) + weights.update( + get_tllm_linear_sq_weight( + vals=qkv_vals_int8, + prefix=f'{tllm_prex}.attention.qkv.', + shape=[1, qkv_weight.size(0)], + is_qkv=True, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.input_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None)) + if qkv_bias is not None: + qkv_b = split_qkv(qkv_bias, + mapping.tp_size, + mapping.tp_rank, + hidden_size, + num_attention_heads, + num_kv_heads=num_kv_heads) + weights[f'{tllm_prex}.attention.qkv.bias'] = qkv_b + else: + qkv_w = split_qkv(qkv_weight, mapping.tp_size, mapping.tp_rank, hidden_size, num_attention_heads, num_kv_heads=num_kv_heads) - weights[f'{tllm_prex}.attention.qkv.bias'] = qkv_b + if qkv_bias is None: + qkv_b = None + else: + qkv_b = split_qkv(qkv_bias, + mapping.tp_size, + mapping.tp_rank, + hidden_size, + num_attention_heads, + num_kv_heads=num_kv_heads) + + weights.update( + get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', + qkv_b, use_weight_only, + plugin_weight_only_quant_type)) if int8_kv_cache: + qkv_vals_int8 = generate_int8(qkv_weight.t().numpy(), + qkv_act_range, + is_qkv=True, + multi_query_mode=True) weights[ f'{tllm_prex}.attention.kv_cache_scaling_factor'] = torch.from_numpy( np.array([qkv_vals_int8['scale_y_quant_orig']], dtype=np.float32)).contiguous() + # Attention dense if chatglm_version in ['glm', 'chatglm']: attn_dense_weight, attn_dense_bias = get_weight_and_bias( model_params, f'{prefix}.attention.dense', dtype) @@ -920,96 +760,143 @@ def convert_hf_chatglm_sq(hf_model: AutoModel, dense_act_range = act_range.get(f'{prefix}.self_attention.dense') dense_smoother = smoother.get(f'{prefix}.self_attention.dense') - dense_vals_int8 = generate_int8(attn_dense_weight.t().numpy(), - dense_act_range, - is_qkv=False, - multi_query_mode=True) - weights.update( - get_tllm_linear_sq_weight( - vals=dense_vals_int8, - prefix=f'{tllm_prex}.attention.dense.', - shape=[1, hidden_size], - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.attention.quantization_scaling_factor', - smoother_value=dense_smoother, - smoother_shape=[1, hidden_size])) - - if attn_dense_bias is not None: - weights[f'{tllm_prex}.attention.dense.bias'] = attn_dense_bias - + if use_smooth_quant: + dense_vals_int8 = generate_int8(attn_dense_weight.t().numpy(), + dense_act_range, + is_qkv=False, + multi_query_mode=True) + weights.update( + get_tllm_linear_sq_weight( + vals=dense_vals_int8, + prefix=f'{tllm_prex}.attention.dense.', + shape=[1, hidden_size], + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix= + f'{tllm_prex}.attention.quantization_scaling_factor', + smoother_value=dense_smoother, + smoother_shape=[1, hidden_size])) + if attn_dense_bias is not None: + weights[f'{tllm_prex}.attention.dense.bias'] = attn_dense_bias + else: + attn_dense_w = split(attn_dense_weight, + mapping.tp_size, + mapping.tp_rank, + dim=1) + weights.update( + get_tllm_linear_weight(attn_dense_w, + f'{tllm_prex}.attention.dense', + attn_dense_bias, use_weight_only, + plugin_weight_only_quant_type)) + + # MLP FC mlp_fc_weight, mlp_fc_bias = get_weight_and_bias( model_params, f'{prefix}.mlp.dense_h_to_4h', dtype) - fc_act_range = act_range.get(f'{prefix}.mlp.dense_h_to_4h') - fc_vals_int8 = generate_int8(mlp_fc_weight.t().numpy(), - fc_act_range, - is_qkv=False, - multi_query_mode=True) - - cur_weights = get_tllm_linear_sq_weight( - vals=fc_vals_int8, - prefix=f'{tllm_prex}.mlp.fc.', - shape=[1, mlp_fc_weight.size(0)], - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int', - smoother_value=None, - smoother_shape=None, - ) - cur_weights[f'{tllm_prex}.mlp.fc.weight'] = swap_and_split_mlp( - cur_weights[f'{tllm_prex}.mlp.fc.weight'], - mapping.tp_size, - mapping.tp_rank, - dim=0, - ) - if per_channel: - cur_weights[ - f'{tllm_prex}.mlp.fc.per_channel_scale'] = swap_and_split_mlp( - cur_weights[f'{tllm_prex}.mlp.fc.per_channel_scale'], - mapping.tp_size, - mapping.tp_rank, - dim=1, - ) - weights.update(cur_weights) - if chatglm_version in ['glm', 'chatglm']: - if mlp_fc_bias is not None: - mlp_fc_b = split(mlp_fc_bias, + if use_smooth_quant: + fc_act_range = act_range.get(f'{prefix}.mlp.dense_h_to_4h') + fc_vals_int8 = generate_int8(mlp_fc_weight.t().numpy(), + fc_act_range, + is_qkv=False, + multi_query_mode=True) + cur_weights = get_tllm_linear_sq_weight( + vals=fc_vals_int8, + prefix=f'{tllm_prex}.mlp.fc.', + shape=[1, mlp_fc_weight.size(0)], + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.post_layernorm.scale_to_int', + smoother_value=None, + smoother_shape=None, + ) + cur_weights[f'{tllm_prex}.mlp.fc.weight'] = swap_and_split_mlp( + cur_weights[f'{tllm_prex}.mlp.fc.weight'], + mapping.tp_size, + mapping.tp_rank, + dim=0, + ) + if per_channel: + cur_weights[ + f'{tllm_prex}.mlp.fc.per_channel_scale'] = swap_and_split_mlp( + cur_weights[f'{tllm_prex}.mlp.fc.per_channel_scale'], + mapping.tp_size, + mapping.tp_rank, + dim=1, + ) + weights.update(cur_weights) + if chatglm_version in ['glm', 'chatglm']: + if mlp_fc_bias is not None: + mlp_fc_b = split(mlp_fc_bias, + mapping.tp_size, + mapping.tp_rank, + dim=0) + weights[f'{tllm_prex}.mlp.fc.bias'] = mlp_fc_b + elif chatglm_version in ['chatglm2', 'chatglm3']: + if mlp_fc_bias is not None: + mlp_fc_b = swap_and_split_mlp(mlp_fc_bias, mapping.tp_size, + mapping.tp_rank) + weights[f'{tllm_prex}.mlp.fc.bias'] = mlp_fc_b + else: + if chatglm_version in ['glm', 'chatglm']: + mlp_fc_w = split(mlp_fc_weight, mapping.tp_size, mapping.tp_rank, dim=0) - weights[f'{tllm_prex}.mlp.fc.bias'] = mlp_fc_b - elif chatglm_version in ['chatglm2', 'chatglm3']: - if mlp_fc_bias is not None: - mlp_fc_b = swap_and_split_mlp(mlp_fc_bias, mapping.tp_size, + if mlp_fc_bias is None: + mlp_fc_b = None + else: + mlp_fc_b = split(mlp_fc_bias, + mapping.tp_size, + mapping.tp_rank, + dim=0) + elif chatglm_version in ['chatglm2', 'chatglm3']: + mlp_fc_w = swap_and_split_mlp(mlp_fc_weight, mapping.tp_size, mapping.tp_rank) - weights[f'{tllm_prex}.mlp.fc.bias'] = mlp_fc_b - + if mlp_fc_bias is None: + mlp_fc_b = None + else: + mlp_fc_b = swap_and_split_mlp(mlp_fc_bias, mapping.tp_size, + mapping.tp_rank) + weights.update( + get_tllm_linear_weight(mlp_fc_w, f'{tllm_prex}.mlp.fc', + mlp_fc_b, use_weight_only, + plugin_weight_only_quant_type)) + + # MLP Proj mlp_proj_weight, mlp_proj_bias = get_weight_and_bias( model_params, f'{prefix}.mlp.dense_4h_to_h', dtype) - proj_act_range = act_range.get(f'{prefix}.mlp.dense_4h_to_h') - proj_smoother = smoother.get(f'{prefix}.mlp.dense_4h_to_h') - proj_vals_int8 = generate_int8(mlp_proj_weight.t().numpy(), - proj_act_range, - is_qkv=False, - multi_query_mode=True) - - weights.update( - get_tllm_linear_sq_weight( - vals=proj_vals_int8, - prefix=f'{tllm_prex}.mlp.proj.', - shape=[1, hidden_size], - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor', - smoother_value=proj_smoother, - smoother_shape=[1, hf_config.ffn_hidden_size])) - if mlp_proj_bias is not None: - weights[f'{tllm_prex}.mlp.proj.bias'] = mlp_proj_bias + if use_smooth_quant: + proj_act_range = act_range.get(f'{prefix}.mlp.dense_4h_to_h') + proj_smoother = smoother.get(f'{prefix}.mlp.dense_4h_to_h') + proj_vals_int8 = generate_int8(mlp_proj_weight.t().numpy(), + proj_act_range, + is_qkv=False, + multi_query_mode=True) + weights.update( + get_tllm_linear_sq_weight( + vals=proj_vals_int8, + prefix=f'{tllm_prex}.mlp.proj.', + shape=[1, hidden_size], + is_qkv=False, + per_token=per_token, + per_channel=per_channel, + last_prefix=f'{tllm_prex}.mlp.quantization_scaling_factor', + smoother_value=proj_smoother, + smoother_shape=[1, hf_config.ffn_hidden_size])) + if mlp_proj_bias is not None: + weights[f'{tllm_prex}.mlp.proj.bias'] = mlp_proj_bias + else: + mlp_proj_w = split(mlp_proj_weight, + mapping.tp_size, + mapping.tp_rank, + dim=1) + weights.update( + get_tllm_linear_weight(mlp_proj_w, f'{tllm_prex}.mlp.proj', + mlp_proj_bias, use_weight_only, + plugin_weight_only_quant_type)) input_ln_weight, input_ln_bias = get_weight_and_bias( model_params, f'{prefix}.input_layernorm', dtype) @@ -1127,7 +1014,6 @@ def convert_hf_chatglm_sq(hf_model: AutoModel, 'quantization': { 'quant_algo': None, 'kv_cache_quant_algo': None, - "sq_use_plugin": True, }, 'mapping': { 'world_size': world_size, @@ -1179,70 +1065,60 @@ def convert_hf_chatglm_sq(hf_model: AutoModel, else: plugin_weight_only_quant_type = None + hf_model = AutoModel.from_pretrained( + args.model_dir, + trust_remote_code=True, + torch_dtype="auto", + device_map="auto" if chatglm_version != 'glm' else None) + + act_range = {} + # smoother for query_key_value.dense and mlp.proj + model_smoother = {} + if args.smoothquant is not None or args.int8_kv_cache: + os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( + "TOKENIZERS_PARALLELISM", "false") + tokenizer = AutoTokenizer.from_pretrained( + args.model_dir, + trust_remote_code=True, + ) + dataset = load_dataset( + "cnn_dailymail", + '3.0.0', + split="validation", + ) + + act_range = capture_activation_range(hf_model, + tokenizer, + dataset, + num_samples=64) + if args.smoothquant is not None: + smooth_chatglm_model(hf_model, act_range, args.smoothquant, + model_smoother) + def covert_and_save(rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size) - hf_model = AutoModel.from_pretrained( - args.model_dir, - trust_remote_code=True, - torch_dtype="auto", - device_map="auto" if chatglm_version != 'glm' else None) - - if args.smoothquant is not None or args.int8_kv_cache: - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - tokenizer = AutoTokenizer.from_pretrained( - args.model_dir, - trust_remote_code=True, - ) - dataset = load_dataset( - "cnn_dailymail", - '3.0.0', - split="validation", - ) - - act_range = capture_activation_range(hf_model, - tokenizer, - dataset, - num_samples=64) - model_smoother = { - } # smoother for query_key_value.dense and mlp.proj - if args.smoothquant is not None: - smooth_chatglm_model(hf_model, act_range, args.smoothquant, - model_smoother) - weights = convert_hf_chatglm_sq( - hf_model, - hf_config, - chatglm_version, - mapping, - dtype=args.dtype, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - per_channel=args.per_channel, - per_token=args.per_token, - int8_kv_cache=args.int8_kv_cache, - act_range=act_range, - smoother=model_smoother, - ) - - else: - weights = convert_hf_chatglm( - hf_model, - hf_config, - chatglm_version, - mapping, - dtype=args.dtype, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_weight_only=args.use_weight_only, - plugin_weight_only_quant_type=plugin_weight_only_quant_type) - - del hf_model + weights = convert_hf_chatglm( + hf_model, + hf_config, + chatglm_version, + mapping, + dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_weight_only=args.use_weight_only, + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + use_smooth_quant=args.smoothquant is not None, + per_channel=args.per_channel, + per_token=args.per_token, + int8_kv_cache=args.int8_kv_cache, + act_range=act_range, + smoother=model_smoother, + ) safetensors.torch.save_file( weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) @@ -1266,6 +1142,7 @@ def covert_and_save(rank): exceptions ) == 0, "Checkpoint conversion failed, please check error log." + del hf_model tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') diff --git a/examples/chatglm/requirements.txt b/examples/chatglm/requirements.txt index dd2c60a15..ce3d5bf47 100644 --- a/examples/chatglm/requirements.txt +++ b/examples/chatglm/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 protobuf diff --git a/examples/cpp_library/main.cpp b/examples/cpp_library/main.cpp index e9f084111..20372f227 100644 --- a/examples/cpp_library/main.cpp +++ b/examples/cpp_library/main.cpp @@ -25,7 +25,7 @@ int main(int argc, char* argv[]) class TRTLogger : public nvinfer1::ILogger { public: - void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override + void log(nvinfer1::ILogger::Severity severity, char const* msg) noexcept override { if (severity <= nvinfer1::ILogger::Severity::kERROR) std::cerr << "[TensorRT-LLM ERR]: " << msg << std::endl; @@ -42,7 +42,7 @@ int main(int argc, char* argv[]) /* =============== initLibNvInferPlugins =============== */ - typedef bool (*initLibNvInferPlugins_sig)(void*, const void*); + typedef bool (*initLibNvInferPlugins_sig)(void*, void const*); auto initLibNvInferPlugins = getTrtLLMFunction( /*libFileSoName=*/libname, @@ -51,7 +51,7 @@ int main(int argc, char* argv[]) std::cout << std::endl; std::string libNamespace = "tensorrt_llm"; - const char* libNamespace_cstr = libNamespace.data(); + char const* libNamespace_cstr = libNamespace.data(); bool status1 = initLibNvInferPlugins(trtLogger, libNamespace_cstr); std::cout << "Success Status: " << status1 << std::endl << std::endl; diff --git a/examples/cpp_library/tensorrt_llm_libutils.h b/examples/cpp_library/tensorrt_llm_libutils.h index fcdf4a089..aa60444ee 100644 --- a/examples/cpp_library/tensorrt_llm_libutils.h +++ b/examples/cpp_library/tensorrt_llm_libutils.h @@ -33,7 +33,7 @@ tSymbolSignature getTrtLLMFunction(std::string libFileSoName, std::string symbol void* handle = dlopen(libFileSoName.c_str(), RTLD_LAZY | RTLD_GLOBAL); // 2. Check for errors - const char* dl_error1 = dlerror(); + char const* dl_error1 = dlerror(); if (!handle) { throw std::runtime_error("Cannot open library: " + std::string(dl_error1)); @@ -46,7 +46,7 @@ tSymbolSignature getTrtLLMFunction(std::string libFileSoName, std::string symbol *(void**) (&symbolFctn) = dlsym(handle, symbol.c_str()); // 4. Check for errors - const char* dl_error2 = dlerror(); + char const* dl_error2 = dlerror(); if (dl_error2) { dlclose(handle); diff --git a/examples/enc_dec/README.md b/examples/enc_dec/README.md index 2f7eea2c6..fa7b2777c 100644 --- a/examples/enc_dec/README.md +++ b/examples/enc_dec/README.md @@ -208,7 +208,7 @@ python build.py --model_type bart \ --max_beam_width 1 ``` -* Run the engine, setting `--lora_dir` and `--lora_task_uids`. `lora_task_uids` should be set as a list of uids which length equals to batch size. The following example is for batch size = 2: +* Run the engine, setting `--lora_dir` and `--lora_task_uids`. `--lora_task_uids` should be set as a list of uids which length equals to batch size. The following example is for batch size = 2: ```bash python run.py \ @@ -221,6 +221,19 @@ python run.py \ --lora_task_uids 0 0 ``` +* Run with multi-loRA, append `--lora_dir` with other lora directories and set `--lora_task_uids` according to the index of the lora directories. Set to "-1" to run with the base model: + +```bash +python run.py \ + --engine_dir tmp/trt_engines/bart-large-cnn/1-gpu/float16/tp1/ \ + --engine_name bart-large-cnn \ + --model_name tmp/hf_models/bart-large-cnn \ + --max_new_token=64 \ + --num_beams=1 \ + --lora_dir tmp/hf_models/bart-large-cnn-samsum-lora/ ... \ + --lora_task_uids 0 -1 -1 0 0 -1 +``` + ### Reminders - Flan-T5 models have known issues regarding FP16 precision and using BF16 precision is recommended, regardless of TRT-LLM. While we are working on improving FP16 results, please stay with FP32 or BF16 precision for Flan-T5 family. diff --git a/examples/enc_dec/bart/convert.py b/examples/enc_dec/bart/convert.py index 60cbac0e8..44ff04207 100644 --- a/examples/enc_dec/bart/convert.py +++ b/examples/enc_dec/bart/convert.py @@ -15,7 +15,7 @@ VisionEncoderDecoderModel) from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy -from tensorrt_llm.runtime.lora_manager import LoraConfig +from tensorrt_llm.lora_manager import LoraConfig LOGGER = logging.getLogger(__name__) diff --git a/examples/enc_dec/build.py b/examples/enc_dec/build.py index 027c8e162..b3b3c0b27 100644 --- a/examples/enc_dec/build.py +++ b/examples/enc_dec/build.py @@ -267,6 +267,12 @@ def parse_arguments(component): default=False, choices=['float16', 'float32', 'bfloat16'], help='Activates the lora plugin which enables embedding sharing.') + parser.add_argument( + '--skip_cross_qkv', + action='store_true', + help= + 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).' + ) # parse cmdline args args = parser.parse_args() @@ -391,7 +397,8 @@ def build_rank_engine(builder: Builder, rescale_before_lm_head=args.rescale_before_lm_head, dtype=dtype, logits_dtype=args.logits_dtype, - fp16_clamping=fp16_clamping) + fp16_clamping=fp16_clamping, + skip_cross_qkv=args.skip_cross_qkv) if args.weight_from_pytorch_ckpt: assert args.tp_size == 1, "Loading from framework model via memory is for demonstration purpose. For multi-GPU inference, please use loading from binary for better performance." @@ -533,7 +540,8 @@ def build(rank, args): hf_modules_to_trtllm_modules=args.hf_modules_to_trtllm_modules if args.use_lora_plugin else None, trtllm_modules_to_hf_modules=args.trtllm_modules_to_hf_modules - if args.use_lora_plugin else None) + if args.use_lora_plugin else None, + skip_cross_qkv=args.skip_cross_qkv) engine_name = get_engine_name(args.engine_name, args.dtype, args.tp_size, args.pp_size, cur_rank) diff --git a/examples/enc_dec/run.py b/examples/enc_dec/run.py index df85b3d22..c130550f1 100644 --- a/examples/enc_dec/run.py +++ b/examples/enc_dec/run.py @@ -29,7 +29,8 @@ import tensorrt_llm from tensorrt_llm import logger from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch -from tensorrt_llm.runtime import LoraManager, ModelConfig, SamplingConfig +from tensorrt_llm.lora_manager import LoraManager +from tensorrt_llm.runtime import ModelConfig, SamplingConfig def get_engine_name(model, dtype, tp_size, pp_size, rank): @@ -83,6 +84,7 @@ def read_config(config_path: Path): num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size cross_attention = builder_config["cross_attention"] + skip_cross_qkv = builder_config.get('skip_cross_qkv', False) has_position_embedding = builder_config["has_position_embedding"] has_token_type_embedding = builder_config["has_token_type_embedding"] use_custom_all_reduce = config['plugin_config'].get('use_custom_all_reduce', @@ -120,6 +122,7 @@ def read_config(config_path: Path): 'hf_modules_to_trtllm_modules'), trtllm_modules_to_hf_modules=builder_config.get( 'trtllm_modules_to_hf_modules'), + skip_cross_qkv=skip_cross_qkv, ) return model_config, tp_size, pp_size, gpus_per_node, dtype @@ -145,7 +148,7 @@ def parse_arguments(): parser.add_argument("--compare_hf_fp32", help="Compare results with HuggingFace FP32", action='store_true') - parser.add_argument('--lora_dir', type=str, default=None) + parser.add_argument('--lora_dir', type=str, default=None, nargs="+") parser.add_argument('--lora_task_uids', type=str, default=None, nargs="+") return parser.parse_args() @@ -228,7 +231,7 @@ def engine_setup(component): # TODO: this is only for bart self.encoder_lora_manager.load_from_hf_bart( component='encoder', - model_dirs=[lora_dir], + model_dirs=lora_dir, model_config=self.encoder_model_config, runtime_mapping=self.encoder_runtime_mapping, ) @@ -252,7 +255,7 @@ def engine_setup(component): # TODO: this is only for bart self.decoder_lora_manager.load_from_hf_bart( component='decoder', - model_dirs=[lora_dir], + model_dirs=lora_dir, model_config=self.decoder_model_config, runtime_mapping=self.decoder_runtime_mapping, ) @@ -378,6 +381,7 @@ def encoder_run(self, device=self.device).contiguous() if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None: + batch_size = input_lengths.size(0) missing_qkv_modules = [] if any(x in self.encoder_model_config.lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): @@ -391,7 +395,7 @@ def encoder_run(self, missing_qkv_modules): lora_ranks = [] lora_ptrs = [] - for batch_idx in range(input_ids.shape[0]): + for batch_idx in range(batch_size): lora_uid = self.lora_task_uids[batch_idx] if lora_uid is not None and lora_uid != "-1" and self.encoder_lora_manager.uid_to_low_ranks( lora_uid)[layer_idx][lora_module] != 0: @@ -415,8 +419,12 @@ def encoder_run(self, }) inputs.update({ 'host_request_types': - torch.IntTensor([0] * input_ids.shape[0]).to('cpu'), + torch.IntTensor([0] * batch_size).to('cpu'), }) + if self.encoder_model_config.remove_input_padding: + inputs.update({ + 'host_context_lengths': input_lengths.to('cpu'), + }) # Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape self.encoder_session.set_shapes(inputs) @@ -575,8 +583,8 @@ def generate( num_beams=num_beams, min_length=1, return_dict=return_dict) - sampling_config.update(output_cum_log_probs=False, - output_log_probs=False) + sampling_config.update(output_cum_log_probs=return_dict, + output_log_probs=return_dict) # decoder autoregressive generation self.decoder_session.setup( @@ -748,9 +756,12 @@ def test_fairseq_models(args): MBartForConditionalGeneration), 'Unsupported model!' if args.lora_dir is not None: + assert len(args.lora_dir + ) >= 1, "At least one lora model dir is required" + # we can only test single lora with HF from peft import PeftModel hf_model = PeftModel.from_pretrained( - hf_model, args.lora_dir).to('cuda').eval() + hf_model, args.lora_dir[0]).to('cuda').eval() tik = time.time() hf_gen_output = hf_model.generate( diff --git a/examples/falcon/requirements.txt b/examples/falcon/requirements.txt index 7888a6752..59eb000e3 100644 --- a/examples/falcon/requirements.txt +++ b/examples/falcon/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 transformers>=4.31.0 datasets~=2.14.5 evaluate~=0.4.1 diff --git a/examples/gemma/README.md b/examples/gemma/README.md index bcad67ced..a99118757 100644 --- a/examples/gemma/README.md +++ b/examples/gemma/README.md @@ -11,15 +11,15 @@ - [Run inference](#run-inference) - [Specific commands](#specific-commands) - [Run Gemma 2B](#run-gemma-2b) - - [Run inference under bfloat16 for keras checkpoint](#run-inference-under-bfloat16-for-keras-checkpoint) + - [Run inference under bfloat16 for HF checkpoint](#run-inference-under-bfloat16-for-hf-checkpoint) - [Run inference under FP8 for keras checkpoint](#run-inference-under-fp8-for-keras-checkpoint) - - [Run inference under SmoothQuant for jax checkpoint](#run-2b-inference-under-smoothquant-for-jax-checkpoint) + - [Run 2B inference under SmoothQuant for jax checkpoint](#run-2b-inference-under-smoothquant-for-jax-checkpoint) - [Run inference under weight only for jax checkpoint](#run-inference-under-weight-only-for-jax-checkpoint) - [Run inference under INT8 KV caches for jax checkpoint](#run-inference-under-int8-kv-caches-for-jax-checkpoint) - [Run Gemma 7B](#run-gemma-7b) - [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint) - [Run inference under FP8 for jax checkpoint](#run-inference-under-fp8-for-jax-checkpoint) - - [Run inference under SmoothQuant for jax checkpoint](#run-7b-inference-under-smoothquant-for-jax-checkpoint) + - [Run 7B inference under SmoothQuant for jax checkpoint](#run-7b-inference-under-smoothquant-for-jax-checkpoint) - [Run inference under weight only for keras checkpoint](#run-inference-under-weight-only-for-keras-checkpoint) - [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint) - [Run AMMO Quantization](#run-ammo-quantization) @@ -31,7 +31,7 @@ ## Support Matrix * FP32/FP16/BF16/INT8 Weight-Only/INT4 Weight-Only/SmoothQuant/FP8 * For SmoothQuant, TRT-LLM only supports FP16 higher precision now. - * checkpoint type: Jax, Torch, Keras + * checkpoint type: Jax, Torch, Keras, Huggingface (HF) * STRONGLY TYPED * python runtime and triton backend @@ -138,16 +138,16 @@ In this section, we demonstrate the scripts to convert checkpoint, building engi ### Run Gemma 2B -#### Run inference under bfloat16 for keras checkpoint +#### Run inference under bfloat16 for HF checkpoint ```bash -CKPT_PATH=/tmp/models/gemma_keras/keras/gemma_2b_en/ -UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_2b_en_tensorrt_llm/bf16/tp1/ -ENGINE_PATH=/tmp/gemma/2B/bf16/1-gpu/ -VOCAB_FILE_PATH=/tmp/models/gemma_nv/checkpoints/tmp_vocab.model +CKPT_PATH=/tmp/models/hf/gemma/gemma-2b/ +UNIFIED_CKPT_PATH=/tmp/ckpt/hf/gemma/2b/1-gpu/ +ENGINE_PATH=/tmp/engines/gemma/2B/bf16/1-gpu/ +VOCAB_FILE_PATH=/tmp/models/hf/gemma/gemma-2b/ -python3 ./convert_checkpoint.py \ - --ckpt-type keras \ +python3 ./examples/gemma/convert_checkpoint.py \ + --ckpt-type hf \ --model-dir ${CKPT_PATH} \ --dtype bfloat16 \ --world-size 1 \ @@ -162,19 +162,19 @@ trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \ --output_dir ${ENGINE_PATH} python3 ../summarize.py --test_trt_llm \ - --vocab_file ${VOCAB_FILE_PATH} \ + --tokenizer_dir ${VOCAB_FILE_PATH} \ --engine_dir ${ENGINE_PATH} \ --batch_size 8 \ --max_ite 5 -[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.96612286567688 sec) -[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2510) -[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 632.8598697034137) -[02/08/2024-05:04:13] [TRT-LLM] [I] TensorRT-LLM beam 0 result -[02/08/2024-05:04:13] [TRT-LLM] [I] rouge1 : 20.40970022875146 -[02/08/2024-05:04:13] [TRT-LLM] [I] rouge2 : 5.512437888775742 -[02/08/2024-05:04:13] [TRT-LLM] [I] rougeL : 15.135998543979978 -[02/08/2024-05:04:13] [TRT-LLM] [I] rougeLsum : 17.250431908889873 +[03/05/2024-02:24:39] [TRT-LLM] [I] TensorRT-LLM (total latency: 3.0897433757781982 sec) +[03/05/2024-02:24:39] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 2141) +[03/05/2024-02:24:39] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 692.9378073221881) +[03/05/2024-02:24:39] [TRT-LLM] [I] TensorRT-LLM beam 0 result +[03/05/2024-02:24:39] [TRT-LLM] [I] rouge1 : 21.042873132085678 +[03/05/2024-02:24:39] [TRT-LLM] [I] rouge2 : 6.322669223228836 +[03/05/2024-02:24:39] [TRT-LLM] [I] rougeL : 16.450116567540338 +[03/05/2024-02:24:39] [TRT-LLM] [I] rougeLsum : 18.836567173262736 ``` #### Run inference under FP8 for keras checkpoint diff --git a/examples/gemma/convert_checkpoint.py b/examples/gemma/convert_checkpoint.py index 57c892a14..cfe077aba 100644 --- a/examples/gemma/convert_checkpoint.py +++ b/examples/gemma/convert_checkpoint.py @@ -19,6 +19,7 @@ import utils.transformer from datasets import load_dataset from easydict import EasyDict +from transformers import AutoConfig, AutoModelForCausalLM import tensorrt_llm from tensorrt_llm._utils import torch_to_numpy @@ -34,7 +35,7 @@ def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-type", type=str, - choices=["jax", "keras", "torch"]) + choices=["jax", "keras", "torch", "hf"]) parser.add_argument("--model-dir", type=pathlib.Path, required=True) parser.add_argument("--output-model-dir", type=pathlib.Path, required=True) parser.add_argument("--world-size", @@ -307,7 +308,85 @@ def flatten_params(self, params): return f_params -CKPT_PARSER = {'jax': JAXParser, 'keras': KerasParser, 'torch': TorchParser} +class HfParser: + + def load_parameters(self, checkpoint_path: pathlib.Path): + hf_model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + device_map='auto', + torch_dtype='auto', + trust_remote_code=True, + ) + model_params = dict(hf_model.named_parameters()) + return model_params + + def embedding_weights(self, ckpt_params): + return ckpt_params['model.embed_tokens.weight'] + + def get_config(self, checkpoint_path, ckpt_params, num_embed): + hf_config = AutoConfig.from_pretrained( + checkpoint_path, trust_remote_code=True).to_dict() + config_new = {} + config_new["num_layers"] = hf_config["num_hidden_layers"] + config_new["num_embed"] = hf_config["vocab_size"] + config_new["embed_dim"] = hf_config["hidden_size"] + config_new["hidden_dim"] = hf_config["intermediate_size"] + config_new["num_heads"] = hf_config["num_attention_heads"] + config_new["head_dim"] = hf_config["head_dim"] + config_new["num_kv_heads"] = hf_config["num_key_value_heads"] + return EasyDict(config_new) + + def rename_to_trt_llm(self, name: str): + """Rename a gemma parameter name by the corresponding TRT-LLM style name.""" + prefix = "transformer" + sub_patterns = ( + (r"model.embed_tokens.weight", r"vocab_embedding.weight"), + (r"model.layers.(\d+).input_layernorm.weight", + r"layers.\1.input_layernorm.weight"), + (r"model.layers.(\d+).self_attn.q_proj.weight", + r"layers.\1.attention.qkv.weight"), + (r"model.layers.(\d+).self_attn.k_proj.weight", + None), # merged with above + (r"model.layers.(\d+).self_attn.v_proj.weight", + None), # merged with above + (r"model.layers.(\d+).self_attn.o_proj.weight", + r"layers.\1.attention.dense.weight"), + (r"model.layers.(\d+).mlp.gate_proj.weight", + r"layers.\1.mlp.fc.weight"), + (r"model.layers.(\d+).mlp.up_proj.weight", + None), # merged with above + (r"model.layers.(\d+).mlp.down_proj.weight", + r"layers.\1.mlp.proj.weight"), + (r"model.layers.(\d+).post_attention_layernorm.weight", + r"layers.\1.post_layernorm.weight"), + (r"model.norm.weight", r"ln_f.weight"), + ) + + for source, target in sub_patterns: + if re.match(source, name): + if target is None: + return target + else: + name = re.sub(source, target, name) + return ".".join((prefix, name)) + else: + raise ValueError(f"Don't know how to rename {prefix}.{name}") + + def flatten_params(self, params): + f_params = {} + for k, v in params.items(): + if v.dtype == torch.bfloat16: + v = v.float() + f_params[k] = torch_to_numpy(v) + return f_params + + +CKPT_PARSER = { + 'jax': JAXParser, + 'keras': KerasParser, + 'torch': TorchParser, + 'hf': HfParser +} def split(v, tp_size, idx, dim=0): @@ -556,6 +635,51 @@ def convert_from_checkpoint( else: add_trt_llm_weight(weights, trt_llm_name, qkv_param, trt_llm_config.dtype) + elif "q_proj" in name: + gqa_mode = trt_llm_config.num_attention_heads != trt_llm_config.num_key_value_heads + + if gqa_mode: + # initial shape: (num_heads * head_dim, hidden_size) + q_param = param + q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0) + + k_name = name.replace("q_proj", "k_proj") + k_param = model_params[k_name] + + v_name = name.replace("q_proj", "v_proj") + v_param = model_params[v_name] + else: + # initial shape: (num_heads * head_dim, hidden_size) + q_param = param + q_param = split_matrix_tp(q_param, tp_size, tp_rank, dim=0) + + k_name = name.replace("q_proj", "k_proj") + k_param = model_params[k_name] + k_param = split_matrix_tp(k_param, tp_size, tp_rank, dim=0) + + v_name = name.replace("q_proj", "v_proj") + v_param = model_params[v_name] + v_param = split_matrix_tp(v_param, tp_size, tp_rank, dim=0) + + qkv_param = np.concatenate([q_param, k_param, v_param], axis=0) + qkv_param = qkv_param.reshape(qkv_param.shape[0], -1) + + # If int8 kv enabled, weight-only quantization will be done later. + if trt_llm_config.quant_mode.is_weight_only() and not trt_llm_config.quant_mode.has_per_group_scaling() and \ + not trt_llm_config.quant_mode.has_int8_kv_cache(): + qkv_param_quantized, qkv_param_scales = quantize( + qkv_param, trt_llm_config.quant_mode) + add_trt_llm_weight(weights, trt_llm_name, + qkv_param_quantized) + add_trt_llm_weight( + weights, + trt_llm_name.replace(".weight", ".per_channel_scale"), + qkv_param_scales, + trt_llm_config.dtype, + ) + else: + add_trt_llm_weight(weights, trt_llm_name, qkv_param, + trt_llm_config.dtype) elif "attention.dense.weight" in trt_llm_name: # initial shape: (num_heads, head_dim, hidden_size) if len(param.shape) == 3: @@ -588,6 +712,12 @@ def convert_from_checkpoint( "mlp.gate_proj", "mlp.up_proj")] fc_param = fc_param.transpose(1, 0) gate_param = gate_param.transpose(1, 0) + elif isinstance(ckpt_parser, HfParser): + # initial shape: (intermediate_size, hidden_size) + fc_param, gate_param = param, model_params[name.replace( + "mlp.gate_proj", "mlp.up_proj")] + fc_param = fc_param.transpose(1, 0) + gate_param = gate_param.transpose(1, 0) else: # initial shape: (2, hidden_size, intermediate_size) fc_param, gate_param = param[0], param[1] @@ -632,7 +762,8 @@ def convert_from_checkpoint( add_trt_llm_weight(weights, trt_llm_name, gate_param, trt_llm_config.dtype) elif "mlp.proj.weight" in trt_llm_name: - if not isinstance(ckpt_parser, TorchParser): + if not isinstance(ckpt_parser, TorchParser) and not isinstance( + ckpt_parser, HfParser): # initial shape: (intermediate_size, hidden_size) param = param.transpose(1, 0) param = split_matrix_tp(param, tp_size, tp_rank, dim=1) @@ -650,7 +781,8 @@ def convert_from_checkpoint( else: add_trt_llm_weight(weights, trt_llm_name, param, trt_llm_config.dtype) - elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name: + elif "embedder.input_embedding" in name or "reversible_embedding" in name or "embedder.weight" in name \ + or "embed_tokens.weight" in name: if not trt_llm_config.share_embedding_table: # TODO: safetensor doesn't allow to save a shared tensor. # Currently, we clone the weight but to save the disk, it @@ -796,7 +928,6 @@ def main(): quant_algo = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' elif args.per_token and not args.per_channel: quant_algo = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' - quant_kwargs.update(sq_use_plugin=True) quant_kwargs.update(quant_algo=quant_algo, kv_cache_quant_algo=kv_cache_quant_algo) @@ -806,6 +937,16 @@ def main(): pre_quant_scale=True, exclude_modules=["lm_head"]) + quant_config = tensorrt_llm.models.modeling_utils.QuantizationConfig() + quant_config.quant_algo = quant_kwargs['quant_algo'] + quant_config.kv_cache_quant_algo = quant_kwargs['kv_cache_quant_algo'] + if args.use_weight_only_with_precision and args.use_weight_only_with_precision.endswith( + "awq"): + quant_config.group_size = 128 + quant_config.has_zero_point = quant_kwargs['has_zero_point'] + quant_config.pre_quant_scale = quant_kwargs['pre_quant_scale'] + quant_config.exclude_modules = quant_kwargs['exclude_modules'] + trt_llm_config = tensorrt_llm.models.modeling_utils.PretrainedConfig( architecture="GemmaForCausalLM", dtype=args.dtype or ckpt_params_dtype, @@ -824,7 +965,7 @@ def main(): world_size=args.world_size, tp_size=args.world_size, pp_size=1, - quantization=quant_kwargs, + quantization=quant_config, ) trt_llm_config_dict = trt_llm_config.to_dict() diff --git a/examples/gemma/requirements.txt b/examples/gemma/requirements.txt index 2a670e21f..3661de075 100644 --- a/examples/gemma/requirements.txt +++ b/examples/gemma/requirements.txt @@ -1,4 +1,6 @@ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 flax~=0.8.0 jax[cuda12_pip]~=0.4.19; platform_system != "Windows" jax~=0.4.19; platform_system == "Windows" diff --git a/examples/gpt/nemo_lora_convert.py b/examples/gpt/nemo_lora_convert.py index 0b2989c6a..67fd350fb 100644 --- a/examples/gpt/nemo_lora_convert.py +++ b/examples/gpt/nemo_lora_convert.py @@ -26,8 +26,7 @@ from utils.nemo import unpack_nemo_ckpt from tensorrt_llm._utils import str_dtype_to_torch, to_json_file, torch_to_numpy -from tensorrt_llm.runtime.lora_manager import (LoraConfig, - get_all_nemo_lora_weights) +from tensorrt_llm.lora_manager import LoraConfig, get_all_nemo_lora_weights log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(format=log_format) diff --git a/examples/gpt/requirements.txt b/examples/gpt/requirements.txt index 02895a917..3d0c299a3 100644 --- a/examples/gpt/requirements.txt +++ b/examples/gpt/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/gptj/convert_checkpoint.py b/examples/gptj/convert_checkpoint.py index a2ce1dbd7..251637689 100644 --- a/examples/gptj/convert_checkpoint.py +++ b/examples/gptj/convert_checkpoint.py @@ -331,14 +331,16 @@ def main(): if args.model_dir is None: return + hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, + trust_remote_code=True, + torch_dtype="auto") + def covert_and_save(rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size) - hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, - trust_remote_code=True, - torch_dtype="auto") + weights = convert_hf_gptj( hf_model, hf_config, @@ -346,7 +348,6 @@ def covert_and_save(rank): dtype=args.dtype, use_weight_only=args.use_weight_only, plugin_weight_only_quant_type=plugin_weight_only_quant_type) - del hf_model safetensors.torch.save_file( weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) @@ -370,6 +371,7 @@ def covert_and_save(rank): exceptions ) == 0, "Checkpoint conversion failed, please check error log." + del hf_model tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') diff --git a/examples/gptneox/requirements.txt b/examples/gptneox/requirements.txt index fa85d0a33..648a50ba9 100644 --- a/examples/gptneox/requirements.txt +++ b/examples/gptneox/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 rouge_score~=0.1.2 evaluate~=0.4.1 diff --git a/examples/hf_lora_convert.py b/examples/hf_lora_convert.py index 0a3415323..efd566bc7 100644 --- a/examples/hf_lora_convert.py +++ b/examples/hf_lora_convert.py @@ -27,7 +27,7 @@ import torch from tensorrt_llm._utils import str_dtype_to_torch -from tensorrt_llm.runtime.lora_manager import LoraConfig +from tensorrt_llm.lora_manager import LoraConfig log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" logging.basicConfig(format=log_format) diff --git a/examples/high-level-api/README.md b/examples/high-level-api/README.md index ddbf78dd6..b4a1c227b 100644 --- a/examples/high-level-api/README.md +++ b/examples/high-level-api/README.md @@ -104,6 +104,20 @@ config.quant_config.quantize_lm_head = True llm = LLM(config) ``` +## Auto parallel + +By simply enabling `parallel_config.auto_parallel` in the ModelConfig, TensorRT-LLM can parallelize the model automatically. For example, setting `parallel_config.world_size` to perform a 2-way parallelism: + +``` python +from tensorrt_llm import LLM, ModelConfig + +config = ModelConfig(model_dir=) +config.parallel_config.auto_parallel = True +config.parallel_config.world_size = 2 + +llm = LLM(config) +``` + ## Asynchronous generation With the high-level API, you can also perform asynchronous generation with the `generate_async` method. For example: @@ -118,6 +132,29 @@ async for output in llm.generate_async(, streaming=True): When the `streaming` flag is set to `True`, the `generate_async` method will return a generator that yields the token results as soon as they are available. Otherwise, it will return a generator that yields the final results only. +## Future-like generation result +The result of the `generate_async` methods is a Future-like object, it doesn't block the thread unless the `.result()` is called. + +```python +# This will not block the main thread +generation = llm.generate_async() +# Do something else here +# call .result() to explicitly block the main thread and wait for the result when needed +output = generation.result() +``` + +The `.result()` method works like the [result](https://docs.python.org/zh-cn/3/library/asyncio-future.html#asyncio.Future.result) method in the Python Future, you can specify a timeout to wait for the result. + +```python +output = generation.result(timeout=10) +``` + +There is an async version, where the `.aresult()` is used. + +```python +generation = llm.generate_async() +output = await generation.aresult() +``` ## Customization @@ -129,10 +166,10 @@ llm = LLM(config, tokenizer=) The LLM() workflow should use your tokenizer instead. -It is also possible to input token IDs directly without Tokenizers with the following code: +It is also possible to input token IDs directly without Tokenizers with the following code, note that the result will be also IDs without text since the tokenizer is not used. ``` python -llm = LLM(config, enable_tokenizer=False) +llm = LLM(config) for output in llm.generate([32, 12]): ... ``` diff --git a/examples/high-level-api/llm_examples.py b/examples/high-level-api/llm_examples.py index 9a9ee2673..0e7da7b49 100644 --- a/examples/high-level-api/llm_examples.py +++ b/examples/high-level-api/llm_examples.py @@ -8,7 +8,7 @@ import torch from tensorrt_llm import LLM, ModelConfig -from tensorrt_llm.hlapi.llm import SamplingConfig +from tensorrt_llm.hlapi.llm import KvCacheConfig, SamplingConfig from tensorrt_llm.hlapi.utils import get_device_count # NOTE, Currently, the following examples are only available for LLaMA models. @@ -86,7 +86,8 @@ def run_llm_generate_async_example(prompts: List[str], config = ModelConfig(llama_model_dir) config.parallel_config.tp_size = tp_size - llm = LLM(config, kvcache_free_gpu_memory_fraction=0.4) + llm = LLM(config, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) async def task(prompt: str): outputs = [] @@ -135,6 +136,84 @@ def run_llm_with_quantization(prompts: List[str], print(output) +def run_llm_with_async_future(prompts: List[str], llama_model_dir: str): + config = ModelConfig(llama_model_dir) + llm = LLM(config, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4)) + + # The result of generate() is similar to a Future, it won't block the main thread, call .result() to explicitly wait for the result + for generation in llm.generate_async(prompts): + # .result() is a blocking call, call it when you want to wait for the result + output = generation.result() + print(output.text) + + # Similar to .result(), there is an async version of .result(), which is .aresult(), and it works with the generate_async(). + async def task(prompt: str): + generation = llm.generate_async(prompt, streaming=False) + output = await generation.aresult() + print(output.text) + + async def main(): + tasks = [task(prompt) for prompt in prompts] + await asyncio.gather(*tasks) + + asyncio.run(main()) + + +def run_llm_with_auto_parallel(prompts: List[str], + llama_model_dir: str, + world_size: int = 1): + ''' Running LLM with auto parallel enabled. ''' + if get_device_count() < world_size: + print( + "Skip the example for auto parallel!!! Since the number of GPUs is less than required" + ) + return + if world_size > 1: + print(f'Running LLM with Auto Parallel on {world_size} GPUs.') + + config = ModelConfig(llama_model_dir) + config.parallel_config.auto_parallel = True + config.parallel_config.world_size = world_size + + llm = LLM(config) + + for output in llm.generate(prompts): + print(output) + + +def run_llm_with_auto_parallel_async(prompts: List[str], + llama_model_dir: str, + world_size: int = 1, + streaming: bool = False): + ''' Running LLM asynchronously with auto parallel enabled. ''' + if get_device_count() < world_size: + print( + "Skip the example for auto parallel!!! Since the number of GPUs is less than required" + ) + return + if world_size > 1: + print(f'Running LLM with Auto Parallel on {world_size} GPUs.') + + config = ModelConfig(llama_model_dir) + config.parallel_config.auto_parallel = True + config.parallel_config.world_size = world_size + + llm = LLM(config) + + async def task(prompt: str): + outputs = [] + async for output in llm.generate_async(prompt, streaming=streaming): + outputs.append(output.text) + print(' '.join(outputs)) + + async def main(): + tasks = [task(prompt) for prompt in prompts] + await asyncio.gather(*tasks) + + asyncio.run(main()) + + def _parse_arguments(): parser = ArgumentParser() parser.add_argument('--task', type=str, choices=_get_functions()) @@ -147,6 +226,7 @@ def _parse_arguments(): default=None) parser.add_argument('--quant_type', type=str, choices=['int4_awq', 'fp8']) parser.add_argument('--prompt', type=str, default="What is LLM?") + parser.add_argument('--world_size', type=int, default=1) parser.add_argument('--tp_size', type=int, default=1) parser.add_argument('--streaming', action='store_true') return parser.parse_args() @@ -182,9 +262,17 @@ def _get_functions(): streaming=args.streaming), run_llm_with_quantization=lambda: run_llm_with_quantization( [args.prompt], args.hf_model_dir, args.quant_type), + run_llm_with_auto_parallel=lambda: run_llm_with_auto_parallel( + [args.prompt], args.hf_model_dir, args.world_size), + run_llm_with_auto_parallel_async=lambda: + run_llm_with_auto_parallel_async([args.prompt], + args.hf_model_dir, + args.world_size, + streaming=args.streaming), run_llm_without_tokenizer_from_tllm_engine=lambda: run_llm_without_tokenizer_from_tllm_engine(args.dump_engine_dir), - ) + run_llm_with_async_future=lambda: run_llm_with_async_future( + [args.prompt], args.hf_model_dir)) print(f'Running {args.task} ...') diff --git a/examples/high-level-api/run_auto_parallel_examples.sh b/examples/high-level-api/run_auto_parallel_examples.sh new file mode 100644 index 000000000..46bbc24c3 --- /dev/null +++ b/examples/high-level-api/run_auto_parallel_examples.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -ex + +PROMPT="Tell a story" +LLAMA_MODEL_DIR=$1 +WORLD_SIZE=${2:-2} + +dir=$(dirname "$0") + +python3 $dir/llm_examples.py --task run_llm_with_auto_parallel \ + --prompt="$PROMPT" \ + --world_size=$WORLD_SIZE \ + --hf_model_dir=$LLAMA_MODEL_DIR + +python3 $dir/llm_examples.py --task run_llm_with_auto_parallel_async \ + --prompt="$PROMPT" \ + --world_size=$WORLD_SIZE \ + --hf_model_dir=$LLAMA_MODEL_DIR diff --git a/examples/high-level-api/run_examples.sh b/examples/high-level-api/run_examples.sh old mode 100644 new mode 100755 index 6b8153621..f7ef57a10 --- a/examples/high-level-api/run_examples.sh +++ b/examples/high-level-api/run_examples.sh @@ -32,3 +32,7 @@ python3 llm_examples.py --task run_llm_generate_async_example \ --hf_model_dir=$LLAMA_MODEL_DIR \ --streaming \ --tp_size=2 + +python3 llm_examples.py --task run_llm_with_async_future \ + --prompt="$PROMPT" \ + --hf_model_dir=$LLAMA_MODEL_DIR diff --git a/examples/high-level-api/run_quant_examples.sh b/examples/high-level-api/run_quant_examples.sh old mode 100644 new mode 100755 diff --git a/examples/internlm/README.md b/examples/internlm/README.md index e7f373280..e6fcde2e4 100644 --- a/examples/internlm/README.md +++ b/examples/internlm/README.md @@ -4,8 +4,11 @@ This document shows how to build and run InternLM 7B / 20B models in TensorRT-LL ## Overview -The TensorRT-LLM InternLM implementation can be found in [tensorrt_llm/models/internlm/model.py](../../tensorrt_llm/models/internlm/model.py). The TensorRT-LLM InternLM example code is located in [`examples/internlm`](./). There is one main file: +The TensorRT-LLM InternLM implementation is based on the LLaMA model. The implementation can +be found in [tensorrt_llm/models/llama/model.py](../../tensorrt_llm/models/llama/model.py). +The TensorRT-LLM InternLM example code lies in [`examples/llama`](./): +* [`convert_checkpoint.py`](../llama/convert_checkpoint.py) converts the Huggingface Model of Skywork into TensorRT-LLM checkpoint. * [`convert_checkpoint.py`] to to convert a checkpoint from the [HuggingFace (HF) Transformers](https://github.com/huggingface/transformers) format to the TensorRT-LLM format In addition, there are two shared files in the parent folder [`examples`](../) for inference and evaluation: @@ -40,6 +43,7 @@ Here're some examples: # Build a single-GPU float16 engine from HF weights. # gpt_attention_plugin is necessary in InternLM. # Try use_gemm_plugin to prevent accuracy issue. +cd examples/llama # Convert the InternLM 7B model using a single GPU and FP16. python convert_checkpoint.py --model_dir ./internlm-chat-7b/ \ @@ -97,6 +101,8 @@ and then export the scaling factors needed for INT8 KV cache inference. Example: ```bash +cd examples/llama + # For 7B models python convert_checkpoint.py --model_dir ./internlm-chat-7b \ --output_dir ./internlm-chat-7b/smooth_internlm/int8_kv_cache/ \ @@ -113,6 +119,8 @@ trtllm-build --checkpoint_dir ./internlm-chat-7b/smooth_internlm/int8_kv_cache/ ```bash +cd examples/llama + # For 20B models python convert_checkpoint.py --model_dir ./internlm-chat-20b \ --output_dir ./internlm-chat-20b/smooth_internlm/int8_kv_cache/ \ @@ -158,6 +166,8 @@ Unlike the FP16 build where the HF weights are processed and loaded into the Ten Example: ```bash +cd examples/llama + # For 7B models python convert_checkpoint.py --model_dir ./internlm-chat-7b --output_dir ./internlm-chat-7b/smooth_internlm/sq0.5/ --dtype float16 --smoothquant 0.5 # Build the engine @@ -166,6 +176,8 @@ trtllm-build --checkpoint_dir ./internlm-chat-7b/smooth_internlm/sq0.5/ \ --gemm_plugin float16 # For 20B models +cd examples/llama + python convert_checkpoint.py --model_dir ./internlm-chat-20b --output_dir ./internlm-chat-20b/smooth_internlm/sq0.5/ --dtype float16 --smoothquant 0.5 trtllm-build --checkpoint_dir ./internlm-chat-20b/smooth_internlm/sq0.5/ \ --output_dir ./engine_outputs \ @@ -183,6 +195,8 @@ Examples of build invocations: ```bash # Build model for SmoothQuant in the _per_token_ + _per_channel_ mode +cd examples/llama + # 7B model python convert_checkpoint.py --model_dir ./internlm-chat-7b --output_dir ./internlm-chat-7b/smooth_internlm/sq0.5/ --dtype float16 --smoothquant 0.5 --per_channel --per_token diff --git a/examples/internlm/convert_checkpoint.py b/examples/internlm/convert_checkpoint.py deleted file mode 100644 index 87892d3d8..000000000 --- a/examples/internlm/convert_checkpoint.py +++ /dev/null @@ -1,1497 +0,0 @@ -import argparse -import copy -import functools -import json -import os -import time -import traceback -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path - -import numpy as np -import safetensors -import torch -import torch.nn as nn -from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer -from transformers.pytorch_utils import Conv1D - -import tensorrt_llm -from tensorrt_llm.layers import MoeConfig -from tensorrt_llm.logger import logger -from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.llama.weight import (load_from_gptq_llama, - load_from_hf_checkpoint, - load_from_meta_llama) -from tensorrt_llm.models.modeling_utils import PretrainedConfig -from tensorrt_llm.runtime.lora_manager import LoraConfig - - -def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--meta_ckpt_dir', type=str, default=None) - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') - parser.add_argument('--dtype', - type=str, - default='float16', - choices=['float32', 'bfloat16', 'float16']) - parser.add_argument('--vocab_size', type=int, default=32000) - parser.add_argument('--n_positions', type=int, default=2048) - parser.add_argument('--n_layer', type=int, default=32) - - parser.add_argument( - '--use_weight_only', - default=False, - action="store_true", - help='Quantize weights for the various GEMMs to INT4/INT8.' - 'See --weight_only_precision to set the precision') - parser.add_argument( - '--weight_only_precision', - const='int8', - type=str, - nargs='?', - default='int8', - choices=['int8', 'int4', 'int4_gptq'], - help= - 'Define the precision for the weights when using weight-only quantization.' - 'You must also use --use_weight_only for that argument to have an impact.' - ) - parser.add_argument( - "--smoothquant", - "-sq", - type=float, - default=None, - help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)" - " to Smoothquant the model, and output int8 weights." - " A good first try is 0.5. Must be in [0, 1]") - parser.add_argument( - '--per_channel', - action="store_true", - default=False, - help= - 'By default, we use a single static scaling factor for the GEMM\'s result. ' - 'per_channel instead uses a different static scaling factor for each channel. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--per_token', - action="store_true", - default=False, - help= - 'By default, we use a single static scaling factor to scale activations in the int8 range. ' - 'per_token chooses at run time, and for each token, a custom scaling factor. ' - 'The latter is usually more accurate, but a little slower.') - parser.add_argument( - '--int8_kv_cache', - default=False, - action="store_true", - help= - 'By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV' - ) - parser.add_argument( - '--ammo_quant_ckpt_path', - type=str, - default=None, - help='Path of a quantized model checkpoint in .npz format') - - parser.add_argument( - '--per_group', - default=False, - action="store_true", - help= - 'By default, we use a single static scaling factor to scale weights in the int4 range. ' - 'per_group chooses at run time, and for each group, a custom scaling factor. ' - 'The flag is built for GPTQ/AWQ quantization.') - - parser.add_argument('--load_by_shard', - action='store_true', - help='Load a pretrained model shard-by-shard.') - parser.add_argument('--hidden_act', type=str, default='silu') - - parser.add_argument('--rotary_base', type=float, default=10000.0) - parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) - - parser.add_argument('--group_size', - type=int, - default=128, - help='Group size used in GPTQ/AWQ quantization.') - - parser.add_argument("--storage-type", - "-t", - type=str, - default="fp32", - choices=["fp32", "fp16"]) - parser.add_argument("--dataset-cache-dir", - type=str, - default=None, - help="cache dir to load the hugging face dataset") - parser.add_argument("--load-model-on-cpu", action="store_true") - parser.add_argument("--convert-model-on-cpu", action="store_true") - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) - parser.add_argument( - '--use_embedding_sharing', - action="store_true", - default=False, - help= - 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' - 'Note: the flag might not take effect when the criteria are not met.') - parser.add_argument('--use_prompt_tuning', - action="store_true", - default=False) - parser.add_argument('--output_dir', - type=str, - default='tllm_checkpoint', - help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') - parser.add_argument( - '--moe_num_experts', - default=0, - type=int, - help='Specify the number of experts to use for MOE layers') - parser.add_argument( - '--moe_top_k', - default=0, - type=int, - help= - 'Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set' - ) - parser.add_argument( - '--moe_tp_mode', - default=MoeConfig.ParallelismMode.TENSOR_PARALLEL, - type=int, - help= - 'Controls how to distribute experts in TP. Check layers/moe.py for accepted values', - ) - parser.add_argument( - '--moe_renorm_mode', - default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE, - type=int, - help= - 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', - ) - parser.add_argument('--enable_pos_shift', - default=False, - action='store_true', - help='Enable position shift for streamingllm method') - parser.add_argument( - '--dense_context_fmha', - default=False, - action='store_true', - help= - 'Enable dense fmha in context phase, otherwise sliding window attention.' - 'If dense_context_fmha=False, the sliding window size is the max attention window size.' - ) - parser.add_argument('--hf_lora_dir', type=str, default=None) - parser.add_argument( - '--lora_target_modules', - nargs='+', - default=None, - choices=[ - "attn_qkv", - "attn_q", - "attn_k", - "attn_v", - "attn_dense", - "mlp_h_to_4h", - "mlp_gate", - "mlp_4h_to_h", - ], - help= - "Add lora in which modules. Only be activated when use_lora_plugin is enabled." - ) - parser.add_argument( - '--max_lora_rank', - type=int, - default=64, - help='maximum lora rank for different lora modules. ' - 'It is used to compute the workspace size of lora plugin.') - args = parser.parse_args() - return args - - -def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False): - """ - This function has two purposes: - - compute quantized weights, scaled either per-tensor or per-column - - compute scaling factors - - Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ. - CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W. - CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor. - - Here is the list of what we need (T means per-tensor, C per-column): - - scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T) - - scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T) - - scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C) - - scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32) - to quant range (int8) (used for CUBLAS) (T, C) - - Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too, - but then the model would change depending on the number of GPUs used. - - For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it - as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V. - For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns. - """ - weights = weights.detach().cpu().numpy() - - # compute weight scaling factors for fp->int8 and int8->fp - if is_qkv and not multi_query_mode: - scale_w_orig_quant_t = 127. / act_range["w"].reshape(3, -1).max( - dim=-1, keepdims=True)[0].cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].reshape(3, - -1).cpu().numpy() - elif is_qkv and multi_query_mode: - hidden_dim = weights.shape[0] - local_dim = act_range["w"].shape[0] - kv_dim = (local_dim - hidden_dim) // 2 - scale_w_q = act_range["w"][0:hidden_dim] - scale_w_k = act_range["w"][hidden_dim:hidden_dim + kv_dim] - scale_w_v = act_range["w"][-kv_dim:] - - scale_w_qkv_t = torch.concat([ - scale_w_q.max(dim=0, keepdim=True)[0], - scale_w_k.max(dim=0, keepdim=True)[0], - scale_w_v.max(dim=0, keepdim=True)[0] - ]) - - scale_w_orig_quant_t = 127. / scale_w_qkv_t.cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() - else: - scale_w_orig_quant_t = 127. / act_range["w"].max().cpu().numpy() - scale_w_orig_quant_c = 127. / act_range["w"].cpu().numpy() - scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t - scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c - - scale_w_orig_quant_c = scale_w_orig_quant_c.astype(np.float32) - scale_w_orig_quant_t = scale_w_orig_quant_t.astype(np.float32) - - # compute the rest of needed scaling factors - scale_x_orig_quant_t = np.array(127. / act_range["x"].max().item()) - scale_y_orig_quant_t = np.array(127. / act_range["y"].max().item()) - scale_y_quant_orig_t = np.array(act_range["y"].max().item() / 127.) - scale_y_accum_quant_t = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_t) - scale_y_accum_quant_c = scale_y_orig_quant_t / (scale_x_orig_quant_t * - scale_w_orig_quant_c) - if is_qkv and not multi_query_mode: - scale_y_accum_quant_t = np.broadcast_to(scale_y_accum_quant_t, - scale_w_orig_quant_c.shape) - scale_w_quant_orig_t = np.broadcast_to(scale_w_quant_orig_t, - scale_w_orig_quant_c.shape) - if is_qkv and multi_query_mode: - scale_q_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[0], - scale_w_q.shape) - scale_k_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[1], - scale_w_k.shape) - scale_v_y_accum_t = np.broadcast_to(scale_y_accum_quant_t[2], - scale_w_v.shape) - scale_y_accum_quant_t = np.concatenate( - [scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]) - scale_w_quant_orig_t = np.concatenate([ - np.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape), - np.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape), - np.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape) - ]) - - to_i8 = lambda x: x.round().clip(-127, 127).astype(np.int8) - - if is_qkv and multi_query_mode: - weight_int8 = to_i8(weights / scale_w_quant_orig_t) - else: - weight_int8 = to_i8(weights * scale_w_orig_quant_t) - return { - "weight.int8": weight_int8, - "weight.int8.col": to_i8(weights * scale_w_orig_quant_c), - "scale_x_orig_quant": scale_x_orig_quant_t.astype(np.float32), - "scale_w_quant_orig": scale_w_quant_orig_t.astype(np.float32), - "scale_w_quant_orig.col": scale_w_quant_orig_c.astype(np.float32), - "scale_y_accum_quant": scale_y_accum_quant_t.astype(np.float32), - "scale_y_accum_quant.col": scale_y_accum_quant_c.astype(np.float32), - "scale_y_quant_orig": scale_y_quant_orig_t.astype(np.float32), - } - - -@torch.no_grad() -def apply_smoothing(scales, - gemm_weights, - layernorm_weights=None, - layernorm_bias=None, - dtype=torch.float32, - layernorm_1p=False): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - - if layernorm_weights is not None: - assert layernorm_weights.numel() == scales.numel() - layernorm_weights.div_(scales).to(dtype) - if layernorm_bias is not None: - assert layernorm_bias.numel() == scales.numel() - layernorm_bias.div_(scales).to(dtype) - if layernorm_1p: - layernorm_weights += (1 / scales) - 1 - - for gemm in gemm_weights: - gemm.mul_(scales.view(1, -1)).to(dtype) - - -@torch.no_grad() -def smooth_gemm(gemm_weights, - act_scales, - layernorm_weights=None, - layernorm_bias=None, - alpha=0.5, - weight_scales=None): - if not isinstance(gemm_weights, list): - gemm_weights = [gemm_weights] - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, - orig_dtype) - - return scales - - -@torch.no_grad() -def smooth_gemm_fc1_gate(fc1_weights, - gate_weights, - act_scales, - layernorm_weights=None, - layernorm_bias=None, - alpha=0.5, - weight_scales=None): - gemm_weights = [] - if not isinstance(fc1_weights, list): - fc1_weights = [fc1_weights] - if not isinstance(gate_weights, list): - gate_weights = [gate_weights] - - for i in range(len(fc1_weights)): - gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0) - gemm_weights.append(gemm_weight) - - orig_dtype = gemm_weights[0].dtype - - for gemm in gemm_weights: - # gemm_weights are expected to be transposed - assert gemm.shape[1] == act_scales.numel() - - if weight_scales is None: - weight_scales = torch.cat( - [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], - dim=0) - weight_scales = weight_scales.max(dim=0)[0] - weight_scales.to(float).clamp(min=1e-5) - scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) / - weight_scales.pow(1 - alpha)).clamp(min=1e-5) - - apply_smoothing(scales, fc1_weights + gate_weights, layernorm_weights, - layernorm_bias, orig_dtype) - - return scales - - -@torch.no_grad() -def smooth_internlm_model(model, scales, alpha, internlm_qkv_para, - internlm_smoother): - # Smooth the activation and weights with smoother = $\diag{s}$ - for name, module in model.named_modules(): - if not module.__class__.__name__ == "InternLMDecoderLayer": - continue - # qkv_proj - layer_name_q = name + ".self_attn.q_proj" - layer_name_k = name + ".self_attn.k_proj" - layer_name_v = name + ".self_attn.v_proj" - layer_name_qkv = name + ".self_attn.qkv_proj" - - weight = torch.cat([ - module.self_attn.q_proj.weight, module.self_attn.k_proj.weight, - module.self_attn.v_proj.weight - ], - dim=0) - - smoother = smooth_gemm(weight, scales[layer_name_q]["x"], - module.input_layernorm.weight, None, alpha) - - scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother - scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0] - scales[layer_name_qkv]["y"] = torch.cat([ - scales[layer_name_q]["y"], scales[layer_name_k]["y"], - scales[layer_name_v]["y"] - ], - dim=0) - - # see transpose_weights function - internlm_qkv_para[layer_name_qkv] = weight.transpose(0, 1) - - # ================================================================= - layer_name = name + ".self_attn.o_proj" - smoother = smooth_gemm(module.self_attn.o_proj.weight, - scales[layer_name]["x"], None, None, alpha) - internlm_smoother[layer_name] = smoother.float() - - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max( - dim=1)[0] - - # ================================================================== - fc1_layer_name = name + ".mlp.gate_proj" - gate_layer_name = name + ".mlp.up_proj" - - smoother = smooth_gemm_fc1_gate(module.mlp.gate_proj.weight, - module.mlp.up_proj.weight, - scales[fc1_layer_name]["x"], - module.post_attention_layernorm.weight, - None, alpha) - - scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother - scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max( - dim=1)[0] - - scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother - scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max( - dim=1)[0] - - # ================================================================== - layer_name = name + ".mlp.down_proj" - smoother = smooth_gemm(module.mlp.down_proj.weight, - scales[layer_name]["x"], None, None, alpha) - internlm_smoother[layer_name] = smoother.float() - scales[layer_name]["x"] = scales[layer_name]["x"] / smoother - scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max( - dim=1)[0] - - -@torch.no_grad() -def capture_activation_range(model, - tokenizer, - dataset, - num_samples=512, - seq_len=512): - model.eval() - device = next(model.parameters()).device - act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None}) - - tokenizer.pad_token = tokenizer.eos_token - - def stat_tensor(name, tensor, act_scales, key): - hidden_dim = tensor.shape[-1] - tensor = tensor.view(-1, hidden_dim).abs().detach() - comming_max = torch.max(tensor, dim=0)[0].float() - - if act_scales[name][key] is None: - act_scales[name][key] = comming_max - else: - act_scales[name][key] = torch.max(act_scales[name][key], - comming_max) - - def stat_input_hook(m, x, y, name): - if isinstance(x, tuple): - x = x[0] - stat_tensor(name, x, act_scales, "x") - stat_tensor(name, y, act_scales, "y") - - if act_scales[name]["w"] is None: - act_scales[name]["w"] = m.weight.abs().clip(1e-8, - None).max(dim=1)[0] - - hooks = [] - for name, m in model.named_modules(): - if isinstance(m, nn.Linear) or isinstance(m, Conv1D): - hooks.append( - m.register_forward_hook( - functools.partial(stat_input_hook, name=name))) - - for i in tqdm(range(num_samples), desc="calibrating model"): - datapoint = dataset['train'][i:i + 1] - line = copy.copy(datapoint['article']) - line[0] = line[0] + ' TL;DR: ' - line[0] = line[0].strip() - line[0] = line[0].replace(" n't", "n't") - input_ids = tokenizer(line, - return_tensors="pt", - max_length=seq_len, - padding=True, - truncation=True).input_ids.to(device) - model(input_ids) - for h in hooks: - h.remove() - return act_scales - - -def split(v, tp_size, idx, dim=0): - if tp_size == 1: - return v - if len(v.shape) == 1: - return torch.chunk(v, tp_size)[idx].contiguous() - else: - return torch.chunk(v, tp_size, dim=dim)[idx].contiguous() - - -def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): - """ - Splits the QKV matrix according to tensor parallelism - """ - v = v.reshape(3, n_hidden, n_hidden) - split_v = split(v, tensor_parallel, rank, dim=1) - split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) - return split_v.contiguous() - - -def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): - """ - Splits the QKV bias according to tensor parallelism - """ - v = v.reshape(3, n_hidden) - split_v = split(v, tensor_parallel, rank, dim=1) - split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) - return split_v.contiguous() - - -def split_matrix_tp(v, tensor_parallel, rank, dim): - return split(v, tensor_parallel, rank, dim=dim) - - -def get_weight(config, prefix, dtype): - if config[prefix + '.weight'].dtype != dtype: - config[prefix + '.weight'].data = config[prefix + '.weight'].to(dtype) - return config[prefix + '.weight'] - - -def get_bias(config, prefix, dtype): - if config[prefix + '.bias'].dtype != dtype: - config[prefix + '.bias'].data = config[prefix + '.bias'].to(dtype) - return config[prefix + '.bias'] - - -def get_weight_and_bias(config, prefix, dtype): - return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype) - - -def get_tllm_linear_weight(weight, - prefix, - bias=None, - use_weight_only=False, - plugin_weight_only_quant_type=torch.int8, - postfix='weight'): - results = {} - if use_weight_only: - v = weight.t().contiguous() - processed_torch_weights, torch_weight_scales = \ - torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( - v.cpu(), plugin_weight_only_quant_type) - results[prefix + postfix] = processed_torch_weights - results[prefix + 'per_channel_scale'] = torch_weight_scales - else: - results[prefix + postfix] = weight.contiguous() - - if bias is not None: - results[prefix + 'bias'] = bias - - return results - - -def dup_kv_weight(v, num_head, tp_size): - assert tp_size % num_head == 0 - reps = tp_size // num_head - head_size = v.shape[0] // num_head - v = v.reshape(num_head, head_size, - -1)[:, None, :, :].expand(num_head, reps, head_size, - v.shape[1]) - return v.reshape(num_head * reps * head_size, -1).clone().detach() - - -def get_tllm_linear_sq_weight(vals, - prefix, - shape, - tensor_parallel, - is_qkv=False, - per_token=False, - per_channel=False, - last_prefix=None, - bias=None, - smoother_value=None, - smoother_shape=None, - rank=0, - cat_dim=0, - multi_query_mode=False): - results = {} - - def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): - q, k, v = np.split(data, [local_dim, local_dim + head_size], axis=-1) - q_split = np.split(q, tp_size, axis=-1) - k_split = np.split(k, tp_size, axis=-1) - v_split = np.split(v, tp_size, axis=-1) - return [ - np.concatenate((q_split[ii], k_split[ii], v_split[ii]), axis=-1) - for ii in range(tp_size) - ][cur_rank] - - col_shape = shape if (is_qkv or per_channel) else [1, 1] - - if per_token: - original_weights = vals["weight.int8.col"] - - local_dim = original_weights.shape[0] - head_size = (original_weights.shape[1] - local_dim) // 2 - if multi_query_mode: - cur_weights = multi_query_split(original_weights, local_dim, - head_size, tensor_parallel, rank) - else: - cur_weights = np.split(original_weights, - tensor_parallel, - axis=cat_dim)[rank] - if is_qkv: - hidden_dim = cur_weights.shape[0] - cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() - if smoother_value is None: - results[last_prefix] = torch.from_numpy( - np.array([1.0], dtype=np.float32)) - - if smoother_value is None: - if multi_query_mode: - cur_per_channel_value = multi_query_split( - vals["scale_w_quant_orig.col"], local_dim, head_size, - tensor_parallel, rank) - else: - cur_per_channel_value = np.split(vals["scale_w_quant_orig.col"], - tensor_parallel, - axis=cat_dim)[rank] - else: - cur_per_channel_value = vals["scale_w_quant_orig.col"] - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array(cur_per_channel_value, - dtype=np.float32).reshape(col_shape)).contiguous() - else: - original_weights = np.array(vals["weight.int8"]) - cur_weights = np.split(original_weights, tensor_parallel, - axis=cat_dim)[rank] - - if is_qkv: - hidden_dim = cur_weights.shape[0] - cur_weights = cur_weights.reshape(hidden_dim, -1) - results[prefix + - 'weight'] = torch.from_numpy(cur_weights).t().contiguous() - # 'weight'] = torch.from_numpy(cur_weights).t().contiguous() - - cur_per_channel_value = vals["scale_y_accum_quant"] - - results[prefix + 'per_channel_scale'] = torch.from_numpy( - np.array([cur_per_channel_value], - dtype=np.float32).reshape(col_shape)).contiguous() - - results[last_prefix] = torch.from_numpy( - np.array([vals['scale_x_orig_quant']], - dtype=np.float32)).contiguous() - - results[prefix + 'act_scale'] = torch.from_numpy( - np.array([[vals["scale_y_quant_orig"]]], - dtype=np.float32)).contiguous() - - if smoother_value is not None: - cur_smoother_value = np.split(smoother_value, - tensor_parallel, - axis=cat_dim)[rank] - results[prefix + 'smoother'] = cur_smoother_value.reshape( - smoother_shape).contiguous().to(torch.float32) - - if bias is not None: - results[prefix + 'bias'] = bias - - return results - - -class QkvWeightHelper: - """ A helper utility for loading QKV weights from sharded files. """ - - def __init__(self, config: PretrainedConfig): - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.tp_size = config.mapping.tp_size - self.tp_rank = config.mapping.tp_rank - self.is_mha = self.num_heads == self.num_kv_heads - self._qkv_weights = {} - - @staticmethod - def is_qkv_weight(name): - for k in ['q_proj', 'k_proj', 'v_proj']: - if 'self_attn' in name and k in name: - return True - return False - - def add_weight(self, i: int, name: str, weight: torch.Tensor): - if 'q_proj' in name: - tag = 'q' - elif 'k_proj' in name: - tag = 'k' - elif 'v_proj' in name: - tag = 'v' - else: - raise ValueError(f'Got an unexpected parameter of name {name}') - if i not in self._qkv_weights: - self._qkv_weights[i] = {} - self._qkv_weights[i][tag] = weight - - def is_qkv_prepared(self, layer_idx): - if layer_idx not in self._qkv_weights: - return False - weights = self._qkv_weights[layer_idx] - return 'q' in weights and 'k' in weights and 'v' in weights - - def split_qkv_weights(self, layer_idx): - if not self.is_qkv_prepared(layer_idx): - return None - weights = self._qkv_weights.pop(layer_idx) # to prevent memory leak. - q, k, v = (torch.tensor(weights[t]) for t in ['q', 'k', 'v']) - - if not self.is_mha: - head_size = self.hidden_size // self.num_heads - if self.num_kv_heads < self.tp_size: - # duplicate the KV heads up to tensor_parallel - k = dup_kv_weight(k, self.num_kv_heads, self.tp_size) - v = dup_kv_weight(v, self.num_kv_heads, self.tp_size) - assert k.shape[0] % (self.tp_size * head_size) == 0 - assert v.shape[0] % (self.tp_size * head_size) == 0 - wq = split(q, self.tp_size, self.tp_rank) - wk = split(k, self.tp_size, self.tp_rank) - wv = split(v, self.tp_size, self.tp_rank) - fused_qkv = torch.cat((wq, wk, wv), dim=0) - else: - qkv = torch.cat([q, k, v], dim=0) - qkv = qkv.reshape(3, q.shape[0], q.shape[1]) - fused_qkv = split(qkv, self.tp_size, self.tp_rank, dim=1) - fused_qkv = fused_qkv.reshape(3 * (q.shape[0] // self.tp_size), - q.shape[1]) - return fused_qkv - - -def convert_hf_internlm(hf_model, - mapping, - rank=0, - dtype='float32', - use_parallel_embedding=False, - sharding_dim=0, - use_weight_only=False, - share_embedding_table=False, - plugin_weight_only_quant_type=torch.int8, - use_smooth_quant=False, - per_channel=False, - per_token=False, - int8_kv_cache=False, - act_range=[], - qkv_para=[], - smoother=[], - moe_config=None, - lora_config=None): - - weights = {} - tik = time.time() - tensor_parallel = mapping.tp_size - model_params = dict(hf_model.named_parameters()) - dtype = getattr(torch, dtype) - num_attention_heads = hf_model.config.num_attention_heads - hidden_size = hf_model.config.hidden_size - intermediate_size = hf_model.config.intermediate_size - - num_key_value_heads = hf_model.config.num_attention_heads - mha_mode = (num_key_value_heads == num_attention_heads) - - num_hidden_layers = hf_model.config.num_hidden_layers - layers_range = mapping.pp_layers(num_hidden_layers) - - if moe_config and moe_config.has_moe(): - rank_experts = list(range(moe_config.num_experts)) - if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL: - rank_experts = mapping.ep_experts(moe_config.num_experts) - - for l in range(num_hidden_layers): - for suffix in ["w1", "w2", "w3"]: - model_params[f'model.layers.{l}.block_sparse_moe.experts.{suffix}.weight'] = \ - torch.stack(list(model_params[f'model.layers.{l}.block_sparse_moe.experts.{expert}.{suffix}.weight'] - for expert in rank_experts)) - w3 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w3.weight'] - w2 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w2.weight'] - w1 = model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w1.weight'] - if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL: - w3 = split(w3, mapping.tp_size, mapping.tp_rank, dim=1) - w2 = split(w2, mapping.tp_size, mapping.tp_rank, dim=2) - w1 = split(w1, mapping.tp_size, mapping.tp_rank, dim=1) - # concat w3 and w1 for gated expert - model_params[f'model.layers.{l}.block_sparse_moe.experts.w3w1.weight'] = \ - torch.concat([w3, w1], dim=-2) - model_params[ - f'model.layers.{l}.block_sparse_moe.experts.w2.weight'] = w2 - - for l in layers_range: - layer_idx = l - layers_range[0] - prefix = f'model.layers.{l}.' - tllm_prex = f'transformer.layers.{layer_idx}.' - - q_weight = get_weight(model_params, prefix + 'self_attn.q_proj', dtype) - k_weight = get_weight(model_params, prefix + 'self_attn.k_proj', dtype) - v_weight = get_weight(model_params, prefix + 'self_attn.v_proj', dtype) - - if not mha_mode: - head_size = hidden_size // num_attention_heads - if num_key_value_heads < tensor_parallel: - # duplicate the KV heads up to tensor_parallel - k_weight = dup_kv_weight(k_weight, num_key_value_heads, - tensor_parallel) - v_weight = dup_kv_weight(v_weight, num_key_value_heads, - tensor_parallel) - assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 - - wq = split(q_weight, mapping.tp_size, mapping.tp_rank) - wk = split(k_weight, mapping.tp_size, mapping.tp_rank) - wv = split(v_weight, mapping.tp_size, mapping.tp_rank) - - split_v = torch.concat((wq, wk, wv)) - - else: - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - - split_v = split_qkv_tp(qkv_weight, num_attention_heads, hidden_size, - tensor_parallel, mapping.tp_rank) - - if prefix + 'self_attn.q_proj.bias' in model_params: - # only used in 7B models - # assert not mha_mode, "MHA mode not used in internlm 7B models" - q_bias = get_bias(model_params, prefix + 'self_attn.q_proj', dtype) - k_bias = get_bias(model_params, prefix + 'self_attn.k_proj', dtype) - v_bias = get_bias(model_params, prefix + 'self_attn.v_proj', dtype) - qkv_bias = torch.cat((q_bias, k_bias, v_bias)) - split_bias_v = split_qkv_bias_tp(qkv_bias, num_attention_heads, - hidden_size, tensor_parallel, - mapping.tp_rank) - else: - split_bias_v = None - - if use_smooth_quant: - qkv_weight = qkv_para[prefix + 'self_attn.qkv_proj'] - - if not mha_mode: - hidden_size = qkv_weight.shape[0] - local_dim = hidden_size - head_size = (qkv_weight.shape[-1] - local_dim) // 2 - qkv_weight = qkv_weight.reshape(hidden_size, - local_dim + 2 * head_size) - else: - qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size) - - int8_weights = generate_int8(qkv_weight, - act_range.get(prefix + - 'self_attn.qkv_proj'), - is_qkv=True, - multi_query_mode=bool(not mha_mode)) - - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - tllm_prex + 'attention.qkv.', [ - 1, 3 * hidden_size // tensor_parallel - if mha_mode else hidden_size // tensor_parallel + - (hidden_size // num_key_value_heads) // - tensor_parallel * 2 - ], - tensor_parallel, - is_qkv=True, - bias=split_bias_v, - per_token=per_token, - per_channel=per_channel, - last_prefix=tllm_prex + 'input_layernorm.scale_to_int', - smoother_value=None, - smoother_shape=None, - rank=mapping.tp_rank, - cat_dim=-1, - multi_query_mode=bool(not mha_mode))) - else: - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'attention.qkv.', - split_bias_v, use_weight_only, - plugin_weight_only_quant_type)) - - if int8_kv_cache: - qkv_y = torch.cat([ - act_range.get(prefix + 'self_attn.q_proj')["y"], - act_range.get(prefix + 'self_attn.k_proj')["y"], - act_range.get(prefix + 'self_attn.v_proj')["y"] - ], - dim=0) - - int8_kv_scales = qkv_y.max() / 127. - - kv_cache_weights = {} - - kv_cache_weights[ - tllm_prex + - 'attention.kv_cache_scaling_factor'] = int8_kv_scales.reshape( - [1]) - - weights.update(kv_cache_weights) - - attn_dense_weight = get_weight(model_params, - prefix + 'self_attn.o_proj', dtype) - - if prefix + 'self_attn.o_proj.bias' in model_params: - attn_dense_bias = get_bias(model_params, - prefix + 'self_attn.o_proj', dtype) - else: - attn_dense_bias = None - split_v = split_matrix_tp(attn_dense_weight, - tensor_parallel, - mapping.tp_rank, - dim=1) - if use_smooth_quant: - attn_dense_weight = attn_dense_weight.t() - int8_weights = generate_int8( - attn_dense_weight, act_range.get(prefix + 'self_attn.o_proj')) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - tllm_prex + 'attention.dense.', [1, hidden_size], - tensor_parallel, - is_qkv=False, - bias=attn_dense_bias, - per_token=per_token, - per_channel=per_channel, - last_prefix=tllm_prex + - 'attention.quantization_scaling_factor', - smoother_value=smoother[(prefix + 'self_attn.o_proj')], - smoother_shape=[1, hidden_size // tensor_parallel], - rank=mapping.tp_rank, - cat_dim=0)) - else: - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'attention.dense.', - attn_dense_bias, use_weight_only, - plugin_weight_only_quant_type)) - - if moe_config and moe_config.has_moe(): - ## block_sparse_moe.experts.w2.weight - moe_experts_w2_weights = get_weight( - model_params, prefix + 'block_sparse_moe.experts.w2', dtype) - weights.update( - get_tllm_linear_weight(moe_experts_w2_weights, - tllm_prex + 'mlp.experts_weight_2', - None, - use_weight_only, - plugin_weight_only_quant_type, - postfix='')) - ##block_sparse_moe.experts.w3w1.weight - moe_experts_w3w1_weights = get_weight( - model_params, prefix + 'block_sparse_moe.experts.w3w1', dtype) - weights.update( - get_tllm_linear_weight(moe_experts_w3w1_weights, - tllm_prex + 'mlp.experts_weight_1', - None, - use_weight_only, - plugin_weight_only_quant_type, - postfix='')) - - moe_experts_gate_weights = get_weight( - model_params, prefix + 'block_sparse_moe.gate', dtype) - v = split(moe_experts_gate_weights, - mapping.tp_size, - mapping.tp_rank, - dim=-1) - - weights.update( - get_tllm_linear_weight(v.to(torch.float32), - tllm_prex + 'mlp.router.', None, - use_weight_only, - plugin_weight_only_quant_type)) - else: - mlp_gate_weight = get_weight(model_params, prefix + 'mlp.up_proj', - dtype) - split_v = split_matrix_tp(mlp_gate_weight, - tensor_parallel, - mapping.tp_rank, - dim=0) - if use_smooth_quant: - mlp_gate_weight = mlp_gate_weight.t() - int8_weights = generate_int8( - mlp_gate_weight, act_range.get(prefix + 'mlp.up_proj')) - - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - tllm_prex + 'mlp.gate.', - [1, intermediate_size // tensor_parallel], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=tllm_prex + 'post_layernorm.scale_to_int', - smoother_value=None, - smoother_shape=None, - rank=mapping.tp_rank, - cat_dim=-1)) - else: - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'mlp.gate.', - None, use_weight_only, - plugin_weight_only_quant_type)) - - mlp_fc_weight = get_weight(model_params, prefix + 'mlp.gate_proj', - dtype) - split_v = split_matrix_tp(mlp_fc_weight, - tensor_parallel, - mapping.tp_rank, - dim=0) - - if use_smooth_quant: - mlp_fc_weight = mlp_fc_weight.t() #verified - int8_weights = generate_int8( - mlp_fc_weight, act_range.get(prefix + 'mlp.gate_proj')) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - tllm_prex + 'mlp.fc.', - [1, intermediate_size // tensor_parallel], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=tllm_prex + 'post_layernorm.scale_to_int', - smoother_value=None, - smoother_shape=None, - rank=mapping.tp_rank, - cat_dim=-1)) - else: - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'mlp.fc.', None, - use_weight_only, - plugin_weight_only_quant_type)) - - mlp_proj_weight = get_weight(model_params, prefix + 'mlp.down_proj', - dtype) - split_v = split_matrix_tp(mlp_proj_weight, - tensor_parallel, - mapping.tp_rank, - dim=1) - - if use_smooth_quant: - mlp_proj_weight = mlp_proj_weight.t() - int8_weights = generate_int8( - mlp_proj_weight, act_range.get(prefix + 'mlp.down_proj')) - weights.update( - get_tllm_linear_sq_weight( - int8_weights, - tllm_prex + 'mlp.proj.', [1, hidden_size], - tensor_parallel, - is_qkv=False, - per_token=per_token, - per_channel=per_channel, - last_prefix=tllm_prex + - 'mlp.quantization_scaling_factor', - smoother_value=smoother[prefix + 'mlp.down_proj'], - smoother_shape=[ - 1, intermediate_size // tensor_parallel - ], - rank=mapping.tp_rank, - cat_dim=0)) - else: - weights.update( - get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.', - None, use_weight_only, - plugin_weight_only_quant_type)) - - # Layer norms do not use tensor parallelism - input_ln_weight = get_weight(model_params, prefix + 'input_layernorm', - dtype) - weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight - - post_ln_weight = get_weight(model_params, - prefix + 'post_attention_layernorm', dtype) - weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight - - v = get_weight(model_params, 'model.embed_tokens', dtype) - if lora_config.is_valid and lora_config.embedding_weight is not None: - v = lora_config.embedding_weight - if hf_model.config.tie_word_embeddings: - # lm_head.weight has the same weights as embedding - if mapping.is_last_pp_rank(): - weights['lm_head.weight'] = split(v, mapping.tp_size, - mapping.tp_rank) - - if use_parallel_embedding: - v = split_matrix_tp(v, mapping.tp_size, rank, dim=sharding_dim) - - if mapping.is_first_pp_rank(): - weights['transformer.vocab_embedding.weight'] = v - - lm_head_weights = get_weight(model_params, 'lm_head', dtype) - - if mapping.is_last_pp_rank(): - - if lora_config.is_valid and lora_config.lm_head_weight is not None: - - lm_head_weights = lora_config.lm_head_weight - - weights['lm_head.weight'] = split_matrix_tp(lm_head_weights, - tensor_parallel, - mapping.tp_rank, - dim=0) - - ln_f_w = get_weight(model_params, 'model.norm', dtype) - weights['transformer.ln_f.weight'] = ln_f_w - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Weights loaded. Total time: {t}') - return weights - - -if __name__ == '__main__': - # TODO(qijun): Currently, the convert script depends on a torch op: - # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, - # which is included in tensorrt_llm Python package. Otherwise, the convert - # script does not need to import tensorrt_llm. Will remove it after reimplementing - # the op with PyTorch. - print(tensorrt_llm.__version__) - args = parse_arguments() - world_size = args.tp_size * args.pp_size - - tik = time.time() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - - hf_config = None - - if args.model_dir is not None: - hf_config = AutoConfig.from_pretrained(args.model_dir, - trust_remote_code=True) - - args.model_type = hf_config.model_type - args.n_head = hf_config.num_attention_heads - args.inter_size = hf_config.intermediate_size - args.n_layer = hf_config.num_hidden_layers - args.n_embd = hf_config.hidden_size - args.rms_norm_eps = hf_config.rms_norm_eps - args.vocab_size = hf_config.vocab_size - args.n_positions = hf_config.max_position_embeddings - args.bias = getattr(hf_config, "bias", False) - if hf_config.model_type == "mixtral": - # HF LLaMA-type models are implicitly using gated activation. - # With our MoE implementation, we must make it explicit - args.hidden_act = "swiglu" - args.moe_num_experts = getattr(hf_config, "num_local_experts", - args.moe_num_experts) - args.moe_top_k = getattr(hf_config, "num_experts_per_tok", - args.moe_top_k) - args.rotary_base = getattr(hf_config, "rope_theta", - args.rotary_base) - elif args.meta_ckpt_dir is not None: - with open(Path(args.meta_ckpt_dir, "params.json")) as fp: - meta_config: dict = json.load(fp) - args.n_embd = meta_config["dim"] - args.n_head = meta_config["n_heads"] - args.n_layer = meta_config["n_layers"] - args.n_kv_head = meta_config.get("n_kv_heads", args.n_head) - - if "hidden_dim" in meta_config: - args.inter_size = meta_config["hidden_dim"] - else: - args.multiple_of = meta_config.get("multiple_of", 1) - n_embd = int(4 * args.n_embd * 2 / 3) - args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) - args.inter_size = args.multiple_of * ( - (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) - // args.multiple_of) - args.rms_norm_eps = meta_config["norm_eps"] - args.moe_num_experts = meta_config.get("moe", {}).get("num_experts", 0) - args.moe_top_k = meta_config.get("moe", {}).get("num_experts_per_tok", - 0) - - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 - args.moe_config = MoeConfig(args.moe_num_experts, args.moe_top_k, - args.moe_tp_mode, - args.moe_renorm_mode).validate() - - if args.rotary_scaling is not None: - # assert args.use_gpt_attention_plugin, "RoPE scaling is only supported through GPT attention plugin." - rotary_scaling = { - "type": args.rotary_scaling[0], - "factor": float(args.rotary_scaling[1]) - } - assert rotary_scaling["type"] in ["linear", "dynamic"] - assert rotary_scaling["factor"] > 1.0 - args.rotary_scaling = rotary_scaling - - hf_modules_to_trtllm_modules = { - "q_proj": "attn_q", - "k_proj": "attn_k", - "v_proj": "attn_v", - "o_proj": "attn_dense", - "gate_proj": "mlp_h_to_4h", - "down_proj": "mlp_4h_to_h", - "up_proj": "mlp_gate" - } # lora modules on llama - - trtllm_modules_to_hf_modules = { - "attn_q": "q_proj", - "attn_k": "k_proj", - "attn_v": "v_proj", - "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_proj", - "mlp_4h_to_h": "down_proj", - "mlp_gate": "up_proj", - } - - lora_config = LoraConfig.from_hf(args.hf_lora_dir, - hf_modules_to_trtllm_modules, - trtllm_modules_to_hf_modules) - - if lora_config.is_valid and lora_config.vocab_size != 0: - if args.lora_target_modules is None: - args.lora_target_modules = lora_config.lora_target_modules - - # the lora checkpoint might finetune the embedding - if lora_config.vocab_size != 0: - args.vocab_size = lora_config.vocab_size - - args.lora_config = lora_config - - config = { - 'architecture': hf_config.architectures[0] - if hf_config is not None else "LlamaForCausalLM", - 'dtype': args.dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': args.n_layer, - 'num_attention_heads': args.n_head, - 'hidden_size': args.n_embd, - 'intermediate_size': args.inter_size, - 'vocab_size': args.vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': args.n_positions, - 'hidden_act': args.hidden_act, - 'rotary_base': args.rotary_base, - 'rotary_scaling': args.rotary_scaling, - 'norm_epsilon': args.rms_norm_eps, - 'quantization': { - 'quant_algo': None, - 'kv_cache_quant_algo': None, - "sq_use_plugin": True, - }, - 'mapping': { - 'world_size': world_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - }, - 'use_parallel_embedding': args.use_parallel_embedding, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': args.use_embedding_sharing, - 'use_prompt_tuning': args.use_prompt_tuning, - 'moe_num_experts': args.moe_num_experts, - 'moe_top_k': args.moe_top_k, - 'moe_tp_mode': args.moe_tp_mode, - 'moe_normalization_mode': args.moe_renorm_mode, - 'enable_pos_shift': args.enable_pos_shift, - 'dense_context_fmha': args.dense_context_fmha, - 'max_lora_rank': args.max_lora_rank, - 'lora_target_modules': args.lora_target_modules, - 'hf_modules_to_trtllm_modules': - args.lora_config.hf_modules_to_trtllm_modules, - 'trtllm_modules_to_hf_modules': - args.lora_config.trtllm_modules_to_hf_modules, - 'attn_bias': args.bias, - } - - if args.use_weight_only: - if args.weight_only_precision == 'int8': - config['quantization']['quant_algo'] = 'W8A16' - elif args.weight_only_precision == 'int4': - config['quantization']['quant_algo'] = 'W4A16' - elif args.smoothquant: - if args.per_channel: - if args.per_token: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN' - else: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN' - else: - if args.per_token: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN' - else: - config['quantization'][ - 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PLUGIN' - - if args.int8_kv_cache: - config['quantization']['kv_cache_quant_algo'] = 'INT8' - - if args.weight_only_precision == 'int4_gptq': - config['quantization'].update({ - "group_size": args.group_size, - "has_zero_point": True, - "pre_quant_scale": False, - 'quant_algo': 'W4A16_GPTQ' - }) - - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - - if args.weight_only_precision == 'int8': - plugin_weight_only_quant_type = torch.int8 - elif args.weight_only_precision == 'int4': - plugin_weight_only_quant_type = torch.quint4x2 - - act_range = {} - internlm_qkv_para = {} - # smoother for inputs of self_attn.o_proj and mlp.down_proj - internlm_smoother = {} - model = None - if args.model_dir is not None: - - model = AutoModelForCausalLM.from_pretrained(args.model_dir, - device_map="auto", - torch_dtype="auto", - trust_remote_code=True) - - if args.smoothquant is not None or args.int8_kv_cache: - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - if args.load_model_on_cpu: - logger.warning( - "Note that running capture_activation_range on cpu would be very small." - ) - dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0') - - act_range = capture_activation_range( - model, - AutoTokenizer.from_pretrained(args.model_dir, - padding_side='left', - trust_remote_code=True), dataset) - if args.smoothquant is not None: - smooth_internlm_model(model, act_range, args.smoothquant, - internlm_qkv_para, internlm_smoother) - - convert_args = { - 'hf_model': model, - 'act_range': act_range, - 'internlm_qkv_para': internlm_qkv_para, - 'internlm_smoother': internlm_smoother - } - - def covert_and_save(rank, convert_args): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - - if args.use_weight_only and args.weight_only_precision == 'int4_gptq': - weights = load_from_gptq_llama(args.ammo_quant_ckpt_path, - hf_config, - mapping, - dtype=args.dtype) - - elif args.meta_ckpt_dir is not None: - weights = load_from_meta_llama(args.meta_ckpt_dir, mapping, - PretrainedConfig.from_dict(config)) - - else: - if args.load_by_shard: - weights = load_from_hf_checkpoint( - args.model_dir, mapping, PretrainedConfig.from_dict(config), - args.lora_config) - else: - - weights = convert_hf_internlm( - convert_args['hf_model'], - mapping, - rank, - dtype=args.dtype, - use_weight_only=args.use_weight_only, - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_smooth_quant=args.smoothquant, - per_channel=args.per_channel, - per_token=args.per_token, - int8_kv_cache=args.int8_kv_cache, - act_range=convert_args['act_range'], - qkv_para=convert_args['internlm_qkv_para'], - smoother=convert_args['internlm_smoother'], - moe_config=args.moe_config, - lora_config=args.lora_config) - - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) - - if args.workers == 1: - - for rank in range(world_size): - covert_and_save(rank, convert_args) - else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(covert_and_save, rank, convert_args) - for rank in range(world_size) - ] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." - - tok = time.time() - t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) - print(f'Total time of converting checkpoints: {t}') diff --git a/examples/internlm/requirements.txt b/examples/internlm/requirements.txt index ed4f5ac05..fe334af0b 100644 --- a/examples/internlm/requirements.txt +++ b/examples/internlm/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets==2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/llama/README.md b/examples/llama/README.md index 4e1e203a9..13f44833b 100644 --- a/examples/llama/README.md +++ b/examples/llama/README.md @@ -74,6 +74,16 @@ trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16_wq \ --output_dir ./tmp/llama/7B/trt_engines/weight_only/1-gpu/ \ --gemm_plugin float16 +# Build LLaMA 7B using 2-way auto parallelism. +python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ + --output_dir ./tllm_checkpoint_1gpu_fp16 \ + --dtype float16 + +trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \ + --output_dir ./tmp/llama/7B/trt_engines/fp16/2-gpu/ \ + --gemm_plugin float16 \ + --world_size 2 + # Build LLaMA 7B using 2-way tensor parallelism. python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ --output_dir ./tllm_checkpoint_2gpu_tp2 \ @@ -570,8 +580,7 @@ python convert_checkpoint.py --model_dir /tmp/llama-v2-13b-hf \ --output_dir ./tllm_checkpoint_2gpu_lora \ --dtype float16 \ --tp_size 2 \ - --hf_lora_dir /tmp/chinese-llama-2-lora-13b \ - --use_fused_mlp + --hf_lora_dir /tmp/chinese-llama-2-lora-13b trtllm-build --checkpoint_dir ./tllm_checkpoint_2gpu_lora \ --output_dir /tmp/new_lora_13b/trt_engines/fp16/2-gpu/ \ @@ -692,19 +701,19 @@ We can observe that `luotuo-lora-7b-0.1` produces correct answers on the first s ## Run LLaMa with StreamingLLM -* Build engine. Set `--enable_pos_shift` to use positions in KV cache for RoPE, and set `--dense_context_fmha` to use dense context fmha in context phase. +* Build engine. Set `--pos_shift enable` to use positions in KV cache for RoPE, and set `--dense_context_fmha enable` to use dense context fmha in context phase. ```bash # Build the LLaMA 7B model with StreamingLLM feature using a single GPU and FP16. python convert_checkpoint.py --model_dir ./tmp/llama/7B/ \ --output_dir ./tllm_checkpoint_1gpu_streamlingllm \ - --dtype float16 \ - --dense_context_fmha \ - --enable_pos_shift + --dtype float16 trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_streamlingllm \ --output_dir ./tmp/llama/7B/trt_engines/fp16_StreamingLLM/1-gpu/ \ - --gemm_plugin float16 + --gemm_plugin float16 \ + --dense_context_fmha enable \ + --pos_shift enable ``` diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index 5c7b0703f..ab5cd2f77 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -4,31 +4,16 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Optional import safetensors -import torch -from datasets import load_dataset -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import tensorrt_llm from tensorrt_llm.layers import MoeConfig -from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping -from tensorrt_llm.models.llama.convert import (capture_activation_range, - convert_hf_llama, - smooth_llama_model) -from tensorrt_llm.models.llama.weight import (load_from_gptq_llama, - load_from_hf_checkpoint, - load_from_meta_llama) -from tensorrt_llm.models.modeling_utils import PretrainedConfig -from tensorrt_llm.runtime.lora_manager import LoraConfig - -try: - from transformers import LlavaConfig, LlavaForConditionalGeneration -except ImportError: - pass +from tensorrt_llm.models import LLaMAForCausalLM +from tensorrt_llm.models.llama.convert import (create_config_from_hugging_face, + from_hugging_face, quantize) +from tensorrt_llm.models.llama.weight import load_from_gptq_llama def parse_arguments(): @@ -134,18 +119,13 @@ def parse_arguments(): parser.add_argument('--hidden_act', type=str, default='silu') parser.add_argument('--rotary_base', type=float, default=10000.0) - parser.add_argument('--rotary_scaling', nargs=2, type=str, default=None) parser.add_argument('--group_size', type=int, default=128, - help='Group size used in GPTQ/AWQ quantization.') + help='Group size used in GPTQ quantization.' + ) # AWQ is only supported by quantize.py script - parser.add_argument("--storage-type", - "-t", - type=str, - default="fp32", - choices=["fp32", "fp16"]) parser.add_argument("--dataset-cache-dir", type=str, default=None, @@ -213,18 +193,7 @@ def parse_arguments(): help= 'Controls renormalization after gate logits. Check layers/moe.py for accepted values', ) - parser.add_argument('--enable_pos_shift', - default=False, - action='store_true', - help='Enable position shift for streamingllm method') - parser.add_argument( - '--dense_context_fmha', - default=False, - action='store_true', - help= - 'Enable dense fmha in context phase, otherwise sliding window attention.' - 'If dense_context_fmha=False, the sliding window size is the max attention window size.' - ) + parser.add_argument('--hf_lora_dir', type=str, default=None) parser.add_argument( '--lora_target_modules', @@ -249,20 +218,35 @@ def parse_arguments(): default=64, help='maximum lora rank for different lora modules. ' 'It is used to compute the workspace size of lora plugin.') + parser.add_argument( + '--save_config_only', + action="store_true", + default=False, + help= + 'Only save the model config w/o read and converting weights, be careful, this is for debug only' + ) + args = parser.parse_args() return args -def update_quantization_from_args(config: dict, args: argparse.Namespace): - '''update the given config dict in-place based on the command line args +def args_to_quantization(args: argparse.Namespace): + '''return config dict with quantization info based on the command line args ''' + config = { + 'quantization': { + 'quant_algo': None, + 'kv_cache_quant_algo': None, + 'exclude_modules': ['lm_head'], + } + } + if args.use_weight_only: if args.weight_only_precision == 'int8': config['quantization']['quant_algo'] = 'W8A16' elif args.weight_only_precision == 'int4': config['quantization']['quant_algo'] = 'W4A16' elif args.smoothquant: - config['quantization']['sq_use_plugin'] = True if args.per_channel: if args.per_token: config['quantization'][ @@ -278,9 +262,6 @@ def update_quantization_from_args(config: dict, args: argparse.Namespace): config['quantization'][ 'quant_algo'] = 'W8A8_SQ_PER_TENSOR_PLUGIN' - if args.use_weight_only and args.moe_config.has_moe(): - config['quantization']['exclude_modules'].append('router') - if args.int8_kv_cache: config['quantization']['kv_cache_quant_algo'] = 'INT8' @@ -291,303 +272,204 @@ def update_quantization_from_args(config: dict, args: argparse.Namespace): "pre_quant_scale": False, 'quant_algo': 'W4A16_GPTQ' }) - - -def create_config_from_args(args: argparse.Namespace, - lora_config: Optional[LoraConfig] = None): - config = { - 'architecture': args.architecture, - 'dtype': args.dtype, - 'logits_dtype': 'float32', - 'num_hidden_layers': args.n_layer, - 'num_attention_heads': args.n_head, - 'hidden_size': args.n_embd, - 'intermediate_size': args.inter_size, - 'num_key_value_heads': args.n_kv_head, - 'vocab_size': args.vocab_size, - 'position_embedding_type': 'rope_gpt_neox', - 'max_position_embeddings': args.n_positions, - 'hidden_act': args.hidden_act, - 'rotary_base': args.rotary_base, - 'rotary_scaling': args.rotary_scaling, - 'norm_epsilon': args.rms_norm_eps, - 'quantization': { - 'quant_algo': None, - 'kv_cache_quant_algo': None, - "sq_use_plugin": False, - 'exclude_modules': ['lm_head'], - }, - 'mapping': { - 'world_size': args.tp_size * args.pp_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - }, - 'use_parallel_embedding': args.use_parallel_embedding, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': args.use_embedding_sharing, - 'use_prompt_tuning': args.use_prompt_tuning, - 'moe_num_experts': args.moe_num_experts, - 'moe_top_k': args.moe_top_k, - 'moe_tp_mode': args.moe_tp_mode, - 'moe_normalization_mode': args.moe_renorm_mode, - 'enable_pos_shift': args.enable_pos_shift, - 'dense_context_fmha': args.dense_context_fmha, - } - if lora_config is not None: - config.update({ - 'max_lora_rank': - args.max_lora_rank, - 'lora_target_modules': - lora_config.lora_target_modules, - 'hf_modules_to_trtllm_modules': - lora_config.hf_modules_to_trtllm_modules, - 'trtllm_modules_to_hf_modules': - lora_config.trtllm_modules_to_hf_modules, - 'disable_weight_only_quant_plugin': - args.disable_weight_only_quant_plugin - }) - # the lora checkpoint might finetune the embedding - if lora_config.vocab_size != 0: - config['vocab_size'] = lora_config.vocab_size - update_quantization_from_args(config, args) return config -def create_lora_config(args: argparse.Namespace): - '''update args based on lora dir - ''' - hf_modules_to_trtllm_modules = { - "q_proj": "attn_q", - "k_proj": "attn_k", - "v_proj": "attn_v", - "o_proj": "attn_dense", - "gate_proj": "mlp_h_to_4h", - "down_proj": "mlp_4h_to_h", - "up_proj": "mlp_gate" - } # lora modules on llama - - trtllm_modules_to_hf_modules = { - "attn_q": "q_proj", - "attn_k": "k_proj", - "attn_v": "v_proj", - "attn_dense": "o_proj", - "mlp_h_to_4h": "gate_proj", - "mlp_4h_to_h": "down_proj", - "mlp_gate": "up_proj", - } +def has_any_quant(args): + config = args_to_quantization(args) + return config['quantization']['quant_algo'] is not None or config[ + 'quantization']['kv_cache_quant_algo'] is not None - lora_config = LoraConfig.from_hf(args.hf_lora_dir, - hf_modules_to_trtllm_modules, - trtllm_modules_to_hf_modules) - - if lora_config.is_valid: - if args.lora_target_modules is not None: - # command line options is preferred over the modules in the lora dir - lora_config.lora_target_modules = args.lora_target_modules - # can be invalid - return lora_config - - -def smooth_quant(model, args): - assert model is not None - act_range = {} - llama_qkv_para = {} - # smoother for inputs of self_attn.o_proj and mlp.down_proj - llama_smoother = {} - - os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( - "TOKENIZERS_PARALLELISM", "false") - if args.load_model_on_cpu: - logger.warning( - "Note that running capture_activation_range on cpu would be very small." - ) - dataset = load_dataset("ccdv/cnn_dailymail", - '3.0.0', - cache_dir=args.dataset_cache_dir) - - act_range = capture_activation_range( - model, - AutoTokenizer.from_pretrained(args.model_dir, - trust_remote_code=True, - use_fast=False, - padding_side='left'), dataset) - if args.smoothquant is not None: - smooth_llama_model(model, act_range, args.smoothquant, llama_qkv_para, - llama_smoother) - return act_range, llama_qkv_para, llama_smoother +def create_config_from_args(args: argparse.Namespace): + config = {} + mapping = Mapping(world_size=args.tp_size * args.pp_size, + tp_size=args.tp_size, + pp_size=args.pp_size) -def main(): - # TODO(qijun): Currently, the convert script depends on a torch op: - # torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix, - # which is included in tensorrt_llm Python package. Otherwise, the convert - # script does not need to import tensorrt_llm. Will remove it after reimplementing - # the op with PyTorch. - print(tensorrt_llm.__version__) - args = parse_arguments() + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {'moe_tp_mode': args.moe_tp_mode} + override_fields.update(args_to_quantization(args)) + override_fields.update(args_to_build_options(args)) - world_size = args.tp_size * args.pp_size - - tik = time.time() - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - hf_config = None - if args.model_dir is not None: - hf_config = AutoConfig.from_pretrained(args.model_dir, - trust_remote_code=True) - if hf_config.model_type == "llava": - # LLaVA = Vision model + Llama LLM - # We load a llava config and use its' text config as llama config - hf_config = LlavaConfig.from_pretrained(args.model_dir).text_config - hf_config.model_type = "llava" # Replace llama with llava - - args.model_type = hf_config.model_type - args.n_head = hf_config.num_attention_heads - args.inter_size = hf_config.intermediate_size - args.n_layer = hf_config.num_hidden_layers - args.n_embd = hf_config.hidden_size - args.n_kv_head = hf_config.num_key_value_heads - args.rms_norm_eps = hf_config.rms_norm_eps - args.vocab_size = hf_config.vocab_size - args.n_positions = hf_config.max_position_embeddings - args.rotary_scaling = getattr(hf_config, "rope_scaling", None) - args.rotary_base = getattr(hf_config, "rope_theta", args.rotary_base) - args.vocab_size = getattr(hf_config, "vocab_size", args.vocab_size) - if hf_config.model_type == "mixtral": - # HF LLaMA-type models are implicitly using gated activation. - # With our MoE implementation, we must make it explicit - args.hidden_act = "swiglu" - args.moe_num_experts = getattr(hf_config, "num_local_experts", - args.moe_num_experts) - args.moe_top_k = getattr(hf_config, "num_experts_per_tok", - args.moe_top_k) - - args.architecture = hf_config.architectures[0] - - elif args.meta_ckpt_dir is not None: - with open(Path(args.meta_ckpt_dir, "params.json")) as fp: - meta_config: dict = json.load(fp) - args.n_embd = meta_config["dim"] - args.n_head = meta_config["n_heads"] - args.n_layer = meta_config["n_layers"] - args.n_kv_head = meta_config.get("n_kv_heads", args.n_head) - - if "hidden_dim" in meta_config: - args.inter_size = meta_config["hidden_dim"] - else: - args.multiple_of = meta_config.get("multiple_of", 1) - n_embd = int(4 * args.n_embd * 2 / 3) - args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) - args.inter_size = args.multiple_of * ( - (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) - // args.multiple_of) - args.rms_norm_eps = meta_config["norm_eps"] - args.moe_num_experts = meta_config.get("moe", {}).get("num_experts", 0) - args.moe_top_k = meta_config.get("moe", {}).get("num_experts_per_tok", - 0) - args.architecture = "LlamaForCausalLM" - else: # fake checkpoint for testing only - args.n_kv_head = args.n_kv_head or args.n_head - args.architecture = "LlamaForCausalLM" + assert args.model_dir is not None + kwargs = { + 'hf_lora_dir': args.hf_lora_dir, + 'lora_target_modules': args.lora_target_modules, + 'max_lora_rank': args.max_lora_rank, + } + config = create_config_from_hugging_face(args.model_dir, + args.dtype, + mapping, + override_fields=override_fields, + **kwargs) + return config - if args.moe_num_experts and args.moe_top_k == 0: - args.moe_top_k = 1 - args.moe_config = MoeConfig(args.moe_num_experts, args.moe_top_k, - args.moe_tp_mode, - args.moe_renorm_mode).validate() - lora_config = create_lora_config(args) - config = create_config_from_args(args, lora_config) +def convert_and_save_meta(args, rank): + mapping = Mapping(world_size=args.tp_size * args.pp_size, + tp_size=args.tp_size, + pp_size=args.pp_size, + rank=rank) + override_fields = {'moe_tp_mode': args.moe_tp_mode} + override_fields.update(args_to_quantization(args)) + override_fields.update(args_to_build_options(args)) + + assert not has_any_quant( + args + ), "quantization from meta checkpoint or empty model were never supported" + assert not args.hf_lora_dir, "lora is only supported when loading from hf model dir for now" + kwargs = {} + assert args.meta_ckpt_dir is not None + llama = LLaMAForCausalLM.from_meta_ckpt(args.meta_ckpt_dir, + args.dtype, + mapping, + override_fileds=override_fields, + **kwargs) + llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) + + +def args_to_build_options(args): + return { + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, + 'use_prompt_tuning': args.use_prompt_tuning, + 'disable_weight_only_quant_plugin': + args.disable_weight_only_quant_plugin + } - with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - if args.model_dir is None and args.meta_ckpt_dir is None: - return - act_range = {} - llama_qkv_para = {} - # smoother for inputs of self_attn.o_proj and mlp.down_proj - llama_smoother = {} - model = None - if args.model_dir is not None: - - if args.model_type == "llava": - hf_llava = LlavaForConditionalGeneration.from_pretrained( - args.model_dir, torch_dtype="auto") - model = hf_llava.language_model - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_dir, - device_map='auto' if not args.load_model_on_cpu else 'cpu', - torch_dtype='auto' if not args.smoothquant else torch.float16, - trust_remote_code=True, - ) - if args.smoothquant is not None or args.int8_kv_cache: - act_range, llama_qkv_para, llama_smoother = smooth_quant( - model, args) - - def covert_and_save(rank): - mapping = Mapping(world_size=world_size, - rank=rank, - tp_size=args.tp_size, - pp_size=args.pp_size) - - if args.use_weight_only and args.weight_only_precision == 'int4_gptq': - weights = load_from_gptq_llama(args.ammo_quant_ckpt_path, - args.n_layer, - args.vocab_size, - mapping, - dtype=args.dtype) - - elif args.meta_ckpt_dir is not None: - weights = load_from_meta_llama(args.meta_ckpt_dir, mapping, - PretrainedConfig.from_dict(config)) - else: - if args.load_by_shard: - weights = load_from_hf_checkpoint( - args.model_dir, mapping, PretrainedConfig.from_dict(config), - lora_config) +def from_cli_args(args): + config = {} + mapping = Mapping(world_size=args.tp_size * args.pp_size, + tp_size=args.tp_size, + pp_size=args.pp_size) + architecture = "LlamaForCausalLM" + n_layer = args.n_layer + n_head = args.n_head + n_embd = args.n_embd + inter_size = args.inter_size + n_kv_head = args.n_kv_head if args.n_kv_head is not None else n_head # default to MHA + vocab_size = args.vocab_size + n_positions = args.n_positions + hidden_act = args.hidden_act + rotary_base = args.rotary_base + rms_norm_eps = args.rms_norm_eps + moe_num_experts = args.moe_num_experts + moe_top_k = args.moe_top_k + moe_tp_mode = args.moe_tp_mode + config['moe_normalization_mode'] = args.moe_renorm_mode + # config values from reading model config + config.update({ + 'architecture': architecture, + 'dtype': args.dtype, + 'logits_dtype': 'float32', + 'num_hidden_layers': n_layer, + 'num_attention_heads': n_head, + 'hidden_size': n_embd, + 'intermediate_size': inter_size, + 'num_key_value_heads': n_kv_head, + 'vocab_size': vocab_size, + 'position_embedding_type': 'rope_gpt_neox', + 'max_position_embeddings': n_positions, + 'hidden_act': hidden_act, + 'rotary_base': rotary_base, + 'norm_epsilon': rms_norm_eps, + 'moe_num_experts': moe_num_experts, + 'moe_top_k': moe_top_k, + 'moe_tp_mode': moe_tp_mode, + 'mapping': { + 'world_size': mapping.tp_size * mapping.pp_size, + 'tp_size': mapping.tp_size, + 'pp_size': mapping.pp_size + } + }) + config.update(args_to_build_options(args)) + return config - else: - if args.weight_only_precision == 'int8': - plugin_weight_only_quant_type = torch.int8 - elif args.weight_only_precision == 'int4': - plugin_weight_only_quant_type = torch.quint4x2 - weights = convert_hf_llama( - model, - mapping, - vocab_size=args.vocab_size, - dtype=args.dtype, - use_weight_only=args.use_weight_only, - use_gemm_woq_plugin=not args. - disable_weight_only_quant_plugin, - plugin_weight_only_quant_type=plugin_weight_only_quant_type, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim, - share_embedding_table=args.use_embedding_sharing, - use_smooth_quant=args.smoothquant, - per_channel=args.per_channel, - per_token=args.per_token, - int8_kv_cache=args.int8_kv_cache, - act_range=act_range, - qkv_para=llama_qkv_para, - smoother=llama_smoother, - moe_config=args.moe_config, - lora_config=lora_config) - - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) - - if args.workers == 1: +def convert_and_save_hf(args): + model_dir = args.model_dir + load_model_on_cpu = args.load_model_on_cpu + load_by_shard = args.load_by_shard + world_size = args.tp_size * args.pp_size + # Need to convert the cli args to the kay-value pairs and override them in the generate config dict. + # Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now, + # before the refactor is done. + override_fields = {'moe_tp_mode': args.moe_tp_mode} + override_fields.update(args_to_quantization(args)) + override_fields.update(args_to_build_options(args)) + assert model_dir is not None + + if args.smoothquant is not None or args.int8_kv_cache: + assert not args.load_by_shard, "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported" + assert not args.load_model_on_cpu, "When using quantization, TRT-LLM needs to load the model to GPU" + mapping = Mapping( + world_size=world_size, + rank=-1, #intentinoally make -1 to avoid mistake + tp_size=args.tp_size, + pp_size=args.pp_size) + quantize(args.dtype, + args.model_dir, + args.output_dir, + mapping, + override_fields=override_fields, + dataset_cache_dir=args.dataset_cache_dir, + smoothquant_val=args.smoothquant, + int8_kv_cache=args.int8_kv_cache, + hf_lora_dir=args.hf_lora_dir, + lora_target_modules=args.lora_target_modules, + max_lora_rank=args.max_lora_rank) + else: for rank in range(world_size): - covert_and_save(rank) + mapping = Mapping(world_size=world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size) + #TODO: change to LLaMAForCausalLM.from_hugging_face after refactor is done + llama = from_hugging_face( + LLaMAForCausalLM, + model_dir, + args.dtype, + mapping=mapping, + load_by_shard=load_by_shard, + load_model_on_cpu=load_model_on_cpu, + override_fields=override_fields, + hf_lora_dir=args.hf_lora_dir, + lora_target_modules=args.lora_target_modules, + max_lora_rank=args.max_lora_rank) + llama.save_checkpoint(args.output_dir, save_config=(rank == 0)) + + +def convert_and_save_gptq(args, rank): + config = create_config_from_args(args) + if rank == 0: + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + mapping = Mapping(world_size=config['mapping']['tp_size'] * + config['mapping']['pp_size'], + rank=rank, + tp_size=config['mapping']['tp_size'], + pp_size=config['mapping']['pp_size']) + weights = load_from_gptq_llama(args.ammo_quant_ckpt_path, + config['num_hidden_layers'], + config['vocab_size'], + mapping, + dtype=config['dtype']) + safetensors.torch.save_file( + weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + + +def execute(workers, func, args): + if workers == 1: + for rank, f in enumerate(func): + f(args, rank) else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(covert_and_save, rank) for rank in range(world_size) - ] + with ThreadPoolExecutor(max_workers=workers) as p: + futures = [p.submit(f, args, rank) for rank, f in enumerate(func)] exceptions = [] for future in as_completed(futures): try: @@ -599,6 +481,37 @@ def covert_and_save(rank): exceptions ) == 0, "Checkpoint conversion failed, please check error log." + +def main(): + print(tensorrt_llm.__version__) + args = parse_arguments() + + # changing the default to be consistent as the cli help said. + if args.moe_num_experts and args.moe_top_k == 0: + args.moe_top_k = 1 + world_size = args.tp_size * args.pp_size + tik = time.time() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + ####### save config + if (args.model_dir is None and args.meta_ckpt_dir is None): + config = from_cli_args(args) + with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: + json.dump(config, f, indent=4) + return + elif args.meta_ckpt_dir is not None: + execute(args.workers, [convert_and_save_meta] * world_size, args) + elif args.weight_only_precision == 'int4_gptq': + assert args.model_dir is not None + assert args.ammo_quant_ckpt_path is not None + execute(args.workers, [convert_and_save_gptq] * world_size, args) + else: # all other non-gptq paths from hf model + assert args.model_dir is not None + assert args.ammo_quant_ckpt_path is None, "only gptq weights only needs this option" + convert_and_save_hf(args) + tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') diff --git a/examples/llama/requirements.txt b/examples/llama/requirements.txt index 55f57b4de..c2524c4cb 100644 --- a/examples/llama/requirements.txt +++ b/examples/llama/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets==2.14.6 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/mamba/requirements.txt b/examples/mamba/requirements.txt index 7e55f4d98..b6b028069 100644 --- a/examples/mamba/requirements.txt +++ b/examples/mamba/requirements.txt @@ -1,2 +1,4 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 mamba-ssm==1.1.1 causal-conv1d==1.1.1 # 1.1.2 needs torch 2.2, while TRT-LLM sticks to pre 2.2 diff --git a/examples/medusa/convert_checkpoint.py b/examples/medusa/convert_checkpoint.py index d7872ab42..8645ee9b6 100644 --- a/examples/medusa/convert_checkpoint.py +++ b/examples/medusa/convert_checkpoint.py @@ -177,18 +177,6 @@ def parse_arguments(): default=1, help='The number of workers for converting checkpoint in parallel') - parser.add_argument('--enable_pos_shift', - default=False, - action='store_true', - help='Enable position shift for streamingllm method') - parser.add_argument( - '--dense_context_fmha', - default=False, - action='store_true', - help= - 'Enable dense fmha in context phase, otherwise sliding window attention.' - 'If dense_context_fmha=False, the sliding window size is the max attention window size.' - ) parser.add_argument('--num_medusa_heads', type=int, default=4) parser.add_argument( '--fixed_num_medusa_heads', @@ -1140,7 +1128,6 @@ def convert_hf_llama(hf_model, 'quantization': { 'quant_algo': None, 'kv_cache_quant_algo': None, - "sq_use_plugin": True, }, 'mapping': { 'world_size': world_size, @@ -1151,8 +1138,6 @@ def convert_hf_llama(hf_model, 'embedding_sharding_dim': args.embedding_sharding_dim, 'share_embedding_table': args.use_embedding_sharing, 'use_prompt_tuning': args.use_prompt_tuning, - 'enable_pos_shift': args.enable_pos_shift, - 'dense_context_fmha': args.dense_context_fmha, 'max_draft_len': args.max_medusa_token_len, 'num_medusa_heads': args.num_medusa_heads, 'num_medusa_layers': args.num_medusa_layers diff --git a/examples/medusa/requirements.txt b/examples/medusa/requirements.txt index 180e5f0be..c14ed53cd 100644 --- a/examples/medusa/requirements.txt +++ b/examples/medusa/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 rouge_score~=0.1.2 sentencepiece~=0.1.99 diff --git a/examples/mixtral/requirements.txt b/examples/mixtral/requirements.txt index a1ee1e577..598fa22ca 100644 --- a/examples/mixtral/requirements.txt +++ b/examples/mixtral/requirements.txt @@ -1,2 +1,4 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 transformers==4.36.1 accelerate==0.25.0 diff --git a/examples/mpt/convert_checkpoint.py b/examples/mpt/convert_checkpoint.py index 1444e4468..38061a280 100644 --- a/examples/mpt/convert_checkpoint.py +++ b/examples/mpt/convert_checkpoint.py @@ -42,6 +42,31 @@ def parse_arguments(): type=str, default='float32', choices=['float16', 'float32']) + parser.add_argument( + '--use_parallel_embedding', + action="store_true", + default=False, + help= + 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' + ) + parser.add_argument( + '--embedding_sharding_dim', + type=int, + default=0, + choices=[0, 1], + help= + 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' + 'To shard it along hidden dimension, set embedding_sharding_dim=1' + 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' + ) + parser.add_argument( + '--use_embedding_sharing', + action="store_true", + default=False, + help= + 'Try to reduce the engine size by sharing the embedding lookup table between two layers.' + 'Note: the flag might not take effect when the criteria are not met.') + parser.add_argument( "--calibrate_kv_cache", "-kv", @@ -617,6 +642,9 @@ def convert_hf_mpt_legacy(hf_model, mapping, rank=0, dtype='float32', + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, use_weight_only=False, plugin_weight_only_quant_type='int8', use_smooth_quant=False, @@ -633,6 +661,7 @@ def convert_hf_mpt_legacy(hf_model, dtype = getattr(torch, dtype) num_attention_heads = hf_model.config.n_heads hidden_size = hf_model.config.d_model + vocab_size = hf_model.config.vocab_size num_key_value_heads = hf_config.attn_config['kv_n_heads'] if 'kv_n_heads' in hf_config.attn_config \ else hf_config.n_heads multi_query_mode = (num_key_value_heads != num_attention_heads) @@ -794,13 +823,22 @@ def convert_hf_mpt_legacy(hf_model, embed_w = get_weight(model_params, 'transformer.wte', dtype) if mapping.is_first_pp_rank(): # Embedding - weights['transformer.vocab_embedding.weight'] = embed_w + if not use_parallel_embedding: + weights['transformer.vocab_embedding.weight'] = embed_w + else: + if sharding_dim == 0: + assert vocab_size % mapping.tp_size == 0 + else: + assert hidden_size % mapping.tp_size == 0 + weights['transformer.vocab_embedding.weight'] = split_matrix( + embed_w, mapping.tp_size, mapping.tp_rank, sharding_dim) if mapping.is_last_pp_rank(): # lm_head weight and bias - weights['lm_head.weight'] = split_matrix(embed_w.clone(), - mapping.tp_size, - mapping.tp_rank, - dim=0) + if not share_embedding_table: + weights['lm_head.weight'] = split_matrix(embed_w.clone(), + mapping.tp_size, + mapping.tp_rank, + dim=0) ln_f_w = get_weight(model_params, 'transformer.norm_f', dtype) # ln_f weight and bias weights['transformer.ln_f.weight'] = ln_f_w @@ -815,6 +853,9 @@ def convert_hf_mpt(hf_model: MptForCausalLM, hf_config: AutoConfig, mapping: Mapping, dtype: str = 'float32', + use_parallel_embedding: bool = False, + sharding_dim: int = 0, + share_embedding_table: bool = False, use_weight_only: bool = False, plugin_weight_only_quant_type: torch.dtype = torch.int8): @@ -827,7 +868,8 @@ def convert_hf_mpt(hf_model: MptForCausalLM, num_head = hf_config.n_heads num_kv_heads = hf_config.attn_config['kv_n_heads'] if 'kv_n_heads' in hf_config.attn_config \ else hf_config.n_heads - num_hidden = hf_config.d_model + hidden_size = hf_config.d_model + vocab_size = hf_config.vocab_size layers_range = mapping.pp_layers(num_hidden_layers) for l in layers_range: @@ -835,7 +877,7 @@ def convert_hf_mpt(hf_model: MptForCausalLM, tllm_prex = f'transformer.layers.{l-layers_range[0]}' # Attention QKV (no bias) qkv_w = get_weight(model_params, f'{prefix}.attn.Wqkv', dtype) - qkv_w = split_qkv_tp(qkv_w, num_head, num_kv_heads, num_hidden, + qkv_w = split_qkv_tp(qkv_w, num_head, num_kv_heads, hidden_size, mapping.tp_size, mapping.tp_rank) weights.update( get_tllm_linear_weight(qkv_w, f'{tllm_prex}.attention.qkv', None, @@ -884,13 +926,22 @@ def convert_hf_mpt(hf_model: MptForCausalLM, embed_w = get_weight(model_params, 'transformer.wte', dtype) if mapping.is_first_pp_rank(): # Embedding - weights['transformer.vocab_embedding.weight'] = embed_w + if not use_parallel_embedding: + weights['transformer.vocab_embedding.weight'] = embed_w + else: + if sharding_dim == 0: + assert vocab_size % mapping.tp_size == 0 + else: + assert hidden_size % mapping.tp_size == 0 + weights['transformer.vocab_embedding.weight'] = split_matrix( + embed_w, mapping.tp_size, mapping.tp_rank, sharding_dim) if mapping.is_last_pp_rank(): # lm_head weight and bias - weights['lm_head.weight'] = split_matrix(embed_w.clone(), - mapping.tp_size, - mapping.tp_rank, - dim=0) + if not share_embedding_table: + weights['lm_head.weight'] = split_matrix(embed_w.clone(), + mapping.tp_size, + mapping.tp_rank, + dim=0) ln_f_w = get_weight(model_params, 'transformer.norm_f', dtype) # ln_f weight and bias weights['transformer.ln_f.weight'] = ln_f_w @@ -956,10 +1007,12 @@ def convert_hf_mpt(hf_model: MptForCausalLM, 'num_key_value_heads': num_kv_heads, 'position_embedding_type': 'alibi', 'hidden_act': 'gelu', + 'use_parallel_embedding': args.use_parallel_embedding, + 'embedding_sharding_dim': args.embedding_sharding_dim, + 'share_embedding_table': args.use_embedding_sharing, 'quantization': { 'quant_algo': quant_algo, 'kv_cache_quant_algo': kv_cache_quant_algo, - 'sq_use_plugin': True, }, 'mapping': { 'world_size': world_size, @@ -974,42 +1027,61 @@ def convert_hf_mpt(hf_model: MptForCausalLM, with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(config, f, indent=4) + hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, + trust_remote_code=True, + device_map="auto", + torch_dtype=getattr( + torch, args.dtype)) + + act_range = {} + mpt_qkv_para = {} + # smoother for inputs of self_attn.o_proj and mlp.down_proj + mpt_smoother = {} + if args.smoothquant is not None or args.calibrate_kv_cache: + dataset = load_dataset("ccdv/cnn_dailymail", + '3.0.0', + cache_dir=args.dataset_cache_dir) + act_range = capture_activation_range( + hf_model, + AutoTokenizer.from_pretrained(args.model_dir, padding_side='left'), + dataset) + if args.smoothquant is not None: + smooth_mpt_model(hf_model, act_range, args.smoothquant, + mpt_qkv_para, mpt_smoother) + def covert_and_save(rank): mapping = Mapping(world_size=world_size, rank=rank, tp_size=args.tp_size, pp_size=args.pp_size) - hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, - trust_remote_code=True, - device_map="auto", - torch_dtype=getattr( - torch, args.dtype)) - act_range = {} - mpt_qkv_para = {} - # smoother for inputs of self_attn.o_proj and mlp.down_proj - mpt_smoother = {} + if args.smoothquant is not None or args.calibrate_kv_cache: - dataset = load_dataset("ccdv/cnn_dailymail", - '3.0.0', - cache_dir=args.dataset_cache_dir) - act_range = capture_activation_range( - hf_model, - AutoTokenizer.from_pretrained(args.model_dir, - padding_side='left'), dataset) - if args.smoothquant is not None: - smooth_mpt_model(hf_model, act_range, args.smoothquant, - mpt_qkv_para, mpt_smoother) weights = convert_hf_mpt_legacy( - hf_model, mapping, rank, args.dtype, args.use_weight_only, - plugin_weight_only_quant_type, args.smoothquant is not None, - args.per_channel, args.per_token, args.calibrate_kv_cache, - act_range, mpt_qkv_para, mpt_smoother) + hf_model, + mapping, + rank, + dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, + use_weight_only=args.use_weight_only, + plugin_weight_only_quant_type=plugin_weight_only_quant_type, + use_smooth_quant=(args.smoothquant is not None), + per_channel=args.per_channel, + per_token=args.per_token, + int8_kv_cache=args.calibrate_kv_cache, + act_range=act_range, + qkv_para=mpt_qkv_para, + smoother=mpt_smoother) else: weights = convert_hf_mpt( hf_model, hf_config, mapping, dtype=args.dtype, + use_parallel_embedding=args.use_parallel_embedding, + sharding_dim=args.embedding_sharding_dim, + share_embedding_table=args.use_embedding_sharing, use_weight_only=args.use_weight_only, plugin_weight_only_quant_type=plugin_weight_only_quant_type) @@ -1035,6 +1107,7 @@ def covert_and_save(rank): exceptions ) == 0, "Checkpoint conversion failed, please check error log." + del hf_model tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) print(f'Total time of converting checkpoints: {t}') diff --git a/examples/mpt/requirements.txt b/examples/mpt/requirements.txt index fb205c8dd..9f15a5f00 100644 --- a/examples/mpt/requirements.txt +++ b/examples/mpt/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/multimodal/README.md b/examples/multimodal/README.md index 769a333fc..bb582e1fb 100644 --- a/examples/multimodal/README.md +++ b/examples/multimodal/README.md @@ -1,3 +1,4 @@ + # Multi-Modal This document shows how to run multimodal pipelines with TensorRT-LLM, e.g. from image+text input modalities to text output. @@ -6,6 +7,12 @@ Multimodal models' LLM part has an additional parameter `--max_multimodal_len` c We first describe how to run each model on a single GPU. We then provide general guidelines on using tensor parallelism for LLM part of the pipeline. +- [BLIP2-T5](#blip2-t5) +- [BLIP2-OPT](#blip2-opt) +- [LLaVA and VILA](#llava-and-vila) +- [Nougat](#nougat) +- [Enabling tensor parallelism for multi-GPU](#enabling-tensor-parallelism-for-multi-gpu) + ## BLIP2-T5 1. Download Huggingface weights and convert original checkpoint to TRT-LLM checkpoint format @@ -46,7 +53,7 @@ We first describe how to run each model on a single GPU. We then provide general 3. Build TensorRT engines for visual components ```bash - python build_visual_engine.py --model_name ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 8 + python build_visual_engine.py --model_type ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} --max_batch_size 8 ``` The built engines are located in `./visual_engines/${MODEL_NAME}`. @@ -96,7 +103,7 @@ OPT pipeline needs few minor changes from T5 pipeline --max_input_len 924 \ --max_output_len 100 - python build_visual_engine.py --model_name ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} + python build_visual_engine.py --model_type ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} python run.py \ --blip2_encoder \ @@ -134,14 +141,32 @@ OPT pipeline needs few minor changes from T5 pipeline **NOTE:** INT8/INT4 option is not supported for BLIP2-T5, because quantization support has not been added for encoder-decoder models yet. -## LLaVA +## LLaVA and VILA + +[LLaVA](https://github.com/haotian-liu/LLaVA) and [VILA](https://github.com/Efficient-Large-Model/VILA) are both visual language models (VLM) that can be deployed in TensorRT-LLM with many quantization options. -1. Download Huggingface model weights. This model has both LLM and visual components +1. Download Huggingface model weights. These models have both visual and LLM components unlike BLIP2 example which downloads only LLM components from Huggingface. + For LLaVA, ```bash - export MODEL_NAME="llava-1.5-7b-hf" - git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} + export MODEL_NAME="llava-1.5-7b-hf" + git clone https://huggingface.co/llava-hf/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} + ``` + + For VILA, we need a few more steps until it is added to HF model zoo + ```bash + export MODEL_NAME="vila-7B" + git clone https://huggingface.co/Efficient-Large-Model/${MODEL_NAME} tmp/hf_models/${MODEL_NAME} + # clone original VILA repo (please clone to the same level of ${MODEL_NAME} directory) + git clone https://github.com/Efficient-Large-Model/VILA.git tmp/hf_models/VILA + # reuse LLaVA's preprocessor + wget https://huggingface.co/llava-hf/llava-1.5-7b-hf/resolve/main/preprocessor_config.json -P tmp/hf_models/${MODEL_NAME}/ + # turn off delay_load to allow model component access + sed -i 's/delay_load=True/delay_load=False/g' tmp/hf_models/VILA/llava/model/llava_arch.py + # line manipulation to enable AWQ. otherwise need to replace HF's llama implementation + sed -i '/vision_tower = self.get_vision_tower()/a \ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)' tmp/hf_models/VILA/llava/model/llava_arch.py + sed -i 's/seqlens_in_batch=sorted_seqlens_in_batch/#seqlens_in_batch=sorted_seqlens_in_batch/g' tmp/hf_models/VILA/llava/model/language_model/llava_llama.py ``` 2. Generate TRT-LLM engine for LLaMA following example in `examples/llama/README.md` @@ -156,16 +181,18 @@ OPT pipeline needs few minor changes from T5 pipeline --checkpoint_dir tmp/trt_models/${MODEL_NAME}/fp16/1-gpu \ --output_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ --gemm_plugin float16 \ + --use_fused_mlp \ --max_batch_size 1 \ --max_input_len 2048 \ --max_output_len 512 \ --max_multimodal_len 576 # 1 (max_batch_size) * 576 (num_visual_features) ``` + Note: do not use `--use_fused_mlp` flag in quantization mode. 3. Build TensorRT engines for visual components ```bash - python build_visual_engine.py --model_name ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} + python build_visual_engine.py --model_path tmp/hf_models/${MODEL_NAME} --model_type llava # or "--model_type vila" for VILA ``` 4. Add `--decoder-llm` argument to inference script, since LLaMA is a decoder-only LLM. @@ -173,14 +200,15 @@ OPT pipeline needs few minor changes from T5 pipeline ```bash python run.py \ --max_new_tokens 30 \ - --input_text "Question: which city is this? Answer:" \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ --visual_engine_dir visual_engines/${MODEL_NAME} \ --llm_engine_dir trt_engines/${MODEL_NAME}/fp16/1-gpu \ - --decoder_llm + --decoder_llm \ + --input_text "Question: which city is this? Answer:" # or "Please describe the traffic condition." for VILA ``` + Note: use `--run_profiling` for performance measurement, use `--check_accuracy` for accuracy check. -5. INT8/INT4 weight-only quantization for LLaMA can be enabled as follows (take `INT4` as an example, while `INT8` is the default precision for weight-only quantization): +5. (Optional) INT8/INT4 weight-only quantization for LLaMA can be enabled as follows (take `INT4` as an example, while `INT8` is the default precision for weight-only quantization): ```bash python ../llama/convert_checkpoint.py \ --model_dir tmp/hf_models/${MODEL_NAME} \ @@ -194,7 +222,7 @@ OPT pipeline needs few minor changes from T5 pipeline --output_dir trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu \ --gemm_plugin float16 \ --max_batch_size 1 \ - --max_input_len 924 \ + --max_input_len 1024 \ --max_output_len 100 \ --max_multimodal_len 576 ``` @@ -202,9 +230,28 @@ OPT pipeline needs few minor changes from T5 pipeline The built engines lie in `trt_engines/${MODEL_NAME}/int4_weightonly/1-gpu`. You should use this directory as `--llm_engine_dir` argument to `run.py` -6. One can use LLaVA with other quantization options, like SmoothQuant and INT4 AWQ, that are supported by LLaMA. +6. (Optional) One can also use LLaVA/VILA with other quantization options, like SmoothQuant and INT4 AWQ, that are supported by LLaMA. Instructions in [../llama/README.md](../llama/README.md) to enable SmoothQuant and INT4 AWQ can be re-used to generate - quantized TRT engines for LLM component of LLaVA. + quantized TRT engines for LLM component of LLaVA/VILA. + + For example, + ```bash + python ../quantization/quantize.py \ + --model_dir tmp/hf_models/${MODEL_NAME} \ + --output_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \ + --dtype float16 \ + --qformat int4_awq \ + --calib_size 32 + + trtllm-build \ + --checkpoint_dir tmp/trt_models/${MODEL_NAME}/int4_awq/1-gpu \ + --output_dir trt_engines/${MODEL_NAME}/int4_awq/1-gpu \ + --gemm_plugin float16 \ + --max_batch_size 1 \ + --max_input_len 1024 \ + --max_output_len 100 \ + --max_multimodal_len 576 + ``` ## Nougat @@ -245,7 +292,7 @@ OPT pipeline needs few minor changes from T5 pipeline 3. Generate TensorRT engines for visual components and combine everything into final pipeline. ```bash - python build_visual_engine.py --model_name ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} + python build_visual_engine.py --model_type nougat --model_path tmp/hf_models/${MODEL_NAME} python run.py \ --hf_model_dir tmp/hf_models/${MODEL_NAME} \ @@ -253,6 +300,7 @@ OPT pipeline needs few minor changes from T5 pipeline --llm_engine_dir trt_engines/${MODEL_NAME}/1-gpu/bfloat16/tp1 \ --nougat ``` + Note: Nougat models usually do not need a text prompt. ## Enabling tensor parallelism for multi-GPU @@ -282,7 +330,7 @@ The full set of commands to enable 2-way tensor parallelism for LLaVA is: --max_output_len 512 \ --max_multimodal_len 576 - python build_visual_engine.py --model_name ${MODEL_NAME} --model_path tmp/hf_models/${MODEL_NAME} + python build_visual_engine.py --model_type llava --model_path tmp/hf_models/${MODEL_NAME} mpirun -n 2 --allow-run-as-root \ python run.py \ diff --git a/examples/multimodal/build_visual_engine.py b/examples/multimodal/build_visual_engine.py index 77cd5abe1..22389d9ed 100644 --- a/examples/multimodal/build_visual_engine.py +++ b/examples/multimodal/build_visual_engine.py @@ -1,6 +1,7 @@ import argparse import os import shutil +import sys from time import time # isort: off @@ -16,7 +17,7 @@ def export_visual_wrapper_onnx(visual_wrapper, image, output_dir): logger.log(trt.Logger.INFO, "Exporting onnx") - os.mkdir(f'{output_dir}/onnx') + os.makedirs(f'{output_dir}/onnx', exist_ok=True) torch.onnx.export(visual_wrapper, image, f'{output_dir}/onnx/visual_encoder.onnx', @@ -80,11 +81,8 @@ def build_trt_engine(img_height, img_width, output_dir, max_batch_size): def build_blip2_engine(args): - model_type = 'Salesforce/blip2-' + args.model_name + model_type = 'Salesforce/blip2-' + args.model_type processor = Blip2Processor.from_pretrained(model_type) - model = Blip2ForConditionalGeneration.from_pretrained( - model_type, torch_dtype=torch.float16) - model.to(args.device) raw_image = Image.new('RGB', [10, 10]) # dummy image prompt = "Question: what is this? Answer:" @@ -94,12 +92,12 @@ def build_blip2_engine(args): class Blip2VisionWrapper(torch.nn.Module): - def __init__(self, model): + def __init__(self, vision_model, qformer, projector, query_tokens): super().__init__() - self.vision_model = model.vision_model - self.qformer = model.qformer - self.projector = model.language_projection - self.query_tokens = model.query_tokens + self.vision_model = vision_model + self.qformer = qformer + self.projector = projector + self.query_tokens = query_tokens def forward(self, image): features = self.vision_model(image)[0] @@ -108,7 +106,11 @@ def forward(self, image): return_dict=True) return self.projector(qformer_output.last_hidden_state) - wrapper = Blip2VisionWrapper(model) + model = Blip2ForConditionalGeneration.from_pretrained( + model_type, torch_dtype=torch.float16) + wrapper = Blip2VisionWrapper(model.vision_model, model.qformer, + model.language_projection, model.query_tokens) + wrapper.to(args.device) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine(image.shape[2], image.shape[3], args.output_dir, @@ -138,9 +140,8 @@ def forward(self, image): model = LlavaForConditionalGeneration.from_pretrained( args.model_path, torch_dtype=torch.float16) - model.to(args.device) - wrapper = LlavaVisionWrapper(model.vision_tower, - model.multi_modal_projector, + wrapper = LlavaVisionWrapper(model.vision_tower.to(args.device), + model.multi_modal_projector.to(args.device), model.config.vision_feature_layer) export_visual_wrapper_onnx(wrapper, image, args.output_dir) @@ -148,6 +149,39 @@ def forward(self, image): args.max_batch_size) +def build_vila_engine(args): + # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo + sys.path.append(args.model_path + "/../VILA") + from llava.model import LlavaLlamaForCausalLM + + processor = AutoProcessor.from_pretrained(args.model_path) + raw_image = Image.new('RGB', [10, 10]) # dummy image + image = processor(text="dummy", images=raw_image, + return_tensors="pt")['pixel_values'].to( + args.device, torch.float16) + + class VilaVisionWrapper(torch.nn.Module): + + def __init__(self, tower, projector): + super().__init__() + self.tower = tower + self.projector = projector + + def forward(self, image): + features = self.tower(image) + return self.projector(features) + + model = LlavaLlamaForCausalLM.from_pretrained(args.model_path, + torch_dtype=torch.float16) + wrapper = VilaVisionWrapper( + model.get_model().get_vision_tower().to(args.device), + model.get_model().mm_projector.to(args.device)) + + export_visual_wrapper_onnx(wrapper, image, args.output_dir) + build_trt_engine(image.shape[2], image.shape[3], args.output_dir, + args.max_batch_size) + + def build_nougat_engine(args): processor = NougatProcessor.from_pretrained(args.model_path) raw_image = Image.new('RGB', [10, 10]) # dummy image @@ -174,20 +208,20 @@ def forward(self, image): if __name__ == '__main__': - logger = trt.Logger(trt.Logger.ERROR) + logger = trt.Logger(trt.Logger.INFO) parser = argparse.ArgumentParser() - parser.add_argument('--model_name', + parser.add_argument('--model_type', type=str, default=None, - help="Model name") + help="Model type") parser.add_argument('--model_path', type=str, default=None, help="Huggingface repo or local directory with weights") parser.add_argument('--output_dir', type=str, - default='visual_engines', + default=None, help="Directory where visual TRT engines are saved") parser.add_argument('--max_batch_size', type=int, @@ -197,15 +231,18 @@ def forward(self, image): args.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" - args.output_dir = args.output_dir + "/" + args.model_name + if args.output_dir is None: + args.output_dir = 'visual_engines/%s' % (args.model_path.split('/')[-1]) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) - if args.model_name in ['opt-2.7b', 'flan-t5-xl']: + if args.model_type in ['opt-2.7b', 'flan-t5-xl']: build_blip2_engine(args) - elif 'llava' in args.model_name: + elif args.model_type == 'llava': build_llava_engine(args) - elif 'nougat' in args.model_name: + elif args.model_type == 'vila': + build_vila_engine(args) + elif args.model_type == 'nougat': build_nougat_engine(args) else: - raise RuntimeError(f"Invalid model name {args.model_name}") + raise RuntimeError(f"Invalid model type {args.model_type}") diff --git a/examples/multimodal/run.py b/examples/multimodal/run.py index 93750de55..03ebf4d9f 100644 --- a/examples/multimodal/run.py +++ b/examples/multimodal/run.py @@ -55,7 +55,7 @@ def parse_arguments(): help='Run nougat pipeline') parser.add_argument('--input_text', type=str, - default='Question: which city is this? Answer:', + default=None, help='Text prompt to LLM') parser.add_argument('--num_beams', type=int, @@ -144,21 +144,14 @@ def init_llm(self): self.model_config = self.model.encoder_model_config self.runtime_mapping = self.model.encoder_runtime_mapping - config = AutoConfig.from_pretrained(self.args.hf_model_dir) - decoder_start_id = config.decoder_start_token_id - if decoder_start_id is None: - decoder_start_id = self.tokenizer.bos_token_id - - decoder_input_ids = torch.IntTensor([[decoder_start_id] - ]).to(self.device) - batch_size = self.args.batch_size - self.decoder_input_ids = decoder_input_ids.repeat((batch_size, 1)) - - def generate(self, pre_prompt, post_prompt, image, max_new_tokens): - profiler.start("Generate") - profiler.start("Vision") + def generate(self, pre_prompt, post_prompt, image, decoder_input_ids, + max_new_tokens, warmup): + if not warmup: + profiler.start("Generate") + profiler.start("Vision") visual_features, visual_atts = self.get_visual_features(image) - profiler.stop("Vision") + if not warmup: + profiler.stop("Vision") pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", @@ -178,11 +171,11 @@ def generate(self, pre_prompt, post_prompt, image, max_new_tokens): input_ids, ptuning_args = self.setup_fake_prompts( visual_features, pre_input_ids, post_input_ids, input_lengths) - if self.args.decoder_llm and tensorrt_llm.mpi_rank() == 0: + if warmup and self.args.decoder_llm and tensorrt_llm.mpi_rank() == 0: prompt_table = ptuning_args[0] prompt_table = torch.stack([prompt_table]) np.save('prompt_table.npy', torch_to_numpy(prompt_table)) - tensorrt_llm.mpi_barrier() # Sync before reading prompt_table file + if warmup: return None profiler.start("LLM") if self.args.decoder_llm: @@ -212,7 +205,7 @@ def generate(self, pre_prompt, post_prompt, image, max_new_tokens): output_ids = self.model.generate( input_ids, - self.decoder_input_ids, + decoder_input_ids, max_new_tokens, num_beams=self.args.num_beams, bos_token_id=self.tokenizer.bos_token_id, @@ -293,6 +286,7 @@ def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, return input_ids, ptuning_args def ptuning_setup(self, prompt_table, input_ids, input_lengths): + hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size if prompt_table is not None: task_vocab_size = torch.tensor( [prompt_table.shape[1]], @@ -302,7 +296,6 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths): (prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])) - hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size assert prompt_table.shape[ 1] == hidden_size, "Prompt table dimensions do not match hidden size" @@ -324,7 +317,11 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths): def load_test_image(model_name): - if "nougat" in model_name: + if "vila" in model_name: + img_url = 'https://github.com/Efficient-Large-Model/VILA/raw/main/demo_images/av.png' + image = Image.open(requests.get(img_url, + stream=True).raw).convert('RGB') + elif "nougat" in model_name: filepath = hf_hub_download( repo_id="hf-internal-testing/fixtures_docvqa", filename="nougat_paper.png", @@ -351,6 +348,9 @@ def load_test_image(model_name): else: model_type = 'Salesforce/blip2-flan-t5-xl' + if args.input_text is None: + args.input_text = "Question: which city is this? Answer:" + processor = Blip2Processor.from_pretrained(model_type) image = processor(image, args.input_text, return_tensors="pt")['pixel_values'] @@ -361,17 +361,29 @@ def load_test_image(model_name): processor = NougatProcessor.from_pretrained(args.hf_model_dir) image = processor(image, return_tensors="pt")['pixel_values'] + # Nougat doesn't need text prompt (mBART use single token to start generation), just leave a dummy one here + if args.input_text is None: + args.input_text = "Question: which city is this? Answer:" + pre_prompt = args.input_text post_prompt = None else: + # LLaVA and VILA + if "llava" in args.hf_model_dir: + pre_prompt = "USER:\n" + if args.input_text is None: + args.input_text = "Question: which city is this? Answer:" + elif "vila" in args.hf_model_dir: + pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: " + if args.input_text is None: + args.input_text = "Please describe the traffic condition." + post_prompt = args.input_text + " ASSISTANT:" + processor = AutoProcessor.from_pretrained(args.hf_model_dir) image = processor(text=args.input_text, images=image, return_tensors="pt")['pixel_values'] - pre_prompt = "USER:\n" - post_prompt = args.input_text + " ASSISTANT:" - # Repeat inputs to match batch size pre_prompt = [pre_prompt] * args.batch_size post_prompt = [post_prompt] * args.batch_size @@ -380,30 +392,67 @@ def load_test_image(model_name): model = MultiModalModel(args) image = image.to(model.device) - num_iters = 100 if args.run_profiling else 1 + # Generate decoder_input_ids for enc-dec models + # Custom prompts can be added as: + # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids + if args.decoder_llm: + decoder_input_ids = None + else: + config = AutoConfig.from_pretrained(args.hf_model_dir) + decoder_start_id = config.decoder_start_token_id # T5 + if decoder_start_id is None: + decoder_start_id = config.decoder.bos_token_id # Nougat + + decoder_input_ids = torch.IntTensor([[decoder_start_id]]) + decoder_input_ids = decoder_input_ids.repeat((args.batch_size, 1)) + + model.generate(pre_prompt, + post_prompt, + image, + decoder_input_ids, + args.max_new_tokens, + warmup=True) + tensorrt_llm.mpi_barrier() + + num_iters = 20 if args.run_profiling else 1 for _ in range(num_iters): - stripped_text = model.generate(pre_prompt, post_prompt, image, - args.max_new_tokens) + stripped_text = model.generate(pre_prompt, + post_prompt, + image, + decoder_input_ids, + args.max_new_tokens, + warmup=False) if runtime_rank == 0: logger.info("---------------------------------------------------------") if not args.nougat: logger.info(f"\n[Q] {args.input_text}") - logger.info(f"\n[A] {stripped_text}") - - if args.check_accuracy and not args.nougat: - assert stripped_text[0][0].lower() == 'singapore' + logger.info(f"\n[A] {stripped_text[0]}") + + if args.num_beams == 1: + output_ids = model.tokenizer(stripped_text[0][0], + add_special_tokens=False)['input_ids'] + logger.info(f"Generated {len(output_ids)} tokens") + + if args.check_accuracy: + for i in range(args.batch_size - 1): + if not (stripped_text[i] == stripped_text[i + 1]): + logger.info(f"Output {i} and {i + 1} do not match") + assert False + if not args.nougat: + if "vila" in args.hf_model_dir: + assert stripped_text[0][0].lower( + ) == 'the traffic condition in the image is quite busy, with multiple cars and bicycles sharing the road. there are also pedestrians walking on' + else: + assert stripped_text[0][0].lower() == 'singapore' if args.run_profiling: - vision_latency = profiler.elapsed_time_in_sec("Vision") / num_iters - logger.info( - f'TensorRT vision encoder latency: {vision_latency} sec') - - llm_latency = profiler.elapsed_time_in_sec("LLM") / num_iters - logger.info(f'TensorRT-LLM LLM latency: {llm_latency} sec') - - generate_latency = profiler.elapsed_time_in_sec( - "Generate") / num_iters - logger.info(f'Generate latency: {generate_latency} sec') + msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec( + name) / num_iters + logger.info('Latencies per batch (msec)') + logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision'))) + logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM'))) + logger.info('Multimodal generate: %.1f' % + (msec_per_batch('Generate'))) logger.info("---------------------------------------------------------") diff --git a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp index 87677d86e..004373d94 100644 --- a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp +++ b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.cpp @@ -32,8 +32,8 @@ using namespace nvinfer1; using openai_triton::plugin::TritonFlashAttentionPluginCreator; using openai_triton::plugin::TritonFlashAttentionPlugin; -static const char* TRITON_FLASH_ATTENTION_PLUGIN_VERSION{"1"}; -static const char* TRITON_FLASH_ATTENTION_PLUGIN_NAME{"TritonFlashAttention"}; +static char const* TRITON_FLASH_ATTENTION_PLUGIN_VERSION{"1"}; +static char const* TRITON_FLASH_ATTENTION_PLUGIN_NAME{"TritonFlashAttention"}; PluginFieldCollection TritonFlashAttentionPluginCreator::mFC{}; std::vector TritonFlashAttentionPluginCreator::mPluginAttributes; @@ -42,7 +42,7 @@ namespace openai_triton::plugin // Write values into buffer template -void writeArg(char*& buffer, const T& val) +void writeArg(char*& buffer, T const& val) { std::memcpy(buffer, &val, sizeof(T)); buffer += sizeof(T); @@ -50,7 +50,7 @@ void writeArg(char*& buffer, const T& val) // Read values from buffer template -void readArg(const char*& buffer, T& val) +void readArg(char const*& buffer, T& val) { std::memcpy(&val, buffer, sizeof(T)); buffer += sizeof(T); @@ -79,9 +79,9 @@ TritonFlashAttentionPlugin::TritonFlashAttentionPlugin( } // Parameterized constructor -TritonFlashAttentionPlugin::TritonFlashAttentionPlugin(const void* data, size_t length) +TritonFlashAttentionPlugin::TritonFlashAttentionPlugin(void const* data, size_t length) { - const char *d = reinterpret_cast(data), *a = d; + char const *d = reinterpret_cast(data), *a = d; readArg(d, mNumHeads); readArg(d, mHeadSize); readArg(d, mSoftmaxScale); @@ -98,7 +98,7 @@ nvinfer1::IPluginV2DynamicExt* TritonFlashAttentionPlugin::clone() const noexcep } nvinfer1::DimsExprs TritonFlashAttentionPlugin::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept + int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { // Output shape. // output tensor [batchSize, seqLen, mNumHeads, head_size] @@ -107,7 +107,7 @@ nvinfer1::DimsExprs TritonFlashAttentionPlugin::getOutputDimensions( } bool TritonFlashAttentionPlugin::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept { // In this example, inputs: Q, K, V, outputs: Out assert(nbInputs + nbOutputs == 4); @@ -125,19 +125,19 @@ bool TritonFlashAttentionPlugin::supportsFormatCombination( return is_valid; } -void TritonFlashAttentionPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept +void TritonFlashAttentionPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept { } -size_t TritonFlashAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept +size_t TritonFlashAttentionPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept { // Set workspace size if needed. In this example, we need for L and m buffers. - const auto Q = inputs[0]; - const int batchSize = Q.dims.d[0]; - const int seqLen = Q.dims.d[2]; - const int numBuffers = 2; + auto const Q = inputs[0]; + int const batchSize = Q.dims.d[0]; + int const seqLen = Q.dims.d[2]; + int const numBuffers = 2; size_t workspaces[numBuffers]; workspaces[0] = sizeof(float) * batchSize * mNumHeads * seqLen; workspaces[1] = sizeof(float) * batchSize * mNumHeads * seqLen; @@ -155,8 +155,8 @@ size_t TritonFlashAttentionPlugin::getWorkspaceSize(const nvinfer1::PluginTensor } template -int TritonFlashAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int TritonFlashAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) { assert(inputDesc[0].dims.d[1] == mNumHeads && inputDesc[0].dims.d[3] == mHeadSize); @@ -172,9 +172,9 @@ int TritonFlashAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* in float* L = reinterpret_cast(workspace); float* M = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(L), bufSize)); - const T* Q = reinterpret_cast(inputs[0]); - const T* K = reinterpret_cast(inputs[1]); - const T* V = reinterpret_cast(inputs[2]); + T const* Q = reinterpret_cast(inputs[0]); + T const* K = reinterpret_cast(inputs[1]); + T const* V = reinterpret_cast(inputs[2]); // Launch a cuda kernel generated by Triton AoT. int res = 0; @@ -193,8 +193,8 @@ int TritonFlashAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* in return res; } -int TritonFlashAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, void* const* outputs, void* workspace, +int TritonFlashAttentionPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept { if (mType == DataType::kHALF) @@ -210,7 +210,7 @@ int TritonFlashAttentionPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputD // IPluginV2Ext Methods nvinfer1::DataType TritonFlashAttentionPlugin::getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept { assert(index == 0); return inputTypes[0]; @@ -218,12 +218,12 @@ nvinfer1::DataType TritonFlashAttentionPlugin::getOutputDataType( // IPluginV2 Methods -const char* TritonFlashAttentionPlugin::getPluginType() const noexcept +char const* TritonFlashAttentionPlugin::getPluginType() const noexcept { return TRITON_FLASH_ATTENTION_PLUGIN_NAME; } -const char* TritonFlashAttentionPlugin::getPluginVersion() const noexcept +char const* TritonFlashAttentionPlugin::getPluginVersion() const noexcept { return TRITON_FLASH_ATTENTION_PLUGIN_VERSION; } @@ -269,12 +269,12 @@ void TritonFlashAttentionPlugin::destroy() noexcept delete this; } -void TritonFlashAttentionPlugin::setPluginNamespace(const char* libNamespace) noexcept +void TritonFlashAttentionPlugin::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } -const char* TritonFlashAttentionPlugin::getPluginNamespace() const noexcept +char const* TritonFlashAttentionPlugin::getPluginNamespace() const noexcept { return mNamespace.c_str(); } @@ -293,24 +293,24 @@ TritonFlashAttentionPluginCreator::TritonFlashAttentionPluginCreator() mFC.fields = mPluginAttributes.data(); } -const char* TritonFlashAttentionPluginCreator::getPluginName() const noexcept +char const* TritonFlashAttentionPluginCreator::getPluginName() const noexcept { return TRITON_FLASH_ATTENTION_PLUGIN_NAME; } -const char* TritonFlashAttentionPluginCreator::getPluginVersion() const noexcept +char const* TritonFlashAttentionPluginCreator::getPluginVersion() const noexcept { return TRITON_FLASH_ATTENTION_PLUGIN_VERSION; } -const PluginFieldCollection* TritonFlashAttentionPluginCreator::getFieldNames() noexcept +PluginFieldCollection const* TritonFlashAttentionPluginCreator::getFieldNames() noexcept { return &mFC; } -IPluginV2* TritonFlashAttentionPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept +IPluginV2* TritonFlashAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept { - const PluginField* fields = fc->fields; + PluginField const* fields = fc->fields; int numHeads = 0; int headSize = 0; float softmaxScale = 1.0f; @@ -318,26 +318,26 @@ IPluginV2* TritonFlashAttentionPluginCreator::createPlugin(const char* name, con // Read configurations from each fields for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; + char const* attrName = fields[i].name; if (!strcmp(attrName, "num_heads")) { assert(fields[i].type == PluginFieldType::kINT32); - numHeads = static_cast(*(static_cast(fields[i].data))); + numHeads = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "head_size")) { assert(fields[i].type == PluginFieldType::kINT32); - headSize = static_cast(*(static_cast(fields[i].data))); + headSize = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "softmax_scale")) { assert(fields[i].type == PluginFieldType::kFLOAT32); - softmaxScale = static_cast(*(static_cast(fields[i].data))); + softmaxScale = static_cast(*(static_cast(fields[i].data))); } else if (!strcmp(attrName, "type_id")) { assert(fields[i].type == PluginFieldType::kINT32); - type = static_cast(*(static_cast(fields[i].data))); + type = static_cast(*(static_cast(fields[i].data))); } } try @@ -346,7 +346,7 @@ IPluginV2* TritonFlashAttentionPluginCreator::createPlugin(const char* name, con obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { std::cerr << "Caught exception: " << e.what() << std::endl; } @@ -354,7 +354,7 @@ IPluginV2* TritonFlashAttentionPluginCreator::createPlugin(const char* name, con } IPluginV2* TritonFlashAttentionPluginCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept + char const* name, void const* serialData, size_t serialLength) noexcept { // This object will be deleted when the network is destroyed, which will // call TritonFlashAttentionPlugin::destroy() @@ -364,19 +364,19 @@ IPluginV2* TritonFlashAttentionPluginCreator::deserializePlugin( obj->setPluginNamespace(mNamespace.c_str()); return obj; } - catch (const std::exception& e) + catch (std::exception const& e) { std::cerr << "Caught exception: " << e.what() << std::endl; } return nullptr; } -void TritonFlashAttentionPluginCreator::setPluginNamespace(const char* libNamespace) noexcept +void TritonFlashAttentionPluginCreator::setPluginNamespace(char const* libNamespace) noexcept { mNamespace = libNamespace; } -const char* TritonFlashAttentionPluginCreator::getPluginNamespace() const noexcept +char const* TritonFlashAttentionPluginCreator::getPluginNamespace() const noexcept { return mNamespace.c_str(); } diff --git a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.h b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.h index f9d51a06d..cd95eb48e 100644 --- a/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.h +++ b/examples/openai_triton/manual_plugin/TritonFlashAttentionPlugin.h @@ -34,42 +34,42 @@ class TritonFlashAttentionPlugin : public nvinfer1::IPluginV2DynamicExt public: TritonFlashAttentionPlugin(int numHeads, int headSize, float softmaxScale, nvinfer1::DataType type); - TritonFlashAttentionPlugin(const void* data, size_t length); + TritonFlashAttentionPlugin(void const* data, size_t length); ~TritonFlashAttentionPlugin() override = default; // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override; bool supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; + int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs, + nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override; + int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override; template - int enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); + int enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream); // IPluginV2Ext Methods nvinfer1::DataType getOutputDataType( - int index, const nvinfer1::DataType* inputTypes, int nbInputs) const noexcept override; + int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override; // IPluginV2 Methods - const char* getPluginType() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginType() const noexcept override; + char const* getPluginVersion() const noexcept override; int getNbOutputs() const noexcept override; int initialize() noexcept override; void terminate() noexcept override; size_t getSerializationSize() const noexcept override; void serialize(void* buffer) const noexcept override; void destroy() noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + char const* getPluginNamespace() const noexcept override; private: const std::string mLayerName; @@ -89,20 +89,20 @@ class TritonFlashAttentionPluginCreator : public nvinfer1::IPluginCreator public: TritonFlashAttentionPluginCreator(); - const char* getPluginName() const noexcept override; + char const* getPluginName() const noexcept override; - const char* getPluginVersion() const noexcept override; + char const* getPluginVersion() const noexcept override; - const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override; - nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override; nvinfer1::IPluginV2* deserializePlugin( - const char* name, const void* serialData, size_t serialLength) noexcept override; + char const* name, void const* serialData, size_t serialLength) noexcept override; - void setPluginNamespace(const char* pluginNamespace) noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; - const char* getPluginNamespace() const noexcept override; + char const* getPluginNamespace() const noexcept override; private: static nvinfer1::PluginFieldCollection mFC; diff --git a/examples/openai_triton/manual_plugin/tritonPlugins.cpp b/examples/openai_triton/manual_plugin/tritonPlugins.cpp index 636132608..27b1ece08 100644 --- a/examples/openai_triton/manual_plugin/tritonPlugins.cpp +++ b/examples/openai_triton/manual_plugin/tritonPlugins.cpp @@ -40,7 +40,7 @@ class TritonPluginCreatorRegistry } template - void addPluginCreator(void* logger, const char* libNamespace) + void addPluginCreator(void* logger, char const* libNamespace) { // Make accesses to the plugin creator registry thread safe std::lock_guard lock(mRegistryLock); @@ -114,7 +114,7 @@ class TritonPluginCreatorRegistry }; template -void initializeTritonPlugin(void* logger, const char* libNamespace) +void initializeTritonPlugin(void* logger, char const* libNamespace) { TritonPluginCreatorRegistry::getInstance().addPluginCreator(logger, libNamespace); } @@ -125,7 +125,7 @@ void initializeTritonPlugin(void* logger, const char* libNamespace) extern "C" { - bool initOpenAiTritonPlugins(void* logger, const char* libNamespace) + bool initOpenAiTritonPlugins(void* logger, char const* libNamespace) { initializeTritonPlugin(logger, libNamespace); return true; diff --git a/examples/opt/requirements.txt b/examples/opt/requirements.txt index fb205c8dd..9f15a5f00 100644 --- a/examples/opt/requirements.txt +++ b/examples/opt/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/phi/convert_checkpoint.py b/examples/phi/convert_checkpoint.py index 491ae5b94..4f55ac39d 100644 --- a/examples/phi/convert_checkpoint.py +++ b/examples/phi/convert_checkpoint.py @@ -16,287 +16,51 @@ import json import os import time -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed -import numpy as np import safetensors -import torch from transformers import AutoModelForCausalLM import tensorrt_llm -from tensorrt_llm._utils import pad_vocab_size, str_dtype_to_torch +from tensorrt_llm.models.phi.convert import convert_hf_config, convert_hf_phi - -def torch_split(v, tensor_parallel, idx, dim=0): - if tensor_parallel == 1: - return v - else: - return (torch.split(v, v.shape[dim] // tensor_parallel, - dim=dim)[idx]).contiguous() - - -def convert_hf_phi(hf_model, - rank=0, - tensor_parallel=1, - dtype='float32', - use_parallel_embedding=False, - sharding_dim=0): - - hf_model_phi_block_names = [ - "input_layernorm.weight", - "input_layernorm.bias", - "self_attn.dense.weight", - "self_attn.dense.bias", - "mlp.fc1.weight", - "mlp.fc1.bias", - "mlp.fc2.weight", - "mlp.fc2.bias", - ] - - tensorrt_llm_model_phi_block_names = [ - "input_layernorm.weight", - "input_layernorm.bias", - "attention.dense.weight", - "attention.dense.bias", - "mlp.fc.weight", - "mlp.fc.bias", - "mlp.proj.weight", - "mlp.proj.bias", - ] - - weights = {} - torch_dtype = str_dtype_to_torch(dtype) - hf_phi_state_dict = hf_model.state_dict() - - # Embedding - # [vocab_size, hidden_size] - v = hf_phi_state_dict.get('model.embed_tokens.weight').to(torch_dtype).cpu() - if use_parallel_embedding: - v = torch_split(v, tensor_parallel, rank, sharding_dim) - weights['transformer.vocab_embedding.weight'] = v - - # Decoder Layers - n_layer = hf_model.config.num_hidden_layers - for layer_idx in range(n_layer): - hf_prefix = f"model.layers.{layer_idx}." - tllm_prex = f'transformer.layers.{layer_idx}.' - - # MLPs - for idx, hf_attr in enumerate(hf_model_phi_block_names): - v = hf_phi_state_dict.get(hf_prefix + hf_attr).to(torch_dtype).cpu() - - if tensor_parallel > 1: - if 'self_attn.dense.weight' in hf_attr: - # [n=hidden_size, k=hidden_size] -> - # [n=hidden_size, k=hidden_size // tensor_parallel] - v = torch_split(v, tensor_parallel, rank, dim=1) - elif 'mlp.fc1.weight' in hf_attr: - # [hidden_size * 4, hidden_size] -> - # [hidden_size * 4 // tensor_parallel, hidden_size] - v = torch_split(v, tensor_parallel, rank, dim=0) - elif 'mlp.fc1.bias' in hf_attr: - # [hidden_size * 4] -> [hidden_size * 4 // tensor_parallel] - v = torch_split(v, tensor_parallel, rank, dim=0) - elif 'mlp.fc2.weight' in hf_attr: - # [hidden_size, hidden_size * 4] -> - # [hidden_size, hidden_size * 4 // tensor_parallel] - v = torch_split(v, tensor_parallel, rank, dim=1) - - tllm_attr = tensorrt_llm_model_phi_block_names[idx] - weights[f'{tllm_prex}{tllm_attr}'] = v - - # Attention QKV Linear - num_heads = hf_model.config.num_attention_heads - hidden_size = hf_model.config.hidden_size - hidden_size // num_heads - - # [(num_heads x q)|(num_heads x k)|(num_heads x v), hidden_size] - q_weights = hf_phi_state_dict.get(hf_prefix + "self_attn.q_proj.weight") - k_weights = hf_phi_state_dict.get(hf_prefix + "self_attn.k_proj.weight") - v_weights = hf_phi_state_dict.get(hf_prefix + "self_attn.v_proj.weight") - q_bias = hf_phi_state_dict.get(hf_prefix + "self_attn.q_proj.bias") - k_bias = hf_phi_state_dict.get(hf_prefix + "self_attn.k_proj.bias") - v_bias = hf_phi_state_dict.get(hf_prefix + "self_attn.v_proj.bias") - qkv_weights = torch.cat((q_weights, k_weights, v_weights), dim=0) - qkv_bias = torch.cat((q_bias, k_bias, v_bias), dim=0) - - qkv_weights = qkv_weights.reshape([hidden_size * 3, hidden_size]) - qkv_bias = qkv_bias.reshape([hidden_size * 3]) - - if tensor_parallel > 1: - qkv_weights = qkv_weights.reshape( - 3, hidden_size, hidden_size).to(torch_dtype).cpu() - qkv_weights = torch_split(qkv_weights, tensor_parallel, rank, - dim=1).reshape( - 3 * (hidden_size // tensor_parallel), - hidden_size) - - qkv_bias = qkv_bias.reshape(3, hidden_size).to(torch_dtype).cpu() - qkv_bias = torch_split(qkv_bias, tensor_parallel, rank, - dim=1).reshape( - 3 * (hidden_size // tensor_parallel)) - - weights[ - f"{tllm_prex}attention.qkv.weight"] = qkv_weights.contiguous() - weights[f"{tllm_prex}attention.qkv.bias"] = qkv_bias.contiguous() - else: - weights[f"{tllm_prex}attention.qkv.weight"] = qkv_weights.to( - torch_dtype).cpu() - weights[f"{tllm_prex}attention.qkv.bias"] = qkv_bias.to( - torch_dtype).cpu() - - # Final Layer Norm - v = hf_phi_state_dict.get('model.final_layernorm.weight') - weights["transformer.ln_f.weight"] = v.to(torch_dtype).cpu() - - v = hf_phi_state_dict.get('model.final_layernorm.bias') - weights["transformer.ln_f.bias"] = v.to(torch_dtype).cpu() - - # LM Head - v = hf_phi_state_dict.get('lm_head.weight').to(torch_dtype).cpu() - if tensor_parallel > 1: - # [vocab_size, hidden_size] -> - # [vocab_size // tensor_parallel, hidden_size] - if v.shape[0] % tensor_parallel != 0: - # padding - vocab_size_padded = pad_vocab_size(v.shape[0], tensor_parallel) - pad_width = vocab_size_padded - v.shape[0] - v = np.pad(v, ((0, pad_width), (0, 0)), - 'constant', - constant_values=0) - - v = torch_split(v, tensor_parallel, rank, dim=0) - weights["lm_head.weight"] = v - - v = hf_phi_state_dict.get('lm_head.bias').to(torch_dtype).cpu() - if tensor_parallel > 1: - v = torch_split(v, tensor_parallel, rank, dim=0) - weights["lm_head.bias"] = v - - return weights +__all__ = ['convert_hf_phi', 'convert_hf_config'] def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--model_dir', type=str, default=None) - parser.add_argument('--tp_size', - type=int, - default=1, - help='N-way tensor parallelism size') - parser.add_argument('--pp_size', - type=int, - default=1, - help='N-way pipeline parallelism size') parser.add_argument('--dtype', type=str, default='float16', choices=['float32', 'bfloat16', 'float16']) - parser.add_argument( - '--use_parallel_embedding', - action="store_true", - default=False, - help= - 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' - ) - parser.add_argument( - '--embedding_sharding_dim', - type=int, - default=0, - choices=[0, 1], - help= - 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' - 'To shard it along hidden dimension, set embedding_sharding_dim=1' - 'Note: embedding sharing is only enabled when embedding_sharding_dim = 0' - ) parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', help='The path to save the TensorRT-LLM checkpoint') - parser.add_argument( - '--workers', - type=int, - default=1, - help='The number of workers for converting checkpoint in parallel') args = parser.parse_args() return args if __name__ == '__main__': - # TODO(qijun): Currently, the convert script depends on a torch op: - # torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix, - # which is included in tensorrt_llm Python package. Otherwise, the convert - # script does not need to import tensorrt_llm. Will remove it after reimplementing - # the op with PyTorch. print(tensorrt_llm.__version__) args = parse_arguments() - world_size = args.tp_size * args.pp_size - assert args.pp_size == 1, "Pipeline parallelism is not supported." tik = time.time() - if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) hf_model = AutoModelForCausalLM.from_pretrained(args.model_dir, torch_dtype="auto", trust_remote_code=True) - hf_config = hf_model.config - config = { - 'architecture': hf_config.architectures[0], - 'dtype': args.dtype, - 'num_hidden_layers': hf_config.num_hidden_layers, - 'num_attention_heads': hf_config.num_key_value_heads, - 'partial_rotary_factor': hf_config.partial_rotary_factor, - 'rope_theta': hf_config.rope_theta, - 'hidden_size': hf_config.hidden_size, - 'intermediate_size': hf_config.intermediate_size, - 'vocab_size': hf_config.vocab_size, - 'max_position_embeddings': hf_config.max_position_embeddings, - 'hidden_act': hf_config.hidden_act, - 'mapping': { - 'world_size': world_size, - 'tp_size': args.tp_size, - 'pp_size': args.pp_size, - }, - 'use_parallel_embedding': False, - 'embedding_sharding_dim': args.embedding_sharding_dim, - 'share_embedding_table': False, - } + trtllm_config = convert_hf_config(hf_model.config, args) with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: - json.dump(config, f, indent=4) - - def covert_and_save(rank): - weights = convert_hf_phi( - hf_model, - rank, - world_size, - dtype=args.dtype, - use_parallel_embedding=args.use_parallel_embedding, - sharding_dim=args.embedding_sharding_dim) - safetensors.torch.save_file( - weights, os.path.join(args.output_dir, f'rank{rank}.safetensors')) + json.dump(trtllm_config, f, indent=4) - if args.workers == 1: - for rank in range(world_size): - covert_and_save(rank) - else: - with ThreadPoolExecutor(max_workers=args.workers) as p: - futures = [ - p.submit(covert_and_save, rank) for rank in range(world_size) - ] - exceptions = [] - for future in as_completed(futures): - try: - future.result() - except Exception as e: - traceback.print_exc() - exceptions.append(e) - assert len( - exceptions - ) == 0, "Checkpoint conversion failed, please check error log." + trtllm_weights = convert_hf_phi(hf_model, dtype=args.dtype) + safetensors.torch.save_file( + trtllm_weights, os.path.join(args.output_dir, f'rank0.safetensors')) tok = time.time() t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) diff --git a/examples/phi/requirements.txt b/examples/phi/requirements.txt index 2c61236d2..6896c3fa4 100644 --- a/examples/phi/requirements.txt +++ b/examples/phi/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.14.5 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/quantization/quantize.py b/examples/quantization/quantize.py index 757811ecd..9608e9a3c 100644 --- a/examples/quantization/quantize.py +++ b/examples/quantization/quantize.py @@ -1,333 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from examples/quantization/hf_ptq.py -""" - import argparse -import copy -import json -import random -import time - -import ammo.torch.quantization as atq -import numpy as np -import torch -from ammo.torch.export import export_model_config -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer - -RAND_SEED = 1234 -MAX_SEQ_LEN = 2048 - -EMPTY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "enable": False, - }, - "*input_quantizer": { - "enable": False - }, - "*lm_head*": { - "enable": False - }, - "*output_layer*": { - "enable": False - }, - "default": { - "enable": False - }, - }, - "algorithm": "max", -} - -KV_CACHE_CFG = { - "*.query_key_value.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.Wqkv.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.W_pack.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.c_attn.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.k_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.v_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, -} - -QUANT_CFG_CHOICES = { - "int8_sq": atq.INT8_SMOOTHQUANT_CFG, - "fp8": atq.FP8_DEFAULT_CFG, - "int4_awq": atq.INT4_AWQ_CFG, - "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, - "int8_wo": EMPTY_CFG, - "int4_wo": EMPTY_CFG, - "full_prec": EMPTY_CFG, -} - -MODEL_NAME_PATTERN_MAP = { - "GPT2": "gpt2", - "Xverse": "llama", - "Llama": "llama", - "Mistral": "llama", - "GPTJ": "gptj", - "FalconForCausalLM": "falcon", - "RWForCausalLM": "falcon", - "baichuan": "baichuan", - "MPT": "mpt", - "Bloom": "bloom", - "ChatGLM": "chatglm", - "QWen": "qwen", - "Gemma": "gemma", -} - - -def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): - print(f"Initializing tokenizer from {ckpt_path}") - tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, - model_max_length=max_seq_len, - padding_side="left", - trust_remote_code=True, - ) - if model_type and model_type == "qwen": - # qwen use token id 151643 as pad and eos tokens - tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) - tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) - - # can't set attribute 'pad_token' for "" - if tokenizer.pad_token != "": - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - assert tokenizer.pad_token is not None, f"Pad token for {model_type} cannot be set!" - - return tokenizer - - -def get_model(ckpt_path, dtype="fp16", device="cuda"): - print(f"Initializing model from {ckpt_path}") - if dtype == "bf16" or dtype == "bfloat16": - dtype = torch.bfloat16 - elif dtype == "fp16" or dtype == "float16": - dtype = torch.float16 - elif dtype == "fp32" or dtype == "float32": - dtype = torch.float32 - else: - raise NotImplementedError(f"Unknown dtype {dtype}") - - model_kwargs = {"torch_dtype": "auto"} - model = AutoModelForCausalLM.from_pretrained(ckpt_path, - device_map="auto", - **model_kwargs, - trust_remote_code=True) - model.eval() - - model_dtype = next(model.parameters()).dtype - if dtype != model_dtype: - print( - f"[TensorRT-LLM][WARNING] The manually set model data type is {dtype}, " - f"but the data type of the HuggingFace model is {model_dtype}.") - - return model - - -def get_model_type(model): - for k, v in MODEL_NAME_PATTERN_MAP.items(): - if k.lower() in type(model).__name__.lower(): - return v - return None - - -def get_calib_dataloader(data="cnn_dailymail", - tokenizer=None, - batch_size=1, - calib_size=512, - block_size=512, - device=None): - print("Loading calibration dataset") - if data == "pileval": - dataset = load_dataset( - "json", - data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", - split="train") - dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") - dataset = dataset["article"][:calib_size] - else: - raise NotImplementedError - - batch_encoded = tokenizer.batch_encode_plus(dataset, - return_tensors="pt", - padding=True, - truncation=True, - max_length=block_size) - if device: - batch_encoded = batch_encoded.to(device) - batch_encoded = batch_encoded["input_ids"] - - calib_dataloader = DataLoader(batch_encoded, - batch_size=batch_size, - shuffle=False) - - return calib_dataloader - - -def quantize_model(model, quant_cfg, calib_dataloader=None): - - def calibrate_loop(): - if calib_dataloader is None: - return - """Adjusts weights and scaling factors based on selected algorithms.""" - for idx, data in enumerate(calib_dataloader): - print(f"Calibrating batch {idx}") - model(data) - - print("Starting quantization...") - start_time = time.time() - atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print("Quantization done. Total time used: {:.2f} s.".format(end_time - - start_time)) - - return model - - -def main(args): - if not torch.cuda.is_available(): - raise EnvironmentError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - model = get_model(args.model_dir, args.dtype, args.device) - model_type = get_model_type(model) - tokenizer = get_tokenizer(args.model_dir, model_type=model_type) - - if args.qformat in ["full_prec", "int8_wo", "int4_wo" - ] and args.kv_cache_dtype is None: - print(f"No quantization applied, export {args.dtype} model") - else: - if "awq" in args.qformat: - if args.calib_size > 32: - print( - f"AWQ calibration could take longer with calib_size = {args.calib_size}, Using" - " calib_size=32 instead") - args.calib_size = 32 - print( - "\nAWQ calibration could take longer than other calibration methods. Please" - " increase the batch size to speed up the calibration process. Batch size can be" - " set by adding the argument --batch_size to the command line.\n" - ) - - calib_dataloader = get_calib_dataloader( - tokenizer=tokenizer, - batch_size=args.batch_size, - calib_size=args.calib_size, - device=args.device, - ) - - if args.qformat in QUANT_CFG_CHOICES: - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - else: - raise ValueError(f"Unsupported quantization format: {args.qformat}") - - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) - weight_quantizer = quant_cfg["quant_cfg"][ - "*weight_quantizer"] # type: ignore - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = args.awq_block_size - - if args.kv_cache_dtype is not None: - if args.kv_cache_dtype == "fp8": - for value in KV_CACHE_CFG.values(): - value.update({"num_bits": (4, 3)}) # type: ignore - quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore - - print(quant_cfg) - - model = quantize_model(model, quant_cfg, calib_dataloader) - - with torch.inference_mode(): - if model_type is None: - print( - f"Unknown model type {type(model).__name__}. Continue exporting..." - ) - model_type = f"unknown:{type(model).__name__}" - - export_path = args.output_dir - start_time = time.time() - - if args.qformat == "int4_awq" and model_type == "qwen": - torch.save(model.state_dict(), export_path) - else: - export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan', 'gemma' - ]) - export_model_config(model, - model_type, - getattr(torch, args.dtype), - export_dir=export_path, - inference_tensor_parallel=args.tp_size, - inference_pipeline_parallel=args.pp_size, - export_tensorrt_llm_config=(not export_npz), - export_npz=export_npz) - - # Workaround for wo quantization - if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json", 'r') as f: - tensorrt_llm_config = json.load(f) - if args.qformat == "int8_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' - elif args.qformat == "int4_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' - else: - tensorrt_llm_config["quantization"]["quant_algo"] = None - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - - end_time = time.time() - print( - "Quantized model exported to {} \nTotal time used {:.2f} s.".format( - export_path, end_time - start_time)) +from tensorrt_llm.quantization import quantize_and_export if __name__ == "__main__": + DEFAULT_RAND_SEED = 1234 + DEFAULT_MAX_SEQ_LEN = 2048 parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model_dir", help="Specify where the HuggingFace model is", @@ -343,6 +20,17 @@ def main(args): "full_prec" ], ) + parser.add_argument( + "--seed", + help="Seed the generate random numbers, the value will be used to call" + "random.seed(value) and numpy.random.seed(value)", + type=int, + default=DEFAULT_RAND_SEED) + parser.add_argument("--max_seq_length", + help="Max sequence length to init the tokenizers", + type=int, + default=DEFAULT_MAX_SEQ_LEN) + parser.add_argument("--batch_size", help="Batch size for calibration.", type=int, @@ -361,4 +49,16 @@ def main(args): choices=["int8", "fp8", None]) args = parser.parse_args() - main(args) + quantize_and_export(model_dir=args.model_dir, + dtype=args.dtype, + output_dir=args.output_dir, + device=args.device, + tp_size=args.tp_size, + pp_size=args.pp_size, + qformat=args.qformat, + kv_cache_dtype=args.kv_cache_dtype, + calib_size=args.calib_size, + batch_size=args.batch_size, + awq_block_size=args.awq_block_size, + seed=args.seed, + max_seq_length=args.max_seq_length) diff --git a/examples/quantization/requirements.txt b/examples/quantization/requirements.txt index ce63abe54..a2a9ceb9c 100644 --- a/examples/quantization/requirements.txt +++ b/examples/quantization/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets>=2.14.4 nemo-toolkit[all]<=1.20.0,>=1.18.0 rouge_score~=0.1.2 diff --git a/examples/qwen/build.py b/examples/qwen/build.py index 1183862e9..442676d00 100644 --- a/examples/qwen/build.py +++ b/examples/qwen/build.py @@ -236,14 +236,7 @@ def parse_arguments(): 'For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors are discarded ' '(0.45734 vs 0.45755 for LLaMA-v2 7B using ammo/examples/hf/instruct_eval/mmlu.py).' ) - parser.add_argument( - '--dense_context_fmha', - default=False, - action='store_true', - help= - 'Enable dense fmha in context phase, otherwise sliding window attention.' - 'If dense_context_fmha=False, the sliding window size is the max attention window size.' - ) + # Arguments related to the quantization of the model. parser.add_argument( '--use_smooth_quant', @@ -541,8 +534,7 @@ def get_model_object(args, mapping, trt_dtype=None): quant_mode=args.quant_mode, rms_norm_eps=args.rms_norm_eps, use_fused_mlp=args.use_fused_mlp, - use_prompt_tuning=args.max_prompt_embedding_table_size > 0, - dense_context_fmha=args.dense_context_fmha) + use_prompt_tuning=args.max_prompt_embedding_table_size > 0) quantize_kwargs = {} if args.use_smooth_quant or args.use_weight_only: if args.weight_only_precision == 'int4_awq': diff --git a/examples/qwen/requirements.txt b/examples/qwen/requirements.txt index 3c9f325ff..d20e2764b 100644 --- a/examples/qwen/requirements.txt +++ b/examples/qwen/requirements.txt @@ -1,7 +1,8 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 -transformers~=4.36.1 transformers-stream-generator sentencepiece~=0.1.99 tiktoken diff --git a/examples/qwenvl/README.md b/examples/qwenvl/README.md index 3df2d4eb9..a517990e1 100644 --- a/examples/qwenvl/README.md +++ b/examples/qwenvl/README.md @@ -11,7 +11,7 @@ ``` The ONNX and TensorRT engine will be generated under `./onnx/visual_encoder` and `./plan/visual_encoder` respectively. -- If you already have an OONX file under `./onnx/visual_encoder` and want to build a TensorRT engine with it, run: +- If you already have an ONNX file under `./onnx/visual_encoder` and want to build a TensorRT engine with it, run: ```bash python3 vit_onnx_trt.py --pretrained_model_path ./Qwen-VL-Chat --only_trt ``` diff --git a/examples/qwenvl/requirements.txt b/examples/qwenvl/requirements.txt index 36193f165..f7422ffb5 100644 --- a/examples/qwenvl/requirements.txt +++ b/examples/qwenvl/requirements.txt @@ -1,7 +1,8 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.16.0 evaluate~=0.4.1 rouge_score~=0.1.2 -transformers~=4.36.1 transformers-stream-generator sentencepiece~=0.1.99 tiktoken diff --git a/examples/qwenvl/run.py b/examples/qwenvl/run.py index 4340c662d..f7c8c000c 100644 --- a/examples/qwenvl/run.py +++ b/examples/qwenvl/run.py @@ -271,6 +271,7 @@ def generate_for_qwenvl( prompt_table=None, tasks=None, task_vocab_size=None, + num_beams=1, ): input_ids = None input_lengths = None @@ -283,9 +284,12 @@ def generate_for_qwenvl( max_input_length = torch.max(input_lengths).item() max_new_tokens = min(max_new_tokens, self.global_max_input_len - max_input_length) - self.decoder.setup(batch_size=input_lengths.size(0), - max_context_length=max_input_length, - max_new_tokens=max_new_tokens) + self.decoder.setup( + batch_size=input_lengths.size(0), + max_context_length=max_input_length, + max_new_tokens=max_new_tokens, + beam_width=num_beams, + ) profiler.start("QWen") run_time = 1 for _ in range(run_time): @@ -303,6 +307,7 @@ def qwen_infer(self, images_path, input_text, max_new_tokens, + num_beams=1, history=None): if images_path is None: content_list = [] @@ -337,7 +342,8 @@ def qwen_infer(self, input_vit, dtype, self.config.hidden_size, None, input_ids) output_ids, Qwen_time = self.generate_for_qwenvl( - input_ids, max_new_tokens, prompt_table, tasks, task_vocab_size) + input_ids, max_new_tokens, prompt_table, tasks, task_vocab_size, + num_beams) runtime_rank = tensorrt_llm.mpi_rank() input_lengths = torch.tensor([input_ids.size(1)], @@ -488,4 +494,5 @@ def vit_process(image_path, engine_dir, stream): args.images_path, args.input_text, args.max_new_tokens, + args.num_beams, history=[]) diff --git a/examples/qwenvl/run_chat.py b/examples/qwenvl/run_chat.py index 25e0dd896..37becbb24 100644 --- a/examples/qwenvl/run_chat.py +++ b/examples/qwenvl/run_chat.py @@ -118,7 +118,8 @@ def exist_cooridinate(input): run_i = run_i + 1 output_text = qinfer.qwen_infer(image_embeds, None, query, - args.max_new_tokens, history) + args.max_new_tokens, args.num_beams, + history) if args.display: if exist_cooridinate(output_text): cooridinate_extract_show(output_text, history, qinfer.tokenizer, diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt index 405599c4f..606faaeea 100644 --- a/examples/server/requirements.txt +++ b/examples/server/requirements.txt @@ -1,2 +1,4 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 uvicorn fastapi diff --git a/examples/server/test_executor.py b/examples/server/test_executor.py index 33d8af88b..da076f9ed 100644 --- a/examples/server/test_executor.py +++ b/examples/server/test_executor.py @@ -51,5 +51,5 @@ def test_sync_generation(): print("We have sent the requests: ", [id(f) for f in futures]) for future in executor.wait_first_completed(futures): print( - f"Request {id(future)} has finished: {future.wait_completion(timeout=0).text}" + f"Request {id(future)} has finished: {future.result(timeout=0).text}" ) diff --git a/examples/skywork/requirements.txt b/examples/skywork/requirements.txt index 72dfa8deb..626b31531 100644 --- a/examples/skywork/requirements.txt +++ b/examples/skywork/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 datasets~=2.16.1 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/examples/utils.py b/examples/utils.py index bf19c28b7..f5bee0f1a 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -19,7 +19,7 @@ from transformers import AutoTokenizer, T5Tokenizer -import tensorrt_llm +from tensorrt_llm.builder import get_engine_version # TODO(enweiz): Update for refactored models DEFAULT_HF_MODEL_DIRS = { @@ -47,7 +47,7 @@ def read_model_name(engine_dir: str): - engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir) + engine_version = get_engine_version(engine_dir) with open(Path(engine_dir) / "config.json", 'r') as f: config = json.load(f) diff --git a/examples/whisper/requirements.txt b/examples/whisper/requirements.txt index cb5230668..67ec9df76 100644 --- a/examples/whisper/requirements.txt +++ b/examples/whisper/requirements.txt @@ -1,3 +1,5 @@ +--extra-index-url https://pypi.nvidia.com +tensorrt_llm==0.9.0.dev2024031200 tiktoken datasets kaldialign diff --git a/requirements-windows.txt b/requirements-windows.txt index 0955add40..4acc11967 100644 --- a/requirements-windows.txt +++ b/requirements-windows.txt @@ -11,6 +11,9 @@ onnx>=1.12.0 polygraphy psutil pynvml>=11.5.0 +pulp +pandas +h5py pywin32 sentencepiece>=0.1.99 tensorrt==9.2.0.post12.dev5 diff --git a/requirements.txt b/requirements.txt index 51f7178e6..7cd6d86e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,9 +11,12 @@ onnx>=1.12.0 polygraphy psutil pynvml>=11.5.0 +pulp +pandas +h5py sentencepiece>=0.1.99 -tensorrt==9.2.0.post12.dev5 -torch<=2.2.0a +tensorrt==9.3.0.post12.dev1 +torch>=2.1.0a,<=2.2.0a # https://github.com/pytorch/pytorch/blob/v2.1.2/version.txt still uses 2.1.0a0. nvidia-ammo~=0.7.0; platform_machine=="x86_64" transformers==4.38.2 wheel diff --git a/scripts/replace_version.sh b/scripts/replace_version.sh new file mode 100644 index 000000000..ead15ff7c --- /dev/null +++ b/scripts/replace_version.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +DIR="tensorrt_llm/examples" +SOURCE="tensorrt_llm==0.9.0.dev0" +TARGET="tensorrt_llm==0.9.0.dev1" + +find "$DIR" -type f -name "requirements.txt" | while read -r file; do + sed -i "s/$SOURCE/$TARGET/g" "$file" +done diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 63c09a398..19efed50d 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -41,6 +41,7 @@ def _add_trt_llm_dll_directory(): from ._utils import mpi_barrier # NOQA from ._utils import str_dtype_to_torch # NOQA from ._utils import mpi_rank, mpi_world_size, str_dtype_to_trt +from .auto_parallel import AutoParallelConfig, auto_parallel from .builder import Builder, BuilderConfig from .functional import Tensor, constant from .hlapi.llm import LLM, ModelConfig @@ -73,6 +74,8 @@ def _add_trt_llm_dll_directory(): 'Module', 'functional', 'models', + 'auto_parallel', + 'AutoParallelConfig', 'quantization', 'tools', 'LLM', diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 3360e52c9..bf6c43edb 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -20,7 +20,7 @@ import weakref from functools import partial from pathlib import Path, PosixPath -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import yaml @@ -407,3 +407,20 @@ def unpack_nemo_weights(nemo_archive_path): raise Exception(err_str) return yaml.safe_load(model_config), torch.load( model_weights, map_location=torch.device("cpu")) + + +def set_obj_attrs( + obj: torch.Tensor, + ojb_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a object. + + This method is used to set attributes on a object. This method + will not overwrite existing attributes. + """ + if ojb_attrs is None: + return + for key, value in ojb_attrs.items(): + assert not hasattr( + obj, key), (f"Overwriting existing tensor attribute: {key}") + setattr(obj, key, value) diff --git a/tensorrt_llm/auto_parallel/__init__.py b/tensorrt_llm/auto_parallel/__init__.py new file mode 100644 index 000000000..7dcd31c8b --- /dev/null +++ b/tensorrt_llm/auto_parallel/__init__.py @@ -0,0 +1,6 @@ +from .auto_parallel import AutoParallelConfig, auto_parallel + +__all__ = [ + 'auto_parallel', + 'AutoParallelConfig', +] diff --git a/tensorrt_llm/auto_parallel/auto_parallel.py b/tensorrt_llm/auto_parallel/auto_parallel.py new file mode 100644 index 000000000..d7b79b8dd --- /dev/null +++ b/tensorrt_llm/auto_parallel/auto_parallel.py @@ -0,0 +1,263 @@ +import gc +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import tensorrt as trt +import torch +from filelock import FileLock + +from tensorrt_llm.functional import DimRange, Tensor +from tensorrt_llm.logger import logger +from tensorrt_llm.network import Network, net_guard + +from .config import AutoParallelConfig +from .device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh +from .node_graph import NodeGraph +from .parallelization import ParallelConfig, parallelize +from .pipeline_graph import PipelineGraph +from .simplifier import GraphConfig, Simplifier, StageType +from .utils import current_flags + + +def to_network(graph: PipelineGraph, network: Network): + logger.debug("Converting graph to network") + trt_network = graph.as_trt() + trt_network.name = network.trt_network.name + new_network = Network() + new_network._init(trt_network) + new_network._dtype = network._dtype + new_network._plugin_config = network._plugin_config + new_network._unfilled_weights = graph._unfilled_weights + new_network._auto_parallel_config = graph._auto_parallel_config + with net_guard(network): + for i in range(trt_network.num_inputs): + input = trt_network.get_input(i) + tensor = Tensor(is_network_input=False) + if input.name in network._inputs: + profiles = network._inputs[input.name].profiles + elif len(network._inputs) == 0: + profiles = [] + else: + shape = input.shape + num_profiles = len(list(network._inputs.values())[0].profiles) + profile = DimRange(shape, [None] * len(shape)) + profiles = [profile] * num_profiles + tensor.profiles = profiles + tensor.trt_tensor = input + new_network._inputs[input.name] = tensor + return new_network + + +def find_solution( + node_graph: NodeGraph, + graph_config: GraphConfig, + lmesh: LogicalDeviceMesh, + memory_budget: int, + flags: list, + device: int, + dump_path: str, +) -> ParallelConfig: + torch.cuda.set_device(device) + with current_flags(*flags): + cost_graph = node_graph.get_cost_graph(lmesh) + num_stages = graph_config.num_stages + if num_stages == 1: + stage_types = [None] + elif num_stages == 2: + stage_types = [StageType.START, StageType.END] + else: + stage_types = [StageType.START, StageType.BLOCK, StageType.END] + + best_config, best_solution = None, None + for stage_type in stage_types: + if stage_type is not None: + node_graph.set_slowest_stage(stage_type, graph_config) + solution = node_graph.find_solution( + cost_graph, + memory_budget, + ) + cost = solution.total_cost + if best_config is None or cost < best_config.cost: + best_config = ParallelConfig() + best_config.graph_config = graph_config + best_config.lmesh = lmesh + best_config.cost = cost + best_config.graph_strategy = solution.node_best_strategy + best_config.stage_type = stage_type + best_solution = solution + if dump_path is not None: + lock = FileLock(f"{dump_path}/path.lock", thread_local=False) + vlz_name = f"{dump_path}/solution." + if graph_config.num_micro_batches != 1: + vlz_name += f"mbs{graph_config.num_micro_batches}." + if graph_config.num_stages != 1: + vlz_name += f"stages{graph_config.num_stages}." + vlz_name += lmesh.cluster_key + with lock: + node_graph.visualize_solution( + best_solution, + vlz_name, + ignore_shape_io=True, + ) + return best_config + + +def infer_builder_flags(network): + fp16_enabled = False + bf16_enabled = False + int8_enabled = False + fp8_enabled = False + + def check_dtype(tensor): + nonlocal fp16_enabled + nonlocal bf16_enabled + nonlocal int8_enabled + nonlocal fp8_enabled + if tensor.dtype == trt.DataType.HALF: + fp16_enabled = True + elif tensor.dtype == trt.DataType.BF16: + bf16_enabled = True + elif tensor.dtype == trt.DataType.INT8: + int8_enabled = True + elif tensor.dtype == trt.DataType.FP8: + fp8_enabled = True + + trt_network = network.trt_network + for i in range(trt_network.num_inputs): + input = trt_network.get_input(i) + check_dtype(input) + for i in range(trt_network.num_layers): + layer = trt_network.get_layer(i) + for j in range(layer.num_outputs): + output = layer.get_output(j) + check_dtype(output) + + builder_flags = 0 + if fp16_enabled: + builder_flags |= 1 << int(trt.BuilderFlag.FP16) + builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + if bf16_enabled: + builder_flags |= 1 << int(trt.BuilderFlag.BF16) + builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + if int8_enabled: + builder_flags |= 1 << int(trt.BuilderFlag.INT8) + if fp8_enabled: + builder_flags |= 1 << int(trt.BuilderFlag.FP8) + builder_flags |= 1 << int(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + return builder_flags + + +def auto_parallel(network: Network, config: AutoParallelConfig): + debug_mode = config.debug_mode + memory_budget = config.get_cluster_info( + ).memory_budget_per_device * 1024 * 1024 * 1024 + enable_pipeline_parallelism = config.enable_pipeline_parallelism + if config.world_size < config.gpus_per_node: + num_hosts = 1 + num_devices_per_host = config.world_size + else: + assert config.world_size % config.gpus_per_node == 0 + num_hosts = config.world_size // config.gpus_per_node + num_devices_per_host = config.gpus_per_node + parallel_config_cache = config.parallel_config_cache + dump_path = config.dump_path if debug_mode else None + fill_weights = config.fill_weights + + if num_hosts == 1 and num_devices_per_host == 1: + return [network] + + if dump_path is not None: + if not os.path.exists(dump_path): + os.makedirs(dump_path) + + builder_flags = config.builder_flags or infer_builder_flags(network) + flags = [builder_flags, network.strongly_typed] + with current_flags(*flags): + simplifier = Simplifier(network, config) + network_hash = simplifier.get_network_hash() + + best_config = None + if parallel_config_cache is not None and Path( + parallel_config_cache).exists(): + parallel_config = ParallelConfig.from_file(parallel_config_cache) + if (ParallelConfig.VERSION == parallel_config.version + and network_hash == parallel_config.network_hash + and config == parallel_config.auto_parallel_config): + logger.info( + f"use cache of parallel config from {parallel_config_cache}" + ) + best_config = parallel_config + + if best_config is None: + num_devices = num_hosts * num_devices_per_host + phy_ids = [[ + i + j * num_devices_per_host + for i in range(num_devices_per_host) + ] for j in range(num_hosts)] + phy_mesh = PhysicalDeviceMesh(phy_ids, config) + if enable_pipeline_parallelism: + num_micro_batches_list = simplifier.list_all_num_micro_batches() + else: + num_micro_batches_list = [1] + + jobs = [] + for num_micro_batches in num_micro_batches_list: + simplifier.infer_shapes(num_micro_batches) + if enable_pipeline_parallelism: + pipeline_configs = phy_mesh.list_all_pipeline_configs() + else: + pipeline_configs = [(1, num_devices)] + for num_stages, num_devices_per_stage in pipeline_configs: + # TODO: add fallback path that allows num_micro_batches >= num_stages + # if no solution satisfies memory budget + if num_micro_batches < num_stages: + continue + simplified_graph, graph_config = simplifier.simplify_graph( + phy_mesh, + num_stages, + num_devices_per_stage, + ) + if simplified_graph is None: + continue + node_graph = NodeGraph(simplified_graph) + node_graph.assign_cost_weights(graph_config) + lmeshes = graph_config.stage_phy_meshes[ + 0].get_logical_meshes() + for lmesh in lmeshes: + jobs.append( + (node_graph, graph_config, lmesh, memory_budget * + (num_devices / num_devices_per_stage))) + + try: + with ThreadPoolExecutor() as executor: + best_config = sorted( + executor.map( + lambda x: find_solution( + *x, + flags, + torch.cuda.current_device(), + dump_path, + ), + jobs, + ), + key=lambda x: x.cost, + )[0] + finally: + phy_mesh.close() + + if parallel_config_cache is not None: + best_config.network_hash = network_hash + best_config.auto_parallel_config = config + best_config.save(parallel_config_cache) + + new_graphs = parallelize(simplifier, best_config) + + networks = [to_network(new_graph, network) for new_graph in new_graphs] + if debug_mode and fill_weights: + networks[0]._fill_weights() + + gc.collect() + torch.cuda.empty_cache() + + return networks diff --git a/tensorrt_llm/auto_parallel/config.py b/tensorrt_llm/auto_parallel/config.py new file mode 100644 index 000000000..b54410165 --- /dev/null +++ b/tensorrt_llm/auto_parallel/config.py @@ -0,0 +1,393 @@ +import json +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional, Union + +import torch + +from .utils import BaseEnum + + +class DictConversion: + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + obj = cls() + fields = obj.__dataclass_fields__ + for key, value in config.items(): + assert hasattr(obj, key) + field_cls = fields[key].type + if (isinstance(field_cls, type) + and issubclass(field_cls, DictConversion) + and isinstance(value, dict)): + value = field_cls.from_dict(value) + setattr(obj, key, value) + return obj + + def to_dict(self): + return asdict(self) + + @classmethod + def from_json_file(cls, file): + with open(file) as f: + return cls.from_dict(json.load(f)) + + def set_defaults(self, **kwargs): + for key, default in kwargs.items(): + value = getattr(self, key) + if (value is None + or (isinstance(value, (list, dict)) and len(value) == 0)): + setattr(self, key, default) + + +@dataclass +class MathThroughput(DictConversion): + int4: int = 0 # Tflops + int8: int = 0 # Tflops + fp8: int = 0 # Tflops + float16: int = 0 # Tflops + bfloat16: int = 0 # Tflops + float32: int = 0 # Tflops + + +@dataclass +class ClusterInfo(DictConversion): + inter_node_bw_per_device: int = 25 # GBps + intra_node_bw_per_device: int = 0 # GBps + inter_node_latency: int = 10 # us + intra_node_latency: int = 10 # us + intra_node_sharp: bool = False + inter_node_sharp: bool = True + + memory_bw: int = 0 # GBps + memory_budget_per_device: int = 0 # GB + + math_throughput: MathThroughput = field(default_factory=MathThroughput) + + memory_efficiency: float = 1.0 + math_efficiency: float = 1.0 + communication_efficiency: float = 1.0 + + +_math_throughputs = { + "A100": MathThroughput( + int8=624, + float16=312, + bfloat16=312, + float32=156, + ), +} + +_bandwidths = { + "PCIe-3": 16, + "PCIe-4": 32, + "PCIe-5": 64, +} + +_cluster_infos = { + # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf + "A100-SXM-80GB": + ClusterInfo( + intra_node_bw_per_device=300, + memory_bw=2039, + memory_budget_per_device=80, + math_throughput=_math_throughputs["A100"], + ), + "A100-SXM-40GB": + ClusterInfo( + intra_node_bw_per_device=300, + memory_bw=1555, + memory_budget_per_device=40, + math_throughput=_math_throughputs["A100"], + ), + "A100-PCIe-80GB": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=1935, + memory_budget_per_device=80, + math_throughput=_math_throughputs["A100"], + ), + "A100-PCIe-40GB": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=1555, + memory_budget_per_device=40, + math_throughput=_math_throughputs["A100"], + ), + # from https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet + "H100-SXM": + ClusterInfo( + inter_node_bw_per_device=50, + intra_node_bw_per_device=450, + intra_node_sharp=True, + memory_bw=3350, + memory_budget_per_device=80, + math_throughput=MathThroughput( + int8=1979, + fp8=1979, + float16=989, + bfloat16=989, + float32=495, + ), + ), + "H100-PCIe": + ClusterInfo( + inter_node_bw_per_device=50, + intra_node_bw_per_device=_bandwidths["PCIe-5"], + memory_bw=2000, + memory_budget_per_device=80, + math_throughput=MathThroughput( + int8=1513, + fp8=1513, + float16=756, + bfloat16=756, + float32=378, + ), + ), + # from https://images.nvidia.cn/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf + "V100-PCIe-16GB": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-3"], + memory_bw=900, + memory_budget_per_device=16, + math_throughput=MathThroughput(float32=112), + ), + "V100-PCIe-32GB": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-3"], + memory_bw=900, + memory_budget_per_device=32, + math_throughput=MathThroughput(float32=112), + ), + "V100-SMX-16GB": + ClusterInfo( + intra_node_bw_per_device=150, + memory_bw=900, + memory_budget_per_device=16, + math_throughput=MathThroughput(float32=125), + ), + "V100-SMX-32GB": + ClusterInfo( + intra_node_bw_per_device=150, + memory_bw=900, + memory_budget_per_device=32, + math_throughput=MathThroughput(float32=125), + ), + "V100S-PCIe": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-3"], + memory_bw=1134, + memory_budget_per_device=32, + math_throughput=MathThroughput(float32=130), + ), + # from https://images.nvidia.cn/content/Solutions/data-center/a40/nvidia-a40-datasheet.pdf + "A40": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=696, + memory_budget_per_device=48, + math_throughput=MathThroughput( + int4=600, + int8=300, + float16=150, + bfloat16=150, + float32=75, + ), + ), + # from https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf + "A30": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=933, + memory_budget_per_device=24, + math_throughput=MathThroughput( + int4=661, + int8=330, + float16=165, + bfloat16=165, + float32=82, + ), + ), + # from https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/datasheet-new/nvidia-a10-datasheet.pdf + "A10": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=600, + memory_budget_per_device=24, + math_throughput=MathThroughput( + int4=500, + int8=250, + float16=125, + bfloat16=125, + float32=62.5, + ), + ), + "A10G": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=600, + memory_budget_per_device=24, + math_throughput=MathThroughput( + int4=280, + int8=140, + float16=70, + bfloat16=70, + float32=35, + ), + ), + # from https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413 + "L40S": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=864, + memory_budget_per_device=48, + math_throughput=MathThroughput( + int4=733, + int8=733, + fp8=733, + float16=362, + bfloat16=362, + float32=183, + ), + ), + # from https://images.nvidia.cn/content/Solutions/data-center/vgpu-L40-datasheet.pdf + "L40": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=864, + memory_budget_per_device=48, + math_throughput=MathThroughput( + int4=724, + int8=362, + fp8=362, + float16=181, + bfloat16=181, + float32=90, + ), + ), + # from https://nvdam.widen.net/s/rvq98gbwsw/l4-datasheet-2595652 + "L4": + ClusterInfo( + intra_node_bw_per_device=_bandwidths["PCIe-4"], + memory_bw=300, + memory_budget_per_device=24, + math_throughput=MathThroughput( + int8=242, + fp8=242, + float16=120, + bfloat16=120, + float32=60, + ), + ), +} + + +def infer_cluster_key() -> str: + + def is_sxm(): + return "SXM" in device_name + + def is_80gb(): + return "80GB" in device_name + + def is_32gb(): + return "32GB" in device_name + + device_name = torch.cuda.get_device_name(torch.cuda.current_device()) + + if "A100" in device_name: + if is_sxm(): + if is_80gb(): + return "A100-SXM-80GB" + else: + return "A100-SXM-40GB" + else: + if is_80gb(): + return "A100-PCIe-80GB" + else: + return "A100-PCIe-40GB" + elif "A10G" in device_name: + return "A10G" + elif "A10" in device_name: + return "A10" + elif "A30" in device_name: + return "A30" + elif "A40" in device_name: + return "A40" + elif "H100" in device_name: + if is_sxm(): + return "H100-SXM" + else: + return "H100-PCIe" + elif "L40S" in device_name: + return "L40S" + elif "L40" in device_name: + return "L40" + elif "L4" in device_name: + return "L4" + elif "V100S" in device_name: + return "V100S-PCIe" + elif "V100" in device_name: + if is_sxm(): + if is_32gb(): + return "V100-SXM-32GB" + else: + return "V100-SXM-16GB" + else: + if is_32gb(): + return "V100-PCIe-32GB" + else: + return "V100-PCIe-16GB" + return None + + +class CostModel(str, BaseEnum): + ALPHA_BETA = "alpha_beta" + PROFILE = "profile" + S_CURVE = "s_curve" + # Zero cost model is for test purpose. + # Use zero cost model for communication will make solver prefer sharding + # Use zero cost model for computation will make solver prefer replication + ZERO = "zero" + + +@dataclass +class AutoParallelConfig(DictConversion): + # cluster configuration + world_size: int = 1 + gpus_per_node: int = 8 + cluster_key: str = None + cluster_info: Optional[ClusterInfo] = None + + # cost model configuration + sharding_cost_model: str = CostModel.ALPHA_BETA + comm_cost_model: str = CostModel.ALPHA_BETA + + # strategy configuration + enable_pipeline_parallelism: bool = False + enable_shard_unbalanced_shape: bool = False + enable_shard_dynamic_shape: bool = False + enable_reduce_scatter: bool = True + + # parallelization configuration + builder_flags: Optional[int] = None + debug_mode: bool = False + infer_shape: bool = True + validation_mode: bool = False + same_buffer_io: Dict[str, str] = field(default_factory=dict) + same_spec_io: Dict[str, str] = field(default_factory=dict) + sharded_io_allowlist: List[str] = field(default_factory=list) + fast_reduce: bool = True + fill_weights: bool = False + + # debug configuration + parallel_config_cache: Optional[str] = None + profile_cache: Optional[str] = None + dump_path: Optional[str] = None + debug_outputs: Union[List[str], str] = field(default_factory=list) + + def get_cluster_info(self) -> ClusterInfo: + return self.cluster_info or _cluster_infos[self.cluster_key] + + @property + def enabled(self) -> bool: + return self.world_size > 1 diff --git a/tensorrt_llm/auto_parallel/device_mesh.py b/tensorrt_llm/auto_parallel/device_mesh.py new file mode 100644 index 000000000..84c8d0ce2 --- /dev/null +++ b/tensorrt_llm/auto_parallel/device_mesh.py @@ -0,0 +1,612 @@ +import os +import re +from abc import ABC, abstractmethod +from typing import List + +import h5py +import numpy as np +from filelock import FileLock + +from .config import AutoParallelConfig, CostModel +from .tensor_parallel.shape_consistency import ShapeConsistencyManager + + +class ProfileDB(ABC): + """A database that stores profiling results for multiple device mesh + shapes.""" + + @abstractmethod + def query(self, cluster_key, data_key): + ... + + @abstractmethod + def update(self, cluster_key, data_key, mesh_result): + ... + + def close(self): + pass + + +class MemDB(ProfileDB): + + def __init__(self): + self.data = {} + + def query(self, cluster_key, data_key): + key = (cluster_key, data_key) + mesh_result = self.data.get(key, None) + if mesh_result is None: + return None + else: + return mesh_result[0] + + def update(self, cluster_key, data_key, mesh_result): + key = (cluster_key, data_key) + self.data[key] = mesh_result + + +class Hdf5DB(ProfileDB): + + def __init__(self, name): + self.name = name + lock_name = self.name + ".lock" + self.lock = FileLock(lock_name, thread_local=False) + + def query(self, cluster_key, data_key): + file_name = f"{self.name}.hdf5" + key = str((cluster_key, data_key)) + self.lock.acquire() + mesh_result = None + with h5py.File(file_name, 'a') as f: + if key in f: + self.lock.release() + mesh_result = f[key] + return mesh_result[0] + else: + return None + + def update(self, cluster_key, data_key, mesh_result): + key = str((cluster_key, data_key)) + file_name = f"{self.name}.hdf5" + with h5py.File(file_name, 'a') as f: + f[key] = mesh_result + + def close(self): + self.lock.release(force=True) + + +class LogicalDeviceMesh(object): + + def __init__(self, + phy_mesh_shape, + mesh_shape, + phy_ids, + config: AutoParallelConfig, + alpha, + beta, + sharp, + prof_database=None, + shape_consistency_manager=None, + host_ips=None): + self.phy_mesh_shape = phy_mesh_shape + self.mesh_shape = mesh_shape + self.phy_ids = phy_ids + self.host_ips = host_ips + self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join( + [str(i) for i in mesh_shape])) + self.prof_min_max_size = [1, 2**34] + self.prof_comm_dtypes = [ + "int8", "uint8", "int32", "uint32", "int64", "uint64", "float16", + "float32", "float64", "bfloat16" + ] + self.devices_group = { + (0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1], + (1, ): [self.phy_ids, self.mesh_shape[1]], + (0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0] + } + self.prof_database = prof_database + self.shape_consistency_manager = shape_consistency_manager + self.config = config + self.cluster_info = config.get_cluster_info() + self.hw_alpha = alpha + self.hw_beta = beta + self.hw_sharp = sharp + self.algo_alpha_beta = self._estimate_algo_alpha_beta() + self.comm_op_to_nccl_test_func_name = { + 'all_reduce': 'all_reduce_perf_mpi', + 'all_gather': 'all_gather_perf_mpi', + 'all_to_all': 'alltoall_perf_mpi', + 'reduce_scatter': 'reduce_scatter_perf_mpi', + 'split': 'split', + } + + @property + def size(self) -> int: + return self.phy_ids.size + + def _estimate_algo_alpha_beta(self): + ret = {} + ar_alpha, ar_beta = {}, {} + ag_alpha, ag_beta = {}, {} + rs_alpha, rs_beta = {}, {} + a2a_alpha, a2a_beta = {}, {} + phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape + if phy_num_hosts == 1 or phy_num_devices_per_host == 1: + for dims in [(0, ), (1, ), (0, 1), (1, 0)]: + num_devices = 1 + for dim in dims: + num_devices = self.mesh_shape[dim] * num_devices + if num_devices != 1: + ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[ + 0] else self.hw_alpha[0] * num_devices / 2 / ( + num_devices - 1) + ar_beta[dims] = self.hw_beta[0] + ag_alpha[dims] = self.hw_alpha[0] * num_devices / ( + num_devices - 1) + ag_beta[dims] = self.hw_beta[0] + rs_alpha[dims] = self.hw_alpha[0] * num_devices / ( + num_devices - 1) + rs_beta[dims] = self.hw_beta[0] + a2a_alpha[dims] = self.hw_alpha[0] * num_devices / ( + num_devices - 1) + a2a_beta[dims] = self.hw_beta[0] + # phy and logical have the same mesh shape if num_hosts > 1 and num_devices_per_host > 1 + else: + for dims in [(0, ), (1, ), (0, 1), (1, 0)]: + num_devices = 1 + for dim in dims: + num_devices = self.mesh_shape[dim] * num_devices + if num_devices != 1: + if len(dims) == 1: + dim = dims[0] + ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[ + dim] else self.hw_alpha[dim] * num_devices / 2 / ( + num_devices - 1) + ar_beta[dims] = self.hw_beta[dim] + ag_alpha[dims] = self.hw_alpha[dim] * num_devices / ( + num_devices - 1) + ag_beta[dims] = self.hw_beta[dim] + rs_alpha[dims] = self.hw_alpha[dim] * num_devices / ( + num_devices - 1) + rs_beta[dims] = self.hw_beta[dim] + a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / ( + num_devices - 1) + a2a_beta[dims] = self.hw_beta[dim] + elif len(dims) == 2: # two level communication + num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host + inter_node_col_alpha = self.hw_alpha[ + 0] * num_devices_per_host + inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[ + 0] else inter_node_col_alpha * num_hosts / 2 / ( + num_hosts - 1) + intra_node_ar_alpha = self.hw_alpha[1] + intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[ + 1] else intra_node_ar_alpha * num_devices_per_host / 2 / ( + num_devices_per_host - 1) + ar_alpha[dims] = min(inter_node_ar_alpha, + intra_node_ar_alpha) + ar_beta[dims] = max(self.hw_beta) + ag_alpha[dims] = min( + inter_node_col_alpha * num_hosts / (num_hosts - 1), + self.hw_alpha[1] * num_devices_per_host / + (num_devices_per_host - 1)) + ag_beta[dims] = max(self.hw_beta) + rs_alpha[dims] = ag_alpha[dims] + rs_beta[dims] = ag_beta[dims] + a2a_alpha[dims] = min( + num_hosts * self.hw_alpha[0] / (num_hosts - 1), + self.hw_alpha[1] * num_hosts) + a2a_beta[dims] = max(self.hw_beta) + else: + pass + ret['all_to_all'] = [a2a_alpha, a2a_beta] + ret['all_reduce'] = [ar_alpha, ar_beta] + ret['all_gather'] = [ag_alpha, ag_beta] + ret['reduce_scatter'] = [rs_alpha, rs_beta] + ret['p2p_cross_device'] = [ + self.cluster_info.intra_node_bw_per_device, + self.cluster_info.intra_node_latency + ] + ret['p2p_cross_host'] = [ + self.cluster_info.inter_node_bw_per_device, + self.cluster_info.inter_node_latency + ] + return ret + + #[ToDo][KDuan] stub functions here + def _profile_split(self, min_max_comm_size): + comm_size, elapsed_time = [], [] + size = min_max_comm_size[0] + while size <= min_max_comm_size[1]: + time = size * 2 / self.cluster_info.memory_bw + comm_size.append(size) + elapsed_time.append(time) + size = size * 2 + return np.array([comm_size, elapsed_time]) + + def _prase_nccl_test_results(self, f_nccl_test_out_log): + '''[ToDo][KDuan] There is some dtye that may not been supported by nccl test, using default dtype (float)''' + start_parse = False + comm_size, elapsed_time = [], [] + try: + with open(f_nccl_test_out_log, 'r') as lines: + for line in lines: + if start_parse: + prof_data = re.split(r"[ ]+", line.strip()) + if len(prof_data) != 13: + continue + comm_size.append(float(prof_data[0])) + elapsed_time.append(float(prof_data[5])) + if 'GB/s' in line and 'us' in line: + start_parse = True + except Exception: + print(f'failed to parse {f_nccl_test_out_log}') + return comm_size, elapsed_time + + def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group, + func_name, step, workload_key): + + if func_name == 'split': + if 2 == step: + return self._profile_split(min_max_comm_size) + else: + return None + workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}' + os.makedirs(workspace_dir, exist_ok=True) + outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err' + if 1 == step: + num_nodes = len(self.host_ips) + num_gpus = self.mesh_shape[0] * self.mesh_shape[1] + ntasks_per_node = num_gpus // num_nodes + nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format( + device_group[1], func_name, min_max_comm_size[0], + min_max_comm_size[1], dtype, 2) + sbatch_command = '#!/bin/bash\n' + sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition']) + sbatch_command += '#SBATCH -A {}\n'.format(self.config['account']) + sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname']) + sbatch_command += '#SBATCH -N {}\n'.format(num_nodes) + sbatch_command += '#SBATCH -t {}\n'.format(self.config['time']) + sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format( + ntasks_per_node) + sbatch_command += '#SBATCH --exclusive\n' + sbatch_command += '#SBATCH --mem=0\n' + sbatch_command += '#SBATCH --network=sharp\n' + sbatch_command += '#SBATCH --mail-type=FAIL\n' + srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format( + num_nodes, ntasks_per_node, outfile, errfile, + self.config['container']) + command = sbatch_command + srun_command + nccl_test_command + with open(workspace_dir + '/workload.sub', 'w') as f: + f.write(command) + with open('./preprofiling_step1.sh', 'a') as f: + f.write(f'sbatch {workspace_dir}/workload.sub\n') + return None + + else: + comm_size, elapsed_time = self._prase_nccl_test_results(outfile) + if len(comm_size) < 2: + assert 0, 'the profiling for {} was failed at step1, please try again'.format( + workload_key) + else: + print(workload_key, comm_size, elapsed_time) + return np.array([comm_size, elapsed_time]) + + def _profile_single_comm_perf(self, device_group, comm_op, step, data_key): + results = {} + func_name = self.comm_op_to_nccl_test_func_name[comm_op] + for dtype in self.prof_comm_dtypes: + size_time = self._profile_with_nccl_test( + self.prof_min_max_size, dtype, device_group, func_name, step, + data_key + f'_dtype{dtype}') + results[dtype] = size_time + return results + + def profile_all_comms_perf(self, step): + if self.mesh_shape == (1, 1): + return None + mesh_results = self.prof_database.query(self.cluster_key, + self.mesh_shape) + if mesh_results: + return mesh_results + + mesh_results = {} + data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}' + for comm_op in [ + 'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter', + 'split' + ]: + comm_perf = {} + for dim, device_group in self.devices_group.items(): + # don't need to profile for mesh dim == 1 + if len(dim) == 1 and self.mesh_shape[dim[0]] == 1: + continue + + comm_perf[dim] = self._profile_single_comm_perf( + device_group, comm_op, step, data_key + + '_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim)))) + mesh_results[comm_op] = comm_perf + if 2 == step: + self.prof_database.update(self.cluster_key, self.mesh_shape, + mesh_results) + + return mesh_results + + def _model_comm_cost_from_s_curve(self, size_time_array, realsize): + assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\ + 'the comm_size: {} is not in the profile range: [{}{}]'\ + .format(realsize, size_time_array[0][0], size_time_array[0][-1]) + return np.interp(realsize, size_time_array[0], size_time_array[1]) + + def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes): + elapsed_time = 0.0 + if 'split' == comm_op: + elapsed_time = size_in_bytes * 2 / ( + self.cluster_info.memory_bw * + self.cluster_info.memory_efficiency) * 1e-3 + else: + dict_alpha, dict_beta = self.algo_alpha_beta[comm_op] + alpha, beta = dict_alpha[dim_key], dict_beta[dim_key] + elapsed_time = (size_in_bytes / + (alpha * self.cluster_info.communication_efficiency) + * 1e-3) + beta + return elapsed_time + + def _input_size_to_comm_size(self, comm_op, dims, input_size): + ret = input_size + if 'all_gather' == comm_op: + for dim in dims: + ret = ret * self.mesh_shape[dim] + return ret + + def estimate_comm_cost(self, comm_op, dim, input_size, dtype): + + size = self._input_size_to_comm_size(comm_op, dim, input_size) + if self.config.comm_cost_model == CostModel.S_CURVE: + mesh_perf = self.prof_database.query(self.cluster_key, + self.mesh_shape) + assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format( + self.mesh_shape) + comm_op_perf = mesh_perf.get(comm_op, None) + assert comm_op_perf is not None, '{} is not profiled'.format( + comm_op) + elapsed_time = self._model_comm_cost_from_s_curve( + comm_op_perf[tuple(dim)][dtype], size) + return elapsed_time + elif self.config.comm_cost_model == CostModel.ALPHA_BETA: + elapsed_time = self._model_comm_cost_from_alpha_beta( + comm_op, tuple(dim), size) + elif self.config.comm_cost_model == CostModel.PROFILE: + assert False, 'Unsupported profile based communication cost model now' + elif self.config.comm_cost_model == CostModel.ZERO: + elapsed_time = 0.0 + + return elapsed_time # us + + +class PhysicalDeviceMesh(object): + + def __init__(self, + phy_devices_id, + config: AutoParallelConfig, + prof_database=None, + shape_consistency_manager=None, + host_ips=None): + self.phy_devices_id = np.array(phy_devices_id) + self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape + self.host_ips = host_ips + if host_ips is None: + self.host_ips = [''] * self.num_hosts + self.config = config + self.cluster_info = config.get_cluster_info() + self.prof_database: ProfileDB = prof_database + self.shape_consistency_manager = shape_consistency_manager + if self.config.comm_cost_model not in CostModel: + raise ValueError( + f'unsupported communication cost model: {self.config.comm_cost_model}' + ) + if self.config.sharding_cost_model not in CostModel: + raise ValueError( + f'unsupported sharding cost model: {self.config.sharding_cost_model}' + ) + if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE: + if self.prof_database is None: + profile_cache = config.profile_cache + if profile_cache is None: + self.prof_database = MemDB() + else: + self.prof_database = Hdf5DB(profile_cache) + elif self.config.comm_cost_model == CostModel.ALPHA_BETA: + assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method' + assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method' + if self.config.sharding_cost_model == CostModel.ALPHA_BETA: + assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method' + + if not shape_consistency_manager: + self.shape_consistency_manager = ShapeConsistencyManager() + + @property + def size(self) -> int: + return self.phy_devices_id.size + + def close(self): + if self.prof_database is not None: + self.prof_database.close() + + def split_pipeline_meshes( + self, num_stages, + num_devices_per_stage) -> List["PhysicalDeviceMesh"]: + sub_meshes = [] + if num_devices_per_stage <= self.num_devices_per_host: + assert self.num_devices_per_host % num_devices_per_stage == 0, \ + "num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\ + .format(self.num_devices_per_host, num_devices_per_stage) + num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage + num_clusters = self.num_hosts * num_clusters_per_host + assert num_stages % num_clusters == 0, \ + "num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters) + for mesh_id in range(num_stages): + cluster_id = mesh_id % num_clusters + cluster_col = cluster_id % num_clusters_per_host + cluster_row = cluster_id // num_clusters_per_host + sub_devices_id = [ + self.phy_devices_id[cluster_row][cluster_col * + num_devices_per_stage:( + (cluster_col + 1) * + num_devices_per_stage)] + ] + sub_meshes.append( + PhysicalDeviceMesh(sub_devices_id, self.config, + self.prof_database, + self.shape_consistency_manager, + [self.host_ips[cluster_row]])) + else: + assert num_devices_per_stage % self.num_devices_per_host == 0, \ + "num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\ + .format(num_devices_per_stage, self.num_devices_per_host) + num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host + assert self.num_hosts % num_host_per_cluster == 0, \ + "num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster) + num_clusters = self.num_hosts // num_host_per_cluster + for mesh_id in range(num_stages): + cluster_id = mesh_id % num_clusters + cluster_row = cluster_id * num_host_per_cluster + sub_devices_id = self.phy_devices_id[cluster_row:( + cluster_row + num_host_per_cluster)] + host_ips = self.host_ips[cluster_row:(cluster_row + + num_host_per_cluster)] + sub_meshes.append( + PhysicalDeviceMesh(sub_devices_id, self.config, + self.prof_database, + self.shape_consistency_manager, + host_ips)) + return sub_meshes + + def _profile_logical_meshes(self, logical_meshes, step): + for lmesh in logical_meshes: + lmesh.profile_all_comms_perf(step) + + def as_logical_mesh(self) -> LogicalDeviceMesh: + alpha = [ + self.cluster_info.inter_node_bw_per_device, + self.cluster_info.intra_node_bw_per_device + ] + beta = [ + self.cluster_info.inter_node_latency, + self.cluster_info.intra_node_latency + ] + sharp = [ + self.cluster_info.inter_node_sharp, + self.cluster_info.intra_node_sharp + ] + return LogicalDeviceMesh( + self.phy_devices_id.shape, + self.phy_devices_id.shape, + self.phy_devices_id, + self.config, + alpha, + beta, + sharp, + self.prof_database, + self.shape_consistency_manager, + self.host_ips, + ) + + def get_logical_meshes(self): + logical_meshes = [] + # (1, 2) -> (1, 2) + # (1, 4) -> (2, 2) + # (1, 8) -> (2, 4) + # (1, 16) -> (2, 8), (4, 4) + # (1, 32) -> (2, 16), (4, 8) + # (1, 48) -> (2, 24), (3, 16), (4, 12), (6, 8) + # (1, 64) -> (2, 32), (4, 16), (8, 8) + # we will traverse logical shape's axis in sharding spec, thus (2, 8) contains (8, 2) + # we will merge logical shapes' axis, thus (2, 8) contains (1, 16) and (16, 1) + if self.num_hosts == 1: + alpha = [self.cluster_info.intra_node_bw_per_device] + beta = [self.cluster_info.intra_node_latency] + sharp = [self.cluster_info.intra_node_sharp] + for i in range(2, self.num_devices_per_host): + if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host: + lmesh_shape = (i, self.num_devices_per_host // i) + lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) + logical_meshes.append( + LogicalDeviceMesh(self.phy_devices_id.shape, + lmesh_shape, lmesh_phy_ids, + self.config, alpha, beta, sharp, + self.prof_database, + self.shape_consistency_manager, + self.host_ips)) + # (8, 1) -> (2, 4) + # (16, 1) -> (2, 8), (4, 4) + elif self.num_devices_per_host == 1: + alpha = [self.cluster_info.inter_node_bw_per_device] + beta = [self.cluster_info.inter_node_latency] + sharp = [self.cluster_info.inter_node_sharp] + for i in range(2, self.num_hosts): + if self.num_hosts % i == 0 and i * i <= self.num_hosts: + lmesh_shape = (i, self.num_hosts // i) + lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) + logical_meshes.append( + LogicalDeviceMesh(self.phy_devices_id.shape, + lmesh_phy_ids, self.config, alpha, + beta, sharp, self.prof_database, + self.shape_consistency_manager, + self.host_ips)) + # (2, 1) -> (2, 1) + # (2, 8) -> (2, 8) + # (1, 2) -> (1, 2) + # (1, 3) -> (1, 3) + # (1, 5) -> (1, 5) + if 0 == len(logical_meshes): + logical_meshes.append(self.as_logical_mesh()) + return logical_meshes + + ''' + we assume we can evenly split the pipeline and deviceMesh + ''' + + def _list_all_sub_meshes(self): + sub_meshes = [] + for num_devices_per_stage in range(1, self.num_devices_per_host + 1): + if self.num_devices_per_host % num_devices_per_stage == 0: + num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage + sub_meshes.append( + self.split_pipeline_meshes(num_stages, + num_devices_per_stage)[0]) + for num_hosts_per_stage in range(2, self.num_hosts + 1): + if self.num_hosts % num_hosts_per_stage == 0: + num_stages = self.num_hosts // num_hosts_per_stage + sub_meshes.append( + self.split_pipeline_meshes( + num_stages, + num_hosts_per_stage * self.num_devices_per_host)[0]) + return sub_meshes + + def list_all_pipeline_configs(self): + configs = [] + for num_devices_per_stage in range(1, self.num_devices_per_host + 1): + if self.num_devices_per_host % num_devices_per_stage == 0: + num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage + configs.append((num_stages, num_devices_per_stage)) + for num_hosts_per_stage in range(2, self.num_hosts + 1): + if self.num_hosts % num_hosts_per_stage == 0: + num_stages = self.num_hosts // num_hosts_per_stage + configs.append( + (num_stages, + num_hosts_per_stage * self.num_devices_per_host)) + return configs + + def profile_s_curve(self, step): + sub_phy_device_meshes = self._list_all_sub_meshes() + for phy_mesh in sub_phy_device_meshes: + lmeshes = phy_mesh.get_logical_meshes() + self._profile_logical_meshes(lmeshes, step) + if 2 == step: + self.save_profile_database() + + def profile_alpha_beta(self): + alpha = [250, 25] + beta = [100, 100] + return alpha, beta diff --git a/tensorrt_llm/auto_parallel/node_graph.py b/tensorrt_llm/auto_parallel/node_graph.py new file mode 100644 index 000000000..503925c9d --- /dev/null +++ b/tensorrt_llm/auto_parallel/node_graph.py @@ -0,0 +1,347 @@ +from typing import List + +import pandas as pd +import tensorrt as trt + +from .pipeline_graph import PipelineGraph +from .runtime_profiling import RuntimeProfiler +from .simplifier import GraphConfig, StageType +from .solver import CostGraph, Solver +from .tensor_parallel.activation_node import Activation +from .tensor_parallel.assertion_node import Assertion +from .tensor_parallel.cast_node import Cast +from .tensor_parallel.concatenation_node import Concatenation +from .tensor_parallel.constant_node import Constant +from .tensor_parallel.elementwise_node import ElementWise +from .tensor_parallel.fill_node import Fill +from .tensor_parallel.gather_node import Gather +from .tensor_parallel.identity_node import Identity +from .tensor_parallel.input_node import InputNode +from .tensor_parallel.matmul_node import MatrixMultiply +from .tensor_parallel.node import Node +from .tensor_parallel.normalization_node import Normalization +from .tensor_parallel.output_node import OuputNode +from .tensor_parallel.p2p_node import P2PNode, P2PType +from .tensor_parallel.plugin_node import PluginNode +from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin +from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin +from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin +from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin +from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin, + RMSnormPlugin) +from .tensor_parallel.reduce_node import Reduce +from .tensor_parallel.select_node import Select +from .tensor_parallel.shape_node import Shape +from .tensor_parallel.shuffle_node import Shuffle +from .tensor_parallel.slice_node import Slice +from .tensor_parallel.softmax_node import SoftMax +from .tensor_parallel.unary_node import Unary + +LAYER_TYPE_2_NODE_TYPE = { + trt.LayerType.ACTIVATION: Activation, + trt.LayerType.ASSERTION: Assertion, + trt.LayerType.CAST: Cast, + trt.LayerType.CONCATENATION: Concatenation, + trt.LayerType.CONSTANT: Constant, + trt.LayerType.ELEMENTWISE: ElementWise, + trt.LayerType.FILL: Fill, + trt.LayerType.GATHER: Gather, + trt.LayerType.IDENTITY: Identity, + trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply, + trt.LayerType.NORMALIZATION: Normalization, + trt.LayerType.PLUGIN_V2: PluginNode, + trt.LayerType.REDUCE: Reduce, + trt.LayerType.SELECT: Select, + trt.LayerType.SHAPE: Shape, + trt.LayerType.SHUFFLE: Shuffle, + trt.LayerType.SLICE: Slice, + trt.LayerType.SOFTMAX: SoftMax, + trt.LayerType.UNARY: Unary, +} +# TODO: BertAttention/All Quant plugins +PLUGIN_LAYER_TYPE_2_NODE_TYPE = { + 'GPTAttention': GPTAttentionPlugin, + 'Gemm': GemmPlugin, + 'Layernorm': LayernormPlugin, + 'Rmsnorm': RMSnormPlugin, + 'Lookup': LookupPlugin, + 'Identity': IdentityPlugin, +} + + +class NodeGraph: + + def __init__(self, graph: PipelineGraph): + self._nodes = {} + + # construct nodes + for input in graph.inputs: + self._nodes[input.name] = InputNode(input) + for layer in graph.layers: + layer.to_base_class() + if "p2p_type" in layer.attrs: + self._nodes[layer.name] = P2PNode(layer) + elif layer.type == trt.LayerType.PLUGIN_V2: + layer.to_subclass() + plugin_type = layer.as_trt().plugin.plugin_type + layer.to_base_class() + if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE: + node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer) + else: + node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) + self._nodes[layer.name] = node + else: + node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) + self._nodes[layer.name] = node + for output in graph.outputs: + self._nodes[output.name] = OuputNode(output) + for node in self.nodes: + node.post_init(self) + node.node_runtime_profiler = RuntimeProfiler() + + def get_node(self, name): + return self._nodes[name] + + @property + def nodes(self) -> List[Node]: + return [*self._nodes.values()] + + def assign_cost_weights(self, graph_config: GraphConfig): + layer_mapping = graph_config.graph_mapping.layer_mapping + for layer_name in layer_mapping.values(): + node = self.get_node(layer_name) + node.sharding_weight += 1 + node.resharding_weight += 1 + same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping + for same_spec_layer_name, layer_name in same_spec_layer_mapping.items(): + node = self.get_node(layer_name) + same_spec_node = self.get_node(same_spec_layer_name) + same_spec_node.sharding_weight = node.sharding_weight + same_spec_node.resharding_weight = node.resharding_weight + + def set_slowest_stage(self, stage_type: StageType, + graph_config: GraphConfig): + num_micro_batches = graph_config.num_micro_batches + block_per_stage = graph_config.num_blocks // graph_config.num_stages + block_pipeline_weight = block_per_stage * (num_micro_batches - 1) + for node in self.nodes: + node.pipeline_weight = 0 + node.cost_level = -1 + if node.stage_type == StageType.START: + if stage_type == StageType.START: + node.pipeline_weight = num_micro_batches - 1 + node.cost_level = 1 + else: + node.cost_level = 0 + if stage_type == StageType.START and node.in_start_block: + node.pipeline_weight = block_pipeline_weight + if node.stage_type == StageType.END: + if stage_type == StageType.END: + node.pipeline_weight = num_micro_batches - 1 + node.cost_level = 1 + else: + node.cost_level = 0 + if stage_type == StageType.END and node.in_end_block: + node.pipeline_weight = block_pipeline_weight + if isinstance(node, P2PNode): + if (graph_config.has_cross_host + and node.p2p_type == P2PType.CROSS_HOST) or ( + not graph_config.has_cross_host + and node.p2p_type == P2PType.CROSS_DEVICE): + if stage_type == StageType.BLOCK: + node.pipeline_weight += num_micro_batches - 1 + node.cost_level = 1 + else: + node.cost_level = 0 + elif (graph_config.has_cross_device + and node.p2p_type == P2PType.CROSS_DEVICE) or ( + not graph_config.has_cross_device + and node.p2p_type == P2PType.CROSS_HOST): + node.pipeline_weight += num_micro_batches - 1 + if stage_type == StageType.BLOCK and node.in_slowest_block: + node.pipeline_weight = block_pipeline_weight + + def get_cost_graph(self, lmesh): + leaf_strategies = [] + for node in self.nodes: + if node.is_replicated: + node.set_strategy(None, lmesh) + else: + node.collect_strategies(lmesh) + for node in self.nodes: + strategies_vector = node.update_resharding_cost() + if len(strategies_vector) != 0: + leaf_strategies.append(strategies_vector) + cost_graph = CostGraph(leaf_strategies) + return cost_graph + + def find_solution(self, cost_graph, memory_budget): + solver = Solver(cost_graph, memory_budget=memory_budget) + solution = solver.find_solution()[1] + + graph_strategy = solution.node_best_strategy + for node_name, strategy in graph_strategy.items(): + node = self._nodes[node_name] + for idx, pre_node in enumerate(node.predecessor_nodes): + if pre_node is None: + continue + if pre_node.node_name not in strategy.best_resharding_cost: + continue + strategy.best_resharding_cost[ + idx] = strategy.best_resharding_cost[pre_node.node_name] + strategy.node_names[idx] = pre_node.node_name + for key in list(strategy.best_resharding_cost.keys()): + if isinstance(key, str): + del strategy.best_resharding_cost[key] + + return solution + + def visualize(self, name='pp_graph'): + with open(name + '.dot', 'w') as f: + f.write("digraph {\n") + ''' + f.write(" // Value Nodes\n") + for name, tensor in self._tensors.items(): + f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape)) + ''' + f.write(" // Operation Nodes\n") + for name, node in self._nodes.items(): + fillcolor = 'white' + if 'MATRIX_MULTIPLY' in name: + fillcolor = 'green' + label = name + if len(node.outputs) > 0: + label = name + '\\n' + str(node.outputs[0].shape) + f.write( + " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n" + .format(name, fillcolor, label)) + f.write(" // Edges\n") + for name, node in self._nodes.items(): + for successor_node in node.successor_nodes: + if successor_node: + f.write(" \"{}\" ->\"{}\";\n".format( + name, successor_node.node_name)) + f.write(" }\n") + + def visualize_solution(self, + solution, + fname='pp_graph_solution', + ignore_shape_io=True): + with open(fname + '.dot', 'w') as f: + names, costs, block_ids = [], [], [] + f.write("digraph {\n") + f.write(" // Operation Nodes\n") + for name, node in self._nodes.items(): + if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: + continue + cost = 0.0 + fillcolor = 'white' + if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name: + fillcolor = 'orange' + elif '_same_spec' in name: + fillcolor = 'gray' + elif 'p2p_block' in name: + fillcolor = 'blue' + elif 'PLUGIN' in name: + fillcolor = 'yellow' + + shape = 'box' + if 'output_node' == node.node_type or 'input_node' == node.node_type: + shape = 'ellipse' + fillcolor = 'green' + + label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}' + if len(node.inputs) > 0: + for idx, input in enumerate(node.inputs): + if not input: + continue + label = label + f'\\ninput{idx}_' + str( + input.shape) + f'_{input.dtype_str_size[0]}_' + if node.node_name in solution.node_best_strategy: + best_strategy = solution.node_best_strategy[ + node.node_name] + shard_seq = str( + best_strategy.sharding_specs[f'input{idx}']. + sharding_sequence) + label = label + shard_seq + if idx not in best_strategy.best_resharding_cost: + continue + rcosts = best_strategy.best_resharding_cost[idx][0] + comm_action_sequence, resharding_cost = rcosts[ + 1], rcosts[2] + if len(comm_action_sequence) > 0: + label = label + '|' + for commspec in comm_action_sequence: + comm = [ + commspec.comm_pattern, commspec.gather_dim, + commspec.shard_dim, + commspec.logical_process_axis + ] + label = label + '->' + str(comm) + if resharding_cost > 0: + label = label + '_rcost{:.2}'.format( + resharding_cost) + cost = cost + resharding_cost + if len(node.outputs) > 0: + best_strategy = None + for idx, output in enumerate(node.outputs): + label = label + f'\\noutput{idx}_' + str( + output.shape) + f'_{output.dtype_str_size[0]}' + if node.node_name in solution.node_best_strategy: + best_strategy = solution.node_best_strategy[ + node.node_name] + shard_seq = str( + best_strategy.sharding_specs[f'output{idx}']. + sharding_sequence) + comm = None + if f'output{idx}' in best_strategy.communication_actions: + commspec = best_strategy.communication_actions[ + f'output{idx}'] + comm = [ + commspec.comm_pattern, commspec.gather_dim, + commspec.shard_dim, + commspec.logical_process_axis + ] + label = label + '_' + shard_seq + if comm: + label = label + f' | {comm}' + if best_strategy: + cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost + label = label + '| scost{:.2}'.format( + best_strategy.sharding_cost) + if best_strategy.communication_cost > 0: + label = label + ' | ccost{:.2}'.format( + best_strategy.communication_cost) + names.append(name) + costs.append(cost) + block_ids.append([ + node.building_block_id, node.cost_level, + node.sharding_weight + node.pipeline_weight, + node.same_spec_id + ]) + f.write( + " \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n" + .format(name, fillcolor, label, shape)) + f.write(" // Edges\n") + for name, node in self._nodes.items(): + if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: + continue + for successor_node in node.successor_nodes: + if successor_node: + if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io: + continue + f.write(" \"{}\" ->\"{}\";\n".format( + name, successor_node.node_name)) + f.write(" }\n") + df = pd.DataFrame.from_dict({ + 'node': + names, + 'cost': + costs, + 'block_id': [block[0] for block in block_ids], + 'cost_level': [block[1] for block in block_ids], + 'sharding_weight': [block[2] for block in block_ids], + 'same_spec_id': [block[3] for block in block_ids] + }) + df['weight_cost'] = df['sharding_weight'] * df['cost'] + df.to_csv(fname + '.csv') diff --git a/tensorrt_llm/auto_parallel/parallelization.py b/tensorrt_llm/auto_parallel/parallelization.py new file mode 100644 index 000000000..cc598befa --- /dev/null +++ b/tensorrt_llm/auto_parallel/parallelization.py @@ -0,0 +1,2311 @@ +import contextlib +import copy +import itertools +import pickle +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, ClassVar, Dict, List, Sequence, Set, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +from filelock import FileLock + +from tensorrt_llm._utils import trt_dtype_to_np, trt_dtype_to_torch +from tensorrt_llm.functional import AllReduceStrategy +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight, + get_plugin_info, set_plugin_info) +from tensorrt_llm.plugin import (TRT_LLM_PLUGIN_NAMESPACE, + current_all_reduce_helper, + init_all_reduce_helper) +from tensorrt_llm.plugin.plugin import CustomAllReduceHelper +from tensorrt_llm.version import __version__ + +from .config import AutoParallelConfig +from .device_mesh import LogicalDeviceMesh +from .pipeline_graph import Layer, PipelineGraph, Tensor +from .shape_info import (ShapeInfo, get_per_layer_graph, get_shape_layers, + infer_per_layer_shapes) +from .simplifier import GraphConfig, GraphMapping, Simplifier, StageType +from .tensor_parallel.comm_spec import CommSpec +from .tensor_parallel.plugin_nodes.gpt_attention_node import ( + GPTAttentionPlugin, IdxEntry, IdxEntryParser) +from .tensor_parallel.sharding_spec import ShardingSpec, get_sharding_sequence +from .tensor_parallel.sharding_strategy import ShardingStrategy +from .utils import (get_builder_flags, get_updated_plugin, to_base_class_layer, + to_subclass_layer, to_trt_weights) + + +@dataclass +class ParallelConfig: + VERSION: ClassVar[str] = __version__ + + version: str = VERSION + network_hash: str = None + auto_parallel_config: AutoParallelConfig = None + graph_config: GraphConfig = None + lmesh: LogicalDeviceMesh = None + cost: float = None + graph_strategy: Dict[str, ShardingStrategy] = None + stage_type: StageType = None + + def save(self, filename): + with open(filename, 'wb') as file: + pickle.dump(self, file) + + @staticmethod + def from_file(filename) -> "ParallelConfig": + with open(filename, "rb") as file: + return pickle.load(file) + + def print_graph_strategy(self, file=None): + for index, (node_name, + strategy) in enumerate(self.graph_strategy.items()): + print(f'\n[{index}]: node_name = {node_name}', file=file) + strategy.print_strategy(best_resharding_cost_only=True, file=file) + + +def desimplify_strategy( + graph: PipelineGraph, + graph_strategy: Dict[str, ShardingStrategy], + graph_mapping: GraphMapping, +): + for strategy in graph_strategy.values(): + for name, commspec in list(strategy.communication_actions.items()): + strategy.communication_actions[name] = [commspec] + strategy.sharding_specs[ + f"{name}_after_comm"] = strategy.sharding_specs[name] + + # insert same spec layers' communication actions after + # its producer's communication actions + same_spec_layer_mapping = graph_mapping.same_spec_layer_mapping + for same_spec_layer_name in same_spec_layer_mapping.keys(): + same_spec_strategy = graph_strategy[same_spec_layer_name] + same_spec_commspecs = same_spec_strategy.best_resharding_cost[0][0][1] + if len(same_spec_commspecs) == 0: + continue + output_name = same_spec_layer_name[:-len("_same_spec")] + output = graph.get_tensor(output_name) + layer_name = output.producer.name + output_index = output.output_index + strategy = graph_strategy[layer_name] + commspecs = strategy.communication_actions.get(f"output{output_index}", + []) + commspecs.extend(same_spec_commspecs) + strategy.communication_actions[f"output{output_index}"] = commspecs + strategy.sharding_specs[ + f"output{output_index}_after_comm"] = same_spec_strategy.sharding_specs[ + "output0"] + + layer_mapping = graph_mapping.layer_mapping + for removed_layer_name, layer_name in layer_mapping.items(): + if layer_name in graph_strategy: + strategy = copy.copy(graph_strategy[layer_name]) + layer = graph.get_layer(removed_layer_name) + if layer is not None: + strategy.node_names = strategy.node_names.copy() + for index, name in list(strategy.node_names.items()): + input = layer.get_input(index) + node_name = input.name if input.producer is None else input.producer.name + strategy.node_names[index] = node_name + graph_strategy[removed_layer_name] = strategy + + +@dataclass +class SplitInfo: + input_dim: Union[int, trt.ITensor] + partition: int + + def __deepcopy__(self, memo) -> "SplitInfo": + return SplitInfo(self.input_dim, self.partition) + + +@dataclass +class TensorInfo: + name: str = None + split_infos: Dict[int, SplitInfo] = field(default_factory=dict) + + def set_split_info(self, dim, split_info): + self.split_infos[dim] = split_info + + def __deepcopy__(self, memo) -> "TensorInfo": + return TensorInfo(self.name, copy.deepcopy(self.split_infos)) + + +@dataclass +class TensorContext: + info_by_device: Dict[int, TensorInfo] = field(default_factory=dict) + device_dims_for_shape: Set[int] = field(default_factory=set) + + def update_name_mapping(self, device_id, new_name): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].name = new_name + + def set_split_info(self, device_id, dim, split_info): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].set_split_info(dim, split_info) + + def set_split_infos(self, device_id, split_infos: Dict[int, SplitInfo]): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].split_infos = split_infos + + def __deepcopy__(self, memo) -> "TensorContext": + return TensorContext(copy.deepcopy(self.info_by_device), + set(self.device_dims_for_shape)) + + +@dataclass +class LayerUpdate: + updated_attrs: Dict[str, Any] = field(default_factory=dict) + updated_inputs: Dict[int, trt.ITensor] = field(default_factory=dict) + split_info_updated: bool = False + + @staticmethod + def none() -> "LayerUpdate": + return LayerUpdate() + + +@dataclass +class GraphContext: + tensor_contexts: Dict[str, TensorContext] = field(default_factory=dict) + + def get_name(self, tensor_name, device_id): + if tensor_name not in self.tensor_contexts: + return None + if device_id not in self.tensor_contexts[tensor_name].info_by_device: + return None + return self.tensor_contexts[tensor_name].info_by_device[device_id].name + + def update_name_mapping(self, tensor_name, device_id, new_name): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].update_name_mapping( + device_id, new_name) + + def get_name_mapping(self, device_id, prefix: str) -> Dict[str, str]: + name_mapping = {} + for tensor_name in self.tensor_contexts.keys(): + new_name = self.get_name(tensor_name, device_id) + if new_name is not None: + name_mapping[f"{prefix}{tensor_name}"] = new_name + return name_mapping + + def add_device_dims_for_shape(self, tensor_name: str, + device_dims: Sequence[int]): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].device_dims_for_shape.update( + device_dims) + + def get_device_dims_for_shape(self, tensor_name: str): + if tensor_name not in self.tensor_contexts: + return set() + return self.tensor_contexts[tensor_name].device_dims_for_shape + + def get_split_infos(self, tensor_name, device_id): + if tensor_name not in self.tensor_contexts: + return None + if device_id not in self.tensor_contexts[tensor_name].info_by_device: + return None + return self.tensor_contexts[tensor_name].info_by_device[ + device_id].split_infos + + def set_split_info(self, tensor_name, device_id, dim, split_info): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].set_split_info( + device_id, dim, split_info) + + def set_split_infos(self, tensor_name, device_id, + split_infos: Dict[int, SplitInfo]): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].set_split_infos( + device_id, split_infos) + + def update_layer_context(self, wrapped_layer: Layer, + layer_update: LayerUpdate, + local_context: "GraphContext", device_id: int, + device_ids: np.ndarray, + sharding_specs: Dict[str, ShardingSpec]): + layer = wrapped_layer.as_trt() + for i in range(layer.num_outputs): + output = layer.get_output(i) + new_name = local_context.get_name(output.name, device_id) + if new_name is not None: + self.update_name_mapping(output.name, device_id, new_name) + if layer_update.split_info_updated: + for i in range(layer.num_outputs): + output = layer.get_output(i) + split_infos = local_context.get_split_infos( + output.name, device_id) + if split_infos is not None: + self.set_split_infos(output.name, device_id, split_infos) + return + split_info_by_device_dim = {} + for i in range(layer.num_inputs): + input = layer.get_input(i) + if input is None: + continue + sharding_spec = sharding_specs[f"input{i}"] + split_infos = local_context.get_split_infos(input.name, device_id) + if split_infos is None: + continue + for dim, split_info in split_infos.items(): + device_dim = tuple(sharding_spec.dim_partition_dict[dim]) + split_info_by_device_dim[device_dim] = split_info + for i in range(layer.num_outputs): + output = layer.get_output(i) + sharding_spec = sharding_specs[f"output{i}"] + for dim, device_dim in sharding_spec.dim_partition_dict.items(): + split_info = split_info_by_device_dim.get(tuple(device_dim)) + if split_info is None: + if device_dim == [0, 1] or device_dim == [1, 0]: + if (0, ) in split_info_by_device_dim and ( + 1, ) in split_info_by_device_dim: + split_info = SplitInfo( + split_info_by_device_dim[(0, )].input_dim * + split_info_by_device_dim[(1, )].input_dim, + split_info_by_device_dim[(0, )].partition * + split_info_by_device_dim[(1, )].partition, + ) + assert split_info is not None + partition = get_partition(device_dim, device_ids) + if split_info.input_dim != output.shape[dim]: + assert output.shape[ + dim] > 0 and output.shape[dim] % partition == 0 + output_split_info = SplitInfo(output.shape[dim], partition) + self.set_split_info(output.name, device_id, dim, + output_split_info) + + def get_local_context(self, layer: trt.ILayer) -> "GraphContext": + local_context = GraphContext() + for i in range(layer.num_inputs): + input = layer.get_input(i) + if input is None: + continue + local_context.tensor_contexts[input.name] = copy.deepcopy( + self.tensor_contexts[input.name]) + return local_context + + def get_local_context_for_output(self, + output: trt.ITensor) -> "GraphContext": + local_context = GraphContext() + local_context.tensor_contexts[output.name] = copy.deepcopy( + self.tensor_contexts[output.name]) + return local_context + + def merge_context(self, context: "GraphContext"): + self.tensor_contexts.update(context.tensor_contexts) + + +@dataclass +class ShardContext: + graph_context: GraphContext + layer: Layer + nditer: np.nditer + device_ids: np.ndarray + strategy: ShardingStrategy + + +def get_partition(device_dim, device_ids): + if device_dim == [0]: + partition = device_ids.shape[0] + elif device_dim == [1]: + partition = device_ids.shape[1] + else: + assert device_dim == [0, 1] or device_dim == [1, 0] + partition = device_ids.size + return partition + + +def get_index(device_dim, iter): + if device_dim == [0]: + index = iter.multi_index[0] + elif device_dim == [1]: + index = iter.multi_index[1] + else: + assert device_dim == [0, 1] or device_dim == [1, 0] + index = iter.iterindex + return index + + +def get_full_sharding_spec(sharding_spec): + return ShardingSpec(sharding_spec.device_mesh, + sharding_spec.data_type_size, + sharding_spec.entire_shape, + sharding_spec.max_entire_shape, + sharding_spec.raw_shape, + dim_partition_dict={}) + + +def get_comm_action_sequence(from_sharding_sepc, to_sharding_sepc): + comm_action_sequence = from_sharding_sepc.device_mesh.shape_consistency_manager.shape_consistency( + from_sharding_sepc, to_sharding_sepc)[1] + # TODO: should merged by shape_consistency + if len(comm_action_sequence) == 2: + if comm_action_sequence[0].comm_pattern == comm_action_sequence[ + 1].comm_pattern == "all_gather": + if comm_action_sequence[0].gather_dim == comm_action_sequence[ + 1].gather_dim: + comm_action_sequence = [ + CommSpec( + comm_action_sequence[0].comm_pattern, + comm_action_sequence[0].sharding_spec, + comm_action_sequence[0].gather_dim, + comm_action_sequence[0].shard_dim, [[ + *comm_action_sequence[0].logical_process_axis[0], + *comm_action_sequence[1].logical_process_axis[0] + ]], comm_action_sequence[0].mix_gather, + comm_action_sequence[0].forward_only) + ] + assert len(comm_action_sequence[0].logical_process_axis[0]) <= 2 + assert len(comm_action_sequence) <= 1 + return comm_action_sequence + + +class GraphGroup(ABC): + + @staticmethod + def from_graph( + graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> "GraphGroup": + if auto_parallel_config.debug_mode: + return PrefixedGraphGroup(graph, config, auto_parallel_config) + else: + return DistributedGraphGroup(graph, config, auto_parallel_config) + + @property + @abstractmethod + def auto_parallel_config(self) -> AutoParallelConfig: + ... + + @abstractmethod + def add_input(self, tensor, device_ids, strategy: ShardingStrategy): + ... + + @abstractmethod + def add_layer(self, layer, device_ids, strategy: ShardingStrategy): + ... + + @abstractmethod + def add_output(self, tensor, device_ids, sharding_spec: ShardingSpec): + ... + + @abstractmethod + def get_network(self, device_id) -> trt.INetworkDefinition: + ... + + @abstractmethod + def get_graph(self, device_id) -> PipelineGraph: + ... + + @property + @abstractmethod + def full_graph(self) -> PipelineGraph: + ... + + @abstractmethod + def get_prefix(self, device_id) -> str: + ... + + @abstractmethod + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + ... + + @abstractmethod + def get_values(self, device_id) -> Dict[str, List[int]]: + ... + + @abstractmethod + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + ... + + @abstractmethod + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + ... + + @abstractmethod + def register_layer(self, + layer, + base_name, + input_name, + output_name=None, + device_id=None, + keep_tensor_name=False) -> Layer: + ... + + def get_tensor(self, context: GraphContext, tensor_name: str, + device_id: int) -> Tensor: + name = context.get_name(tensor_name, device_id) + return self.get_graph(device_id).get_tensor(name) + + def add_comm(self, + context: GraphContext, + input_name, + device_ids, + commspec, + output_name=None, + is_singleton=False): + remove_index = [] + for i, device_dim in enumerate(commspec.logical_process_axis): + partition = get_partition(device_dim, device_ids) + if partition == 1: + remove_index.append(i) + if len(remove_index) > 0: + if commspec.comm_pattern in ["all_gather", "all_to_all"]: + commspec.gather_dim = [ + dim for i, dim in enumerate(commspec.gather_dim) + if i not in remove_index + ] + if commspec.comm_pattern in [ + "split", "reduce_scatter", "all_to_all" + ]: + commspec.shard_dim = [ + dim for i, dim in enumerate(commspec.shard_dim) + if i not in remove_index + ] + commspec.logical_process_axis = [ + dim for i, dim in enumerate(commspec.logical_process_axis) + if i not in remove_index + ] + flatten_device_dim = list( + itertools.chain.from_iterable(commspec.logical_process_axis)) + if flatten_device_dim == []: + return + if flatten_device_dim == [0, 1] or flatten_device_dim == [1, 0]: + self._add_comm(context, input_name, device_ids, commspec, + output_name, is_singleton) + elif flatten_device_dim == [0]: + for i in range(device_ids.shape[1]): + self._add_comm(context, input_name, device_ids[:, i:i + 1], + commspec, output_name, is_singleton) + elif flatten_device_dim == [1]: + for i in range(device_ids.shape[0]): + self._add_comm(context, input_name, device_ids[i:i + 1, :], + commspec, output_name, is_singleton) + else: + raise RuntimeError( + f"Invalid flatten device_dim: {flatten_device_dim}") + + def _add_comm(self, + context: GraphContext, + input_name, + device_ids, + commspec, + output_name=None, + is_singleton=False): + comm_pattern = commspec.comm_pattern + if comm_pattern == "split": + self.add_split(context, input_name, output_name, device_ids, + commspec.shard_dim, commspec.logical_process_axis) + elif comm_pattern == "all_gather": + self.add_all_gather(context, input_name, output_name, device_ids, + commspec.gather_dim, + commspec.logical_process_axis, is_singleton) + elif comm_pattern == "all_reduce": + self.add_all_reduce(context, input_name, output_name, device_ids) + elif comm_pattern == "reduce_scatter": + self.add_reduce_scatter(context, input_name, output_name, + device_ids, commspec.shard_dim, + commspec.logical_process_axis) + elif comm_pattern == "all_to_all": + self.add_all_to_all(context, input_name, output_name, device_ids, + commspec.gather_dim, commspec.shard_dim, + commspec.logical_process_axis) + else: + raise NotImplementedError + + def add_all_reduce(self, context: GraphContext, input_name, output_name, + device_ids): + builder_flags = get_builder_flags() + if builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0: + dtype = trt.DataType.HALF + elif builder_flags & (1 << int(trt.BuilderFlag.BF16)) != 0: + dtype = trt.DataType.BF16 + else: + dtype = trt.DataType.FLOAT + fast_reduce = self.auto_parallel_config.fast_reduce + if fast_reduce: + logger.debug(f"all_reduce with {dtype} after {input_name}") + + to_reduce_tensors = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + input_dtype = input_tensor.dtype + if fast_reduce: + to_reduce_tensor = self.cast( + network, + input_tensor, + dtype, + layer_info, + ) + else: + to_reduce_tensor = input_tensor + to_reduce_tensors.append(to_reduce_tensor) + self.add_all_reduce_layer(context, input_name, output_name, device_ids, + to_reduce_tensors) + if fast_reduce and input_dtype != dtype: + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor( + context, + input_name, + device_id, + ).as_trt() + output_tensor = self.cast( + network, + input_tensor, + input_dtype, + layer_info, + ) + context.update_name_mapping( + input_name, + device_id, + output_tensor.name, + ) + + def add_reduce_scatter(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + self.add_all_reduce(context, input_name, output_name, device_ids) + self.add_split(context, input_name, output_name, device_ids, shard_dims, + device_dims) + + # TODO: use native all_to_all operation + def add_all_to_all(self, context: GraphContext, input_name, output_name, + device_ids, gather_dims, shard_dims, device_dims): + self.add_all_gather(context, input_name, output_name, device_ids, + gather_dims, device_dims) + self.add_split(context, input_name, output_name, device_ids, shard_dims, + device_dims) + + def get_item(self, network, tensor, index, layer_info): + get_item_layer = network.add_slice(tensor, [index], [1], [1]) + self.register_layer(get_item_layer, f"get_item{index}", *layer_info) + return get_item_layer.get_output(0) + + def get_shape(self, network, tensor, layer_info): + shape_layer = network.add_shape(tensor) + self.register_layer(shape_layer, "shape", *layer_info) + return shape_layer.get_output(0) + + def concat(self, network, tensors, layer_info): + concat_layer = network.add_concatenation(tensors) + self.register_layer(concat_layer, "concat", *layer_info) + return concat_layer.get_output(0) + + def flatten(self, network, tensor, layer_info): + shuffle_layer = network.add_shuffle(tensor) + shuffle_layer.reshape_dims = [-1] + shuffle_layer.zero_is_placeholder = False + self.register_layer(shuffle_layer, "flatten", *layer_info) + return shuffle_layer.get_output(0) + + def reshape(self, network, tensor, reshape_dims, layer_info): + reshape_layer = network.add_shuffle(tensor) + reshape_layer.set_input(1, reshape_dims) + reshape_layer.zero_is_placeholder = False + self.register_layer(reshape_layer, "reshape", *layer_info) + return reshape_layer.get_output(0) + + def cast(self, network, tensor, dtype, layer_info): + if tensor.dtype == dtype: + return tensor + cast_layer = network.add_cast(tensor, dtype) + self.register_layer(cast_layer, "cast", *layer_info) + return cast_layer.get_output(0) + + def const_int(self, network, name, value, layer_info): + const_layer = network.add_constant([1], np.array([value], + dtype=np.int32)) + self.register_layer(const_layer, name, *layer_info) + return const_layer.get_output(0) + + def get_dim_size(self, network, tensor, dim, layer_info, shape_tensor=None): + raw_shape = tensor.shape + dim_size = raw_shape[dim] + if dim_size != -1: + return dim_size + else: + if shape_tensor is None: + shape_tensor = self.get_shape(network, tensor, layer_info) + return self.get_item(network, shape_tensor, dim, layer_info) + + def add_split(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + raw_input_shape = input_tensor.shape + start = [] + output_dims = [] + stride = [] + input_shape_tensor = self.get_shape(network, input_tensor, + layer_info) + for dim in range(len(raw_input_shape)): + stride.append(1) + if dim not in shard_dims: + start.append(0) + output_dims.append( + self.get_item(network, input_shape_tensor, dim, + layer_info)) + else: + start.append(None) + output_dims.append(None) + + for dim, device_dim in zip(shard_dims, device_dims): + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, it) + input_dim = raw_input_shape[dim] + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + start[dim] = index * quotient + output_dims[dim] = self.const_int(network, f"output_dim{dim}", + quotient, layer_info) + context.set_split_info(input_name, device_id, dim, + SplitInfo(input_dim, partition)) + output_dims_tensor = self.concat(network, output_dims, layer_info) + split_layer = network.add_slice(input_tensor, start, [], stride) + split_layer.set_input(2, output_dims_tensor) + wrapped_layer = self.register_layer(split_layer, "split", + *layer_info) + wrapped_layer.attrs["strategy"] = get_sharding_sequence( + len(raw_input_shape), + shard_dims, + device_dims, + ) + + output_tensor = split_layer.get_output(0) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather(self, + context: GraphContext, + input_name, + output_name, + device_ids, + gather_dims, + device_dims, + is_singleton=False): + to_gather_tensors = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + to_gather_tensor = self.flatten(network, input_tensor, layer_info) + to_gather_tensors.append(to_gather_tensor) + + all_gather_layers = self.add_all_gather_layer( + context, + input_name, + output_name, + device_ids, + to_gather_tensors, + ) + + if len(device_dims) == 1: + gather_indices = [0] + elif len(device_dims) == 2 and device_dims[0] == [1]: + gather_indices = [1, 0] + else: + gather_indices = [0, 1] + + for device_id, all_gather_layer in zip(np.nditer(device_ids), + all_gather_layers): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + permutation = [] + gathered_dims = [] + output_dims = [] + partitions = [] + raw_input_shape = input_tensor.shape + + wrapped_layer = self.get_graph(device_id).get_layer( + all_gather_layer.name) + wrapped_layer.attrs["strategy"] = get_sharding_sequence( + len(raw_input_shape), + gather_dims, + device_dims, + ) + + input_shape_layer = network.add_shape(input_tensor) + self.register_layer(input_shape_layer, "input_shape", *layer_info) + input_shape_tensor = input_shape_layer.get_output(0) + split_infos = context.get_split_infos(input_name, device_id) + for index in gather_indices: + gather_dim = gather_dims[index] + device_dim = device_dims[index] + partition = get_partition(device_dim, device_ids) + assert partition == split_infos[gather_dim].partition + partitions.append( + self.const_int(network, f"partition_num{gather_dim}", + partition, layer_info)) + for dim in range(len(raw_input_shape)): + if dim in gather_dims: + gather_index = gather_dims.index(dim) + device_dim = device_dims[gather_index] + permutation.append(gather_indices.index(gather_index)) + permutation.append(dim + len(gather_dims)) + if dim not in split_infos: + output_dim_layer = network.add_slice( + input_shape_tensor, [dim], [1], [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + dim_tensor = output_dim_layer.get_output(0) + output_dims.append(dim_tensor) + gathered_dims.append(dim_tensor) + else: + input_dim = split_infos[dim].input_dim + partition = split_infos[dim].partition + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + output_dims.append( + self.const_int(network, f"output_dim{dim}", quotient, + layer_info)) + if dim in gather_dims: + gathered_dims.append( + self.const_int(network, f"gathered_dim{dim}", + quotient * partition, layer_info)) + del split_infos[dim] + else: + gathered_dims.append(output_dim_layer.get_output(0)) + + reshape_dims_for_transpose_layer = network.add_concatenation( + [*partitions, *output_dims]) + self.register_layer(reshape_dims_for_transpose_layer, + "reshape_dims_for_transpose", *layer_info) + reshape_dims_tensor = reshape_dims_for_transpose_layer.get_output(0) + transpose_layer = network.add_shuffle( + all_gather_layer.get_output(0)) + transpose_layer.set_input(1, reshape_dims_tensor) + transpose_layer.second_transpose = permutation + transpose_layer.zero_is_placeholder = False + self.register_layer(transpose_layer, "transpose", *layer_info) + + reshape_dims_for_reshape_layer = network.add_concatenation( + gathered_dims) + self.register_layer(reshape_dims_for_reshape_layer, + "reshape_dims_for_reshape", *layer_info) + reshape_dims_tensor = reshape_dims_for_reshape_layer.get_output(0) + output_tensor = self.reshape( + network, + transpose_layer.get_output(0), + reshape_dims_tensor, + layer_info, + ) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + if is_singleton: + break + + def register_unfilled_weights(self, graph, layer): + if (layer.name in self.full_graph._unfilled_weights + and layer.name not in graph._unfilled_weights): + weights, values = self.full_graph._unfilled_weights[layer.name] + graph._register_unfilled_weights( + layer.name, + weights, + values, + ) + + def shard_constant(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + device_id = context.nditer.value.item() + device_ids = context.device_ids + layer = context.layer.as_trt() + graph = self.get_graph(device_id) + if len(shard_dims) == 0: + self.register_unfilled_weights(graph, layer) + return LayerUpdate(split_info_updated=True) + flatten_device_dim = list( + itertools.chain.from_iterable(shard_dims.values())) + output_name = layer.get_output(0).name + output_dtype = layer.get_output(0).dtype + output_shape = layer.shape + output_dims = [] + weight_index = [] + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, context.nditer) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + weight_index.append( + slice(index * quotient, (index + 1) * quotient)) + context.graph_context.set_split_info( + output_name, device_id, dim, + SplitInfo(output_dim, partition)) + else: + output_dims.append(output_dim) + weight_index.append(slice(None)) + if layer.name in self.full_graph._unfilled_weights: + values = self.full_graph._unfilled_weights[layer.name][1] + else: + values = layer.weights + if isinstance(values, trt.Weights): + values = values.numpy() + # TODO: remove this WAR after https://nvbugs/4359151 fixed. + if isinstance(values, trt.Weights): + network = context.layer.graph.as_trt() + values = get_np_weight(network, layer.name) + if values is not None: + values = values.reshape(layer.shape) + assert values.size == np.prod(layer.shape) + sharded_values = values[tuple(weight_index)] + assert sharded_values.size * get_partition( + flatten_device_dim, device_ids) == np.prod(layer.shape) + else: + sharded_values = None + dtype = trt_dtype_to_np(output_dtype) + sharded_weights = np.empty(tuple(output_dims), dtype) + graph._register_unfilled_weights( + f"device{device_id}_{layer.name}", + sharded_weights, + sharded_values, + ) + sharded_weights = to_trt_weights(sharded_weights) + return LayerUpdate( + updated_attrs=dict( + shape=trt.Dims(output_dims), + weights=sharded_weights, + ), + split_info_updated=True, + ) + + def shard_fill(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate(split_info_updated=True) + device_id = context.nditer.value.item() + device_ids = context.device_ids + layer = context.layer.as_trt() + output_name = layer.get_output(0).name + output_shape = layer.shape + output_dims = [] + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + context.graph_context.set_split_info( + output_name, device_id, dim, + SplitInfo(output_dim, partition)) + else: + output_dims.append(output_dim) + return LayerUpdate( + updated_attrs=dict(shape=trt.Dims(output_dims), ), + split_info_updated=True, + ) + + def update_shape(self, context: ShardContext): + if not context.layer.is_shape_io: + return + layer = context.layer.as_trt() + input_name = layer.get_input(0).name + output_name = layer.get_output(0).name + device_id = context.nditer.value.item() + layer_info = (output_name, None, device_id) + split_infos = context.graph_context.get_split_infos( + input_name, device_id) + if len(split_infos) == 0: + return + network = self.get_network(device_id) + shape_tensor = self.get_tensor(context.graph_context, output_name, + device_id).as_trt() + output_dims = [] + for dim in range(len(context.layer.get_input(0).shape)): + if dim not in split_infos: + output_dim_layer = network.add_slice(shape_tensor, [dim], [1], + [1]) + else: + input_dim = split_infos[dim].input_dim + output_dim_layer = network.add_constant([1], + np.array( + [input_dim], + dtype=np.int32)) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + output_dims.append(output_dim_layer.get_output(0)) + new_shape_layer = network.add_concatenation(output_dims) + self.register_layer(new_shape_layer, "new_shape", *layer_info) + new_shape_tensor = new_shape_layer.get_output(0) + context.graph_context.update_name_mapping(output_name, device_id, + new_shape_tensor.name) + + def shard_slice(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + layer = context.layer.as_trt() + output_dims = [] + updated_attrs = {} + updated_inputs = {} + + if layer.num_inputs >= 3: + raw_output_shape = layer.get_output(0).shape + input_name = layer.get_input(2).name + layer_info = (input_name, layer.name, device_id) + shape_tensor = self.get_tensor(context.graph_context, input_name, + device_id).as_trt() + for dim in range(len(raw_output_shape)): + output_dim_layer = network.add_slice(shape_tensor, [dim], [1], + [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + partition_num_tensor = self.const_int( + network, f"partition_num{dim}", partition, layer_info) + quotient_layer = network.add_elementwise( + output_dim_layer.get_output(0), partition_num_tensor, + trt.ElementWiseOperation.FLOOR_DIV) + self.register_layer(quotient_layer, f"quotient{dim}", + *layer_info) + output_dim = self.cast(network, + quotient_layer.get_output(0), + trt.DataType.INT32, layer_info) + output_dims.append(output_dim) + else: + output_dims.append(output_dim_layer.get_output(0)) + output_dims_layer = network.add_concatenation(output_dims) + self.register_layer(output_dims_layer, "output_dims", *layer_info) + updated_inputs[2] = output_dims_layer.get_output(0) + else: + output_shape = layer.shape + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + assert output_dim != -1 + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + else: + output_dims.append(output_dim) + updated_attrs["shape"] = trt.Dims(output_dims) + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_shuffle(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + layer = context.layer.as_trt() + updated_attrs = {} + updated_inputs = {} + updated_reshape_dims = {} + second_transpose = layer.second_transpose + + if layer.num_inputs >= 2: + raw_output_shape = layer.get_output(0).shape + input_name = layer.get_input(1).name + layer_info = (input_name, layer.name, device_id) + reshape_dims_tensor = self.get_tensor(context.graph_context, + input_name, device_id) + reshape_dims = context.layer.get_input(1).value + reshape_dims_tensor = reshape_dims_tensor.as_trt() + for dim in range(len(raw_output_shape)): + if second_transpose is not None: + reshape_dim = second_transpose[dim] + else: + reshape_dim = dim + output_dim_layer = network.add_slice(reshape_dims_tensor, + [reshape_dim], [1], [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + output_dim = reshape_dims[reshape_dim] + if dim in shard_dims and output_dim != -1: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + partition_num_tensor = self.const_int( + network, f"partition_num{dim}", partition, layer_info) + quotient_layer = network.add_elementwise( + output_dim_layer.get_output(0), partition_num_tensor, + trt.ElementWiseOperation.FLOOR_DIV) + self.register_layer(quotient_layer, f"quotient{dim}", + *layer_info) + updated_reshape_dims[reshape_dim] = self.cast( + network, + quotient_layer.get_output(0), + trt.DataType.INT32, + layer_info, + ) + else: + updated_reshape_dims[ + reshape_dim] = output_dim_layer.get_output(0) + updated_reshape_dims = list( + map(lambda x: x[1], sorted(updated_reshape_dims.items()))) + reshape_dims_layer = network.add_concatenation(updated_reshape_dims) + self.register_layer(reshape_dims_layer, "reshape_dims", *layer_info) + updated_inputs[1] = reshape_dims_layer.get_output(0) + else: + reshape_dims = layer.reshape_dims + if reshape_dims.__len__() < 0: + return LayerUpdate.none() + for dim in range(len(reshape_dims)): + if second_transpose is not None: + reshape_dim = second_transpose[dim] + else: + reshape_dim = dim + output_dim = reshape_dims[reshape_dim] + if dim in shard_dims and output_dim != -1: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + quotient = output_dim // partition + updated_reshape_dims[reshape_dim] = quotient + else: + updated_reshape_dims[reshape_dim] = output_dim + updated_reshape_dims = list( + map(lambda x: x[1], sorted(updated_reshape_dims.items()))) + updated_attrs["reshape_dims"] = trt.Dims(updated_reshape_dims) + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_gpt_attention(self, context: ShardContext): + layer = context.layer.as_trt() + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + parser = IdxEntryParser(plugin_info) + head_dim = 1 if parser.remove_input_padding else 2 + sharding_spec = context.strategy.sharding_specs[ + f"input{parser.get_index(IdxEntry.QKV_TENSOR)}"] + shard_dims = sharding_spec.dim_partition_dict + if head_dim not in shard_dims: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + updated_attrs = {} + updated_inputs = {} + device_dim = shard_dims[head_dim] + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, context.nditer) + if parser.is_entry_used(IdxEntry.K_TENSOR): + kv_sharding_spec = context.strategy.sharding_specs[ + f"input{parser.get_index(IdxEntry.K_TENSOR)}"] + kv_shard_dims = kv_sharding_spec.dim_partition_dict + if head_dim in kv_shard_dims: + kv_device_dim = kv_shard_dims[head_dim] + kv_partition = get_partition(kv_device_dim, device_ids) + else: + kv_partition = 1 + else: + kv_partition = 1 + num_heads = plugin_info.pfc_as_ndarray["num_heads"].copy() + num_kv_heads = plugin_info.pfc_as_ndarray["num_kv_heads"].copy() + tp_size = plugin_info.pfc_as_ndarray["tp_size"].copy() + tp_rank = plugin_info.pfc_as_ndarray["tp_rank"].copy() + num_kv_heads = num_kv_heads // kv_partition + num_heads = num_heads // partition + tp_size[0] = partition + tp_rank[0] = index + + new_plugin, new_plugin_info = get_updated_plugin( + plugin_info, + dict( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + tp_size=tp_size, + tp_rank=tp_rank, + )) + prefix = self.get_prefix(device_id) + new_layer_name = f"{prefix}{layer.name}" + set_plugin_info(network, new_layer_name, new_plugin_info) + updated_attrs["plugin"] = new_plugin + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_lookup(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["input1"] + shard_dims = sharding_spec.dim_partition_dict + if 0 not in shard_dims: + return LayerUpdate.none() + layer = context.layer.as_trt() + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + device_id = context.nditer.value.item() + network = self.get_network(device_id) + updated_attrs = {} + device_dim = shard_dims[0] + index = get_index(device_dim, context.nditer) + rank = plugin_info.pfc_as_ndarray["rank"].copy() + rank[0] = index + + new_plugin, new_plugin_info = get_updated_plugin( + plugin_info, dict(rank=rank, )) + prefix = self.get_prefix(device_id) + new_layer_name = f"{prefix}{layer.name}" + set_plugin_info(network, new_layer_name, new_plugin_info) + updated_attrs["plugin"] = new_plugin + return LayerUpdate(updated_attrs) + + +class GraphGroupBase(GraphGroup): + + def __init__( + self, + full_graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> None: + self._full_graph = full_graph + self.config = config + self._auto_parallel_config = auto_parallel_config + self.infer_shape = auto_parallel_config.infer_shape + self.global_context = GraphContext() + self.shape_cache = {} + self.suffix = 0 + self.current_block_id = -1 + + @property + def auto_parallel_config(self) -> AutoParallelConfig: + return self._auto_parallel_config + + @property + def full_graph(self) -> PipelineGraph: + return self._full_graph + + def register_layer(self, + layer, + base_name, + input_name, + output_name=None, + device_id=None, + keep_tensor_name=False) -> Layer: + layer_name = f"{base_name}_{input_name}" + if device_id is not None: + layer_name = f"{self.get_prefix(device_id)}{layer_name}" + if output_name is not None: + layer_name = f"{layer_name}_to_{output_name}" + suffix = self.suffix + self.suffix += 1 + layer_name = f"{layer_name}_{suffix}" + if layer.type == trt.LayerType.PLUGIN_V2: + network = self.get_network(device_id) + plugin_info = get_plugin_info(network, layer.name) + if plugin_info is not None: + set_plugin_info(network, layer_name, plugin_info) + delete_plugin_info(network, layer.name) + layer.name = layer_name + if not keep_tensor_name: + for i in range(layer.num_outputs): + output_tensor = layer.get_output(i) + assert output_tensor.shape.__len__() >= 0 + output_tensor.name = f"{layer.name}_output_{i}" + wrapped_layer = self.get_graph(device_id).register_layer(layer) + if self.current_block_id != -1: + wrapped_layer.attrs["block_id"] = self.current_block_id + wrapped_layer.attrs["role"] = "helper" + if self.infer_shape: + infer_per_layer_shapes( + layer, + self.get_shapes(device_id), + self.get_values(device_id), + self.shape_cache, + is_shape_io=True, + ) + wrapped_layer.assign_shapes( + self.get_shapes(device_id), + self.get_values(device_id), + ) + return wrapped_layer + + def add_layer(self, wrapped_layer: Layer, device_ids, + strategy: ShardingStrategy): + layer = wrapped_layer.as_trt() + local_context = self.global_context.get_local_context(layer) + self.current_block_id = wrapped_layer.attrs["block_id"] + + for i, input in enumerate(wrapped_layer.inputs): + if input is None: + continue + if i not in strategy.best_resharding_cost: + continue + comm_action_sequence = strategy.best_resharding_cost[i][0][1] + for commspec in comm_action_sequence: + self.add_comm(local_context, + input.name, + device_ids, + commspec, + output_name=layer.name) + + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + + layer_type = layer.type + to_subclass_layer(layer) + shard_context = ShardContext( + local_context, + wrapped_layer, + it, + device_ids, + strategy, + ) + if layer_type == trt.LayerType.CONSTANT: + layer_update = self.shard_constant(shard_context) + elif layer_type == trt.LayerType.FILL: + layer_update = self.shard_fill(shard_context) + elif layer_type == trt.LayerType.SLICE: + layer_update = self.shard_slice(shard_context) + elif layer_type == trt.LayerType.SHUFFLE: + layer_update = self.shard_shuffle(shard_context) + elif layer_type == trt.LayerType.PLUGIN_V2: + if layer.plugin.plugin_type == "GPTAttention": + layer_update = self.shard_gpt_attention(shard_context) + elif layer.plugin.plugin_type == "Lookup": + layer_update = self.shard_lookup(shard_context) + else: + layer_update = LayerUpdate.none() + else: + layer_update = LayerUpdate.none() + to_base_class_layer(layer) + + for i, updated_input in layer_update.updated_inputs.items(): + input_name = layer.get_input(i).name + local_context.update_name_mapping(input_name, device_id, + updated_input.name) + + prefix = self.get_prefix(device_id) + new_wrapped_layer = self.get_graph(device_id).add_layer( + layer, + prefix=prefix, + input_mapping=local_context.get_name_mapping(device_id, + prefix=prefix), + updated_attrs=layer_update.updated_attrs, + ) + new_wrapped_layer.attrs["strategy"] = strategy.name + new_wrapped_layer.attrs["block_id"] = self.current_block_id + new_layer = new_wrapped_layer.as_trt() + + if self.infer_shape: + infer_per_layer_shapes( + new_layer, + self.get_shapes(device_id), + self.get_values(device_id), + self.shape_cache, + is_shape_io=wrapped_layer.is_shape_io, + ) + new_wrapped_layer.assign_shapes( + self.get_shapes(device_id), + self.get_values(device_id), + ) + + for i in range(layer.num_outputs): + output_tensor = new_layer.get_output(i) + assert output_tensor.shape.__len__() >= 0 + local_context.update_name_mapping( + layer.get_output(i).name, device_id, output_tensor.name) + + if layer.type == trt.LayerType.SHAPE: + self.update_shape(shard_context) + + self.global_context.update_layer_context( + wrapped_layer, + layer_update, + local_context, + device_id, + device_ids, + strategy.sharding_specs, + ) + + for i in range(layer.num_outputs): + commspecs = strategy.communication_actions.get(f"output{i}") + if commspecs is None: + continue + output = layer.get_output(i) + for commspec in commspecs: + self.add_comm( + self.global_context, + output.name, + device_ids, + commspec, + ) + + self.current_block_id = -1 + + +class DistributedGraphGroup(GraphGroupBase): + + def __init__( + self, + full_graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> None: + super().__init__(full_graph, config, auto_parallel_config) + self.graphs = {} + self.io_tensor_shards = {} + self.shapes_by_device = {} + self.values_by_device = {} + self.use_custom_all_reduce = False + phy_mesh = config.graph_config.phy_mesh + device_ids = phy_mesh.phy_devices_id + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = PipelineGraph.create_graph() + graph._auto_parallel_config = { + "io_shards": {}, + "mapping": + Mapping( + world_size=device_ids.size, + rank=device_id, + gpus_per_node=device_ids.shape[1], + tp_size=device_ids.size // config.graph_config.num_stages, + pp_size=config.graph_config.num_stages, + ), + } + self.graphs[device_id] = graph + self.shapes_by_device[device_id] = {} + self.values_by_device[device_id] = {} + + @contextlib.contextmanager + def disable_infer_shape(self): + infer_shape = self.infer_shape + self.infer_shape = False + yield + self.infer_shape = infer_shape + + def get_network(self, device_id) -> trt.INetworkDefinition: + return self.graphs[device_id].as_trt() + + def get_graph(self, device_id) -> PipelineGraph: + return self.graphs[device_id] + + def get_prefix(self, device_id) -> str: + return "" + + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + return self.shapes_by_device[device_id] + + def get_values(self, device_id) -> Dict[str, List[int]]: + return self.values_by_device[device_id] + + def add_reduce_scatter(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + builder_flags = get_builder_flags() + if builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0: + dtype = trt.DataType.HALF + elif builder_flags & (1 << int(trt.BuilderFlag.BF16)) != 0: + dtype = trt.DataType.BF16 + else: + dtype = trt.DataType.FLOAT + fast_reduce = self.auto_parallel_config.fast_reduce + if fast_reduce: + logger.debug(f"reduce_scatter with {dtype} after {input_name}") + + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + raw_input_shape = input_tensor.shape + input_shape_tensor = self.get_shape(network, input_tensor, + layer_info) + if shard_dims != [0]: + permutation = list(range(len(raw_input_shape))) + for dim in shard_dims: + permutation.remove(dim) + permutation = shard_dims + permutation + transpose_layer = network.add_shuffle(input_tensor) + transpose_layer.second_transpose = permutation + self.register_layer(transpose_layer, "input_transpose", + *layer_info) + input_tensor = transpose_layer.get_output(0) + flatten_tensor = self.flatten(network, input_tensor, layer_info) + input_dtype = flatten_tensor.dtype + if fast_reduce: + to_reduce_tensor = self.cast( + network, + flatten_tensor, + dtype, + layer_info, + ) + else: + to_reduce_tensor = flatten_tensor + + reduce_scatter_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert reduce_scatter_plg_creator is not None + + group = trt.PluginField( + "group", + np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField( + "type_id", np.array([int(to_reduce_tensor.dtype)], np.int32), + trt.PluginFieldType.INT32) + + pfc = trt.PluginFieldCollection([group, pf_type]) + rs_plug = reduce_scatter_plg_creator.create_plugin( + "reduce_scatter", pfc) + + reduce_scatter_layer = network.add_plugin_v2([to_reduce_tensor], + rs_plug) + plugin_info = PluginInfo(reduce_scatter_plg_creator, + "reduce_scatter", pfc) + set_plugin_info(network, reduce_scatter_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + reduce_scatter_layer, + "reduce_scatter", + *layer_info, + ).get_output(0) + reduce_scatter_tensor = reduce_scatter_layer.get_output(0) + if self.infer_shape: + shape = self.shapes_by_device[device_id][to_reduce_tensor.name] + assert len(shape) == 1 + output_shape = (shape[0] // device_ids.size, ) + self.shapes_by_device[device_id][ + reduce_scatter_tensor.name] = output_shape + wrapped_tensor.shape = output_shape + if fast_reduce: + reduce_scatter_tensor = self.cast( + network, + reduce_scatter_tensor, + input_dtype, + layer_info, + ) + + start = [] + output_dims = [] + stride = [] + for dim in range(len(raw_input_shape)): + stride.append(1) + if dim not in shard_dims: + start.append(0) + output_dims.append( + self.get_item(network, input_shape_tensor, dim, + layer_info)) + else: + start.append(None) + output_dims.append(None) + + for dim, device_dim in zip(shard_dims, device_dims): + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, it) + input_dim = raw_input_shape[dim] + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + start[dim] = index * quotient + output_dims[dim] = self.const_int(network, f"output_dim{dim}", + quotient, layer_info) + context.set_split_info(input_name, device_id, dim, + SplitInfo(input_dim, partition)) + if shard_dims != [0]: + output_dims = [ + output_dims[permutation[i]] for i in range(len(output_dims)) + ] + output_dims_tensor = self.concat(network, output_dims, layer_info) + output_tensor = self.reshape( + network, + reduce_scatter_tensor, + output_dims_tensor, + layer_info, + ) + if shard_dims != [0]: + transpose_layer = network.add_shuffle(output_tensor) + transpose_layer.second_transpose = permutation + self.register_layer(transpose_layer, "output_transpose", + *layer_info) + output_tensor = transpose_layer.get_output(0) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + if self.use_custom_all_reduce: + all_reduce_instance_id = current_all_reduce_helper().gen_id() + for device_id, to_reduce_tensor in zip(np.nditer(device_ids), + to_reduce_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + graph = self.get_graph(device_id) + if self.use_custom_all_reduce: + strategy = AllReduceStrategy.AUTO + else: + strategy = AllReduceStrategy.RING + allreduce_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('AllReduce', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert allreduce_plg_creator is not None + + group = trt.PluginField( + "group", + np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField( + "type_id", np.array([int(to_reduce_tensor.dtype)], np.int32), + trt.PluginFieldType.INT32) + pf_strategy = trt.PluginField("strategy", + np.array([int(strategy)], np.int8), + trt.PluginFieldType.INT8) + pfc = [group, pf_type, pf_strategy] + if self.use_custom_all_reduce: + pf_counter = trt.PluginField( + "counter", + np.array([all_reduce_instance_id], np.int32), + trt.PluginFieldType.INT32, + ) + pfc.append(pf_counter) + + pfc = trt.PluginFieldCollection(pfc) + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) + + inputs = [to_reduce_tensor] + if self.use_custom_all_reduce: + workspace = graph.get_input("all_reduce_workspace").as_trt() + inputs.append(workspace) + all_reduce_layer = network.add_plugin_v2(inputs, ar_plug) + plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc) + set_plugin_info(network, all_reduce_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + all_reduce_layer, + "all_reduce", + *layer_info, + ).get_output(0) + output_tensor = all_reduce_layer.get_output(0) + if self.infer_shape: + shape = self.shapes_by_device[device_id][to_reduce_tensor.name] + self.shapes_by_device[device_id][output_tensor.name] = shape + wrapped_tensor.shape = shape + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + all_gather_layers = [] + for device_id, to_gather_tensor in zip(np.nditer(device_ids), + to_gather_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + + allgather_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert allgather_plg_creator is not None + + group = trt.PluginField( + "group", + np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField( + "type_id", np.array([int(to_gather_tensor.dtype)], np.int32), + trt.PluginFieldType.INT32) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + + all_gather_layer = network.add_plugin_v2([to_gather_tensor], + allgather) + plugin_info = PluginInfo(allgather_plg_creator, "allgather", pfc) + set_plugin_info(network, all_gather_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + all_gather_layer, + "all_gather", + *layer_info, + ).get_output(0) + if self.infer_shape: + output_tensor = all_gather_layer.get_output(0) + shape = self.shapes_by_device[device_id][to_gather_tensor.name] + assert len(shape) == 1 + output_shape = (shape[0] * device_ids.size, ) + self.shapes_by_device[device_id][ + output_tensor.name] = output_shape + wrapped_tensor.shape = output_shape + all_gather_layers.append(all_gather_layer) + return all_gather_layers + + def set_shard_num(self, tensor_name, dim, shard_num): + for graph in self.graphs.values(): + io_shards = graph._auto_parallel_config["io_shards"] + if tensor_name not in io_shards: + io_shards[tensor_name] = {} + io_shards[tensor_name][dim] = shard_num + + def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): + context = self.global_context + sharding_spec = strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + self.set_shard_num(tensor.name, dim, partition) + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = self.get_graph(device_id) + new_input = graph.add_input(tensor.as_trt()) + shape = [*tensor.shape] + if len(shard_dims) != 0: + output_shape = [*tensor.raw_shape] + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + output_dim = output_shape[dim] + assert output_dim != -1 + assert output_dim % partition == 0 + quotient = output_dim // partition + output_shape[dim] = quotient + shape[dim] = quotient + assert tensor.value is None + context.set_split_info(tensor.name, device_id, dim, + SplitInfo(output_dim, partition)) + new_input.raw_shape = output_shape + context.update_name_mapping(tensor.name, device_id, tensor.name) + if self.infer_shape: + self.shapes_by_device[device_id][tensor.name] = tuple(shape) + new_input.shape = tuple(shape) + if tensor.value is not None: + self.values_by_device[device_id][tensor.name] = tensor.value + new_input.value = tensor.value + + def add_output(self, tensor: Tensor, device_ids, + strategy: ShardingStrategy): + comm_action_sequence = strategy.best_resharding_cost[0][0][1] + for commspec in comm_action_sequence: + self.add_comm(self.global_context, tensor.name, device_ids, + commspec) + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = self.get_graph(device_id) + output_name = tensor.name + new_output_name = self.global_context.get_name( + output_name, device_id) + if new_output_name != output_name: + suffix = self.suffix + self.suffix += 1 + original_name = f"original_{output_name}_{suffix}" + original_tensor = graph.get_tensor(output_name) + original_tensor.as_trt().name = original_name + output_tensor = graph.get_tensor(new_output_name) + output_tensor.as_trt().name = output_name + graph._tensors[original_name] = original_tensor + graph._tensors[output_name] = output_tensor + del graph._tensors[new_output_name] + else: + output_tensor = graph.get_tensor(output_name) + trt_output = output_tensor.as_trt() + if trt_output.is_shape_tensor: + graph.add_output_shape(trt_output) + else: + graph.add_output(trt_output) + + shard_dims = strategy.sharding_specs["input0"].dim_partition_dict + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + self.set_shard_num(tensor.name, dim, partition) + + +class PrefixedGraphGroup(GraphGroupBase): + + def __init__( + self, + full_graph: PipelineGraph = None, + config: ParallelConfig = None, + auto_parallel_config: AutoParallelConfig = None, + ) -> None: + auto_parallel_config = auto_parallel_config or dict( + infer_shape=False, + validation_mode=False, + ) + super().__init__(full_graph, config, auto_parallel_config) + self.validation_mode = auto_parallel_config.validation_mode + if not self.infer_shape: + self.validation_mode = False + self.prefixed_graph = PipelineGraph.create_graph() + if self.validation_mode: + self.layer_mapping = config.graph_config.graph_mapping.layer_mapping + self.graph_strategy = config.graph_strategy + self.shapes = {} + self.values = {} + self.timing_cache = None + + def get_network(self, device_id) -> trt.INetworkDefinition: + return self.prefixed_graph.as_trt() + + def get_graph(self, device_id) -> PipelineGraph: + return self.prefixed_graph + + def get_prefix(self, device_id) -> str: + return f"device{device_id}_" + + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + return self.shapes + + def get_values(self, device_id) -> Dict[str, List[int]]: + return self.values + + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + reshaped_tensors = [] + for device_id, to_reduce_tensor in zip(np.nditer(device_ids), + to_reduce_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + reshape_dims_tensor = self.concat( + network, + [ + self.get_shape(network, to_reduce_tensor, layer_info), + self.const_int(network, "expanded_dim", 1, layer_info) + ], + layer_info, + ) + reshaped_tensor = self.reshape( + network, + to_reduce_tensor, + reshape_dims_tensor, + layer_info, + ) + reshaped_tensors.append(reshaped_tensor) + + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + input_tensor = self.get_tensor(context, input_name, 0).as_trt() + num_dims = len(input_tensor.shape) + network = self.get_network(device_id) + concat_layer = network.add_concatenation(reshaped_tensors) + concat_layer.axis = num_dims + self.register_layer(concat_layer, "concat", *layer_info) + reduce_layer = network.add_reduce(concat_layer.get_output(0), + trt.ReduceOperation.SUM, + axes=1 << num_dims, + keep_dims=False) + dtype = to_reduce_tensors[0].dtype + reduce_layer.precision = dtype + reduce_layer.set_output_type(0, dtype) + self.register_layer(reduce_layer, "reduce", *layer_info) + output_tensor = reduce_layer.get_output(0) + + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + all_gather_layers = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + all_gather_layer = network.add_concatenation(to_gather_tensors) + all_gather_layer.axis = 0 + self.register_layer(all_gather_layer, "all_gather", *layer_info) + all_gather_layers.append(all_gather_layer) + return all_gather_layers + + def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): + + def add_identity(): + identity_layer = network.add_identity(input.as_trt()) + return identity_layer + + input = self.prefixed_graph.add_input(tensor.as_trt()) + if self.infer_shape: + self.shapes[tensor.name] = tensor.shape + input.shape = tensor.shape + if tensor.value is not None: + self.values[tensor.name] = tensor.value + input.value = tensor.value + network = self.get_network(None) + if self.validation_mode: + identity_layer = add_identity() + identity_layer.get_output(0).name = f"ref_{tensor.name}" + layer_info = (tensor.name, None, None) + self.register_layer(identity_layer, + "identity", + *layer_info, + keep_tensor_name=True) + input.attrs["strategy"] = strategy.name + sharding_spec = strategy.sharding_specs["output0"] + pre_sharding_sepc = get_full_sharding_spec(sharding_spec) + comm_action_sequence = get_comm_action_sequence(pre_sharding_sepc, + sharding_spec) + context = self.global_context + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (tensor.name, None, device_id) + context.update_name_mapping(tensor.name, device_id, tensor.name) + if len(comm_action_sequence + ) == 0 and not tensor.as_trt().is_shape_tensor: + identity_layer = add_identity() + self.register_layer(identity_layer, "identity", *layer_info) + context.update_name_mapping( + tensor.name, + device_id, + identity_layer.get_output(0).name, + ) + for commspec in comm_action_sequence: + self.add_comm(context, tensor.name, device_ids, commspec) + + def get_graph_in_range(self, graph_group, src_layer, layer_range, + device_ids, shapes, values): + src_network = self.prefixed_graph.as_trt() + graph = graph_group.prefixed_graph + network = graph.as_trt() + input_mapping = {} + for device_id in np.nditer(device_ids): + device_id = device_id.item() + for i in range(src_layer.num_inputs): + src_input = src_layer.get_input(i) + if src_input is not None: + input = self.get_tensor( + self.global_context, + src_input.name, + device_id, + ).as_trt() + if graph.get_input(src_input.name) is not None: + new_input = graph_group.get_tensor( + graph_group.global_context, + src_input.name, + device_id, + ).as_trt() + input_mapping[input.name] = new_input.name + continue + if graph.get_tensor(input.name) is not None: + continue + shape = shapes[input.name] + assert input.name in values + value = values[input.name] + weights = np.asarray(value, + dtype=trt_dtype_to_np(input.dtype)) + weights = to_trt_weights(weights) + input_layer = network.add_constant(shape, weights) + new_input = input_layer.get_output(0) + new_input.name = input.name + graph.register_layer(input_layer) + for i in layer_range: + layer = src_network.get_layer(i) + graph.add_layer(layer, input_mapping=input_mapping) + + def add_layer_singleton(self, output, device_ids, sharding_spec): + assert self.prefixed_graph.get_tensor(output.name) is None + network = self.prefixed_graph.as_trt() + full_sharding_sepc = get_full_sharding_spec(sharding_spec) + comm_action_sequence = get_comm_action_sequence(sharding_spec, + full_sharding_sepc) + output_context = self.global_context.get_local_context_for_output( + output) + if len(comm_action_sequence) != 0: + for commspec in comm_action_sequence[:-1]: + self.add_comm(output_context, output.name, device_ids, commspec) + self.add_comm( + output_context, + output.name, + device_ids, + comm_action_sequence[-1], + is_singleton=True, + ) + device_id = next(np.nditer(device_ids)).item() + layer_info = (output.name, None, device_id) + output_tensor = self.get_tensor(output_context, output.name, + device_id).as_trt() + singleton_layer = network.add_identity(output_tensor) + singleton_layer.get_output(0).name = output.name + self.register_layer(singleton_layer, + "singleton", + *layer_info, + keep_tensor_name=True) + + def add_layer(self, wrapped_layer: Layer, device_ids, + strategy: ShardingStrategy): + graph = self.prefixed_graph + network = graph.as_trt() + start_layer_id = network.num_layers + + super().add_layer(wrapped_layer, device_ids, strategy) + + layer = wrapped_layer.as_trt() + + if self.validation_mode: + is_shape = (wrapped_layer.is_shape_io + or layer.type == trt.LayerType.SHAPE) + + if not is_shape: + self.current_block_id = wrapped_layer.attrs["block_id"] + for i, wrapped_output in enumerate(wrapped_layer.outputs): + if wrapped_output.is_graph_output: + continue + output = wrapped_output.as_trt() + output_name = f"output{i}" + if strategy.communication_actions.get( + output_name) is not None: + output_name += "_after_comm" + sharding_spec = strategy.sharding_specs[output_name] + self.add_layer_singleton(output, device_ids, sharding_spec) + self.current_block_id = -1 + end_layer_id = network.num_layers + + is_skip = (is_shape or layer.type == trt.LayerType.CONSTANT + or layer.name in self.layer_mapping) + sharded = False + for sharding_spec in strategy.sharding_specs.values(): + if len(sharding_spec.dim_partition_dict) > 0: + sharded = True + break + if not sharded: + is_skip = True + + ref_layer = graph.add_layer(layer, prefix="ref_") + ref_layer.attrs["strategy"] = strategy.name + ref_layer.attrs["block_id"] = wrapped_layer.attrs["block_id"] + if layer.type == trt.LayerType.CONSTANT: + self.register_unfilled_weights(graph, layer) + + if is_skip: + return + + logger.debug(f"validating layer {layer.name}") + + layer_type = layer.type + generated_input_values = {} + to_subclass_layer(layer) + if layer_type == trt.LayerType.PLUGIN_V2: + if layer.plugin.plugin_type == "GPTAttention": + sharding_specs = {} + for name, sharding_spec in strategy.sharding_specs.items(): + sharding_specs[name] = get_full_sharding_spec( + sharding_spec) + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + generated_input_values = GPTAttentionPlugin.parameter_generator( + sharding_specs, plugin_info) + to_base_class_layer(layer) + + validation_graph_group = PrefixedGraphGroup() + validation_graph = validation_graph_group.prefixed_graph + validation_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping + extra_input_values = {} + validation_shapes = {} + for i, wrapped_input in enumerate(wrapped_layer.inputs): + if wrapped_input is None: + continue + input = wrapped_input.as_trt() + validation_shapes[input.name] = wrapped_input.shape + if wrapped_input.value is None: + if i in generated_input_values: + extra_input_value = generated_input_values[i] + else: + extra_input_value = torch.empty( + tuple(wrapped_input.shape), + dtype=trt_dtype_to_torch(input.dtype), + device=torch.cuda.current_device(), + ) + if torch.is_floating_point(extra_input_value): + extra_input_value.normal_() + # extra_input_value[:] = random.choice([2, 3, 5, 7]) + extra_input_values[input.name] = extra_input_value + self.values[input.name] = extra_input_value + if wrapped_input.producer is not None: + node_name = wrapped_input.producer.name + output_index = wrapped_input.output_index + else: + node_name = wrapped_input.name + output_index = 0 + sharding_spec = self.graph_strategy[ + node_name].sharding_specs[f"output{output_index}"] + validation_graph_group.add_input( + wrapped_input, + device_ids, + ShardingStrategy( + sharding_specs={"output0": sharding_spec}), + ) + validation_graph.get_input( + input.name).raw_shape = wrapped_input.shape + + self.get_graph_in_range( + validation_graph_group, + layer, + range(start_layer_id, end_layer_id), + device_ids, + self.shapes, + self.values, + ) + + for i, wrapped_output in enumerate(wrapped_layer.outputs): + output = wrapped_output.as_trt() + if wrapped_output.is_graph_output: + output_name = f"output{i}" + if strategy.communication_actions.get( + output_name) is not None: + output_name += "_after_comm" + sharding_spec = strategy.sharding_specs[output_name] + validation_graph_group.global_context.merge_context( + self.global_context.get_local_context_for_output( + output)) + validation_graph_group.add_layer_singleton( + output, device_ids, sharding_spec) + validation_graph.add_output(output) + validation_shapes[output.name] = wrapped_output.shape + if not self.timing_cache: + self.timing_cache = network.builder.create_builder_config( + ).create_timing_cache(b"") + logger.debug(f"run validation graph for layer {layer.name}") + validation_runner = validation_graph.get_runner( + validation_shapes, + self.values, + timing_cache=self.timing_cache, + opt_level=0, + ) + values = validation_runner.run() + refer_input_values = {} + for wrapped_input in wrapped_layer.inputs: + if wrapped_input is None: + continue + if wrapped_input.value is not None: + refer_input_values[wrapped_input.name] = wrapped_input.value + refer_graph, output_mapping = get_per_layer_graph( + layer, + validation_shapes, + refer_input_values, + is_shape_io=False, + ) + refer_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping + for proxy_output, output in output_mapping.items(): + validation_shapes[proxy_output] = validation_shapes[output] + logger.debug(f"run refer graph for layer {layer.name}") + refer_runner = refer_graph.get_runner( + validation_shapes, + self.values, + timing_cache=self.timing_cache, + opt_level=0, + ) + refer_outputs = refer_runner.run() + for name, refer_output in refer_outputs.items(): + if name in output_mapping: + refer_output = refer_output.bool() + output = values[name] + # ∣output−refer_output∣ <= atol+rtol*∣refer_output∣ + atol = 1e-02 + rtol = 1e-02 + if not torch.allclose( + output, + refer_output, + rtol=rtol, + atol=atol, + equal_nan=True, + ): + size = output.nelement() + diff = (output - refer_output).abs() + diff_index = (~torch.isnan(diff)) & ( + diff > (atol + rtol * refer_output.abs())) + diff_output = diff[diff_index] + diff_size = diff_output.nelement() + logger.warning( + f"output {name} of {layer.name} is not accurate after parallelization. " + f"{diff_size} out of {size} elements ({diff_size / size * 100:.2f}%) are not close. " + f"max: {diff_output.max():.5f}, mean: {diff_output.float().mean():.5f}, std: {diff_output.float().std():.5f}. " + f"mean of reference: {refer_output.float().mean():.5f}, mean of output: {output.float().mean():.5f}." + ) + for name in extra_input_values.keys(): + del self.values[name] + + def add_output(self, tensor: Tensor, device_ids, + strategy: ShardingStrategy): + trt_output = tensor.as_trt() + comm_action_sequence = strategy.best_resharding_cost[0][0][1] + for commspec in comm_action_sequence: + self.add_comm(self.global_context, tensor.name, device_ids, + commspec) + self.add_layer_singleton(trt_output, device_ids, + strategy.sharding_specs["input0"]) + if trt_output.is_shape_tensor: + output = self.prefixed_graph.add_output_shape(trt_output) + else: + output = self.prefixed_graph.add_output(trt_output) + output.attrs["strategy"] = strategy.name + + def assign_shapes(self, shape_info: ShapeInfo): + if self.validation_mode: + shapes = { + f"ref_{name}": shape + for name, shape in shape_info.shapes.items() + } + values = { + f"ref_{name}": value + for name, value in shape_info.values.items() + } + self.shapes.update(shapes) + self.values.update(values) + shape_layers = get_shape_layers(self.prefixed_graph.as_trt()) + shape_info = ShapeInfo(self.shapes, self.values, shape_layers) + self.prefixed_graph.assign_shapes(shape_info) + + +def parallelize( + simplifier: Simplifier, + config: ParallelConfig, +): + auto_parallel_config = simplifier.config + debug_mode = auto_parallel_config.debug_mode + dump_path = auto_parallel_config.dump_path + debug_outputs = auto_parallel_config.debug_outputs + + simplifier.infer_shapes(config.graph_config.num_micro_batches) + network = simplifier.network + graph = simplifier.graph + phy_mesh = config.graph_config.phy_mesh + # TODO: test device_ids = [[0]] + device_ids = phy_mesh.phy_devices_id + stage_phy_meshes = config.graph_config.stage_phy_meshes + block_to_stage = config.graph_config.graph_mapping.block_to_stage + graph_strategy = config.graph_strategy + desimplify_strategy( + graph, + graph_strategy, + config.graph_config.graph_mapping, + ) + graph_group = GraphGroup.from_graph(graph, config, auto_parallel_config) + + use_custom_all_reduce = simplifier.llm_network.plugin_config.use_custom_all_reduce + if use_custom_all_reduce and not debug_mode: + graph_group.use_custom_all_reduce = True + init_all_reduce_helper() + tp_size = phy_mesh.size // config.graph_config.num_stages + shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size, ) + workspace = graph.as_trt().add_input( + name="all_reduce_workspace", + dtype=trt.int64, + shape=shape, + ) + tensor = graph.register_input(workspace) + tensor.shape = shape + graph_strategy["all_reduce_workspace"] = ShardingStrategy( + sharding_specs={ + "output0": + ShardingSpec( + device_mesh=phy_mesh.as_logical_mesh(), + data_type_size=tensor.dtype_str_size, + data_shape=shape, + max_data_shape=shape, + raw_data_shape=shape, + dim_partition_dict={}, + ) + }) + + if dump_path is not None: + lock = FileLock(f"{dump_path}/path.lock", thread_local=False) + with lock: + with open(f'{dump_path}/sharded_graph.log', 'w+') as file: + config.print_graph_strategy(file) + + for input in graph.inputs: + graph_group.add_input(input, device_ids, graph_strategy[input.name]) + for block in simplifier.blocks: + stage_id = block_to_stage[block.block_id] + stage_phy_mesh = stage_phy_meshes[stage_id] + stage_device_ids = stage_phy_mesh.phy_devices_id.reshape( + config.lmesh.mesh_shape) + for i in block.sorted_layer_ids: + layer = graph.get_layer(network.get_layer(i).name) + layer.attrs["block_id"] = block.block_id + graph_group.add_layer( + layer, + stage_device_ids, + graph_strategy[layer.name], + ) + for output in graph.outputs: + graph_group.add_output(output, device_ids, graph_strategy[output.name]) + + if debug_mode: + new_graph = graph_group.prefixed_graph + debug_outputs = debug_outputs or [] + if isinstance(debug_outputs, str): + if debug_outputs == 'validation': + debug_outputs = [] + for tensor in new_graph.tensors: + if tensor.name.startswith('ref_'): + original_name = tensor.name[4:] + original_tensor = new_graph.get_tensor(original_name) + if original_tensor is not None: + if not original_tensor.is_graph_io: + debug_outputs.append(tensor.name) + debug_outputs.append(original_name) + if original_tensor.is_graph_output: + debug_outputs.append(tensor.name) + else: + pattern = debug_outputs + debug_outputs = [] + for tensor in new_graph.tensors: + if tensor.as_trt().is_shape_tensor: + continue + if tensor.producer is not None: + layer = tensor.producer + if layer.type == trt.LayerType.SHAPE: + continue + if re.match(pattern, tensor.name): + debug_outputs.append(tensor.name) + for output_name in debug_outputs: + trt_output = new_graph.get_tensor(output_name).as_trt() + if trt_output.is_shape_tensor: + output = new_graph.add_output_shape(trt_output) + else: + output = new_graph.add_output(trt_output) + graph_group.assign_shapes(simplifier.shape_info) + if dump_path is not None: + with lock: + new_graph.to_dot( + f'{dump_path}/sharded_graph.dot', + per_device=True, + per_block=True, + # ignore_shape_io=True, + extra_attrs=['strategy'], + ) + return [new_graph] + else: + graphs = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = graph_group.graphs[device_id] + graphs.append(graph) + return graphs diff --git a/tensorrt_llm/auto_parallel/pipeline_graph.py b/tensorrt_llm/auto_parallel/pipeline_graph.py new file mode 100644 index 000000000..f021334a7 --- /dev/null +++ b/tensorrt_llm/auto_parallel/pipeline_graph.py @@ -0,0 +1,1024 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + +import numpy as np +import tensorrt as trt +import torch + +from tensorrt_llm._utils import trt_dtype_to_str, trt_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.network import get_plugin_info, set_plugin_info +from tensorrt_llm.runtime.session import Session + +from .utils import (get_builder_flags, get_sorted_layer_ids, get_strongly_typed, + get_trt_network, set_trt_network, to_base_class_layer, + to_subclass_layer) + + +class Tensor: + + def __init__(self, graph: "PipelineGraph"): + self._graph = graph + self._trt = None + self._shape = None + self._max_shape = None + self._value = None + self.producer: Layer = None + self.output_index = None + self.consumers = [] + self.graph_input_index = -1 + self.graph_output_index = -1 + self.attrs = {} + + @staticmethod + def from_trt(graph: "PipelineGraph", trt_tensor: trt.ITensor): + tensor = Tensor(graph) + tensor._trt = trt_tensor + return tensor + + def as_trt(self) -> trt.ITensor: + return self._trt + + def copy(self) -> "Tensor": + tensor = Tensor(self._graph) + tensor._trt = self._trt + tensor._shape = self._shape + tensor._max_shape = self._max_shape + tensor._value = self._value + tensor.producer = self.producer + tensor.output_index = self.output_index + tensor.consumers = [*self.consumers] + tensor.graph_input_index = self.graph_input_index + tensor.graph_output_index = self.graph_output_index + tensor.attrs = self.attrs.copy() + return tensor + + @property + def graph(self) -> "PipelineGraph": + return self._graph + + @property + def name(self) -> str: + return self._trt.name + + @name.setter + def name(self, name: str): + old_name = self._trt.name + if name != old_name: + self._trt.name = name + self.graph._tensors[name] = self + del self.graph._tensors[old_name] + if self.is_graph_input: + self.graph._inputs[name] = self + del self.graph._inputs[old_name] + elif self.is_graph_output: + self.graph._outputs[name] = self + del self.graph._outputs[old_name] + + @property + def shape(self): + return self._shape + + @property + def max_shape(self): + return self._max_shape + + @property + def raw_shape(self): + assert isinstance(self._trt, trt.ITensor) + return self._trt.shape + + @shape.setter + def shape(self, shape): + self._shape = shape + + @max_shape.setter + def max_shape(self, max_shape): + self._max_shape = max_shape + + @raw_shape.setter + def raw_shape(self, raw_shape): + assert isinstance(self._trt, trt.ITensor) + self._trt.shape = raw_shape + + @property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + @property + def dtype(self): + return self._trt.dtype + + @property + def broadcast_across_batch(self): + return self._trt.broadcast_across_batch + + @property + def dtype_size(self): + return self.dtype.itemsize + + @property + def dtype_str(self): + return trt_dtype_to_str(self.dtype) + + @property + def dtype_str_size(self): + return [trt_dtype_to_str(self.dtype), self.dtype.itemsize] + + @property + def is_graph_input(self) -> bool: + return self.graph_input_index != -1 + + @property + def is_graph_output(self) -> bool: + return self.graph_output_index != -1 + + @property + def is_graph_io(self) -> bool: + return self.is_graph_input or self.is_graph_output + + +class Layer: + + def __init__(self, graph): + self._graph = graph + self._trt = None + self._index = None + self._inputs = [] + self._outputs = [] + self._is_shape_io = False + self.attrs = {} + + @staticmethod + def from_trt(graph, trt_layer, index): + layer = Layer(graph) + layer._trt = trt_layer + layer._index = index + for i in range(trt_layer.num_inputs): + input = trt_layer.get_input(i) + if input is not None: + layer._inputs.append(graph.get_tensor(input.name)) + layer._inputs[i].consumers.append((layer, i)) + else: + layer._inputs.append(None) + for i in range(trt_layer.num_outputs): + output = trt_layer.get_output(i) + layer._outputs.append(graph.get_tensor(output.name)) + layer._outputs[i].producer = layer + layer._outputs[i].output_index = i + set_trt_network(trt_layer, graph.as_trt()) + return layer + + def as_trt(self) -> trt.ILayer: + return self._trt + + @property + def graph(self) -> "PipelineGraph": + return self._graph + + @property + def name(self) -> str: + return self._trt.name + + @name.setter + def name(self, name: str): + old_name = self._trt.name + if name != old_name: + self._trt.name = name + self.graph._layers[name] = self + del self.graph._layers[old_name] + + @property + def type(self) -> trt.LayerType: + return self._trt.type + + @property + def index(self) -> int: + return self._index + + @property + def inputs(self) -> List[Tensor]: + return self._inputs + + @property + def outputs(self) -> List[Tensor]: + return self._outputs + + def get_input(self, index: int) -> Tensor: + return self._inputs[index] + + def get_output(self, index: int) -> Tensor: + return self._outputs[index] + + @property + def num_inputs(self) -> int: + return self._trt.num_inputs + + @property + def num_outputs(self) -> int: + return self._trt.num_outputs + + @property + def is_shape_io(self) -> bool: + return self._is_shape_io + + def to_subclass(self): + to_subclass_layer(self._trt) + + def to_base_class(self): + to_base_class_layer(self._trt) + + def assign_shapes(self, shapes, values): + for output in self.outputs: + output.shape = shapes[output.name] + output.value = values.get(output.name) + + +@dataclass +class GraphRunner: + session: Session + inputs: Dict[str, torch.Tensor] + outputs: Dict[str, torch.Tensor] + stream: torch.Stream + + def run(self): + cuda_stream = self.stream.cuda_stream + assert self.session.run(self.inputs, self.outputs, cuda_stream) + self.stream.synchronize() + return self.outputs + + +class PipelineGraph: + + def __init__(self): + self._trt = None + self._inputs: Dict[str, Tensor] = {} + self._outputs: Dict[str, Tensor] = {} + self._layers: Dict[str, Layer] = {} + self._tensors: Dict[str, Tensor] = {} + self._io_buffer_mapping = {} + self._unfilled_weights = {} + self._auto_parallel_config = None + + @staticmethod + def create_graph(): + graph = PipelineGraph() + trt_builder = trt.Builder(logger.trt_logger) + explicit_batch_flag = 0 + # Explicit batch flag will be deprecated in TRT 10 + if "EXPLICIT_BATCH" in trt.NetworkDefinitionCreationFlag.__members__.keys( + ): + explicit_batch_flag = 1 << int( + trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + if get_strongly_typed(): + network = trt_builder.create_network( + explicit_batch_flag + | (1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED))) + else: + network = trt_builder.create_network(explicit_batch_flag) + graph._trt = network + return graph + + def _register_unfilled_weights(self, layer_name, weights, values): + self._unfilled_weights[layer_name] = (weights, values) + + def _add_tensor(self, tensor, old_tensor, prefix): + if prefix is not None: + tensor.name = prefix + old_tensor.name + else: + tensor.name = old_tensor.name + tensor.location = old_tensor.location + if old_tensor.dynamic_range is not None: + tensor.dynamic_range = old_tensor.dynamic_range + if tensor.is_network_input: + tensor.shape = old_tensor.shape + for i in range(len(old_tensor.shape)): + name = old_tensor.get_dimension_name(i) + if name is not None: + tensor.set_dimension_name(i, name) + return self._register_tensor(tensor) + + def _register_tensor(self, tensor): + wrapped_tensor = Tensor.from_trt(self, tensor) + assert tensor.name not in self._tensors + self._tensors[tensor.name] = wrapped_tensor + return wrapped_tensor + + def add_input(self, tensor, prefix=None): + tensor_name = tensor.name + if prefix is not None: + tensor_name = prefix + tensor_name + input = self._trt.add_input(tensor_name, tensor.dtype, tensor.shape) + new_tensor = self._add_tensor(input, tensor, prefix) + new_tensor.graph_input_index = len(self._inputs) + self._inputs[tensor_name] = new_tensor + return new_tensor + + def register_input(self, tensor, index=None): + if index is None: + index = self.num_inputs - 1 + assert self._trt.get_input(index).name == tensor.name + wrapped_input = self._register_tensor(tensor) + wrapped_input.graph_input_index = index + self._inputs[tensor.name] = wrapped_input + return wrapped_input + + def add_output(self, tensor, prefix=None): + tensor_name = tensor.name + if prefix is not None: + tensor_name = prefix + tensor_name + output = self.get_tensor(tensor_name) + output.graph_output_index = len(self._outputs) + trt_output = output.as_trt() + self._trt.mark_output(trt_output) + trt_output.dtype = tensor.dtype + self._outputs[tensor_name] = output + return output + + def add_output_shape(self, tensor, prefix=None): + tensor_name = tensor.name + if prefix is not None: + tensor_name = prefix + tensor_name + output = self.get_tensor(tensor_name) + trt_output = output.as_trt() + self._trt.mark_output_for_shapes(trt_output) + trt_output.dtype = tensor.dtype + self._outputs[tensor_name] = output + return output + + def add_layer( + self, + layer, + input_mapping=None, + prefix=None, + updated_attrs=None, + ) -> Layer: + + def get_input(i): + name = layer.get_input(i).name + if prefix is not None: + name = prefix + name + if input_mapping is not None and name in input_mapping: + name = input_mapping[name] + return self.get_tensor(name).as_trt() + + network = self._trt + layer_type = layer.type + to_subclass_layer(layer) + if layer_type == trt.LayerType.ACTIVATION: + trt_input = get_input(0) + new_layer = network.add_activation(trt_input, layer.type) + new_layer.alpha = layer.alpha + new_layer.beta = layer.beta + elif layer_type == trt.LayerType.CONCATENATION: + trt_inputs = [get_input(i) for i in range(layer.num_inputs)] + new_layer = network.add_concatenation(trt_inputs) + new_layer.axis = layer.axis + elif layer_type == trt.LayerType.CONSTANT: + new_layer = network.add_constant(layer.shape, layer.weights) + elif layer_type == trt.LayerType.ELEMENTWISE: + new_layer = network.add_elementwise(get_input(0), get_input(1), + layer.op) + elif layer_type == trt.LayerType.FILL: + if layer.num_inputs >= 1 and layer.get_input(0) is not None: + shape_input = get_input(0) + shape = [1] + else: + shape_input = None + shape = layer.shape + new_layer = network.add_fill(shape, layer.operation, layer.to_type) + if shape_input is not None: + new_layer.set_input(0, shape_input) + if layer.num_inputs >= 1 and layer.get_input(0) is not None: + new_layer.set_input(0, get_input(0)) + if layer.num_inputs >= 2 and layer.get_input(1) is not None: + new_layer.set_input(1, get_input(1)) + else: + new_layer.alpha = layer.alpha + if layer.num_inputs >= 3 and layer.get_input(2) is not None: + new_layer.set_input(2, get_input(2)) + else: + new_layer.beta = layer.beta + elif layer_type == trt.LayerType.GATHER: + trt_input = get_input(0) + trt_indices = get_input(1) + new_layer = network.add_gather_v2(trt_input, trt_indices, + layer.mode) + new_layer.axis = layer.axis + new_layer.num_elementwise_dims = layer.num_elementwise_dims + new_layer.mode = layer.mode + elif layer_type == trt.LayerType.MATRIX_MULTIPLY: + new_layer = network.add_matrix_multiply(get_input(0), layer.op0, + get_input(1), layer.op1) + elif layer_type == trt.LayerType.REDUCE: + new_layer = network.add_reduce(get_input(0), layer.op, layer.axes, + layer.keep_dims) + elif layer_type == trt.LayerType.SELECT: + trt_condition = get_input(0) + trt_then = get_input(1) + trt_else = get_input(2) + new_layer = network.add_select(trt_condition, trt_then, trt_else) + elif layer_type == trt.LayerType.SHUFFLE: + new_layer = network.add_shuffle(get_input(0)) + new_layer.first_transpose = layer.first_transpose + new_layer.second_transpose = layer.second_transpose + new_layer.zero_is_placeholder = layer.zero_is_placeholder + if layer.num_inputs >= 2: + trt_reshape_dims_tensor = get_input(1) + new_layer.set_input(1, trt_reshape_dims_tensor) + else: + new_layer.reshape_dims = layer.reshape_dims + elif layer_type == trt.LayerType.SLICE: + if layer.num_inputs >= 2 and layer.get_input(1) is not None: + trt_start = get_input(1) + start = [] + else: + trt_start = None + start = layer.start + if layer.num_inputs >= 3 and layer.get_input(2) is not None: + trt_shape = get_input(2) + shape = [] + else: + trt_shape = None + shape = layer.shape + if layer.num_inputs >= 4 and layer.get_input(3) is not None: + trt_stride = get_input(3) + stride = [] + else: + trt_stride = None + stride = layer.stride + new_layer = network.add_slice(get_input(0), start, shape, stride) + new_layer.mode = layer.mode + if trt_start is not None: + new_layer.set_input(1, trt_start) + if trt_shape is not None: + new_layer.set_input(2, trt_shape) + if trt_stride is not None: + new_layer.set_input(3, trt_stride) + elif layer_type == trt.LayerType.SOFTMAX: + new_layer = network.add_softmax(get_input(0)) + new_layer.axes = layer.axes + elif layer_type == trt.LayerType.UNARY: + new_layer = network.add_unary(get_input(0), layer.op) + elif layer_type == trt.LayerType.SHAPE: + new_layer = network.add_shape(get_input(0)) + elif layer_type == trt.LayerType.ASSERTION: + new_layer = network.add_assertion(get_input(0), layer.message) + elif layer_type == trt.LayerType.CAST: + new_layer = network.add_cast(get_input(0), layer.to_type) + elif layer_type == trt.LayerType.NORMALIZATION: + trt_input = get_input(0) + trt_scale = get_input(1) + trt_bias = get_input(2) + new_layer = network.add_normalization(trt_input, trt_scale, + trt_bias, layer.axes) + new_layer.epsilon = layer.epsilon + new_layer.num_groups = layer.num_groups + new_layer.compute_precision = layer.compute_precision + elif layer_type == trt.LayerType.IDENTITY: + new_layer = network.add_identity(get_input(0)) + elif layer_type == trt.LayerType.PLUGIN_V2: + plugin = layer.plugin + updated = False + if (updated_attrs is not None + and updated_attrs.get("plugin") is not None): + plugin = updated_attrs["plugin"] + updated = True + updated_attrs = None + new_layer = network.add_plugin_v2( + [get_input(i) for i in range(layer.num_inputs)], + plugin, + ) + else: + raise NotImplementedError( + "Unsupported layer type: {}".format(layer_type)) + + if updated_attrs is not None: + for attr_name, attr_value in updated_attrs.items(): + setattr(new_layer, attr_name, attr_value) + + to_base_class_layer(layer) + to_base_class_layer(new_layer) + layer_index = network.num_layers - 1 + layer_name = layer.name + if prefix is not None: + layer_name = prefix + layer_name + new_layer.name = layer_name + if layer.precision_is_set: + new_layer.precision = layer.precision + for i in range(layer.num_outputs): + if layer.output_type_is_set(i): + new_layer.set_output_type(i, layer.get_output_type(i)) + output = new_layer.get_output(i) + self._add_tensor(output, layer.get_output(i), prefix) + wrapped_layer = Layer.from_trt(self, new_layer, layer_index) + assert layer_name not in self._layers + self._layers[layer_name] = wrapped_layer + if layer_type == trt.LayerType.PLUGIN_V2: + if not updated: + plugin_info = get_plugin_info(get_trt_network(layer), + layer.name) + set_plugin_info(self.as_trt(), new_layer.name, plugin_info) + return wrapped_layer + + def register_layer(self, layer, index=None): + if index is None: + index = self.num_layers - 1 + assert self._trt.get_layer(index).name == layer.name + to_base_class_layer(layer) + for i in range(layer.num_outputs): + output = layer.get_output(i) + self._register_tensor(output) + wrapped_layer = Layer.from_trt(self, layer, index) + assert layer.name not in self._layers + self._layers[layer.name] = wrapped_layer + to_subclass_layer(layer) + return wrapped_layer + + def get_runner( + self, + shapes=None, + values=None, + profile=None, + timing_cache=None, + opt_level=None, + ) -> GraphRunner: + shapes = shapes or {} + values = values or {} + inputs = {} + outputs = {} + for input in self.inputs: + if input is not None: + value = values.get(input.name) + if value is None: + value = input.value + if value is not None: + if not isinstance(value, torch.Tensor): + value = torch.tensor( + value, + dtype=trt_dtype_to_torch(input.dtype), + device='cpu', + ) + inputs[input.name] = value + else: + shape = shapes.get(input.name) + if shape is None: + shape = input.shape + assert shape is not None + inputs[input.name] = torch.empty( + tuple(shape), + dtype=trt_dtype_to_torch(input.dtype), + device=torch.cuda.current_device(), + ) + if torch.is_floating_point(inputs[input.name]): + inputs[input.name].normal_() + # inputs[input.name][:] = random.choice([2, 3, 5, 7]) + for output in self.outputs: + if output.as_trt().is_shape_tensor: + continue + if output.name in self._io_buffer_mapping: + input_name = self._io_buffer_mapping[output.name] + if input_name in inputs: + outputs[output.name] = inputs[input_name] + continue + value = values.get(output.name) + if value is not None and isinstance(value, torch.Tensor): + outputs[output.name] = value + else: + shape = shapes.get(output.name) + if shape is None: + shape = output.shape + assert shape is not None + outputs[output.name] = torch.empty( + tuple(shape), + dtype=trt_dtype_to_torch(output.dtype), + device=torch.cuda.current_device(), + ) + network = self.as_trt() + config = network.builder.create_builder_config() + if opt_level is not None: + config.builder_optimization_level = opt_level + config.flags = get_builder_flags() + profile = profile or network.builder.create_optimization_profile() + profile_index = config.add_optimization_profile(profile) + if timing_cache is not None: + config.set_timing_cache(timing_cache, ignore_mismatch=False) + plan = network.builder.build_serialized_network(network, config) + if plan is None: + logger.error('Engine building failed, please check the error log.') + session = Session.from_serialized_engine(plan) + stream = torch.cuda.current_stream() + cuda_stream = stream.cuda_stream + context = session.context + context.set_optimization_profile_async(profile_index, cuda_stream) + runner = GraphRunner(session, inputs, outputs, stream) + return runner + + def run( + self, + shapes=None, + values=None, + profile=None, + timing_cache=None, + opt_level=None, + ): + return self.get_runner( + shapes, + values, + profile, + timing_cache, + opt_level, + ).run() + + def duplicate_graph(self): + graph = PipelineGraph.create_graph() + network = self.as_trt() + for i in range(network.num_inputs): + input = network.get_input(i) + graph.add_input(input) + sorted_layer_ids = get_sorted_layer_ids(network) + for i in sorted_layer_ids: + layer = network.get_layer(i) + graph.add_layer(layer) + for i in range(network.num_outputs): + output = network.get_output(i) + if output.is_shape_tensor: + graph.add_output_shape(output) + else: + graph.add_output(output) + return graph + + @staticmethod + def from_trt(trt_network): + graph = PipelineGraph() + graph._trt = trt_network + + # construct inputs and tensors + for i in range(trt_network.num_inputs): + trt_input = trt_network.get_input(i) + tensor = Tensor.from_trt(graph, trt_input) + tensor.graph_input_index = i + graph._tensors[tensor.name] = tensor + graph._inputs[tensor.name] = tensor + for i in range(trt_network.num_layers): + trt_layer = trt_network.get_layer(i) + for i in range(trt_layer.num_outputs): + trt_output = trt_layer.get_output(i) + tensor = Tensor.from_trt(graph, trt_output) + graph._tensors[tensor.name] = tensor + + # construct layers and outputs + for i in range(trt_network.num_layers): + layer = Layer.from_trt(graph, trt_network.get_layer(i), i) + graph._layers[layer.name] = layer + for i in range(trt_network.num_outputs): + tensor_name = trt_network.get_output(i).name + output_tensor = graph._tensors[tensor_name] + output_tensor.graph_output_index = i + graph._outputs[tensor_name] = output_tensor + + return graph + + def assign_shapes(self, shape_info=None, is_partial=False): + if shape_info is None: + for tensor in self.tensors: + tensor.shape = tensor.raw_shape + return + for tensor in self.tensors: + if tensor.name in shape_info.shapes: + tensor.shape = shape_info.shapes[tensor.name] + elif not is_partial: + raise ValueError(f"Cannot find shape for tensor: {tensor.name}") + if shape_info.max_shapes is not None: + if tensor.name in shape_info.max_shapes: + tensor.max_shape = shape_info.max_shapes[tensor.name] + elif not is_partial: + raise ValueError( + f"Cannot find max shape for tensor: {tensor.name}") + if tensor.name in shape_info.values: + tensor.value = shape_info.values[tensor.name] + for layer in self.layers: + if layer.name in shape_info.shape_layers: + layer._is_shape_io = True + + def infer_shapes(self, profile=None): + from .shape_info import get_shape_info + + shape_info = get_shape_info(self._trt, profile) + self.assign_shapes(shape_info) + + def as_trt(self) -> trt.INetworkDefinition: + return self._trt + + def get_input(self, name: str) -> Tensor: + return self._inputs.get(name) + + def is_input(self, name: str) -> bool: + return name in self._inputs + + @property + def inputs(self) -> List[Tensor]: + return [*self._inputs.values()] + + @property + def num_inputs(self) -> int: + return self._trt.num_inputs + + def get_output(self, name: str) -> Tensor: + return self._outputs.get(name) + + def is_output(self, name: str) -> bool: + return name in self._outputs + + @property + def outputs(self) -> List[Tensor]: + return [*self._outputs.values()] + + @property + def num_outputs(self) -> int: + return self._trt.num_outputs + + def get_tensor(self, name: str) -> Tensor: + return self._tensors.get(name) + + @property + def tensors(self) -> List[Tensor]: + return [*self._tensors.values()] + + def get_layer(self, name: str) -> Layer: + return self._layers.get(name) + + @property + def layers(self) -> List[Layer]: + return [*self._layers.values()] + + @property + def sorted_layers(self) -> List[Layer]: + sorted_layer_ids = get_sorted_layer_ids(self.as_trt()) + return [ + self.get_layer(self.as_trt().get_layer(layer_id).name) + for layer_id in sorted_layer_ids + ] + + @property + def num_layers(self) -> int: + return self._trt.num_layers + + def to_dot(self, + path=None, + per_device=False, + per_block=False, + ignore_shape_io=False, + no_style=False, + extra_attrs=None) -> Optional[str]: + ''' + Get a graphviz representation of the graph. + + Parameters: + path: the path to save the graphviz file, if not provided, will return the graphviz source code + ''' + try: + import graphviz + except ImportError: + logger.error( + "Failed to import graphviz, please install graphviz to enable PipelineGraph.to_dot()" + ) + return + + extra_attrs = extra_attrs or [] + + graph = graphviz.Digraph() + input_block_graph = graphviz.Digraph(name='cluster_inputs') + output_block_graph = graphviz.Digraph(name='cluster_outputs') + device_graphs = {} + block_graphs = {} + block_graph_mapping = [] + tensor_names = set() + layer_names = set() + + common_style = dict(fontname='Arial', ) + node_style = dict( + **common_style, + style='rounded,filled,bold', + ) + tensor_style = dict( + **node_style, + shape='ellipse', + fillcolor='white', + ) + input_tensor_style = {**tensor_style, 'fillcolor': 'green'} + output_tensor_style = {**tensor_style, 'fillcolor': 'lightgreen'} + layer_style = dict( + **node_style, + shape='box', + fillcolor='white', + ) + shape_layer_style = {**layer_style, 'fillcolor': 'grey'} + helper_layer_style = {**layer_style, 'fillcolor': 'lightgrey'} + graph_style = dict( + **common_style, + style='rounded', + penwidth='5', + fontsize='28', + ) + device_graph_style = dict( + **graph_style, + color='cornflowerblue', + ) + block_graph_style = dict( + **graph_style, + color='darkcyan', + ) + input_block_style = dict( + **graph_style, + color='green', + ) + output_block_style = dict( + **graph_style, + color='lightgreen', + ) + if no_style: + device_graph_style = {} + block_graph_style = {} + input_block_style = {} + output_block_style = {} + input_block_graph.attr(label='inputs', **input_block_style) + output_block_graph.attr(label='outputs', **output_block_style) + + def get_tensor_labels(tensor): + labels = [] + if tensor.value is not None: + labels.append(f"value={tensor.value}") + else: + labels.append(f"dtype={tensor.dtype.name}{tensor.shape}") + for attr_name in extra_attrs: + if attr_name in tensor.attrs: + labels.append(f"{attr_name}={tensor.attrs[attr_name]}") + return labels + + def get_device_graph(name): + if per_device and name.startswith('device'): + device_name = name.split('_')[0] + if device_name not in device_graphs: + device_graph = graphviz.Digraph(name='cluster_' + + device_name) + device_graph.attr(label=device_name, **device_graph_style) + device_graphs[device_name] = device_graph + return device_graphs[device_name] + return None + + def get_block_graph(layer, current_graph): + if per_block and 'block_id' in layer.attrs: + block_label = f"block{layer.attrs['block_id']}" + if current_graph.name is not None: + graph_label = current_graph.name[len('cluster_'):] + else: + graph_label = '' + block_name = f"{graph_label}{block_label}" + if block_name not in block_graphs: + block_graph = graphviz.Digraph(name='cluster_' + block_name) + block_graph.attr(label=block_label, **block_graph_style) + block_graphs[block_name] = block_graph + block_graph_mapping.append((current_graph, block_graph)) + return block_graphs[block_name] + return current_graph + + for name, tensor in self._tensors.items(): + style = tensor_style + if tensor.is_graph_input: + style = input_tensor_style + current_graph = input_block_graph + elif tensor.is_graph_output: + style = output_tensor_style + current_graph = output_block_graph + elif tensor.producer.num_outputs == 1: + continue + else: + current_graph = get_device_graph(name) or graph + current_graph = get_block_graph(tensor.producer, current_graph) + if no_style: + style = {} + labels = [name, *get_tensor_labels(tensor)] + content = "\n".join(labels) + current_graph.node(name, content, **style) + tensor_names.add(name) + + for layer in self.sorted_layers: + name = layer.name + + style = layer_style + if layer.is_shape_io: + if ignore_shape_io: + continue + style = shape_layer_style + elif layer.attrs.get("role", None) == "helper": + style = helper_layer_style + fillcolor = None + plugin_type = None + if layer.type == trt.LayerType.PLUGIN_V2: + fillcolor = 'yellow' + layer.to_subclass() + plugin_type = layer.as_trt().plugin.plugin_type + layer.to_base_class() + if layer.type == trt.LayerType.MATRIX_MULTIPLY or plugin_type == 'Gemm': + fillcolor = 'orange' + if fillcolor is not None: + style = {**style, 'fillcolor': fillcolor} + if no_style: + style = {} + + layer_attrs = {} + layer_type = layer.type + layer.to_subclass() + if layer_type == trt.LayerType.CONSTANT: + if not layer.is_shape_io: + if trt.volume(layer.get_output(0).shape) <= 8: + weights = layer.as_trt().weights + if isinstance(weights, trt.Weights): + weights = weights.numpy() + value = np.array2string( + weights, + formatter={'float_kind': lambda x: f"{x:.2e}"}) + layer_attrs['value'] = value + elif layer_type == trt.LayerType.SHUFFLE: + for attr_name in ['first_transpose', 'second_transpose']: + attr_value = getattr(layer.as_trt(), attr_name) + if tuple(attr_value) != (0, 1, 2, 3, 4, 5, 6, 7): + tensor = layer.get_input( + 0 + ) if attr_name == 'first_transpose' else layer.get_output( + 0) + layer_attrs[attr_name] = tuple( + attr_value)[:len(tensor.shape)] + if layer.num_inputs < 2: + attr_value = layer.as_trt().reshape_dims + layer_attrs['reshape_dims'] = attr_value + elif layer_type == trt.LayerType.SLICE: + if layer.num_inputs < 2 or layer.get_input(1) is None: + layer_attrs['start'] = layer.as_trt().start + if layer.num_inputs < 4 or layer.get_input(3) is None: + attr_value = layer.as_trt().stride + if attr_value != tuple( + [1] * len(layer.get_output(0).shape)): + layer_attrs['stride'] = attr_value + layer.to_base_class() + + if layer.is_shape_io: + labels = [layer.type.name] + else: + labels = [name, layer.type.name] + for key, value in layer_attrs.items(): + labels.append(f"{key}={value}") + for attr_name in extra_attrs: + if attr_name in layer.attrs: + labels.append(f"{attr_name}={layer.attrs[attr_name]}") + if layer.num_outputs == 1: + output = layer.get_output(0) + if output.name != f'{layer.name}_output_0': + labels.append(f"output={output.name}") + labels.extend(get_tensor_labels(output)) + content = "\n".join(labels) + + current_graph = get_device_graph(name) or graph + current_graph = get_block_graph(layer, current_graph) + current_graph.node(name, content, **style) + layer_names.add(name) + + for index, input in enumerate(layer.inputs): + if input is not None: + if input.is_graph_input or input.producer.num_outputs > 1: + if input.name in tensor_names: + graph.edge(input.name, name, str(index)) + else: + if input.producer.name in layer_names: + graph.edge(input.producer.name, name, str(index)) + if layer.num_outputs > 1 or (layer.num_outputs == 1 and + layer.get_output(0).is_graph_output): + for index, output in enumerate(layer.outputs): + graph.edge(name, output.name, str(index)) + + graph.subgraph(input_block_graph) + graph.subgraph(output_block_graph) + for parent_graph, block_graph in block_graph_mapping: + parent_graph.subgraph(block_graph) + for device_graph in device_graphs.values(): + graph.subgraph(device_graph) + + if not path: + return graph.source + graph.save(path) + + @staticmethod + def trt_to_dot(trt_network, path=None): + graph = PipelineGraph.from_trt(trt_network) + graph.assign_shapes() + dot = graph.to_dot(no_style=True) + if path is not None: + with open(path, "w") as f: + f.write(dot) + else: + return dot diff --git a/tensorrt_llm/auto_parallel/runtime_profiling.py b/tensorrt_llm/auto_parallel/runtime_profiling.py new file mode 100644 index 000000000..8f6c8d9cc --- /dev/null +++ b/tensorrt_llm/auto_parallel/runtime_profiling.py @@ -0,0 +1,150 @@ +import numpy as np +import tensorrt as trt +import torch + +from tensorrt_llm.logger import logger +from tensorrt_llm.network import get_plugin_info + +from .shape_info import get_per_layer_graph +from .utils import get_cache_key, get_trt_network, get_updated_plugin + + +class NvtxProfiler(object): + + def __init__(self, nvtx_name, enable=True): + self.nvtx_name = nvtx_name + self.enable = enable + + def __enter__(self): + if self.enable: + torch.cuda.nvtx.range_push(self.nvtx_name) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + torch.cuda.nvtx.range_pop() + + +class LayerProfiler(trt.IProfiler): + + def __init__(self): + trt.IProfiler.__init__(self) + self.layer_count = 0 + self.time = 0 + + def report_layer_time(self, layer_name, ms): + logger.debug(f'{layer_name=}, {self.layer_count=}, time = {ms} ms') + self.time += ms + self.layer_count += 1 + + +class RuntimeProfiler(object): + + def __init__(self): + self.timing_cache = None + + def _profile(self, layer, layer_attrs, shapes, values, io_buffer_mapping): + is_plugin = layer.type == trt.LayerType.PLUGIN_V2 + if is_plugin and len(layer_attrs) > 0: + plugin_info = get_plugin_info( + get_trt_network(layer), + layer.name, + ) + new_plugin, _ = get_updated_plugin(plugin_info, layer_attrs) + layer_attrs = {"plugin": new_plugin} + graph, output_mapping = get_per_layer_graph(layer, shapes, values, + layer_attrs) + graph._io_buffer_mapping = io_buffer_mapping + network = graph.as_trt() + if network.num_outputs > 0 and np.all([ + network.get_output(i).is_shape_tensor + for i in range(network.num_outputs) + ]): + return 0.0 + for proxy_output, output in output_mapping.items(): + shapes[proxy_output] = shapes[output] + if not self.timing_cache: + self.timing_cache = network.builder.create_builder_config( + ).create_timing_cache(b"") + runner = graph.get_runner( + shapes, + values, + timing_cache=self.timing_cache, + ) + context = runner.session.context + context.profiler = LayerProfiler() + runner.run() + profiler_time_first_run = context.profiler.time + runner.run() + return (context.profiler.time - profiler_time_first_run) * 1000.0 + + def runtime_profile(self, layer, layer_attrs, input_values, strategy, + device_mesh): + logger.debug(f"start to profile layer {layer.name}") + shapes = {} + values = {} + dtypes = {} + trt_layer = layer.as_trt() + + sharding_sequences = () + for i in range(layer.num_inputs): + input = trt_layer.get_input(i) + if input is not None: + shapes[input.name] = strategy.sharding_specs[ + f'input{i}'].get_sharded_shape_per_device() + dtypes[input.name] = input.dtype + sharding_sequences += (str( + strategy.sharding_specs[f"input{i}"].sharding_sequence), ) + if i in input_values: + values[input.name] = input_values[i] + else: + value = layer.get_input(i).value + if value is not None: + values[input.name] = value + else: + sharding_sequences += (None, ) + + for i in range(layer.num_outputs): + output = trt_layer.get_output(i) + if f'output{i}' in strategy.communication_actions: + shapes[output.name] = strategy.communication_actions[ + f'output{i}'].sharding_spec.get_sharded_shape_per_device() + else: + shapes[output.name] = strategy.sharding_specs[ + f'output{i}'].get_sharded_shape_per_device() + dtypes[output.name] = output.dtype + sharding_sequences += (str( + strategy.sharding_specs[f"output{i}"].sharding_sequence), ) + data_key = get_cache_key( + trt_layer, + shapes, + values, + dtypes=dtypes, + updated_attrs=layer_attrs, + ) + data_key += (sharding_sequences, ) + elapsed_time = device_mesh.prof_database.query( + device_mesh.cluster_key, + data_key, + ) + if elapsed_time: + logger.debug( + f'runtime profiling cache hit {data_key}: {elapsed_time} us') + return elapsed_time + with NvtxProfiler(f'{layer.name}_{data_key}', enable=True): + elapsed_time = self._profile( + layer.as_trt(), + layer_attrs, + shapes, + values, + layer.graph._io_buffer_mapping, + ) + logger.debug( + f'runtime profiling cache miss {data_key}: {elapsed_time} us') + + device_mesh.prof_database.update( + device_mesh.cluster_key, + data_key, + (elapsed_time, strategy.alpha_beta_cost), + ) + + return elapsed_time diff --git a/tensorrt_llm/auto_parallel/shape_info.py b/tensorrt_llm/auto_parallel/shape_info.py new file mode 100644 index 000000000..1034d422b --- /dev/null +++ b/tensorrt_llm/auto_parallel/shape_info.py @@ -0,0 +1,337 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Set + +import numpy as np +import tensorrt as trt +import torch + +from tensorrt_llm._utils import trt_dtype_to_np, trt_dtype_to_str +from tensorrt_llm.logger import logger + +from .pipeline_graph import PipelineGraph +from .utils import (get_builder_flags, get_cache_key, get_sorted_layer_ids, + set_trt_network, to_base_class_layer, to_subclass_layer, + to_trt_weights) + + +class ShapeType(Enum): + MIN = 0 + OPT = 1 + MAX = 2 + + +def get_shape_layers(trt_network): + shape_layers = set() + for i in range(trt_network.num_layers): + layer = trt_network.get_layer(i) + if (layer.num_inputs > 0 and np.all([ + layer.get_input(j).is_shape_tensor + for j in range(layer.num_inputs) + if layer.get_input(j) is not None + ])) or (layer.num_outputs > 0 and np.all([ + layer.get_output(j).is_shape_tensor + for j in range(layer.num_outputs) + ])): + shape_layers.add(layer.name) + return shape_layers + + +def get_layers_in_shape_network(trt_network, shape_layers, sorted_layer_ids): + layers = set() + shape_tensors = set() + for layer_id in reversed(sorted_layer_ids): + layer = trt_network.get_layer(layer_id) + in_shape_network = False + if layer.name in shape_layers: + in_shape_network = True + else: + for j in range(layer.num_outputs): + output = layer.get_output(j) + if output.name in shape_tensors: + in_shape_network = True + break + if in_shape_network: + layers.add(layer.name) + for j in range(layer.num_inputs): + input = layer.get_input(j) + if input is not None: + shape_tensors.add(input.name) + return layers + + +def get_shape_network(trt_network, + shapes, + values, + sorted_layer_ids, + profile=None, + shape_type: ShapeType = ShapeType.OPT): + shape_layers = get_shape_layers(trt_network) + layers_in_shape_network = get_layers_in_shape_network( + trt_network, shape_layers, sorted_layer_ids) + shape_graph = PipelineGraph.create_graph() + shape_network = shape_graph.as_trt() + shape_builder = shape_network.builder + shape_profile = shape_builder.create_optimization_profile() + for i in range(trt_network.num_inputs): + input = trt_network.get_input(i) + shapes[input.name] = input.shape + new_input = shape_graph.add_input(input) + if profile is not None: + if -1 in input.shape: + shape = profile.get_shape(input.name) + shape = shape[shape_type.value] + shapes[input.name] = shape + new_input.raw_shape = shape + if input.is_shape_tensor: + shape_values = profile.get_shape_input(input.name) + value = shape_values[shape_type.value] + values[input.name] = value + shape_profile.set_shape_input(input.name, value, value, value) + output_mapping = {} + for layer_id in sorted_layer_ids: + layer = trt_network.get_layer(layer_id) + if layer.name in shape_layers: + new_layer = shape_graph.add_layer(layer) + for i in range(layer.num_outputs): + output = layer.get_output(i) + if output.dtype != trt.DataType.BOOL: + shape_graph.add_output_shape(output) + else: + proxy_layer = shape_network.add_identity( + new_layer.as_trt().get_output(i)) + proxy_output = proxy_layer.get_output(0) + proxy_output.dtype = trt.DataType.INT32 + shape_graph.register_layer(proxy_layer) + shape_graph.add_output_shape(proxy_output) + output_mapping[proxy_output.name] = output.name + elif layer.name in layers_in_shape_network: + if layer.type == trt.LayerType.CONSTANT: + shape_graph.add_input(layer.get_output(0)) + else: + shape_graph.add_layer(layer) + return shape_network, shape_profile, shape_layers, output_mapping + + +def get_per_layer_graph( + layer, + shapes, + values, + updated_attrs=None, + is_shape_io: bool = None, +): + graph = PipelineGraph.create_graph() + network = graph.as_trt() + is_shape_layer = layer.num_inputs != 0 + for i in range(layer.num_inputs): + input = layer.get_input(i) + if input is not None: + shape = shapes[input.name] + if (values.get(input.name) is not None + and not isinstance(values[input.name], torch.Tensor)): + value = values[input.name] + weights = np.asarray(value, dtype=trt_dtype_to_np(input.dtype)) + weights = to_trt_weights(weights) + input_layer = network.add_constant(shape, weights) + new_input = input_layer.get_output(0) + new_input.name = input.name + graph.register_layer(input_layer) + elif graph.get_input(input.name) is None: + new_input = graph.add_input(input) + new_input.raw_shape = shapes[input.name] + is_shape_layer = False + new_layer = graph.add_layer( + layer, + updated_attrs=updated_attrs, + ) + output_mapping = {} + if layer.type == trt.LayerType.SHAPE: + is_shape_layer = True + if layer.num_inputs == 0: + is_shape_layer = False + if is_shape_io is not None: + is_shape_layer = is_shape_io + for i in range(layer.num_outputs): + output = layer.get_output(i) + value = values.get(output.name) + if value is not None and isinstance(value, torch.Tensor): + is_output_shape = False + elif is_shape_layer: + is_output_shape = True + else: + is_output_shape = False + if is_output_shape: + if output.dtype == trt.DataType.BOOL: + proxy_layer = network.add_cast( + new_layer.as_trt().get_output(i), + trt.DataType.INT32, + ) + proxy_output = proxy_layer.get_output(0) + graph.register_layer(proxy_layer) + output_mapping[proxy_output.name] = output.name + output = proxy_output + graph.add_output_shape(output) + else: + graph.add_output(output) + return graph, output_mapping + + +def infer_shapes(network, shapes, values, profile=None): + if network.num_outputs == 0: + return + builder = network.builder + config = builder.create_builder_config() + config.builder_optimization_level = 0 + config.flags = get_builder_flags() + profile = profile or builder.create_optimization_profile() + config.add_optimization_profile(profile) + plan = builder.build_serialized_network(network, config) + if plan is None: + raise RuntimeError( + 'Engine building failed when inferring shapes, please check the error log.' + ) + runtime = trt.Runtime(logger.trt_logger) + engine = runtime.deserialize_cuda_engine(plan) + context = engine.create_execution_context() + for i in range(network.num_inputs): + input = network.get_input(i) + if input.is_shape_tensor: + value = values[input.name] + context.set_shape_input(engine[input.name], value) + context.infer_shapes() + assert context.all_binding_shapes_specified + for i in range(network.num_outputs): + output = network.get_output(i) + shape = context.get_tensor_shape(output.name) + # if len(shape) == 0: + # shape = trt.Dims([1]) + shapes[output.name] = shape + if output.is_shape_tensor: + if shape == [0]: + values[output.name] = [] + else: + values[output.name] = context.get_shape(engine[output.name]) + + +@dataclass +class ShapeInfo: + shapes: Dict[str, trt.Dims] + values: Dict[str, List[int]] + shape_layers: Set[str] + max_shapes: Dict[str, trt.Dims] = None + + +def set_constant_value(layer, values): + to_subclass_layer(layer) + output_name = layer.get_output(0).name + weights = layer.weights + if isinstance(weights, trt.Weights): + weights = weights.numpy() + values[output_name] = list(weights) + to_base_class_layer(layer) + + +def infer_per_layer_shapes( + layer: trt.ILayer, + shapes, + values, + cache=None, + is_shape_io=False, +): + if layer.type == trt.LayerType.CONSTANT: + to_subclass_layer(layer) + output_name = layer.get_output(0).name + shape = layer.shape + shapes[output_name] = shape + if is_shape_io: + set_constant_value(layer, values) + to_base_class_layer(layer) + return + elif layer.type == trt.LayerType.SHAPE: + input_name = layer.get_input(0).name + output_name = layer.get_output(0).name + shape = [*shapes[input_name]] + shapes[output_name] = trt.Dims([len(shape)]) + values[output_name] = shape + return + if cache is not None: + cache_key = get_cache_key(layer, shapes, values) + if cache_key in cache: + output_shapes, output_values = cache[cache_key] + for i in range(layer.num_outputs): + output = layer.get_output(i) + shapes[output.name] = output_shapes[i] + if output_values[i] is not None: + values[output.name] = output_values[i] + return + logger.debug(f"infer shapes for layer {layer.name}") + graph, output_mapping = get_per_layer_graph(layer, shapes, values) + try: + infer_shapes(graph.as_trt(), shapes, values) + except RuntimeError as e: + dtypes = [ + trt_dtype_to_str(layer.get_input(i).dtype) + for i in range(layer.num_inputs) + ] + layer_info = (f"type={cache_key[0]}, " + f"attrs={dict(cache_key[1])}, " + f"dtypes={dtypes}, " + f"shapes={list(cache_key[2])}, " + f"values={list(cache_key[3])}") + raise RuntimeError( + f"infer shapes failed for layer {layer.name} ({layer_info})") from e + for proxy_output, output in output_mapping.items(): + shapes[output] = shapes[proxy_output] + del shapes[proxy_output] + if proxy_output in values: + values[output] = [*map(bool, values[proxy_output])] + del values[proxy_output] + if cache is not None: + logger.debug( + f"shape inference cache miss, layer: {layer.name}, cache key: {cache_key}" + ) + output_shapes = [] + output_values = [] + for i in range(layer.num_outputs): + output = layer.get_output(i) + output_shapes.append(shapes[output.name]) + output_values.append(values.get(output.name)) + cache[cache_key] = (output_shapes, output_values) + + +def get_shape_info(trt_network, profile, shape_type: ShapeType = ShapeType.OPT): + shapes = {} + values = {} + sorted_layer_ids = get_sorted_layer_ids(trt_network) + infer_shape_layers = False + + shape_network, shape_profile, shape_layers, output_mapping = get_shape_network( + trt_network, + shapes, + values, + sorted_layer_ids, + profile=profile, + shape_type=shape_type) + try: + infer_shapes(shape_network, shapes, values, shape_profile) + for proxy_output, output in output_mapping.items(): + shapes[output] = shapes[proxy_output] + values[output] = [*map(bool, values[proxy_output])] + del shapes[proxy_output] + del values[proxy_output] + except RuntimeError: + infer_shape_layers = True + + cache = {} + for layer_id in sorted_layer_ids: + layer = trt_network.get_layer(layer_id) + is_shape_io = layer.name in shape_layers + if is_shape_io and not infer_shape_layers: + continue + set_trt_network(layer, trt_network) + infer_per_layer_shapes(layer, + shapes, + values, + cache, + is_shape_io=is_shape_io) + return ShapeInfo(shapes, values, shape_layers) diff --git a/tensorrt_llm/auto_parallel/simplifier.py b/tensorrt_llm/auto_parallel/simplifier.py new file mode 100644 index 000000000..2aa308146 --- /dev/null +++ b/tensorrt_llm/auto_parallel/simplifier.py @@ -0,0 +1,835 @@ +import math +import re +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Tuple + +import numpy as np + +from tensorrt_llm.network import Network + +from .config import AutoParallelConfig +from .device_mesh import PhysicalDeviceMesh +from .pipeline_graph import PipelineGraph +from .shape_info import ShapeInfo, ShapeType, get_shape_info +from .tensor_parallel.p2p_node import P2PType +from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger + + +class StageType(Enum): + START = 0 + BLOCK = 1 + END = 2 + + +class BuildingBlock: + + def __init__(self, graph, layer_range) -> None: + self.graph = graph + self.layer_range = layer_range + self.network = graph.as_trt() + self.owned_inputs = {} + self.is_edges_collected = False + self.intra_edges = [] + self.src_inter_edges = [] + self.dst_inter_edges = [] + self.relative_src_inter_edges = [] + self.relative_dst_inter_edges = [] + self.relative_inter_edges = set() + self.edge_hash = None + self.outputs = None + self.type_id = -1 + self.block_id = -1 + self.p2p_type = None + self.is_superset = False + self.is_subset = False + self.sorted_layer_ids = [] + + def collect_edges(self): + if self.is_edges_collected: + return + for layer_index in self.layer_range: + trt_layer = self.network.get_layer(layer_index) + layer = self.graph.get_layer(trt_layer.name) + layer_offset = layer.index - self.layer_range.start + for input_index, input in enumerate(layer.inputs): + if input is not None: + if input.is_graph_input: + is_owned = input.graph_input_index in self.owned_inputs + if not is_owned and np.all([ + layer.index in self.layer_range or np.all([ + output.as_trt().is_shape_tensor + for output in layer.outputs + ]) for layer, _ in input.consumers + ]): + self.owned_inputs[input.graph_input_index] = len( + self.owned_inputs) + is_owned = True + if is_owned: + self.intra_edges.append( + (-1, self.owned_inputs[input.graph_input_index], + layer_offset, input_index)) + else: + self.dst_inter_edges.append( + (-1, input.graph_input_index, layer_offset, + input_index)) + else: + src_layer_index = input.producer.index + if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop: + self.dst_inter_edges.append( + (src_layer_index, input.output_index, + layer_offset, input_index)) + else: + src_layer_offset = src_layer_index - self.layer_range.start + self.intra_edges.append( + (src_layer_offset, input.output_index, + layer_offset, input_index)) + for output_index, output in enumerate(layer.outputs): + for dst_layer, dst_input_index in output.consumers: + dst_layer_index = dst_layer.index + if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop: + self.src_inter_edges.append( + (layer_offset, output_index, dst_layer_index, + dst_input_index)) + self.edge_hash = tuple(self.intra_edges) + self.outputs = sorted( + set((edge[0], edge[1]) for edge in self.src_inter_edges)) + self.is_edges_collected = True + + def collect_relative_inter_edges(self, layer_to_block): + self.collect_edges() + for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges: + if src_layer_index in layer_to_block: + src_block = layer_to_block[src_layer_index] + src_layer_offset = src_layer_index - src_block.layer_range.start + dst = (self.type_id, dst_layer_index, dst_input_index) + self.relative_dst_inter_edges.append( + (src_block.type_id, src_layer_offset, src_output_index, + *dst)) + else: + self.relative_dst_inter_edges.append( + (-1, src_layer_index, src_output_index, self.type_id, + dst_layer_index, dst_input_index)) + self.relative_inter_edges = set(self.relative_dst_inter_edges + + self.outputs) + + def get_input_names(self): + self.collect_edges() + input_tensor_names = [] + for edge in self.dst_inter_edges: + layer_index = edge[0] + output_index = edge[1] + if layer_index == -1: + tensor_name = self.network.get_input(output_index).name + else: + tensor_name = self.network.get_layer(layer_index).get_output( + output_index).name + input_tensor_names.append(tensor_name) + return input_tensor_names + + def get_input_mapping(self, last_blocks): + input_mapping = {} + for tensor_name, relative_edge in zip(self.get_input_names(), + self.relative_dst_inter_edges): + type_id = relative_edge[0] + output_index = relative_edge[2] + if type_id >= 0: + last_block = last_blocks[type_id] + layer_offset = relative_edge[1] + mapped_layer_index = last_block.layer_range.start + layer_offset + mapped_tensor_name = self.network.get_layer( + mapped_layer_index).get_output(output_index).name + input_mapping[tensor_name] = mapped_tensor_name + else: + input_mapping[tensor_name] = tensor_name + return input_mapping + + +@dataclass +class GraphMapping: + layer_mapping: Dict[int, int] = None + block_mapping: Dict[int, int] = None + p2p_types: Dict[int, P2PType] = None + p2p_tensors: Dict[int, List[str]] = None + block_to_stage: Dict[int, int] = None + same_spec_layer_mapping: Dict[str, str] = None + + +@dataclass +class GraphConfig: + num_micro_batches: int = 1 + num_blocks: int = 1 + num_stages: int = 1 + has_cross_device: bool = False + has_cross_host: bool = False + graph_mapping: GraphMapping = None + phy_mesh: PhysicalDeviceMesh = None + stage_phy_meshes: List[PhysicalDeviceMesh] = None + + +class Simplifier: + + def __init__(self, network: Network, config: AutoParallelConfig): + self.config = config + self.sharded_io_allowlist = config.sharded_io_allowlist + self.same_buffer_io = config.same_buffer_io + self.same_spec_io = config.same_spec_io.copy() + for key, value in self.same_buffer_io.items(): + if key not in self.same_spec_io: + self.same_spec_io[key] = value + + self.llm_network = network + self.network = network.trt_network + self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map + self.graph = self.get_graph() + self.init_layer_hash() + + module_tree = self.get_module_tree() + building_blocks = self.collect_building_blocks(module_tree) + blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks) + self.blocks_by_edge_hash = self.get_blocks_by_edge_hash( + blocks_by_module_hash) + self.layer_to_block = self.get_layer_to_block() + self.blocks = self.get_all_blocks() + self.backbone_blocks = self.get_backbone_blocks() + self.graph_mapping_for_shape = self.get_graph_mapping_for_shape() + self.graph_for_shape = self.create_simplified_graph_for_shape() + self.shape_info = None + self.num_micro_batches = None + + def infer_shapes(self, num_micro_batches): + if self.num_micro_batches == num_micro_batches: + return + with silent_trt_logger(): + self.shape_info = self.get_full_shape_info(num_micro_batches) + self.graph.assign_shapes(self.shape_info) + self.num_micro_batches = num_micro_batches + + def list_all_num_micro_batches(self): + opt_batch_size = self.get_opt_batch_size() + candidates = [] + for num_micro_batches in range(1, self.get_opt_batch_size() + 1): + if opt_batch_size % num_micro_batches == 0: + candidates.append(num_micro_batches) + return candidates + + def get_graph(self): + graph = PipelineGraph.from_trt(self.network) + graph._unfilled_weights = self.llm_network._unfilled_weights.copy() + graph._io_buffer_mapping + for input in graph.inputs: + input_name = input.name + for pattern, repl in self.same_buffer_io.items(): + if re.match(pattern, input_name): + output_name = re.sub(pattern, repl, input_name) + output = graph.get_output(output_name) + if output is not None: + graph._io_buffer_mapping[output_name] = input_name + return graph + + def get_opt_batch_size(self): + input_tensors = self.llm_network._inputs + num_profiles = len(list(input_tensors.values())[0].profiles) + opt_batch_sizes = [] + for i in range(num_profiles): + for input_tensor in input_tensors.values(): + shape_profile = input_tensor.profiles[i] + opt_shape = shape_profile.opt + for j in range(len(input_tensor.shape)): + name = input_tensor.trt_tensor.get_dimension_name(j) + if name == 'batch_size': + opt_batch_sizes.append(opt_shape[j]) + return min(opt_batch_sizes) + + def get_module_hash(self, layer_range): + module_hash = () + for i in layer_range: + assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}" + layer_name = self.network.get_layer(i).name + layer = self.graph.get_layer(layer_name) + module_hash += (layer.attrs["hash"], ) + return module_hash + + def get_network_hash(self) -> str: + return str(self.get_module_hash(range(self.network.num_layers))) + + def collect_building_blocks(self, module_tree): + building_blocks = {} + queue = [] + for tree in module_tree["children"].values(): + queue.append(tree) + while len(queue) > 0: + while len(queue) > 0: + tree = queue.pop(0) + module_name = tree["name"] + if module_name is None: + for child in tree["children"].values(): + queue.append(child) + continue + layer_range = self.module_to_layer_range_map[module_name] + module_hash = self.get_module_hash(layer_range) + if module_hash in building_blocks: + building_blocks[module_hash].append(tree) + else: + building_blocks[module_hash] = [tree] + for module_hash in [*building_blocks.keys()]: + if len(building_blocks[module_hash]) == 1: + tree = building_blocks[module_hash][0] + for child in tree["children"].values(): + queue.append(child) + del building_blocks[module_hash] + blocks_by_module_hash = { + module_hash: [ + BuildingBlock(self.graph, + self.module_to_layer_range_map[tree["name"]]) + for tree in trees + ] + for module_hash, trees in building_blocks.items() + } + building_blocks = [] + for block_list in blocks_by_module_hash.values(): + for block in block_list: + building_blocks.append(block) + building_blocks = sorted(building_blocks, + key=lambda x: x.layer_range.start) + if len(building_blocks) >= 2: + for block, next_block in zip(building_blocks[:-1], + building_blocks[1:]): + block.layer_range = range(block.layer_range.start, + next_block.layer_range.start) + return building_blocks + + def get_all_blocks(self): + building_blocks = [] + for block_list in self.blocks_by_edge_hash.values(): + for block in block_list: + building_blocks.append(block) + building_blocks = sorted(building_blocks, + key=lambda x: x.layer_range.start) + all_blocks = [] + current_layer_index = 0 + block_id = 0 + for block in building_blocks: + assert current_layer_index <= block.layer_range.start + if current_layer_index < block.layer_range.start: + new_block = BuildingBlock( + self.graph, + range(current_layer_index, block.layer_range.start)) + new_block.block_id = block_id + block_id += 1 + all_blocks.append(new_block) + block.block_id = block_id + block_id += 1 + all_blocks.append(block) + current_layer_index = block.layer_range.stop + if current_layer_index < self.graph.num_layers: + new_block = BuildingBlock( + self.graph, range(current_layer_index, self.graph.num_layers)) + new_block.block_id = block_id + all_blocks.append(new_block) + sorted_layer_ids = get_sorted_layer_ids(self.network) + for block in all_blocks: + block.collect_relative_inter_edges(self.layer_to_block) + for layer_id in sorted_layer_ids: + if layer_id in block.layer_range: + block.sorted_layer_ids.append(layer_id) + return all_blocks + + def get_backbone_blocks(self): + sorted_blocks = sorted( + self.blocks_by_edge_hash.values(), + key=lambda blocks: (len(blocks), len(blocks[0].layer_range)), + ) + if len(sorted_blocks) == 0: + return [] + else: + return sorted_blocks[-1] + + def get_blocks_by_module_hash(self, blocks): + blocks_by_module_hash = {} + for block in blocks: + module_hash = self.get_module_hash(block.layer_range) + if module_hash not in blocks_by_module_hash: + blocks_by_module_hash[module_hash] = [] + blocks_by_module_hash[module_hash].append(block) + for module_hash in [*blocks_by_module_hash.keys()]: + if len(blocks_by_module_hash[module_hash]) == 1: + del blocks_by_module_hash[module_hash] + return blocks_by_module_hash + + def get_module_tree(self): + module_tree = {"children": {}, "name": None} + for module_name in self.module_to_layer_range_map.keys(): + full_name = module_name.split('.') + current_tree = module_tree["children"] + for depth, name in enumerate(full_name): + if name not in current_tree: + current_tree[name] = {"children": {}, "name": None} + if depth == len(full_name) - 1: + current_tree[name]["name"] = module_name + else: + current_tree = current_tree[name]["children"] + return module_tree + + def get_blocks_by_edge_hash(self, blocks_by_module_hash): + blocks_by_edge_hash = {} + for block_list in blocks_by_module_hash.values(): + for block in block_list: + block.collect_edges() + edge_hash = block.edge_hash + if edge_hash not in blocks_by_edge_hash: + blocks_by_edge_hash[edge_hash] = [] + blocks_by_edge_hash[edge_hash].append(block) + for edge_hash in [*blocks_by_edge_hash.keys()]: + if len(blocks_by_edge_hash[edge_hash]) == 1: + del blocks_by_edge_hash[edge_hash] + else: + block_list = blocks_by_edge_hash[edge_hash] + blocks_by_edge_hash[edge_hash] = sorted( + block_list, key=lambda x: x.layer_range.start) + for type_id, block_list in enumerate(blocks_by_edge_hash.values()): + for block in block_list: + block.type_id = type_id + return blocks_by_edge_hash + + def get_layer_to_block(self): + layer_to_block = {} + for block_list in self.blocks_by_edge_hash.values(): + for block in block_list: + for layer_index in block.layer_range: + layer_to_block[layer_index] = block + return layer_to_block + + def clean_blocks(self): + for block in self.blocks: + block.p2p_type = None + block.is_superset = False + block.is_subset = False + + def mark_p2p_type(self, phy_mesh, stage_phy_meshes, + graph_config: GraphConfig): + if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1: + return + assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0 + block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes) + + for block in self.backbone_blocks: + block.p2p_type = None + for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]): + next_stage_phy_mesh = stage_phy_meshes[stage_index + 1] + last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1] + next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten( + )[0] + num_devices_per_host = phy_mesh.num_devices_per_host + next_block = self.backbone_blocks[(stage_index + 1) * + block_per_stage] + if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host: + next_block.p2p_type = P2PType.CROSS_HOST + graph_config.has_cross_host = True + else: + next_block.p2p_type = P2PType.CROSS_DEVICE + graph_config.has_cross_device = True + + def get_graph_mapping(self): + layer_mapping = {} + block_mapping = {} + p2p_types = {} + p2p_tensors = {} + for block_list in self.blocks_by_edge_hash.values(): + superset_blocks = [] + superset_block_index = {} + for block in block_list: + block_added = False + for index, superset_block in enumerate(list(superset_blocks)): + if block.p2p_type == superset_block.p2p_type: + if block.relative_inter_edges.issubset( + superset_block.relative_inter_edges): + block.is_subset = True + block.is_superset = False + superset_block_index[id(block)] = index + block_added = True + break + elif superset_block.relative_inter_edges.issubset( + block.relative_inter_edges): + superset_block.is_subset = True + superset_block.is_superset = False + block.is_subset = False + block.is_superset = True + superset_blocks[index] = block + superset_block_index[id(block)] = index + block_added = True + break + if not block_added: + block.is_subset = False + block.is_superset = True + superset_blocks.append(block) + superset_block_index[id(block)] = len(superset_blocks) - 1 + for block in block_list: + assert not (block.is_subset and block.is_superset) + if block.is_subset: + superset_block = superset_blocks[superset_block_index[id( + block)]] + block_mapping[block.block_id] = superset_block.block_id + owned_inputs = map( + lambda x: x[0], + sorted(block.owned_inputs.items(), key=lambda x: x[1])) + superset_owned_inputs = map( + lambda x: x[0], + sorted(superset_block.owned_inputs.items(), + key=lambda x: x[1])) + for from_input_id, to_input_id in zip( + owned_inputs, superset_owned_inputs): + from_input_name = self.network.get_input( + from_input_id).name + to_input_name = self.network.get_input(to_input_id).name + layer_mapping[from_input_name] = to_input_name + for from_layer_id, to_layer_id in zip( + block.layer_range, superset_block.layer_range): + from_layer = self.network.get_layer(from_layer_id) + to_layer = self.network.get_layer(to_layer_id) + layer_mapping[from_layer.name] = to_layer.name + for i in range(from_layer.num_outputs): + from_output = from_layer.get_output(i) + if from_output.is_network_output: + to_output = to_layer.get_output(i) + layer_mapping[from_output.name] = to_output.name + if block.p2p_type is not None: + p2p_types[block.block_id] = block.p2p_type + p2p_tensors[block.block_id] = [ + *set(block.get_input_names()) + ] + for from_name, to_name in zip( + block.get_input_names(), + superset_block.get_input_names()): + layer_mapping[ + f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}" + stage_id = 0 + block_to_stage = {} + for block in self.blocks: + if block.p2p_type is not None: + stage_id += 1 + block_to_stage[block.block_id] = stage_id + return GraphMapping( + layer_mapping, + block_mapping, + p2p_types, + p2p_tensors, + block_to_stage, + ) + + def create_simplified_graph(self, graph_config: GraphConfig): + new_graph = PipelineGraph.create_graph() + new_graph._io_buffer_mapping = self.graph._io_buffer_mapping + layer_mapping = graph_config.graph_mapping.layer_mapping + + for i in range(self.network.num_inputs): + trt_input = self.network.get_input(i) + if trt_input.name not in layer_mapping: + new_graph.add_input(trt_input) + + last_blocks = {} + same_spec_mapping = {} + same_spec_layer_mapping = {} + shape_mapping = {} + building_block_id = 0 + same_spec_ids = {} + same_spec_count = 0 + for block in self.blocks: + if not block.is_subset: + stage_type = None + if not block.is_superset: + if block.block_id == 0: + stage_type = StageType.START + elif block.block_id == len(self.blocks) - 1: + stage_type = StageType.END + input_mapping = block.get_input_mapping(last_blocks) + for from_name, to_name in [*input_mapping.items()]: + if to_name in same_spec_mapping: + input_mapping[from_name] = same_spec_mapping[to_name] + if to_name in layer_mapping: + input_mapping[from_name] = layer_mapping[to_name] + if block.is_superset and block.p2p_type is not None: + for from_name, to_name in [*input_mapping.items()]: + output_tensor = new_graph.get_tensor(to_name) + p2p_layer = new_graph.as_trt().add_identity( + output_tensor.as_trt()) + p2p_layer.name = f"p2p_block{block.block_id}_{from_name}" + p2p_tensor = p2p_layer.get_output(0) + p2p_tensor.name = f"{p2p_layer.name}_output" + wrapped_layer = new_graph.register_layer(p2p_layer) + wrapped_layer.attrs[ + "building_block_id"] = building_block_id + wrapped_layer.attrs["p2p_type"] = block.p2p_type + input_mapping[from_name] = p2p_tensor.name + shape_mapping[p2p_tensor.name] = from_name + building_block_id += 1 + for i in block.sorted_layer_ids: + layer = self.network.get_layer(i) + wrapped_layer = new_graph.add_layer( + layer, + input_mapping=input_mapping, + ) + wrapped_layer.attrs["building_block_id"] = building_block_id + wrapped_layer.attrs["stage_type"] = stage_type + if block.is_superset: + last_blocks[block.type_id] = block + + if block.type_id in same_spec_ids: + same_spec_id = same_spec_ids[block.type_id] + update_same_spec_count = False + else: + same_spec_id = same_spec_count + same_spec_ids[block.type_id] = same_spec_id + update_same_spec_count = True + count = same_spec_id + for i, (layer_offset, + output_index) in enumerate(block.outputs): + layer = self.network.get_layer(block.layer_range.start + + layer_offset) + tensor_name = layer.get_output(output_index).name + output_tensor = new_graph.get_tensor(tensor_name) + same_spec_layer = new_graph.as_trt().add_identity( + output_tensor.as_trt()) + same_spec_layer.name = f"{tensor_name}_same_spec" + same_spec_tensor = same_spec_layer.get_output(0) + same_spec_tensor.name = f"{same_spec_layer.name}_output" + wrapped_layer = new_graph.register_layer( + same_spec_layer) + wrapped_layer.attrs[ + "building_block_id"] = building_block_id + wrapped_layer.attrs["same_spec_id"] = count + count += 1 + same_spec_mapping[tensor_name] = same_spec_tensor.name + same_spec_layer_mapping[ + same_spec_layer.name] = layer.name + shape_mapping[same_spec_tensor.name] = tensor_name + for i, graph_input_index in enumerate( + block.owned_inputs.keys()): + input_name = self.network.get_input( + graph_input_index).name + input_tensor = new_graph.get_input(input_name) + input_tensor.attrs["same_spec_id"] = count + count += 1 + if update_same_spec_count: + same_spec_count = count + building_block_id += 1 + graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping + + if len(self.backbone_blocks) >= 2: + start_block = self.backbone_blocks[0] + if start_block.is_subset: + start_block = self.blocks[graph_config.graph_mapping. + block_mapping[start_block.block_id]] + for i in start_block.layer_range: + layer_name = self.network.get_layer(i).name + layer = new_graph.get_layer(layer_name) + layer.attrs["in_start_block"] = True + end_block = self.backbone_blocks[-1] + if end_block.is_subset: + end_block = self.blocks[graph_config.graph_mapping. + block_mapping[end_block.block_id]] + for i in end_block.layer_range: + layer_name = self.network.get_layer(i).name + layer = new_graph.get_layer(layer_name) + layer.attrs["in_end_block"] = True + slowest_p2p_type = None + if graph_config.has_cross_host: + slowest_p2p_type = P2PType.CROSS_HOST + elif graph_config.has_cross_device: + slowest_p2p_type = P2PType.CROSS_DEVICE + if slowest_p2p_type is not None: + for block in self.blocks: + if block.is_superset and block.p2p_type == slowest_p2p_type: + for i in block.layer_range: + layer_name = self.network.get_layer(i).name + layer = new_graph.get_layer(layer_name) + layer.attrs["in_slowest_block"] = True + + for i in range(self.network.num_outputs): + trt_output = self.network.get_output(i) + output = self.graph.get_output(trt_output.name) + if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[ + output.producer.index].is_subset: + continue + if trt_output.is_shape_tensor: + new_output = new_graph.add_output_shape(trt_output) + else: + new_output = new_graph.add_output(trt_output) + sharded_io = False + for pattern in self.sharded_io_allowlist: + if re.match(pattern, new_output.name): + sharded_io = True + break + if not sharded_io: + new_output.producer.attrs["is_replicated"] = True + + for input in new_graph.inputs: + input_name = input.name + sharded_io = False + for pattern in self.sharded_io_allowlist: + if re.match(pattern, input_name): + sharded_io = True + break + if not sharded_io: + input.attrs["is_replicated"] = True + for pattern, repl in self.same_spec_io.items(): + if re.match(pattern, input_name): + output_name = re.sub(pattern, repl, input_name) + output = new_graph.get_output(output_name) + if output is not None: + if "same_spec_id" in input.attrs: + same_spec_id = input.attrs["same_spec_id"] + else: + same_spec_id = same_spec_count + same_spec_count += 1 + input.attrs["same_spec_id"] = same_spec_id + output.attrs["same_spec_id"] = same_spec_id + if math.prod(self.graph.get_input( + input_name).shape) < math.prod( + self.graph.get_output(output_name).shape): + input.attrs["no_memory_footprint"] = True + else: + output.attrs["no_memory_footprint"] = True + + return new_graph, shape_mapping + + def enrich_shape_info(self, shape_mapping): + shapes = self.shape_info.shapes.copy() + max_shapes = self.shape_info.max_shapes.copy() + values = self.shape_info.values.copy() + shape_layers = self.shape_info.shape_layers + for from_name, to_name in shape_mapping.items(): + if to_name in shapes: + shapes[from_name] = shapes[to_name] + if to_name in max_shapes: + max_shapes[from_name] = max_shapes[to_name] + if to_name in values: + values[from_name] = values[to_name] + shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes) + return shape_info + + def simplify_graph( + self, phy_mesh: PhysicalDeviceMesh, num_stages: int, + num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]: + num_blocks = len(self.backbone_blocks) + if num_blocks % num_stages != 0: + return None, None + graph_config = GraphConfig() + graph_config.num_micro_batches = self.num_micro_batches + graph_config.num_blocks = num_blocks + graph_config.num_stages = num_stages + graph_config.phy_mesh = phy_mesh + stage_phy_meshes = phy_mesh.split_pipeline_meshes( + num_stages, num_devices_per_stage) + graph_config.stage_phy_meshes = stage_phy_meshes + with silent_trt_logger(): + self.clean_blocks() + self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config) + graph_config.graph_mapping = self.get_graph_mapping() + new_graph, shape_mapping = self.create_simplified_graph( + graph_config) + shape_info = self.enrich_shape_info(shape_mapping) + new_graph.assign_shapes(shape_info) + return new_graph, graph_config + + def get_graph_mapping_for_shape(self): + layer_mapping = {} + tensor_mapping = {} + for block_list in self.blocks_by_edge_hash.values(): + head_block = block_list[0] + for block in block_list[1:]: + for from_layer_id, to_layer_id in zip(block.layer_range, + head_block.layer_range): + from_layer = self.network.get_layer(from_layer_id) + to_layer = self.network.get_layer(to_layer_id) + layer_mapping[from_layer.name] = to_layer.name + for i in range(from_layer.num_outputs): + tensor_mapping[from_layer.get_output( + i).name] = to_layer.get_output(i).name + return layer_mapping, tensor_mapping + + def create_simplified_graph_for_shape(self): + new_graph = PipelineGraph.create_graph() + + for i in range(self.network.num_inputs): + trt_input = self.network.get_input(i) + new_graph.add_input(trt_input) + + head_blocks = {} + removed_blocks = set() + removed_layers = set() + for block_list in self.blocks_by_edge_hash.values(): + head_block = block_list[0] + head_blocks[head_block.type_id] = head_block + for block in block_list[1:]: + removed_blocks.add(id(block)) + for layer_index in block.layer_range: + removed_layers.add(layer_index) + + for block in self.blocks: + if not id(block) in removed_blocks: + input_mapping = block.get_input_mapping(head_blocks) + for i in block.sorted_layer_ids: + layer = self.network.get_layer(i) + new_graph.add_layer( + layer, + input_mapping=input_mapping, + ) + + for i in range(self.network.num_outputs): + trt_output = self.network.get_output(i) + output = self.graph.get_output(trt_output.name) + if output.producer is not None and output.producer.index in removed_layers: + continue + if trt_output.is_shape_tensor: + new_graph.add_output_shape(trt_output) + else: + new_graph.add_output(trt_output) + + return new_graph + + def get_full_shape_info(self, num_micro_batches): + layer_mapping, tensor_mapping = self.graph_mapping_for_shape + optimization_profiles = self.llm_network._generate_optimization_profiles( + ) + if len(optimization_profiles) > 0: + optimization_profile = optimization_profiles[-1] + else: + optimization_profile = None + shape_info = get_shape_info(self.graph_for_shape.as_trt(), + optimization_profile) + max_shape_info = get_shape_info(self.graph_for_shape.as_trt(), + optimization_profile, + shape_type=ShapeType.MAX) + shape_info.max_shapes = max_shape_info.shapes + for removed_tensor_name, tensor_name in tensor_mapping.items(): + shape_info.shapes[removed_tensor_name] = shape_info.shapes[ + tensor_name] + shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[ + tensor_name] + if tensor_name in shape_info.values: + shape_info.values[removed_tensor_name] = shape_info.values[ + tensor_name] + for removed_layer_name, layer_name in layer_mapping.items(): + if layer_name in shape_info.shape_layers: + shape_info.shape_layers.add(removed_layer_name) + return shape_info + + def init_layer_hash(self): + with silent_trt_logger(): + optimization_profiles = self.llm_network._generate_optimization_profiles( + ) + if len(optimization_profiles) > 0: + optimization_profile = optimization_profiles[-1] + else: + optimization_profile = None + shape_info = get_shape_info(self.network, optimization_profile) + dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors} + for layer in self.graph.layers: + layer_hash = get_cache_key( + layer.as_trt(), + shape_info.shapes, + shape_info.values, + dtypes, + ) + layer.attrs["hash"] = layer_hash diff --git a/tensorrt_llm/auto_parallel/solver.py b/tensorrt_llm/auto_parallel/solver.py new file mode 100644 index 000000000..7dd195117 --- /dev/null +++ b/tensorrt_llm/auto_parallel/solver.py @@ -0,0 +1,641 @@ +"""This code is adapted from Alpa https://github.com/alpa-projects/alpa/ with some changes. +""" +import multiprocessing +import time +import warnings +from collections import defaultdict + +import numpy as np +import pulp +from pulp import LpMinimize, LpProblem, LpVariable, lpDot, lpSum + +from ..logger import logger + + +class Solution: + + def __init__(self, leaf_strategies, s_val, e_val, edge_pairs, + node_index_dict, total_cost): + self.leaf_strategies = leaf_strategies + self.nodes = [ + strategies_vector.node for strategies_vector in self.leaf_strategies + ] + self.s_val = s_val + self.e_val = e_val + self.total_cost = total_cost + self.edge_pairs = list(np.reshape(edge_pairs, (-1, 2))) + self.node_index_dict = node_index_dict + self.index_node_dict = {} + for node, index in self.node_index_dict.items(): + self.index_node_dict[index] = node + self.node_best_strategy = {} + self._annotate_strategy() + + def _annotate_strategy(self): + self.node_best_strategy = {} + for index, node in enumerate(self.nodes): + best_strategy_id = self.s_val[index] + best_strategy = self.leaf_strategies[index][best_strategy_id] + self.node_best_strategy[node.node_name] = best_strategy + + for edge_idx, edge_pair in enumerate(self.edge_pairs): + src_node = self.index_node_dict[edge_pair[0]] + dst_node = self.index_node_dict[edge_pair[1]] + src_node_index = self.node_index_dict[src_node] + for dst_pre_node in dst_node.predecessor_nodes: + if dst_pre_node is None: + continue + if src_node.node_name == dst_pre_node.node_name: + self.node_best_strategy[ + dst_node.node_name].best_resharding_cost[ + src_node.node_name] = [ + self.node_best_strategy[dst_node.node_name]. + resharding_costs[src_node.node_name][ + self.s_val[src_node_index]] + ] + + def print_solution(self): + for index, node in enumerate(self.nodes): + best_strategy = self.node_best_strategy[node.node_name] + print(f'\n[{index}]: node_name = {node.node_name}') + best_strategy.print_strategy(best_resharding_cost_only=True) + print(f'solution total cost = {self.total_cost}') + + +class CostGraph: + ''' + A graph data structure to simplify the edge cost graph. It has two main functions: + 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in + CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. + 2. To reduce the searching space, we merge computationally-trivial operators, such as + element-wise operators, transpose, and reduction, into their following nodes. The merging information will + be given by the StrategiesVector depending on the type of target node and following nodes. + + Argument: + leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph. + simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True) + ''' + + def __init__(self, leaf_strategies): + self.leaf_strategies = leaf_strategies + self.nodes = [ + strategies_vector.node for strategies_vector in leaf_strategies + ] + # stores number of strategies in each node + self.node_strategies_vector = {} + for node, strategies_vector in zip(self.nodes, self.leaf_strategies): + self.node_strategies_vector[node] = strategies_vector + # extra_node_costs will store the extra costs introduced by merging nodes + self.extra_node_costs = {} + self.following_dict = {} + self._build_cost_graph() + + def _remove_invalid_node(self, node, attr_name): + remove_list = [] + target_node_list = getattr(node, attr_name, []) + for target_node in target_node_list: + if target_node not in self.nodes: + remove_list.append(target_node) + for element in remove_list: + target_node_list.remove(element) + + def _build_cost_graph(self): + ''' + This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be + set to node. + ''' + self.edge_costs = {} + for dst_node, strategies_vector in zip(self.nodes, + self.leaf_strategies): + # build edge_cost + for src_node in dst_node.predecessor_nodes: + if src_node is None: + continue + if src_node not in self.nodes: + continue + node_pair = (src_node, dst_node) + edge_cost = {} + for i in range(len(strategies_vector)): + for j in range(len(self.node_strategies_vector[src_node])): + resharding_cost = strategies_vector[i].resharding_costs[ + src_node.node_name][j][-1] + edge_cost[(j, i)] = resharding_cost + self.edge_costs[node_pair] = edge_cost + + def get_edge_cost(self, src_node, dst_node): + return self.edge_costs[(src_node, dst_node)] + + +class Solver: + INFINITY_COST = 1e13 + + def __init__(self, + cost_graph: CostGraph, + memory_budget: float = -1.0, + solution_numbers: int = 1, + memory_increasing_coefficient: float = 1.3, + verbose=False): + ''' + Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. + Argument: + graph: The computing graph to be optimized. + strategies_constructor: It will provide all the possible strategies for each node in the computing graph. + cost_graph: A graph data structure to simplify the edge cost graph. + graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints. + memory_budget: Memory constraint for the solution. + solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget. + memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget. + ''' + self.cost_graph = cost_graph + self.leaf_strategies = cost_graph.leaf_strategies + self.nodes = cost_graph.nodes + self.memory_budget = memory_budget + self.solution_numbers = solution_numbers + if self.solution_numbers > 1: + self.memory_increasing_coefficient = memory_increasing_coefficient + else: + self.memory_increasing_coefficient = 1 + # temporarily we use all nodes as liveness list, we count the backward memory cost together with + # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase. + # self.liveness_list = self.graph_analyser.liveness_analysis() + self.liveness_list = self.nodes + self.node_index_dict = self._generate_node_index_dict() + # The last solution vector of auto sharding. + self.last_s_val = None + # The last objective value of the best ILP solution. + self.last_objective = None + self.verbose = verbose + + def _generate_node_index_dict(self): + node_index_dict = {} + for index, node in enumerate(self.nodes): + node_index_dict[node] = index + return node_index_dict + + def _prepare_data_for_solver(self): + ''' + Extract information from components for solver. + ''' + node_nums = len(self.leaf_strategies) + memory_budget = self.memory_budget + + # prepare strategies_len + strategies_len = [] + for node in self.nodes: + strategies_len.append( + len(self.cost_graph.node_strategies_vector[node])) + strategies_len = np.array(strategies_len) + + # prepare edge_pairs and resharding costs + edge_pairs = [] + resharding_costs = [] + edge_cost_level = [] + edge_resharding_weights = [] + for pairs, edge_cost in self.cost_graph.edge_costs.items(): + src_node = pairs[0] + dst_node = pairs[1] + src_node_index = self.node_index_dict[src_node] + dst_node_index = self.node_index_dict[dst_node] + edge_pairs.append(src_node_index) + edge_pairs.append(dst_node_index) + edge_cost_level.append( + (dst_node.building_block_id, dst_node.cost_level)) + for i in range(strategies_len[src_node_index]): + for j in range(strategies_len[dst_node_index]): + resharding_costs.append(edge_cost[(i, j)]) + edge_resharding_weights.append(dst_node.resharding_weight + + dst_node.pipeline_weight) + edge_pairs = np.array(edge_pairs) + resharding_costs = np.array(resharding_costs) + edge_resharding_weights = np.array(edge_resharding_weights) + # prepare compute_costs, communication_costs and memory_costs + compute_costs = [] + communication_costs = [] + memory_costs = [] + peak_act_memory_costs, constant_memory_costs = [], [] + node_sharding_weights = [] + for node, strategies_vector in zip(self.nodes, self.leaf_strategies): + for index, strategy in enumerate(strategies_vector): + compute_cost = strategy.sharding_cost + origin_communication_cost = strategy.communication_cost + memory_cost = strategy.const_memory_footprint * node.sharding_weight + peak_act_memory = strategy.peak_memory_footprint + # extract the memory cost in float from MemoryCost item and sum them up + compute_costs.append(compute_cost) + # node in extra_node_costs means it has some extra communication + # cost from node merging, so we need to add those extra communication + # cost into + + communication_costs.append(origin_communication_cost) + peak_act_memory_costs.append(peak_act_memory) + constant_memory_costs.append(memory_cost) + node_sharding_weights.append(node.sharding_weight + + node.pipeline_weight) + + compute_costs = np.array(compute_costs) + communication_costs = np.array(communication_costs) + memory_costs = np.array([constant_memory_costs, peak_act_memory_costs]) + node_sharding_weights = np.array(node_sharding_weights) + same_spec_nodes_dict = defaultdict(list) + node_cost_level = [] + for idx, node in enumerate(self.nodes): + if node.same_spec_id >= 0: + same_spec_nodes_dict[node.same_spec_id].append(idx) + node_cost_level.append((node.building_block_id, node.cost_level)) + # omit initial value for nodes + s_init_np = None + following_nodes = [-1 for i in range(node_nums)] + liveness_set = self.nodes + alias_set = [] + alias_convert_costs = None + return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, node_sharding_weights, edge_resharding_weights, same_spec_nodes_dict, node_cost_level, edge_cost_level, alias_convert_costs, s_init_np, self.verbose + + def _call_solver_serialized_args(self, + node_nums, + memory_budget, + strategies_len, + following_nodes, + edge_pairs, + alias_set, + liveness_set, + compute_costs, + communication_costs, + memory_costs, + resharding_costs, + node_sharding_weights, + edge_resharding_weights, + same_spec_nodes_dict, + node_cost_level, + edge_cost_level, + alias_convert_costs, + s_init_np=None, + verbose=True): + """ + Call the solver with serialized arguments. + """ + + time.time() + + for x in [ + strategies_len, edge_pairs, compute_costs, communication_costs, + memory_costs, resharding_costs, node_sharding_weights, + edge_resharding_weights + ]: + assert isinstance(x, np.ndarray) + assert len(strategies_len) == node_nums, "strategies_len" + + def get_non_zero_index(binary_vector): + """ + Get the index of non-zero item in a vector. + """ + ct = 0 + ret = None + for i, elem in enumerate(binary_vector): + if pulp.value(elem): + ret = i + ct += 1 + + assert ct == 1 + return ret + + # 0. Unpack flatten numpy arrays + s_follow = following_nodes + s_alias = alias_set + + E = edge_pairs.reshape((-1, 2)) # noqa + r = [] + pt = 0 + edge_set = set() + for (i, j) in E: + prod_length = strategies_len[i] * strategies_len[j] + + if (i, j) in edge_set: + raise ValueError(f"Duplicated edges: {(i, j)}") + + edge_set.add((i, j)) + r.append(resharding_costs[pt:pt + prod_length]) + pt += prod_length + assert pt == len(resharding_costs) + + ###################### + # omit alias set now # + ###################### + + # A = alias_set.reshape((-1, 2)) # noqa + # for (i, j) in A: + # prod_length = strategies_len[i] * strategies_len[j] + # v.append(alias_convert_costs[pt:pt + prod_length]) + # pt += prod_length + # assert pt == len(alias_convert_costs) + + # L = [] # noqa + # pt = node_nums + # for i in range(node_nums): + # length = liveness_set[i] + # L.append(liveness_set[pt:pt + length]) + # pt += length + # assert pt == len(liveness_set) + pt = 0 + + c = [] + d = [] + m = [] + peak_m = [] + pt = 0 + for i in range(node_nums): + length = strategies_len[i] + c.append(compute_costs[pt:pt + length]) + d.append(communication_costs[pt:pt + length]) + m.append(memory_costs[0][pt:pt + length]) + peak_m.append(memory_costs[1][pt:pt + length]) + pt += length + assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}" + assert pt == len( + communication_costs), f"{pt} == {len(communication_costs)}" + assert pt == len(memory_costs[0]), f"{pt} == {len(memory_costs[0])}" + + # 1. Create variables + + ############################# + # create variables for node # + ############################# + s = [] + num_nodes = 0 + reverse_follow_backpatch = [] + for i in range(node_nums): + if s_follow[i] < 0: + if strategies_len[i] == 1: + s.append([1]) + else: + if i not in s_alias: + num_nodes += 1 + s.append( + LpVariable.matrix(f"s[{i}]", + (range(strategies_len[i]), ), + cat="Binary")) + else: + s.append(s[s_alias[i]]) + else: + if s_follow[i] < len(s): + s.append(s[s_follow[i]]) + else: + s.append(None) + reverse_follow_backpatch.append(i) + + for i in reverse_follow_backpatch: + s[i] = s[s_follow[i]] + + ############################# + # create variables for edge # + ############################# + e = [] + num_edges = 0 + map_edge_to_idx = {} + for (idx, (i, j)) in enumerate(E): + if len(s[i]) == 1: + e.append(s[j]) + elif len(s[j]) == 1: + e.append(s[i]) + else: + if i in s_alias and j in s_alias and ( + s_alias[i], s_alias[j]) in map_edge_to_idx: + e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]]) + else: + num_edges += 1 + e.append( + LpVariable.matrix(f"e[{i},{j}]", + (range(len(s[i]) * len(s[j])), ), + cat="Binary")) + assert len(e[idx]) == len(r[idx]) + map_edge_to_idx[(i, j)] = idx + for element in s: + assert len(element) > 0 + # 2. Set initial value + ###################################### + # set a initial value for warm start # + ###################################### + if s_init_np is not None: + s_init = s_init_np.reshape((-1, 3)) + for (idx, value, fix) in s_init: + for i in range(len(s[idx])): + s[idx][i].setInitialValue(i == value) + if fix: + s[idx][i].fixValue() + + # 3. Objective + prob = LpProblem("myProblem", LpMinimize) + ################################################################### + # computing the node cost(computing cost and communication cost) # + ################################################################### + obj = 0 + block_cost_level_dict = {} + for i in range(node_nums): + assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + obj += (lpDot(s[i], c[i]) + + lpDot(s[i], d[i])) * node_sharding_weights[i] + cost_level = node_cost_level[i] + if -1 != cost_level[1]: + if cost_level in block_cost_level_dict: + block_cost_level_dict[cost_level] += lpDot( + s[i], c[i]) + lpDot(s[i], d[i]) + else: + block_cost_level_dict[cost_level] = lpDot( + s[i], c[i]) + lpDot(s[i], d[i]) + + ############################################# + # computing the edge cost(resharding cost) # + ############################################# + + for i in range(len(E)): + assert len(e[i]) == len(r[i]) + obj += lpDot(e[i], r[i]) * edge_resharding_weights[i] + cost_level = edge_cost_level[i] + if -1 != cost_level[1]: + if cost_level in block_cost_level_dict: + block_cost_level_dict[cost_level] += lpDot(e[i], r[i]) + else: + block_cost_level_dict[cost_level] = lpDot(e[i], r[i]) + prob += obj + if len(block_cost_level_dict) >= 2: + block_cost_levels = [key for key in block_cost_level_dict.keys()] + for i in range(len(block_cost_levels)): + for j in range(i + 1, len(block_cost_levels)): + if block_cost_levels[i][1] > block_cost_levels[j][1]: + prob += block_cost_level_dict[ + block_cost_levels[i]] >= block_cost_level_dict[ + block_cost_levels[j]] + 1e-6 + elif block_cost_levels[i][1] < block_cost_levels[j][1]: + prob += block_cost_level_dict[ + block_cost_levels[j]] >= block_cost_level_dict[ + block_cost_levels[i]] + 1e-6 + # 4. Constraints + # (a). specified by `cat="Binary"` + + # (b) + ################################################# + # make sure each node only choose one strategy # + ################################################# + for i in range(node_nums): + if s_follow[i] < 0: + prob += lpSum(s[i]) == 1 + + # (c) + ################################################# + # force to constrain some nodes have the same sharding specs # + ################################################# + for spec_id, same_spec_nodes_id in same_spec_nodes_dict.items(): + num_same_spec_nodes = len(same_spec_nodes_id) + if num_same_spec_nodes >= 2: + src_node_s = s[same_spec_nodes_id[0]] + num_specs = len(src_node_s) + for i in range(1, num_same_spec_nodes): + dst_node_s = s[same_spec_nodes_id[i]] + assert len( + dst_node_s + ) == num_specs, f'unmatched num_specs when force node {same_spec_nodes_id[0]} and {same_spec_nodes_id[i]} the same specs' + for j in range(num_specs): + prob += (src_node_s[j] == dst_node_s[j]) + + # (c) + ################################################# + # compute memory consumption with liveness set # + ################################################# + if memory_budget > 0: + # calculate the constant memory + mem = 0 + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + mem += lpSum(s[node_index][j] * m[node_index][j] + for j in range(len(s[node_index]))) + # calculate the peak activation memory + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + cur_peak_mem = lpSum(s[node_index][j] * peak_m[node_index][j] + for j in range(len(s[node_index]))) + total_mem = mem + cur_peak_mem + prob += total_mem <= memory_budget + + # (d). specified by `cat="Binary"` + + for (idx, (i, j)) in enumerate(E): + if strategies_len[i] == 1 or strategies_len[j] == 1: + continue + + # (e) + prob += lpSum(e[idx]) == 1 + + # (f) + for row in range(len(s[i])): + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] + for col in range(0, C)) <= s[i][row] + + # (g) + for col in range(len(s[j])): + R = len(s[i]) # noqa + C = len(s[j]) # noqa + prob += lpSum(e[idx][row * C + col] + for row in range(0, R)) <= s[j][col] + + if prob.objective.isNumericalConstant(): + objective = float(pulp.value(prob.objective)) + status = pulp.LpStatusOptimal + else: + msg = verbose + time_limit = 600 + solver = pulp.PULP_CBC_CMD( + mip=True, + msg=msg, + timeLimit=time_limit, + threads=multiprocessing.cpu_count(), + ) + prob.solve(solver) + + status = prob.status + objective = pulp.value(prob.objective) + objective = float( + objective) if objective is not None else self.INFINITY_COST + + if prob.status in [pulp.LpStatusInfeasible]: + objective = self.INFINITY_COST + + # Get and check results + s_val = np.full((node_nums, ), -1, dtype=np.int32) + for i in range(node_nums): + s_val[i] = get_non_zero_index(s[i]) + + e_val = np.full((len(E), ), -1, dtype=np.int32) + for (idx, (i, j)) in enumerate(E): + e_val[idx] = get_non_zero_index(e[idx]) + i_spec_index = e_val[idx] // len(s[j]) + j_spec_index = e_val[idx] % len(s[j]) + assert i_spec_index == s_val[i], f"e_val[{i}][{j}]" + assert j_spec_index == s_val[j], f"e_val[{i}][{j}]" + if verbose and r[idx][e_val[idx]] > 0: + print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}") + + self.last_s_val = list(s_val) + # self._recover_merged_node_strategy() + self.last_objective = objective + + if objective >= self.INFINITY_COST: + warnings.warn( + f"Cannot find an optimized solution given memory budget {self.memory_budget}, Please consider\n" + \ + f"1. increase memory budget if possible\n" + \ + f"2. enlarge mesh shape if possible\n" + \ + f"3. decrease the maximum parameters(i.e., max_batch_size, max_seq_len, etc.) in building config") + if memory_budget > 0: + # calculate the constant memory + mem = 0 + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + j = self.last_s_val[node_index] + mem += m[node_index][j] + max_peak_mem = 0 + for node in liveness_set: + if node not in self.node_index_dict: + continue + node_index = self.node_index_dict[node] + j = self.last_s_val[node_index] + cur_peak_mem = peak_m[node_index][j] + max_peak_mem = max(max_peak_mem, cur_peak_mem) + logger.debug( + f'constant_mem = {mem}, peak_mem = {max_peak_mem}, memory_budget = {memory_budget}' + ) + + solution = Solution(self.leaf_strategies, self.last_s_val, e_val, + edge_pairs, self.node_index_dict, + self.last_objective) + return status, solution + + def find_solution(self): + """ + Call the solver with serialized arguments and handle python errors. Additionally, + we could give a serious of solutions with different memory budget. + """ + if self.solution_numbers == 1: + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + + return ret + + origin_memory_budget = self.memory_budget + memory_budget_list = [ + origin_memory_budget * self.memory_increasing_coefficient**i + for i in range(self.solution_numbers) + ] + ret_list = [] + for memory_budget in memory_budget_list: + self.memory_budget = memory_budget + args = self._prepare_data_for_solver() + ret = self._call_solver_serialized_args(*args) + ret_list.append(ret) + + return ret_list diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py b/tensorrt_llm/auto_parallel/tensor_parallel/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py new file mode 100644 index 000000000..a5fd51310 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/activation_node.py @@ -0,0 +1,41 @@ +import copy + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Activation(Node): + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + in0_partition_dict = dim_partition_dict + out_partition_dict = copy.deepcopy(dim_partition_dict) + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py new file mode 100644 index 000000000..73e100559 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/assertion_node.py @@ -0,0 +1,34 @@ +import copy + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Assertion(Node): + + def _collect_strategies(self, device_mesh): + predecessor = self.predecessor_nodes[0] # one input for softmax node + strategies_vector = StrategiesVector(self) + for idx, strategy in enumerate(predecessor.strategies_vector): + global_input_name = self.op_data[ + 'input0'].name # current node's local name input0 -> global name xxx + prenode_local_name = predecessor.global_to_local_op_name[ + global_input_name] # global name xxx -> pre node local output name + dim_partition_dict = copy.deepcopy( + strategy.sharding_specs[prenode_local_name].dim_partition_dict) + in0_partition_dict = dim_partition_dict + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + return strategies_vector + name = ' {}'.format( + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py new file mode 100644 index 000000000..b58083ea6 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/cast_node.py @@ -0,0 +1,45 @@ +import copy + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Cast(Node): + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + in0_partition_dict = dim_partition_dict + out_partition_dict = copy.deepcopy(dim_partition_dict) + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + def _update_memory_cost(self, strategies): + pass diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py b/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py new file mode 100644 index 000000000..3f164ee99 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/comm_spec.py @@ -0,0 +1,58 @@ +__all__ = [ + 'CommSpec', +] + + +class CommSpec: + + def __init__(self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + mix_gather=False, + forward_only=True): + self.comm_pattern = comm_pattern + self.sharding_spec = sharding_spec + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + self.device_mesh = self.sharding_spec.device_mesh + self.mix_gather = mix_gather + self.forward_only = forward_only + if self.gather_dim: + assert len(self.gather_dim) == len( + self.logical_process_axis + ), f'unmatched gather dim {self.gather_dim} and logical process axis {self.logical_process_axis}' + if self.shard_dim: + assert len(self.shard_dim) == len( + self.logical_process_axis + ), f'unmatched shard dim {self.shard_dim} and logical process axis {self.logical_process_axis}' + if self.gather_dim and self.shard_dim: + assert len(self.shard_dim) == len( + self.gather_dim + ), f'unmatched gather dim {self.gather_dim} and shard dim {self.shard_dim}' + + def get_comm_cost(self): + ''' + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. + For shard operation, it is an on-chip operation, so the communication cost is zero. + ''' + comm_size = self.sharding_spec.get_sharded_size_per_device() + dtype = self.sharding_spec.dtype + + # reduce list_of_list to list + comm_dims = sum(self.logical_process_axis, []) + comm_cost = self.device_mesh.estimate_comm_cost(self.comm_pattern, + comm_dims, comm_size, + dtype) + return comm_cost + + def get_mem_cost(self): + return self.device_mesh.shape_consistency_manager.mem_cost([self]) + + def get_max_mem_cost(self): + return self.device_mesh.shape_consistency_manager.mem_cost( + [self], mem_pattern='max') diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py new file mode 100644 index 000000000..a9225fd40 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/concatenation_node.py @@ -0,0 +1,56 @@ +import copy + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Concatenation(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + batch_dims = [i for i in range(len(self.get_output(0).shape))] + self.axis = layer.as_trt().axis + batch_dims.remove(self.axis) + self._generate_bcast_dims(batch_dims, self.get_output(0).shape) + layer.to_base_class() + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + dim_partition_dict_mapping = {} + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + if self.axis in dim_partition_dict: + dim_partition_dict.pop(self.axis) + for idx in range(self.num_inputs): + in_partition_dict = copy.deepcopy(dim_partition_dict) + dim_partition_dict_mapping[f'input{idx}'] = in_partition_dict + out_partition_dict = dim_partition_dict + dim_partition_dict_mapping['output0'] = out_partition_dict + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, self.axis, [ + sharding_spec_mapping[f'input{idx}'].sharding_sequence + for idx in range(self.num_inputs) + ]) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py new file mode 100644 index 000000000..ed43331fe --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/constant_node.py @@ -0,0 +1,45 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Constant(Node): + + def _update_memory_cost(self, strategies): + super()._update_memory_cost(strategies) + for strategy in strategies: + strategy.inout_memory_footprint = 0.0 + strategy.peak_memory_footprint = 0.0 + strategy.const_memory_footprint = strategy.sharding_specs[ + 'output0'].get_max_sharded_size_per_device() + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = {'output0': dim_partition_dict} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + sharding_seq = sharding_spec_mapping['output0'].sharding_sequence + sharding_strategy = self._get_sharding_strategy( + name=f'constant-op {sharding_seq}', + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py new file mode 100644 index 000000000..6c2af5aaa --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/elementwise_node.py @@ -0,0 +1,49 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class ElementWise(Node): + + def __init__(self, layer): + super().__init__(layer) + batch_dims = [i for i in range(len(self.get_output(0).shape))] + self._generate_bcast_dims(batch_dims, self.get_output(0).shape) + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + in0_partition_dict = self._recover_bcast_partition_dict( + dim_partition_dict, self.op_data['input0']) + in1_partition_dict = self._recover_bcast_partition_dict( + dim_partition_dict, self.op_data['input1']) + out_partition_dict = dim_partition_dict + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "input1": in1_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {} {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence, + sharding_spec_mapping['input1'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py new file mode 100644 index 000000000..3be5f79ac --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/fill_node.py @@ -0,0 +1,59 @@ +import tensorrt as trt + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Fill(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + self.operation = layer.as_trt().operation + layer.to_base_class() + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + if self.num_inputs == 0 and self.operation != trt.FillOperation.LINSPACE: + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = {'output0': dim_partition_dict} + for i in range(self.num_inputs): + dim_partition_dict_mapping[f'input{i}'] = {} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + sharding_seq = sharding_spec_mapping['output0'].sharding_sequence + sharding_strategy = self._get_sharding_strategy( + name=f'fill-op {sharding_seq}', + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + updated_layer_attrs = {} + updated_input_values = {} + shape = strategy.sharding_specs['output0'].get_sharded_shape_per_device( + ) + if self.layer.num_inputs >= 1: + updated_input_values[0] = shape + else: + updated_layer_attrs['shape'] = shape + elapsed_time = self.node_runtime_profiler.runtime_profile( + self.layer, updated_layer_attrs, updated_input_values, strategy, + device_mesh) + return elapsed_time diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py new file mode 100644 index 000000000..8b9a6f088 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/gather_node.py @@ -0,0 +1,196 @@ +import copy + +import tensorrt as trt + +from .comm_spec import CommSpec +from .node import Node +from .sharding_spec import DimSpec +from .sharding_strategy import StrategiesVector + + +class Gather(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + self.mode = layer.as_trt().mode + self.axis = layer.as_trt().axis + self.num_elementwise_dims = layer.as_trt().num_elementwise_dims + self.input_id = 0 + self.indice_id = 1 + self.support_vocab_tp = False + layer.to_base_class() + + def _update_memory_cost(self, strategies): + for strategy in strategies: + # for gather node, it input0's read = output0's write + inout_memory_footprint = ( + strategy.sharding_specs['output0'].get_sharded_size_per_device( + ) * 2 + + strategy.sharding_specs['input1'].get_sharded_size_per_device()) + strategy.inout_memory_footprint = inout_memory_footprint + strategy.peak_memory_footprint = ( + strategy.sharding_specs['output0']. + get_max_sharded_size_per_device() + strategy. + sharding_specs['input0'].get_max_sharded_size_per_device() + + strategy.sharding_specs['input1']. + get_max_sharded_size_per_device()) + + def _collect_strategies(self, device_mesh): + if self.mode == trt.GatherMode.DEFAULT: + return self._default_gather_strategies(device_mesh) + elif self.mode == trt.GatherMode.ELEMENT: + return self._element_gather_strategies(device_mesh) + elif self.mode == trt.GatherMode.ND: + assert 0, 'unsupport gatherND' + else: + assert 0, f'unsupport gather mode {self.mode}' + + def _element_gather_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + if self.axis in dim_partition_dict: + dim_partition_dict.pop(self.axis) + + dim_partition_dict_mapping = { + 'input0': dim_partition_dict, + 'input1': copy.deepcopy(dim_partition_dict), + 'output0': copy.deepcopy(dim_partition_dict), + } + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {} {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence, self.axis, + sharding_spec_mapping['input1'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + # for plugin, indice is input0, and weight is input1, which is different from gather node + def _default_gather_strategies(self, device_mesh): + + def add_sharding_strategy(dim_partition_dict_mapping, + vocab_tp_dim=None): + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) > 0: + name = '{} = {} {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence, + self.axis, self.num_elementwise_dims, + sharding_spec_mapping['input1'].sharding_sequence) + communication_action_mapping = {} + if vocab_tp_dim is not None: + name += f'_allreduce{DimSpec(vocab_tp_dim)}' + output0_comm_action = CommSpec( + comm_pattern='all_reduce', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=[vocab_tp_dim], + ) + communication_action_mapping[ + 'output0'] = output0_comm_action + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategies_vector.append(sharding_strategy) + + input_id, indice_id = self.input_id, self.indice_id + strategies_vector = StrategiesVector(self) + input_size = len(self.op_data[f'input{input_id}'].shape) + indice_size = len(self.op_data[f'input{indice_id}'].shape) + output_dim = input_size + indice_size - 1 - self.num_elementwise_dims + for strategy in self.predecessor_nodes[input_id].strategies_vector: + # current node's local name input0 -> global name xxx + global_input_name = self.op_data[f'input{input_id}'].name + # global name xxx -> pre node local output name + prenode_local_name = self.predecessor_nodes[ + input_id].global_to_local_op_name[global_input_name] + input_dim_partition_dict = copy.deepcopy( + strategy.sharding_specs[prenode_local_name].dim_partition_dict) + + vocab_tp_dim = input_dim_partition_dict.pop(self.axis, None) + + input_mesh_dims = [] + for dim, mesh_dims in input_dim_partition_dict.items(): + input_mesh_dims += mesh_dims + input_mesh_dims = set(input_mesh_dims) + + for idx_strategy in self.predecessor_nodes[ + indice_id].strategies_vector: + # current node's local name input0 -> global name xxx + global_indice_name = self.op_data[f'input{indice_id}'].name + # global name xxx -> pre node local output name + prenode_local_name = self.predecessor_nodes[ + indice_id].global_to_local_op_name[global_indice_name] + indice_dim_partition_dict = copy.deepcopy( + idx_strategy.sharding_specs[prenode_local_name]. + dim_partition_dict) + + for dim, indice_mesh_dims in idx_strategy.sharding_specs[ + prenode_local_name].dim_partition_dict.items(): + for indice_mesh_dim in indice_mesh_dims: + if indice_mesh_dim in input_mesh_dims: + indice_dim_partition_dict.pop(dim) + break + + out_partition_dict = {} + + for dim in range(output_dim): + if dim < self.axis: + if dim in input_dim_partition_dict: + out_partition_dict[dim] = \ + input_dim_partition_dict[dim] + elif dim >= self.axis and dim < self.axis + indice_size - self.num_elementwise_dims: + indice_dim = dim - self.axis + self.num_elementwise_dims + if indice_dim in indice_dim_partition_dict: + out_partition_dict[dim] = \ + indice_dim_partition_dict[indice_dim] + else: + input_dim = dim - (indice_size - + self.num_elementwise_dims) + 1 + if input_dim in input_dim_partition_dict: + out_partition_dict[dim] = \ + input_dim_partition_dict[input_dim] + + dim_partition_dict_mapping = { + f"input{input_id}": input_dim_partition_dict, + f"input{indice_id}": indice_dim_partition_dict, + "output0": out_partition_dict, + } + add_sharding_strategy(dim_partition_dict_mapping) + + if self.support_vocab_tp and vocab_tp_dim is not None: + vocab_tp_dim_partition_dict = { + **input_dim_partition_dict, + self.axis: vocab_tp_dim, + } + dim_partition_dict_mapping = { + f"input{input_id}": vocab_tp_dim_partition_dict, + f"input{indice_id}": indice_dim_partition_dict, + "output0": out_partition_dict, + } + add_sharding_strategy(dim_partition_dict_mapping, + vocab_tp_dim) + + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py new file mode 100644 index 000000000..470978b8e --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/identity_node.py @@ -0,0 +1,56 @@ +import copy + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Identity(Node): + + def _update_memory_cost(self, strategies): + if not self.is_fake: + super()._update_memory_cost(strategies) + else: + # fake nodes for building block/PP connection + pass + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + # dim_partition_dict can be the same as previous node if solver's time is a problem + for dim_partition_dict in dim_partition_list: + in0_partition_dict = dim_partition_dict + out_partition_dict = copy.deepcopy(dim_partition_dict) + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + # if same spec id is not 0, identify node is used as same spec id node + if self.same_spec_id == -1: + return super()._profile_sharding_cost(strategy, device_mesh) + else: + return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py new file mode 100644 index 000000000..f8e24cd49 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/input_node.py @@ -0,0 +1,79 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class InputNode(Node): + + def _update_memory_cost(self, strategies): + for strategy in strategies: + if not self.no_memory_footprint: + strategy.const_memory_footprint = strategy.sharding_specs[ + 'output0'].get_max_sharded_size_per_device() + + def __init__(self, tensor): + self._layer = None + self.is_shape_io = False + self._inputs = [] + self._outputs = [] + self.predecessor_nodes = [] + self.predecessor_nodes_out_index = {} + self.successor_nodes = [] + self.op_data = {} + self.global_to_local_op_name = {} + self.is_replicated = tensor.attrs.get("is_replicated", False) + self.same_spec_id = tensor.attrs.get("same_spec_id", -1) + self.no_memory_footprint = tensor.attrs.get("no_memory_footprint", + False) + self.building_block_id = -1 + self.cost_level = -1 + self.stage_type = None + self.in_start_block = None + self.in_end_block = None + self.in_slowest_block = None + output = tensor.copy() + self._outputs.append(output) + self.op_data['output0'] = output + self.global_to_local_op_name[output.name] = 'output0' + + self.sharding_weight = 1.0 + self.resharding_weight = 1.0 + self.pipeline_weight = 0 + self.node_name = tensor.name + self.node_type = 'input_node' + self.num_inputs = 0 + self.num_outputs = 1 + self.dtype = tensor.dtype + self.strategies_vector = [] + self.node_runtime_profiler = None + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = {'output0': dim_partition_dict} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + sharding_seq = sharding_spec_mapping['output0'].sharding_sequence + sharding_strategy = self._get_sharding_strategy( + name=f'input-op {sharding_seq}', + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py new file mode 100644 index 000000000..2f2835e20 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/matmul_node.py @@ -0,0 +1,798 @@ +import copy +import operator +from functools import reduce + +import tensorrt as trt + +from ..device_mesh import LogicalDeviceMesh +from ..utils import get_builder_flags +from .comm_spec import CommSpec +from .node import Node +from .sharding_spec import DimSpec +from .sharding_strategy import StrategiesVector + + +class MatrixMultiply(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2] + self._generate_bcast_dims(batch_dims, self.get_output(0).shape) + self.op0_transpose = layer.as_trt().op0 == trt.MatrixOperation.TRANSPOSE + self.op1_transpose = layer.as_trt().op1 == trt.MatrixOperation.TRANSPOSE + self.num_out_dims = len(self.get_output(0).shape) + dtypes_str = [ + self.get_input(0).dtype_str, + self.get_input(1).dtype_str, + self.get_output(0).dtype_str + ] + dtypes_size = [ + self.get_input(0).dtype_size, + self.get_input(1).dtype_size, + self.get_output(0).dtype_size + ] + min_idx = dtypes_size.index(min(dtypes_size)) + self.dtype = dtypes_str[min_idx] + layer.to_base_class() + + def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1, device_mesh): + in0_split_dim = -1 if self.op0_transpose else -2 + in1_split_dim = -2 if self.op1_transpose else -1 + name = (f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} = ' + f'{DimSpec(mesh_dim_0)}R x R{DimSpec(mesh_dim_1)}') + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim_0 + }, + "input1": { + in1_split_dim: mesh_dim_1 + }, + "output0": { + -2: mesh_dim_0, + -1: mesh_dim_1 + }, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + strategy = self._get_sharding_strategy(name = name, \ + sharding_spec_mapping = sharding_spec_mapping, \ + communication_action_mapping = {}) + return strategy + + def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1, + device_mesh): + # handle the case SR = SS x SR + name = ( + f'{DimSpec(mesh_dim_0)}R = ' + f'{DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)} x {DimSpec(mesh_dim_1)}R' + f'_allreduce{DimSpec(mesh_dim_1)}') + in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1] + in1_split_dim = -1 if self.op1_transpose else -2 + # get sharding spec mapping + dim_partition_dict_mapping = { + "input0": { + in0_split_dim[0]: mesh_dim_0, + in0_split_dim[1]: mesh_dim_1 + }, + "input1": { + in1_split_dim: mesh_dim_1 + }, + "output0": { + -2: mesh_dim_0 + }, + } + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + # get communication action mapping + communication_action_mapping = {} + output0_comm_action = CommSpec( + comm_pattern='all_reduce', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=[mesh_dim_1], + ) + communication_action_mapping['output0'] = output0_comm_action + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def _split_both_contract_rs(self, name, rs_dim, rs_mesh_dim, src_spec, + dim_partition_dict_mapping, device_mesh): + output0_comm_action = CommSpec( + comm_pattern='reduce_scatter', + sharding_spec=src_spec, + shard_dim=[rs_dim], + logical_process_axis=[rs_mesh_dim], + ) + rs_out_partition_dict_mapping = copy.deepcopy( + dim_partition_dict_mapping) + rs_out_partition_dict_mapping["output0"][rs_dim] = rs_mesh_dim + rs_out_sharding_spec_mapping = self._to_sharding_spec_mapping( + rs_out_partition_dict_mapping, device_mesh) + if len(rs_out_sharding_spec_mapping) == 0: + return None + + communication_action_mapping = {} + communication_action_mapping['output0'] = output0_comm_action + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=rs_out_sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def _split_lhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1, + device_mesh): + # handle the case SS = SS x SR -> reduce_scatter + in0_split_dim = [-1, -2] if self.op0_transpose else [-2, -1] + in1_split_dim = -1 if self.op1_transpose else -2 + # get sharding spec mapping + dim_partition_dict_mapping = { + "input0": { + in0_split_dim[0]: mesh_dim_0, + in0_split_dim[1]: mesh_dim_1 + }, + "input1": { + in1_split_dim: mesh_dim_1 + }, + "output0": { + -2: mesh_dim_0, + }, + } + mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(mm_out_sharding_spec_mapping) == 0: + return [] + strategies = [] + for rs_dim in range(self.num_out_dims): + if rs_dim != self.num_out_dims - 2: + name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ + 'R' + ] * self.num_out_dims, ['R'] * self.num_out_dims + name_in0[-2], name_in0[-1] = str(DimSpec(mesh_dim_0)), str( + DimSpec(mesh_dim_1)) + name_in1[-2] = str(DimSpec(mesh_dim_1)) + name_out0[-2], name_out0[rs_dim] = str( + DimSpec(mesh_dim_0)), str(DimSpec(mesh_dim_1)) + name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( + name_in1), ', '.join(name_out0) + name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]' + f'_reducescatter{(rs_dim, DimSpec(mesh_dim_1))}') + ret = self._split_both_contract_rs( + name, rs_dim, mesh_dim_1, + mm_out_sharding_spec_mapping['output0'], + dim_partition_dict_mapping, device_mesh) + if ret: + strategies.append(ret) + return strategies + + def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1, + device_mesh): + name = ( + f'R{DimSpec(mesh_dim_1)} = ' + f'R{DimSpec(mesh_dim_0)} x {DimSpec(mesh_dim_0)}{DimSpec(mesh_dim_1)}' + f'_allreduce{DimSpec(mesh_dim_0)}') + in0_split_dim = -2 if self.op0_transpose else -1 + in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1] + # get sharding specs + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim_0 + }, + "input1": { + in1_split_dim[0]: mesh_dim_0, + in1_split_dim[1]: mesh_dim_1 + }, + "output0": { + -1: mesh_dim_1 + }, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + # get communication actions + communication_action_mapping = {} + output0_comm_action = CommSpec( + comm_pattern='all_reduce', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=[mesh_dim_0], + ) + communication_action_mapping['output0'] = output0_comm_action + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def _split_rhs_space_both_contract_rs(self, mesh_dim_0, mesh_dim_1, + device_mesh): + in0_split_dim = -2 if self.op0_transpose else -1 + in1_split_dim = [-1, -2] if self.op1_transpose else [-2, -1] + # get sharding specs + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim_0 + }, + "input1": { + in1_split_dim[0]: mesh_dim_0, + in1_split_dim[1]: mesh_dim_1 + }, + "output0": { + -1: mesh_dim_1 + }, + } + mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(mm_out_sharding_spec_mapping) == 0: + return [] + strategies = [] + for rs_dim in range(self.num_out_dims): + if rs_dim != self.num_out_dims - 1: + name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ + 'R' + ] * self.num_out_dims, ['R'] * self.num_out_dims + name_in1[-2], name_in1[-1] = str(DimSpec(mesh_dim_0)), str( + DimSpec(mesh_dim_1)) + name_in0[-1] = str(DimSpec(mesh_dim_0)) + name_out0[-1], name_out0[rs_dim] = str( + DimSpec(mesh_dim_1)), str(DimSpec(mesh_dim_0)) + name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( + name_in1), ', '.join(name_out0) + name = (f'[{name_out0}] = [{name_in0}] x [{name_in1}]' + f'_reducescatter{(rs_dim, DimSpec(mesh_dim_0))}') + ret = self._split_both_contract_rs( + name, rs_dim, mesh_dim_0, + mm_out_sharding_spec_mapping['output0'], + dim_partition_dict_mapping, device_mesh) + if ret: + strategies.append(ret) + return strategies + + def _recompute_split_both_contract(self, mesh_dim, device_mesh): + name = (f'RR = R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R' + f'_allreduce{DimSpec(mesh_dim)}') + in0_split_dim = -2 if self.op0_transpose else -1 + in1_split_dim = -1 if self.op1_transpose else -2 + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim + }, + "input1": { + in1_split_dim: mesh_dim + }, + "output0": {}, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + + # get communication action + communication_action_mapping = {} + output0_comm_action = CommSpec( + comm_pattern='all_reduce', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=[mesh_dim], + ) + communication_action_mapping['output0'] = output0_comm_action + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def _recompute_split_both_contract_rs(self, mesh_dim, device_mesh): + name = (f'{DimSpec(mesh_dim)}R = ' + f'R{DimSpec(mesh_dim)} x {DimSpec(mesh_dim)}R' + f'_reducescatter0_{DimSpec(mesh_dim)}') + in0_split_dim = -2 if self.op0_transpose else -1 + in1_split_dim = -1 if self.op1_transpose else -2 + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim + }, + "input1": { + in1_split_dim: mesh_dim + }, + "output0": {}, + } + mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(mm_out_sharding_spec_mapping) == 0: + return [] + + strategies = [] + for rs_dim in range(self.num_out_dims): + name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ + 'R' + ] * self.num_out_dims, ['R'] * self.num_out_dims + name_in0[-1], name_in1[-2], name_out0[rs_dim] = str( + DimSpec(mesh_dim)), str(DimSpec(mesh_dim)), str( + DimSpec(mesh_dim)) + name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( + name_in1), ', '.join(name_out0) + name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim))}' + ret = self._split_both_contract_rs( + name, rs_dim, mesh_dim, mm_out_sharding_spec_mapping['output0'], + dim_partition_dict_mapping, device_mesh) + if ret: + strategies.append(ret) + return strategies + + def _split_rhs_space_only(self, mesh_dim, device_mesh): + name = f'R{DimSpec(mesh_dim)} = RR x R{DimSpec(mesh_dim)}' + in1_split_dim = -2 if self.op1_transpose else -1 + # get sharding spec + dim_partition_dict_mapping = { + "input0": {}, + "input1": { + in1_split_dim: mesh_dim + }, + "output0": { + -1: mesh_dim + }, + } + # We don't have to do anything special for bias here, because + # the bias is already the same sharding spec as the output0. + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_lhs_space_only(self, mesh_dim, device_mesh): + name = f'{DimSpec(mesh_dim)}R = {DimSpec(mesh_dim)}R x RR' + in0_split_dim = -1 if self.op0_transpose else -2 + # get sharding spec + dim_partition_dict_mapping = { + "input0": { + in0_split_dim: mesh_dim + }, + "input1": {}, + "output0": { + -2: mesh_dim + }, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _non_split(self, device_mesh): + name = 'RR = RR x RR' + # get sharding spec + dim_partition_dict_mapping = { + "input0": {}, + "input1": {}, + "output0": {}, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_one_batch_dim(self, batch_dim, mesh_dim, device_mesh): + name = ( + f'{DimSpec(mesh_dim)}b{batch_dim}RR = ' + f'{DimSpec(mesh_dim)}b{batch_dim}RR x {DimSpec(mesh_dim)}b{batch_dim}RR' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + + batch_partition_dict = {batch_dim: mesh_dim} + in0_parition_dict = self._recover_bcast_partition_dict( + batch_partition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + batch_partition_dict, in1_data) + out_partition_dict = {batch_dim: mesh_dim} + # TODO:[KDuan] Double check if MatrixMultiplication's output has bcast in dim + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_two_batch_dims(self, batch_dim0, batch_dim1, mesh_dim0, + mesh_dim1, device_mesh): + name = ( + f'{DimSpec(mesh_dim0)}b{batch_dim0}{DimSpec(mesh_dim1)}b{batch_dim1}RR = ' + f'{DimSpec(mesh_dim0)}b{batch_dim0}RR x {DimSpec(mesh_dim1)}b{batch_dim1}RR' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + + in0_parition_dict = {} + if batch_dim0 not in in0_data.attrs["broadcast_dims"]: + in0_parition_dict[batch_dim0] = mesh_dim0 + if batch_dim1 not in in0_data.attrs["broadcast_dims"]: + in0_parition_dict[batch_dim1] = mesh_dim1 + + in1_parition_dict = {} + if batch_dim0 not in in1_data.attrs["broadcast_dims"]: + in1_parition_dict[batch_dim0] = mesh_dim0 + if batch_dim1 not in in1_data.attrs["broadcast_dims"]: + in1_parition_dict[batch_dim1] = mesh_dim1 + + batch_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1} + in0_parition_dict = self._recover_bcast_partition_dict( + batch_partition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + batch_partition_dict, in1_data) + out_partition_dict = {batch_dim0: mesh_dim0, batch_dim1: mesh_dim1} + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_batch_dim_lhs_space(self, batch_dim, mesh_dim0, mesh_dim1, + device_mesh): + + name = ( + f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R = ' + f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R x {DimSpec(mesh_dim0)}b{batch_dim}RR' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + in0_parition_dict = {batch_dim: mesh_dim0} + in1_parition_dict = {batch_dim: mesh_dim0} + in0_lhs_split_dim = -1 if self.op0_transpose else -2 + in0_parition_dict[in0_lhs_split_dim] = mesh_dim1 + + in0_parition_dict = self._recover_bcast_partition_dict( + in0_parition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + in1_parition_dict, in1_data) + out_partition_dict = {batch_dim: mesh_dim0, -2: mesh_dim1} + + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_batch_dim_rhs_space(self, batch_dim, mesh_dim0, mesh_dim1, + device_mesh): + + name = ( + f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} = ' + f'{DimSpec(mesh_dim0)}b{batch_dim}RR x {DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)}' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + in0_parition_dict = {batch_dim: mesh_dim0} + in1_parition_dict = {batch_dim: mesh_dim0} + + in1_rhs_split_dim = -2 if self.op1_transpose else -1 + in1_parition_dict[in1_rhs_split_dim] = mesh_dim1 + + in0_parition_dict = self._recover_bcast_partition_dict( + in0_parition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + in1_parition_dict, in1_data) + out_partition_dict = {batch_dim: mesh_dim0, -1: mesh_dim1} + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + + def _split_batch_dim_both_contract(self, batch_dim, mesh_dim0, mesh_dim1, + device_mesh): + + name = ( + f'{DimSpec(mesh_dim0)}b{batch_dim}RR = ' + f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x ' + f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + in0_parition_dict = {batch_dim: mesh_dim0} + in1_parition_dict = {batch_dim: mesh_dim0} + + in0_contract_dim = -2 if self.op0_transpose else -1 + in1_contract_dim = -1 if self.op1_transpose else -2 + in0_parition_dict[in0_contract_dim] = mesh_dim1 + in1_parition_dict[in1_contract_dim] = mesh_dim1 + + in0_parition_dict = self._recover_bcast_partition_dict( + in0_parition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + in1_parition_dict, in1_data) + out_partition_dict = {batch_dim: mesh_dim0} + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(sharding_spec_mapping) == 0: + return None + + # get communication actions + communication_action_mapping = {} + output0_comm_action = CommSpec( + comm_pattern='all_reduce', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=[mesh_dim1], + ) + communication_action_mapping['output0'] = output0_comm_action + return self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + + def _split_batch_dim_both_contract_rs(self, batch_dim, mesh_dim0, mesh_dim1, + device_mesh): + + name = ( + f'{DimSpec(mesh_dim0)}b{batch_dim}RR = ' + f'{DimSpec(mesh_dim0)}b{batch_dim}R{DimSpec(mesh_dim1)} x ' + f'{DimSpec(mesh_dim0)}b{batch_dim}{DimSpec(mesh_dim1)}R_AR{mesh_dim1}' + ) + in0_data = self.op_data['input0'] + in1_data = self.op_data['input1'] + in0_parition_dict = {batch_dim: mesh_dim0} + in1_parition_dict = {batch_dim: mesh_dim0} + + in0_contract_dim = -2 if self.op0_transpose else -1 + in1_contract_dim = -1 if self.op1_transpose else -2 + in0_parition_dict[in0_contract_dim] = mesh_dim1 + in1_parition_dict[in1_contract_dim] = mesh_dim1 + + in0_parition_dict = self._recover_bcast_partition_dict( + in0_parition_dict, in0_data) + in1_parition_dict = self._recover_bcast_partition_dict( + in1_parition_dict, in1_data) + out_partition_dict = {batch_dim: mesh_dim0} + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "input1": in1_parition_dict, + "output0": out_partition_dict, + } + mm_out_sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if len(mm_out_sharding_spec_mapping) == 0: + return [] + + strategies = [] + for rs_dim in range(self.num_out_dims): + if rs_dim != batch_dim: + name_in0, name_in1, name_out0 = ['R'] * self.num_out_dims, [ + 'R' + ] * self.num_out_dims, ['R'] * self.num_out_dims + name_in0[batch_dim], name_in0[-1] = str( + DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) + name_in1[batch_dim], name_in1[-2] = str( + DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) + name_in1[batch_dim], name_out0[rs_dim] = str( + DimSpec(mesh_dim0)), str(DimSpec(mesh_dim1)) + name_in0, name_in1, name_out0 = ', '.join(name_in0), ', '.join( + name_in1), ', '.join(name_out0) + name = f'[{name_out0}] = [{name_in0}] x [{name_in1}]_reducescatter{(rs_dim, DimSpec(mesh_dim1))}' + ret = self._split_both_contract_rs( + name, rs_dim, mesh_dim1, + mm_out_sharding_spec_mapping['output0'], + dim_partition_dict_mapping, device_mesh) + if ret: + strategies.append(ret) + return strategies + + def _dp_strategies(self, device_mesh): + strategies = [] + # S0R = S0R x RR + strategies.append(self._split_lhs_space_only([0], device_mesh)) + # S1R = S1R x RR + strategies.append(self._split_lhs_space_only([1], device_mesh)) + # S01R = S01R x RR + strategies.append(self._split_lhs_space_only([0, 1], device_mesh)) + return strategies + + def _tp_strategies(self, device_mesh: LogicalDeviceMesh): + strategies = [] + # RR = RS x SR _ AR + strategies.append(self._recompute_split_both_contract([0], device_mesh)) + strategies.append(self._recompute_split_both_contract([1], device_mesh)) + strategies.append( + self._recompute_split_both_contract([0, 1], device_mesh)) + + if device_mesh.config.enable_reduce_scatter: + # RS x SR _ reduce scatter + strategies.extend( + self._recompute_split_both_contract_rs([0], device_mesh)) + strategies.extend( + self._recompute_split_both_contract_rs([1], device_mesh)) + strategies.extend( + self._recompute_split_both_contract_rs([0, 1], device_mesh)) + + # RS = RR x RS + strategies.append(self._split_rhs_space_only([0], device_mesh)) + strategies.append(self._split_rhs_space_only([1], device_mesh)) + strategies.append(self._split_rhs_space_only([0, 1], device_mesh)) + + # RS = RS x SS _ AR + strategies.append( + self._split_rhs_space_both_contract([0], [1], device_mesh)) + strategies.append( + self._split_rhs_space_both_contract([1], [0], device_mesh)) + + if device_mesh.config.enable_reduce_scatter: + # RS x SS _ reduce scatter + strategies.extend( + self._split_rhs_space_both_contract_rs([0], [1], device_mesh)) + strategies.extend( + self._split_rhs_space_both_contract_rs([1], [0], device_mesh)) + + return strategies + + def _mix_strategies(self, device_mesh): + strategies = [] + + # SR = SS x SR_AR + strategies.append( + self._split_lhs_space_both_contract([0], [1], device_mesh)) + strategies.append( + self._split_lhs_space_both_contract([1], [0], device_mesh)) + if device_mesh.config.enable_reduce_scatter: + # RS x SS _ reduce scatter + strategies.extend( + self._split_lhs_space_both_contract_rs([0], [1], device_mesh)) + strategies.extend( + self._split_lhs_space_both_contract_rs([1], [0], device_mesh)) + # SS = SR x RS + strategies.append(self._split_lhs_space_rhs_space([0], [1], + device_mesh)) + strategies.append(self._split_lhs_space_rhs_space([0], [1], + device_mesh)) + + # RR = RR x RR + strategies.append(self._non_split(device_mesh)) + return strategies + + def _bmm_strategies(self, device_mesh: LogicalDeviceMesh): + strategies = [] + bmm_dim = len(self.op_data['output0'].shape) + if bmm_dim >= 3: + for batch_dim in range(0, bmm_dim - 2): + strategies.append( + self._split_one_batch_dim(batch_dim, [0], device_mesh)) + strategies.append( + self._split_one_batch_dim(batch_dim, [1], device_mesh)) + strategies.append( + self._split_one_batch_dim(batch_dim, [0, 1], device_mesh)) + + strategies.append( + self._split_batch_dim_lhs_space(batch_dim, [0], [1], + device_mesh)) + strategies.append( + self._split_batch_dim_lhs_space(batch_dim, [1], [0], + device_mesh)) + + strategies.append( + self._split_batch_dim_rhs_space(batch_dim, [0], [1], + device_mesh)) + strategies.append( + self._split_batch_dim_rhs_space(batch_dim, [1], [0], + device_mesh)) + + strategies.append( + self._split_batch_dim_both_contract(batch_dim, [0], [1], + device_mesh)) + strategies.append( + self._split_batch_dim_both_contract(batch_dim, [1], [0], + device_mesh)) + if device_mesh.config.enable_reduce_scatter: + strategies.extend( + self._split_batch_dim_both_contract_rs( + batch_dim, [0], [1], device_mesh)) + strategies.extend( + self._split_batch_dim_both_contract_rs( + batch_dim, [1], [0], device_mesh)) + if bmm_dim >= 4: + for batch_dim0 in range(0, bmm_dim - 2): + for batch_dim1 in range(0, bmm_dim - 2): + if batch_dim0 != batch_dim1: + strategies.append( + self._split_two_batch_dims( + batch_dim0, batch_dim1, [0], [1], + device_mesh)) + + return strategies + + def _collect_strategies(self, device_mesh): + strategies_vector = StrategiesVector(self) + dp_strategies = self._dp_strategies(device_mesh) + tp_strategies = self._tp_strategies(device_mesh) + mix_strategies = self._mix_strategies(device_mesh) + bmm_strategies = self._bmm_strategies(device_mesh) + strategies_vector.extend(dp_strategies) + strategies_vector.extend(tp_strategies) + strategies_vector.extend(mix_strategies) + strategies_vector.extend(bmm_strategies) + return strategies_vector + + def is_fp16(self): + builder_flags = get_builder_flags() + return builder_flags & (1 << int(trt.BuilderFlag.FP16)) != 0 + + def _get_math_time(self, strategy, device_mesh): + shape_in0 = strategy.sharding_specs[ + 'input0'].get_sharded_shape_per_device() + shape_out = strategy.sharding_specs[ + 'output0'].get_sharded_shape_per_device() + m, n = shape_out[-2], shape_out[-1] + batches = shape_out[:-2] + k = shape_in0[-2] if self.op0_transpose else shape_in0[-1] + macs_shape = batches + [m, n, k] + macs = reduce(operator.mul, macs_shape, 1) * 2 + config = device_mesh.config + cluster_info = device_mesh.cluster_info + dtype = self.dtype + # For fp16 matmul ops that use_fp32_acc=True. + # They are mistaken for fp32 ops since all of their IO tensors use fp32 dtype. + if self.is_fp16() and self.dtype == "float32": + dtype = "float16" + math_throughput_tflops = getattr(cluster_info.math_throughput, dtype) + assert math_throughput_tflops != 0, \ + "Undefined {} math throughput of cluster {}".format(dtype, config.cluster_key) + math_time = macs / math_throughput_tflops * 1e-6 * cluster_info.math_efficiency + return math_time + + def _update_memory_cost(self, strategies): + super()._update_memory_cost(strategies) + # For fp16 matmul ops that use_fp32_acc=True. + # Their memory footprints are calculated based on fp32 IO tensors. + # Actually they will use fp16 IO tensors after fused. + # So we divide all the memory footprints by 2. + if self.is_fp16() and self.dtype == "float32": + for strategy in strategies: + strategy.inout_memory_footprint /= 2 + strategy.peak_memory_footprint /= 2 + strategy.comm_buff_memory_footprint /= 2 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/node.py b/tensorrt_llm/auto_parallel/tensor_parallel/node.py new file mode 100644 index 000000000..5ea867041 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/node.py @@ -0,0 +1,376 @@ +from abc import ABC + +from ..config import CostModel +from ..device_mesh import LogicalDeviceMesh +from .comm_spec import CommSpec +from .sharding_spec import ShardingSpec +from .sharding_strategy import ShardingStrategy, StrategiesVector + + +class Node(ABC): + + def __init__(self, layer): + self._layer = layer + self.is_shape_io = self._layer.is_shape_io + self._inputs = [] + self._outputs = [] + self.predecessor_nodes = [] + self.predecessor_nodes_out_index = {} + self.successor_nodes = [] + self.op_data = {} + self.global_to_local_op_name = {} + self.num_inputs = 0 + self.is_replicated = layer.attrs.get("is_replicated", False) + self.same_spec_id = layer.attrs.get("same_spec_id", -1) + self.is_fake = self.same_spec_id != -1 + self.building_block_id = layer.attrs.get("building_block_id", -1) + self.cost_level = -1 + self.stage_type = layer.attrs.get("stage_type", None) + self.in_start_block = layer.attrs.get("in_start_block", False) + self.in_end_block = layer.attrs.get("in_end_block", False) + self.in_slowest_block = layer.attrs.get("in_slowest_block", False) + for i, input in enumerate(layer.inputs): + if input is None: + self._inputs.append(None) + self.op_data[f'input{i}'] = None + continue + input = input.copy() + input.attrs["broadcast_dims"] = [] + self._inputs.append(input) + self.op_data[f'input{i}'] = input + self.global_to_local_op_name[input.name] = f'input{i}' + + for i, output in enumerate(layer.outputs): + output = output.copy() + output.attrs["broadcast_dims"] = [] + self._outputs.append(output) + self.op_data[f'output{i}'] = output + self.global_to_local_op_name[output.name] = f'output{i}' + + self.sharding_weight = 1.0 + self.resharding_weight = 1.0 + self.pipeline_weight = 0 + self.node_name = layer.name + self.node_type = 'normal_node' + self.num_inputs = layer.num_inputs + self.num_outputs = layer.num_outputs + self.dtype = layer.as_trt().precision + self.strategies_vector = [] + self.node_runtime_profiler = None + + def post_init(self, graph): + for input in self.inputs: + if input is None: + self.predecessor_nodes.append(None) + continue + if input.producer is None: + predecessor_node = graph.get_node(input.name) + self.predecessor_nodes.append(predecessor_node) + self.predecessor_nodes_out_index[predecessor_node] = 0 + predecessor_node.successor_nodes.append(self) + else: + predecessor_node = graph.get_node(input.producer.name) + self.predecessor_nodes.append(predecessor_node) + self.predecessor_nodes_out_index[ + predecessor_node] = input.output_index + predecessor_node.successor_nodes.append(self) + + @property + def layer(self): + return self._layer + + def get_input(self, index): + return self._inputs[index] + + @property + def inputs(self): + return self._inputs + + def get_output(self, index): + return self._outputs[index] + + @property + def outputs(self): + return self._outputs + + def collect_strategies(self, device_mesh): + strategies_vector = self._collect_strategies(device_mesh) + strategies_vector = self._post_process(strategies_vector) + self._update_sharding_cost(strategies_vector, device_mesh) + self.strategies_vector = strategies_vector + return self.strategies_vector + + def _set_strategy(self, strategy, device_mesh): + strategies_vector = StrategiesVector(self) + if strategy is None: + dim_partition_dict_mapping = {} + for i in range(self.num_inputs): + dim_partition_dict_mapping[f'input{i}'] = {} + for i in range(self.num_outputs): + dim_partition_dict_mapping[f'output{i}'] = {} + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + assert 0 != len( + sharding_spec_mapping + ), f'failed to set default(all Replicate) strategy for node {self.node_name}' + name = 'RRs' + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + else: + sharding_specs_map = strategy.sharding_specs + comm_specs_map = strategy.communication_actions + dim_partition_dict_mapping = {} + for op_name, sharding_spec in sharding_specs_map.items(): + dim_partition_dict_mapping[ + op_name] = sharding_spec.dim_partition_dict + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + assert 0 != len( + sharding_spec_mapping + ), f'failed to set strategy for node {self.node_name}' + comm_specs_mapping = {} + if len(comm_specs_map) > 0: + for op_name, comm_spec in comm_specs_map.items(): + comm_specs_mapping[op_name] = CommSpec( + comm_pattern=comm_spec.comm_pattern, + sharding_spec=sharding_spec_mapping[op_name], + logical_process_axis=comm_spec.logical_process_axis, + ) + strategies_vector.append( + self._get_sharding_strategy( + name=strategy.name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=comm_specs_mapping)) + return strategies_vector + + def set_strategy(self, strategy, device_mesh): + strategies_vector = self._set_strategy(strategy, device_mesh) + strategies_vector = self._post_process(strategies_vector) + self._update_sharding_cost(strategies_vector, device_mesh) + self.strategies_vector = strategies_vector + return self.strategies_vector + + def update_resharding_cost(self): + self._update_resharding_cost(self.strategies_vector) + return self.strategies_vector + + def _to_sharding_spec_mapping(self, dim_partition_dict_mapping, + device_mesh): + results = {} + for op_data_name, dim_partition_dict in dim_partition_dict_mapping.items( + ): + if op_data_name in self.op_data: + op_data = self.op_data[op_data_name] + + def _to_sharding_spec(op_data, dim_partition_dict): + sharding_spec = ShardingSpec( + device_mesh, + op_data.dtype_str_size, [*op_data.shape], + [*op_data.max_shape], [*op_data.raw_shape], + dim_partition_dict=dim_partition_dict) + if sharding_spec.sanity_check(): + return sharding_spec + else: + return None + + sharding_spec = _to_sharding_spec(op_data, dim_partition_dict) + if sharding_spec: + results[op_data_name] = sharding_spec + else: + return {} + return results + + def _get_sharding_strategy(self, name, sharding_spec_mapping, + communication_action_mapping): + return ShardingStrategy( + name=name, + sharding_specs=sharding_spec_mapping, + communication_actions=communication_action_mapping, + ) + + def _remove_duplicated_strategy(self, strategies_vector): + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + for strategy in remove_list: + strategies_vector.remove(strategy) + + def _post_process(self, strategies_vector): + # TODO:[KDuan] deal with transpose and dimension 1 problem in ClossalAI, which have been processed before + for i in range(len(strategies_vector) - 1, -1, -1): + if strategies_vector[i] is None: + strategies_vector.pop(i) + + self._remove_duplicated_strategy(strategies_vector) + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh: LogicalDeviceMesh): + elapsed_time = self.node_runtime_profiler.runtime_profile( + self.layer, {}, {}, strategy, device_mesh) + return elapsed_time + + def _model_sharding_cost_from_s_curve(self, strategy, + device_mesh: LogicalDeviceMesh): + ''' + [ToDo][KDuan] preprofile the s_curve + ''' + sharding_cost = 0.0 + return sharding_cost + + # this method might be overwritten by some Ops + def _get_math_time(self, strategy, device_mesh: LogicalDeviceMesh): + return 0.0 + + # this method might be overwritten by some Ops + def _get_memory_time(self, strategy, device_mesh: LogicalDeviceMesh): + memory_time = (strategy.inout_memory_footprint / + device_mesh.cluster_info.memory_bw * 1e-3 * + device_mesh.cluster_info.memory_efficiency) + return memory_time + + def _model_sharding_cost_from_alpha_beta(self, strategy, + device_mesh: LogicalDeviceMesh): + math_time = self._get_math_time(strategy, device_mesh) + mem_time = self._get_memory_time(strategy, device_mesh) + return max(math_time, mem_time) + + def _get_communication_cost(self, strategy): + total_comm_cost = 0.0 + for op_data_name, comm_spec in strategy.communication_actions.items(): + comm_cost = comm_spec.get_comm_cost() + total_comm_cost = total_comm_cost + comm_cost + return total_comm_cost + + def _update_sharding_cost(self, strategies, device_mesh: LogicalDeviceMesh): + self._update_memory_cost(strategies) + + if device_mesh.config.sharding_cost_model == CostModel.ALPHA_BETA: + for strategy in strategies: + strategy.sharding_cost = self._model_sharding_cost_from_alpha_beta( + strategy, device_mesh) + elif device_mesh.config.sharding_cost_model == CostModel.S_CURVE: + for strategy in strategies: + strategy.sharding_cost = self._model_sharding_cost_from_s_curve( + strategy, device_mesh) + elif device_mesh.config.sharding_cost_model == CostModel.PROFILE: + for strategy in strategies: + strategy.alpha_beta_cost = self._model_sharding_cost_from_alpha_beta( + strategy, device_mesh) + if self.is_shape_io: + strategy.sharding_cost = strategy.alpha_beta_cost + else: + strategy.sharding_cost = self._profile_sharding_cost( + strategy, device_mesh) + elif device_mesh.config.sharding_cost_model == CostModel.ZERO: + for strategy in strategies: + strategy.sharding_cost = 0.0 + else: + assert False, 'unsupport sharding cost model option: {}'.format( + device_mesh.config.sharding_cost_model) + + for strategy in strategies: + strategy.communication_cost = self._get_communication_cost(strategy) + + def _compute_resharding_cost(self, pre_sharding_sepc, cur_sharding_spec, + op_data): + transform_path, comm_action_sequence, resharding_cost = cur_sharding_spec.device_mesh.shape_consistency_manager.shape_consistency( + pre_sharding_sepc, cur_sharding_spec) + return (transform_path, comm_action_sequence, resharding_cost) + + def _update_resharding_cost(self, strategies): + for strategy in strategies: + resharding_costs = {} + for pre_node, out_index in self.predecessor_nodes_out_index.items(): + if pre_node is None: + continue + pre_node_out_data_name = pre_node.get_output(out_index).name + pre_node_out_data_lname = pre_node.global_to_local_op_name[ + pre_node_out_data_name] + if pre_node_out_data_name not in self.global_to_local_op_name: + print(f"pre_node_out_data_name = {pre_node_out_data_name}") + continue + cur_node_inp_data_lname = self.global_to_local_op_name[ + pre_node_out_data_name] + cur_sharding_spec = strategy.sharding_specs[ + cur_node_inp_data_lname] + + pre_node_out_sharding_specs = [] + for pre_strategy in pre_node.strategies_vector: + pre_node_out_sharding_specs.append( + pre_strategy.sharding_specs[pre_node_out_data_lname]) + + if pre_node not in resharding_costs: + resharding_costs[pre_node.node_name] = [] + for prev_sharding_spec in pre_node_out_sharding_specs: + resharding_cost = self._compute_resharding_cost( + prev_sharding_spec, cur_sharding_spec, + self.op_data[cur_node_inp_data_lname]) + resharding_costs[pre_node.node_name].append(resharding_cost) + strategy.resharding_costs = resharding_costs + + def _enumerate_all_possible_1d_sharding(self, mesh_dim, dim_size): + dim_partition_list = [] + for i in range(dim_size): + dim_partition_list.append({i: mesh_dim}) + return dim_partition_list + + def _enumerate_all_possible_2d_sharding(self, mesh_dim0, mesh_dim1, + dim_size): + dim_partition_list = [] + for i in range(dim_size): + for j in range(dim_size): + if i != j: + dim_partition_list.append({i: mesh_dim0, j: mesh_dim1}) + return dim_partition_list + + def _update_memory_cost(self, strategies): + for strategy in strategies: + inout_memory_footprint, max_inout_memory_footprint = 0.0, 0.0 + for spec in strategy.sharding_specs.values(): + inout_memory_footprint += spec.get_sharded_size_per_device() + max_inout_memory_footprint += spec.get_max_sharded_size_per_device( + ) + + # the communication happens + comm_buffer_footprint, max_comm_buffer_footprint = 0.0, 0.0 + for comm_spec in strategy.communication_actions.values(): + comm_buffer_footprint += comm_spec.get_mem_cost() + max_comm_buffer_footprint += comm_spec.get_max_mem_cost() + + # when doing the output0 comm action, the input buffer should be released, the buffer is used to estimate the memory time + # rather than memory usage + strategy.inout_memory_footprint = inout_memory_footprint + + strategy.comm_buff_memory_footprint = comm_buffer_footprint + strategy.peak_memory_footprint = max(max_inout_memory_footprint, + max_comm_buffer_footprint) + + # The const memory (weight) is recorded in constant layers and should be accumulated + strategy.const_memory_footprint = 0.0 + + def _generate_bcast_dims(self, batch_dims, out_data_shape): + for output in self.outputs: + if output.broadcast_across_batch: + for bs in batch_dims: + if output.shape[ + bs] == 1 and output.shape[bs] != out_data_shape[bs]: + output.attrs["broadcast_dims"].append(bs) + + def _recover_bcast_partition_dict(self, partition_dict, op_data): + ret = {} + for data_dim, mesh_dim in partition_dict.items(): + if data_dim not in op_data.attrs[ + "broadcast_dims"] and data_dim + len( + op_data.shape) not in op_data.attrs[ + "broadcast_dims"] and op_data.shape[data_dim] != 1: + ret[data_dim] = mesh_dim + return ret diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py new file mode 100644 index 000000000..3cfed5023 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/normalization_node.py @@ -0,0 +1,60 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Normalization(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + self.axes = layer.as_trt().axes + self.weight_bias_dim_base = 0 + layer.to_base_class() + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + shard_reduction_axes = False + for dim in range(len(self.get_input(0).shape)): + if (self.axes & (1 << dim)) and dim in dim_partition_dict: + shard_reduction_axes = True + break + if shard_reduction_axes: + continue + dim_partition_dict_mapping = { + "input0": dim_partition_dict, + "output0": dim_partition_dict, + } + if self.num_inputs >= 2: + dim_partition_dict_mapping['input1'] = {} + if self.num_inputs >= 3: + dim_partition_dict_mapping['input2'] = {} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {} scale {}, bias {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence, + sharding_spec_mapping['input1'].sharding_sequence + if self.num_inputs >= 2 else 'None', + sharding_spec_mapping['input2'].sharding_sequence + if self.num_inputs >= 3 else 'None', + ) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py new file mode 100644 index 000000000..09d500797 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/output_node.py @@ -0,0 +1,79 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class OuputNode(Node): + + def _update_memory_cost(self, strategies): + for strategy in strategies: + if not self.no_memory_footprint: + strategy.const_memory_footprint = strategy.sharding_specs[ + 'input0'].get_max_sharded_size_per_device() + + def __init__(self, tensor): + self._layer = None + self.is_shape_io = False + self._inputs = [] + self._outputs = [] + self.predecessor_nodes = [] + self.predecessor_nodes_out_index = {} + self.successor_nodes = [] + self.op_data = {} + self.global_to_local_op_name = {} + self.is_replicated = tensor.attrs.get("is_replicated", False) + self.same_spec_id = tensor.attrs.get("same_spec_id", -1) + self.no_memory_footprint = tensor.attrs.get("no_memory_footprint", + False) + self.building_block_id = -1 + self.cost_level = -1 + self.stage_type = None + self.in_start_block = None + self.in_end_block = None + self.in_slowest_block = None + input = tensor.copy() + self._inputs.append(input) + self.op_data['input0'] = input + self.global_to_local_op_name[input.name] = 'input0' + + self.sharding_weight = 1.0 + self.resharding_weight = 1.0 + self.pipeline_weight = 0 + self.node_name = tensor.name + self.node_type = 'output_node' + self.num_inputs = 0 + self.num_outputs = 1 + self.dtype = tensor.dtype + self.strategies_vector = [] + self.node_runtime_profiler = None + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + dim_partition_dict_mapping = {'input0': dim_partition_dict} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + sharding_seq = sharding_spec_mapping['input0'].sharding_sequence + sharding_strategy = self._get_sharding_strategy( + name=f'output-op {sharding_seq}', + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py new file mode 100644 index 000000000..8042a89d4 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/p2p_node.py @@ -0,0 +1,67 @@ +import copy +from enum import Enum + +from .comm_spec import CommSpec +from .identity_node import Identity +from .sharding_strategy import StrategiesVector + + +class P2PType(Enum): + CROSS_DEVICE = 0 + CROSS_HOST = 1 + + +class P2PNode(Identity): + + def __init__(self, layer): + super().__init__(layer) + self.p2p_type = layer.attrs["p2p_type"] + self.is_fake = True + + def _collect_strategies(self, device_mesh): + # one input for softmax node + predecessor = self.predecessor_nodes[0] + strategies_vector = StrategiesVector(self) + for idx, strategy in enumerate(predecessor.strategies_vector): + # current node's local name input0 -> global name xxx + global_input_name = self.op_data['input0'].name + # global name xxx -> pre node local output name + prenode_local_name = predecessor.global_to_local_op_name[ + global_input_name] + dim_partition_dict = copy.deepcopy( + strategy.sharding_specs[prenode_local_name].dim_partition_dict) + in0_partition_dict = dim_partition_dict + out_partition_dict = copy.deepcopy(dim_partition_dict) + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + + logical_process_axis = [ + ['p2p_cross_device'] + ] if self.p2p_type == P2PType.CROSS_DEVICE else [['p2p_cross_host']] + # get communication action mapping + communication_action_mapping = {} + output0_comm_action = CommSpec( + comm_pattern='peer_to_peer', + sharding_spec=sharding_spec_mapping['output0'], + logical_process_axis=logical_process_axis, + ) + communication_action_mapping['output0'] = output0_comm_action + + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategies_vector.append(sharding_strategy) + return strategies_vector + + def _profile_sharding_cost(self, strategy, device_mesh): + return 0.0 diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py new file mode 100644 index 000000000..42a2eeb8a --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_node.py @@ -0,0 +1,35 @@ +from tensorrt_llm.network import PluginInfo, get_plugin_info + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class PluginNode(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + self.plugin = layer.as_trt().plugin + self.plugin_type: str = self.plugin.plugin_type + self.plugin_info: PluginInfo = get_plugin_info(layer.graph.as_trt(), + layer.name) + layer.to_base_class() + + def _default_strategy(self, device_mesh): + strategies_vector = StrategiesVector(self) + dim_partition_dict_mapping = {} + for idx in range(self.num_inputs): + dim_partition_dict_mapping[f'input{idx}'] = {} + for idx in range(self.num_outputs): + dim_partition_dict_mapping[f'output{idx}'] = {} + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + return strategies_vector + name = '{}_all_replicate'.format(self.plugin_type) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py new file mode 100644 index 000000000..fc52372b9 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gemm_node.py @@ -0,0 +1,27 @@ +import tensorrt as trt + +from tensorrt_llm._utils import trt_dtype_to_str + +from ..matmul_node import MatrixMultiply +from ..plugin_node import PluginNode + + +class GemmPlugin(MatrixMultiply, PluginNode): + + def __init__(self, layer): + PluginNode.__init__(self, layer) + batch_dims = [i for i in range(len(self.get_output(0).shape))][:-2] + self._generate_bcast_dims(batch_dims, self.get_output(0).shape) + pfc_as_list = self.plugin_info.pfc_as_list + self.op0_transpose = (pfc_as_list['transa'][0] == 1) + self.op1_transpose = (pfc_as_list['transb'][0] == 1) + self.num_out_dims = len(self.get_output(0).shape) + self.dtype = trt_dtype_to_str(trt.DataType(pfc_as_list['type_id'][0])) + + def _collect_strategies(self, device_mesh): + strategies_vector = MatrixMultiply._collect_strategies( + self, device_mesh) + return strategies_vector + + def _get_math_time(self, strategy, device_mesh): + return MatrixMultiply._get_math_time(self, strategy, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py new file mode 100644 index 000000000..70439d0df --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/gpt_attention_node.py @@ -0,0 +1,379 @@ +from enum import Enum, auto + +import torch + +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.quantization.mode import QuantMode + +from ..plugin_node import PluginNode +from ..sharding_strategy import StrategiesVector + + +# WARNING: Must in sync with IdxEntry in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.h +class IdxEntry(Enum): + QKV_TENSOR = auto() + K_TENSOR = auto() + V_TENSOR = auto() + SEQUENCE_LENGTH = auto() + HOST_PAST_KEY_VALUE_LENGTHS = auto() + HOST_MAX_ATTENTION_WINDOW = auto() + HOST_SINK_TOKEN_LENGTH = auto() + CONTEXT_LENGTHS = auto() + CACHE_INDIR = auto() + REQUEST_TYPES = auto() + KV_CACHE_BLOCK_POINTERS = auto() + HOST_KV_CACHE_BLOCK_POINTERS = auto() + PAST_KEY_VALUE = auto() + KV_CACHE_QUANTIZATION_SCALE = auto() + KV_CACHE_DEQUANTIZATION_SCALE = auto() + ALIBI_SLOPES = auto() + RELATIVE_ATTENTION_BIAS = auto() + CROSS_QKV = auto() + CROSS_QKV_LENGTH = auto() + ENCODER_INPUT_LENGTH = auto() + HOST_CONTEXT_LENGTH = auto() + QKV_BIAS_TENSOR = auto() + MEDUSA_PACKED_MASK = auto() + MEDUSA_POSITION_OFFSETS = auto() + + +class IdxEntryParser: + + def __init__(self, plugin_info): + self.num_kv_heads = plugin_info.pfc_as_list['num_kv_heads'][0] + self.unfuse_qkv_gemm = bool( + plugin_info.pfc_as_list['unfuse_qkv_gemm'][0]) + self.use_cache = bool(plugin_info.pfc_as_list['use_cache'][0]) + self.paged_kv_cache = bool(plugin_info.pfc_as_list['paged_kv_cache'][0]) + self.do_cross_attention = bool( + plugin_info.pfc_as_list['do_cross_attention'][0]) + self.remove_input_padding = bool( + plugin_info.pfc_as_list['remove_input_padding'][0]) + self.qkv_bias_enabled = bool( + plugin_info.pfc_as_list['qkv_bias_enabled'][0]) + self.kv_cache_quant_mode = QuantMode( + plugin_info.pfc_as_list['kv_cache_quant_mode'][0]) + self.position_embedding_type = PositionEmbeddingType( + plugin_info.pfc_as_list['position_embedding_type'][0]) + self.is_medusa_enabled = bool( + plugin_info.pfc_as_list['is_medusa_enabled'][0]) + self.init_entry_to_index() + + # WARNING: Must in sync with GPTAttentionPlugin::isEntryUsed in cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp + def is_entry_used(self, entry: IdxEntry) -> bool: + if entry == IdxEntry.QKV_TENSOR: + return True + elif entry == IdxEntry.K_TENSOR: + return self.unfuse_qkv_gemm + elif entry == IdxEntry.V_TENSOR: + return self.unfuse_qkv_gemm + elif entry == IdxEntry.SEQUENCE_LENGTH: + return self.use_cache + elif entry == IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS: + return self.use_cache + elif entry == IdxEntry.HOST_MAX_ATTENTION_WINDOW: + return True + elif entry == IdxEntry.HOST_SINK_TOKEN_LENGTH: + return True + elif entry == IdxEntry.CONTEXT_LENGTHS: + return True + elif entry == IdxEntry.CACHE_INDIR: + return self.use_cache + elif entry == IdxEntry.REQUEST_TYPES: + return True + elif entry == IdxEntry.KV_CACHE_BLOCK_POINTERS: + return self.use_cache and self.paged_kv_cache + elif entry == IdxEntry.HOST_KV_CACHE_BLOCK_POINTERS: + return self.use_cache and self.paged_kv_cache + elif entry == IdxEntry.PAST_KEY_VALUE: + return self.use_cache and not self.paged_kv_cache + elif entry == IdxEntry.KV_CACHE_QUANTIZATION_SCALE: + return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant( + ) + elif entry == IdxEntry.KV_CACHE_DEQUANTIZATION_SCALE: + return self.use_cache and self.kv_cache_quant_mode.has_kv_cache_quant( + ) + elif entry == IdxEntry.ALIBI_SLOPES: + return self.position_embedding_type.is_alibi() + elif entry == IdxEntry.RELATIVE_ATTENTION_BIAS: + return self.position_embedding_type == PositionEmbeddingType.relative + elif entry == IdxEntry.CROSS_QKV: + return self.do_cross_attention + elif entry == IdxEntry.CROSS_QKV_LENGTH: + return self.do_cross_attention + elif entry == IdxEntry.ENCODER_INPUT_LENGTH: + return self.do_cross_attention + elif entry == IdxEntry.HOST_CONTEXT_LENGTH: + return self.remove_input_padding + elif entry == IdxEntry.QKV_BIAS_TENSOR: + return self.qkv_bias_enabled + elif entry == IdxEntry.MEDUSA_PACKED_MASK: + return self.is_medusa_enabled + elif entry == IdxEntry.MEDUSA_POSITION_OFFSETS: + return self.is_medusa_enabled + else: + return False + + def init_entry_to_index(self): + self.entry_to_index = {} + index = 0 + for entry in IdxEntry: + if self.is_entry_used(entry): + self.entry_to_index[entry] = index + index += 1 + + def get_index(self, entry: IdxEntry) -> int: + if entry not in self.entry_to_index: + raise Exception( + f"Entry {entry} is not existed in gpt attention plugin layer {self.layer.name}" + ) + return self.entry_to_index[entry] + + +def get_partition(device_dim, device_ids): + if device_dim == [0]: + partition = device_ids.shape[0] + elif device_dim == [1]: + partition = device_ids.shape[1] + else: + assert device_dim == [0, 1] or device_dim == [1, 0] + partition = device_ids.size + return partition + + +class GPTAttentionPlugin(PluginNode): + + def __init__(self, layer): + super().__init__(layer) + self.parser = IdxEntryParser(self.plugin_info) + assert self.num_inputs == len( + self.parser.entry_to_index + ), f'the plugin inputs number {self.num_inputs} is invalid' + assert self.num_outputs == ( + 2 if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE) else + 1), f'the plugin outputs number {self.num_outputs} has been changed' + + def _tp_strategy(self, device_mesh): + strategies_vector = StrategiesVector(self) + head_dim = 1 if self.parser.remove_input_padding else 2 + # TODO: allow mesh_dim = [0] or [1] + # for mesh_dim in ([0], [1], [0, 1]): + for mesh_dim in ([0, 1], ): + if self.parser.num_kv_heads != 1: + # MHA or GQA + # TODO: allow to duplicate kv when #kv_head < #partition + q_pdict = { + head_dim: mesh_dim + } # split in heads/hidden dimension + k_pdict = { + head_dim: mesh_dim + } # split in heads/hidden dimension + v_pdict = { + head_dim: mesh_dim + } # split in heads/hidden dimension + pastkv_pdict = {2: mesh_dim} # split in heads dimension + present_kv_pdict = {2: mesh_dim} # split in heads dimension + else: + # MQA + q_pdict = { + head_dim: mesh_dim + } # split in heads/hidden dimension + k_pdict = {} # RR + v_pdict = {} # RR + pastkv_pdict = {} # RR + present_kv_pdict = {} # RR + + out0_pdict = {head_dim: mesh_dim} + + dim_partition_dict_mapping = { + f'input{self.parser.get_index(IdxEntry.QKV_TENSOR)}': q_pdict, + f'input{self.parser.get_index(IdxEntry.K_TENSOR)}': k_pdict, + f'input{self.parser.get_index(IdxEntry.V_TENSOR)}': v_pdict, + 'output0': out0_pdict, + } + if self.parser.is_entry_used(IdxEntry.PAST_KEY_VALUE): + dim_partition_dict_mapping[ + f'input{self.parser.get_index(IdxEntry.PAST_KEY_VALUE)}'] = pastkv_pdict + dim_partition_dict_mapping['output1'] = present_kv_pdict + for i in range(self.num_inputs): + if f'input{i}' not in dim_partition_dict_mapping: + dim_partition_dict_mapping[f'input{i}'] = {} + for i in range(self.num_outputs): + if f'output{i}' not in dim_partition_dict_mapping: + dim_partition_dict_mapping[f'output{i}'] = {} + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = 'gptAttentionPlugin_tp_strategy' + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector + + def _dp_strategy(self, device_mesh): + strategies_vector = StrategiesVector(self) + for mesh_dim in ([0], [1], [0, 1]): + dim_partition_dict_mapping = {} + for i in range(self.num_inputs): + dim_partition_dict_mapping[f'input{i}'] = {0: mesh_dim} + for i in range(self.num_outputs): + dim_partition_dict_mapping[f'output{i}'] = {0: mesh_dim} + + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = 'gptAttentionPlugin_dp_strategy' + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector + + def _collect_strategies(self, device_mesh): + if device_mesh.size == 1: + default_strategies = self._default_strategy(device_mesh) + else: + # Avoid to use all-replicate strategy for mesh size > 1 + # since the CPP runtime does not support it for gpt attention plugin + default_strategies = StrategiesVector(self) + for idx, strategy in enumerate(default_strategies): + strategy.name = 'gptAttentionPlugin_' + strategy.name + f'{idx}' + if self.parser.unfuse_qkv_gemm: + tp_strategies = self._tp_strategy(device_mesh) + default_strategies.extend(tp_strategies) + # if we don't split the batch dim, it should be default strategis + # elif we split the batch dim, it should be dp_strategies + # we can use above information to distinguish the two kinds of strategy + if not self.parser.remove_input_padding: + dp_strategies = self._dp_strategy(device_mesh) + default_strategies.extend(dp_strategies) + return default_strategies + + @staticmethod + def parameter_generator(sharding_specs, plugin_info): + + def get_shape(entry): + return sharding_specs[ + f'input{parser.get_index(entry)}'].get_sharded_shape_per_device( + ) + + parser = IdxEntryParser(plugin_info) + updated_input_values = {} + batch_size = get_shape(IdxEntry.CONTEXT_LENGTHS)[0] + if parser.use_cache: + beams_width = get_shape(IdxEntry.CACHE_INDIR)[1] + max_seq_length = get_shape(IdxEntry.CACHE_INDIR)[2] + elif not parser.remove_input_padding: + max_seq_length = get_shape(IdxEntry.QKV_BIAS_TENSOR)[1] + else: + max_seq_length = 1 + host_request_types = torch.full( + (batch_size, ), + 1, + dtype=torch.int32, + device='cpu', + ) + updated_input_values[parser.get_index( + IdxEntry.REQUEST_TYPES)] = host_request_types + context_lengths = torch.full( + (batch_size, ), + max_seq_length - 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + updated_input_values[parser.get_index( + IdxEntry.CONTEXT_LENGTHS)] = context_lengths + host_max_attention_window_sizes = torch.tensor( + [max_seq_length], + dtype=torch.int32, + device='cpu', + ) + updated_input_values[parser.get_index( + IdxEntry.HOST_MAX_ATTENTION_WINDOW + )] = host_max_attention_window_sizes + host_sink_token_length = torch.tensor( + [0], + dtype=torch.int32, + device='cpu', + ) + updated_input_values[parser.get_index( + IdxEntry.HOST_SINK_TOKEN_LENGTH)] = host_sink_token_length + if parser.use_cache: + sequence_length = torch.full((batch_size, ), + max_seq_length, + dtype=torch.int32, + device=torch.cuda.current_device()) + updated_input_values[parser.get_index( + IdxEntry.SEQUENCE_LENGTH)] = sequence_length + host_past_key_value_length = torch.full((batch_size, ), + max_seq_length - 1, + dtype=torch.int32, + device='cpu') + updated_input_values[parser.get_index( + IdxEntry.HOST_PAST_KEY_VALUE_LENGTHS + )] = host_past_key_value_length + cache_indirections = torch.full( + (batch_size, beams_width, max_seq_length), + 0, + dtype=torch.int32, + device=torch.cuda.current_device()) + updated_input_values[parser.get_index( + IdxEntry.CACHE_INDIR)] = cache_indirections + if parser.remove_input_padding: + host_context_lengths = torch.full(get_shape( + IdxEntry.HOST_CONTEXT_LENGTH), + max_seq_length - 1, + dtype=torch.int32, + device='cpu') + updated_input_values[parser.get_index( + IdxEntry.HOST_CONTEXT_LENGTH)] = host_context_lengths + return updated_input_values + + def _profile_sharding_cost(self, strategy, device_mesh): + sharding_spec = strategy.sharding_specs[ + f"input{self.parser.get_index(IdxEntry.QKV_TENSOR)}"] + shard_dims = sharding_spec.dim_partition_dict + device_ids = device_mesh.phy_ids + if 2 in shard_dims: + device_dim = shard_dims[2] + partition = get_partition(device_dim, device_ids) + else: + partition = 1 + if self.parser.is_entry_used(IdxEntry.K_TENSOR): + kv_sharding_spec = strategy.sharding_specs[ + f"input{self.parser.get_index(IdxEntry.K_TENSOR)}"] + kv_shard_dims = kv_sharding_spec.dim_partition_dict + if 2 in kv_shard_dims: + kv_device_dim = kv_shard_dims[2] + kv_partition = get_partition(kv_device_dim, device_ids) + else: + kv_partition = 1 + else: + kv_partition = 1 + num_heads = self.plugin_info.pfc_as_ndarray["num_heads"].copy() + num_kv_heads = self.plugin_info.pfc_as_ndarray["num_kv_heads"].copy() + tp_size = self.plugin_info.pfc_as_ndarray["tp_size"].copy() + tp_rank = self.plugin_info.pfc_as_ndarray["tp_rank"].copy() + num_kv_heads = num_kv_heads // kv_partition + num_heads = num_heads // partition + tp_size[0] = partition + tp_rank[0] = 0 + + updated_layer_attrs = { + 'tp_size': tp_size, + 'tp_rank': tp_rank, + 'num_heads': num_heads, + 'num_kv_heads': num_kv_heads + } + updated_input_values = self.parameter_generator(strategy.sharding_specs, + self.plugin_info) + elapsed_time = self.node_runtime_profiler.runtime_profile( + self.layer, updated_layer_attrs, updated_input_values, strategy, + device_mesh) + return elapsed_time diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py new file mode 100644 index 000000000..c94b49347 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/identity_node.py @@ -0,0 +1,11 @@ +from ..identity_node import Identity +from ..plugin_node import PluginNode + + +class IdentityPlugin(Identity, PluginNode): + + def __init__(self, layer): + PluginNode.__init__(self, layer) + + def _collect_strategies(self, device_mesh): + return Identity._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py new file mode 100644 index 000000000..38fc9dd19 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/look_up_node.py @@ -0,0 +1,19 @@ +import tensorrt as trt + +from ..gather_node import Gather +from ..plugin_node import PluginNode + + +class LookupPlugin(Gather, PluginNode): + + def __init__(self, layer): + PluginNode.__init__(self, layer) + self.mode = trt.GatherMode.DEFAULT + self.axis = 0 + self.num_elementwise_dims = 0 + self.input_id = 1 + self.indice_id = 0 + self.support_vocab_tp = True + + def _collect_strategies(self, device_mesh): + return Gather._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py new file mode 100644 index 000000000..88e0b08f9 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/plugin_nodes/normalization_node.py @@ -0,0 +1,28 @@ +from ..normalization_node import Normalization +from ..plugin_node import PluginNode + + +class LayernormPlugin(Normalization, PluginNode): + + def __init__(self, layer): + PluginNode.__init__(self, layer) + # the is only true for llm model, because layer norm is only effect on hidden dim + hidden_dim = len(self.op_data['input0'].shape) - 1 + self.axes = 1 << hidden_dim + self.weight_bias_dim_base = hidden_dim + + def _collect_strategies(self, device_mesh): + return Normalization._collect_strategies(self, device_mesh) + + +class RMSnormPlugin(Normalization, PluginNode): + + def __init__(self, layer): + PluginNode.__init__(self, layer) + # the is only true for llm model, because rms norm is only effect on hidden dim + hidden_dim = len(self.op_data['input0'].shape) - 1 + self.axes = 1 << hidden_dim + self.weight_bias_dim_base = hidden_dim + + def _collect_strategies(self, device_mesh): + return Normalization._collect_strategies(self, device_mesh) diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py new file mode 100644 index 000000000..22f708faa --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/reduce_node.py @@ -0,0 +1,73 @@ +from tensorrt_llm._utils import trt_axes_to_dim + +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Reduce(Node): + + def __init__(self, layer): + super().__init__(layer) + layer.to_subclass() + self.reduce_dims = trt_axes_to_dim(layer.as_trt().axes) + self.sum_mapping_dict = {} + num_input_dims = len(self.get_input(0).shape) + if layer.as_trt().keep_dims: + for i in range(num_input_dims): + self.sum_mapping_dict[i] = i + else: + output_index = 0 + for i in range(num_input_dims): + if i not in self.reduce_dims: + self.sum_mapping_dict[i] = output_index + output_index += 1 + assert output_index == len(self.get_output(0).shape) + layer.to_base_class() + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['input0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + recover_dims = [] + out_partition_dict = {} + for dim in dim_partition_dict.keys(): + if dim in self.reduce_dims: + recover_dims.append(dim) + elif dim in self.sum_mapping_dict: + out_partition_dict[ + self.sum_mapping_dict[dim]] = dim_partition_dict[dim] + else: + assert 0, f'dim {dim} is not in sum_dims or sum_mapping_dict' + + for dim in recover_dims: + dim_partition_dict.pop(dim) + + in0_parition_dict = dim_partition_dict + dim_partition_dict_mapping = { + "input0": in0_parition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} = {}'.format( + sharding_spec_mapping['output0'].sharding_sequence, + self.reduce_dims, + sharding_spec_mapping['input0'].sharding_sequence) + sharding_strategy = self._get_sharding_strategy( + name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping={}) + strategies_vector.append(sharding_strategy) + return strategies_vector diff --git a/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py b/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py new file mode 100644 index 000000000..ce74a1295 --- /dev/null +++ b/tensorrt_llm/auto_parallel/tensor_parallel/select_node.py @@ -0,0 +1,56 @@ +from .node import Node +from .sharding_strategy import StrategiesVector + + +class Select(Node): + + def __init__(self, layer): + super().__init__(layer) + batch_dims = [i for i in range(len(self.get_output(0).shape))] + self._generate_bcast_dims(batch_dims, self.get_output(0).shape) + + def _collect_strategies(self, device_mesh): + dim_partition_list = [] + dim_size = len(self.op_data['output0'].shape) + dim_partition_list.append({}) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_1d_sharding([0, 1], dim_size)) + dim_partition_list.extend( + self._enumerate_all_possible_2d_sharding([0], [1], dim_size)) + + strategies_vector = StrategiesVector(self) + for dim_partition_dict in dim_partition_list: + # the three inputs are condition, true tensor and false tensor + in0_partition_dict = self._recover_bcast_partition_dict( + dim_partition_dict, self.op_data['input0']) + in1_partition_dict = self._recover_bcast_partition_dict( + dim_partition_dict, self.op_data['input1']) + in2_partition_dict = self._recover_bcast_partition_dict( + dim_partition_dict, self.op_data['input2']) + out_partition_dict = dim_partition_dict + dim_partition_dict_mapping = { + "input0": in0_partition_dict, + "input1": in1_partition_dict, + "input2": in2_partition_dict, + "output0": out_partition_dict, + } + sharding_spec_mapping = self._to_sharding_spec_mapping( + dim_partition_dict_mapping, device_mesh) + if 0 == len(sharding_spec_mapping): + continue + name = '{} =