From fb4459641aec367fe07119f963f18c488147c060 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Thu, 7 Nov 2024 08:36:46 -0800 Subject: [PATCH 1/3] refactor: removing initial zero gradient loads from l=2 and 3 Signed-off-by: Kin Long Kelvin Lee --- src/equitriton/sph_harm/triton_kernels.py | 35 +++++++++-------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/equitriton/sph_harm/triton_kernels.py b/src/equitriton/sph_harm/triton_kernels.py index 1271127..9b50c62 100644 --- a/src/equitriton/sph_harm/triton_kernels.py +++ b/src/equitriton/sph_harm/triton_kernels.py @@ -168,21 +168,13 @@ def _triton_second_order_bwd( x = tl.load(x_row_start, mask=offset < vector_length) y = tl.load(y_row_start, mask=offset < vector_length) z = tl.load(z_row_start, mask=offset < vector_length) - # load the pre-allocated xyz gradients - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset - # NOTE: these are the gradient outputs and are assumed to be initially zeros - g_x = tl.load(g_x_start, mask=offset < vector_length) - g_y = tl.load(g_y_start, mask=offset < vector_length) - g_z = tl.load(g_z_start, mask=offset < vector_length) # this is the first order derivative, which is just root 3 g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) - g_x += sqrt_3 * g_1_0 - g_y += sqrt_3 * g_1_1 - g_z += sqrt_3 * g_1_2 + g_x = sqrt_3 * g_1_0 + g_y = sqrt_3 * g_1_1 + g_z = sqrt_3 * g_1_2 # now work on the second order derivatives, grouped by m g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) @@ -206,6 +198,9 @@ def _triton_second_order_bwd( g_x += -1.0 * sqrt_15 * x * g_2_4 g_z += sqrt_15 * z * g_2_4 # after all the operations are done, write back to memory + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) @@ -347,21 +342,14 @@ def _triton_third_order_bwd( x = tl.load(x_row_start, mask=offset < vector_length) y = tl.load(y_row_start, mask=offset < vector_length) z = tl.load(z_row_start, mask=offset < vector_length) - # load the pre-allocated xyz gradients - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset - # NOTE: these are the gradient outputs and are assumed to be initially zeros - g_x = tl.load(g_x_start, mask=offset < vector_length) - g_y = tl.load(g_y_start, mask=offset < vector_length) - g_z = tl.load(g_z_start, mask=offset < vector_length) # this is the first order derivative, which is just root 3 g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) - g_x += sqrt_3 * g_1_0 - g_y += sqrt_3 * g_1_1 - g_z += sqrt_3 * g_1_2 + # initialize gradients + g_x = sqrt_3 * g_1_0 + g_y = sqrt_3 * g_1_1 + g_z = sqrt_3 * g_1_2 # now work on the second order derivatives, grouped by m g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) @@ -438,6 +426,9 @@ def _triton_third_order_bwd( * (1.08012344973464 * sq_x + 0.540061724867322 * sq_x - 1.62018517460196 * sq_z) ) # after all the operations are done, write back to memory + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) From f912e44609726e8938e4dcccbe52f10806944d91 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Thu, 7 Nov 2024 08:41:38 -0800 Subject: [PATCH 2/3] refactor: removing zero gradient loads for l=4 Signed-off-by: Kin Long Kelvin Lee --- src/equitriton/sph_harm/triton_kernels.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/equitriton/sph_harm/triton_kernels.py b/src/equitriton/sph_harm/triton_kernels.py index 9b50c62..ccd02f8 100644 --- a/src/equitriton/sph_harm/triton_kernels.py +++ b/src/equitriton/sph_harm/triton_kernels.py @@ -657,21 +657,13 @@ def _triton_fourth_order_bwd( x = tl.load(x_row_start, mask=offset < vector_length) y = tl.load(y_row_start, mask=offset < vector_length) z = tl.load(z_row_start, mask=offset < vector_length) - # load the pre-allocated xyz gradients - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset - # NOTE: these are the gradient outputs and are assumed to be initially zeros - g_x = tl.load(g_x_start, mask=offset < vector_length) - g_y = tl.load(g_y_start, mask=offset < vector_length) - g_z = tl.load(g_z_start, mask=offset < vector_length) # this is the first order derivative, which is just root 3 g_1_0 = tl.load(g_1_0_ptr + offset, mask=offset < vector_length) g_1_1 = tl.load(g_1_1_ptr + offset, mask=offset < vector_length) g_1_2 = tl.load(g_1_2_ptr + offset, mask=offset < vector_length) - g_x += sqrt_3 * g_1_0 - g_y += sqrt_3 * g_1_1 - g_z += sqrt_3 * g_1_2 + g_x = sqrt_3 * g_1_0 + g_y = sqrt_3 * g_1_1 + g_z = sqrt_3 * g_1_2 # now work on the second order derivatives, grouped by m g_2_0 = tl.load(g_2_0_ptr + offset, mask=offset < vector_length) g_2_1 = tl.load(g_2_1_ptr + offset, mask=offset < vector_length) @@ -948,6 +940,9 @@ def _triton_fourth_order_bwd( ) ) # after all the operations are done, write back to memory + g_x_start = g_x_ptr + offset + g_y_start = g_y_ptr + offset + g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) From 5bee6747cfc73dbfce3fb58b4cf310478b67762c Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Thu, 7 Nov 2024 08:44:06 -0800 Subject: [PATCH 3/3] refactor: removing unused gradient pointer offset variables Signed-off-by: Kin Long Kelvin Lee --- src/equitriton/sph_harm/triton_kernels.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/equitriton/sph_harm/triton_kernels.py b/src/equitriton/sph_harm/triton_kernels.py index ccd02f8..bbfdc50 100644 --- a/src/equitriton/sph_harm/triton_kernels.py +++ b/src/equitriton/sph_harm/triton_kernels.py @@ -198,9 +198,6 @@ def _triton_second_order_bwd( g_x += -1.0 * sqrt_15 * x * g_2_4 g_z += sqrt_15 * z * g_2_4 # after all the operations are done, write back to memory - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) @@ -426,9 +423,6 @@ def _triton_third_order_bwd( * (1.08012344973464 * sq_x + 0.540061724867322 * sq_x - 1.62018517460196 * sq_z) ) # after all the operations are done, write back to memory - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length) @@ -940,9 +934,6 @@ def _triton_fourth_order_bwd( ) ) # after all the operations are done, write back to memory - g_x_start = g_x_ptr + offset - g_y_start = g_y_ptr + offset - g_z_start = g_z_ptr + offset tl.store(g_x_ptr + offset, g_x, mask=offset < vector_length) tl.store(g_y_ptr + offset, g_y, mask=offset < vector_length) tl.store(g_z_ptr + offset, g_z, mask=offset < vector_length)