-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
327 additions
and
0 deletions.
There are no files selected for viewing
327 changes: 327 additions & 0 deletions
327
machine_learning_with_python/08.model_quantization.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,327 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e4ea6dad-a631-495c-b34b-614cf6f3adce", | ||
"metadata": {}, | ||
"source": [ | ||
"# Model Quantization Methods" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a7317247-28f2-4b21-ba7c-07cc6a770c62", | ||
"metadata": {}, | ||
"source": [ | ||
"* Model quantization methods aim to reduce the size and improve inference speed of the models." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ab826755-f2bc-47e7-85ef-91ef733dd0f1", | ||
"metadata": {}, | ||
"source": [ | ||
"### Memory Computation\n", | ||
"\n", | ||
"* Models can be trained and inferenced on different precision levels: float32, float16 (half precision), bfloat16 (developed by Google), int8, int4, int2.\n", | ||
"* Usually when we convert weights into a different precision, there is some loss of model performance.\n", | ||
"* Choosing the correct quantization is a tradeoff between accuracy, speed and memory." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1f38eb23-44a5-4673-9ad1-9c0c13a13a3c", | ||
"metadata": {}, | ||
"source": [ | ||
"### Estimating memory for a model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "39fa78a4-34f9-417d-8c14-a95c5a51077e", | ||
"metadata": {}, | ||
"source": [ | ||
"If a model is 8B parameters in float16, what is the memory consumed by it?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "0cf7b04e-10a1-4aed-9b10-3efc3c5013ef", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def get_required_memory(params, fp=16):\n", | ||
" parms = params # billions\n", | ||
" floating_point = fp\n", | ||
" parms_bytes = parms * (floating_point / 8)\n", | ||
" # 1 kilobyte = 1024 bytes, 1 megabyte = 1024 kilobytes, 1 gigabyte = 1024 megabytes\n", | ||
" parms_bytes_gb = parms_bytes / 1024 / 1024 / 1024\n", | ||
" return parms_bytes_gb" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 28, | ||
"id": "3e9188c6-a08a-4d44-b1ef-f7fd7c1d069b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"14.901161193847656" | ||
] | ||
}, | ||
"execution_count": 28, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"get_required_memory(8e9,fp=16)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"id": "4610be17-4ea6-444d-84c4-bbcc4d904439", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"7.450580596923828" | ||
] | ||
}, | ||
"execution_count": 29, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"get_required_memory(8e9,fp=8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 30, | ||
"id": "016922e1-f6f6-4a31-8472-3709a3eb4f4b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"3.725290298461914" | ||
] | ||
}, | ||
"execution_count": 30, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"get_required_memory(8e9,fp=4)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"id": "356488f1-f179-4639-b1d1-6d339130e481", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"130.385160446167" | ||
] | ||
}, | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"get_required_memory(70e9,fp=16)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3be3d0b5-f6c6-4535-9e85-56a453ed7a1f", | ||
"metadata": {}, | ||
"source": [ | ||
"### Quantization Methods" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "857279e2-abdf-4f97-a66b-e7fb8ff0bbbd", | ||
"metadata": {}, | ||
"source": [ | ||
"* Post-training Quantization: Quantization is done after training. It may lead to accuracy drop.\n", | ||
"* Quantization Aware Training: Involves training with quantization in mind\n", | ||
"* Mixed-Precision Quantization: Some weights are computed with higher precision & some weights are quantized to lower precision" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "faa97780-91ec-4ad7-816a-0e20deeb7a5a", | ||
"metadata": {}, | ||
"source": [ | ||
"#### K_S\n", | ||
"\n", | ||
"* Uniform quantization, divides the float unifromly into buckets to achieve desired quantization.\n", | ||
"* Most simple method, fast but may lose significant accuracy\n", | ||
"\n", | ||
"#### K_M\n", | ||
"\n", | ||
"* Non-uniform quantization, the distribution of the quantizer is learned with respect to weights of the model. It is more complex than K_S, but offers better performance. A simple k-means clustering would allow us to round the number to the nearest cluster.\n", | ||
"\n", | ||
"#### K_L\n", | ||
"\n", | ||
"* The k-l divergence between the original and quantized model weights is minimized.\n", | ||
"* Results in lower information loss\n", | ||
"* Computationally expensive" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "8156188d-5c21-4e44-9f71-0d14b6deb51c", | ||
"metadata": {}, | ||
"source": [ | ||
"### How to quantize a model?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 43, | ||
"id": "4c4c90cc-2525-4049-b552-68b760d57439", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import AutoModelForCausalLM\n", | ||
"from optimum.quanto import quantize, qint8" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 49, | ||
"id": "fb84a7a1-307f-41c6-a11d-428acabdf8ac", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = AutoModelForCausalLM.from_pretrained('gpt2')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 50, | ||
"id": "a1bf20c3-eb60-4724-9fa3-2ebe45a32abf", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Linear(in_features=768, out_features=50257, bias=False)" | ||
] | ||
}, | ||
"execution_count": 50, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model.lm_head" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 51, | ||
"id": "dbd1d1ef-c5b5-4a81-a66f-daab9ac5ab46", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"quantize(model, weights=qint8, activations=qint8)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 52, | ||
"id": "651596ed-25f7-4bda-8337-9c11ab56ec0c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from optimum.quanto import freeze\n", | ||
"\n", | ||
"freeze(model)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 53, | ||
"id": "674e3f04-6f6d-43fb-a2cf-6c5ea4d147e0", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"GPT2LMHeadModel(\n", | ||
" (transformer): GPT2Model(\n", | ||
" (wte): Embedding(50257, 768)\n", | ||
" (wpe): Embedding(1024, 768)\n", | ||
" (drop): Dropout(p=0.1, inplace=False)\n", | ||
" (h): ModuleList(\n", | ||
" (0-11): 12 x GPT2Block(\n", | ||
" (ln_1): QLayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | ||
" (attn): GPT2SdpaAttention(\n", | ||
" (c_attn): Conv1D()\n", | ||
" (c_proj): Conv1D()\n", | ||
" (attn_dropout): Dropout(p=0.1, inplace=False)\n", | ||
" (resid_dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" (ln_2): QLayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | ||
" (mlp): GPT2MLP(\n", | ||
" (c_fc): Conv1D()\n", | ||
" (c_proj): Conv1D()\n", | ||
" (act): NewGELUActivation()\n", | ||
" (dropout): Dropout(p=0.1, inplace=False)\n", | ||
" )\n", | ||
" )\n", | ||
" )\n", | ||
" (ln_f): QLayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | ||
" )\n", | ||
" (lm_head): QLinear(in_features=768, out_features=50257, bias=False)\n", | ||
")" | ||
] | ||
}, | ||
"execution_count": 53, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |