From 488c263cea579e4e5ab2a118469bfee3ce5a7c60 Mon Sep 17 00:00:00 2001 From: Wenzel Jakob Date: Tue, 15 Oct 2024 23:55:39 +0900 Subject: [PATCH] ``dr.if_stmt()``: maintain AD status of unchanged variables in AD-suspended mode Dr.Jit control flow operations (``dr.if_stmt(), drjit.while_loop()``) disable gradient tracking of all variable state when the operation takes place within an AD-disabled scope. This can be surprising when a ``@dr.syntax`` transformation silently passes local variables to such an operation, which then become non-differentiable. This commit carves out an exception: when variables aren't actually modified by the control flow operation, they can retain their AD identity. This is part #1 of the fix for issue #253 reported by @dvicini and targets ``dr.if_stmt()`` only. The next commit will also fix the same problem for while loops. --- include/drjit/extra.h | 6 ++++++ src/extra/autodiff.cpp | 35 ++++++++++++++++++++++++++++------- src/extra/call.cpp | 6 +++--- src/extra/common.h | 13 ++++++++++--- src/extra/cond.cpp | 21 ++++++++++++++++----- src/extra/loop.cpp | 4 ++-- tests/test_if_stmt.py | 29 +++++++++++++++++++++++++++++ 7 files changed, 94 insertions(+), 20 deletions(-) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index 249786496..422d1206e 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -119,6 +119,12 @@ extern DRJIT_EXTRA_EXPORT uint32_t ad_grad(uint64_t index, bool null_ok = false) /// Check if gradient tracking is enabled for the given variable extern DRJIT_EXTRA_EXPORT int ad_grad_enabled(uint64_t index); +/// Check if gradient tracking is disabled (can't create new AD variables) +extern DRJIT_EXTRA_EXPORT int ad_grad_suspended(); + +/// Temporarily enforce gradient tracking without creating a new scope +extern DRJIT_EXTRA_EXPORT int ad_set_force_grad(int status); + /// Accumulate into the gradient associated with a given variable extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value); diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index ccbdeb976..17195cebd 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -506,6 +506,9 @@ struct Scope { */ bool isolate = false; + /// Flag to temporarily force gradient tracking in an ad-disabled scope + bool force_grad = false; + // Current ``state.counter`` value when entering this scope uint64_t counter = 0; @@ -537,7 +540,7 @@ struct Scope { /// Check if a variable has gradients enabled bool enabled(ADIndex index) const { - return (indices.find(index) != indices.end()) != complement; + return (indices.find(index) != indices.end()) != complement || force_grad; } /// Potentially zero out 'index' if the variable has gradients disabled @@ -549,7 +552,7 @@ struct Scope { /// Track gradients for the given variable void enable(ADIndex index) { - if (!index) + if (!index || force_grad) return; if (complement) @@ -967,7 +970,7 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result, active |= scope.maybe_disable(args[i].ad_index); } - if (!active) + if (!active && !scope.force_grad) return (Index) result.release(); } @@ -1784,6 +1787,25 @@ int ad_grad_enabled(Index index) { return ad_index != 0; } +int ad_grad_suspended() { + const std::vector &scopes = local_state.scopes; + if (scopes.empty()) + return false; + else + return scopes.back().complement == false; +} + +/// Temporarily enforce gradient tracking without creating a new scope +int ad_set_force_grad(int status) { + std::vector &scopes = local_state.scopes; + if (scopes.empty()) + return 0; + Scope &scope = scopes.back(); + bool old = scope.force_grad; + scope.force_grad = (bool) status; + return (int) old; +} + // ========================================================================== // AD traversal callbacks for special operations: masks, gathers, scatters // ========================================================================== @@ -2825,12 +2847,11 @@ Index ad_var_cast(Index i0, VarType vt) { void ad_var_map_put(Index source, Index target) { uint32_t ad_index_source = ad_index(source), ad_index_target = ad_index(target); - - if ((ad_index_source == 0) != (ad_index_target == 0)) - ad_raise("ad_var_map_put(): mixed attached/detached inputs!"); + if (ad_index_target == 0) + return; if (ad_index_source == 0) - return; + ad_raise("ad_var_map_put(): mixed attached/detached inputs!"); ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target); diff --git a/src/extra/call.cpp b/src/extra/call.cpp index 8db2a27e3..b5d17d7a7 100644 --- a/src/extra/call.cpp +++ b/src/extra/call.cpp @@ -90,7 +90,7 @@ static void ad_call_getter(JitBackend backend, const char *domain, "ad_call_getter(\"%s%s%s\", index=r%u, mask=r%u)", domain_or_empty, separator, name, index, mask.index()); - scoped_isolation_boundary guard; + scoped_isolation_guard guard; { scoped_record rec(backend, name, true); @@ -279,7 +279,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary guard_1(1); + scoped_isolation_guard guard_1(1); scoped_record guard_2(backend, name, true); // Recording may fail due to recursion depth @@ -631,7 +631,7 @@ struct CallOp : public dr::detail::CustomOpBase { /// Implements f(arg..., grad(rv)...) -> grad(arg) ... void backward() override { - scoped_isolation_boundary isolation_guard; + scoped_isolation_guard isolation_guard; std::string name = m_name + " [ad, bwd]"; index64_vector args, rv; diff --git a/src/extra/common.h b/src/extra/common.h index c6c8e5033..6e1af4035 100644 --- a/src/extra/common.h +++ b/src/extra/common.h @@ -38,12 +38,12 @@ template class unlock_guard { }; /// RAII AD Isolation helper -struct scoped_isolation_boundary { - scoped_isolation_boundary(int symbolic = -1) : symbolic(symbolic) { +struct scoped_isolation_guard { + scoped_isolation_guard(int symbolic = -1) : symbolic(symbolic) { ad_scope_enter(drjit::ADScope::Isolate, 0, nullptr, symbolic); } - ~scoped_isolation_boundary() { + ~scoped_isolation_guard() { ad_scope_leave(success); } @@ -58,6 +58,13 @@ struct scoped_isolation_boundary { bool success = false; }; +struct scoped_force_grad_guard { + scoped_force_grad_guard() { value = ad_set_force_grad(1); } + ~scoped_force_grad_guard() { ad_set_force_grad(value); } + bool value; +}; + + /// RAII helper to temporarily push a mask onto the Dr.Jit mask stack struct scoped_push_mask { scoped_push_mask(JitBackend backend, uint32_t index) : backend(backend) { diff --git a/src/extra/cond.cpp b/src/extra/cond.cpp index 19517f208..b3e986571 100644 --- a/src/extra/cond.cpp +++ b/src/extra/cond.cpp @@ -35,7 +35,7 @@ static void ad_cond_evaluated(JitBackend backend, const char *label, index64_vector args_t, args_f; size_t cond_size = jit_var_size((uint32_t) cond_t); - /// For differentiable inputs, create masked AD variables + // For differentiable inputs, create masked AD variables for (size_t i = 0; i < args.size(); ++i) { uint64_t index = args[i]; uint32_t index_lo = (uint32_t) index; @@ -43,6 +43,11 @@ static void ad_cond_evaluated(JitBackend backend, const char *label, bool is_diff = index != index_lo; if (is_diff && (size == cond_size || size == 1 || cond_size == 1)) { + // Force the creation of AD variables even when in an AD-suspended + // scope. This is so that we can preserve the AD status of variables + // that aren't changed by the loop + scoped_force_grad_guard guard; + uint64_t idx_t = ad_var_select(cond_t, index, index_lo), idx_f = ad_var_select(cond_f, index, index_lo); uint32_t ad_idx_t = (uint32_t) (idx_t >> 32), @@ -145,18 +150,22 @@ static void ad_cond_symbolic(JitBackend backend, const char *label, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary isolation_guard(1); + scoped_isolation_guard isolation_guard(1); scoped_record record_guard(backend); index64_vector args_t, args_f, rv_t, rv_f, cleanup; dr::vector tmp; - // For differentiable inputs, create new disconnected AD variables for (size_t i = 0; i < args.size(); ++i) { uint64_t index = args[i]; uint32_t index_lo = (uint32_t) index; if (ad && (args[i] >> 32)) { + // Force the creation of AD variables even when in an AD-suspended + // scope. This is so that we can preserve the AD status of variables + // that aren't changed by the loop + scoped_force_grad_guard guard; + uint64_t idx_t = ad_var_new(index_lo), idx_f = ad_var_new(index_lo); @@ -259,6 +268,8 @@ static void ad_cond_symbolic(JitBackend backend, const char *label, // Unchanged differentiable outputs can be piped through directly // without being outputs of the CustomOp if (ad && ((idx_f >> 32) || (idx_t >> 32))) { + // Force the creation of AD variables even when in an AD-suspended + scoped_force_grad_guard guard; idx_t = ad_var_map_get(idx_t); idx_f = ad_var_map_get(idx_f); @@ -627,7 +638,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload, output_offsets, implicit_in, implicit_out, ad); } - if (!input_offsets.empty() || !output_offsets.empty()) { + if ((!input_offsets.empty() || !output_offsets.empty()) && !ad_grad_suspended()) { nanobind::ref op = new CondOp( backend, label, payload, cond, body_cb, delete_cb, args, rv, input_offsets, output_offsets, implicit_in, implicit_out); @@ -641,7 +652,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload, op->disable(rv); } } else { - scoped_isolation_boundary guard; + scoped_isolation_guard guard; ad_cond_evaluated(backend, label, payload, true_mask.index(), false_mask.index(), args, rv, body_cb); guard.disarm(); diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index fa133acd5..35f0e7e2b 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -45,7 +45,7 @@ static bool ad_loop_symbolic(JitBackend backend, const char *name, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary isolation_guard(1); + scoped_isolation_guard isolation_guard(1); scoped_record record_guard(backend); // Rewrite the loop state variables @@ -980,7 +980,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, "drjit.while_loop() for general information on symbolic and " "evaluated loops, as well as their limitations."); - scoped_isolation_boundary guard; + scoped_isolation_guard guard; ad_loop_evaluated(backend, name, payload, read_cb, write_cb, cond_cb, body_cb, compress); guard.disarm(); diff --git a/tests/test_if_stmt.py b/tests/test_if_stmt.py index c488a554d..7b53be563 100644 --- a/tests/test_if_stmt.py +++ b/tests/test_if_stmt.py @@ -584,3 +584,32 @@ class Z: assert dr.all(x == (10 + (mutate and tt != 'nested'), 20)) assert dr.all(y == (30, 40)) + +@pytest.test_arrays('float32,is_diff,shape=(*)') +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +def test18_if_stmt_preserve_unused_ad(t, mode): + with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False): + x = t(0, 1) + y = t(1, 3) + dr.enable_grad(x, y) + print(x.index) + print(y.index) + + with dr.suspend_grad(): + def true_fn(x, y): + return x + y, y + + def false_fn(x, y): + return x, y + + x, y = dr.if_stmt( + args=(x, y), + cond=x<.5, + arg_labels=('x', 'y'), + true_fn=true_fn, + false_fn=false_fn, + mode=mode + ) + + assert not dr.grad_enabled(x) + assert dr.grad_enabled(y)