diff --git a/powersgd/powersgd.py b/powersgd/powersgd.py index f0ea55a..67a4f87 100644 --- a/powersgd/powersgd.py +++ b/powersgd/powersgd.py @@ -197,10 +197,12 @@ def aggregate(self, gradients: List[torch.Tensor]) -> List[torch.Tensor]: for group, in_batch, out_batch in zip( shape_groups, out_batches, in_batches ): - iter_approx = torch.einsum("bnr, bmr -> bmn", out_batch, in_batch) - maybe_transpose(group["grad_batch"]).sub_(iter_approx) # error feedback - maybe_transpose(group["approximation"]).add_(iter_approx) - del iter_approx + torch.bmm( + in_batch, + out_batch.permute([0, 2, 1]), + out=maybe_transpose(group["approximation"]) + ) + group["grad_batch"].sub_(group["approximation"]) # error feedback # Un-batch the approximation and error feedback, write to the output for group in shape_groups: