Skip to content

Commit

Permalink
Merge branch 'fix-grad-codegen' into 'master'
Browse files Browse the repository at this point in the history
Fix 2 codegen issues for custom grad functions

See merge request omniverse/warp!331
  • Loading branch information
eric-heiden committed Mar 6, 2024
2 parents cefd01f + 8b85247 commit 2ce2498
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 9 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## [0.15.1] - 2024-03-05

- Fix codegen for custom grad functions calling their respective forward functions
- Fix custom grad function handling for functions that have no outputs
- Fix issues when `wp.config.quiet = True`

## [0.15.0] - 2024-03-04

Expand Down
10 changes: 7 additions & 3 deletions warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,12 +919,16 @@ def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
break

# if it is a user-function then build it recursively
if not func.is_builtin():
if not func.is_builtin() and func not in adj.builder.functions:
adj.builder.build_function(func)
# add custom grad, replay functions to the list of functions
# to be built later (invalid code could be generated if we built them now)
# so that they are not missed when only the forward function is imported
# from another module
if func.custom_grad_func:
adj.builder.build_function(func.custom_grad_func)
adj.builder.deferred_functions.append(func.custom_grad_func)
if func.custom_replay_func:
adj.builder.build_function(func.custom_replay_func)
adj.builder.deferred_functions.append(func.custom_replay_func)

# evaluate the function type based on inputs
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
Expand Down
25 changes: 19 additions & 6 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,11 @@ def wrapper(grad_fn):
def match_function(f):
# check whether the function overload f matches the signature of the provided gradient function
if not hasattr(f.adj, "return_var"):
f.adj.build(None, f.module.options)
# we have to temporarily build this function to figure out its return type(s);
# note that we do not have a ModuleBuilder instance here at this wrapping stage, hence we
# have to create a dummy builder
builder = ModuleBuilder(Module("dummy", None), f.module.options)
f.adj.build(builder)
expected_args = list(f.input_types.items())
if f.adj.return_var is not None:
expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
Expand Down Expand Up @@ -1159,6 +1163,7 @@ def __init__(self, module, options):
self.structs = {}
self.options = options
self.module = module
self.deferred_functions = []

# build all functions declared in the module
for func in module.functions.values():
Expand All @@ -1175,6 +1180,10 @@ def __init__(self, module, options):
for k in kernel.overloads.values():
self.build_kernel(k)

# build all functions outside this module which are called from functions or kernels in this module
for func in self.deferred_functions:
self.build_function(func)

def build_struct_recursive(self, struct: warp.codegen.Struct):
structs = []

Expand Down Expand Up @@ -2871,6 +2880,15 @@ def __init__(self):
# initialize kernel cache
warp.build.init_kernel_cache(warp.config.kernel_cache_dir)

devices_without_uva = []
devices_without_mempool = []
for cuda_device in self.cuda_devices:
if cuda_device.is_primary:
if not cuda_device.is_uva:
devices_without_uva.append(cuda_device)
if not cuda_device.is_mempool_supported:
devices_without_mempool.append(cuda_device)

