Skip to content

Commit

Permalink
Fix hardcoded rope_scale factor to 32 for Llama 3.2
Browse files Browse the repository at this point in the history
Differential Revision: D67061188

Pull Request resolved: pytorch#7272
  • Loading branch information
mergennachin authored Dec 11, 2024
1 parent 59df3fe commit 957259e
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 12 deletions.
1 change: 1 addition & 0 deletions backends/vulkan/docs/android_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ partially lower the Llama model to Vulkan.
# The files will usually be downloaded to ~/.llama
python -m examples.models.llama.export_llama \
--disable_dynamic_shape --vulkan -kv --use_sdpa_with_kv_cache -d fp32 \
--model "llama3_2" \
-c ~/.llama/checkpoints/Llama3.2-1B/consolidated.00.pth \
-p ~/.llama/checkpoints/Llama3.2-1B/params.json \
--metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ In this demo app, we support text-only inference with up-to-date Llama models an
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
```

### For Llama 3.2 1B and 3B QAT+LoRA models
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
```

### For Llama 3.2 1B and 3B BF16 models
Expand All @@ -72,7 +72,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
* Export Llama model and generate .pte file as below:

```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
```

For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ sh examples/models/llama/install_requirements.sh
Meta has released prequantized INT4 SpinQuant Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --use_spin_quant native --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_spinquant.pte"
```

### For Llama 3.2 1B and 3B QAT+LoRA models
Meta has released prequantized INT4 QAT+LoRA Llama 3.2 models that ExecuTorch supports on the XNNPACK backend.
* Export Llama model and generate .pte file as below:
```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -qat -lora 16 -kv --use_sdpa_with_kv_cache -X -d fp32 --xnnpack-extended-ops --preq_mode 8da4w_output_8da8w --preq_group_size 32 --max_seq_length 2048 --preq_embedding_quantize 8,0 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name "llama3_2_qat_lora.pte"
```

### For Llama 3.2 1B and 3B BF16 models
Expand All @@ -64,7 +64,7 @@ We have supported BF16 as a data type on the XNNPACK backend for Llama 3.2 1B/3B
* Export Llama model and generate .pte file as below:

```
python -m examples.models.llama.export_llama --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
python -m examples.models.llama.export_llama --model "llama3_2" --checkpoint <path-to-your-checkpoint.pth> --params <path-to-your-params.json> -kv --use_sdpa_with_kv_cache -X -d bf16 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --output_name="llama3_2_bf16.pte"
```

For more detail using Llama 3.2 lightweight models including prompt template, please go to our official [website](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-lightweight-models-(1b/3b)-).
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ LLAMA_CHECKPOINT=path/to/checkpoint.pth
LLAMA_PARAMS=path/to/params.json
python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
-kv \
Expand All @@ -189,6 +190,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
LLAMA_PARAMS=path/to/spinquant/params.json
python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
--use_sdpa_with_kv_cache \
Expand All @@ -214,6 +216,7 @@ LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
LLAMA_PARAMS=path/to/qlora/params.json
python -m examples.models.llama.export_llama \
--model "llama3_2" \
--checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
--params "${LLAMA_PARAMS:?}" \
-qat \
Expand Down
5 changes: 4 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ModelArgs:
)
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
rope_scale_factor: int = 8
# Additional Model Metadata needed at runtime
bos_idx: int = 1
eos_idx: int = 3
Expand Down Expand Up @@ -155,7 +156,9 @@ def __init__(self, params: ModelArgs):
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=self.params.use_scaled_rope
precompute_freqs_cis,
use_scaled=self.params.use_scaled_rope,
scale_factor=self.params.rope_scale_factor,
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
Expand Down
9 changes: 9 additions & 0 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def __init__(self, **kwargs):
enable_dynamic_shape=self.enable_dynamic_shape,
**params,
)

if model_args.use_scaled_rope:
# Older models don't have use_scaled_rope configuration
assert self.args.model not in ["llama2", "stories110m"]

# Llama3_2 and newer models in ExecuTorch repo should set larger scale factor
if self.args.model not in ["llama3", "llama3_1"]:
model_args.rope_scale_factor = 32

if kwargs.get("verbose", False):
print("============= weights ================")
print("{key} : {weights.numel()} : {weights.size()}")
Expand Down
14 changes: 9 additions & 5 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
# Different RoPE implementations

import math
from typing import Tuple
from typing import Optional, Tuple

import torch

# ======================== Stock Implementation ========================


def apply_scaling(freqs: torch.Tensor):
def apply_scaling(freqs: torch.Tensor, scale_factor: int):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
Expand All @@ -41,14 +40,19 @@ def apply_scaling(freqs: torch.Tensor):


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
dim: int,
end: int,
theta: float = 10000.0,
use_scaled: bool = False,
scale_factor: Optional[int] = None,
):
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device) # pyre-ignore
if use_scaled:
freqs = apply_scaling(freqs) # pyre-ignore
assert scale_factor is not None
freqs = apply_scaling(freqs, scale_factor) # pyre-ignore
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
Expand Down

0 comments on commit 957259e

Please sign in to comment.