Skip to content

Commit

Permalink
Merge pull request #844 from mlcommons/python_upgrades
Browse files Browse the repository at this point in the history
fix: frozen dict conversion
  • Loading branch information
priyakasimbeg authored Feb 13, 2025
2 parents 5c4c07d + a687fa7 commit c1a1ef0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tests/modeldiffs/diff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flax import jax_utils
from flax.core import FrozenDict
import jax
import numpy as np
import torch
Expand All @@ -16,6 +17,8 @@ def torch2jax(jax_workload,
jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0),
**init_kwargs)
pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs)
if isinstance(jax_params, dict):
jax_params = FrozenDict(jax_params)
jax_params = jax_utils.unreplicate(jax_params).unfreeze()
if model_state is not None:
model_state = jax_utils.unreplicate(model_state)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_traindiffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_workload(self, workload):
pyt_logs = '/tmp/pyt_log.pkl'
try:
run(
f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python3 -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}'
f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs}'
f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
shell=True,
stdout=DEVNULL,
Expand Down

0 comments on commit c1a1ef0

Please sign in to comment.