# print device and version information
if not warp.config.quiet:
greeting = []
Expand All @@ -2893,24 +2911,19 @@ def __init__(self):
alias_str = f'"{self.cpu_device.alias}"'
name_str = f'"{self.cpu_device.name}"'
greeting.append(f" {alias_str:10s} : {name_str}")
devices_without_uva = []
devices_without_mempool = []
for cuda_device in self.cuda_devices:
alias_str = f'"{cuda_device.alias}"'
if cuda_device.is_primary:
name_str = f'"{cuda_device.name}"'
arch_str = f"sm_{cuda_device.arch}"
mem_str = f"{cuda_device.total_memory / 1024 / 1024 / 1024:.0f} GiB"
if not cuda_device.is_uva:
devices_without_uva.append(cuda_device)
if cuda_device.is_mempool_supported:
if cuda_device.is_mempool_enabled:
mempool_str = "mempool enabled"
else:
mempool_str = "mempool supported"
else:
mempool_str = "mempool not supported"
devices_without_mempool.append(cuda_device)
greeting.append(f" {alias_str:10s} : {name_str} ({mem_str}, {arch_str}, {mempool_str})")
else:
primary_alias_str = f'"{self.cuda_primary_devices[cuda_device.ordinal].alias}"'
Expand Down
104 changes: 104 additions & 0 deletions warp/tests/test_grad_customs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,109 @@ def run_defined_float_fn(
assert_np_equal(ys_float.grad.numpy(), ys_float.numpy() * 3.0)


@wp.func
def sigmoid(x: float):
return 1.0 / (1.0 + wp.exp(-x))


@wp.func_grad(sigmoid)
def adj_sigmoid(x: float, adj: float):
# unused function to test that we don't run into infinite recursion when calling
# the forward function from within the gradient function
wp.adjoint[x] += adj * sigmoid(x) * (1.0 - sigmoid(x))


@wp.func
def sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
# test function that does not return anything
ys[i] = sigmoid(xs[i])


@wp.func_grad(sigmoid_no_return)
def adj_sigmoid_no_return(i: int, xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
wp.adjoint[xs][i] += ys[i] * (1.0 - ys[i])


@wp.kernel
def eval_sigmoid(xs: wp.array(dtype=float), ys: wp.array(dtype=float)):
i = wp.tid()
sigmoid_no_return(i, xs, ys)


def test_custom_grad_no_return(test, device):
xs = wp.array([1.0, 2.0, 3.0, 4.0], dtype=wp.float32, requires_grad=True)
ys = wp.zeros_like(xs)
ys.grad.fill_(1.0)

tape = wp.Tape()
with tape:
wp.launch(eval_sigmoid, dim=len(xs), inputs=[xs], outputs=[ys])
tape.backward()

sigmoids = ys.numpy()
grad = xs.grad.numpy()
assert_np_equal(grad, sigmoids * (1.0 - sigmoids))


@wp.func
def dense_gemm(
m: int,
n: int,
p: int,
transpose_A: bool,
transpose_B: bool,
add_to_C: bool,
A: wp.array(dtype=float),
B: wp.array(dtype=float),
# outputs
C: wp.array(dtype=float),
):
# this function doesn't get called but it is an important test for code generation
# multiply a `m x p` matrix A by a `p x n` matrix B to produce a `m x n` matrix C
for i in range(m):
for j in range(n):
sum = float(0.0)
for k in range(p):
if transpose_A:
a_i = k * m + i
else:
a_i = i * p + k
if transpose_B:
b_j = j * p + k
else:
b_j = k * n + j
sum += A[a_i] * B[b_j]

if add_to_C:
C[i * n + j] += sum
else:
C[i * n + j] = sum


@wp.func_grad(dense_gemm)
def adj_dense_gemm(
m: int,
n: int,
p: int,
transpose_A: bool,
transpose_B: bool,
add_to_C: bool,
A: wp.array(dtype=float),
B: wp.array(dtype=float),
# outputs
C: wp.array(dtype=float),
):
# code generation would break here if we didn't defer building the custom grad
# function until after the forward functions + kernels of the module have been built
add_to_C = True
if transpose_A:
dense_gemm(p, m, n, False, True, add_to_C, B, wp.adjoint[C], wp.adjoint[A])
dense_gemm(p, n, m, False, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])
else:
dense_gemm(m, p, n, False, not transpose_B, add_to_C, wp.adjoint[C], B, wp.adjoint[A])
dense_gemm(p, n, m, True, False, add_to_C, A, wp.adjoint[C], wp.adjoint[B])


devices = get_test_devices()


Expand All @@ -207,6 +310,7 @@ class TestGradCustoms(unittest.TestCase):
add_function_test(TestGradCustoms, "test_custom_replay_grad", test_custom_replay_grad, devices=devices)
add_function_test(TestGradCustoms, "test_custom_overload_grad", test_custom_overload_grad, devices=devices)
add_function_test(TestGradCustoms, "test_custom_import_grad", test_custom_import_grad, devices=devices)
add_function_test(TestGradCustoms, "test_custom_grad_no_return", test_custom_grad_no_return, devices=devices)


if __name__ == "__main__":
Expand Down

0 comments on commit 2ce2498

Please sign in to comment.