diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index e61fcda77..af1b05dcc 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -811,7 +811,7 @@ def distribute( name = name or f"distributed_{self.name}" pooled_module = DistributedPool( **distribution_kwargs, name=name, replicas=replicas - ).to(self.system, env=self.env.name) + ).to(self.system, process=self.env.name) return pooled_module elif distribution == "ray": from runhouse.resources.distributed.ray_distributed import RayDistributed @@ -819,7 +819,7 @@ def distribute( name = name or f"ray_{self.name}" ray_module = RayDistributed( **distribution_kwargs, name=name, module=self - ).to(self.system, self.env.name) + ).to(system=self.system, process=self.env.name) return ray_module elif distribution == "dask": from runhouse.resources.distributed.dask_distributed import DaskDistributed @@ -831,7 +831,7 @@ def distribute( num_replicas=num_replicas, replicas_per_node=replicas_per_node, **distribution_kwargs, - ).to(self.system, self.env.name) + ).to(system=self.system, process=self.env.name) return dask_module elif distribution == "pytorch": from runhouse.resources.distributed.pytorch_distributed import ( @@ -846,7 +846,7 @@ def distribute( name = name or f"pytorch_{self.local.name}" ptd_module = PyTorchDistributed( **distribution_kwargs, name=name, replicas=replicas - ).to(self.system, env=self.env.name) + ).to(system=self.system, process=self.env.name) return ptd_module @property