forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
functional.py
5039 lines (4000 loc) · 180 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import weakref
from collections import OrderedDict
from enum import IntEnum, IntFlag, auto
from functools import partial
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
from . import graph_rewriting as gw
from ._common import default_net, default_trtnet, precision
from ._utils import (bf16_array, dim_resolve_negative, dim_to_trt_axes,
fp16_array, fp32_array, int32_array, np_dtype_to_trt,
str_dtype_to_trt, torch_to_numpy, trt_dtype_to_np,
trt_dtype_to_torch)
from .network import PluginInfo, set_np_weight, set_plugin_info
from .plugin import TRT_LLM_PLUGIN_NAMESPACE, current_all_reduce_helper
from .quantization import QuantMode
class DimRange(object):
'''
One DimRange object stores the ranges of all the dimensions of one tensor in one optimization profile.
For example, tensor has 2 dimensions. Then the data members are:
self.min = [dim 0 min, dim 1 min]
self.opt = [dim 0 opt, dim 1 opt]
self.max = [dim 0 max, dim 1 max]
For static dimension, it has min==opt==max, thus the \p shape param in the ctor can be an integer
'''
def __init__(self, shape: List[Union[int, List[int], Tuple[int, int, int]]],
names: List[str]):
'''
Parameters:
shape: a list with length N, each element is an integer or a 3-elements tuple/list of int,
where N is the number of dimensions for a tensor.
When one element is an integer, it means that dimension is static.
Otherwise, when one element is a tuple/list, it means the dimension is dynamic.
The 3 elements in one tuple/list is ordered by (min, opt, max), and this function asserts
0 <= min <= opt <= max.
Example, for a 3 rank tensor, with 1st dimension being static and has value 16, and second dimension being dynamic with
min/opt/max values being 1/8/32, and 3rd dimension being static and has value 8.
The shape parameter could be:
[16, (1, 8, 32), 8]
It has same semantics of
[(16, 16, 16), (1, 8, 32), (8, 8, 8)]
'''
self.min = []
self.opt = []
self.max = []
self.dimension_names = names
assert len(names) == len(
shape
), "Expecting shape list and name list must have same length, got {shape=}, {name=}"
for dim in shape:
if isinstance(dim, (list, tuple)):
assert len(dim) == 3 and 0 <= dim[0] <= dim[1] <= dim[2], \
"Each dimension must specify a 3-elements tuple or list in the order of (min,opt,max), got {dim=}"
self.min.append(dim[0])
self.opt.append(dim[1])
self.max.append(dim[2])
elif isinstance(dim, int):
self.min.append(dim)
self.opt.append(dim)
self.max.append(dim)
else:
raise AttributeError(
f'Dimension should be [min, opt, max] (dynamic shape) or int (specific value). Got {type(dim)}'
)
def __eq__(self, __value: object) -> bool:
return isinstance(__value, DimRange) and \
self.dimension_names == __value.dimension_names and \
self.min == __value.min and self.opt == __value.opt and self.max == __value.max
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return f"{self.dimension_names=} {self.min=}, {self.opt=}, {self.max=})"
def __hash__(self) -> int:
return hash(str(self))
class Tensor(object):
'''
The class to represent dense tensors.
A dense tensor is named, has a shape and contains typed elements. Each
dimension of a tensor can either be static or dynamic. Static dimensions
are known at engine compilation by TensorRT. Dynamic dimensions can take
values determined at runtime. The tensor can be located on the host (CPU)
or the device (GPU).
'''
def __init__(self,
name=None,
dtype=None,
shape=None,
dim_range=None,
is_network_input=True,
location=trt.TensorLocation.DEVICE,
network=None,
trt_tensor=None):
'''
Parameters:
name : str
The name of the tensor.
dtype : tensorrt.DataType
The type of the elements of the tensor. See the TensorRT
documentation for list of supported data types.
shape : tensorrt.Dims
The dimensions of the tensor. In TensorRT-LLM, tensors can have
static or dynamic dimensions (it is possible to mix static and
dynamic dimensions). A static dimension is known when the
TensorRT engine is built. A dynamic dimension can be set when
the engine is executed. Use -1 for dynamic dimensions.
dim_range : OrderedDict
An ordered dictionary (the positions of the elements matter)
that associates a name and a range of values to the dimensions.
For a static dimension, the range must be limited to a single
value. For a dynamic dimension, the range is defined by three
values [min, opt, max] where min and max are, respectively, the
smallest and largest possible values of that dimension. The
opt value is used by TensorRT to optimize the engine for the
most common case.
Assume there is N optimization profiles, each item dim_range dict is ordered by:
(dynamic dimension name : [profile 0 (min, opt, max), profile 1 (min, opt, max), ... profile N(min, opt, max)])
or it's following when the dimension is static (can think as min==opt==max):
(static dimension name : [profile 0 value, profile 1 value, ... profile N value])
For static dimension the profile 0-N value must be same, (TODO: can it be simplified to be only 1 value?)
And number of keys is equal to number of optimization profiles.
is_network_input : bool
A boolean indicating if that tensor is an input of the network.
Inputs must be provided by the user to run the engine.
location : tensorrt.TensorLocation
A flag to indicate where the tensor will be located. It can be
on the host (CPU) or the device (GPU).
network: Network
A parent Network instance, that helps to fine the users of this tensor.
trt_tensor: trt.ITensor
Construct with the ITensor instance directly, and no shape profiles are required.
'''
# Layout of self.profiles
# Opt profile 0: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M
# Opt profile 1: dim 0 (min, opt, max), dim 1 (min, opt, max) ... dim M
# ...
# Opt profile N: dim 0 ... dim M
# So from the dim_range arg to self.profiles conversion, there is a layout transpose
# dim_range arg is: {M dimension x N profiles}, while self.profiles layout is {N profiles x M dimensions}
self.profiles = []
self.is_tensor_wrapper = False # specially for the graph rewriter
# work as a wrapper for a trt.ITensor, this is used specially in the graph rewriter
if trt_tensor is not None:
self.is_tensor_wrapper = True
assert network is not None
self.trt_tensor = trt_tensor
self._network = weakref.ref(network)
assert not is_network_input, "is_network_input should be False when trt_tensor is not None"
return
# be cautious here, the weakref is critical to avoid circular referencing before Network and Tensor
# using strong reference will likely cause significant peak memory increase, since Network objects
# holds the weights data.
self._network = weakref.ref(default_net())
if is_network_input:
if dim_range is not None:
assert isinstance(dim_range, OrderedDict)
assert len(
dim_range
) >= 1, f"Each input tensor shall have at least one dimension, tensor '{name}' found {dim_range=}"
found_profiles = [
len(ranges) for _, ranges in dim_range.items()
]
assert all(
[x == found_profiles[0] for x in found_profiles]
), f"Expecting all the dimensions in the dim_range has same number of profiles, tensor '{name}' got {dim_range=}"
num_opt_profile = len(list(dim_range.items())[0][1])
assert num_opt_profile >= 1
for i in range(num_opt_profile):
range_shape = []
dimension_names = []
for dim, ranges in dim_range.items():
assert isinstance(ranges, (list, tuple))
range_shape.append(ranges[i])
dimension_names.append(dim)
self.profiles.append(DimRange(range_shape, dimension_names))
default_net()._add_input(self, name, dtype, shape, dim_range)
self.name = name
self.dtype = dtype
self.shape = shape
self.location = location
@property
def network(self):
return self._network()
@property
def name(self):
'''
The name of the tensor.
'''
return self.trt_tensor.name
@name.setter
def name(self, name):
'''
Set the name of the tensor.
'''
if name is not None:
self.trt_tensor.name = name
@property
def dtype(self):
'''
The type of the elements in the tensor.
'''
return self.trt_tensor.dtype
@dtype.setter
def dtype(self, dtype):
'''
Set the type of the elements in the tensor.
'''
if dtype is not None:
self.trt_tensor.dtype = dtype
@property
def shape(self):
'''
The shape of the tensor.
'''
return self.size()
@shape.setter
def shape(self, shape):
'''
Set the shape of the tensor. See __init__.
'''
if shape is not None:
self.trt_tensor.shape = shape
@property
def location(self):
'''
The physical location of the tensor (on the host or the device).
'''
return self.trt_tensor.location
@location.setter
def location(self, location):
'''
Set the physical location of the tensor (on the host or the device). See __init__.
'''
if location is not None:
self.trt_tensor.location = location
def mark_output(self,
name: Optional[str] = None,
dtype: Optional[Union[str, trt.DataType]] = None):
'''
Mark a tensor as a network output.
When a tensor is marked as an output, its content can be obtained after
the execution of the TensorRT engine. The user is responsible for
allocating buffers to store the output tensors when preparing the
execution of the TensorRT engine.
'''
if name is None:
name = self.name
if dtype is None:
dtype = self.dtype
elif isinstance(dtype, str):
dtype = str_dtype_to_trt(dtype)
else:
assert isinstance(dtype, trt.DataType)
default_net()._mark_output(self, name, dtype)
def __add__(self, b):
'''
See functional.add.
'''
return add(self, b)
def __radd__(self, b):
'''
See functional.add.
'''
return add(b, self)
def __sub__(self, b):
'''
See functional.sub.
'''
return sub(self, b)
def __rsub__(self, b):
'''
See functional.sub.
'''
return sub(b, self)
def __mul__(self, b):
'''
See functional.mul.
'''
return mul(self, b)
def __rmul__(self, b):
'''
See functional.mul.
'''
return mul(b, self)
def __truediv__(self, b):
'''
See functional.div.
'''
return div(self, b)
def __lt__(self, b):
'''
See functional.lt.
'''
return lt(self, b)
def __gt__(self, b):
'''
See functional.gt.
'''
return gt(self, b)
def __eq__(self, b):
'''
See functional.eq.
'''
if self.is_tensor_wrapper:
# for graph rewriter
return hash(self) == hash(b)
else:
# for creating the network
return eq(self, b)
def __ge__(self, b):
'''
Maps to functional.gt or functional.eq.
'''
return op_or(self.__gt__(b), self.__eq__(b))
def __le__(self, b):
'''
Maps to functional.lt or functional.eq.
'''
return op_or(self.__lt__(b), self.__eq__(b))
def view(self, shape, zero_is_placeholder=True):
'''
See functional.view.
'''
return view(self, shape, zero_is_placeholder)
def permute(self, dims):
'''
See functional.permute.
'''
return permute(self, dims)
def transpose(self, dim0, dim1):
'''
See functional.transpose.
'''
return transpose(self, dim0, dim1)
def mean(self, dim, keepdim=False):
'''
See functional.mean.
'''
return mean(self, dim, keepdim)
def max(self, dim, keepdim=False):
'''
See functional.max.
'''
return max(self, dim, keepdim)
def abs(self):
'''
See functional.abs.
'''
return abs(self)
def sqrt(self):
'''
See functional.sqrt.
'''
return sqrt(self)
def log(self):
'''
See functional.log.
'''
return log(self)
def cast(self, dtype):
'''
See functional.cast.
'''
return cast(self, dtype)
def size(self, dim=None):
'''
Returns the shape of the tensor if the dim parameter is None.
Otherwise, returns a size of the dimension indicated by dim. The
behavior is undefined if dim is negative or exceeds the rank of the
tensor.
'''
if dim is None:
return self.trt_tensor.shape
return self.trt_tensor.shape[dim]
def rank(self):
'''
Returns the rank (i.e. the number of dimensions) of the tensor.
'''
return len(self.trt_tensor.shape)
def ndim(self):
'''
Returns the rank (i.e. the number of dimensions) of the tensor.
'''
return self.rank()
def split(self, split_size_or_sections, dim=0):
'''
See functional.split.
'''
return split(self, split_size_or_sections, dim)
def is_dynamic(self, dim=None):
'''
If the argument 'dim' is None, that function returns a boolean that
indicates if the tensor contains a dynamic dimension (True) or not
(False). In that case, the first dimension is excluded (as it usually
corresponds to the batch size). If the argument is an integer, that
functions returns a boolean that indicates if the dimension 'dim' is
dynamic (True) or not (False).
'''
if dim is not None:
return self.trt_tensor.shape[dim] == -1
for i, s in enumerate(self.trt_tensor.shape):
if i != 0 and s == -1:
return True
return False
# graph writer related functions
def get_parent(self):
''' Get the layer that produces this tensor. '''
return self.network.get_tensor_parent(self)
def get_users(self):
''' Get the layers that use this tensor as an input. '''
return self.network.get_tensor_users(self)
def replace_all_uses_with(self, new_tensor):
'''
Replace all uses of this tensor as an input to consumer layers
'''
self.network.is_graph_altered = True
users = self.get_users()
for user in users:
inputs_changed = 0
for i in range(user.num_inputs):
if user.get_inputs(i)[0].trt_tensor is self.trt_tensor:
inputs_changed += 1
user.set_input(i, new_tensor.trt_tensor)
assert inputs_changed >= 1, "Tensor not found in layer inputs"
# update the FLayerMetadata as well
flayer = gw.FLayerInfoMemo.instance().get(user.name)
flayer and flayer.replace_input_with(self, new_tensor)
def is_trt_wrapper(self):
'''
Check if there is a trt.ITensor member inside, which is required for
graph rewriter. In order to differentiate usages, it may be necessary
to have an inheritance hierarchy.
'''
if hasattr(self, 'trt_tensor'):
return True
else:
return False
def __hash__(self):
if self.is_trt_wrapper():
return id(self.trt_tensor)
else:
return id(None)
def __repr__(self):
return f"TensorRT-LLM Tensor: {self.name=} {self.dtype=} {self.shape=}"
def _create_tensor(trt_tensor: trt.ITensor,
producer: trt.ILayer = None) -> Tensor:
'''
A helper function to create a TensorRT-LLM Tensor object that encapsulates
the connection between the TensorRT tensor (trt.ITensor) and the layer
(trt.ILayer) that produces it.
That function is expected to be used as:
# Insert a new layer in the network using the TensorRT API:
layer = default_trtnet().add_<some_layer>(...)
# Extract the first output of that layer and connect it to the layer.
return _create_tensor(layer.get_output(0), layer)
That function also sets the precision of the layer/producer to the default
precision of the network.
Parameters:
trt_tensor : trt.ITensor
The TensorRT tensor to connect to its producer (the layer).
producer : trt.ILayer = None
The producer.
Returns:
The TensorRT-LLM tensor (functional.Tensor) that encapsulates the
TensorRT tensor and the layer that produces it. The former is
accessible through the attribute 'trt_tensor' and the latter using the
attribute 'producer'.
'''
assert trt_tensor is not None
tensor = Tensor(name=trt_tensor.name,
dtype=trt_tensor.dtype,
shape=trt_tensor.shape,
is_network_input=False)
tensor.trt_tensor = trt_tensor
tensor.producer = producer
# Set the layer name since this is the only
# centralized location to pass the name from
# module space to the TRT IR
default_net()._set_layer_name(producer)
if default_net().dtype is not None and not default_net().strongly_typed:
if producer.type not in [
trt.LayerType.SHAPE, trt.LayerType.CONSTANT,
trt.LayerType.GATHER, trt.LayerType.CONCATENATION
]:
producer.precision = default_net().dtype
assert tensor is not None
if gw.FLayerInfoMemo.instance().cur_flayer is not None:
gw.FLayerInfoMemo.instance().cur_flayer.layer_name = producer.name
return tensor
def _add_plugin_info(layer, plugin_creator: trt.IPluginCreator,
plugin_name: str, pfc: trt.PluginFieldCollection) -> None:
plugin_info = PluginInfo(plugin_creator, plugin_name, pfc)
set_plugin_info(default_net().trt_network, layer.name, plugin_info)
class RotaryScalingType(IntEnum):
none = 0
linear = 1
dynamic = 2
class PositionEmbeddingType(IntEnum):
learned_absolute = 0
rope_gptj = 1
rope_gpt_neox = 2
alibi = 3
alibi_with_scale = 4
relative = 5
chatglm = 6
def is_rope(self) -> bool:
return self in [self.rope_gptj, self.rope_gpt_neox]
def is_alibi(self) -> bool:
return self in [self.alibi, self.alibi_with_scale]
@staticmethod
def choices() -> List[str]:
return [embedding.name for embedding in PositionEmbeddingType]
def __str__(self):
return self.name
@staticmethod
def from_string(s):
try:
return PositionEmbeddingType[s]
except KeyError:
raise ValueError(f'Unsupported position embedding type: {s}')
class AttentionMaskType(IntEnum):
padding = 0
causal = 1
bidirectional = 2
bidirectionalglm = 3 # TODO: merge this mask into bidirectional
class LayerNormType(IntEnum):
LayerNorm = 0
RmsNorm = 1
GroupNorm = 2
class LayerNormPositionType(IntEnum):
pre_layernorm = 0
post_layernorm = 1
class MLPType(IntEnum):
MLP = 0
GatedMLP = 1
FusedGatedMLP = 2
def activation(input: Tensor, act_type: trt.ActivationType) -> Tensor:
'''
Add an activation function.
Parameters:
input : Tensor
The input tensor on which the activation function is applied.
act_type : trt.ActivationType
The type of the activation (RELU, TANH, SIGMOID, ...).
The following closures are defined in functional.*:
relu for op=trt.ActivationType.RELU
tanh for op=trt.ActivationType.TANH
sigmoid for op=trt.ActivationType.SIGMOID
Returns:
The tensor produced by the activation layer.
'''
layer = default_trtnet().add_activation(input.trt_tensor, act_type)
return _create_tensor(layer.get_output(0), layer)
def clip(input: Tensor, alpha: float, beta: float) -> Tensor:
'''
Add a CLIP operation that sets the range to [alpha, beta].
Parameters:
input : Tensor
The input tensor on which the activation function is applied.
alpha : float
The lower bound of the CLIP function.
beta : float
The upper bound of the CLIP function.
Returns:
The tensor produced by the activation layer.
'''
layer = default_trtnet().add_activation(input.trt_tensor,
trt.ActivationType.CLIP)
layer.alpha = alpha
layer.beta = beta
return _create_tensor(layer.get_output(0), layer)
relu = partial(activation, act_type=trt.ActivationType.RELU)
tanh = partial(activation, act_type=trt.ActivationType.TANH)
sigmoid = partial(activation, act_type=trt.ActivationType.SIGMOID)
def silu(input: Tensor) -> Tensor:
'''
Add a SiLU (`x * sigmoid(x)`) operation.
Parameters:
input : Tensor
The input tensor on which the activation function is applied.
Returns:
The tensor produced by the activation layer.
'''
return input * sigmoid(input)
def swiglu(input: Tensor) -> Tensor:
'''
Add a SwiGLU (`x * SiLU(gate)`) operation.
That function takes a tensor, splits it into two halves along the last
dimension, applies SiLU to the second half and multiply the results. The
behavior is undefined if the last dimension is not even.
Parameters:
input : Tensor
The input tensor on which the activation function is applied.
Returns:
The tensor produced by the activation layer.
'''
x, gate = chunk(input, 2, dim=-1)
return silu(gate) * x
def squared_relu(x: Tensor) -> Tensor:
'''
Add a Squared ReLU operation.
This function applies ReLU and squares the output.
Parameters:
input : Tensor
The input tensor on which the activation function is applied.
Returns:
The tensor produced by the activation layer.
'''
return pow(relu(x), 2.0)
def cast(input: Tensor, dtype: Union[str, trt.DataType]):
'''
Add a cast operation.
For an input tensor of type INT8, this function sets the dynamic range of
the input to [-127, 127] for automatic dequantization. For a cast into
INT8, that function sets the dynamic range of the output to [-127, 127] for
automatic quantization.
Parameters:
input : Tensor
The input tensor on which the cast is applied.
dtype : str or trt.DataType
The data type of the output tensor after the cast. When 'dtype' is
provided as a string, it must be a name amongst the valid names.
See _str_to_trt_dtype_dict in _utils.py for a list of supported
types and type names.
Returns:
The tensor produced by the inserted layer.
'''
if isinstance(dtype, str):
cvt_dtype = str_dtype_to_trt(dtype)
elif isinstance(dtype, trt.DataType):
cvt_dtype = dtype
else:
raise TypeError("%s is not supported" % type(dtype))
if input.dtype == cvt_dtype:
# If input type and cast dtype are the same, do nothing
return input
layer = default_trtnet().add_cast(input.trt_tensor, cvt_dtype)
if not default_net().strongly_typed:
layer.set_output_type(0, cvt_dtype)
output = _create_tensor(layer.get_output(0), layer)
if not default_net().strongly_typed:
if input.dtype == str_dtype_to_trt('int8'):
layer.get_input(0).set_dynamic_range(-127, 127)
if cvt_dtype == str_dtype_to_trt('int8'):
layer.get_output(0).set_dynamic_range(-127, 127)
return output
def flip(input: Tensor, dims: Sequence[int]) -> Tensor:
'''
Reverses the order of an n-D tensor along given axis in dims.
That flip operation maps to a TensorRT ISliceLayer. For the dimensions
listed in dims it copies the elements from the last one to the first one
(from (N-1) down to 0 with a step of -1). For the dimensions not in 'dims',
it copies the elements from the first one to the last one (from 0 to N-1
with a step of 1).
Parameters:
input : Tensor
The input tensor on which the cast is applied.
dims : list or tuple
The axes to flip. Negative indices are supported.
Returns:
The tensor produced by the inserted layer.
'''
assert not input.is_dynamic()
ndim = input.ndim()
for index, value in enumerate(dims):
assert -ndim <= value < ndim
if -ndim <= value < 0:
dims[index] += ndim
assert len(dims) == len(set(dims))
start_values = [
input.size()[i] - 1 if i in dims else 0 for i in range(ndim)
]
stride_values = [-1 if i in dims else 1 for i in range(ndim)]
layer = default_trtnet().add_slice(input.trt_tensor,
start=start_values,
shape=input.size(),
stride=stride_values)
return _create_tensor(layer.get_output(0), layer)
def interpolate(input: Tensor,
size: Union[int, List[int]] = None,
scale_factor: Union[float, List[float]] = None,
mode: str = 'nearest',
align_corners: bool = False,
recompute_scale_factor: bool = False,
antialias: bool = False) -> Tensor:
##
## TODO: Document that function!
##
assert not input.is_dynamic()
input_ndim = input.ndim()
assert 2 < input_ndim < 6, "Only 3D, 4D and 5D input Tensors supported"
assert (size is not None) ^ (
scale_factor
is not None), "Only one of out_shape or scales should be defined"
assert mode in ('nearest', 'linear', 'bilinear', 'bicubic', 'trilinear',
'nearest-exact')
if mode == 'trilinear' and input_ndim != 5:
raise ValueError("trilinear only supports 5D tensor")
if mode == "bilinear" and input_ndim != 4:
raise ValueError("bilinear only supports 4D tensor")
if mode == "linear" and input_ndim != 3:
raise ValueError("linear only supports 3D tensor")
layer = default_trtnet().add_resize(input.trt_tensor)
input_shape = input.size()
updated_shape = []
if scale_factor:
scale_len = 1 if isinstance(scale_factor,
(float, int)) else len(scale_factor)
if scale_len == 1 and isinstance(scale_factor, (float, int)):
updated_scale = [scale_factor for _ in range(input_ndim - 2)]
else:
updated_scale = scale_factor
updated_shape = [
int(math.floor(updated_scale[i - 2] *
input_shape[i])) if i > 1 else input_shape[i]
for i in range(input_ndim)
]
else:
size_len = 1 if isinstance(size, int) else len(size)
assert size_len == input_ndim - 2
if size_len == 1 and isinstance(size, int):
updated_size = [size for _ in range(input_ndim - 2)]
else:
updated_size = size
updated_shape = [
input_shape[i] if i < 2 else updated_size[i - 2]
for i in range(input_ndim)
]
layer.shape = updated_shape
if mode in ['nearest', 'nearest-exact'] or mode is None:
layer.resize_mode = trt.InterpolationMode.NEAREST
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ASYMMETRIC
elif mode in ['linear', 'bilinear', 'trilinear']:
layer.resize_mode = trt.InterpolationMode.LINEAR
if align_corners:
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ALIGN_CORNERS
else:
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.HALF_PIXEL
# TODO, need to confirm the align_corners effect on bilinear mode.
if mode == 'bilinear':
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.HALF_PIXEL
elif mode in ['bicubic']:
layer.resize_mode = trt.InterpolationMode.CUBIC
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.HALF_PIXEL
else:
layer.resize_mode = trt.InterpolationMode.NEAREST
layer.coordinate_transformation = trt.ResizeCoordinateTransformation.ASYMMETRIC
return _create_tensor(layer.get_output(0), layer)
def matmul(input: Tensor,
mat2: Tensor,
transa: bool = False,
transb: bool = False,
use_fp32_acc: bool = True) -> Tensor:
'''
Add a matrix multiplication.
That operation maps to a tensorrt.IMatrixMultiplyLayer layer. As explained
in the TensorRT documentation, it computes the inner product between the
two inputs after applying an optional transposition on the inputs.
Parameters:
input : Tensor
The first tensor (often called A).
mat2 : Tensor
The second tensor (often called B).
transa : bool
Is the first input transposed? Set to 'True' if you want the first
input to be transposed, 'False' otherwise.
transb : bool
Is the second input transposed? Set to 'True' if you want the
second input to be transposed, 'False' otherwise.
use_fp32_acc: bool
Set to 'True' if for accuracy reason, this fp16 matmul needs to use
fp32 accumulation. This can be a per model and per matmul decision.
Returns:
The tensor produced by the inserted layer.
'''
# This option is only supported for fp16, but not bf16 or any other precisions.
use_fp32_acc = use_fp32_acc and input.dtype == trt.DataType.HALF and mat2.dtype == trt.DataType.HALF
# TODO: fp32 accum has issues with strongly_typed and it will be fixed in TensorRT 10.0
if default_net().strongly_typed:
use_fp32_acc = False
if use_fp32_acc:
input = cast(input, 'float32')
mat2 = cast(mat2, 'float32')
input, mat2 = broadcast_helper(input, mat2)
op0 = trt.MatrixOperation.TRANSPOSE if transa \
else trt.MatrixOperation.NONE
op1 = trt.MatrixOperation.TRANSPOSE if transb \
else trt.MatrixOperation.NONE
layer = default_trtnet().add_matrix_multiply(input.trt_tensor, op0,