diff --git a/nequip/train/loss.py b/nequip/train/loss.py index 6a34cdc6..f3f1ea15 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -191,15 +191,28 @@ def gather(self): """Use `torch.distributed` to gather and accumulate state of this LossStat across nodes to rank 0.""" state = ( dist.get_rank(), - {k: rs.get_state() for k, rs in self.loss_stat.items()}, + {k: (rs._state, rs._n) for k, rs in self.loss_stat.items()}, # Access _state and _n ) states = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(states, state) # list of dict + dist.all_gather_object(states, state) # Gather state from all nodes + if dist.get_rank() == 0: # accumulate on rank 0 for from_rank, state in states: if from_rank == 0: - # we already have this don't accumulate it + # Skip the current rank's state as it's already included continue - for k, rs_state in state.items(): - self.loss_stat[k].accumulate_state(rs_state) + for k, (rs_state, rs_n) in state.items(): + # Ensure tensors are on the same device + rs_state = rs_state.to(self.loss_stat[k]._state.device) + rs_n = rs_n.to(self.loss_stat[k]._n.device) + + # Accumulate the state by updating _state and _n + self.loss_stat[k]._state = ( + self.loss_stat[k]._state * self.loss_stat[k]._n + rs_state * rs_n + ) / (self.loss_stat[k]._n + rs_n) + self.loss_stat[k]._n += rs_n + # Ensure no division by zero issues + self.loss_stat[k]._state = torch.nan_to_num_( + self.loss_stat[k]._state, nan=0.0 + ) \ No newline at end of file diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 5e7ae98c..7d5f36ce 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -265,21 +265,36 @@ def flatten_metrics(self, metrics, type_names=None): def gather(self): """Use `torch.distributed` to gather and accumulate state of this Metrics across nodes to rank 0.""" + # this version makes sure that the tensors are on the same device state = ( dist.get_rank(), { - k1: {k2: rs.get_state() for k2, rs in v1.items()} + k1: {k2: (rs._state, rs._n) for k2, rs in v1.items()} # Access _state and _n for k1, v1 in self.running_stats.items() }, ) states = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(states, state) # list of dict + dist.all_gather_object(states, state) # Gather state from all nodes + if dist.get_rank() == 0: # accumulate on rank 0 for from_rank, state in states: if from_rank == 0: - # we already have this don't accumulate it + # Skip the current rank's state as it's already included continue for k1, v1 in state.items(): - for k2, rs_state in v1.items(): - self.running_stats[k1][k2].accumulate_state(rs_state) + for k2, (rs_state, rs_n) in v1.items(): + # Ensure tensors are on the same device + rs_state = rs_state.to(self.running_stats[k1][k2]._state.device) + rs_n = rs_n.to(self.running_stats[k1][k2]._n.device) + + # Accumulate the state by updating _state and _n + self.running_stats[k1][k2]._state = ( + self.running_stats[k1][k2]._state * self.running_stats[k1][k2]._n + rs_state * rs_n + ) / (self.running_stats[k1][k2]._n + rs_n) + self.running_stats[k1][k2]._n += rs_n + # Ensure no division by zero issues + self.running_stats[k1][k2]._state = torch.nan_to_num_( + self.running_stats[k1][k2]._state, nan=0.0 + ) +