Skip to content

Commit

Permalink
migrate static quant tutorials to direct configuration (#1710)
Browse files Browse the repository at this point in the history
* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]

* Update

[ghstack-poisoned]
  • Loading branch information
vkuzo authored Feb 14, 2025
1 parent 17b9ce3 commit 3fa8e44
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 133 deletions.
114 changes: 65 additions & 49 deletions tutorials/calibration_flow/awq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"""

import copy
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx_static,
Expand All @@ -33,6 +35,9 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error


Expand Down Expand Up @@ -83,61 +88,72 @@ def replacement_fn(m):
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)


@dataclass
class ApplyAWQConfig(AOBaseConfig):
target_dtype: torch.dtype


# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_awq(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_awq_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)
@register_quantize_module_handler(ApplyAWQConfig)
def _apply_awq_transform(
module: torch.nn.Module,
config: ApplyAWQConfig,
):
target_dtype = config.target_dtype
observed_linear = module

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)
# target_dtype = torch.uint8
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

return linear
linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)

return _apply_awq_to_linear
return linear


######## Test ##########
Expand Down Expand Up @@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_awq(target_dtype), is_observed_linear)
quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
Expand Down
66 changes: 38 additions & 28 deletions tutorials/calibration_flow/gptq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
to_affine_quantized_intx,
to_affine_quantized_intx_static,
Expand All @@ -47,6 +48,9 @@
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error

torch.manual_seed(0)
Expand Down Expand Up @@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module):
)


# using a function to align with the API in quant_api
def apply_activation_static_weight_quant():
def _apply_activation_static_weight_quant(observed_linear):
target_dtype = torch.uint8

# we can quantize the weight here as well
class ApplyActivationStaticWeightQuantConfig(AOBaseConfig):
pass

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear
# using a function to align with the API in quant_api
@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig)
def _apply_activation_static_weight_quant_transform(
module: torch.nn.Module,
config: ApplyActivationStaticWeightQuantConfig,
):
observed_linear = module
target_dtype = torch.uint8

# we can quantize the weight here as well

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

return _apply_activation_static_weight_quant
del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear


example_inputs = (torch.randn(32, 64),)
Expand All @@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear):

# just quantizing activation since we only observed quantization, this could be extended to support
# quantizing weight as well
quantize_(m, apply_activation_static_weight_quant(), _is_linear)
quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear)
for l in m.modules():
if isinstance(l, torch.nn.Linear):
assert isinstance(l.weight, LinearActivationQuantizedTensor)
Expand Down
Loading

0 comments on commit 3fa8e44

Please sign in to comment.