Skip to content

Commit

Permalink
Merge branch 'ccrouzet/gh-386-func-defaults' into 'main'
Browse files Browse the repository at this point in the history
Fix Default Args Not Working In Python’s Runtime

Closes GH-386

See merge request omniverse/warp!922
  • Loading branch information
christophercrouzet committed Dec 12, 2024
2 parents 9cdd185 + cbe16b2 commit c7fbe05
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Fix custom colors being not being updated when rendering meshes with static topology in OpenGL ([GH-343](https://github.com/NVIDIA/warp/issues/343)).
- Fix `wp.launch_tiled()` not returning a `Launch` object when passed `record_cmd=True`.
- Mark kernel arrays as written to when passed to `wp.atomic_add()` or `wp.atomic_sub()`
- Fix default arguments not being resolved for `wp.func` when called from Python's runtime ([GH-386](https://github.com/NVIDIA/warp/issues/386)).

## [1.5.0] - 2024-12-02

Expand Down
14 changes: 7 additions & 7 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,22 @@ def __call__(self, *args, **kwargs):

if hasattr(self, "user_overloads") and len(self.user_overloads):
# user-defined function with overloads
bound_args = self.signature.bind(*args, **kwargs)
if self.defaults:
warp.codegen.apply_defaults(bound_args, self.defaults)

if len(kwargs):
raise RuntimeError(
f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
)
arguments = tuple(bound_args.arguments.values())

# try and find a matching overload
for overload in self.user_overloads.values():
if len(overload.input_types) != len(args):
if len(overload.input_types) != len(arguments):
continue
template_types = list(overload.input_types.values())
arg_names = list(overload.input_types.keys())
try:
# attempt to unify argument types with function template types
warp.types.infer_argument_types(args, template_types, arg_names)
return overload.func(*args)
warp.types.infer_argument_types(arguments, template_types, arg_names)
return overload.func(*arguments)
except Exception:
continue

Expand Down
25 changes: 21 additions & 4 deletions warp/tests/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def user_func_with_defaults(a: int = 123, b: int = 234) -> int:


@wp.kernel
def test_user_func_with_defaults():
def user_func_with_defaults_kernel():
a = user_func_with_defaults()
wp.expect_eq(a, 357)

Expand All @@ -179,6 +179,25 @@ def test_user_func_with_defaults():
wp.expect_eq(e, 234)


def test_user_func_with_defaults(test, device):
wp.launch(user_func_with_defaults_kernel, dim=1, device=device)

a = user_func_with_defaults()
assert a == 357

b = user_func_with_defaults(111)
assert b == 345

c = user_func_with_defaults(111, 222)
assert c == 333

d = user_func_with_defaults(a=111)
assert d == 345

e = user_func_with_defaults(b=111)
assert e == 234


@wp.func
def user_func_return_multiple_values(a: int, b: float) -> Tuple[int, float]:
return a + a, b * b
Expand Down Expand Up @@ -406,9 +425,7 @@ def test_native_function_error_resolution(self):
add_function_test(TestFunc, func=test_multi_valued_func, name="test_multi_valued_func", devices=devices)
add_kernel_test(TestFunc, kernel=test_func_defaults, name="test_func_defaults", dim=1, devices=devices)
add_kernel_test(TestFunc, kernel=test_builtin_shadowing, name="test_builtin_shadowing", dim=1, devices=devices)
add_kernel_test(
TestFunc, kernel=test_user_func_with_defaults, name="test_user_func_with_defaults", dim=1, devices=devices
)
add_function_test(TestFunc, func=test_user_func_with_defaults, name="test_user_func_with_defaults", devices=devices)
add_kernel_test(
TestFunc,
kernel=test_user_func_return_multiple_values,
Expand Down

0 comments on commit c7fbe05

Please sign in to comment.