Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed gather method to work with distributed training in metrics.py and loss.py #450

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions nequip/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,28 @@ def gather(self):
"""Use `torch.distributed` to gather and accumulate state of this LossStat across nodes to rank 0."""
state = (
dist.get_rank(),
{k: rs.get_state() for k, rs in self.loss_stat.items()},
{k: (rs._state, rs._n) for k, rs in self.loss_stat.items()}, # Access _state and _n
)
states = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(states, state) # list of dict
dist.all_gather_object(states, state) # Gather state from all nodes

if dist.get_rank() == 0:
# accumulate on rank 0
for from_rank, state in states:
if from_rank == 0:
# we already have this don't accumulate it
# Skip the current rank's state as it's already included
continue
for k, rs_state in state.items():
self.loss_stat[k].accumulate_state(rs_state)
for k, (rs_state, rs_n) in state.items():
# Ensure tensors are on the same device
rs_state = rs_state.to(self.loss_stat[k]._state.device)
rs_n = rs_n.to(self.loss_stat[k]._n.device)

# Accumulate the state by updating _state and _n
self.loss_stat[k]._state = (
self.loss_stat[k]._state * self.loss_stat[k]._n + rs_state * rs_n
) / (self.loss_stat[k]._n + rs_n)
self.loss_stat[k]._n += rs_n
# Ensure no division by zero issues
self.loss_stat[k]._state = torch.nan_to_num_(
self.loss_stat[k]._state, nan=0.0
)
25 changes: 20 additions & 5 deletions nequip/train/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,36 @@ def flatten_metrics(self, metrics, type_names=None):

def gather(self):
"""Use `torch.distributed` to gather and accumulate state of this Metrics across nodes to rank 0."""
# this version makes sure that the tensors are on the same device
state = (
dist.get_rank(),
{
k1: {k2: rs.get_state() for k2, rs in v1.items()}
k1: {k2: (rs._state, rs._n) for k2, rs in v1.items()} # Access _state and _n
for k1, v1 in self.running_stats.items()
},
)
states = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(states, state) # list of dict
dist.all_gather_object(states, state) # Gather state from all nodes

if dist.get_rank() == 0:
# accumulate on rank 0
for from_rank, state in states:
if from_rank == 0:
# we already have this don't accumulate it
# Skip the current rank's state as it's already included
continue
for k1, v1 in state.items():
for k2, rs_state in v1.items():
self.running_stats[k1][k2].accumulate_state(rs_state)
for k2, (rs_state, rs_n) in v1.items():
# Ensure tensors are on the same device
rs_state = rs_state.to(self.running_stats[k1][k2]._state.device)
rs_n = rs_n.to(self.running_stats[k1][k2]._n.device)

# Accumulate the state by updating _state and _n
self.running_stats[k1][k2]._state = (
self.running_stats[k1][k2]._state * self.running_stats[k1][k2]._n + rs_state * rs_n
) / (self.running_stats[k1][k2]._n + rs_n)
self.running_stats[k1][k2]._n += rs_n
# Ensure no division by zero issues
self.running_stats[k1][k2]._state = torch.nan_to_num_(
self.running_stats[k1][k2]._state, nan=0.0
)