Removing unnecessary gradient tensor loads #20
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Based on discussions with @Pennycook, this PR provides a small but noticable performance bump to the first version of EquiTriton kernels. Comparison with the
e3nn
+torch.compile
autotuned kernels showed that the backward pass contained unnecessary memory loads that introduced an element-wise latency.Originally, variables for accumulating gradients were initialized by performing
tl.load
. The optimization introduced here is to remove this load dependency, allowingg_x/y/z
to be calculated as soon as the first order gradients are streamed in. The main impact I've seen from running these kernels is removing the significant dropoff in performance at higher node counts seen in #9 from 10^5 nodes and above, making relative performance more or less constant across node counts. Tested on a 1100 GPU Max on PyTorch2.6.0a0+git487873f
, andtriton==3.1.0
.cc @mitkotak in case you're interested in tracking these changes
Types of changes
What types of changes does your code introduce to the project?
Put an
x
in the boxes that applyChecklist
Put an
x
in the boxes that apply. You can also fill these out after creatingthe PR. If you are unsure about any of them, do not hesitate to ask. We are
here to help! This is simply a reminder of what we are going to look for before
merging your code.