Skip to content

Commit

Permalink
test eval
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Apr 24, 2024
1 parent aebaa53 commit 0c3282e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/levanter/main/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def compute_jsd_loss(model1: LmHeadModel, model2: LmHeadModel, example: LmExampl
with use_cpu_device():
model = eqx.filter_eval_shape(config.model.build, Vocab, key=key)
# TODO: don't load the entire checkpoint into CPU memory when we only need our share of the model
model = load_checkpoint(model, config.checkpoint_path, subpath="model")

model = load_checkpoint(model, "gs://levanter-checkpoints/llama2-trace-1b-seed42/19zqlxdg/step-200/model/", subpath="model")

model = hax.shard_with_axis_mapping(model, parameter_axis_mapping)

Expand Down

0 comments on commit 0c3282e

Please sign in to comment.