From bab2c19ce2ec987ffb1785d32d636f58f8863a6f Mon Sep 17 00:00:00 2001 From: Omar Younis <42100908+younik@users.noreply.github.com> Date: Thu, 12 May 2022 15:11:18 +0200 Subject: [PATCH] improve performances #13 (#14) --- powersgd/powersgd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/powersgd/powersgd.py b/powersgd/powersgd.py index f0ea55a..67a4f87 100644 --- a/powersgd/powersgd.py +++ b/powersgd/powersgd.py @@ -197,10 +197,12 @@ def aggregate(self, gradients: List[torch.Tensor]) -> List[torch.Tensor]: for group, in_batch, out_batch in zip( shape_groups, out_batches, in_batches ): - iter_approx = torch.einsum("bnr, bmr -> bmn", out_batch, in_batch) - maybe_transpose(group["grad_batch"]).sub_(iter_approx) # error feedback - maybe_transpose(group["approximation"]).add_(iter_approx) - del iter_approx + torch.bmm( + in_batch, + out_batch.permute([0, 2, 1]), + out=maybe_transpose(group["approximation"]) + ) + group["grad_batch"].sub_(group["approximation"]) # error feedback # Un-batch the approximation and error feedback, write to the output for group in shape_groups: