Skip to content

Commit

Permalink
tests pass!
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 12, 2024
1 parent e3a6e3e commit 14f91a9
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,4 +387,5 @@ def test_state_dict_consistency(scan_layers, num_kv_heads):
model = LlamaLMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(0))
hf_config = config.to_hf_config(Vocab.size)
hf_model = LlamaForCausalLM(hf_config)
assert set(hf_model.state_dict().keys()) == set(model.to_state_dict().keys())
levanter_state_dict = hax.state_dict.to_torch_compatible_state_dict(model)
assert set(hf_model.state_dict().keys()) == set(levanter_state_dict.keys())

0 comments on commit 14f91a9

Please sign in to comment.