Skip to content

Commit

Permalink
dr.if_stmt(): maintain AD status of unchanged variables in AD-sus…
Browse files Browse the repository at this point in the history
…pended mode

Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``)
disable gradient tracking of all variable state when the operation takes
place within an AD-disabled scope.

This can be surprising when a ``@dr.syntax`` transformation silently
passes local variables to such an operation, which then become
non-differentiable.

This commit carves out an exception: when variables aren't actually
modified by the control flow operation, they can retain their AD
identity.

This is part #1 of the fix for issue #253 reported by @dvicini and
targets ``dr.if_stmt()`` only. The next commit will also fix the same
problem for while loops.
  • Loading branch information
wjakob authored and njroussel committed Oct 21, 2024
1 parent 091d8ca commit 494c571
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 20 deletions.
6 changes: 6 additions & 0 deletions include/drjit/extra.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ extern DRJIT_EXTRA_EXPORT uint32_t ad_grad(uint64_t index, bool null_ok = false)
/// Check if gradient tracking is enabled for the given variable
extern DRJIT_EXTRA_EXPORT int ad_grad_enabled(uint64_t index);

/// Check if gradient tracking is disabled (can't create new AD variables)
extern DRJIT_EXTRA_EXPORT int ad_grad_suspended();

/// Temporarily enforce gradient tracking without creating a new scope
extern DRJIT_EXTRA_EXPORT int ad_set_force_grad(int status);

/// Accumulate into the gradient associated with a given variable
extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value);

