From 08b347cff4b222635779436bd3c2d8ace974aaaa Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Wed, 27 Nov 2024 14:35:25 +0100 Subject: [PATCH 01/25] higgs init --- .../Dockerfile | 3 + src/transformers/__init__.py | 2 + src/transformers/integrations/__init__.py | 2 + src/transformers/integrations/higgs.py | 492 ++++++++++++++++++ src/transformers/quantizers/auto.py | 3 + .../quantizers/quantizer_higgs.py | 130 +++++ 6 files changed, 632 insertions(+) create mode 100644 src/transformers/integrations/higgs.py create mode 100644 src/transformers/quantizers/quantizer_higgs.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index a8f131164eb4ae..a7bf472033c51b 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -64,6 +64,9 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto # Add eetq for quantization testing RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git +# Add FLUTE for quantization testing +RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6 + # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. RUN cd transformers && python3 setup.py develop diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fa54ced6a13486..6b7ec5af37c872 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -981,6 +981,7 @@ "CompressedTensorsConfig", "EetqConfig", "FbgemmFp8Config", + "FluteConfig", "GPTQConfig", "HqqConfig", "QuantoConfig", @@ -5925,6 +5926,7 @@ CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, + FluteConfig, GPTQConfig, HqqConfig, QuantoConfig, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 093e0af29844e4..228474f1034e00 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -54,6 +54,7 @@ ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], + "higgs": ["HiggsLinear", "replace_with_higgs_linear"], "fsdp": ["is_fsdp_managed_module"], "ggml": [ "GGUF_CONFIG_MAPPING", @@ -156,6 +157,7 @@ ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear + from .higgs import HiggsLinear, replace_with_higgs_linear from .fsdp import is_fsdp_managed_module from .ggml import ( GGUF_CONFIG_MAPPING, diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py new file mode 100644 index 00000000000000..a015fd71055614 --- /dev/null +++ b/src/transformers/integrations/higgs.py @@ -0,0 +1,492 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file" + +from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available + + +# if is_torch_available(): +import torch +import torch.nn as nn + + +# if is_flute_available(): +import flute.utils + +if is_hadamard_available(): + from fast_hadamard_transform import hadamard_transform + +if is_flute_available(): + import flute.utils + from flute.integrations.higgs import prepare_data_transposed + + +def pad_to_block(tensor, dims, had_block_size, value=0): + pad_dims = [0 for _ in range(2 * len(tensor.shape))] + for dim in dims: + size = tensor.shape[dim] + next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size + delta = next_multiple_of_1024 - size + pad_dims[-2 * dim - 1] = delta + + return nn.functional.pad(tensor, pad_dims, "constant", value) + + +def get_higgs_grid(p: int, n: int) -> torch.Tensor: + if (p, n) == (2, 256): + return torch.tensor( + [[-2.501467704772949, 0.17954708635807037], + [-0.6761789321899414, 1.2728623151779175], + [-1.8025816679000854, 0.7613157629966736], + [-0.538287878036499, -2.6028504371643066], + [0.8415029644966125, -0.8600977659225464], + [0.7023013234138489, 3.3138747215270996], + [0.5699077844619751, 2.5782253742218018], + [3.292393207550049, -0.6016128063201904], + [0.5561617016792297, -1.7723814249038696], + [-2.1012380123138428, 0.020958125591278076], + [0.46085724234580994, 0.8428705334663391], + [1.4548040628433228, -0.6156039237976074], + [3.210029363632202, 0.3546904921531677], + [0.8893890976905823, -0.5967988967895508], + [0.8618854284286499, -3.2061192989349365], + [1.1360996961593628, -0.23852407932281494], + [1.6646337509155273, -0.9265465140342712], + [1.4767773151397705, 1.2476022243499756], + [-1.0511897802352905, 1.94503915309906], + [-1.56318998336792, -0.3264186680316925], + [-0.1829211413860321, 0.2922491431236267], + [-0.8950616717338562, -1.3887052536010742], + [-0.08206957578659058, -1.329533576965332], + [-0.487422913312912, 1.4817842245101929], + [-1.6769757270812988, -2.8269758224487305], + [-1.5057679414749146, 1.8905963897705078], + [1.8335362672805786, 1.0515104532241821], + [0.3273945450782776, 1.0491033792495728], + [-3.295924186706543, -0.7021600008010864], + [-1.8428784608840942, -1.2315762042999268], + [-0.8575026392936707, -1.7005949020385742], + [-1.120667815208435, 0.6467998027801514], + [-0.1588846743106842, -1.804071068763733], + [-0.8539647459983826, 0.5645008683204651], + [-1.4192019701004028, -0.6175029873847961], + [1.0799058675765991, 1.7871345281600952], + [1.171311855316162, 0.7511613965034485], + [2.162078380584717, 0.8044339418411255], + [1.3969420194625854, -1.243762493133545], + [-0.23818807303905487, 0.053944624960422516], + [2.304199457168579, -1.2667627334594727], + [1.4225027561187744, 0.568610668182373], + [0.376836895942688, -0.7134661674499512], + [2.0404467582702637, 0.4087389409542084], + [0.7639489769935608, -1.1367933750152588], + [0.3622530400753021, -1.4827953577041626], + [0.4100743532180786, 0.36108437180519104], + [-1.5867475271224976, -1.618212342262268], + [-2.2769672870635986, -1.2132309675216675], + [0.9184022545814514, -0.34428009390830994], + [-0.3902314603328705, 0.21785245835781097], + [3.120687484741211, 1.3077973127365112], + [1.587440848350525, -1.6506884098052979], + [-1.718808889389038, -0.038405973464250565], + [-0.6888407468795776, -0.8402308821678162], + [-0.7981445789337158, -1.1117373704910278], + [-2.4124443531036377, 1.3419722318649292], + [-0.6611530184745789, 0.9939885139465332], + [-0.33103418350219727, -0.16702833771705627], + [-2.4091389179229736, -2.326857566833496], + [1.6610108613967896, -2.159703254699707], + [0.014884627424180508, 0.3887578248977661], + [0.029668325558304787, 1.8786455392837524], + [1.180362582206726, 2.699317216873169], + [1.821286678314209, -0.5960053205490112], + [-0.44835323095321655, 3.327436685562134], + [-0.3714401423931122, -2.1466753482818604], + [-1.1103475093841553, -2.4536871910095215], + [-0.39110705256462097, 0.6670510172843933], + [0.474752813577652, -1.1959707736968994], + [-0.013110585510730743, -2.52519154548645], + [-2.0836575031280518, -1.703289270401001], + [-1.1077687740325928, -0.1252644956111908], + [-0.4138077199459076, 1.1837692260742188], + [-1.977599024772644, 1.688241720199585], + [-1.659559965133667, -2.1387736797332764], + [0.03242531046271324, 0.6526556015014648], + [0.9127950072288513, 0.6099498867988586], + [-0.38478314876556396, 0.433487206697464], + [0.27454206347465515, -0.27719801664352417], + [0.10388526320457458, 2.2812814712524414], + [-0.014394169673323631, -3.177137613296509], + [-1.2871228456497192, -0.8961855173110962], + [0.5720916986465454, -0.921597957611084], + [1.1159656047821045, -0.7609877586364746], + [2.4383342266082764, -2.2983546257019043], + [-0.294057160615921, -0.9770799875259399], + [-0.9342701435089111, 1.107579231262207], + [-1.549338698387146, 3.090520143508911], + [2.6076579093933105, 2.051239013671875], + [-0.9259037375450134, 1.407211184501648], + [-0.1747353971004486, 0.540488600730896], + [-0.8963701725006104, 0.8271111249923706], + [0.6480194926261902, 1.0128909349441528], + [0.980783998966217, -0.06156221032142639], + [-0.16883476078510284, 1.0601658821105957], + [0.5839992761611938, 0.004697148688137531], + [-0.34228450059890747, -1.2423977851867676], + [2.500824451446533, 0.3665279746055603], + [-0.17641609907150269, 1.3529551029205322], + [0.05378641560673714, 2.817232847213745], + [-1.2391047477722168, 2.354328155517578], + [0.630434513092041, -0.668536365032196], + [1.7576488256454468, 0.6738647818565369], + [0.4435231387615204, 0.6000469326972961], + [-0.08794835954904556, -0.11511358618736267], + [1.6540337800979614, 0.33995017409324646], + [-0.04202975332736969, -0.5375117063522339], + [-0.4247745871543884, -0.7897617220878601], + [0.06695003807544708, 1.2000739574432373], + [-3.2508881092071533, 0.28734830021858215], + [-1.613816261291504, 0.4944162368774414], + [1.3598989248275757, 0.26117825508117676], + [2.308382511138916, 1.3462618589401245], + [-1.2137469053268433, -1.9254342317581177], + [-0.4889402985572815, 1.8136259317398071], + [-0.1870335340499878, -0.3480615019798279], + [1.0766386985778809, -1.0627082586288452], + [0.4651014506816864, 2.131748914718628], + [-0.1306295394897461, -0.7811847925186157], + [0.06433182954788208, -1.5397958755493164], + [-0.2894323468208313, -0.5789554715156555], + [-0.6081662178039551, 0.4845278263092041], + [2.697964668273926, -0.18515698611736298], + [0.1277363896369934, -0.7221432328224182], + [0.8700758218765259, 0.35042452812194824], + [0.22088994085788727, 0.495242178440094], + [-2.5843818187713623, -0.8000828623771667], + [0.6732649803161621, -1.4362232685089111], + [-1.5286413431167603, 1.0417330265045166], + [-1.1222513914108276, -0.6269875764846802], + [-0.9752035140991211, -0.8750635385513306], + [-2.6369473934173584, 0.6918523907661438], + [0.14478731155395508, -0.041986867785453796], + [-1.5629483461380005, 1.4369450807571411], + [0.38952457904815674, -2.16428804397583], + [-0.16885095834732056, 0.7976621985435486], + [-3.12416934967041, 1.256506085395813], + [0.6843105554580688, -0.4203019142150879], + [1.9345275163650513, 1.934950351715088], + [0.012184220366179943, -2.1080918312072754], + [-0.6350273489952087, 0.7358828186988831], + [-0.837304949760437, -0.6214472651481628], + [0.08211923390626907, -0.9472538232803345], + [2.9332995414733887, -1.4956780672073364], + [1.3806978464126587, -0.2916182279586792], + [0.06773144006729126, 0.9285762310028076], + [-1.1943119764328003, 1.5963770151138306], + [1.6395620107650757, -0.32285431027412415], + [-1.390851378440857, -0.08273141086101532], + [1.816330909729004, -1.2812227010726929], + [0.7921574711799622, -2.1135804653167725], + [0.5817914605140686, 1.2644577026367188], + [1.929347038269043, -0.2386285960674286], + [0.8877345323562622, 1.190008521080017], + [1.4732073545455933, 0.8935023546218872], + [-2.8518524169921875, -1.5478795766830444], + [0.2439267635345459, 0.7576767802238464], + [0.5246709585189819, -2.606659412384033], + [1.150876760482788, 1.4073830842971802], + [-0.2643202245235443, 2.0634236335754395], + [1.555483341217041, -0.0023102816194295883], + [2.0830578804016113, -1.7225427627563477], + [-0.5424830317497253, -1.070199728012085], + [0.9168899655342102, 0.8955540060997009], + [-0.8120972514152527, 2.696739912033081], + [-0.29908373951911926, -1.5310651063919067], + [1.2320337295532227, -1.556247353553772], + [1.8612544536590576, 0.08704725652933121], + [0.22133447229862213, -1.8091708421707153], + [-0.4403655230998993, -0.38571012020111084], + [-1.88539457321167, 1.192205786705017], + [2.239687919616699, 0.004709010478109121], + [1.139495611190796, 0.45733731985092163], + [-1.507995367050171, 0.19716016948223114], + [0.46986445784568787, 1.5422041416168213], + [-1.2573751211166382, -0.35984551906585693], + [-1.7415345907211304, -0.6020717024803162], + [1.0751984119415283, 0.19006384909152985], + [2.24186635017395, -0.46343153715133667], + [0.3610347509384155, -0.07658443599939346], + [-1.3111497163772583, 0.432013601064682], + [0.6164408326148987, 0.24538464844226837], + [-1.9266542196273804, -0.3256155550479889], + [-0.5870336890220642, -0.1879584938287735], + [-1.0476511716842651, 0.3677721917629242], + [-1.229940414428711, 1.2433830499649048], + [0.18550436198711395, 0.22753673791885376], + [-0.017921989783644676, 0.12625974416732788], + [1.1659504175186157, -0.5020995736122131], + [-0.5983408093452454, -1.40438973903656], + [0.7519024014472961, -0.16282692551612854], + [0.9920787811279297, -1.344896912574768], + [-0.8103678226470947, 0.3064485788345337], + [0.6956969499588013, 1.8208192586898804], + [-2.7830491065979004, -0.2299390584230423], + [-0.34681546688079834, 2.4890666007995605], + [-1.4452646970748901, -1.2216600179672241], + [-2.1872897148132324, 0.8926076292991638], + [1.706072211265564, -2.8440372943878174], + [1.1119003295898438, -2.4923460483551025], + [-2.582794666290283, 2.0973289012908936], + [0.04987720400094986, -0.2964983284473419], + [-2.063807487487793, -0.7847916483879089], + [-0.4068813621997833, 0.9135897755622864], + [-0.9814359545707703, -0.3874954879283905], + [-1.4227229356765747, 0.7337291240692139], + [0.3065044581890106, 1.3125417232513428], + [1.2160996198654175, -1.9643305540084839], + [-1.2163853645324707, 0.14608727395534515], + [-2.3030710220336914, -0.37558120489120483], + [0.9232977628707886, 2.1843791007995605], + [-0.1989777386188507, 1.651851773262024], + [-0.714374840259552, -0.39365994930267334], + [-0.7805715799331665, -2.099881887435913], + [0.9015759229660034, -1.7053706645965576], + [0.1033422127366066, 1.5256654024124146], + [-1.8773194551467896, 2.324174165725708], + [1.9227174520492554, 2.7441604137420654], + [-0.5994020104408264, 0.23984014987945557], + [1.3496100902557373, -0.9126054644584656], + [-0.8765304088592529, -3.1877026557922363], + [-1.2040035724639893, -1.5169521570205688], + [1.4261796474456787, 2.150200128555298], + [1.463774561882019, 1.6656692028045654], + [0.20364105701446533, -0.4988172650337219], + [0.5195154547691345, -0.24067887663841248], + [-1.1116786003112793, -1.1599653959274292], + [-0.8490808606147766, -0.1681060940027237], + [0.3189965784549713, -0.9641751646995544], + [-0.5664751529693604, -0.5951744318008423], + [-1.6347930431365967, -0.9137664437294006], + [0.44048091769218445, -0.47259435057640076], + [-2.147747039794922, 0.47442489862442017], + [1.834734320640564, 1.4462147951126099], + [1.1777573823928833, 1.0659226179122925], + [-0.9568989872932434, 0.09495053440332413], + [-1.838529348373413, 0.2950586676597595], + [-0.4800611734390259, 0.014894310384988785], + [-0.5235516428947449, -1.7687653303146362], + [2.0735011100769043, -0.8825281262397766], + [2.637502431869507, 0.8455678224563599], + [2.606602907180786, -0.7848446369171143], + [-1.1886937618255615, 0.9330510497093201], + [0.38082656264305115, 0.13328030705451965], + [0.6847941875457764, 0.7384101152420044], + [1.2638574838638306, -0.007309418171644211], + [0.18292222917079926, -1.22371244430542], + [0.8143821954727173, 1.4976691007614136], + [0.6571850776672363, 0.48368802666664124], + [-0.6991601586341858, 2.150190830230713], + [0.8101756572723389, 0.10206498205661774], + [-0.08768226951360703, -1.084917664527893], + [-0.7208092212677002, 0.03657956421375275], + [0.3211449086666107, 1.803687334060669], + [-0.7835946083068848, 1.6869111061096191]] + ) + else: + raise NotImplementedError(f"Unsupported p={p}, n={n}") + + +def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): + assert len(weight.shape) == 2, "Only 2D weights are supported for now" + assert weight.device.type == "cuda", "Only CUDA devices are supported for now" + + grid = get_higgs_grid(p, 2**(p * bits)).to(weight.device) + grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2 + + weight = weight.clone().float() + # Pad to Hadamard transform size + weight = pad_to_block(weight, [1], 1024) + + # Scale and Hadamard transform + mult = weight.shape[1] // 1024 + weight = weight.reshape(-1, mult, 1024) + scales = torch.linalg.norm(weight, axis=-1) + weight = torch.ops.fast_hadamard_transform.fast_hadamard_transform(weight, 1) / scales[:, :, None] + + # Pad to edenn_d and project + weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p) + + # Quantize + codes = torch.empty(weight.shape[:-1], device=weight.device, dtype=torch.uint8) + for i in range(0, weight.shape[0], 64): + codes[i:i+64] = torch.argmax( + 2 * weight[i:i+64] @ grid.T - grid_norm_2, dim=-1 + ).to(torch.uint8) + + codes = codes.reshape(codes.shape[0], -1) + scales = scales / 32 + + weight, scales, tables, tables2 = prepare_data_transposed( + codes, + torch.repeat_interleave(scales.half(), 1024 // 256, dim=1), + grid.half(), + num_bits=bits, + group_size=256, + vector_size=p, + dtype=torch.float16, + device=weight.device, + ) + + return { + "weight": weight, + "scales": scales, + "tables": tables, + "tables2": tables2, + } + + +class HiggsLinear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + num_bits: int, + group_size: int, + num_sms_packed: int, + bias=True, + dtype: torch.dtype=None, + device: torch.device=None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.num_bits = num_bits + self.group_size = group_size + self.num_sms_packed = num_sms_packed + + self.workspace = flute.utils.make_workspace_streamk(device=device) + + assert in_features % 16 == 0 + assert in_features % group_size == 0 + assert num_bits in [2, 3, 4] + + self.weight = nn.Parameter(torch.empty((in_features * num_bits // 16, out_features), dtype=torch.int16, device=device), requires_grad=False) + self.scales = nn.Parameter(torch.empty((out_features, in_features//group_size), dtype=dtype, device=device), requires_grad=False) + self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False) + self.tables2 = nn.Parameter(torch.empty((2**num_bits, 2**num_bits, 1), dtype=torch.float32, device=device), requires_grad=False) + + if bias: + self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False) + else: + self.register_parameter("bias", None) + + def forward(self, x): + x = pad_to_block(x, [-1], 1024) + + orig_shape = x.shape + x = x.reshape(-1, 1024) + x = hadamard_transform(x, scale=1/32) + x = x.reshape(orig_shape) + + return flute.qgemm_simple( + x, + self.weight, + self.scales, + self.tables, + self.tables2, + self.workspace, + self.num_bits, + self.group_size, + ) + + +def replace_with_higgs_linear( + model, + quantization_config=None, + linear_weights_not_to_quantize=None, + current_key_name=None, + has_been_replaced=False, +): + """ + Public method that recursively replaces the Linear layers of the given model with HIGGS quantized layers. + `accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the + conversion has been successfull or not. + + Args: + model (`torch.nn.Module`): + The model to convert, can be any `torch.nn.Module` instance. + quantization_config (`HiggsConfig`): + The quantization config object that contains the quantization parameters. + linear_weights_not_to_quantize (`list[str]`, *optional*): + A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be + converted. + current_key_name (`list`, *optional*): + A list that contains the current key name. This is used for recursion and should not be passed by the user. + has_been_replaced (`bool`, *optional*): + A boolean that indicates if the conversion has been successful or not. This is used for recursion and + should not be passed by the user. + """ + if not is_flute_available(): + raise ValueError("FLUTE is not available. Please install it with `pip install flute-kernel`") + + if not is_hadamard_available(): + raise ValueError("Fast Hadamard Transform is not available. Please install it with `pip install fast_hadamard_transform`") + + if not is_accelerate_available(): + raise ValueError( + f"HIGGS requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + + if linear_weights_not_to_quantize is None: + linear_weights_not_to_quantize = [] + + from accelerate import init_empty_weights + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear): + # Check if the current key is not in the `linear_weights_not_to_quantize` + if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize: + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = HiggsLinear( + in_features, + out_features, + bias=module.bias is not None, + num_bits=quantization_config.num_bits, + group_size=quantization_config.group_size, + num_sms_packed=quantization_config.num_sms_packed, + ) + has_been_replaced = True + + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = replace_with_higgs_linear( + module, + quantization_config=quantization_config, + linear_weights_not_to_quantize=linear_weights_not_to_quantize, + current_key_name=current_key_name, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index 38bebd2d8410e4..b6a03653b80157 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -23,6 +23,7 @@ CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, + HiggsConfig, GPTQConfig, HqqConfig, QuantizationConfigMixin, @@ -38,6 +39,7 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer +from .quantizer_higgs import HiggsHfQuantizer from .quantizer_gptq import GptqHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer @@ -70,6 +72,7 @@ "hqq": HqqConfig, "compressed-tensors": CompressedTensorsConfig, "fbgemm_fp8": FbgemmFp8Config, + "higgs": HiggsConfig, "torchao": TorchAoConfig, "bitnet": BitNetConfig, } diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py new file mode 100644 index 00000000000000..065e9b53a06d8b --- /dev/null +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -0,0 +1,130 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import TYPE_CHECKING, Optional, Dict, List, Any + +from packaging import version + +from .base import HfQuantizer +from .quantizers_utils import get_module_from_name + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..integrations import replace_with_higgs_linear, quantize_with_higgs +from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + +# if is_torch_available(): +import torch + +logger = logging.get_logger(__name__) + + +# Finds the parent of a node module named "name" +def find_parent(model, name): + module_tree = name.split(".")[:-1] + parent = model + for m in module_tree: + parent = parent._modules[m] + return parent + + +class AqlmHfQuantizer(HfQuantizer): + """ + Quantizer of the AQLM method. Enables the loading of prequantized models. + """ + + requires_calibration = True + required_packages = ["aqlm"] + optimum_quantizer = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + self.quantization_config = quantization_config + + def validate_environment(self, *args, **kwargs): + if not is_accelerate_available(): + raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`") + + if not is_flute_available(): + raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel`") + + if not is_hadamard_available(): + raise ImportError("Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + if torch.cuda.is_available(): + torch_dtype = torch.float16 + logger.info( + "CUDA available. Assuming HIGGS inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." + ) + else: + raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.") + return torch_dtype + + def create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + """ + Quantizes weights into weight and weight_scale + """ + + flute_dict = quantize_with_higgs( + param_value, + self.quantization_config.bits, + self.quantization_config.p, + ) + + raise NotImplementedError("This function is not implemented yet.") + + module, tensor_name = get_module_from_name(model, param_name) + module._buffers[tensor_name] = new_value.to(target_device) + # to have the right output shape -> (out_features, 1) + module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device) + + if unexpected_keys is not None and param_name in unexpected_keys: + unexpected_keys.remove(param_name) + del param_name + + def _process_model_before_weight_loading( + self, + model: "PreTrainedModel", + **kwargs, + ): + replace_with_higgs_linear( + model, + quantization_config=self.quantization_config, + linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + return model + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return False + + def is_serializable(self, safe_serialization=None): + return True From 14a0c82c5a46260c2bff17223d9801fce6eac4b2 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Wed, 27 Nov 2024 15:58:28 +0100 Subject: [PATCH 02/25] working with crunches --- src/transformers/__init__.py | 3 +- src/transformers/integrations/__init__.py | 4 +- src/transformers/integrations/higgs.py | 31 ++++++------- src/transformers/quantizers/auto.py | 1 + .../quantizers/quantizer_higgs.py | 45 ++++++++++++++----- src/transformers/testing_utils.py | 2 + src/transformers/utils/__init__.py | 2 + src/transformers/utils/import_utils.py | 10 +++++ src/transformers/utils/quantization_config.py | 43 ++++++++++++++++++ 9 files changed, 109 insertions(+), 32 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6b7ec5af37c872..c8315734c5659e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -983,6 +983,7 @@ "FbgemmFp8Config", "FluteConfig", "GPTQConfig", + "HiggsConfig", "HqqConfig", "QuantoConfig", "TorchAoConfig", @@ -5926,7 +5927,7 @@ CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, - FluteConfig, + HiggsConfig, GPTQConfig, HqqConfig, QuantoConfig, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 228474f1034e00..4d239680298366 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -54,7 +54,7 @@ ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], - "higgs": ["HiggsLinear", "replace_with_higgs_linear"], + "higgs": ["HiggsLinear", "quantize_with_higgs", "replace_with_higgs_linear"], "fsdp": ["is_fsdp_managed_module"], "ggml": [ "GGUF_CONFIG_MAPPING", @@ -157,7 +157,7 @@ ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear - from .higgs import HiggsLinear, replace_with_higgs_linear + from .higgs import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear from .fsdp import is_fsdp_managed_module from .ggml import ( GGUF_CONFIG_MAPPING, diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index a015fd71055614..8cb1917909f7ae 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -314,6 +314,7 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): grid = get_higgs_grid(p, 2**(p * bits)).to(weight.device) grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2 + device = weight.device weight = weight.clone().float() # Pad to Hadamard transform size weight = pad_to_block(weight, [1], 1024) @@ -322,17 +323,18 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): mult = weight.shape[1] // 1024 weight = weight.reshape(-1, mult, 1024) scales = torch.linalg.norm(weight, axis=-1) - weight = torch.ops.fast_hadamard_transform.fast_hadamard_transform(weight, 1) / scales[:, :, None] + weight = hadamard_transform(weight, 1) / scales[:, :, None] # Pad to edenn_d and project weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p) # Quantize - codes = torch.empty(weight.shape[:-1], device=weight.device, dtype=torch.uint8) + codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8) for i in range(0, weight.shape[0], 64): codes[i:i+64] = torch.argmax( 2 * weight[i:i+64] @ grid.T - grid_norm_2, dim=-1 ).to(torch.uint8) + del weight codes = codes.reshape(codes.shape[0], -1) scales = scales / 32 @@ -345,7 +347,7 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): group_size=256, vector_size=p, dtype=torch.float16, - device=weight.device, + device=device, ) return { @@ -354,7 +356,8 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): "tables": tables, "tables2": tables2, } - + +WORKSPACE = flute.utils.make_workspace_streamk(device="cuda") class HiggsLinear(nn.Module): def __init__( @@ -362,7 +365,6 @@ def __init__( in_features: int, out_features: int, num_bits: int, - group_size: int, num_sms_packed: int, bias=True, dtype: torch.dtype=None, @@ -372,17 +374,13 @@ def __init__( self.in_features = in_features self.out_features = out_features self.num_bits = num_bits - self.group_size = group_size self.num_sms_packed = num_sms_packed - - self.workspace = flute.utils.make_workspace_streamk(device=device) - - assert in_features % 16 == 0 - assert in_features % group_size == 0 + + assert in_features % 256 == 0 assert num_bits in [2, 3, 4] self.weight = nn.Parameter(torch.empty((in_features * num_bits // 16, out_features), dtype=torch.int16, device=device), requires_grad=False) - self.scales = nn.Parameter(torch.empty((out_features, in_features//group_size), dtype=dtype, device=device), requires_grad=False) + self.scales = nn.Parameter(torch.empty((out_features, in_features//256), dtype=dtype, device=device), requires_grad=False) self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False) self.tables2 = nn.Parameter(torch.empty((2**num_bits, 2**num_bits, 1), dtype=torch.float32, device=device), requires_grad=False) @@ -405,9 +403,9 @@ def forward(self, x): self.scales, self.tables, self.tables2, - self.workspace, + WORKSPACE, self.num_bits, - self.group_size, + 256, ) @@ -449,7 +447,7 @@ def replace_with_higgs_linear( ) if linear_weights_not_to_quantize is None: - linear_weights_not_to_quantize = [] + linear_weights_not_to_quantize = ["lm_head.weight"] from accelerate import init_empty_weights @@ -469,8 +467,7 @@ def replace_with_higgs_linear( in_features, out_features, bias=module.bias is not None, - num_bits=quantization_config.num_bits, - group_size=quantization_config.group_size, + num_bits=quantization_config.bits, num_sms_packed=quantization_config.num_sms_packed, ) has_been_replaced = True diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index b6a03653b80157..d28412c7f34b94 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -54,6 +54,7 @@ "aqlm": AqlmHfQuantizer, "quanto": QuantoHfQuantizer, "eetq": EetqHfQuantizer, + "higgs": HiggsHfQuantizer, "hqq": HqqHfQuantizer, "compressed-tensors": CompressedTensorsHfQuantizer, "fbgemm_fp8": FbgemmFp8HfQuantizer, diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 065e9b53a06d8b..868800707b715a 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..integrations import replace_with_higgs_linear, quantize_with_higgs +from ..integrations import HiggsLinear, replace_with_higgs_linear, quantize_with_higgs from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin @@ -43,13 +43,14 @@ def find_parent(model, name): return parent -class AqlmHfQuantizer(HfQuantizer): +class HiggsHfQuantizer(HfQuantizer): """ - Quantizer of the AQLM method. Enables the loading of prequantized models. + Quantizer of the HIGGS method. Enables the loading of prequantized models. """ - requires_calibration = True - required_packages = ["aqlm"] + requires_calibration = False + requires_parameters_quantization = True + required_packages = ["flute-kernel", "fast_hadamard_transform"] optimum_quantizer = None def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): @@ -91,21 +92,24 @@ def create_quantized_param( """ flute_dict = quantize_with_higgs( - param_value, + param_value.to(target_device), self.quantization_config.bits, self.quantization_config.p, ) - raise NotImplementedError("This function is not implemented yet.") - + del param_value + module, tensor_name = get_module_from_name(model, param_name) - module._buffers[tensor_name] = new_value.to(target_device) - # to have the right output shape -> (out_features, 1) - module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device) + for key, value in flute_dict.items(): + if key in module._parameters: + module._parameters[key] = value + elif key in module._buffers: + module._buffers[key] = value + else: + raise ValueError(f"Unexpected key {key} in module {module}") if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) - del param_name def _process_model_before_weight_loading( self, @@ -128,3 +132,20 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None): def is_serializable(self, safe_serialization=None): return True + + def check_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module, HiggsLinear) and tensor_name == "weight": + # Add here check for loaded components' dtypes once serialization is implemented + return True + else: + return False diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 25d837ccec0fbe..30ede3724939c8 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -75,12 +75,14 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flax_available, + is_flute_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, is_galore_torch_available, is_gguf_available, is_grokadamw_available, + is_hadamard_available, is_ipex_available, is_jieba_available, is_jinja_available, diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 492642d61babb5..7e3d18c657ded3 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -138,12 +138,14 @@ is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, + is_flute_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, is_galore_torch_available, is_gguf_available, is_grokadamw_available, + is_hadamard_available, is_hqq_available, is_in_notebook, is_ipex_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 70bd236e3bb4ac..95ba54a8d81113 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -102,6 +102,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") +_flute_available = True # _is_package_available("flute") _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") @@ -126,6 +127,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _faiss_available = False _ftfy_available = _is_package_available("ftfy") _g2p_en_available = _is_package_available("g2p_en") +_hadamard_available = _is_package_available("fast_hadamard_transform") _ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) _jieba_available = _is_package_available("jieba") _jinja_available = _is_package_available("jinja2") @@ -328,6 +330,10 @@ def is_torch_deterministic(): return False else: return True + + +def is_hadamard_available(): + return _hadamard_available def is_hqq_available(min_version: str = HQQ_MIN_VERSION): @@ -602,6 +608,10 @@ def is_flax_available(): return _flax_available +def is_flute_available(): + return _flute_available + + def is_ftfy_available(): return _ftfy_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index ac81864e50869b..1836d79ac3b0eb 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -41,6 +41,7 @@ class QuantizationMethod(str, Enum): AQLM = "aqlm" QUANTO = "quanto" EETQ = "eetq" + HIGGS = "higgs" HQQ = "hqq" COMPRESSED_TENSORS = "compressed-tensors" FBGEMM_FP8 = "fbgemm_fp8" @@ -1222,6 +1223,48 @@ def get_loading_attributes(self): loading_attibutes = ["activation_scale_ub"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict + + +@dataclass +class HiggsConfig(QuantizationConfigMixin): + """ + This is a wrapper class about `aqlm` parameters. + + Args: + in_group_size (`int`, *optional*, defaults to 8): + The group size along the input dimension. + out_group_size (`int`, *optional*, defaults to 1): + The group size along the output dimension. It's recommended to always use 1. + num_codebooks (`int`, *optional*, defaults to 1): + Number of codebooks for the Additive Quantization procedure. + nbits_per_codebook (`int`, *optional*, defaults to 16): + Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook. + linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*): + List of full paths of `nn.Linear` weight parameters that shall not be quantized. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + bits: int = 4, + p: int = 2, + linear_weights_not_to_quantize: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.HIGGS + self.bits = bits + self.p = p + self.linear_weights_not_to_quantize = linear_weights_not_to_quantize + self.num_sms_packed=128 + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + return @dataclass From 1c5b9e7d6290a987875b2a80ef14b9114606e297 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 11:50:11 +0100 Subject: [PATCH 03/25] per-model workspaces --- src/transformers/integrations/higgs.py | 17 +++++++++-------- src/transformers/quantizers/quantizer_higgs.py | 8 ++++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 8cb1917909f7ae..ab061b671339b1 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -16,13 +16,13 @@ from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available -# if is_torch_available(): -import torch -import torch.nn as nn +if is_torch_available(): + import torch + import torch.nn as nn -# if is_flute_available(): -import flute.utils +if is_flute_available(): + import flute.utils if is_hadamard_available(): from fast_hadamard_transform import hadamard_transform @@ -356,8 +356,7 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): "tables": tables, "tables2": tables2, } - -WORKSPACE = flute.utils.make_workspace_streamk(device="cuda") + class HiggsLinear(nn.Module): def __init__( @@ -388,6 +387,8 @@ def __init__( self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False) else: self.register_parameter("bias", None) + + self.workspace = None # must be set externally to be reused among layers def forward(self, x): x = pad_to_block(x, [-1], 1024) @@ -403,7 +404,7 @@ def forward(self, x): self.scales, self.tables, self.tables2, - WORKSPACE, + self.workspace, self.num_bits, 256, ) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 868800707b715a..5539f1c4dc9985 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -90,7 +90,12 @@ def create_quantized_param( """ Quantizes weights into weight and weight_scale """ + import flute.utils + if target_device not in model.flute_workspaces: + model.flute_workspaces[target_device] = flute.utils.make_workspace_streamk(device=target_device) + + flute_dict = quantize_with_higgs( param_value.to(target_device), self.quantization_config.bits, @@ -111,6 +116,8 @@ def create_quantized_param( if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) + module.workspace = model.flute_workspaces[target_device] + def _process_model_before_weight_loading( self, model: "PreTrainedModel", @@ -122,6 +129,7 @@ def _process_model_before_weight_loading( linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize, ) model.config.quantization_config = self.quantization_config + model.flute_workspaces = {} def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): return model From 9f2ef77e981272f10a32b01c713b502dca82a4a1 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 12:19:48 +0100 Subject: [PATCH 04/25] style --- docs/source/en/main_classes/quantization.md | 4 + src/transformers/__init__.py | 3 +- src/transformers/integrations/__init__.py | 2 +- src/transformers/integrations/higgs.py | 589 +++++++++--------- src/transformers/quantizers/auto.py | 4 +- .../quantizers/quantizer_higgs.py | 38 +- src/transformers/testing_utils.py | 2 - src/transformers/utils/import_utils.py | 9 +- src/transformers/utils/quantization_config.py | 4 +- 9 files changed, 337 insertions(+), 318 deletions(-) diff --git a/docs/source/en/main_classes/quantization.md b/docs/source/en/main_classes/quantization.md index 3f44569697777b..341f1147d74e46 100755 --- a/docs/source/en/main_classes/quantization.md +++ b/docs/source/en/main_classes/quantization.md @@ -53,6 +53,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide. [[autodoc]] quantizers.base.HfQuantizer +## HiggsConfig + +[[autodoc]] HiggsConfig + ## HqqConfig [[autodoc]] HqqConfig diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c8315734c5659e..1aba83aa092649 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -981,7 +981,6 @@ "CompressedTensorsConfig", "EetqConfig", "FbgemmFp8Config", - "FluteConfig", "GPTQConfig", "HiggsConfig", "HqqConfig", @@ -5927,8 +5926,8 @@ CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, - HiggsConfig, GPTQConfig, + HiggsConfig, HqqConfig, QuantoConfig, TorchAoConfig, diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 4d239680298366..a12552b10841d3 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -157,7 +157,6 @@ ) from .eetq import replace_with_eetq_linear from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear - from .higgs import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear from .fsdp import is_fsdp_managed_module from .ggml import ( GGUF_CONFIG_MAPPING, @@ -167,6 +166,7 @@ load_dequant_gguf_tensor, load_gguf, ) + from .higgs import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear from .hqq import prepare_for_hqq_linear from .integration_utils import ( INTEGRATION_TO_CALLBACK, diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index ab061b671339b1..d289ef4b5799c2 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -13,24 +13,30 @@ # limitations under the License. "HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file" -from ..utils import ACCELERATE_MIN_VERSION, is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available +from ..utils import ( + ACCELERATE_MIN_VERSION, + is_accelerate_available, + is_flute_available, + is_hadamard_available, + is_torch_available, +) if is_torch_available(): import torch import torch.nn as nn - + if is_flute_available(): import flute.utils if is_hadamard_available(): from fast_hadamard_transform import hadamard_transform - + if is_flute_available(): import flute.utils from flute.integrations.higgs import prepare_data_transposed - + def pad_to_block(tensor, dims, had_block_size, value=0): pad_dims = [0 for _ in range(2 * len(tensor.shape))] @@ -39,303 +45,303 @@ def pad_to_block(tensor, dims, had_block_size, value=0): next_multiple_of_1024 = ((size - 1) // had_block_size + 1) * had_block_size delta = next_multiple_of_1024 - size pad_dims[-2 * dim - 1] = delta - + return nn.functional.pad(tensor, pad_dims, "constant", value) def get_higgs_grid(p: int, n: int) -> torch.Tensor: if (p, n) == (2, 256): return torch.tensor( - [[-2.501467704772949, 0.17954708635807037], - [-0.6761789321899414, 1.2728623151779175], - [-1.8025816679000854, 0.7613157629966736], - [-0.538287878036499, -2.6028504371643066], - [0.8415029644966125, -0.8600977659225464], - [0.7023013234138489, 3.3138747215270996], - [0.5699077844619751, 2.5782253742218018], - [3.292393207550049, -0.6016128063201904], - [0.5561617016792297, -1.7723814249038696], - [-2.1012380123138428, 0.020958125591278076], - [0.46085724234580994, 0.8428705334663391], - [1.4548040628433228, -0.6156039237976074], - [3.210029363632202, 0.3546904921531677], - [0.8893890976905823, -0.5967988967895508], - [0.8618854284286499, -3.2061192989349365], - [1.1360996961593628, -0.23852407932281494], - [1.6646337509155273, -0.9265465140342712], - [1.4767773151397705, 1.2476022243499756], - [-1.0511897802352905, 1.94503915309906], - [-1.56318998336792, -0.3264186680316925], - [-0.1829211413860321, 0.2922491431236267], - [-0.8950616717338562, -1.3887052536010742], - [-0.08206957578659058, -1.329533576965332], - [-0.487422913312912, 1.4817842245101929], - [-1.6769757270812988, -2.8269758224487305], - [-1.5057679414749146, 1.8905963897705078], - [1.8335362672805786, 1.0515104532241821], - [0.3273945450782776, 1.0491033792495728], - [-3.295924186706543, -0.7021600008010864], - [-1.8428784608840942, -1.2315762042999268], - [-0.8575026392936707, -1.7005949020385742], - [-1.120667815208435, 0.6467998027801514], - [-0.1588846743106842, -1.804071068763733], - [-0.8539647459983826, 0.5645008683204651], - [-1.4192019701004028, -0.6175029873847961], - [1.0799058675765991, 1.7871345281600952], - [1.171311855316162, 0.7511613965034485], - [2.162078380584717, 0.8044339418411255], - [1.3969420194625854, -1.243762493133545], - [-0.23818807303905487, 0.053944624960422516], - [2.304199457168579, -1.2667627334594727], - [1.4225027561187744, 0.568610668182373], - [0.376836895942688, -0.7134661674499512], - [2.0404467582702637, 0.4087389409542084], - [0.7639489769935608, -1.1367933750152588], - [0.3622530400753021, -1.4827953577041626], - [0.4100743532180786, 0.36108437180519104], - [-1.5867475271224976, -1.618212342262268], - [-2.2769672870635986, -1.2132309675216675], - [0.9184022545814514, -0.34428009390830994], - [-0.3902314603328705, 0.21785245835781097], - [3.120687484741211, 1.3077973127365112], - [1.587440848350525, -1.6506884098052979], - [-1.718808889389038, -0.038405973464250565], - [-0.6888407468795776, -0.8402308821678162], - [-0.7981445789337158, -1.1117373704910278], - [-2.4124443531036377, 1.3419722318649292], - [-0.6611530184745789, 0.9939885139465332], - [-0.33103418350219727, -0.16702833771705627], - [-2.4091389179229736, -2.326857566833496], - [1.6610108613967896, -2.159703254699707], - [0.014884627424180508, 0.3887578248977661], - [0.029668325558304787, 1.8786455392837524], - [1.180362582206726, 2.699317216873169], - [1.821286678314209, -0.5960053205490112], - [-0.44835323095321655, 3.327436685562134], - [-0.3714401423931122, -2.1466753482818604], - [-1.1103475093841553, -2.4536871910095215], - [-0.39110705256462097, 0.6670510172843933], - [0.474752813577652, -1.1959707736968994], - [-0.013110585510730743, -2.52519154548645], - [-2.0836575031280518, -1.703289270401001], - [-1.1077687740325928, -0.1252644956111908], - [-0.4138077199459076, 1.1837692260742188], - [-1.977599024772644, 1.688241720199585], - [-1.659559965133667, -2.1387736797332764], - [0.03242531046271324, 0.6526556015014648], - [0.9127950072288513, 0.6099498867988586], - [-0.38478314876556396, 0.433487206697464], - [0.27454206347465515, -0.27719801664352417], - [0.10388526320457458, 2.2812814712524414], - [-0.014394169673323631, -3.177137613296509], - [-1.2871228456497192, -0.8961855173110962], - [0.5720916986465454, -0.921597957611084], - [1.1159656047821045, -0.7609877586364746], - [2.4383342266082764, -2.2983546257019043], - [-0.294057160615921, -0.9770799875259399], - [-0.9342701435089111, 1.107579231262207], - [-1.549338698387146, 3.090520143508911], - [2.6076579093933105, 2.051239013671875], - [-0.9259037375450134, 1.407211184501648], - [-0.1747353971004486, 0.540488600730896], - [-0.8963701725006104, 0.8271111249923706], - [0.6480194926261902, 1.0128909349441528], - [0.980783998966217, -0.06156221032142639], - [-0.16883476078510284, 1.0601658821105957], - [0.5839992761611938, 0.004697148688137531], - [-0.34228450059890747, -1.2423977851867676], - [2.500824451446533, 0.3665279746055603], - [-0.17641609907150269, 1.3529551029205322], - [0.05378641560673714, 2.817232847213745], - [-1.2391047477722168, 2.354328155517578], - [0.630434513092041, -0.668536365032196], - [1.7576488256454468, 0.6738647818565369], - [0.4435231387615204, 0.6000469326972961], - [-0.08794835954904556, -0.11511358618736267], - [1.6540337800979614, 0.33995017409324646], - [-0.04202975332736969, -0.5375117063522339], - [-0.4247745871543884, -0.7897617220878601], - [0.06695003807544708, 1.2000739574432373], - [-3.2508881092071533, 0.28734830021858215], - [-1.613816261291504, 0.4944162368774414], - [1.3598989248275757, 0.26117825508117676], - [2.308382511138916, 1.3462618589401245], - [-1.2137469053268433, -1.9254342317581177], - [-0.4889402985572815, 1.8136259317398071], - [-0.1870335340499878, -0.3480615019798279], - [1.0766386985778809, -1.0627082586288452], - [0.4651014506816864, 2.131748914718628], - [-0.1306295394897461, -0.7811847925186157], - [0.06433182954788208, -1.5397958755493164], - [-0.2894323468208313, -0.5789554715156555], - [-0.6081662178039551, 0.4845278263092041], - [2.697964668273926, -0.18515698611736298], - [0.1277363896369934, -0.7221432328224182], - [0.8700758218765259, 0.35042452812194824], - [0.22088994085788727, 0.495242178440094], - [-2.5843818187713623, -0.8000828623771667], - [0.6732649803161621, -1.4362232685089111], - [-1.5286413431167603, 1.0417330265045166], - [-1.1222513914108276, -0.6269875764846802], - [-0.9752035140991211, -0.8750635385513306], - [-2.6369473934173584, 0.6918523907661438], - [0.14478731155395508, -0.041986867785453796], - [-1.5629483461380005, 1.4369450807571411], - [0.38952457904815674, -2.16428804397583], - [-0.16885095834732056, 0.7976621985435486], - [-3.12416934967041, 1.256506085395813], - [0.6843105554580688, -0.4203019142150879], - [1.9345275163650513, 1.934950351715088], - [0.012184220366179943, -2.1080918312072754], - [-0.6350273489952087, 0.7358828186988831], - [-0.837304949760437, -0.6214472651481628], - [0.08211923390626907, -0.9472538232803345], - [2.9332995414733887, -1.4956780672073364], - [1.3806978464126587, -0.2916182279586792], - [0.06773144006729126, 0.9285762310028076], - [-1.1943119764328003, 1.5963770151138306], - [1.6395620107650757, -0.32285431027412415], - [-1.390851378440857, -0.08273141086101532], - [1.816330909729004, -1.2812227010726929], - [0.7921574711799622, -2.1135804653167725], - [0.5817914605140686, 1.2644577026367188], - [1.929347038269043, -0.2386285960674286], - [0.8877345323562622, 1.190008521080017], - [1.4732073545455933, 0.8935023546218872], - [-2.8518524169921875, -1.5478795766830444], - [0.2439267635345459, 0.7576767802238464], - [0.5246709585189819, -2.606659412384033], - [1.150876760482788, 1.4073830842971802], - [-0.2643202245235443, 2.0634236335754395], - [1.555483341217041, -0.0023102816194295883], - [2.0830578804016113, -1.7225427627563477], - [-0.5424830317497253, -1.070199728012085], - [0.9168899655342102, 0.8955540060997009], - [-0.8120972514152527, 2.696739912033081], - [-0.29908373951911926, -1.5310651063919067], - [1.2320337295532227, -1.556247353553772], - [1.8612544536590576, 0.08704725652933121], - [0.22133447229862213, -1.8091708421707153], - [-0.4403655230998993, -0.38571012020111084], - [-1.88539457321167, 1.192205786705017], - [2.239687919616699, 0.004709010478109121], - [1.139495611190796, 0.45733731985092163], - [-1.507995367050171, 0.19716016948223114], - [0.46986445784568787, 1.5422041416168213], - [-1.2573751211166382, -0.35984551906585693], - [-1.7415345907211304, -0.6020717024803162], - [1.0751984119415283, 0.19006384909152985], - [2.24186635017395, -0.46343153715133667], - [0.3610347509384155, -0.07658443599939346], - [-1.3111497163772583, 0.432013601064682], - [0.6164408326148987, 0.24538464844226837], - [-1.9266542196273804, -0.3256155550479889], - [-0.5870336890220642, -0.1879584938287735], - [-1.0476511716842651, 0.3677721917629242], - [-1.229940414428711, 1.2433830499649048], - [0.18550436198711395, 0.22753673791885376], - [-0.017921989783644676, 0.12625974416732788], - [1.1659504175186157, -0.5020995736122131], - [-0.5983408093452454, -1.40438973903656], - [0.7519024014472961, -0.16282692551612854], - [0.9920787811279297, -1.344896912574768], - [-0.8103678226470947, 0.3064485788345337], - [0.6956969499588013, 1.8208192586898804], - [-2.7830491065979004, -0.2299390584230423], - [-0.34681546688079834, 2.4890666007995605], - [-1.4452646970748901, -1.2216600179672241], - [-2.1872897148132324, 0.8926076292991638], - [1.706072211265564, -2.8440372943878174], - [1.1119003295898438, -2.4923460483551025], - [-2.582794666290283, 2.0973289012908936], - [0.04987720400094986, -0.2964983284473419], - [-2.063807487487793, -0.7847916483879089], - [-0.4068813621997833, 0.9135897755622864], - [-0.9814359545707703, -0.3874954879283905], - [-1.4227229356765747, 0.7337291240692139], - [0.3065044581890106, 1.3125417232513428], - [1.2160996198654175, -1.9643305540084839], - [-1.2163853645324707, 0.14608727395534515], - [-2.3030710220336914, -0.37558120489120483], - [0.9232977628707886, 2.1843791007995605], - [-0.1989777386188507, 1.651851773262024], - [-0.714374840259552, -0.39365994930267334], - [-0.7805715799331665, -2.099881887435913], - [0.9015759229660034, -1.7053706645965576], - [0.1033422127366066, 1.5256654024124146], - [-1.8773194551467896, 2.324174165725708], - [1.9227174520492554, 2.7441604137420654], - [-0.5994020104408264, 0.23984014987945557], - [1.3496100902557373, -0.9126054644584656], - [-0.8765304088592529, -3.1877026557922363], - [-1.2040035724639893, -1.5169521570205688], - [1.4261796474456787, 2.150200128555298], - [1.463774561882019, 1.6656692028045654], - [0.20364105701446533, -0.4988172650337219], - [0.5195154547691345, -0.24067887663841248], - [-1.1116786003112793, -1.1599653959274292], - [-0.8490808606147766, -0.1681060940027237], - [0.3189965784549713, -0.9641751646995544], - [-0.5664751529693604, -0.5951744318008423], - [-1.6347930431365967, -0.9137664437294006], - [0.44048091769218445, -0.47259435057640076], - [-2.147747039794922, 0.47442489862442017], - [1.834734320640564, 1.4462147951126099], - [1.1777573823928833, 1.0659226179122925], - [-0.9568989872932434, 0.09495053440332413], - [-1.838529348373413, 0.2950586676597595], - [-0.4800611734390259, 0.014894310384988785], - [-0.5235516428947449, -1.7687653303146362], - [2.0735011100769043, -0.8825281262397766], - [2.637502431869507, 0.8455678224563599], - [2.606602907180786, -0.7848446369171143], - [-1.1886937618255615, 0.9330510497093201], - [0.38082656264305115, 0.13328030705451965], - [0.6847941875457764, 0.7384101152420044], - [1.2638574838638306, -0.007309418171644211], - [0.18292222917079926, -1.22371244430542], - [0.8143821954727173, 1.4976691007614136], - [0.6571850776672363, 0.48368802666664124], - [-0.6991601586341858, 2.150190830230713], - [0.8101756572723389, 0.10206498205661774], - [-0.08768226951360703, -1.084917664527893], - [-0.7208092212677002, 0.03657956421375275], - [0.3211449086666107, 1.803687334060669], - [-0.7835946083068848, 1.6869111061096191]] + [ + [-2.501467704772949, 0.17954708635807037], + [-0.6761789321899414, 1.2728623151779175], + [-1.8025816679000854, 0.7613157629966736], + [-0.538287878036499, -2.6028504371643066], + [0.8415029644966125, -0.8600977659225464], + [0.7023013234138489, 3.3138747215270996], + [0.5699077844619751, 2.5782253742218018], + [3.292393207550049, -0.6016128063201904], + [0.5561617016792297, -1.7723814249038696], + [-2.1012380123138428, 0.020958125591278076], + [0.46085724234580994, 0.8428705334663391], + [1.4548040628433228, -0.6156039237976074], + [3.210029363632202, 0.3546904921531677], + [0.8893890976905823, -0.5967988967895508], + [0.8618854284286499, -3.2061192989349365], + [1.1360996961593628, -0.23852407932281494], + [1.6646337509155273, -0.9265465140342712], + [1.4767773151397705, 1.2476022243499756], + [-1.0511897802352905, 1.94503915309906], + [-1.56318998336792, -0.3264186680316925], + [-0.1829211413860321, 0.2922491431236267], + [-0.8950616717338562, -1.3887052536010742], + [-0.08206957578659058, -1.329533576965332], + [-0.487422913312912, 1.4817842245101929], + [-1.6769757270812988, -2.8269758224487305], + [-1.5057679414749146, 1.8905963897705078], + [1.8335362672805786, 1.0515104532241821], + [0.3273945450782776, 1.0491033792495728], + [-3.295924186706543, -0.7021600008010864], + [-1.8428784608840942, -1.2315762042999268], + [-0.8575026392936707, -1.7005949020385742], + [-1.120667815208435, 0.6467998027801514], + [-0.1588846743106842, -1.804071068763733], + [-0.8539647459983826, 0.5645008683204651], + [-1.4192019701004028, -0.6175029873847961], + [1.0799058675765991, 1.7871345281600952], + [1.171311855316162, 0.7511613965034485], + [2.162078380584717, 0.8044339418411255], + [1.3969420194625854, -1.243762493133545], + [-0.23818807303905487, 0.053944624960422516], + [2.304199457168579, -1.2667627334594727], + [1.4225027561187744, 0.568610668182373], + [0.376836895942688, -0.7134661674499512], + [2.0404467582702637, 0.4087389409542084], + [0.7639489769935608, -1.1367933750152588], + [0.3622530400753021, -1.4827953577041626], + [0.4100743532180786, 0.36108437180519104], + [-1.5867475271224976, -1.618212342262268], + [-2.2769672870635986, -1.2132309675216675], + [0.9184022545814514, -0.34428009390830994], + [-0.3902314603328705, 0.21785245835781097], + [3.120687484741211, 1.3077973127365112], + [1.587440848350525, -1.6506884098052979], + [-1.718808889389038, -0.038405973464250565], + [-0.6888407468795776, -0.8402308821678162], + [-0.7981445789337158, -1.1117373704910278], + [-2.4124443531036377, 1.3419722318649292], + [-0.6611530184745789, 0.9939885139465332], + [-0.33103418350219727, -0.16702833771705627], + [-2.4091389179229736, -2.326857566833496], + [1.6610108613967896, -2.159703254699707], + [0.014884627424180508, 0.3887578248977661], + [0.029668325558304787, 1.8786455392837524], + [1.180362582206726, 2.699317216873169], + [1.821286678314209, -0.5960053205490112], + [-0.44835323095321655, 3.327436685562134], + [-0.3714401423931122, -2.1466753482818604], + [-1.1103475093841553, -2.4536871910095215], + [-0.39110705256462097, 0.6670510172843933], + [0.474752813577652, -1.1959707736968994], + [-0.013110585510730743, -2.52519154548645], + [-2.0836575031280518, -1.703289270401001], + [-1.1077687740325928, -0.1252644956111908], + [-0.4138077199459076, 1.1837692260742188], + [-1.977599024772644, 1.688241720199585], + [-1.659559965133667, -2.1387736797332764], + [0.03242531046271324, 0.6526556015014648], + [0.9127950072288513, 0.6099498867988586], + [-0.38478314876556396, 0.433487206697464], + [0.27454206347465515, -0.27719801664352417], + [0.10388526320457458, 2.2812814712524414], + [-0.014394169673323631, -3.177137613296509], + [-1.2871228456497192, -0.8961855173110962], + [0.5720916986465454, -0.921597957611084], + [1.1159656047821045, -0.7609877586364746], + [2.4383342266082764, -2.2983546257019043], + [-0.294057160615921, -0.9770799875259399], + [-0.9342701435089111, 1.107579231262207], + [-1.549338698387146, 3.090520143508911], + [2.6076579093933105, 2.051239013671875], + [-0.9259037375450134, 1.407211184501648], + [-0.1747353971004486, 0.540488600730896], + [-0.8963701725006104, 0.8271111249923706], + [0.6480194926261902, 1.0128909349441528], + [0.980783998966217, -0.06156221032142639], + [-0.16883476078510284, 1.0601658821105957], + [0.5839992761611938, 0.004697148688137531], + [-0.34228450059890747, -1.2423977851867676], + [2.500824451446533, 0.3665279746055603], + [-0.17641609907150269, 1.3529551029205322], + [0.05378641560673714, 2.817232847213745], + [-1.2391047477722168, 2.354328155517578], + [0.630434513092041, -0.668536365032196], + [1.7576488256454468, 0.6738647818565369], + [0.4435231387615204, 0.6000469326972961], + [-0.08794835954904556, -0.11511358618736267], + [1.6540337800979614, 0.33995017409324646], + [-0.04202975332736969, -0.5375117063522339], + [-0.4247745871543884, -0.7897617220878601], + [0.06695003807544708, 1.2000739574432373], + [-3.2508881092071533, 0.28734830021858215], + [-1.613816261291504, 0.4944162368774414], + [1.3598989248275757, 0.26117825508117676], + [2.308382511138916, 1.3462618589401245], + [-1.2137469053268433, -1.9254342317581177], + [-0.4889402985572815, 1.8136259317398071], + [-0.1870335340499878, -0.3480615019798279], + [1.0766386985778809, -1.0627082586288452], + [0.4651014506816864, 2.131748914718628], + [-0.1306295394897461, -0.7811847925186157], + [0.06433182954788208, -1.5397958755493164], + [-0.2894323468208313, -0.5789554715156555], + [-0.6081662178039551, 0.4845278263092041], + [2.697964668273926, -0.18515698611736298], + [0.1277363896369934, -0.7221432328224182], + [0.8700758218765259, 0.35042452812194824], + [0.22088994085788727, 0.495242178440094], + [-2.5843818187713623, -0.8000828623771667], + [0.6732649803161621, -1.4362232685089111], + [-1.5286413431167603, 1.0417330265045166], + [-1.1222513914108276, -0.6269875764846802], + [-0.9752035140991211, -0.8750635385513306], + [-2.6369473934173584, 0.6918523907661438], + [0.14478731155395508, -0.041986867785453796], + [-1.5629483461380005, 1.4369450807571411], + [0.38952457904815674, -2.16428804397583], + [-0.16885095834732056, 0.7976621985435486], + [-3.12416934967041, 1.256506085395813], + [0.6843105554580688, -0.4203019142150879], + [1.9345275163650513, 1.934950351715088], + [0.012184220366179943, -2.1080918312072754], + [-0.6350273489952087, 0.7358828186988831], + [-0.837304949760437, -0.6214472651481628], + [0.08211923390626907, -0.9472538232803345], + [2.9332995414733887, -1.4956780672073364], + [1.3806978464126587, -0.2916182279586792], + [0.06773144006729126, 0.9285762310028076], + [-1.1943119764328003, 1.5963770151138306], + [1.6395620107650757, -0.32285431027412415], + [-1.390851378440857, -0.08273141086101532], + [1.816330909729004, -1.2812227010726929], + [0.7921574711799622, -2.1135804653167725], + [0.5817914605140686, 1.2644577026367188], + [1.929347038269043, -0.2386285960674286], + [0.8877345323562622, 1.190008521080017], + [1.4732073545455933, 0.8935023546218872], + [-2.8518524169921875, -1.5478795766830444], + [0.2439267635345459, 0.7576767802238464], + [0.5246709585189819, -2.606659412384033], + [1.150876760482788, 1.4073830842971802], + [-0.2643202245235443, 2.0634236335754395], + [1.555483341217041, -0.0023102816194295883], + [2.0830578804016113, -1.7225427627563477], + [-0.5424830317497253, -1.070199728012085], + [0.9168899655342102, 0.8955540060997009], + [-0.8120972514152527, 2.696739912033081], + [-0.29908373951911926, -1.5310651063919067], + [1.2320337295532227, -1.556247353553772], + [1.8612544536590576, 0.08704725652933121], + [0.22133447229862213, -1.8091708421707153], + [-0.4403655230998993, -0.38571012020111084], + [-1.88539457321167, 1.192205786705017], + [2.239687919616699, 0.004709010478109121], + [1.139495611190796, 0.45733731985092163], + [-1.507995367050171, 0.19716016948223114], + [0.46986445784568787, 1.5422041416168213], + [-1.2573751211166382, -0.35984551906585693], + [-1.7415345907211304, -0.6020717024803162], + [1.0751984119415283, 0.19006384909152985], + [2.24186635017395, -0.46343153715133667], + [0.3610347509384155, -0.07658443599939346], + [-1.3111497163772583, 0.432013601064682], + [0.6164408326148987, 0.24538464844226837], + [-1.9266542196273804, -0.3256155550479889], + [-0.5870336890220642, -0.1879584938287735], + [-1.0476511716842651, 0.3677721917629242], + [-1.229940414428711, 1.2433830499649048], + [0.18550436198711395, 0.22753673791885376], + [-0.017921989783644676, 0.12625974416732788], + [1.1659504175186157, -0.5020995736122131], + [-0.5983408093452454, -1.40438973903656], + [0.7519024014472961, -0.16282692551612854], + [0.9920787811279297, -1.344896912574768], + [-0.8103678226470947, 0.3064485788345337], + [0.6956969499588013, 1.8208192586898804], + [-2.7830491065979004, -0.2299390584230423], + [-0.34681546688079834, 2.4890666007995605], + [-1.4452646970748901, -1.2216600179672241], + [-2.1872897148132324, 0.8926076292991638], + [1.706072211265564, -2.8440372943878174], + [1.1119003295898438, -2.4923460483551025], + [-2.582794666290283, 2.0973289012908936], + [0.04987720400094986, -0.2964983284473419], + [-2.063807487487793, -0.7847916483879089], + [-0.4068813621997833, 0.9135897755622864], + [-0.9814359545707703, -0.3874954879283905], + [-1.4227229356765747, 0.7337291240692139], + [0.3065044581890106, 1.3125417232513428], + [1.2160996198654175, -1.9643305540084839], + [-1.2163853645324707, 0.14608727395534515], + [-2.3030710220336914, -0.37558120489120483], + [0.9232977628707886, 2.1843791007995605], + [-0.1989777386188507, 1.651851773262024], + [-0.714374840259552, -0.39365994930267334], + [-0.7805715799331665, -2.099881887435913], + [0.9015759229660034, -1.7053706645965576], + [0.1033422127366066, 1.5256654024124146], + [-1.8773194551467896, 2.324174165725708], + [1.9227174520492554, 2.7441604137420654], + [-0.5994020104408264, 0.23984014987945557], + [1.3496100902557373, -0.9126054644584656], + [-0.8765304088592529, -3.1877026557922363], + [-1.2040035724639893, -1.5169521570205688], + [1.4261796474456787, 2.150200128555298], + [1.463774561882019, 1.6656692028045654], + [0.20364105701446533, -0.4988172650337219], + [0.5195154547691345, -0.24067887663841248], + [-1.1116786003112793, -1.1599653959274292], + [-0.8490808606147766, -0.1681060940027237], + [0.3189965784549713, -0.9641751646995544], + [-0.5664751529693604, -0.5951744318008423], + [-1.6347930431365967, -0.9137664437294006], + [0.44048091769218445, -0.47259435057640076], + [-2.147747039794922, 0.47442489862442017], + [1.834734320640564, 1.4462147951126099], + [1.1777573823928833, 1.0659226179122925], + [-0.9568989872932434, 0.09495053440332413], + [-1.838529348373413, 0.2950586676597595], + [-0.4800611734390259, 0.014894310384988785], + [-0.5235516428947449, -1.7687653303146362], + [2.0735011100769043, -0.8825281262397766], + [2.637502431869507, 0.8455678224563599], + [2.606602907180786, -0.7848446369171143], + [-1.1886937618255615, 0.9330510497093201], + [0.38082656264305115, 0.13328030705451965], + [0.6847941875457764, 0.7384101152420044], + [1.2638574838638306, -0.007309418171644211], + [0.18292222917079926, -1.22371244430542], + [0.8143821954727173, 1.4976691007614136], + [0.6571850776672363, 0.48368802666664124], + [-0.6991601586341858, 2.150190830230713], + [0.8101756572723389, 0.10206498205661774], + [-0.08768226951360703, -1.084917664527893], + [-0.7208092212677002, 0.03657956421375275], + [0.3211449086666107, 1.803687334060669], + [-0.7835946083068848, 1.6869111061096191], + ] ) else: raise NotImplementedError(f"Unsupported p={p}, n={n}") -def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): +def quantize_with_higgs(weight: torch.Tensor, bits: int = 4, p: int = 2): assert len(weight.shape) == 2, "Only 2D weights are supported for now" assert weight.device.type == "cuda", "Only CUDA devices are supported for now" - - grid = get_higgs_grid(p, 2**(p * bits)).to(weight.device) + + grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device) grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2 - + device = weight.device weight = weight.clone().float() # Pad to Hadamard transform size weight = pad_to_block(weight, [1], 1024) - + # Scale and Hadamard transform mult = weight.shape[1] // 1024 weight = weight.reshape(-1, mult, 1024) scales = torch.linalg.norm(weight, axis=-1) weight = hadamard_transform(weight, 1) / scales[:, :, None] - + # Pad to edenn_d and project weight = pad_to_block(weight, [2], p).reshape(weight.shape[0], mult, -1, p) # Quantize codes = torch.empty(weight.shape[:-1], device=device, dtype=torch.uint8) for i in range(0, weight.shape[0], 64): - codes[i:i+64] = torch.argmax( - 2 * weight[i:i+64] @ grid.T - grid_norm_2, dim=-1 - ).to(torch.uint8) + codes[i : i + 64] = torch.argmax(2 * weight[i : i + 64] @ grid.T - grid_norm_2, dim=-1).to(torch.uint8) del weight - + codes = codes.reshape(codes.shape[0], -1) scales = scales / 32 @@ -349,7 +355,7 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int=4, p: int=2): dtype=torch.float16, device=device, ) - + return { "weight": weight, "scales": scales, @@ -366,8 +372,8 @@ def __init__( num_bits: int, num_sms_packed: int, bias=True, - dtype: torch.dtype=None, - device: torch.device=None, + dtype: torch.dtype = None, + device: torch.device = None, ): super().__init__() self.in_features = in_features @@ -377,27 +383,34 @@ def __init__( assert in_features % 256 == 0 assert num_bits in [2, 3, 4] - - self.weight = nn.Parameter(torch.empty((in_features * num_bits // 16, out_features), dtype=torch.int16, device=device), requires_grad=False) - self.scales = nn.Parameter(torch.empty((out_features, in_features//256), dtype=dtype, device=device), requires_grad=False) + + self.weight = nn.Parameter( + torch.empty((in_features * num_bits // 16, out_features), dtype=torch.int16, device=device), + requires_grad=False, + ) + self.scales = nn.Parameter( + torch.empty((out_features, in_features // 256), dtype=dtype, device=device), requires_grad=False + ) self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False) - self.tables2 = nn.Parameter(torch.empty((2**num_bits, 2**num_bits, 1), dtype=torch.float32, device=device), requires_grad=False) - + self.tables2 = nn.Parameter( + torch.empty((2**num_bits, 2**num_bits, 1), dtype=torch.float32, device=device), requires_grad=False + ) + if bias: self.bias = nn.Parameter(torch.empty(out_features, device=device, dtype=dtype), requires_grad=False) else: self.register_parameter("bias", None) - - self.workspace = None # must be set externally to be reused among layers - + + self.workspace = None # must be set externally to be reused among layers + def forward(self, x): x = pad_to_block(x, [-1], 1024) - + orig_shape = x.shape x = x.reshape(-1, 1024) - x = hadamard_transform(x, scale=1/32) + x = hadamard_transform(x, scale=1 / 32) x = x.reshape(orig_shape) - + return flute.qgemm_simple( x, self.weight, @@ -438,9 +451,11 @@ def replace_with_higgs_linear( """ if not is_flute_available(): raise ValueError("FLUTE is not available. Please install it with `pip install flute-kernel`") - + if not is_hadamard_available(): - raise ValueError("Fast Hadamard Transform is not available. Please install it with `pip install fast_hadamard_transform`") + raise ValueError( + "Fast Hadamard Transform is not available. Please install it with `pip install fast_hadamard_transform`" + ) if not is_accelerate_available(): raise ValueError( diff --git a/src/transformers/quantizers/auto.py b/src/transformers/quantizers/auto.py index d28412c7f34b94..80ed2f8562488b 100755 --- a/src/transformers/quantizers/auto.py +++ b/src/transformers/quantizers/auto.py @@ -23,8 +23,8 @@ CompressedTensorsConfig, EetqConfig, FbgemmFp8Config, - HiggsConfig, GPTQConfig, + HiggsConfig, HqqConfig, QuantizationConfigMixin, QuantizationMethod, @@ -39,8 +39,8 @@ from .quantizer_compressed_tensors import CompressedTensorsHfQuantizer from .quantizer_eetq import EetqHfQuantizer from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer -from .quantizer_higgs import HiggsHfQuantizer from .quantizer_gptq import GptqHfQuantizer +from .quantizer_higgs import HiggsHfQuantizer from .quantizer_hqq import HqqHfQuantizer from .quantizer_quanto import QuantoHfQuantizer from .quantizer_torchao import TorchAoHfQuantizer diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 5539f1c4dc9985..4fd69da7ae38e4 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -11,10 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -from typing import TYPE_CHECKING, Optional, Dict, List, Any - -from packaging import version +from typing import TYPE_CHECKING, Any, Dict, List, Optional from .base import HfQuantizer from .quantizers_utils import get_module_from_name @@ -23,14 +20,14 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..integrations import HiggsLinear, replace_with_higgs_linear, quantize_with_higgs -from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging -from ..utils.quantization_config import QuantizationConfigMixin - - # if is_torch_available(): import torch +from ..integrations import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear +from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, logging +from ..utils.quantization_config import QuantizationConfigMixin + + logger = logging.get_logger(__name__) @@ -63,9 +60,11 @@ def validate_environment(self, *args, **kwargs): if not is_flute_available(): raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel`") - + if not is_hadamard_available(): - raise ImportError("Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`") + raise ImportError( + "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`" + ) def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: @@ -75,9 +74,11 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": "CUDA available. Assuming HIGGS inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." ) else: - raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.") + raise NotImplementedError( + "HIGGS quantization is only supported on GPU. Please use a different quantizer." + ) return torch_dtype - + def create_quantized_param( self, model: "PreTrainedModel", @@ -91,19 +92,18 @@ def create_quantized_param( Quantizes weights into weight and weight_scale """ import flute.utils - + if target_device not in model.flute_workspaces: model.flute_workspaces[target_device] = flute.utils.make_workspace_streamk(device=target_device) - flute_dict = quantize_with_higgs( param_value.to(target_device), self.quantization_config.bits, self.quantization_config.p, ) - + del param_value - + module, tensor_name = get_module_from_name(model, param_name) for key, value in flute_dict.items(): if key in module._parameters: @@ -140,7 +140,7 @@ def is_trainable(self, model: Optional["PreTrainedModel"] = None): def is_serializable(self, safe_serialization=None): return True - + def check_quantized_param( self, model: "PreTrainedModel", @@ -149,8 +149,6 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ) -> bool: - import bitsandbytes as bnb - module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, HiggsLinear) and tensor_name == "weight": # Add here check for loaded components' dtypes once serialization is implemented diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 30ede3724939c8..25d837ccec0fbe 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -75,14 +75,12 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flax_available, - is_flute_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, is_galore_torch_available, is_gguf_available, is_grokadamw_available, - is_hadamard_available, is_ipex_available, is_jieba_available, is_jinja_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 95ba54a8d81113..5ea304d28bbd48 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -102,7 +102,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") -_flute_available = True # _is_package_available("flute") +try: + _flute_available = package_exists = ( + importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") == "0.2.6" + ) +except importlib.metadata.PackageNotFoundError: + _flute_available = False _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") @@ -330,7 +335,7 @@ def is_torch_deterministic(): return False else: return True - + def is_hadamard_available(): return _hadamard_available diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 1836d79ac3b0eb..833ae05bc36765 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1223,7 +1223,7 @@ def get_loading_attributes(self): loading_attibutes = ["activation_scale_ub"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict - + @dataclass class HiggsConfig(QuantizationConfigMixin): @@ -1256,7 +1256,7 @@ def __init__( self.bits = bits self.p = p self.linear_weights_not_to_quantize = linear_weights_not_to_quantize - self.num_sms_packed=128 + self.num_sms_packed = 128 self.post_init() From 0ff58c372063ffa5ad2efa9f1b10f9fa9a424d63 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 12:51:32 +0100 Subject: [PATCH 05/25] style 2 --- src/transformers/integrations/higgs.py | 2 +- src/transformers/utils/quantization_config.py | 22 ++++++++----------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index d289ef4b5799c2..ed9c64a951a34f 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -463,7 +463,7 @@ def replace_with_higgs_linear( ) if linear_weights_not_to_quantize is None: - linear_weights_not_to_quantize = ["lm_head.weight"] + linear_weights_not_to_quantize = [] from accelerate import init_empty_weights diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 833ae05bc36765..d8080434b296e5 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1228,21 +1228,15 @@ def get_loading_attributes(self): @dataclass class HiggsConfig(QuantizationConfigMixin): """ - This is a wrapper class about `aqlm` parameters. + HiggsConfig is a configuration class for quantization using the HIGGS method. Args: - in_group_size (`int`, *optional*, defaults to 8): - The group size along the input dimension. - out_group_size (`int`, *optional*, defaults to 1): - The group size along the output dimension. It's recommended to always use 1. - num_codebooks (`int`, *optional*, defaults to 1): - Number of codebooks for the Additive Quantization procedure. - nbits_per_codebook (`int`, *optional*, defaults to 16): - Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook. - linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*): - List of full paths of `nn.Linear` weight parameters that shall not be quantized. - kwargs (`Dict[str, Any]`, *optional*): - Additional parameters from which to initialize the configuration object. + bits (int, *optional*, defaults to 4): + Number of bits to use for quantization. Default is 4. + p (int, *optional*, defaults to 2): + Parameter for the HIGGS quantization method. Default is 2. + linear_weights_not_to_quantize (`list`, *optional*, default to ["lm_head.weight"]): + List of linear weight names that should not be quantized. """ def __init__( @@ -1252,6 +1246,8 @@ def __init__( linear_weights_not_to_quantize: Optional[List[str]] = None, **kwargs, ): + if linear_weights_not_to_quantize is None: + linear_weights_not_to_quantize = ["lm_head.weight"] self.quant_method = QuantizationMethod.HIGGS self.bits = bits self.p = p From b6bad710609350dd3d6b028b14fff0cddf38d02d Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 13:41:45 +0100 Subject: [PATCH 06/25] tests and style --- src/transformers/integrations/higgs.py | 15 +- .../quantizers/quantizer_higgs.py | 1 - src/transformers/testing_utils.py | 9 + tests/quantization/higgs/test_higgs.py | 300 ++++++++++++++++++ 4 files changed, 314 insertions(+), 11 deletions(-) create mode 100644 tests/quantization/higgs/test_higgs.py diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index ed9c64a951a34f..be3860103312dd 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -426,7 +426,6 @@ def forward(self, x): def replace_with_higgs_linear( model, quantization_config=None, - linear_weights_not_to_quantize=None, current_key_name=None, has_been_replaced=False, ): @@ -440,9 +439,6 @@ def replace_with_higgs_linear( The model to convert, can be any `torch.nn.Module` instance. quantization_config (`HiggsConfig`): The quantization config object that contains the quantization parameters. - linear_weights_not_to_quantize (`list[str]`, *optional*): - A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be - converted. current_key_name (`list`, *optional*): A list that contains the current key name. This is used for recursion and should not be passed by the user. has_been_replaced (`bool`, *optional*): @@ -462,9 +458,6 @@ def replace_with_higgs_linear( f"HIGGS requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" ) - if linear_weights_not_to_quantize is None: - linear_weights_not_to_quantize = [] - from accelerate import init_empty_weights for name, module in model.named_children(): @@ -473,8 +466,11 @@ def replace_with_higgs_linear( current_key_name.append(name) if isinstance(module, nn.Linear): - # Check if the current key is not in the `linear_weights_not_to_quantize` - if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize: + # Check if the current key is not in the `quantization_config.linear_weights_not_to_quantize` + current_key_name_str = ".".join(current_key_name) + ".weight" + if not any( + current_key_name_str.endswith(key) for key in quantization_config.linear_weights_not_to_quantize + ): with init_empty_weights(): in_features = module.in_features out_features = module.out_features @@ -496,7 +492,6 @@ def replace_with_higgs_linear( _, has_been_replaced = replace_with_higgs_linear( module, quantization_config=quantization_config, - linear_weights_not_to_quantize=linear_weights_not_to_quantize, current_key_name=current_key_name, has_been_replaced=has_been_replaced, ) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 4fd69da7ae38e4..9903e08ad0b717 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -126,7 +126,6 @@ def _process_model_before_weight_loading( replace_with_higgs_linear( model, quantization_config=self.quantization_config, - linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize, ) model.config.quantization_config = self.quantization_config model.flute_workspaces = {} diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 25d837ccec0fbe..6581248d6b4133 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -75,12 +75,14 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flax_available, + is_flute_available, is_fsdp_available, is_ftfy_available, is_g2p_en_available, is_galore_torch_available, is_gguf_available, is_grokadamw_available, + is_hadamard_available, is_ipex_available, is_jieba_available, is_jinja_available, @@ -1227,6 +1229,13 @@ def require_fbgemm_gpu(test_case): return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) +def require_flute_hadamard(test_case): + """ + Decorator marking a test that requires higgs and hadamard + """ + return unittest.skipUnless(is_flute_available() and is_hadamard_available(), "test requires aqlm")(test_case) + + def require_phonemizer(test_case): """ Decorator marking a test that requires phonemizer diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py new file mode 100644 index 00000000000000..da89deee197892 --- /dev/null +++ b/tests/quantization/higgs/test_higgs.py @@ -0,0 +1,300 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HiggsConfig, OPTForCausalLM +from transformers.testing_utils import ( + require_accelerate, + require_flute_hadamard, + require_torch_gpu, + require_torch_multi_gpu, + slow, + torch_device, +) +from transformers.utils import is_accelerate_available, is_torch_available + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +# @require_torch_gpu +# class HiggsConfigTest(unittest.TestCase): +# def test_to_dict(self): +# """ +# Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object +# """ +# quantization_config = HiggsConfig() +# config_to_dict = quantization_config.to_dict() + +# for key in config_to_dict: +# self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + +# def test_from_dict(self): +# """ +# Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict +# """ +# dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"} +# quantization_config = HiggsConfig.from_dict(dict) + +# self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) +# self.assertEqual(dict["quant_method"], quantization_config.quant_method) + + +@slow +@require_torch_gpu +@require_flute_hadamard +@require_accelerate +# @require_read_token +class HiggsTest(unittest.TestCase): + model_name = "meta-llama/Meta-Llama-3.1-8B" + + input_text = "What are we having for dinner?" + max_new_tokens = 9 + + EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad" + + device_map = "cuda" + + offload_device_map = { + "model.embed_tokens": 0, + "model.layers.0": 0, + "model.layers.1": 0, + "model.layers.2": 0, + "model.layers.3": 0, + "model.layers.4": 0, + "model.layers.5": 0, + "model.layers.6": 0, + "model.layers.7": 0, + "model.layers.8": 0, + "model.layers.9": 0, + "model.layers.10": 0, + "model.layers.11": 0, + "model.layers.12": 0, + "model.layers.13": 0, + "model.layers.14": 0, + "model.layers.15": 0, + "model.layers.16": "cpu", + "model.layers.17": "cpu", + "model.layers.18": "cpu", + "model.layers.19": "cpu", + "model.layers.20": "disk", + "model.layers.21": "disk", + "model.layers.22": "disk", + "model.layers.23": "disk", + "model.layers.24": "disk", + "model.layers.25": "disk", + "model.layers.26": "disk", + "model.layers.27": "disk", + "model.layers.28": "disk", + "model.layers.29": "disk", + "model.layers.30": "disk", + "model.layers.31": "disk", + "model.norm": "disk", + "lm_head": "disk", + } + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + quantization_config = HiggsConfig() + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, device_map=cls.device_map, quantization_config=quantization_config + ) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_quantized_model_conversion(self): + """ + Simple test that checks if the quantized model has been converted properly + """ + + from transformers.integrations import HiggsLinear, replace_with_higgs_linear + + model_id = "facebook/opt-350m" + config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5") + quantization_config = HiggsConfig() + + with init_empty_weights(): + model = OPTForCausalLM(config) + + nb_linears = 0 + for module in model.modules(): + if isinstance(module, torch.nn.Linear): + nb_linears += 1 + + model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config) + nb_fbgemm_linear = 0 + for module in model.modules(): + if isinstance(module, HiggsLinear): + nb_fbgemm_linear += 1 + + self.assertEqual(nb_linears - 1, nb_fbgemm_linear) + + with init_empty_weights(): + model = OPTForCausalLM(config) + quantization_config = HiggsConfig(linear_weights_not_to_quantize=["fc1.weight"]) + model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config) + nb_fbgemm_linear = 0 + for module in model.modules(): + if isinstance(module, HiggsLinear): + nb_fbgemm_linear += 1 + + self.assertEqual(nb_linears - 24, nb_fbgemm_linear) + + def test_quantized_model(self): + """ + Simple test that checks if the quantized model is working properly + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_save_pretrained(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_change_loading_attributes(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + quantization_config = HiggsConfig() + + model = AutoModelForCausalLM.from_pretrained( + tmpdirname, device_map=self.device_map, quantization_config=quantization_config + ) + + self.assertEqual(model.model.layers[1].mlp.down_proj.input_scale_ub.item(), 1000.0) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_quantized_model_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly with multiple GPUs + set CUDA_VISIBLE_DEVICES=0,1 if you have more than 2 GPUS + """ + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + quantization_config = HiggsConfig() + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, device_map="auto", quantization_config=quantization_config + ) + self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1}) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + def test_quantized_model_offload(self): + """ + Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded + """ + quantization_config = HiggsConfig() + + with self.assertRaisesRegex( + ValueError, "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device." + ): + AutoModelForCausalLM.from_pretrained( + self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config + ) + + def test_save_pretrained_offload(self): + """ + Simple test that checks if the saved quantized model is working properly cpu/disk offload + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map) + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + @require_torch_multi_gpu + def test_save_pretrained_multi_gpu(self): + """ + Simple test that checks if the quantized model is working properly after being saved and loaded + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname) + + model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="auto") + self.assertTrue(set(model.hf_device_map.values()) == {0, 1}) + + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) + + output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + + +# @require_torch_gpu +# @require_accelerate +# @require_flute_hadamard +# class HiggsLinearTest(unittest.TestCase): +# def test_linear_preserves_shape(self): +# """ +# Test that HiggsLinear preserves shape when in_features == out_features. +# """ +# from transformers.integrations import HiggsLinear + +# with init_empty_weights(include_buffers=True): +# linear = HiggsLinear(1024, 1024, num_bits=4, num_sms_packed=128, bias=True) +# x = torch.rand((17, 23, 1024)) + +# # x_ = linear(x) +# # self.assertEqual(x_.shape, x.shape) # TODO: Fix this + +# def test_linear_with_diff_feature_size_preserves_shape(self): +# """ +# Test that HiggsLinear generates the correct shape when in_features != out_features. +# """ +# from transformers.integrations import HiggsLinear + +# with init_empty_weights(include_buffers=True): +# linear = HiggsLinear(1024, 2048, num_bits=4, num_sms_packed=128, bias=True) +# x = torch.rand((17, 23, 1024)) + +# # x_ = linear(x) +# # self.assertEqual(x_.shape, (17, 23, 2048)) # TODO: Fix this From c2bcf39bb68e1b9b3cb518341110b59df1426f68 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:11:36 +0100 Subject: [PATCH 07/25] higgs tests passing --- .../Dockerfile | 3 +- src/transformers/integrations/higgs.py | 16 ++++-- .../quantizers/quantizer_higgs.py | 40 +++++++++----- tests/quantization/higgs/__init__.py | 0 tests/quantization/higgs/test_higgs.py | 52 ++----------------- 5 files changed, 44 insertions(+), 67 deletions(-) create mode 100644 tests/quantization/higgs/__init__.py diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index a7bf472033c51b..dabcaf103c2e4d 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -64,8 +64,9 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto # Add eetq for quantization testing RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git -# Add FLUTE for quantization testing +# Add flute-kernel and fast_hadamard_transform for quantization testing RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6 +RUN python3 -m pip install --no-cache-dir fast_hadamard_transform==1.0.4.post1 # When installing in editable mode, `transformers` is not recognized as a package. # this line must be added in order for python to be aware of transformers. diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index be3860103312dd..ea5d3b321a938b 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -317,7 +317,10 @@ def get_higgs_grid(p: int, n: int) -> torch.Tensor: def quantize_with_higgs(weight: torch.Tensor, bits: int = 4, p: int = 2): assert len(weight.shape) == 2, "Only 2D weights are supported for now" - assert weight.device.type == "cuda", "Only CUDA devices are supported for now" + if weight.device.type != "cuda": + raise ValueError( + "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device." + ) grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device) grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2 @@ -360,7 +363,7 @@ def quantize_with_higgs(weight: torch.Tensor, bits: int = 4, p: int = 2): "weight": weight, "scales": scales, "tables": tables, - "tables2": tables2, + "tables2": tables2.view(dtype=torch.float16), } @@ -385,7 +388,7 @@ def __init__( assert num_bits in [2, 3, 4] self.weight = nn.Parameter( - torch.empty((in_features * num_bits // 16, out_features), dtype=torch.int16, device=device), + torch.empty((out_features * num_bits // 16, in_features), dtype=torch.int16, device=device), requires_grad=False, ) self.scales = nn.Parameter( @@ -393,7 +396,7 @@ def __init__( ) self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False) self.tables2 = nn.Parameter( - torch.empty((2**num_bits, 2**num_bits, 1), dtype=torch.float32, device=device), requires_grad=False + torch.empty((2**num_bits, 2**num_bits, 2), dtype=dtype, device=device), requires_grad=False ) if bias: @@ -411,12 +414,15 @@ def forward(self, x): x = hadamard_transform(x, scale=1 / 32) x = x.reshape(orig_shape) + if self.workspace is None: + raise Exception("Workspace must be set before calling forward") + return flute.qgemm_simple( x, self.weight, self.scales, self.tables, - self.tables2, + self.tables2.view(dtype=torch.float32), self.workspace, self.num_bits, 256, diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 9903e08ad0b717..80859833bbb434 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -91,11 +91,6 @@ def create_quantized_param( """ Quantizes weights into weight and weight_scale """ - import flute.utils - - if target_device not in model.flute_workspaces: - model.flute_workspaces[target_device] = flute.utils.make_workspace_streamk(device=target_device) - flute_dict = quantize_with_higgs( param_value.to(target_device), self.quantization_config.bits, @@ -107,17 +102,15 @@ def create_quantized_param( module, tensor_name = get_module_from_name(model, param_name) for key, value in flute_dict.items(): if key in module._parameters: - module._parameters[key] = value + module._parameters[key] = torch.nn.Parameter(value, requires_grad=False) elif key in module._buffers: - module._buffers[key] = value + module._buffers[key] = torch.nn.Buffer(value) else: raise ValueError(f"Unexpected key {key} in module {module}") if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) - module.workspace = model.flute_workspaces[target_device] - def _process_model_before_weight_loading( self, model: "PreTrainedModel", @@ -128,10 +121,33 @@ def _process_model_before_weight_loading( quantization_config=self.quantization_config, ) model.config.quantization_config = self.quantization_config - model.flute_workspaces = {} def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): - return model + import flute.utils + + flute_workspaces = {} + for name, module in model.named_modules(): + if isinstance(module, HiggsLinear): + if module.weight.device not in flute_workspaces: + flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( + device=module.weight.device + ) + module.workspace = flute_workspaces[module.weight.device] + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + from ..integrations import HiggsLinear + + not_missing_keys = [] + for name, module in model.named_modules(): + if isinstance(module, HiggsLinear): + for missing in missing_keys: + if ( + (name in missing or name in f"{prefix}.{missing}") + and not missing.endswith(".weight") + and not missing.endswith(".bias") + ): + not_missing_keys.append(missing) + return [k for k in missing_keys if k not in not_missing_keys] @property def is_trainable(self, model: Optional["PreTrainedModel"] = None): @@ -149,7 +165,7 @@ def check_quantized_param( **kwargs, ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - if isinstance(module, HiggsLinear) and tensor_name == "weight": + if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16: # Add here check for loaded components' dtypes once serialization is implemented return True else: diff --git a/tests/quantization/higgs/__init__.py b/tests/quantization/higgs/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py index da89deee197892..41cb3c278e80df 100644 --- a/tests/quantization/higgs/test_higgs.py +++ b/tests/quantization/higgs/test_higgs.py @@ -67,10 +67,10 @@ class HiggsTest(unittest.TestCase): model_name = "meta-llama/Meta-Llama-3.1-8B" - input_text = "What are we having for dinner?" - max_new_tokens = 9 + input_text = "A quick brown fox jumps over the" + max_new_tokens = 2 - EXPECTED_OUTPUT = "What are we having for dinner?\nI'm having a steak and a salad" + EXPECTED_OUTPUT = "A quick brown fox jumps over the lazy dog" device_map = "cuda" @@ -190,26 +190,6 @@ def test_save_pretrained(self): output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - def test_change_loading_attributes(self): - """ - Simple test that checks if the quantized model is working properly after being saved and loaded - """ - with tempfile.TemporaryDirectory() as tmpdirname: - self.quantized_model.save_pretrained(tmpdirname) - - quantization_config = HiggsConfig() - - model = AutoModelForCausalLM.from_pretrained( - tmpdirname, device_map=self.device_map, quantization_config=quantization_config - ) - - self.assertEqual(model.model.layers[1].mlp.down_proj.input_scale_ub.item(), 1000.0) - - input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - @require_torch_multi_gpu def test_quantized_model_multi_gpu(self): """ @@ -226,32 +206,6 @@ def test_quantized_model_multi_gpu(self): output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - def test_quantized_model_offload(self): - """ - Simple test that checks if the quantized model returns an error when loading with cpu/disk offloaded - """ - quantization_config = HiggsConfig() - - with self.assertRaisesRegex( - ValueError, "You are attempting to load an FP8 model with a device_map that contains a CPU or disk device." - ): - AutoModelForCausalLM.from_pretrained( - self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config - ) - - def test_save_pretrained_offload(self): - """ - Simple test that checks if the saved quantized model is working properly cpu/disk offload - """ - with tempfile.TemporaryDirectory() as tmpdirname: - self.quantized_model.save_pretrained(tmpdirname) - - input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device) - - quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.offload_device_map) - output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - @require_torch_multi_gpu def test_save_pretrained_multi_gpu(self): """ From a1e7b35cadbf23d1fb8df06109915b8e3a04e5cd Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:18:15 +0100 Subject: [PATCH 08/25] protecting torch import --- src/transformers/quantizers/quantizer_higgs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 80859833bbb434..875892d1a9ffd3 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -20,14 +20,14 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -# if is_torch_available(): -import torch - from ..integrations import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear -from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, logging +from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) From 8f1a0a6a1e40bfa56c06fc092c42daadb32cdfd2 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:22:54 +0100 Subject: [PATCH 09/25] removed torch.Tensor type annotations --- src/transformers/integrations/higgs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index ea5d3b321a938b..9535931f27f4eb 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -49,7 +49,7 @@ def pad_to_block(tensor, dims, had_block_size, value=0): return nn.functional.pad(tensor, pad_dims, "constant", value) -def get_higgs_grid(p: int, n: int) -> torch.Tensor: +def get_higgs_grid(p: int, n: int): if (p, n) == (2, 256): return torch.tensor( [ @@ -315,7 +315,7 @@ def get_higgs_grid(p: int, n: int) -> torch.Tensor: raise NotImplementedError(f"Unsupported p={p}, n={n}") -def quantize_with_higgs(weight: torch.Tensor, bits: int = 4, p: int = 2): +def quantize_with_higgs(weight, bits: int = 4, p: int = 2): assert len(weight.shape) == 2, "Only 2D weights are supported for now" if weight.device.type != "cuda": raise ValueError( From 120f36005869f96179842f945ae4a8b0ddae8dd9 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:27:45 +0100 Subject: [PATCH 10/25] torch.nn.Module inheritance fix maybe --- src/transformers/integrations/higgs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 9535931f27f4eb..86aef13274b0f1 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -24,7 +24,7 @@ if is_torch_available(): import torch - import torch.nn as nn + from torch import nn if is_flute_available(): @@ -367,7 +367,7 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2): } -class HiggsLinear(nn.Module): +class HiggsLinear(torch.nn.Module): def __init__( self, in_features: int, From fdb71a5dc7e90cfa6a77c026ae79c3dda3780a92 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:35:20 +0100 Subject: [PATCH 11/25] hide inputs inside quantizer calls --- src/transformers/quantizers/quantizer_higgs.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 875892d1a9ffd3..ae20831f3d60aa 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -20,7 +20,6 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..integrations import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging from ..utils.quantization_config import QuantizationConfigMixin @@ -88,6 +87,8 @@ def create_quantized_param( state_dict: Dict[str, Any], unexpected_keys: Optional[List[str]] = None, ): + from ..integrations import quantize_with_higgs + """ Quantizes weights into weight and weight_scale """ @@ -116,6 +117,8 @@ def _process_model_before_weight_loading( model: "PreTrainedModel", **kwargs, ): + from ..integrations import replace_with_higgs_linear + replace_with_higgs_linear( model, quantization_config=self.quantization_config, @@ -125,6 +128,8 @@ def _process_model_before_weight_loading( def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): import flute.utils + from ..integrations import HiggsLinear + flute_workspaces = {} for name, module in model.named_modules(): if isinstance(module, HiggsLinear): @@ -164,6 +169,8 @@ def check_quantized_param( state_dict: Dict[str, Any], **kwargs, ) -> bool: + from ..integrations import HiggsLinear + module, tensor_name = get_module_from_name(model, param_name) if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16: # Add here check for loaded components' dtypes once serialization is implemented From 127c5f0bc8983670d6d8232c893e33fdac26f9e7 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 15:39:03 +0100 Subject: [PATCH 12/25] style structure something --- src/transformers/integrations/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index a12552b10841d3..c2c13a5f85239f 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -54,7 +54,6 @@ ], "eetq": ["replace_with_eetq_linear"], "fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"], - "higgs": ["HiggsLinear", "quantize_with_higgs", "replace_with_higgs_linear"], "fsdp": ["is_fsdp_managed_module"], "ggml": [ "GGUF_CONFIG_MAPPING", @@ -64,6 +63,7 @@ "load_dequant_gguf_tensor", "load_gguf", ], + "higgs": ["HiggsLinear", "quantize_with_higgs", "replace_with_higgs_linear"], "hqq": ["prepare_for_hqq_linear"], "integration_utils": [ "INTEGRATION_TO_CALLBACK", From 0de97f141f8870711ea3568d4148c00bedd90f6c Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 20:30:29 +0300 Subject: [PATCH 13/25] Update src/transformers/quantizers/quantizer_higgs.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/quantizers/quantizer_higgs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index ae20831f3d60aa..7883147798d60d 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -47,7 +47,6 @@ class HiggsHfQuantizer(HfQuantizer): requires_calibration = False requires_parameters_quantization = True required_packages = ["flute-kernel", "fast_hadamard_transform"] - optimum_quantizer = None def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) From 1f08cb0b54a6b36b8711b91c1531087cd7fcd7c4 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 19:58:08 +0100 Subject: [PATCH 14/25] reworked num_sms --- src/transformers/integrations/higgs.py | 5 +-- .../quantizers/quantizer_higgs.py | 43 +++++++++++++++++++ src/transformers/utils/quantization_config.py | 1 - 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 86aef13274b0f1..68bdc4c0bdc339 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -373,7 +373,6 @@ def __init__( in_features: int, out_features: int, num_bits: int, - num_sms_packed: int, bias=True, dtype: torch.dtype = None, device: torch.device = None, @@ -382,7 +381,8 @@ def __init__( self.in_features = in_features self.out_features = out_features self.num_bits = num_bits - self.num_sms_packed = num_sms_packed + + self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False) assert in_features % 256 == 0 assert num_bits in [2, 3, 4] @@ -486,7 +486,6 @@ def replace_with_higgs_linear( out_features, bias=module.bias is not None, num_bits=quantization_config.bits, - num_sms_packed=quantization_config.num_sms_packed, ) has_been_replaced = True diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index ae20831f3d60aa..f6ee7a29b3252a 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -39,6 +39,18 @@ def find_parent(model, name): return parent +def get_num_sms_from_device(device): + target_device_cc = torch.cuda.get_device_capability(device=device) + if target_device_cc == (8, 6): + return 84 + elif target_device_cc == (8, 0): + return 108 + elif target_device_cc == (8, 9): + return 128 + else: + raise NotImplementedError(f"Device capability {target_device_cc} not supported for FLUTE (yet?)") + + class HiggsHfQuantizer(HfQuantizer): """ Quantizer of the HIGGS method. Enables the loading of prequantized models. @@ -112,6 +124,11 @@ def create_quantized_param( if unexpected_keys is not None and param_name in unexpected_keys: unexpected_keys.remove(param_name) + module.num_sms_packed = torch.nn.Parameter( + torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32), + requires_grad=False, + ) + def _process_model_before_weight_loading( self, model: "PreTrainedModel", @@ -133,12 +150,38 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs flute_workspaces = {} for name, module in model.named_modules(): if isinstance(module, HiggsLinear): + # Every HiggsLinear needs a "workspace": a buffer for the unpacking operation. + # This buffer needs to be on the same device as the weights, but can be reused across modules otherwise. if module.weight.device not in flute_workspaces: flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( device=module.weight.device ) module.workspace = flute_workspaces[module.weight.device] + # FLUTE weights are packed in a way that is optimized for a specific number of SMs (GPU streaming multiprocessors). + # If the model is loaded on a different device than the one it was saved on, we need to repack the weights. + if module.num_sms_packed.item() != get_num_sms_from_device(module.weight.device): + new_device = module.weight.device + new_num_sms = get_num_sms_from_device(new_device) + module.weight.data = flute.utils.pack( + flute.utils.unpack( + weight=module.weight.data, + scales=module.scales.data, + workspace=module.workspace, + num_bits=module.num_bits, + group_size=256, + num_sms_packed=module.num_sms_packed.item(), + ) + .T.contiguous() + .cpu(), + module.num_bits, + 256, + ).to(device=new_device) + module.num_sms_packed = torch.nn.Parameter( + torch.tensor(new_num_sms, device=new_device, dtype=torch.int32), + requires_grad=False, + ) + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: from ..integrations import HiggsLinear diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index d8080434b296e5..ac5594f52f229e 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1252,7 +1252,6 @@ def __init__( self.bits = bits self.p = p self.linear_weights_not_to_quantize = linear_weights_not_to_quantize - self.num_sms_packed = 128 self.post_init() From 96023ab6a06d35a88dfb5fbe2fe00a46c33a4374 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Thu, 28 Nov 2024 22:18:59 +0300 Subject: [PATCH 15/25] Update src/transformers/integrations/higgs.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/integrations/higgs.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 68bdc4c0bdc339..899ec024cdde0c 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -451,18 +451,6 @@ def replace_with_higgs_linear( A boolean that indicates if the conversion has been successful or not. This is used for recursion and should not be passed by the user. """ - if not is_flute_available(): - raise ValueError("FLUTE is not available. Please install it with `pip install flute-kernel`") - - if not is_hadamard_available(): - raise ValueError( - "Fast Hadamard Transform is not available. Please install it with `pip install fast_hadamard_transform`" - ) - - if not is_accelerate_available(): - raise ValueError( - f"HIGGS requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" - ) from accelerate import init_empty_weights From 60ce44be17f181a02398918f0598101d86e1036f Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 15:52:49 +0100 Subject: [PATCH 16/25] revamped device checks --- src/transformers/integrations/higgs.py | 6 ------ src/transformers/quantizers/quantizer_higgs.py | 14 +++++--------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 899ec024cdde0c..546ba016f0f58c 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -14,8 +14,6 @@ "HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file" from ..utils import ( - ACCELERATE_MIN_VERSION, - is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, @@ -317,10 +315,6 @@ def get_higgs_grid(p: int, n: int): def quantize_with_higgs(weight, bits: int = 4, p: int = 2): assert len(weight.shape) == 2, "Only 2D weights are supported for now" - if weight.device.type != "cuda": - raise ValueError( - "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device." - ) grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device) grid_norm_2 = torch.linalg.norm(grid, axis=-1) ** 2 diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index c672be679e65a8..e590635df4acf3 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -65,6 +65,9 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): self.quantization_config = quantization_config def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.") + if not is_accelerate_available(): raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`") @@ -78,15 +81,8 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: - if torch.cuda.is_available(): - torch_dtype = torch.float16 - logger.info( - "CUDA available. Assuming HIGGS inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." - ) - else: - raise NotImplementedError( - "HIGGS quantization is only supported on GPU. Please use a different quantizer." - ) + torch_dtype = torch.float16 + return torch_dtype def create_quantized_param( From 81424439afb2bf04708ba170261494d13d5f5c5c Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 15:56:57 +0100 Subject: [PATCH 17/25] docstring upd --- src/transformers/utils/quantization_config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 05a5719a00e52d..5fc55b4f16b1ee 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1240,9 +1240,9 @@ class HiggsConfig(QuantizationConfigMixin): Args: bits (int, *optional*, defaults to 4): - Number of bits to use for quantization. Default is 4. + Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4. p (int, *optional*, defaults to 2): - Parameter for the HIGGS quantization method. Default is 2. + Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2. linear_weights_not_to_quantize (`list`, *optional*, default to ["lm_head.weight"]): List of linear weight names that should not be quantized. """ @@ -1267,7 +1267,8 @@ def post_init(self): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ - return + assert self.bits in [2, 3, 4], "bits must be 2, 3 or 4" + assert self.p in [1, 2], "p must be 1 or 2. 2 is always better in practice" @dataclass From 1d636ac06e922513417a821dbb8a27f93ed5b14e Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 18:06:47 +0300 Subject: [PATCH 18/25] Update src/transformers/quantizers/quantizer_higgs.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- src/transformers/quantizers/quantizer_higgs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index e590635df4acf3..b87e413117154a 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -48,7 +48,7 @@ def get_num_sms_from_device(device): elif target_device_cc == (8, 9): return 128 else: - raise NotImplementedError(f"Device capability {target_device_cc} not supported for FLUTE (yet?)") + raise NotImplementedError(f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus") class HiggsHfQuantizer(HfQuantizer): From 66ece1d19e9ee24d838078f743b6da8bbba02c4e Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:10:21 +0100 Subject: [PATCH 19/25] edited tests and device map assertions --- .../quantizers/quantizer_higgs.py | 13 +- tests/quantization/higgs/test_higgs.py | 117 ++++-------------- 2 files changed, 36 insertions(+), 94 deletions(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index e590635df4acf3..8c3ceeaa5d74e9 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -64,7 +64,7 @@ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): super().__init__(quantization_config, **kwargs) self.quantization_config = quantization_config - def validate_environment(self, *args, **kwargs): + def validate_environment(self, device_map, **kwargs): if not torch.cuda.is_available(): raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.") @@ -79,6 +79,17 @@ def validate_environment(self, *args, **kwargs): "Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`" ) + if device_map is None: + raise ValueError( + "You are attempting to load a HIGGS model without setting device_map." + " Please set device_map comprised of 'cuda' devices." + ) + elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): + raise ValueError( + "You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: torch_dtype = torch.float16 diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py index 41cb3c278e80df..9a42908d69e479 100644 --- a/tests/quantization/higgs/test_higgs.py +++ b/tests/quantization/higgs/test_higgs.py @@ -36,27 +36,27 @@ from accelerate import init_empty_weights -# @require_torch_gpu -# class HiggsConfigTest(unittest.TestCase): -# def test_to_dict(self): -# """ -# Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object -# """ -# quantization_config = HiggsConfig() -# config_to_dict = quantization_config.to_dict() +@require_torch_gpu +class HiggsConfigTest(unittest.TestCase): + def test_to_dict(self): + """ + Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object + """ + quantization_config = HiggsConfig() + config_to_dict = quantization_config.to_dict() -# for key in config_to_dict: -# self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) + for key in config_to_dict: + self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) -# def test_from_dict(self): -# """ -# Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict -# """ -# dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"} -# quantization_config = HiggsConfig.from_dict(dict) + def test_from_dict(self): + """ + Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict + """ + dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"} + quantization_config = HiggsConfig.from_dict(dict) -# self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) -# self.assertEqual(dict["quant_method"], quantization_config.quant_method) + self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) + self.assertEqual(dict["quant_method"], quantization_config.quant_method) @slow @@ -74,44 +74,6 @@ class HiggsTest(unittest.TestCase): device_map = "cuda" - offload_device_map = { - "model.embed_tokens": 0, - "model.layers.0": 0, - "model.layers.1": 0, - "model.layers.2": 0, - "model.layers.3": 0, - "model.layers.4": 0, - "model.layers.5": 0, - "model.layers.6": 0, - "model.layers.7": 0, - "model.layers.8": 0, - "model.layers.9": 0, - "model.layers.10": 0, - "model.layers.11": 0, - "model.layers.12": 0, - "model.layers.13": 0, - "model.layers.14": 0, - "model.layers.15": 0, - "model.layers.16": "cpu", - "model.layers.17": "cpu", - "model.layers.18": "cpu", - "model.layers.19": "cpu", - "model.layers.20": "disk", - "model.layers.21": "disk", - "model.layers.22": "disk", - "model.layers.23": "disk", - "model.layers.24": "disk", - "model.layers.25": "disk", - "model.layers.26": "disk", - "model.layers.27": "disk", - "model.layers.28": "disk", - "model.layers.29": "disk", - "model.layers.30": "disk", - "model.layers.31": "disk", - "model.norm": "disk", - "lm_head": "disk", - } - # called only once for all test in this class @classmethod def setUpClass(cls): @@ -149,23 +111,23 @@ def test_quantized_model_conversion(self): nb_linears += 1 model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config) - nb_fbgemm_linear = 0 + nb_higgs_linear = 0 for module in model.modules(): if isinstance(module, HiggsLinear): - nb_fbgemm_linear += 1 + nb_higgs_linear += 1 - self.assertEqual(nb_linears - 1, nb_fbgemm_linear) + self.assertEqual(nb_linears - 1, nb_higgs_linear) with init_empty_weights(): model = OPTForCausalLM(config) quantization_config = HiggsConfig(linear_weights_not_to_quantize=["fc1.weight"]) model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config) - nb_fbgemm_linear = 0 + nb_higgs_linear = 0 for module in model.modules(): if isinstance(module, HiggsLinear): - nb_fbgemm_linear += 1 + nb_higgs_linear += 1 - self.assertEqual(nb_linears - 24, nb_fbgemm_linear) + self.assertEqual(nb_linears - 24, nb_higgs_linear) def test_quantized_model(self): """ @@ -221,34 +183,3 @@ def test_save_pretrained_multi_gpu(self): output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) - - -# @require_torch_gpu -# @require_accelerate -# @require_flute_hadamard -# class HiggsLinearTest(unittest.TestCase): -# def test_linear_preserves_shape(self): -# """ -# Test that HiggsLinear preserves shape when in_features == out_features. -# """ -# from transformers.integrations import HiggsLinear - -# with init_empty_weights(include_buffers=True): -# linear = HiggsLinear(1024, 1024, num_bits=4, num_sms_packed=128, bias=True) -# x = torch.rand((17, 23, 1024)) - -# # x_ = linear(x) -# # self.assertEqual(x_.shape, x.shape) # TODO: Fix this - -# def test_linear_with_diff_feature_size_preserves_shape(self): -# """ -# Test that HiggsLinear generates the correct shape when in_features != out_features. -# """ -# from transformers.integrations import HiggsLinear - -# with init_empty_weights(include_buffers=True): -# linear = HiggsLinear(1024, 2048, num_bits=4, num_sms_packed=128, bias=True) -# x = torch.rand((17, 23, 1024)) - -# # x_ = linear(x) -# # self.assertEqual(x_.shape, (17, 23, 2048)) # TODO: Fix this From 1cb9f0c14275b4c54cdc271c0e9b7f436ebfa7b9 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:13:02 +0100 Subject: [PATCH 20/25] minor edits --- src/transformers/quantizers/quantizer_higgs.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/transformers/quantizers/quantizer_higgs.py b/src/transformers/quantizers/quantizer_higgs.py index 6c7cd2662c181e..9e42069f57eecc 100644 --- a/src/transformers/quantizers/quantizer_higgs.py +++ b/src/transformers/quantizers/quantizer_higgs.py @@ -30,15 +30,6 @@ logger = logging.get_logger(__name__) -# Finds the parent of a node module named "name" -def find_parent(model, name): - module_tree = name.split(".")[:-1] - parent = model - for m in module_tree: - parent = parent._modules[m] - return parent - - def get_num_sms_from_device(device): target_device_cc = torch.cuda.get_device_capability(device=device) if target_device_cc == (8, 6): @@ -48,12 +39,14 @@ def get_num_sms_from_device(device): elif target_device_cc == (8, 9): return 128 else: - raise NotImplementedError(f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus") + raise NotImplementedError( + f"Device capability {target_device_cc} not supported for FLUTE (yet?) to verify your device capability check out https://developer.nvidia.com/cuda-gpus" + ) class HiggsHfQuantizer(HfQuantizer): """ - Quantizer of the HIGGS method. Enables the loading of prequantized models. + Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. """ requires_calibration = False From 257c39b0d68203c9c763305d84ee82f5da853b37 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:15:54 +0100 Subject: [PATCH 21/25] updated flute cuda version in docker --- docker/transformers-quantization-latest-gpu/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/transformers-quantization-latest-gpu/Dockerfile b/docker/transformers-quantization-latest-gpu/Dockerfile index a917e6e8b4fe11..0e1ebd88adabbf 100755 --- a/docker/transformers-quantization-latest-gpu/Dockerfile +++ b/docker/transformers-quantization-latest-gpu/Dockerfile @@ -67,7 +67,7 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git # Add flute-kernel and fast_hadamard_transform for quantization testing -RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6 +RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6 -i https://flute-ai.github.io/whl/cu118 RUN python3 -m pip install --no-cache-dir fast_hadamard_transform==1.0.4.post1 # When installing in editable mode, `transformers` is not recognized as a package. From f82d1a387183bd55e16c34c5beeaaaf924023d81 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:22:46 +0100 Subject: [PATCH 22/25] Added p=1 and 2,3bit HIGGS --- src/transformers/integrations/higgs.py | 126 +++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 546ba016f0f58c..205ad6badb25bf 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -309,6 +309,132 @@ def get_higgs_grid(p: int, n: int): [-0.7835946083068848, 1.6869111061096191], ] ) + if (p, n) == (2, 64): + return torch.tensor( + [ + [-2.7216711044311523, 0.14431366324424744], + [-0.766914427280426, 1.7193410396575928], + [-2.2575762271881104, 1.2476624250411987], + [1.233758807182312, -2.3560616970062256], + [0.8701965808868408, -0.2649352252483368], + [1.4506438970565796, 2.1776366233825684], + [-0.06305818259716034, 1.9049758911132812], + [2.536226511001587, 0.563927412033081], + [0.4599496126174927, -1.8745561838150024], + [-1.900517225265503, -0.30703988671302795], + [0.09386251866817474, 0.8755807280540466], + [1.946500539779663, -0.6743080615997314], + [2.1338934898376465, 1.4581491947174072], + [0.9429940581321716, -0.8038390278816223], + [2.0697755813598633, -1.614896535873413], + [0.772676408290863, 0.22017823159694672], + [1.0689979791641235, -1.525044322013855], + [0.6813604831695557, 1.1345642805099487], + [0.4706456661224365, 2.606626272201538], + [-1.294018030166626, -0.4372096061706543], + [-0.09134224057197571, 0.4610418677330017], + [-0.7907772064208984, -0.48412787914276123], + [0.060459110885858536, -0.9172890186309814], + [-0.5855047702789307, 2.56172513961792], + [0.11484206467866898, -2.659848213195801], + [-1.5893300771713257, 2.188580274581909], + [1.6750942468643188, 0.7089915871620178], + [-0.445697546005249, 0.7452405095100403], + [-1.8539940118789673, -1.8377939462661743], + [-1.5791912078857422, -1.017285943031311], + [-1.030419945716858, -1.5746369361877441], + [-1.9511750936508179, 0.43696075677871704], + [-0.3446580767631531, -1.8953213691711426], + [-1.4219647645950317, 0.7676230669021606], + [-0.9191089272499084, 0.5021472573280334], + [0.20464491844177246, 1.3684605360031128], + [0.5402919054031372, 0.6699410676956177], + [1.8903915882110596, 0.03638288006186485], + [0.4723062515258789, -0.6216739416122437], + [-0.41345009207725525, -0.22752176225185394], + [2.7119064331054688, -0.5111885070800781], + [1.065286636352539, 0.6950305700302124], + [0.40629103779792786, -0.14339995384216309], + [1.2815024852752686, 0.17108257114887238], + [0.01785222627222538, -0.43778058886528015], + [0.054590027779340744, -1.4225547313690186], + [0.3076786696910858, 0.30697619915008545], + [-0.9498570561408997, -0.9576997756958008], + [-2.4640724658966064, -0.9660449028015137], + [1.3714425563812256, -0.39760473370552063], + [-0.4857747256755829, 0.2386789172887802], + [1.2797833681106567, 1.3097363710403442], + [0.5508887767791748, -1.1777795553207397], + [-1.384316325187683, 0.1465839296579361], + [-0.46556955575942993, -1.2442727088928223], + [-0.3915477693080902, -0.7319604158401489], + [-1.4005504846572876, 1.3890998363494873], + [-0.8647305965423584, 1.0617644786834717], + [-0.8901953101158142, -0.01650036871433258], + [-0.9893633723258972, -2.4662880897521973], + [1.445534110069275, -1.049334168434143], + [-0.041650623083114624, 0.012734669260680676], + [-0.3302375078201294, 1.26217782497406], + [0.6934980154037476, 1.7714335918426514], + ] + ) + elif (p, n) == (2, 16): + return torch.tensor( + [ + [-0.8996632695198059, -1.6360418796539307], + [-0.961183488368988, 1.5999565124511719], + [-1.882026195526123, 0.678778350353241], + [0.36300793290138245, -1.9667866230010986], + [-0.6814072728157043, -0.576818585395813], + [0.7270012497901917, 0.6186859607696533], + [0.3359416127204895, 1.8371193408966064], + [1.859930396080017, 0.036668598651885986], + [0.17208248376846313, -0.9401724338531494], + [-1.7599700689315796, -0.6244229674339294], + [-0.8993809223175049, 0.32267823815345764], + [0.839488685131073, -0.3017036020755768], + [1.5314953327178955, 1.2942044734954834], + [-0.0011779458727687597, 0.00022069070837460458], + [1.4274526834487915, -1.207889199256897], + [-0.16123905777931213, 0.8787511587142944], + ] + ) + elif (p, n) == (1, 16): + return torch.tensor( + [ + [-2.7325894832611084], + [-2.069017171859741], + [-1.6180464029312134], + [-1.2562311887741089], + [-0.9423404335975647], + [-0.6567591428756714], + [-0.38804829120635986], + [-0.12839503586292267], + [0.12839503586292267], + [0.38804829120635986], + [0.6567591428756714], + [0.9423404335975647], + [1.2562311887741089], + [1.6180464029312134], + [2.069017171859741], + [2.7325894832611084], + ] + ) + elif (p, n) == (1, 8): + return torch.tensor( + [ + [-2.1519455909729004], + [-1.3439092636108398], + [-0.7560052871704102], + [-0.2450941801071167], + [0.2450941801071167], + [0.7560052871704102], + [1.3439092636108398], + [2.1519455909729004], + ] + ) + elif (p, n) == (1, 4): + return torch.tensor([[-1.5104175806045532], [-0.4527800381183624], [0.4527800381183624], [1.5104175806045532]]) else: raise NotImplementedError(f"Unsupported p={p}, n={n}") From b74798044bb0dab9e2a0a5218384afbb0707fa1a Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:24:29 +0100 Subject: [PATCH 23/25] flute version check update --- src/transformers/utils/import_utils.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5ea304d28bbd48..5cac0efd730e89 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -102,12 +102,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _fbgemm_gpu_available = _is_package_available("fbgemm_gpu") -try: - _flute_available = package_exists = ( - importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") == "0.2.6" - ) -except importlib.metadata.PackageNotFoundError: - _flute_available = False _galore_torch_available = _is_package_available("galore_torch") _lomo_available = _is_package_available("lomo_optim") _grokadamw_available = _is_package_available("grokadamw") @@ -614,7 +608,10 @@ def is_flax_available(): def is_flute_available(): - return _flute_available + try: + return importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") >= "0.2.6" + except importlib.metadata.PackageNotFoundError: + return False def is_ftfy_available(): From 398d5b187e3481fc5dec5d68f36436943ff1845e Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:29:36 +0100 Subject: [PATCH 24/25] incorporated `modules_to_not_convert` --- src/transformers/integrations/higgs.py | 8 +++----- src/transformers/utils/quantization_config.py | 12 ++++++------ tests/quantization/higgs/test_higgs.py | 6 +++--- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index 205ad6badb25bf..fae465c3b96765 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -580,11 +580,9 @@ def replace_with_higgs_linear( current_key_name.append(name) if isinstance(module, nn.Linear): - # Check if the current key is not in the `quantization_config.linear_weights_not_to_quantize` - current_key_name_str = ".".join(current_key_name) + ".weight" - if not any( - current_key_name_str.endswith(key) for key in quantization_config.linear_weights_not_to_quantize - ): + # Check if the current key is not in the `quantization_config.modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any(current_key_name_str.endswith(key) for key in quantization_config.modules_to_not_convert): with init_empty_weights(): in_features = module.in_features out_features = module.out_features diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 5fc55b4f16b1ee..d23e7ddea7d0c5 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1243,23 +1243,23 @@ class HiggsConfig(QuantizationConfigMixin): Number of bits to use for quantization. Can be 2, 3 or 4. Default is 4. p (int, *optional*, defaults to 2): Quantization grid dimension. 1 and 2 are supported. 2 is always better in practice. Default is 2. - linear_weights_not_to_quantize (`list`, *optional*, default to ["lm_head.weight"]): - List of linear weight names that should not be quantized. + modules_to_not_convert (`list`, *optional*, default to ["lm_head"]): + List of linear layers that should not be quantized. """ def __init__( self, bits: int = 4, p: int = 2, - linear_weights_not_to_quantize: Optional[List[str]] = None, + modules_to_not_convert: Optional[List[str]] = None, **kwargs, ): - if linear_weights_not_to_quantize is None: - linear_weights_not_to_quantize = ["lm_head.weight"] + if modules_to_not_convert is None: + modules_to_not_convert = ["lm_head"] self.quant_method = QuantizationMethod.HIGGS self.bits = bits self.p = p - self.linear_weights_not_to_quantize = linear_weights_not_to_quantize + self.modules_to_not_convert = modules_to_not_convert self.post_init() diff --git a/tests/quantization/higgs/test_higgs.py b/tests/quantization/higgs/test_higgs.py index 9a42908d69e479..fa524f276ce9ed 100644 --- a/tests/quantization/higgs/test_higgs.py +++ b/tests/quantization/higgs/test_higgs.py @@ -52,10 +52,10 @@ def test_from_dict(self): """ Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict """ - dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"} + dict = {"modules_to_not_convert": ["embed_tokens", "lm_head"], "quant_method": "higgs"} quantization_config = HiggsConfig.from_dict(dict) - self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) + self.assertEqual(dict["modules_to_not_convert"], quantization_config.modules_to_not_convert) self.assertEqual(dict["quant_method"], quantization_config.quant_method) @@ -120,7 +120,7 @@ def test_quantized_model_conversion(self): with init_empty_weights(): model = OPTForCausalLM(config) - quantization_config = HiggsConfig(linear_weights_not_to_quantize=["fc1.weight"]) + quantization_config = HiggsConfig(modules_to_not_convert=["fc1"]) model, _ = replace_with_higgs_linear(model, quantization_config=quantization_config) nb_higgs_linear = 0 for module in model.modules(): From 0ede69c2b9082692425ef0b0b45ae8ca020207c1 Mon Sep 17 00:00:00 2001 From: Andrei Panferov Date: Fri, 29 Nov 2024 16:40:18 +0100 Subject: [PATCH 25/25] less hardcoding --- src/transformers/integrations/higgs.py | 33 +++++++++++++++----------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/transformers/integrations/higgs.py b/src/transformers/integrations/higgs.py index fae465c3b96765..c2cbca483f0def 100644 --- a/src/transformers/integrations/higgs.py +++ b/src/transformers/integrations/higgs.py @@ -13,6 +13,8 @@ # limitations under the License. "HIGGS through FLUTE (Flexible Lookup Table Engine for LUT-quantized LLMs) integration file" +from math import sqrt + from ..utils import ( is_flute_available, is_hadamard_available, @@ -439,7 +441,7 @@ def get_higgs_grid(p: int, n: int): raise NotImplementedError(f"Unsupported p={p}, n={n}") -def quantize_with_higgs(weight, bits: int = 4, p: int = 2): +def quantize_with_higgs(weight, bits: int = 4, p: int = 2, group_size: int = 256, hadamard_size: int = 1024): assert len(weight.shape) == 2, "Only 2D weights are supported for now" grid = get_higgs_grid(p, 2 ** (p * bits)).to(weight.device) @@ -448,11 +450,11 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2): device = weight.device weight = weight.clone().float() # Pad to Hadamard transform size - weight = pad_to_block(weight, [1], 1024) + weight = pad_to_block(weight, [1], hadamard_size) # Scale and Hadamard transform - mult = weight.shape[1] // 1024 - weight = weight.reshape(-1, mult, 1024) + mult = weight.shape[1] // hadamard_size + weight = weight.reshape(-1, mult, hadamard_size) scales = torch.linalg.norm(weight, axis=-1) weight = hadamard_transform(weight, 1) / scales[:, :, None] @@ -466,14 +468,14 @@ def quantize_with_higgs(weight, bits: int = 4, p: int = 2): del weight codes = codes.reshape(codes.shape[0], -1) - scales = scales / 32 + scales = scales / sqrt(hadamard_size) weight, scales, tables, tables2 = prepare_data_transposed( codes, - torch.repeat_interleave(scales.half(), 1024 // 256, dim=1), + torch.repeat_interleave(scales.half(), hadamard_size // group_size, dim=1), grid.half(), num_bits=bits, - group_size=256, + group_size=group_size, vector_size=p, dtype=torch.float16, device=device, @@ -496,15 +498,18 @@ def __init__( bias=True, dtype: torch.dtype = None, device: torch.device = None, + group_size: int = 256, + hadamard_size: int = 1024, ): super().__init__() self.in_features = in_features self.out_features = out_features self.num_bits = num_bits - + self.group_size = group_size + self.hadamard_size = hadamard_size self.num_sms_packed = nn.Parameter(torch.tensor(-1, dtype=torch.int32, device=device), requires_grad=False) - assert in_features % 256 == 0 + assert in_features % group_size == 0 assert num_bits in [2, 3, 4] self.weight = nn.Parameter( @@ -512,7 +517,7 @@ def __init__( requires_grad=False, ) self.scales = nn.Parameter( - torch.empty((out_features, in_features // 256), dtype=dtype, device=device), requires_grad=False + torch.empty((out_features, in_features // group_size), dtype=dtype, device=device), requires_grad=False ) self.tables = nn.Parameter(torch.empty((2**num_bits,), dtype=dtype, device=device), requires_grad=False) self.tables2 = nn.Parameter( @@ -527,11 +532,11 @@ def __init__( self.workspace = None # must be set externally to be reused among layers def forward(self, x): - x = pad_to_block(x, [-1], 1024) + x = pad_to_block(x, [-1], self.hadamard_size) orig_shape = x.shape - x = x.reshape(-1, 1024) - x = hadamard_transform(x, scale=1 / 32) + x = x.reshape(-1, self.hadamard_size) + x = hadamard_transform(x, scale=1 / sqrt(self.hadamard_size)) x = x.reshape(orig_shape) if self.workspace is None: @@ -545,7 +550,7 @@ def forward(self, x): self.tables2.view(dtype=torch.float32), self.workspace, self.num_bits, - 256, + self.group_size, )