Skip to content

Commit

Permalink
Added support for subsets of nodes (agents) exploring independently f…
Browse files Browse the repository at this point in the history
…or a while and then syncing periodically
  • Loading branch information
sanchit-misra committed Jan 24, 2025
1 parent d37b5f1 commit 41aa1f2
Showing 1 changed file with 76 additions and 11 deletions.
87 changes: 76 additions & 11 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def average_gradients(model):
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size

def average_models(model):
"""Averages model weights across all ranks."""
world_size = float(dist.get_world_size())
for param in model.parameters():
param_tensor = param.data.clone() # clone to avoid inplace operations
dist.all_reduce(param_tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD)
param.data = param_tensor / world_size

def initialize_distributed_compute(dist_backend: str = "ccl"):
"""Initalizes distributed compute using either ccl or mpi backends."""
Expand Down Expand Up @@ -99,9 +106,17 @@ def initialize_distributed_compute(dist_backend: str = "ccl"):
my_rank = dist.get_rank() # Global!
my_size = dist.get_world_size() # Global!

# for now, let us enforce that each agent gets equal number of ranks.
# TODO: later, we can relax this condition.
assert (my_size % args.num_agents == 0)
agent_group_size = my_size // args.num_agents
agent_group_rank_list = [list(range(i * agent_group_size, (i + 1) * agent_group_size)) for i in range(args.num_agents)]
print (agent_group_rank_list)
agent_group_list = [dist.new_group(agent_group_rank_list[i], backend=dist_backend, timeout=datetime.timedelta(minutes=5),) for i in range(args.num_agents)]

print(f"+ My rank: {my_rank} size: {my_size}")

return (my_rank, my_size)
return (my_rank, my_size, agent_group_size, agent_group_list)


class DistributedErrorHandler:
Expand Down Expand Up @@ -311,10 +326,12 @@ def main(args): # noqa: C901
wandb.config.update(args)

if args.distributed:
my_rank, my_size = initialize_distributed_compute()
my_rank, my_size, agent_group_size, agent_group_list = initialize_distributed_compute()
my_rank = dist.get_rank()
world_size = torch.distributed.get_world_size()
my_agent_group_id = my_rank // agent_group_size
print(f"Running with DDP on rank {my_rank}/{world_size}.")
print(f"agent_group_size, my_agent_group_id = {agent_group_size, my_agent_group_id}")
else:
world_size = 1 # Single machine.
my_rank = 0 # Single machine.
Expand Down Expand Up @@ -355,7 +372,7 @@ def main(args): # noqa: C901
)

if args.distributed:
module = DDP(module)
module = DDP(module, process_group=agent_group_list[my_agent_group_id])

estimator = DiscretePolicyEstimator(
module=module,
Expand Down Expand Up @@ -391,8 +408,8 @@ def main(args): # noqa: C901
assert locals()[v] is not None, f"{v} is None, Args: {args}"

if args.distributed:
pf_module = DDP(pf_module)
pb_module = DDP(pb_module)
pf_module = DDP(pf_module, process_group=agent_group_list[my_agent_group_id])
pb_module = DDP(pb_module, process_group=agent_group_list[my_agent_group_id])

pf_estimator = DiscretePolicyEstimator(
module=pf_module,
Expand Down Expand Up @@ -428,7 +445,7 @@ def main(args): # noqa: C901
)

if args.distributed:
module = DDP(module)
module = DDP(module, process_group=agent_group_list[my_agent_group_id])

logF_estimator = ScalarEstimator(
module=module, preprocessor=env.preprocessor
Expand Down Expand Up @@ -518,6 +535,8 @@ def main(args): # noqa: C901
total_sample_time, total_to_train_samples_time = 0, 0
total_loss_time, total_loss_backward_time = 0, 0
total_opt_time, total_rest_time = 0, 0
total_average_time = 0
total_bar0_time, total_bar1_time = 0, 0
time_start = time.time()

if args.profile:
Expand Down Expand Up @@ -547,6 +566,8 @@ def cleanup():
)
#handler.start()

if args.distributed:
world_size = torch.distributed.get_world_size()
for iteration in trange(n_iterations):

iteration_start = time.time()
Expand All @@ -573,9 +594,6 @@ def cleanup():
# Time to_training_samples method.
to_train_samples_start = time.time()
training_samples = gflownet.to_training_samples(trajectories)
to_train_samples_end = time.time()
to_train_samples_time = to_train_samples_end - to_train_samples_start
total_to_train_samples_time += to_train_samples_time

if replay_buffer is not None:
with torch.no_grad():
Expand All @@ -585,6 +603,9 @@ def cleanup():
)
else:
training_objects = training_samples
to_train_samples_end = time.time()
to_train_samples_time = to_train_samples_end - to_train_samples_start
total_to_train_samples_time += to_train_samples_time

