Skip to content

Commit

Permalink
Merge pull request #20 from laserkelvin/no-initial-loads
Browse files Browse the repository at this point in the history
Removing unnecessary gradient tensor loads
  • Loading branch information
smiret-intel authored Nov 11, 2024
2 parents cff27df + 5bee674 commit 6a78110
Showing 1 changed file with 10 additions and 33 deletions.
43 changes: 10 additions & 33 deletions src/equitriton/sph_harm/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,13 @@ def _triton_second_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down Expand Up @@ -347,21 +339,14 @@ def _triton_third_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
# initialize gradients
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down Expand Up @@ -666,21 +651,13 @@ def _triton_fourth_order_bwd(
x = tl.load(x_row_start, mask=offset < vector_length)
y = tl.load(y_row_start, mask=offset < vector_length)
z = tl.load(z_row_start, mask=offset < vector_length)
# load the pre-allocated xyz gradients
g_x_start = g_x_ptr + offset
g_y_start = g_y_ptr + offset
g_z_start = g_z_ptr + offset
# NOTE: these are the gradient outputs and are assumed to be initially zeros
g_x = tl.load(g_x_start, mask=offset < vector_length)
g_y = tl.load(g_y_start, mask=offset < vector_length)
g_z = tl.load(g_z_start, mask=offset < vector_length)
# this is the first order derivative, which is just root 3
g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length)
g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length)
g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length)
g_x += sqrt_3 * g_1_0
g_y += sqrt_3 * g_1_1
g_z += sqrt_3 * g_1_2
g_x = sqrt_3 * g_1_0
g_y = sqrt_3 * g_1_1
g_z = sqrt_3 * g_1_2
# now work on the second order derivatives, grouped by m
g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length)
g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length)
Expand Down

0 comments on commit 6a78110

Please sign in to comment.