Skip to content

Commit

Permalink
fix llama 3 rotary embeddings (#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Sep 24, 2024
1 parent fe3e2f3 commit 9fa3aaa
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 74 deletions.
9 changes: 6 additions & 3 deletions src/levanter/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
LlamaMlp,
)
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig
from levanter.types import BlockFoldable
from levanter.utils.flop_utils import lm_flops_per_token

Expand Down Expand Up @@ -80,7 +81,6 @@ class GemmaConfig(HFCompatConfig):
attn_dropout = 0.0
norm_eps = 1e-6

rope_base: int = 10_000
norm_embeddings: bool = True

# Attention-related config
Expand All @@ -94,9 +94,12 @@ class GemmaConfig(HFCompatConfig):
scan_layers: bool = True

use_bias: bool = False
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0

@property
def rope(self) -> RotaryEmbeddingsConfig:
return DefaultRotaryEmbeddingsConfig(theta=self.rope_theta)

# Axis
Pos = property(lambda self: Axis(name="position", size=self.seq_len))
KeyPos = property(lambda self: self.Pos.alias("key_position"))
Expand Down Expand Up @@ -146,7 +149,7 @@ def from_hf_config(cls, hf_config: HfConfig):
num_kv_heads=hf_config.num_key_value_heads,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
rope_base=hf_config.rope_theta,
rope_theta=hf_config.rope_theta,
)

def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfGemmaConfig:
Expand Down
75 changes: 13 additions & 62 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import dataclasses
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Tuple, Type, Union
from typing import Callable, Dict, Optional, Type, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import PRNGKeyArray
Expand All @@ -28,6 +27,7 @@
from levanter.models.attention import AttentionBackend, AttentionMask, dot_product_attention
from levanter.models.gpt2 import ACT2FN
from levanter.models.lm_model import LmConfig, LmHeadModel
from levanter.models.rotary import DefaultRotaryEmbeddingsConfig, RotaryEmbeddingsConfig
from levanter.types import BlockFoldable
from levanter.utils.flop_utils import lm_flops_per_token

Expand Down Expand Up @@ -77,8 +77,7 @@ class LlamaConfig(HFCompatConfig):

use_bias: bool = False
use_layer_norm_weight: bool = True
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope: RotaryEmbeddingsConfig = dataclasses.field(default_factory=DefaultRotaryEmbeddingsConfig)

reference_checkpoint: str = "meta-llama/Llama-2-7b-hf"
tokenizer: Optional[str] = None
Expand Down Expand Up @@ -109,6 +108,8 @@ def hf_checkpoint_converter(self) -> HFCheckpointConverter["LlamaConfig"]: # ty

@classmethod
def from_hf_config(cls, hf_config: HfConfig):
rope_theta = hf_config.rope_theta
rope_config = RotaryEmbeddingsConfig.from_hf_config(rope_theta, hf_config.rope_scaling)
return LlamaConfig(
seq_len=hf_config.max_position_embeddings,
hidden_dim=hf_config.hidden_size,
Expand All @@ -119,8 +120,7 @@ def from_hf_config(cls, hf_config: HfConfig):
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
rope_scaling=hf_config.rope_scaling,
rope_theta=hf_config.rope_theta,
rope=rope_config,
)

def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig:
Expand All @@ -136,6 +136,8 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
if config_overrides is None:
config_overrides = {}

rope_theta, rope_scaling = self.rope.to_hf_config()

return HfLlamaConfig(
max_position_embeddings=self.seq_len,
hidden_size=self.hidden_dim,
Expand All @@ -146,9 +148,10 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
rope_scaling=self.rope_scaling,
# rope_scaling=self.rope_scaling,
vocab_size=vocab_size,
rope_theta=self.rope_theta,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
**config_overrides,
)

Expand Down Expand Up @@ -274,13 +277,6 @@ def init(config: LlamaConfig, *, key) -> "LlamaAttention":
)
return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj)

def _rope_scale_factor(self) -> float:
# hasattr for gemma and I'm feeling lazy
if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
assert self.config.rope_scaling["type"] == "linear"
return self.config.rope_scaling["factor"]
return 1.0

@named_call
def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray:
key_q, key_k, key_v, key_o = maybe_rng_split(key, 4)
Expand All @@ -290,13 +286,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size"))

