Skip to content

Commit

Permalink
multinode metric tracking now working
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Dec 20, 2024
1 parent ac9afe1 commit d37b5f1
Showing 1 changed file with 84 additions and 60 deletions.
144 changes: 84 additions & 60 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _error_checker(self):
sys.exit(1)

except Exception as e:
print(f'Process {self.rank}: Error in error checker: {str(e)}')
print('Process {}: Error in error checker: {}'.format(self.rank, e))
self.signal_error()
break

Expand Down Expand Up @@ -196,8 +196,8 @@ def _cleanup(self):


def gather_distributed_data(
local_data: Union[List, torch.Tensor], world_size: int = None, rank: int = None
) -> List:
local_tensor: torch.Tensor, world_size: int = None, rank: int = None, verbose: bool = False,
) -> torch.Tensor:
"""
Gather data from all processes in a distributed setting.
Expand All @@ -207,75 +207,92 @@ def gather_distributed_data(
rank: Current process rank (optional, will get from env if None)
Returns:
List containing gathered data from all processes
On rank 0: Concatenated tensor from all processes
On other ranks: None
"""
print("syncing distributed data")
if verbose:
print("syncing distributed data")

if world_size is None:
world_size = dist.get_world_size()
if rank is None:
rank = dist.get_rank()

# Convert data to tensor if it's not already.
if not isinstance(local_data, torch.Tensor):
print("+ converting tensor data via serialization")
# Serialize complex data structures.
serialized_data = pickle.dumps(local_data)
local_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(serialized_data))
else:
print("+ tensor not converted")
local_tensor = local_data

# First gather sizes to allocate correct buffer sizes.
local_size = torch.tensor([local_tensor.numel()], device=local_tensor.device)
# First gather batch_sizes to allocate correct buffer sizes.
local_batch_size = torch.tensor([local_tensor.shape[0]], device=local_tensor.device, dtype=local_tensor.dtype)
if rank == 0:
size_list = [
torch.tensor([0], device=local_tensor.device) for _ in range(world_size)
# Assumes same dimensionality on all ranks!
batch_size_list = [
torch.zeros((1, ), device=local_tensor.device, dtype=local_tensor.dtype) for _ in range(world_size)
]
else:
size_list = None
print("+ all gather of local_size={} to size_list".format(local_size))
dist.gather(local_size, gather_list=size_list, dst=0)
# 44 [0] Process 0: Caught error: Invalid function argument. Expected parameter `tensor` to be of type torch.Tensor.[0]
batch_size_list = None

if verbose:
print("rank={}, batch_size_list={}".format(rank, batch_size_list))
print("+ gather of local_batch_size={} to batch_size_list".format(local_batch_size))
dist.gather(local_batch_size, gather_list=batch_size_list, dst=0)
dist.barrier() # Add synchronization

# Pad local tensor to maximum size.
if verbose:
print("+ padding local tensor")

if rank == 0:
max_batch_size = (max(bs for bs in batch_size_list))
else:
max_batch_size = 0

state_size = local_tensor.shape[1] # assume states are 1-d, is true for this env.

# Broadcast max_size to all processes for padding
max_batch_size_tensor = torch.tensor(max_batch_size, device=local_tensor.device)
dist.broadcast(max_batch_size_tensor, src=0)

# Pad local tensor to maximum size.
print("+ padding local tensor")
max_size = max(size.item() for size in size_list)
if local_tensor.numel() < max_size:
if local_tensor.shape[0] < max_batch_size:
padding = torch.zeros(
max_size - local_tensor.numel(),
(max_batch_size - local_tensor.shape[0], state_size),
dtype=local_tensor.dtype,
device=local_tensor.device,
device=local_tensor.device
)
local_tensor = torch.cat((local_tensor, padding))
local_tensor = torch.cat((local_tensor, padding), dim=0)

# Gather all tensors.
print("+ gathering all tensors from world_size={}".format(world_size))
# Gather padded tensors.
if rank == 0:
tensor_list = [
torch.zeros(max_size, dtype=local_tensor.dtype, device=local_tensor.device)
torch.zeros(
(max_batch_size, state_size),
dtype=local_tensor.dtype,
device=local_tensor.device,
)
for _ in range(world_size)
]
else:
tensor_list = None

if verbose:
print("+ gathering all tensors from world_size={}".format(world_size))
print("rank={}, tensor_list={}".format(rank, tensor_list))
dist.gather(local_tensor, gather_list=tensor_list, dst=0)
dist.barrier() # Add synchronization

# Trim padding and deserialize if necessary.
result = []
for tensor, size in zip(tensor_list, size_list):
trimmed_tensor = tensor[: size.item()]
if not isinstance(local_data, torch.Tensor):
print("+ deserializing tensor.")
# Deserialize data.
result = pickle.loads(trimmed_tensor.cpu().numpy().tobytes())
else:
print("+ tensor not deserialized")
result = trimmed_tensor
# Only rank 0 processes the results
if rank == 0:
results = []
for tensor, batch_size in zip(tensor_list, batch_size_list):
trimmed_tensor = tensor[:batch_size.item(), ...]
results.append(trimmed_tensor)

return result
if verbose:
print("distributed n_results={}".format(len(results)))

for r in results:
print(" {}".format(r.shape))

return torch.cat(results, dim=0) # Concatenates along the batch dimension.

return None # For all non-zero ranks.


def main(args): # noqa: C901
Expand Down Expand Up @@ -528,7 +545,7 @@ def cleanup():
world_size,
cleanup_callback=cleanup,
)
handler.start()
#handler.start()

for iteration in trange(n_iterations):

Expand Down Expand Up @@ -616,6 +633,24 @@ def cleanup():
]
)

log_this_iter = ((iteration % args.validation_interval == 0) or iteration == n_iterations - 1)

print("before distributed -- orig_shape={}".format(visited_terminating_states.tensor.shape))
if args.distributed and log_this_iter:
try:
all_visited_terminating_states = gather_distributed_data(
visited_terminating_states.tensor
)
except Exception as e:
print('Process {}: Caught error: {}'.format(my_rank, e))
#handler.signal_error()
sys.exit(1)
else:
all_visited_terminating_states = visited_terminating_states.tensor

if my_rank == 0:
print("after distributed -- gathered_shape={}, orig_shape={}".format(all_visited_terminating_states.shape, visited_terminating_states.tensor.shape))

# If we are on the master node, calculate the validation metrics.
if my_rank == 0:
to_log = {
Expand All @@ -628,24 +663,12 @@ def cleanup():
"opt_time": opt_time,
"rest_time": rest_time,
}

if use_wandb:
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (
iteration == n_iterations - 1
):

if args.distributed:
try:
all_visited_terminating_states = gather_distributed_data(
visited_terminating_states.tensor
)
except Exception as e:
print(f'Process {my_rank}: Caught error: {str(e)}')
handler.signal_error()
sys.exit(1)
else:
all_visited_terminating_states = visited_terminating_states.tensor

if log_this_iter:
print("logging thjs iteration!")
validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
Expand Down Expand Up @@ -725,6 +748,7 @@ def validate_hypergrid(
modes_found = set([tuple(s.tolist()) for s in modes])
discovered_modes.update(modes_found)
validation_info["n_modes_found"] = len(discovered_modes)
print(len(discovered_modes))

# Old way of counting modes -- potentially buggy - to be removed.
# # Add the mode counting metric.
Expand Down

0 comments on commit d37b5f1

Please sign in to comment.