Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dr.syntax: AD gets disabled by variable use in while loop #253

Closed
dvicini opened this issue Aug 12, 2024 · 8 comments
Closed

dr.syntax: AD gets disabled by variable use in while loop #253

dvicini opened this issue Aug 12, 2024 · 8 comments

Comments

@dvicini
Copy link
Member

dvicini commented Aug 12, 2024

I have some code that mixes AD with handwritten derivatives & loops (e.g., similar to something like PRB)

It appears that within dr.syntax, any use of a differentiable variable within a loop disables the variable's AD graph.

Here is an example:

import drjit as dr

@dr.syntax
def f():

  i = dr.zeros(dr.llvm.Int32, 10)
  result = dr.zeros(dr.llvm.ad.Float, 10)

  a = dr.linspace(dr.llvm.ad.Float, 0, 1, 10)
  dr.enable_grad(a)
  print(dr.grad_enabled(a))

  with dr.suspend_grad():
    while i < 5:
      result += a

  print(dr.grad_enabled(a))

f()

This prints

True
False

But I would have expected this to print

True
True

The current behavior is a bit unintuitive, and leads to confusing loss of gradient tracking. To me it seems that the loop should have no influence on whether a has gradients enabled or not, similar to the pre-dr.syntax behavior.

@njroussel
Copy link
Member

Hi @dvicini

It looks like the syntax rewriting is a bit aggressive here: it thinks a is part of the loop state. This means that a gets re-assigned at the end the loop and hence within the suspend_grad. Chaning the loop to while dr.hint(i < 5, exclude=[a]): fixes the issue.

I'll look into why it thinks a should be in the state. As a general rule, the @dr.syntax tends to be "safer" than necessary in order to guarantee the loop is valid, but that can come at a cost as you see here.

@njroussel
Copy link
Member

Ok, there's not much we can do here, I believe.

In short, even though it's clear in your reproducer that a is not being written to in the loop, we cannot guarantee that in a more general case. For example, a random number generator will effectively never be on the left-hand side of an assignement (i.e sampler = sampler.next_1d()), but it must still be considered as part of the loop state in order to work because it might evolve implicitly (i.e sampler.next_1d()). So, in this case, we still consider a to be in the loop state because even though it's only use on the right-hand side of an assignment, it might have evolved implicitly.

The workflow for these kind of situations is usually to add print_code=True to @dr.syntax() and look at the rewritten function. It's usually fairly obvious if too many variables are included and that you should then specify then in the exclude list of a dr.hint statement.

@dvicini
Copy link
Member Author

dvicini commented Aug 14, 2024

Fair enough, thanks for checking. I have to say I wasn't super aware of the various debug options for dr.syntax, but with print_code=True and the exclude hints, this should be okay in practice.

@wjakob
Copy link
Member

wjakob commented Aug 17, 2024

The loop constructs, dr.syntax and dr.hint have extensive documentation. Please take a look and post an issue/PR if anything should be added. We should document any potential gotchas.

@wjakob wjakob reopened this Sep 27, 2024
@dvicini
Copy link
Member Author

dvicini commented Oct 4, 2024

I've ran into another seemingly related problem: The following code silently produces a result value of 0, instead of 1 as it should. If I remove the if-statement, it works as expected. I tried adding dr.hint(..., exclude=[b]), but that produces an error RuntimeError: ad_traverse(): tried to forward-propagate derivatives across edge a1 -> a2, which lies outside of the current dr.isolate_grad() scope

import drjit as dr
import mitsuba as mi

mi.set_variant('llvm_ad_rgb')

@dr.syntax(print_code=True)
def f():
  param = mi.Float(0.0)
  dr.enable_grad(param)
  dr.set_grad(param, 1.0)

  a = dr.linspace(mi.Float, 1, 2, 16) + param

  result = mi.Float(0.0)
  b = dr.gather(mi.Float, a, 3)
  if b == b: # Always true
    result += dr.forward_to(b)  # Fails silently

  # Doing the same without the if-statement works as expected
  # result += dr.forward_to(b)

  return result


result = f()
print(result)

wjakob added a commit that referenced this issue Oct 5, 2024
The AD layer exposes a function named ``ad_var_inc_ref()`` that
increases the reference count of a variable analogous to
``jit_var_inc_ref()``. However, one difference between the two is that
the former detaches AD variables when the underlying index has
derivative tracking disabled.

For example, this ensures that code like

```python
x = Float(0)
dr.enable_grad(x)
with dr.suspend_grad():
    y = Float(y)
```
creates a non-differentiable copy.

However, since there are many other operations throughout the Dr.Jit
codebase that require reference counting, there were quite a few places
that exhibited this detaching behavior, which is not always wanted.
(see issue #253).

This commit provides two reference counting functions:

- ``ad_var_inc_ref()`` which increases the reference count *without*
  detaching, and

- ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former
  behavior)

Following this split, only the constructor of AD arrays uses the
detaching ``ad_var_copy_ref()``, while all other operations use
the new ``ad_var_inc_ref()``.
@wjakob
Copy link
Member

wjakob commented Oct 5, 2024

The last issue you posted is unrelated to the problem with dr.suspend_grad and symbolic operations. Could you create a separate issue for it? It's actually not quite sure how this should behave in general. Suppose we have an arbitrarily nested sequence of symbolic operations, and the user calls dr.forward at the innermost level. The system then has to kind of travel back in time and forward-propagate derivatives into each of the outer scopes.

The way we handled this in Mitsuba before is that you had to do a forward AD pass outside of the symbolic region, whose derivative values can then be picked up. But this of course isn't fully general.

@dvicini
Copy link
Member Author

dvicini commented Oct 5, 2024

Yes you are right, I created a separate issue to track this: #295

wjakob added a commit that referenced this issue Oct 16, 2024
…pended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #1 of the fix for issue #253 reported by @dvicini and
targets ``dr.if_stmt()`` only. The next commit will also fix the same
problem for while loops.
wjakob added a commit that referenced this issue Oct 16, 2024
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
wjakob added a commit that referenced this issue Oct 17, 2024
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
wjakob added a commit that referenced this issue Oct 18, 2024
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
njroussel pushed a commit that referenced this issue Oct 21, 2024
The AD layer exposes a function named ``ad_var_inc_ref()`` that
increases the reference count of a variable analogous to
``jit_var_inc_ref()``. However, one difference between the two is that
the former detaches AD variables when the underlying index has
derivative tracking disabled.

For example, this ensures that code like

```python
x = Float(0)
dr.enable_grad(x)
with dr.suspend_grad():
    y = Float(y)
```
creates a non-differentiable copy.

However, since there are many other operations throughout the Dr.Jit
codebase that require reference counting, there were quite a few places
that exhibited this detaching behavior, which is not always wanted.
(see issue #253).

This commit provides two reference counting functions:

- ``ad_var_inc_ref()`` which increases the reference count *without*
  detaching, and

- ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former
  behavior)

Following this split, only the constructor of AD arrays uses the
detaching ``ad_var_copy_ref()``, while all other operations use
the new ``ad_var_inc_ref()``.
njroussel pushed a commit that referenced this issue Oct 21, 2024
…pended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #1 of the fix for issue #253 reported by @dvicini and
targets ``dr.if_stmt()`` only. The next commit will also fix the same
problem for while loops.
njroussel pushed a commit that referenced this issue Oct 21, 2024
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
njroussel pushed a commit that referenced this issue Oct 21, 2024
The AD layer exposes a function named ``ad_var_inc_ref()`` that
increases the reference count of a variable analogous to
``jit_var_inc_ref()``. However, one difference between the two is that
the former detaches AD variables when the underlying index has
derivative tracking disabled.

For example, this ensures that code like

```python
x = Float(0)
dr.enable_grad(x)
with dr.suspend_grad():
    y = Float(y)
```
creates a non-differentiable copy.

However, since there are many other operations throughout the Dr.Jit
codebase that require reference counting, there were quite a few places
that exhibited this detaching behavior, which is not always wanted.
(see issue #253).

This commit provides two reference counting functions:

- ``ad_var_inc_ref()`` which increases the reference count *without*
  detaching, and

- ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former
  behavior)

Following this split, only the constructor of AD arrays uses the
detaching ``ad_var_copy_ref()``, while all other operations use
the new ``ad_var_inc_ref()``.
njroussel pushed a commit that referenced this issue Oct 21, 2024
…pended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #1 of the fix for issue #253 reported by @dvicini and
targets ``dr.if_stmt()`` only. The next commit will also fix the same
problem for while loops.
njroussel pushed a commit that referenced this issue Oct 21, 2024
…AD-suspended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #2 of the fix for issue #253 reported by @dvicini and
targets ``dr.while_lop()`` only. The previous commit fixed the same
problem for ``if`` statements.
@njroussel
Copy link
Member

Fixed in #299

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants