Skip to content

Commit

Permalink
Fix Llama 3 Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Nov 2, 2024
1 parent 5ebf8ce commit 5118db0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/levanter/models/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def to_hf_config(self) -> tuple[float, dict]:
"low_freq_factor": self.low_freq_factor,
"high_freq_factor": self.high_freq_factor,
"original_max_position_embeddings": self.original_max_position_embeddings,
"rope_type": "llama3",
}


Expand Down
4 changes: 3 additions & 1 deletion tests/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ def get_config(vocab_size=1000):
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"head_dim": 64,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"max_position_embeddings": 131072,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
Expand All @@ -55,6 +56,7 @@ def get_config(vocab_size=1000):
llama3_8b_config.hidden_size = 16
llama3_8b_config.intermediate_size = 64
llama3_8b_config.num_attention_heads = 4
llama3_8b_config.head_dim = 4
llama3_8b_config.num_hidden_layers = 4
llama3_8b_config.num_key_value_heads = 2
llama3_8b_config.max_position_embeddings = 128
Expand Down

0 comments on commit 5118db0

Please sign in to comment.