Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removing unnecessary gradient tensor loads #20

Merged
merged 3 commits into from
Nov 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 10 additions & 33 deletions src/equitriton/sph_harm/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,13 @@ def _triton_second_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down Expand Up @@ -347,21 +339,14 @@ def _triton_third_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
# initialize gradients
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down Expand Up @@ -666,21 +651,13 @@ def _triton_fourth_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down