From b52220313b8bceb6832e15534f202a829375cb08 Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Thu, 12 May 2022 16:27:23 +0200 Subject: [PATCH] Revert "improve performances #13 (#14)" This reverts commit bab2c19ce2ec987ffb1785d32d636f58f8863a6f. --- powersgd/powersgd.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/powersgd/powersgd.py b/powersgd/powersgd.py index 67a4f87..f0ea55a 100644 --- a/powersgd/powersgd.py +++ b/powersgd/powersgd.py @@ -197,12 +197,10 @@ def aggregate(self, gradients: List[torch.Tensor]) -> List[torch.Tensor]: for group, in_batch, out_batch in zip( shape_groups, out_batches, in_batches ): - torch.bmm( - in_batch, - out_batch.permute([0, 2, 1]), - out=maybe_transpose(group["approximation"]) - ) - group["grad_batch"].sub_(group["approximation"]) # error feedback + iter_approx = torch.einsum("bnr, bmr -> bmn", out_batch, in_batch) + maybe_transpose(group["grad_batch"]).sub_(iter_approx) # error feedback + maybe_transpose(group["approximation"]).add_(iter_approx) + del iter_approx # Un-batch the approximation and error feedback, write to the output for group in shape_groups: