diff --git a/mqbench/prepare_by_platform.py b/mqbench/prepare_by_platform.py index 9755b12..deed3d8 100644 --- a/mqbench/prepare_by_platform.py +++ b/mqbench/prepare_by_platform.py @@ -64,8 +64,10 @@ class BackendType(Enum): default_weight_observer=MinMaxObserver, default_act_observer=EMAMinMaxObserver), BackendType.Tensorrt: dict(qtype='affine', # noqa: E241 - w_qscheme=QuantizeScheme(symmetry=True, per_channel=True, pot_scale=False, bit=8, symmetric_range=True), - a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True), + w_qscheme=QuantizeScheme(symmetry=True, per_channel=True, pot_scale=False, bit=8, symmetric_range=True, + factory_kwargs={'not_calc_quant_min_max': True}), + a_qscheme=QuantizeScheme(symmetry=True, per_channel=False, pot_scale=False, bit=8, symmetric_range=True, + factory_kwargs={'not_calc_quant_min_max': True}), default_weight_quantize=LearnableFakeQuantize, default_act_quantize=LearnableFakeQuantize, default_weight_observer=MinMaxObserver, diff --git a/mqbench/scheme.py b/mqbench/scheme.py index 80a2ea8..401b370 100644 --- a/mqbench/scheme.py +++ b/mqbench/scheme.py @@ -16,6 +16,10 @@ def __init__(self, symmetry=True, per_channel=False, pot_scale=False, bit=8, **k if 'symmetric_range' in kwargs: self.symmetric_range = kwargs['symmetric_range'] del kwargs['symmetric_range'] + assert isinstance(kwargs.get('factory_kwargs', None), dict) \ + and kwargs['factory_kwargs'].get('not_calc_quant_min_max', False), \ + "QuantizeScheme with `symmetric_range=True` should provide kwargs " \ + "factory_kwargs={not_calc_quant_min_max: True, ...}" else: self.symmetric_range = False self.kwargs = kwargs