diff --git a/CHANGELOG.md b/CHANGELOG.md index 97c45de80..53c40b4ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ - Fix custom colors being not being updated when rendering meshes with static topology in OpenGL ([GH-343](https://github.com/NVIDIA/warp/issues/343)). - Fix `wp.launch_tiled()` not returning a `Launch` object when passed `record_cmd=True`. - Mark kernel arrays as written to when passed to `wp.atomic_add()` or `wp.atomic_sub()` +- Fix default arguments not being resolved for `wp.func` when called from Python's runtime ([GH-386](https://github.com/NVIDIA/warp/issues/386)). ## [1.5.0] - 2024-12-02 diff --git a/warp/context.py b/warp/context.py index d4092c4e2..c368d3493 100644 --- a/warp/context.py +++ b/warp/context.py @@ -293,22 +293,22 @@ def __call__(self, *args, **kwargs): if hasattr(self, "user_overloads") and len(self.user_overloads): # user-defined function with overloads + bound_args = self.signature.bind(*args, **kwargs) + if self.defaults: + warp.codegen.apply_defaults(bound_args, self.defaults) - if len(kwargs): - raise RuntimeError( - f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads." - ) + arguments = tuple(bound_args.arguments.values()) # try and find a matching overload for overload in self.user_overloads.values(): - if len(overload.input_types) != len(args): + if len(overload.input_types) != len(arguments): continue template_types = list(overload.input_types.values()) arg_names = list(overload.input_types.keys()) try: # attempt to unify argument types with function template types - warp.types.infer_argument_types(args, template_types, arg_names) - return overload.func(*args) + warp.types.infer_argument_types(arguments, template_types, arg_names) + return overload.func(*arguments) except Exception: continue diff --git a/warp/tests/test_func.py b/warp/tests/test_func.py index 0b39891cd..e15481b4b 100644 --- a/warp/tests/test_func.py +++ b/warp/tests/test_func.py @@ -162,7 +162,7 @@ def user_func_with_defaults(a: int = 123, b: int = 234) -> int: @wp.kernel -def test_user_func_with_defaults(): +def user_func_with_defaults_kernel(): a = user_func_with_defaults() wp.expect_eq(a, 357) @@ -179,6 +179,25 @@ def test_user_func_with_defaults(): wp.expect_eq(e, 234) +def test_user_func_with_defaults(test, device): + wp.launch(user_func_with_defaults_kernel, dim=1, device=device) + + a = user_func_with_defaults() + assert a == 357 + + b = user_func_with_defaults(111) + assert b == 345 + + c = user_func_with_defaults(111, 222) + assert c == 333 + + d = user_func_with_defaults(a=111) + assert d == 345 + + e = user_func_with_defaults(b=111) + assert e == 234 + + @wp.func def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]: return a + a, b * b @@ -406,9 +425,7 @@ def test_native_function_error_resolution(self): add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices) add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices) add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices) -add_kernel_test( - TestFunc, kernel=test_user_func_with_defaults, name="test_user_func_with_defaults", dim=1, devices=devices -) +add_function_test(TestFunc, func=test_user_func_with_defaults, name="test_user_func_with_defaults", devices=devices) add_kernel_test( TestFunc, kernel=test_user_func_return_multiple_values,