Skip to content

Commit

Permalink
Support llama3.1 (pytorch#4376)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4376

Add scaled RoPE

Reviewed By: Gasoonjia

Differential Revision: D60129927

fbshipit-source-id: b8d2fadcd3e6985740965ad0185b8fb516806c22
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Jul 24, 2024
1 parent 11b2fcb commit 6c69ebd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
6 changes: 5 additions & 1 deletion examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Please refer to README.md in the same folder for more information.

from dataclasses import dataclass
from functools import partial
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -101,6 +102,7 @@ class ModelArgs:
None # The official name to override self.rope_freq_base.
)
rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC.
use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1.
# Additional Model Metadata needed at runtime
bos_idx: int = 1
eos_idx: int = 3
Expand Down Expand Up @@ -453,7 +455,9 @@ def __init__(self, params: ModelArgs):
if params.use_hf_rope:
self.precompute_freqs_cis = hf_precompute_freqs_cis
else:
self.precompute_freqs_cis = precompute_freqs_cis
self.precompute_freqs_cis = partial(
precompute_freqs_cis, use_scaled=params.use_scaled_rope
)
freqs_cos, freqs_sin = self.precompute_freqs_cis(
params.dim // params.n_heads,
(
Expand Down
34 changes: 32 additions & 2 deletions examples/models/llama2/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,49 @@

# Different RoPE implementations

import math
from typing import Tuple

import torch

# ======================== Stock Implementation ========================


def precompute_freqs_cis(dim: int, end: int, theta: float):
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim)
)
t = torch.arange(end, device=freqs.device) # pyre-ignore
freqs = torch.outer(t, freqs).float() # pyre-ignore
if use_scaled:
freqs = apply_scaling(freqs) # pyre-ignore
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
Expand Down

0 comments on commit 6c69ebd

Please sign in to comment.