Skip to content

Commit

Permalink
Merge pull request #15 from epfml/revert-14-master
Browse files Browse the repository at this point in the history
Revert "Improve performances"
  • Loading branch information
younik authored May 12, 2022
2 parents bab2c19 + b522203 commit a386746
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions powersgd/powersgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,10 @@ def aggregate(self, gradients: List[torch.Tensor]) -> List[torch.Tensor]:
for group, in_batch, out_batch in zip(
shape_groups, out_batches, in_batches
):
torch.bmm(
in_batch,
out_batch.permute([0, 2, 1]),
out=maybe_transpose(group["approximation"])
)
group["grad_batch"].sub_(group["approximation"]) # error feedback
iter_approx = torch.einsum("bnr, bmr -> bmn", out_batch, in_batch)
maybe_transpose(group["grad_batch"]).sub_(iter_approx) # error feedback
maybe_transpose(group["approximation"]).add_(iter_approx)
del iter_approx

# Un-batch the approximation and error feedback, write to the output
for group in shape_groups:
Expand Down

0 comments on commit a386746

Please sign in to comment.