From 888bc19e92edb03a4b49f28169ba32f0f4c5d694 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 6 Feb 2025 14:13:33 -0500 Subject: [PATCH] No arg for repl placement --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 7364bc480..935e6f548 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -604,7 +604,7 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate(0)], True) + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Replicate()], True) inp = {"state":deepcopy(base), "dstate":dstate} # Read distributed state dict reader = checkpoint.FileSystemReader(path)