diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 57f484770..c53d42f82 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, Sequence import draccus +import mergedeep import ray from ray._private.accelerators import TPUAcceleratorManager from ray.dashboard.modules.job.sdk import JobSubmissionClient @@ -198,10 +199,15 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 + # ray doesn't merge the runtime envs properly, so we have to do it ourselves + # we need to do a deep merge + runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + remote_fn = remote_fn.options( runtime_env=runtime_env, resources={tpu_name: 1, "TPU": num_tpus_per_host}, ) + logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name @@ -583,6 +589,16 @@ def _make_unique_job_id(client, run_id): return job_id +def _deep_merge_envs(env1, env2): + """ + Merge two environment dictionaries, deeply. + """ + + merged = mergedeep.merge({}, env1, env2, strategy=mergedeep.Strategy.ADDITIVE) + + return merged + + @draccus.wrap() def main(args: RunDockerOnPodConfig): """