cos, sin = llama_rotary_pos_emb(
self.config.HeadSize,
x.resolve_axis("position"),
scale=self._rope_scale_factor(),
theta=self.config.rope_theta,
)
q, k = _apply_rotary_pos_emb(q, k, cos, sin)
rot_embs = self.config.rope.build(self.config.HeadSize, q.resolve_axis("position"))
q, k = rot_embs(self.config.HeadSize, q, k)

k = k.rename({"position": "key_position"})
v = v.rename({"position": "key_position"})
Expand Down Expand Up @@ -588,43 +579,3 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None)

state_dict.update(my_dict)
return state_dict


def _rotate_half(x: NamedArray) -> NamedArray:
"""Rotates half of the hidden dims of the input and concatenates them."""
HeadSize = x.axes[-1]
x1 = x[HeadSize, : HeadSize.size // 2]
x2 = x[HeadSize, HeadSize.size // 2 :]
out = hax.concatenate(HeadSize, (-x2, x1))
return out


def _apply_rotary_pos_emb(
q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size]
k: NamedArray, # [batch, position, kv_heads, head_size]
cos: NamedArray, # [position, head_size]
sin: NamedArray, # [position, head_size]
) -> Tuple[NamedArray, NamedArray]:
"""Applies rotary position embedding to q and k."""
q_embed = q * cos + _rotate_half(q) * sin
k_embed = k * cos + _rotate_half(k) * sin
return q_embed, k_embed


def llama_rotary_pos_emb(
HeadSize: Axis, Pos: Axis, theta: float = 10000, scale: float = 1.0
) -> Tuple[NamedArray, NamedArray]:
with jax.ensure_compile_time_eval():
HeadHalfSize = HeadSize.resize(HeadSize.size // 2)
inv_freq: NamedArray = 1.0 / (theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size))

position_ids: NamedArray = hax.arange(Pos) / scale

freqs = position_ids * inv_freq.broadcast_axis(Pos)
# This is different from the paper but aligns with HF implementation:
# It uses a different permutation in order to obtain the same calculation
emb = hax.concatenate(HeadSize, (freqs, freqs))
cos = hax.cos(emb)
sin = hax.sin(emb)
# This is different from the paper but aligns with HF implementation:
return cos, sin
182 changes: 182 additions & 0 deletions src/levanter/models/rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import abc
from dataclasses import dataclass
from typing import Tuple

import draccus
import equinox as eqx
import jax
import jax.numpy as jnp

import haliax as hax
from haliax import Axis, NamedArray


def _rotate_half(x: NamedArray, HeadSize: Axis) -> NamedArray:
"""Rotates half of the hidden dims of the input and concatenates them."""
x1 = x[HeadSize, : HeadSize.size // 2]
x2 = x[HeadSize, HeadSize.size // 2 :]
out = hax.concatenate(HeadSize, (-x2, x1))
return out


class RotaryEmbeddings(eqx.Module):
cos: NamedArray
sin: NamedArray

@property
def nograd_cos(self):
return jax.lax.stop_gradient(self.cos)

@property
def nograd_sin(self):
return jax.lax.stop_gradient(self.sin)

def __call__(self, HeadDim: Axis, q: NamedArray, k: NamedArray) -> tuple[NamedArray, NamedArray]:
q_embed = q * self.nograd_cos + _rotate_half(q, HeadDim) * self.nograd_sin
k_embed = k * self.nograd_cos + _rotate_half(k, HeadDim) * self.nograd_sin
return q_embed, k_embed


@dataclass
class RotaryEmbeddingsConfig(abc.ABC, draccus.ChoiceRegistry):
@abc.abstractmethod
def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings:
pass

@staticmethod
def from_hf_config(rope_theta, config: dict | None) -> "RotaryEmbeddingsConfig":
if config is None:
return DefaultRotaryEmbeddingsConfig(theta=rope_theta)
tpe = config.get("rope_type") or config.get("type") or "default"
return RotaryEmbeddingsConfig.get_choice_class(tpe).make_from_hf_config(rope_theta, config)

@classmethod
@abc.abstractmethod
def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig":
pass

@abc.abstractmethod
def to_hf_config(self) -> tuple[float, dict | None]:
"""Returns the rope_theta and config dict for the HF config."""
pass


@dataclass
class DefaultRotaryEmbeddingsConfig(RotaryEmbeddingsConfig):
theta: float = 10000
factor: float = 1.0 # this should have been called scale_factor, but for hf compat

def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings:
with jax.ensure_compile_time_eval():
HeadHalfSize = HeadSize.resize(HeadSize.size // 2)
inv_freq: NamedArray = 1.0 / (self.theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size))
inv_freq = inv_freq / self.factor

position_ids: NamedArray = hax.arange(Pos)

freqs = position_ids * inv_freq.broadcast_axis(Pos)
emb = hax.concatenate(HeadSize, (freqs, freqs))
cos = hax.cos(emb)
sin = hax.sin(emb)
return RotaryEmbeddings(cos=cos, sin=sin)

@classmethod
def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig":
return DefaultRotaryEmbeddingsConfig(theta=rope_theta, factor=config.get("factor", 1.0))

def to_hf_config(self) -> tuple[float, dict | None]:
if self.factor == 1.0:
return self.theta, None
return self.theta, {"factor": self.factor}


RotaryEmbeddingsConfig.register_subclass("default", DefaultRotaryEmbeddingsConfig)
RotaryEmbeddingsConfig.register_subclass("linear", DefaultRotaryEmbeddingsConfig)


@dataclass
class Llama3RotaryEmbeddingsConfig(RotaryEmbeddingsConfig):
"""
To match this from HF:
"rope_scaling": {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"""

theta: float = 500000
factor: float = 8.0
low_freq_factor: float = 1.0
high_freq_factor: float = 4.0
original_max_position_embeddings: int = 8192

def build(self, HeadSize: Axis, Pos: Axis) -> RotaryEmbeddings:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
# Porting that to JAX/Haliax:
with jax.ensure_compile_time_eval():
HeadHalfSize = HeadSize.resize(HeadSize.size // 2)
inv_freq: NamedArray = 1.0 / (self.theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size))

old_context_len = self.original_max_position_embeddings
low_freq_wavelen = old_context_len / self.low_freq_factor
high_freq_wavelen = old_context_len / self.high_freq_factor

wavelen = 2 * jnp.pi / inv_freq
inv_freq_llama = hax.where(wavelen > low_freq_wavelen, inv_freq / self.factor, inv_freq)
smooth_factor = (old_context_len / wavelen - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / self.factor + smooth_factor * inv_freq_llama
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
inv_freq_llama = hax.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

position_ids: NamedArray = hax.arange(Pos)

freqs = position_ids * inv_freq_llama.broadcast_axis(Pos)
emb = hax.concatenate(HeadSize, (freqs, freqs))
cos = hax.cos(emb)
sin = hax.sin(emb)
return RotaryEmbeddings(cos=cos, sin=sin)

@classmethod
def make_from_hf_config(cls, rope_theta: float, config: dict) -> "RotaryEmbeddingsConfig":
return Llama3RotaryEmbeddingsConfig(
theta=rope_theta,
factor=config.get("factor", 8.0),
low_freq_factor=config.get("low_freq_factor", 1.0),
high_freq_factor=config.get("high_freq_factor", 4.0),
original_max_position_embeddings=config.get("original_max_position_embeddings", 8192),
)

def to_hf_config(self) -> tuple[float, dict]:
return self.theta, {
"factor": self.factor,
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings": self.original_max_position_embeddings,
}


RotaryEmbeddingsConfig.register_subclass("llama3", Llama3RotaryEmbeddingsConfig)


def rotary_pos_emb(
HeadSize: Axis, Pos: Axis, theta: float = 10000, scale: float = 1.0
) -> Tuple[NamedArray, NamedArray]:
with jax.ensure_compile_time_eval():
HeadHalfSize = HeadSize.resize(HeadSize.size // 2)
inv_freq: NamedArray = 1.0 / (theta ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) / scale

position_ids: NamedArray = hax.arange(Pos)

freqs = position_ids * inv_freq.broadcast_axis(Pos)
# This is different from the paper but aligns with HF implementation:
# It uses a different permutation in order to obtain the same calculation
emb = hax.concatenate(HeadSize, (freqs, freqs))
cos = hax.cos(emb)
sin = hax.sin(emb)
# This is different from the paper but aligns with HF implementation:
return cos, sin
Loading

0 comments on commit 9fa3aaa

Please sign in to comment.