Expand Down
35 changes: 28 additions & 7 deletions src/extra/autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ struct Scope {
*/
bool isolate = false;

/// Flag to temporarily force gradient tracking in an ad-disabled scope
bool force_grad = false;

// Current ``state.counter`` value when entering this scope
uint64_t counter = 0;

Expand Down Expand Up @@ -537,7 +540,7 @@ struct Scope {

/// Check if a variable has gradients enabled
bool enabled(ADIndex index) const {
return (indices.find(index) != indices.end()) != complement;
return (indices.find(index) != indices.end()) != complement || force_grad;
}

/// Potentially zero out 'index' if the variable has gradients disabled
Expand All @@ -549,7 +552,7 @@ struct Scope {

/// Track gradients for the given variable
void enable(ADIndex index) {
if (!index)
if (!index || force_grad)
return;

if (complement)
Expand Down Expand Up @@ -967,7 +970,7 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result,
active |= scope.maybe_disable(args[i].ad_index);
}

if (!active)
if (!active && !scope.force_grad)
return (Index) result.release();
}

Expand Down Expand Up @@ -1784,6 +1787,25 @@ int ad_grad_enabled(Index index) {
return ad_index != 0;
}

int ad_grad_suspended() {
const std::vector<Scope> &scopes = local_state.scopes;
if (scopes.empty())
return false;
else
return scopes.back().complement == false;
}

/// Temporarily enforce gradient tracking without creating a new scope
int ad_set_force_grad(int status) {
std::vector<Scope> &scopes = local_state.scopes;
if (scopes.empty())
return 0;
Scope &scope = scopes.back();
bool old = scope.force_grad;
scope.force_grad = (bool) status;
return (int) old;
}

// ==========================================================================
// AD traversal callbacks for special operations: masks, gathers, scatters
// ==========================================================================
Expand Down Expand Up @@ -2825,12 +2847,11 @@ Index ad_var_cast(Index i0, VarType vt) {
void ad_var_map_put(Index source, Index target) {
uint32_t ad_index_source = ad_index(source),
ad_index_target = ad_index(target);

if ((ad_index_source == 0) != (ad_index_target == 0))
ad_raise("ad_var_map_put(): mixed attached/detached inputs!");
if (ad_index_target == 0)
return;

if (ad_index_source == 0)
return;
ad_raise("ad_var_map_put(): mixed attached/detached inputs!");

ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target);

Expand Down
6 changes: 3 additions & 3 deletions src/extra/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static void ad_call_getter(JitBackend backend, const char *domain,
"ad_call_getter(\"%s%s%s\", index=r%u, mask=r%u)", domain_or_empty,
separator, name, index, mask.index());

scoped_isolation_boundary guard;
scoped_isolation_guard guard;
{
scoped_record rec(backend, name, true);

Expand Down Expand Up @@ -279,7 +279,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain,
/* Postponed operations captured by the isolation scope should only
* be executed once we've exited the symbolic scope. We therefore
* need to declare the AD isolation guard before the recording guard. */
scoped_isolation_boundary guard_1(1);
scoped_isolation_guard guard_1(1);

scoped_record guard_2(backend, name, true);
// Recording may fail due to recursion depth
Expand Down Expand Up @@ -631,7 +631,7 @@ struct CallOp : public dr::detail::CustomOpBase {

/// Implements f(arg..., grad(rv)...) -> grad(arg) ...
void backward() override {
scoped_isolation_boundary isolation_guard;
scoped_isolation_guard isolation_guard;
std::string name = m_name + " [ad, bwd]";

index64_vector args, rv;
Expand Down
13 changes: 10 additions & 3 deletions src/extra/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ template <typename T> class unlock_guard {
};

/// RAII AD Isolation helper
struct scoped_isolation_boundary {
scoped_isolation_boundary(int symbolic = -1) : symbolic(symbolic) {
struct scoped_isolation_guard {
scoped_isolation_guard(int symbolic = -1) : symbolic(symbolic) {
ad_scope_enter(drjit::ADScope::Isolate, 0, nullptr, symbolic);
}

~scoped_isolation_boundary() {
~scoped_isolation_guard() {
ad_scope_leave(success);
}

Expand All @@ -58,6 +58,13 @@ struct scoped_isolation_boundary {
bool success = false;
};

struct scoped_force_grad_guard {
scoped_force_grad_guard() { value = ad_set_force_grad(1); }
~scoped_force_grad_guard() { ad_set_force_grad(value); }
bool value;
};


/// RAII helper to temporarily push a mask onto the Dr.Jit mask stack
struct scoped_push_mask {
scoped_push_mask(JitBackend backend, uint32_t index) : backend(backend) {
Expand Down
21 changes: 16 additions & 5 deletions src/extra/cond.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,19 @@ static void ad_cond_evaluated(JitBackend backend, const char *label,
index64_vector args_t, args_f;
size_t cond_size = jit_var_size((uint32_t) cond_t);

/// For differentiable inputs, create masked AD variables
// For differentiable inputs, create masked AD variables
for (size_t i = 0; i < args.size(); ++i) {
uint64_t index = args[i];
uint32_t index_lo = (uint32_t) index;
size_t size = jit_var_size((uint32_t) index);
bool is_diff = index != index_lo;

if (is_diff && (size == cond_size || size == 1 || cond_size == 1)) {
// Force the creation of AD variables even when in an AD-suspended
// scope. This is so that we can preserve the AD status of variables
// that aren't changed by the loop
scoped_force_grad_guard guard;

uint64_t idx_t = ad_var_select(cond_t, index, index_lo),
idx_f = ad_var_select(cond_f, index, index_lo);
uint32_t ad_idx_t = (uint32_t) (idx_t >> 32),
Expand Down Expand Up @@ -145,18 +150,22 @@ static void ad_cond_symbolic(JitBackend backend, const char *label,
/* Postponed operations captured by the isolation scope should only
* be executed once we've exited the symbolic scope. We therefore
* need to declare the AD isolation guard before the recording guard. */
scoped_isolation_boundary isolation_guard(1);
scoped_isolation_guard isolation_guard(1);
scoped_record record_guard(backend);

index64_vector args_t, args_f, rv_t, rv_f, cleanup;
dr::vector<uint32_t> tmp;


// For differentiable inputs, create new disconnected AD variables
for (size_t i = 0; i < args.size(); ++i) {
uint64_t index = args[i];
uint32_t index_lo = (uint32_t) index;
if (ad && (args[i] >> 32)) {
// Force the creation of AD variables even when in an AD-suspended
// scope. This is so that we can preserve the AD status of variables
// that aren't changed by the loop
scoped_force_grad_guard guard;

uint64_t idx_t = ad_var_new(index_lo),
idx_f = ad_var_new(index_lo);

Expand Down Expand Up @@ -259,6 +268,8 @@ static void ad_cond_symbolic(JitBackend backend, const char *label,
// Unchanged differentiable outputs can be piped through directly
// without being outputs of the CustomOp
if (ad && ((idx_f >> 32) || (idx_t >> 32))) {
// Force the creation of AD variables even when in an AD-suspended
scoped_force_grad_guard guard;
idx_t = ad_var_map_get(idx_t);
idx_f = ad_var_map_get(idx_f);

Expand Down Expand Up @@ -627,7 +638,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload,
output_offsets, implicit_in, implicit_out, ad);
}

if (!input_offsets.empty() || !output_offsets.empty()) {
if ((!input_offsets.empty() || !output_offsets.empty()) && !ad_grad_suspended()) {
nanobind::ref<CondOp> op = new CondOp(
backend, label, payload, cond, body_cb, delete_cb, args, rv,
input_offsets, output_offsets, implicit_in, implicit_out);
Expand All @@ -641,7 +652,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload,
op->disable(rv);
}
} else {
scoped_isolation_boundary guard;
scoped_isolation_guard guard;
ad_cond_evaluated(backend, label, payload, true_mask.index(),
false_mask.index(), args, rv, body_cb);
guard.disarm();
Expand Down
4 changes: 2 additions & 2 deletions src/extra/loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static bool ad_loop_symbolic(JitBackend backend, const char *name,
/* Postponed operations captured by the isolation scope should only
* be executed once we've exited the symbolic scope. We therefore
* need to declare the AD isolation guard before the recording guard. */
scoped_isolation_boundary isolation_guard(1);
scoped_isolation_guard isolation_guard(1);
scoped_record record_guard(backend);

// Rewrite the loop state variables
Expand Down Expand Up @@ -980,7 +980,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress,
"drjit.while_loop() for general information on symbolic and "
"evaluated loops, as well as their limitations.");

scoped_isolation_boundary guard;
scoped_isolation_guard guard;
ad_loop_evaluated(backend, name, payload, read_cb, write_cb,
cond_cb, body_cb, compress);
guard.disarm();
Expand Down
29 changes: 29 additions & 0 deletions tests/test_if_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,32 @@ class Z:
assert dr.all(x == (10 + (mutate and tt != 'nested'), 20))
assert dr.all(y == (30, 40))


@pytest.test_arrays('float32,is_diff,shape=(*)')
@pytest.mark.parametrize('mode', ['evaluated', 'symbolic'])
def test18_if_stmt_preserve_unused_ad(t, mode):
with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False):
x = t(0, 1)
y = t(1, 3)
dr.enable_grad(x, y)
print(x.index)
print(y.index)

with dr.suspend_grad():
def true_fn(x, y):
return x + y, y

def false_fn(x, y):
return x, y

x, y = dr.if_stmt(
args=(x, y),
cond=x<.5,
arg_labels=('x', 'y'),
true_fn=true_fn,
false_fn=false_fn,
mode=mode
)

assert not dr.grad_enabled(x)
assert dr.grad_enabled(y)

0 comments on commit 494c571

Please sign in to comment.