optimizer.zero_grad()

Expand All @@ -595,13 +616,19 @@ def cleanup():
training_objects,
reduction="sum" if args.distributed else "mean",
)

# Normalize the loss by the local batch size if distributed
if args.distributed:
loss = loss / (per_node_batch_size)
loss_end = time.time()
loss_time = loss_end - loss_start
total_loss_time += loss_time

# Normalize the loss by the local batch size if distributed
bar0_start = time.time()
if args.distributed:
loss = loss / (per_node_batch_size)
dist.barrier()
bar0_end = time.time()
total_bar0_time += (bar0_end - bar0_start)

# Time backpropagation computation.
loss_backward_start = time.time()
Expand All @@ -617,6 +644,21 @@ def cleanup():
opt_time = opt_end - opt_start
total_opt_time += opt_time

bar1_start = time.time()
if args.distributed:
dist.barrier()
bar1_end = time.time()
total_bar1_time += (bar1_end - bar1_start)

average_start = time.time()
if args.distributed and (iteration % args.average_every == 0):
print ("before averaging model, iteration = ", iteration)
average_models(gflownet)
print ("after averaging model, iteration = ", iteration)
average_end = time.time()
average_time = average_end - average_start;
total_average_time += average_time

# Keep track of trajectories / states.
visited_terminating_states.extend(trajectories.last_states)
states_visited += len(trajectories)
Expand All @@ -630,6 +672,7 @@ def cleanup():
loss_time,
loss_backward_time,
opt_time,
average_time,
]
)

Expand Down Expand Up @@ -661,6 +704,7 @@ def cleanup():
"loss_time": loss_time,
"loss_backward_time": loss_backward_time,
"opt_time": opt_time,
"average_time": average_time,
"rest_time": rest_time,
}

Expand Down Expand Up @@ -692,6 +736,7 @@ def cleanup():
total_loss_time,
total_loss_backward_time,
total_opt_time,
total_average_time,
]
)

Expand All @@ -703,9 +748,13 @@ def cleanup():
"total_sample_time": total_sample_time,
"total_to_train_samples_time": total_to_train_samples_time,
"total_loss_time": total_loss_time,
"total_actor_barrier_time": total_bar0_time,
"total_loss_backward_time": total_loss_backward_time,
"total_opt_time": total_opt_time,
"total_average_time": total_average_time,
"total_learner_barrier_time": total_bar1_time,
"total_rest_time": total_rest_time,
"total_time": total_time,
}

print("+ Final timing.")
Expand Down Expand Up @@ -788,6 +837,22 @@ def validate_hypergrid(
help="Initalizes distributed computation (torch.distributed)",
)

# Environment settings.
parser.add_argument(
"--num_agents",
type=int,
default=1,
help="Number of agents learning together",
)

# Environment settings.
parser.add_argument(
"--average_every",
type=int,
default=20,
help="Number of epochs after which we average model across all agents",
)

# Environment settings.
parser.add_argument(
"--ndim",
Expand Down

0 comments on commit 41aa1f2

Please sign in to comment.