diff --git a/powersgd/powersgd.py b/powersgd/powersgd.py index 67a4f87..f0ea55a 100644 --- a/powersgd/powersgd.py +++ b/powersgd/powersgd.py @@ -197,12 +197,10 @@ def aggregate(self, gradients: List[torch.Tensor]) -> List[torch.Tensor]: for group, in_batch, out_batch in zip( shape_groups, out_batches, in_batches ): - torch.bmm( - in_batch, - out_batch.permute([0, 2, 1]), - out=maybe_transpose(group["approximation"]) - ) - group["grad_batch"].sub_(group["approximation"]) # error feedback + iter_approx = torch.einsum("bnr, bmr -> bmn", out_batch, in_batch) + maybe_transpose(group["grad_batch"]).sub_(iter_approx) # error feedback + maybe_transpose(group["approximation"]).add_(iter_approx) + del iter_approx # Un-batch the approximation and error feedback, write to the output for group in shape_groups: