-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dr.syntax: AD gets disabled by variable use in while loop #253
Comments
Hi @dvicini It looks like the syntax rewriting is a bit aggressive here: it thinks I'll look into why it thinks |
Ok, there's not much we can do here, I believe. In short, even though it's clear in your reproducer that The workflow for these kind of situations is usually to add |
Fair enough, thanks for checking. I have to say I wasn't super aware of the various debug options for |
The loop constructs, dr.syntax and dr.hint have extensive documentation. Please take a look and post an issue/PR if anything should be added. We should document any potential gotchas. |
I've ran into another seemingly related problem: The following code silently produces a import drjit as dr
import mitsuba as mi
mi.set_variant('llvm_ad_rgb')
@dr.syntax(print_code=True)
def f():
param = mi.Float(0.0)
dr.enable_grad(param)
dr.set_grad(param, 1.0)
a = dr.linspace(mi.Float, 1, 2, 16) + param
result = mi.Float(0.0)
b = dr.gather(mi.Float, a, 3)
if b == b: # Always true
result += dr.forward_to(b) # Fails silently
# Doing the same without the if-statement works as expected
# result += dr.forward_to(b)
return result
result = f()
print(result) |
The AD layer exposes a function named ``ad_var_inc_ref()`` that increases the reference count of a variable analogous to ``jit_var_inc_ref()``. However, one difference between the two is that the former detaches AD variables when the underlying index has derivative tracking disabled. For example, this ensures that code like ```python x = Float(0) dr.enable_grad(x) with dr.suspend_grad(): y = Float(y) ``` creates a non-differentiable copy. However, since there are many other operations throughout the Dr.Jit codebase that require reference counting, there were quite a few places that exhibited this detaching behavior, which is not always wanted. (see issue #253). This commit provides two reference counting functions: - ``ad_var_inc_ref()`` which increases the reference count *without* detaching, and - ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former behavior) Following this split, only the constructor of AD arrays uses the detaching ``ad_var_copy_ref()``, while all other operations use the new ``ad_var_inc_ref()``.
The last issue you posted is unrelated to the problem with The way we handled this in Mitsuba before is that you had to do a forward AD pass outside of the symbolic region, whose derivative values can then be picked up. But this of course isn't fully general. |
Yes you are right, I created a separate issue to track this: #295 |
…pended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #1 of the fix for issue #253 reported by @dvicini and targets ``dr.if_stmt()`` only. The next commit will also fix the same problem for while loops.
…AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements.
…AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements.
…AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements.
The AD layer exposes a function named ``ad_var_inc_ref()`` that increases the reference count of a variable analogous to ``jit_var_inc_ref()``. However, one difference between the two is that the former detaches AD variables when the underlying index has derivative tracking disabled. For example, this ensures that code like ```python x = Float(0) dr.enable_grad(x) with dr.suspend_grad(): y = Float(y) ``` creates a non-differentiable copy. However, since there are many other operations throughout the Dr.Jit codebase that require reference counting, there were quite a few places that exhibited this detaching behavior, which is not always wanted. (see issue #253). This commit provides two reference counting functions: - ``ad_var_inc_ref()`` which increases the reference count *without* detaching, and - ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former behavior) Following this split, only the constructor of AD arrays uses the detaching ``ad_var_copy_ref()``, while all other operations use the new ``ad_var_inc_ref()``.
…pended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #1 of the fix for issue #253 reported by @dvicini and targets ``dr.if_stmt()`` only. The next commit will also fix the same problem for while loops.
…AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements.
The AD layer exposes a function named ``ad_var_inc_ref()`` that increases the reference count of a variable analogous to ``jit_var_inc_ref()``. However, one difference between the two is that the former detaches AD variables when the underlying index has derivative tracking disabled. For example, this ensures that code like ```python x = Float(0) dr.enable_grad(x) with dr.suspend_grad(): y = Float(y) ``` creates a non-differentiable copy. However, since there are many other operations throughout the Dr.Jit codebase that require reference counting, there were quite a few places that exhibited this detaching behavior, which is not always wanted. (see issue #253). This commit provides two reference counting functions: - ``ad_var_inc_ref()`` which increases the reference count *without* detaching, and - ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former behavior) Following this split, only the constructor of AD arrays uses the detaching ``ad_var_copy_ref()``, while all other operations use the new ``ad_var_inc_ref()``.
…pended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #1 of the fix for issue #253 reported by @dvicini and targets ``dr.if_stmt()`` only. The next commit will also fix the same problem for while loops.
…AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #2 of the fix for issue #253 reported by @dvicini and targets ``dr.while_lop()`` only. The previous commit fixed the same problem for ``if`` statements.
Fixed in #299 |
I have some code that mixes AD with handwritten derivatives & loops (e.g., similar to something like PRB)
It appears that within
dr.syntax
, any use of a differentiable variable within a loop disables the variable's AD graph.Here is an example:
This prints
But I would have expected this to print
The current behavior is a bit unintuitive, and leads to confusing loss of gradient tracking. To me it seems that the loop should have no influence on whether
a
has gradients enabled or not, similar to the pre-dr.syntax behavior.The text was updated successfully, but these errors were encountered: