From 6985dfd91b5906fd90c6c15a3f8b60d2872712df Mon Sep 17 00:00:00 2001 From: jiayifeng Date: Mon, 30 Jan 2023 12:09:57 +0800 Subject: [PATCH] fix bug --- varuna/varuna.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/varuna/varuna.py b/varuna/varuna.py index a67d32a..ea733ae 100755 --- a/varuna/varuna.py +++ b/varuna/varuna.py @@ -654,6 +654,8 @@ def all_reduce_pipeline_meta(self, master_grads, overflow_buf=None): def extra_grad_norm_sq(self): extra_norm_sq = 0.0 + if self.shared_weights is None: + return extra_norm_sq for i,w in enumerate(self.shared_weights): recv_stage, send_stage = self.shared_weight_stages[i] recv_wt_name, send_wt_name = w