From 8b8524757435aeed11f5446b02a0638d541d12c3 Mon Sep 17 00:00:00 2001 From: Eric Heiden Date: Tue, 5 Mar 2024 17:24:26 -0800 Subject: [PATCH] Fix 2 codegen issues for custom grad functions --- CHANGELOG.md | 5 ++ warp/codegen.py | 10 ++- warp/context.py | 25 ++++++-- warp/tests/test_grad_customs.py | 104 ++++++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d29416f28..062d1d8b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## [0.15.1] - 2024-03-05 + +- Fix codegen for custom grad functions calling their respective forward functions +- Fix custom grad function handling for functions that have no outputs +- Fix issues when `wp.config.quiet = True` ## [0.15.0] - 2024-03-04 diff --git a/warp/codegen.py b/warp/codegen.py index 5d1febd6e..70c40f736 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -919,12 +919,16 @@ def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None): break # if it is a user-function then build it recursively - if not func.is_builtin(): + if not func.is_builtin() and func not in adj.builder.functions: adj.builder.build_function(func) + # add custom grad, replay functions to the list of functions + # to be built later (invalid code could be generated if we built them now) + # so that they are not missed when only the forward function is imported + # from another module if func.custom_grad_func: - adj.builder.build_function(func.custom_grad_func) + adj.builder.deferred_functions.append(func.custom_grad_func) if func.custom_replay_func: - adj.builder.build_function(func.custom_replay_func) + adj.builder.deferred_functions.append(func.custom_replay_func) # evaluate the function type based on inputs arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)] diff --git a/warp/context.py b/warp/context.py index bbf2515d6..13082dd8c 100644 --- a/warp/context.py +++ b/warp/context.py @@ -710,7 +710,11 @@ def wrapper(grad_fn): def match_function(f): # check whether the function overload f matches the signature of the provided gradient function if not hasattr(f.adj, "return_var"): - f.adj.build(None, f.module.options) + # we have to temporarily build this function to figure out its return type(s); + # note that we do not have a ModuleBuilder instance here at this wrapping stage, hence we + # have to create a dummy builder + builder = ModuleBuilder(Module("dummy", None), f.module.options) + f.adj.build(builder) expected_args = list(f.input_types.items()) if f.adj.return_var is not None: expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var] @@ -1159,6 +1163,7 @@ def __init__(self, module, options): self.structs = {} self.options = options self.module = module + self.deferred_functions = [] # build all functions declared in the module for func in module.functions.values(): @@ -1175,6 +1180,10 @@ def __init__(self, module, options): for k in kernel.overloads.values(): self.build_kernel(k) + # build all functions outside this module which are called from functions or kernels in this module + for func in self.deferred_functions: + self.build_function(func) + def build_struct_recursive(self, struct: warp.codegen.Struct): structs = [] @@ -2871,6 +2880,15 @@ def __init__(self): # initialize kernel cache warp.build.init_kernel_cache(warp.config.kernel_cache_dir) + devices_without_uva = [] + devices_without_mempool = [] + for cuda_device in self.cuda_devices: + if cuda_device.is_primary: + if not cuda_device.is_uva: + devices_without_uva.append(cuda_device) + if not cuda_device.is_mempool_supported: + devices_without_mempool.append(cuda_device) + # print device and version information if not warp.config.quiet: greeting = [] @@ -2893,16 +2911,12 @@ def __init__(self): alias_str = f'"{self.cpu_device.alias}"' name_str = f'"{self.cpu_device.name}"' greeting.append(f" {alias_str:10s} : {name_str}") - devices_without_uva = [] - devices_without_mempool = [] for cuda_device in self.cuda_devices: alias_str = f'"{cuda_device.alias}"' if cuda_device.is_primary: name_str = f'"{cuda_device.name}"' arch_str = f"sm_{cuda_device.arch}" mem_str = f"{cuda_device.total_memory / 1024 / 1024 / 1024:.0f} GiB" - if not cuda_device.is_uva: - devices_without_uva.append(cuda_device) if cuda_device.is_mempool_supported: if cuda_device.is_mempool_enabled: mempool_str = "mempool enabled" @@ -2910,7 +2924,6 @@ def __init__(self): mempool_str = "mempool supported" else: mempool_str = "mempool not supported" - devices_without_mempool.append(cuda_device) greeting.append(f" {alias_str:10s} : {name_str} ({mem_str}, {arch_str}, {mempool_str})") else: primary_alias_str = f'"{self.cuda_primary_devices[cuda_device.ordinal].alias}"' diff --git a/warp/tests/test_grad_customs.py b/warp/tests/test_grad_customs.py index 588c555ca..20099bb05 100644 --- a/warp/tests/test_grad_customs.py +++ b/warp/tests/test_grad_customs.py @@ -197,6 +197,109 @@ def run_defined_float_fn( assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0) +@wp.func +def sigmoid(x: float): + return 1.0 / (1.0 + wp.exp(-x)) + + +@wp.func_grad(sigmoid) +def adj_sigmoid(x: float, adj: float): + # unused function to test that we don't run into infinite recursion when calling + # the forward function from within the gradient function + wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x)) + + +@wp.func +def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)): + # test function that does not return anything + ys[i] = sigmoid(xs[i]) + + +@wp.func_grad(sigmoid_no_return) +def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)): + wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i]) + + +@wp.kernel +def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)): + i = wp.tid() + sigmoid_no_return(i, xs, ys) + + +def test_custom_grad_no_return(test, device): + xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True) + ys = wp.zeros_like(xs) + ys.grad.fill_(1.0) + + tape = wp.Tape() + with tape: + wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys]) + tape.backward() + + sigmoids = ys.numpy() + grad = xs.grad.numpy() + assert_np_equal(grad, sigmoids * (1.0 - sigmoids)) + + +@wp.func +def dense_gemm( + m: int, + n: int, + p: int, + transpose_A: bool, + transpose_B: bool, + add_to_C: bool, + A: wp.array(dtype=float), + B: wp.array(dtype=float), + # outputs + C: wp.array(dtype=float), +): + # this function doesn't get called but it is an important test for code generation + # multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C + for i in range(m): + for j in range(n): + sum = float(0.0) + for k in range(p): + if transpose_A: + a_i = k * m + i + else: + a_i = i * p + k + if transpose_B: + b_j = j * p + k + else: + b_j = k * n + j + sum += A[a_i] * B[b_j] + + if add_to_C: + C[i * n + j] += sum + else: + C[i * n + j] = sum + + +@wp.func_grad(dense_gemm) +def adj_dense_gemm( + m: int, + n: int, + p: int, + transpose_A: bool, + transpose_B: bool, + add_to_C: bool, + A: wp.array(dtype=float), + B: wp.array(dtype=float), + # outputs + C: wp.array(dtype=float), +): + # code generation would break here if we didn't defer building the custom grad + # function until after the forward functions + kernels of the module have been built + add_to_C = True + if transpose_A: + dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A]) + dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B]) + else: + dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A]) + dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B]) + + devices = get_test_devices() @@ -207,6 +310,7 @@ class TestGradCustoms(unittest.TestCase): add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices) add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices) add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices) +add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices) if __name__ == "__main__":