diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b81e928aa6..13f3800891 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,11 +33,14 @@ float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, float8_weight_only, + fpx_weight_only, + gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_weight_only, + uintx_weight_only, ) from torchao.quantization.quant_primitives import MappingType from torchao.quantization.subclass import ( @@ -55,6 +58,13 @@ unwrap_tensor_subclass, ) +try: + import gemlite # noqa: F401 + + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + def dynamic_quant(model, example_inputs): m = torch.export.export(model, example_inputs, strict=True).module() @@ -804,6 +814,9 @@ def test_int4wo_cpu(self, dtype, x_dim): int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int4_weight(), int8_weight_only(), + fpx_weight_only(ebits=4, mbits=3), + gemlite_uintx_weight_only(), + uintx_weight_only(dtype=torch.uint4), ], ) def test_workflow_e2e_numerics(self, config): @@ -827,17 +840,23 @@ def test_workflow_e2e_numerics(self, config): and is_sm_at_least_90() ): return unittest.skip("only supported on CUDA capability 8.9, not greater") + elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite: + return unittest.skip("gemlite not available") # scale has to be moved to cuda here because the parametrization init # code happens before gating for cuda availability if isinstance(config, float8_static_activation_float8_weight): config.scale = config.scale.to("cuda") + dtype = torch.bfloat16 + if isinstance(config, gemlite_uintx_weight_only): + dtype = torch.float16 + # set up inputs - x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + x = torch.randn(128, 128, device="cuda", dtype=dtype) # TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469 # is that expected? - m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16() + m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype) m_q = copy.deepcopy(m_ref) # quantize diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b4f6c86252..1e968e557b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -729,12 +729,8 @@ def _int4_dynamic_activation_int4_weight_transform( return module -def gemlite_uintx_weight_only( - group_size: Optional[int] = 64, - bit_width: int = 4, - packing_bitwidth: int = 32, - contiguous: Optional[bool] = None, -): +@dataclass +class GemliteUIntXWeightOnlyConfig(AOBaseConfig): """ applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. @@ -747,16 +743,39 @@ def gemlite_uintx_weight_only( `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. """ + group_size: Optional[int] = 64 + bit_width: int = 4 + packing_bitwidth: int = 32 + contiguous: Optional[bool] = None + + +# for BC +gemlite_uintx_weight_only = GemliteUIntXWeightOnlyConfig + + +@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) +def _gemlite_uintx_weight_only_transform( + module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig +): + group_size = config.group_size + bit_width = config.bit_width + packing_bitwidth = config.packing_bitwidth + contiguous = config.contiguous + + weight = module.weight + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False - apply_fn = lambda weight: to_affine_quantized_intx( + new_weight = to_affine_quantized_intx( weight, **get_gemlite_aqt_kwargs( weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq ), ) - return _get_linear_subclass_inserter(apply_fn) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module @dataclass @@ -1380,9 +1399,10 @@ def _float8_static_activation_float8_weight_transform( return module -def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): +@dataclass +class UIntXWeightOnlyConfig(AOBaseConfig): """ - Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + Configuration for applying uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where x is the number of bits specified by `dtype` Args: @@ -1392,6 +1412,28 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): `pack_dim`: the dimension we use for packing, defaults to -1 `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight """ + + dtype: torch.dtype + group_size: int = 64 + pack_dim: int = -1 + use_hqq: bool = False + + +# for BC +uintx_weight_only = UIntXWeightOnlyConfig + + +@register_quantize_module_handler(UIntXWeightOnlyConfig) +def _uintx_weight_only_transform( + module: torch.nn.Module, config: UIntXWeightOnlyConfig +): + dtype = config.dtype + group_size = config.group_size + pack_dim = config.pack_dim + use_hqq = config.use_hqq + + weight = module.weight + from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS SUPPORTED_DTYPES = { @@ -1406,49 +1448,50 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): } assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" - def apply_uintx_weight_only_quant(weight, dtype): - mapping_type = MappingType.ASYMMETRIC - block_size = (1, group_size) - - if use_hqq: - if dtype == torch.uint4: - logger.warn( - "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" - ) - quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] - dtype = torch.uint8 - eps = None - zero_point_dtype = None - zero_point_domain = ZeroPointDomain.FLOAT - preserve_zero = False - _layout = PlainLayout() - else: - quant_min, quant_max = None, None - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT - preserve_zero = True - _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) - return to_affine_quantized_intx( - weight, - mapping_type, - block_size, - dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - zero_point_dtype=zero_point_dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - _layout=_layout, - use_hqq=use_hqq, - ) + if use_hqq: + if dtype == torch.uint4: + logger.warn( + "Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance" + ) + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + _layout = PlainLayout() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + _layout = UintxLayout(dtype=dtype, pack_dim=pack_dim) - return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, + _layout=_layout, + use_hqq=use_hqq, + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module -def fpx_weight_only(ebits: int, mbits: int): +@dataclass +class FPXWeightOnlyConfig(AOBaseConfig): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits e.g. fp6_e3_m2, fp6_e2_m3, ... The packing format and kernels are from the fp6-llm paper: https://arxiv.org/abs/2401.14112 @@ -1459,26 +1502,40 @@ def fpx_weight_only(ebits: int, mbits: int): in the future """ - def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: - from torchao.dtypes import to_affine_quantized_fpx - from torchao.dtypes.floatx import FloatxTensorCoreLayout + ebits: int + mbits: int - assert ( - weight.dim() == 2 - ), f"floatx only works for 2-d Tensor, got: {weight.dim()}" - out_dim, in_dim = weight.shape - if (in_dim % 64 != 0) or (out_dim % 256 != 0): - logger.info( - f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " - f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " - "expected in_dim % 64 == 0 and out_dim % 256 == 0" - ) - return weight - _layout = FloatxTensorCoreLayout(ebits, mbits) - return to_affine_quantized_fpx(weight, _layout) +# for BC +fpx_weight_only = FPXWeightOnlyConfig + + +@register_quantize_module_handler(FPXWeightOnlyConfig) +def _fpx_weight_only_transform( + module: torch.nn.Module, config: FPXWeightOnlyConfig +) -> torch.nn.Module: + ebits = config.ebits + mbits = config.mbits + weight = module.weight + + from torchao.dtypes import to_affine_quantized_fpx + from torchao.dtypes.floatx import FloatxTensorCoreLayout - return _get_linear_subclass_inserter(apply_quant_llm) + assert weight.dim() == 2, f"floatx only works for 2-d Tensor, got: {weight.dim()}" + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + logger.info( + f"Skipping floatx quantization float{ebits + mbits + 1}_{ebits}_{mbits} because " + f"the shape is not compatible with the kernel: in_dim={in_dim}, out_dim={out_dim} " + "expected in_dim % 64 == 0 and out_dim % 256 == 0" + ) + return module + + _layout = FloatxTensorCoreLayout(ebits, mbits) + new_weight = to_affine_quantized_fpx(weight, _layout) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module if TORCH_VERSION_AT_LEAST_2_5: