forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gptSession.h
370 lines (306 loc) · 13.2 KB
/
gptSession.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <vector>
namespace tensorrt_llm::batch_manager
{
class TrtGptModelV1;
}
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManager;
}
namespace tensorrt_llm::runtime
{
namespace utils
{
std::vector<uint8_t> loadEngine(std::string const& enginePath);
}
class IpcMemory;
class IStatefulGptDecoder;
class NcclCommunicator;
class RuntimeBuffers;
class TllmRuntime;
class GptSession
{
using KvCacheManager = batch_manager::kv_cache_manager::KVCacheManager;
using KvCacheConfig = batch_manager::kv_cache_manager::KvCacheConfig;
using TensorPtr = runtime::ITensor::SharedPtr;
using TokenGeneratedCallback = std::function<void(SizeType step, bool finished)>;
public:
using LoggerPtr = std::shared_ptr<nvinfer1::ILogger>;
//! @brief Configuration for session execution and buffer sizes.
//! `generate` may be called with batch size and beam width smaller than the configured parameters.
//! @details `maxBatchSize` will be divided by the number of micro batches to initialize each batch buffer.
class Config
{
public:
Config(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxSequenceLength)
: maxBatchSize{maxBatchSize}
, maxBeamWidth{maxBeamWidth}
, maxSequenceLength{maxSequenceLength}
{
}
// The maximum number of sequences in a batch
SizeType maxBatchSize;
// The maximum width of the beams in beam-search
SizeType maxBeamWidth;
// The length of the longest input sequence
SizeType maxSequenceLength;
// Whether the session will use a different decoder per request.
// It must be set to `true` when running in-flight batching
bool decoderPerRequest{false};
// Whether the session will use CUDA graphs for the engine execution in generation phase
bool cudaGraphMode{false};
KvCacheConfig kvCacheConfig{};
// The micro batch size to be used in context phase.
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
// The micro batch size to be used in generation phase.
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size.
std::optional<SizeType> genMicroBatchSize = std::nullopt;
std::optional<DecodingMode> decodingMode = std::nullopt;
bool normalizeLogProbs = true;
};
//! @brief Optional profiler class to profile the generation phase of an inference request
class GenerationProfiler
{
public:
// Use a constexpr variable to resolve the ambiguous match for overloaded CudaEvent constructor
static constexpr unsigned int flags{cudaEventDefault};
GenerationProfiler()
: start(flags)
, end(flags)
{
}
CudaEvent const& getStart() const
{
return start;
}
CudaEvent const& getEnd() const
{
return end;
}
float getElapsedTimeMs()
{
start.synchronize();
end.synchronize();
float result;
TLLM_CUDA_CHECK(::cudaEventElapsedTime(&result, start.get(), end.get()));
return result;
}
private:
CudaEvent start;
CudaEvent end;
};
//! @param sessionConfig Configuration of the session,
//! @param modelConfig Description of the model,
//! @param worldConfig Description of the environment,
//! @param engineBuffer The compiled TensorRT engine (const void*),
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
//! @param logger The optional logger.
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::vector<uint8_t> const& engineBuffer, LoggerPtr logger = nullptr)
: GptSession(
sessionConfig, modelConfig, worldConfig, engineBuffer.data(), engineBuffer.size(), std::move(logger))
{
}
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
std::string const& engineFile, LoggerPtr logger = nullptr)
: GptSession(sessionConfig, modelConfig, worldConfig, utils::loadEngine(engineFile), std::move(logger))
{
}
[[nodiscard]] nvinfer1::ILogger& getLogger() const;
[[nodiscard]] BufferManager const& getBufferManager() const;
[[nodiscard]] GptModelConfig const& getModelConfig() const
{
return mModelConfig;
}
[[nodiscard]] WorldConfig const& getWorldConfig() const
{
return mWorldConfig;
}
[[nodiscard]] int getDevice() const noexcept
{
return mDevice;
}
[[nodiscard]] bool getNormalizeLogProbs() const noexcept
{
return mNormalizeLogProbs;
}
[[nodiscard]] nvinfer1::IEngineInspector& getEngineInspector() const;
[[nodiscard]] nvinfer1::DataType getLogitDataType() const;
//! @brief This function performs the generation loop.
//! @details Given input tensors to read from, output tensors to populate, that member function
//! can be produced or each sequence has reached completion (due to the production
//! will run the generation loop until it reaches the maximum number of tokens that
//! of "end-of-sequence" or a word in the list of "stop words"). The pseudo-code of
//! that function looks like (member function names were changed to keep the
//! presentation simple):
//!
//! ```cpp
//! // Have all the sequences in the batch reached completion?
//! bool allFinished = false;
//!
//! // Until all sequences are finished or the number of steps reaches the limit...
//! for (int step = 0; !allFinished && step < maxNewTokens; ++step) {
//!
//! // Trigger the computation of the logits...
//! computeLogits(...);
//!
//! // Run the sampling to produce a token (for each active sequence) from the logits.
//! allFinished = generateTokensFromLogits(...);
//!
//! // Callback to stream the output tokens while the generation loop continues.
//! onTokenGenerated(...);
//! }
//! ```
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig,
std::shared_ptr<GenerationProfiler> const generationProfiler = nullptr);
private:
[[nodiscard]] bool useCudaGraphs()
{
return !mCudaGraphInstances.empty();
}
void generateBatched(std::vector<GenerationOutput>& microBatchesOutputs,
std::vector<GenerationInput> const& microBatchesInputs, SamplingConfig const& samplingConfig,
TokenGeneratedCallback const& onTokenGenerated, std::shared_ptr<GenerationProfiler> const generationProfiler);
void setup(Config const& sessionConfig);
void createContexts();
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches,
DecodingMode const& decodingMode);
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
void executeContextStep(std::vector<GenerationInput> const& generationBatchesInputs,
std::vector<SizeType> const& generationBatchesOffsets, KvCacheManager const* kvCacheManager);
SizeType executeGenerationStep(SizeType step, std::vector<GenerationInput> const& microBatchesInputs,
std::vector<GenerationOutput>& microBatchesOutputs, std::vector<SizeType> const& microBatchOffsets,
KvCacheManager* kvCacheManager, std::vector<bool>& microBatchesFinished);
//! @brief Execute decoder on last PP rank, receive decoder output on other PP ranks.
void decoderStepAsync(SizeType decoderStep, SizeType microBatchId);
//! @brief Synchronize with the decoder and return the `shouldStop` flag.
bool shouldStopSync(SizeType batchSize, SizeType beamWidth, SizeType microBatchId);
//! @brief Collect final output ids and log probs on last PP rank and send them to first PP rank.
//! @details Receives are asynchronous on host, so synchronization is required before access.
void finalize(SizeType microBatchId);
void kvCacheAddSequences(SizeType beamWidth, SizeType microBatchId, SizeType firstBatchIdx);
//! @brief Populate outputIds and return reference to newTokens tensor
ITensor::SharedPtr initDecoder(ITensor& outputIds, GenerationInput const& inputs, GenerationOutput const& outputs,
SamplingConfig const& samplingConfig, SizeType microBatchId) const;
TokenGeneratedCallback createOnTokenGeneratedCallback(GenerationOutput& outputs);
class CudaGraphExecutor
{
public:
CudaGraphExecutor() = default;
~CudaGraphExecutor()
{
try
{
clear();
}
catch (std::exception& e)
{
TLLM_LOG_EXCEPTION(e);
}
}
bool hasInstance()
{
return mInstance != nullptr;
}
void clear();
void prepareNextGraph(TllmRuntime const& runtime, SizeType nextContextId);
void launch(CudaStream const& stream);
private:
void create(cudaGraph_t const& graph);
bool update(cudaGraph_t const& graph);
void uploadToStream(CudaStream const& stream);
cudaGraphExec_t mInstance;
};
class MicroBatchConfig
{
public:
MicroBatchConfig()
: numCtxBatches{1}
, numGenBatches{1}
, ctxBatchSize{0}
, genBatchSize{0}
{
}
explicit MicroBatchConfig(SizeType maxBatchSize, SizeType pipelineParallelism,
std::optional<SizeType> genMicroBatchSize, std::optional<SizeType> ctxMicroBatchSize);
constexpr SizeType numCtxPerGen() const
{
return numCtxBatches / numGenBatches;
}
//! @details flip-flop between 2 graph instances for each generation batch.
constexpr SizeType getGenGraphId(SizeType flipFlopId, SizeType generationBatchId) const
{
return flipFlopId * numGenBatches + generationBatchId;
}
SizeType numCtxBatches;
SizeType numGenBatches;
SizeType ctxBatchSize;
SizeType genBatchSize;
};
friend class batch_manager::TrtGptModelV1;
private:
GptModelConfig const mModelConfig;
WorldConfig const mWorldConfig;
int mDevice{-1};
std::shared_ptr<NcclCommunicator> mPipelineComm;
std::shared_ptr<CudaStream> mCommStream;
CudaEvent mCommEvent{};
// tensor parallelism with custom allreduce plugin
ITensor::SharedPtr mCommPtrs;
std::vector<std::shared_ptr<IpcMemory>> mIpcMemoryHandles;
SizeType mDecoderMaxSequenceLength{};
SizeType mDecoderMaxAttentionWindow{};
SizeType mDecoderSinkTokenLength{};
LoggerPtr mLogger;
std::shared_ptr<TllmRuntime> mRuntime;
std::shared_ptr<KvCacheManager> mKvCacheManager;
MicroBatchConfig mMicroBatchConfig;
// for each micro batch
std::vector<std::shared_ptr<IStatefulGptDecoder>> mDecoders;
std::vector<std::shared_ptr<RuntimeBuffers>> mBuffers;
std::vector<CudaEvent> mReceivedEvents;
bool mCudaGraphMode{false};
// ping-pong instances
std::vector<CudaGraphExecutor> mCudaGraphInstances;
bool mNormalizeLogProbs = true;
};
} // namespace tensorrt_llm::runtime