From 6c69ebd3d60888582b0ea97f93163adfa613b392 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 24 Jul 2024 02:05:48 -0700 Subject: [PATCH] Support llama3.1 (#4376) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4376 Add scaled RoPE Reviewed By: Gasoonjia Differential Revision: D60129927 fbshipit-source-id: b8d2fadcd3e6985740965ad0185b8fb516806c22 --- examples/models/llama2/llama_transformer.py | 6 +++- examples/models/llama2/rope.py | 34 +++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 431cc6dc2c..56bf4a96c3 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -8,6 +8,7 @@ # Please refer to README.md in the same folder for more information. from dataclasses import dataclass +from functools import partial from typing import Optional, Tuple import torch @@ -101,6 +102,7 @@ class ModelArgs: None # The official name to override self.rope_freq_base. ) rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. + use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 @@ -453,7 +455,9 @@ def __init__(self, params: ModelArgs): if params.use_hf_rope: self.precompute_freqs_cis = hf_precompute_freqs_cis else: - self.precompute_freqs_cis = precompute_freqs_cis + self.precompute_freqs_cis = partial( + precompute_freqs_cis, use_scaled=params.use_scaled_rope + ) freqs_cos, freqs_sin = self.precompute_freqs_cis( params.dim // params.n_heads, ( diff --git a/examples/models/llama2/rope.py b/examples/models/llama2/rope.py index 8c948617bb..233c7a2f98 100644 --- a/examples/models/llama2/rope.py +++ b/examples/models/llama2/rope.py @@ -7,6 +7,7 @@ # Different RoPE implementations +import math from typing import Tuple import torch @@ -14,12 +15,41 @@ # ======================== Stock Implementation ======================== -def precompute_freqs_cis(dim: int, end: int, theta: float): +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) ) t = torch.arange(end, device=freqs.device) # pyre-ignore - freqs = torch.outer(t, freqs).float() # pyre-ignore + if use_scaled: + freqs = apply_scaling(freqs) # pyre-ignore + freqs = torch.outer(t, freqs).float() freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin