From e1fb1727ffb6dc58505322b2145908b37e5552f8 Mon Sep 17 00:00:00 2001 From: Myles Stapelberg Date: Mon, 15 Jul 2024 20:55:55 -0400 Subject: [PATCH 1/2] fixed gather in metrics and loss --- nequip/train/loss.py | 19 ++++++++++++++----- nequip/train/metrics.py | 21 ++++++++++++++++----- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/nequip/train/loss.py b/nequip/train/loss.py index 6a34cdc6..08c114a7 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -191,15 +191,24 @@ def gather(self): """Use `torch.distributed` to gather and accumulate state of this LossStat across nodes to rank 0.""" state = ( dist.get_rank(), - {k: rs.get_state() for k, rs in self.loss_stat.items()}, + {k: (rs._state, rs._n) for k, rs in self.loss_stat.items()}, # Access _state and _n ) states = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(states, state) # list of dict + dist.all_gather_object(states, state) # Gather state from all nodes + if dist.get_rank() == 0: # accumulate on rank 0 for from_rank, state in states: if from_rank == 0: - # we already have this don't accumulate it + # Skip the current rank's state as it's already included continue - for k, rs_state in state.items(): - self.loss_stat[k].accumulate_state(rs_state) + for k, (rs_state, rs_n) in state.items(): + # Ensure tensors are on the same device + rs_state = rs_state.to(self.loss_stat[k]._state.device) + rs_n = rs_n.to(self.loss_stat[k]._n.device) + + # Accumulate the state by updating _state and _n + self.loss_stat[k]._state = (self.loss_stat[k]._state * self.loss_stat[k]._n + rs_state * rs_n) / (self.loss_stat[k]._n + rs_n) + self.loss_stat[k]._n += rs_n + # Ensure no division by zero issues + self.loss_stat[k]._state = torch.nan_to_num_(self.loss_stat[k]._state, nan=0.0) \ No newline at end of file diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index 5e7ae98c..a9f59822 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -265,21 +265,32 @@ def flatten_metrics(self, metrics, type_names=None): def gather(self): """Use `torch.distributed` to gather and accumulate state of this Metrics across nodes to rank 0.""" + # this version makes sure that the tensors are on the same device state = ( dist.get_rank(), { - k1: {k2: rs.get_state() for k2, rs in v1.items()} + k1: {k2: (rs._state, rs._n) for k2, rs in v1.items()} # Access _state and _n for k1, v1 in self.running_stats.items() }, ) states = [None for _ in range(dist.get_world_size())] - dist.all_gather_object(states, state) # list of dict + dist.all_gather_object(states, state) # Gather state from all nodes + if dist.get_rank() == 0: # accumulate on rank 0 for from_rank, state in states: if from_rank == 0: - # we already have this don't accumulate it + # Skip the current rank's state as it's already included continue for k1, v1 in state.items(): - for k2, rs_state in v1.items(): - self.running_stats[k1][k2].accumulate_state(rs_state) + for k2, (rs_state, rs_n) in v1.items(): + # Ensure tensors are on the same device + rs_state = rs_state.to(self.running_stats[k1][k2]._state.device) + rs_n = rs_n.to(self.running_stats[k1][k2]._n.device) + + # Accumulate the state by updating _state and _n + self.running_stats[k1][k2]._state = (self.running_stats[k1][k2]._state * self.running_stats[k1][k2]._n + rs_state * rs_n) / (self.running_stats[k1][k2]._n + rs_n) + self.running_stats[k1][k2]._n += rs_n + # Ensure no division by zero issues + self.running_stats[k1][k2]._state = torch.nan_to_num_(self.running_stats[k1][k2]._state, nan=0.0) + From 891a0ca551c7f76a1b8b961ca2b837b7b8348991 Mon Sep 17 00:00:00 2001 From: Myles Stapelberg Date: Mon, 15 Jul 2024 22:53:16 -0400 Subject: [PATCH 2/2] reformatted to fit black --- nequip/train/loss.py | 8 ++++++-- nequip/train/metrics.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/nequip/train/loss.py b/nequip/train/loss.py index 08c114a7..f3f1ea15 100644 --- a/nequip/train/loss.py +++ b/nequip/train/loss.py @@ -208,7 +208,11 @@ def gather(self): rs_n = rs_n.to(self.loss_stat[k]._n.device) # Accumulate the state by updating _state and _n - self.loss_stat[k]._state = (self.loss_stat[k]._state * self.loss_stat[k]._n + rs_state * rs_n) / (self.loss_stat[k]._n + rs_n) + self.loss_stat[k]._state = ( + self.loss_stat[k]._state * self.loss_stat[k]._n + rs_state * rs_n + ) / (self.loss_stat[k]._n + rs_n) self.loss_stat[k]._n += rs_n # Ensure no division by zero issues - self.loss_stat[k]._state = torch.nan_to_num_(self.loss_stat[k]._state, nan=0.0) \ No newline at end of file + self.loss_stat[k]._state = torch.nan_to_num_( + self.loss_stat[k]._state, nan=0.0 + ) \ No newline at end of file diff --git a/nequip/train/metrics.py b/nequip/train/metrics.py index a9f59822..7d5f36ce 100644 --- a/nequip/train/metrics.py +++ b/nequip/train/metrics.py @@ -289,8 +289,12 @@ def gather(self): rs_n = rs_n.to(self.running_stats[k1][k2]._n.device) # Accumulate the state by updating _state and _n - self.running_stats[k1][k2]._state = (self.running_stats[k1][k2]._state * self.running_stats[k1][k2]._n + rs_state * rs_n) / (self.running_stats[k1][k2]._n + rs_n) + self.running_stats[k1][k2]._state = ( + self.running_stats[k1][k2]._state * self.running_stats[k1][k2]._n + rs_state * rs_n + ) / (self.running_stats[k1][k2]._n + rs_n) self.running_stats[k1][k2]._n += rs_n # Ensure no division by zero issues - self.running_stats[k1][k2]._state = torch.nan_to_num_(self.running_stats[k1][k2]._state, nan=0.0) + self.running_stats[k1][k2]._state = torch.nan_to_num_( + self.running_stats[k1][k2]._state, nan=0.0 + )