forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gptModelConfig.h
484 lines (401 loc) · 12.2 KB
/
gptModelConfig.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm::runtime
{
struct MambaConfig
{
SizeType dState = 0;
SizeType dConv = 0;
SizeType expand = 0;
};
class GptModelConfig
{
public:
enum class ModelVariant : std::int32_t
{
kGpt = 0,
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
kMamba = 2, // https://github.com/state-spaces/mamba
};
explicit GptModelConfig(
SizeType vocabSize, SizeType nbLayers, SizeType nbHeads, SizeType hiddenSize, nvinfer1::DataType dtype)
: mVocabSize(vocabSize)
, mNbLayers(nbLayers)
, mNbHeads(nbHeads)
, mNbKvHeads(nbHeads)
, mHiddenSize(hiddenSize)
, mSizePerHead(mHiddenSize / mNbHeads)
, mDataType(dtype)
, mUseGptAttentionPlugin(false)
, mUseMambaConv1dPlugin(false)
, mInputPacked{false}
, mPagedKvCache{false}
, mPagedState{false}
, mTokensPerBlock{64}
, mQuantMode{common::QuantMode::none()}
, mMaxBatchSize(0)
, mMaxBeamWidth(0)
, mMaxInputLen(0)
, mMaxSequenceLen(0)
, mMaxNumTokens(std::nullopt)
, mComputeContextLogits(false)
, mComputeGenerationLogits(false)
, mModelVariant(ModelVariant::kGpt)
, mUseCustomAllReduce(false)
, mMaxPromptEmbeddingTableSize(0)
, mMaxDraftLen(0)
, mUseContextFMHAForGeneration(false)
, mPagedContextFMHA(false)
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mMedusaModule(std::nullopt)
{
}
[[nodiscard]] SizeType constexpr getVocabSize() const noexcept
{
return mVocabSize;
}
[[nodiscard]] SizeType constexpr getVocabSizePadded(SizeType worldSize) const noexcept
{
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
}
[[nodiscard]] SizeType constexpr getNbLayers(SizeType pipelineParallelism = 1) const
{
TLLM_CHECK(mNbLayers % pipelineParallelism == 0);
return mNbLayers / pipelineParallelism;
}
[[nodiscard]] SizeType constexpr getNbHeads() const noexcept
{
return mNbHeads;
}
[[nodiscard]] SizeType constexpr getNbKvHeads() const noexcept
{
return mNbKvHeads;
}
void constexpr setNbKvHeads(SizeType nbKvHeads) noexcept
{
mNbKvHeads = nbKvHeads;
}
[[nodiscard]] SizeType constexpr getHiddenSize() const noexcept
{
return mHiddenSize;
}
[[nodiscard]] SizeType constexpr getSizePerHead() const noexcept
{
return mSizePerHead;
}
void constexpr setSizePerHead(SizeType sizePerHead) noexcept
{
mSizePerHead = sizePerHead;
}
[[nodiscard]] nvinfer1::DataType constexpr getDataType() const noexcept
{
return mDataType;
}
[[nodiscard]] bool constexpr useGptAttentionPlugin() const noexcept
{
return mUseGptAttentionPlugin;
}
void constexpr useGptAttentionPlugin(bool useGptAttentionPlugin) noexcept
{
mUseGptAttentionPlugin = useGptAttentionPlugin;
}
[[nodiscard]] bool constexpr useMambaConv1dPlugin() const noexcept
{
return mUseMambaConv1dPlugin;
}
void constexpr useMambaConv1dPlugin(bool useMambaConv1dPlugin) noexcept
{
mUseMambaConv1dPlugin = useMambaConv1dPlugin;
}
[[nodiscard]] bool constexpr usePackedInput() const noexcept
{
return mInputPacked;
}
void constexpr usePackedInput(bool inputPacked) noexcept
{
mInputPacked = inputPacked;
}
[[nodiscard]] bool constexpr usePagedKvCache() const noexcept
{
return mPagedKvCache;
}
void constexpr usePagedKvCache(bool pagedKvCache) noexcept
{
mPagedKvCache = pagedKvCache;
}
[[nodiscard]] bool constexpr usePagedState() const noexcept
{
return mPagedState;
}
void constexpr usePagedState(bool pagedState) noexcept
{
mPagedState = pagedState;
}
[[nodiscard]] SizeType constexpr getTokensPerBlock() const noexcept
{
return mTokensPerBlock;
}
void constexpr setTokensPerBlock(SizeType TokensPerBlock) noexcept
{
mTokensPerBlock = TokensPerBlock;
}
[[nodiscard]] common::QuantMode constexpr getQuantMode() const noexcept
{
return mQuantMode;
}
void constexpr setQuantMode(common::QuantMode QuantMode) noexcept
{
mQuantMode = QuantMode;
}
[[nodiscard]] bool constexpr supportsInflightBatching() const noexcept
{
return (isTransformerBased() && mUseGptAttentionPlugin && mInputPacked && mPagedKvCache)
|| (isSsmBased() && mUseMambaConv1dPlugin && mInputPacked && mPagedState);
}
[[nodiscard]] SizeType constexpr getMaxBatchSize() const noexcept
{
return mMaxBatchSize;
}
void constexpr setMaxBatchSize(SizeType maxBatchSize) noexcept
{
mMaxBatchSize = maxBatchSize;
}
[[nodiscard]] SizeType constexpr getMaxBeamWidth() const noexcept
{
return mMaxBeamWidth;
}
void constexpr setMaxBeamWidth(SizeType maxBeamWidth) noexcept
{
mMaxBeamWidth = maxBeamWidth;
}
[[nodiscard]] SizeType constexpr getMaxInputLen() const noexcept
{
return mMaxInputLen;
}
void constexpr setMaxInputLen(SizeType maxInputLen) noexcept
{
mMaxInputLen = maxInputLen;
}
[[nodiscard]] SizeType constexpr getMaxSequenceLen() const noexcept
{
return mMaxSequenceLen;
}
void constexpr setMaxSequenceLen(SizeType maxSequenceLen) noexcept
{
mMaxSequenceLen = maxSequenceLen;
}
[[nodiscard]] std::optional<SizeType> constexpr getMaxNumTokens() const noexcept
{
return mMaxNumTokens;
}
void constexpr setMaxNumTokens(std::optional<SizeType> maxNumTokens) noexcept
{
mMaxNumTokens = maxNumTokens;
}
[[nodiscard]] bool constexpr usePromptTuning() const noexcept
{
return mMaxPromptEmbeddingTableSize > 0;
}
[[nodiscard]] SizeType constexpr getMaxPromptEmbeddingTableSize() const noexcept
{
return mMaxPromptEmbeddingTableSize;
}
void constexpr setMaxPromptEmbeddingTableSize(SizeType maxPromptEmbeddingTableSize) noexcept
{
mMaxPromptEmbeddingTableSize = maxPromptEmbeddingTableSize;
}
[[nodiscard]] bool constexpr computeContextLogits() const noexcept
{
return mComputeContextLogits;
}
void constexpr computeContextLogits(bool computeContextLogits) noexcept
{
mComputeContextLogits = computeContextLogits;
}
[[nodiscard]] bool constexpr computeGenerationLogits() const noexcept
{
return mComputeGenerationLogits;
}
void constexpr computeGenerationLogits(bool computeGenerationLogits) noexcept
{
mComputeGenerationLogits = computeGenerationLogits;
}
[[nodiscard]] ModelVariant getModelVariant() const
{
return mModelVariant;
}
void setModelVariant(ModelVariant modelVariant)
{
mModelVariant = modelVariant;
}
[[nodiscard]] bool constexpr useCustomAllReduce() const noexcept
{
return mUseCustomAllReduce;
}
void constexpr useCustomAllReduce(bool customAllReduce) noexcept
{
mUseCustomAllReduce = customAllReduce;
}
void constexpr setMaxDraftLen(SizeType maxDraftLen) noexcept
{
mMaxDraftLen = maxDraftLen;
}
[[nodiscard]] SizeType getMaxDraftLen() const
{
return mMaxDraftLen;
}
[[nodiscard]] SizeType constexpr getMaxTokensPerStep() const noexcept
{
return mMaxDraftLen + 1;
}
void constexpr setUseContextFMHAForGeneration(bool useContextFMHAForGeneration) noexcept
{
mUseContextFMHAForGeneration = useContextFMHAForGeneration;
}
[[nodiscard]] bool constexpr getContextFMHAForGeneration() const noexcept
{
return mUseContextFMHAForGeneration;
}
void constexpr setPagedContextFMHA(bool pagedContextFMHA) noexcept
{
mPagedContextFMHA = pagedContextFMHA;
}
[[nodiscard]] bool constexpr getPagedContextFMHA() const noexcept
{
return mPagedContextFMHA;
}
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
{
return mUseLoraPlugin;
}
void constexpr useLoraPlugin(bool useLoraPlugin) noexcept
{
mUseLoraPlugin = useLoraPlugin;
}
std::vector<LoraModule> const& getLoraModules() const noexcept
{
return mLoraModules;
}
void setLoraModules(std::vector<LoraModule> const& loraModules) noexcept
{
mLoraModules = loraModules;
}
[[nodiscard]] SizeType constexpr getMlpHiddenSize() const noexcept
{
return mMlpHiddenSize;
}
void constexpr setMlpHiddenSize(SizeType mlpHiddenSize) noexcept
{
mMlpHiddenSize = mlpHiddenSize;
}
[[nodiscard]] SizeType constexpr getMaxLoraRank() const noexcept
{
return mMaxLoraRank;
}
void constexpr setMaxLoraRank(SizeType maxLoraRank) noexcept
{
mMaxLoraRank = maxLoraRank;
}
[[nodiscard]] bool constexpr useMedusa() const noexcept
{
return mMedusaModule.has_value();
}
[[nodiscard]] std::optional<MedusaModule> getMedusaModule() const noexcept
{
return mMedusaModule;
}
void setMedusaModule(MedusaModule const& medusaModule) noexcept
{
mMedusaModule = medusaModule;
}
[[nodiscard]] nvinfer1::DataType getKvDataType() const noexcept
{
if (getQuantMode().hasFp8KvCache())
{
return nvinfer1::DataType::kFP8;
}
else if (getQuantMode().hasInt8KvCache())
{
return nvinfer1::DataType::kINT8;
}
else
{
return getDataType();
}
}
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
{
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm;
}
[[nodiscard]] bool hasMambaConfig() const noexcept
{
return mMambaConfig.has_value();
}
[[nodiscard]] std::optional<MambaConfig> getMambaConfig() const noexcept
{
return mMambaConfig;
}
void setMambaConfig(MambaConfig const& mambaConfig) noexcept
{
mMambaConfig = mambaConfig;
}
[[nodiscard]] bool constexpr isSsmBased() const noexcept
{
return mModelVariant == ModelVariant::kMamba;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
SizeType mNbHeads;
SizeType mNbKvHeads;
SizeType mHiddenSize;
SizeType mSizePerHead;
nvinfer1::DataType mDataType;
bool mUseGptAttentionPlugin;
bool mUseMambaConv1dPlugin;
bool mInputPacked;
bool mPagedKvCache;
bool mPagedState;
SizeType mTokensPerBlock;
common::QuantMode mQuantMode;
SizeType mMaxBatchSize;
SizeType mMaxBeamWidth;
SizeType mMaxInputLen;
SizeType mMaxSequenceLen;
std::optional<SizeType> mMaxNumTokens;
bool mComputeContextLogits;
bool mComputeGenerationLogits;
ModelVariant mModelVariant;
bool mUseCustomAllReduce;
SizeType mMaxPromptEmbeddingTableSize;
SizeType mMaxDraftLen;
bool mUseContextFMHAForGeneration;
bool mPagedContextFMHA;
bool mUseLoraPlugin;
std::vector<LoraModule> mLoraModules;
SizeType mMlpHiddenSize;
SizeType mMaxLoraRank;
std::optional<MedusaModule> mMedusaModule;
std::optional<MambaConfig> mMambaConfig;
};
} // namespace tensorrt_llm::runtime