diff --git a/src/levanter/models/rotary.py b/src/levanter/models/rotary.py index 07657e5ff..55bbf3fcb 100644 --- a/src/levanter/models/rotary.py +++ b/src/levanter/models/rotary.py @@ -157,6 +157,7 @@ def to_hf_config(self) -> tuple[float, dict]: "low_freq_factor": self.low_freq_factor, "high_freq_factor": self.high_freq_factor, "original_max_position_embeddings": self.original_max_position_embeddings, + "rope_type": "llama3", } diff --git a/tests/test_llama3.py b/tests/test_llama3.py index 2fae326d1..653ba723c 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -26,9 +26,10 @@ def get_config(vocab_size=1000): "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, + "head_dim": 64, "initializer_range": 0.02, "intermediate_size": 14336, - "max_position_embeddings": 8192, + "max_position_embeddings": 131072, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, @@ -55,6 +56,7 @@ def get_config(vocab_size=1000): llama3_8b_config.hidden_size = 16 llama3_8b_config.intermediate_size = 64 llama3_8b_config.num_attention_heads = 4 + llama3_8b_config.head_dim = 4 llama3_8b_config.num_hidden_layers = 4 llama3_8b_config.num_key_value_heads = 2 llama3_8b_config.max_position_embeddings = 128