diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h index 10d2c869..b4a2af80 100644 --- a/include/drjit/autodiff.h +++ b/include/drjit/autodiff.h @@ -123,12 +123,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray } DiffArray(const DiffArray &a) { - if constexpr (IsFloat) { - m_index = ad_var_inc_ref(a.m_index); - } else { - m_index = a.m_index; - jit_var_inc_ref(m_index); - } + if constexpr (IsFloat) + m_index = ad_var_copy_ref(a.m_index); + else + m_index = jit_var_inc_ref(a.m_index); } DiffArray(DiffArray &&a) noexcept : m_index(a.m_index) { @@ -148,8 +146,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray m_index = jit_var_cast((uint32_t) v.m_index, Type, 1); } - DiffArray(const Detached &v) : m_index(v.index()) { - jit_var_inc_ref((uint32_t) m_index); + DiffArray(const Detached &v) { + m_index = jit_var_inc_ref((uint32_t) v.index()); } template <typename T, enable_if_scalar_t<T> = 0> @@ -161,14 +159,15 @@ struct DRJIT_TRIVIAL_ABI DiffArray DiffArray &operator=(const DiffArray &a) { Index old_index = m_index; + if constexpr (IsFloat) { - m_index = ad_var_inc_ref(a.m_index); + m_index = ad_var_copy_ref(a.m_index); ad_var_dec_ref(old_index); } else { - m_index = a.m_index; - jit_var_inc_ref(m_index); + m_index = jit_var_inc_ref(a.m_index); jit_var_dec_ref(old_index); } + return *this; } @@ -685,12 +684,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray static DRJIT_INLINE DiffArray borrow(Index index) { DiffArray result; - if constexpr (IsFloat) { + if constexpr (IsFloat) result.m_index = ad_var_inc_ref(index); - } else { - jit_var_inc_ref(index); - result.m_index = index; - } + else + result.m_index = jit_var_inc_ref(index); return result; } @@ -742,7 +739,7 @@ struct DRJIT_TRIVIAL_ABI DiffArray m_index = ad_var_new(jit_index); jit_var_dec_ref(jit_index); } else { - jit_var_inc_ref(jit_index); + jit_index = jit_var_inc_ref(jit_index); ad_var_dec_ref(m_index); m_index = jit_index; } @@ -956,7 +953,7 @@ NAMESPACE_BEGIN(detail) template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) { vector<uint64_t> &indices = *(vector<uint64_t> *) p; if constexpr (IncRef) - ad_var_inc_ref(index); + index = ad_var_inc_ref(index); indices.push_back(index); } @@ -1010,7 +1007,7 @@ struct ad_index32_vector : drjit::vector<uint32_t> { void push_back_steal(uint32_t index) { push_back(index); } void push_back_borrow(uint32_t index) { - push_back(uint32_t(ad_var_inc_ref(uint64_t(index) << 32) >> 32)); + push_back((uint32_t) (ad_var_inc_ref(uint64_t(index) << 32) >> 32)); } }; @@ -1033,9 +1030,7 @@ struct index64_vector : drjit::vector<uint64_t> { } void push_back_steal(uint64_t index) { push_back(index); } - void push_back_borrow(uint64_t index) { - push_back(ad_var_inc_ref(index)); - } + void push_back_borrow(uint64_t index) { push_back(ad_var_inc_ref(index)); } }; NAMESPACE_END(detail) diff --git a/include/drjit/extra.h b/include/drjit/extra.h index d2bea714..422d1206 100644 --- a/include/drjit/extra.h +++ b/include/drjit/extra.h @@ -119,20 +119,28 @@ extern DRJIT_EXTRA_EXPORT uint32_t ad_grad(uint64_t index, bool null_ok = false) /// Check if gradient tracking is enabled for the given variable extern DRJIT_EXTRA_EXPORT int ad_grad_enabled(uint64_t index); +/// Check if gradient tracking is disabled (can't create new AD variables) +extern DRJIT_EXTRA_EXPORT int ad_grad_suspended(); + +/// Temporarily enforce gradient tracking without creating a new scope +extern DRJIT_EXTRA_EXPORT int ad_set_force_grad(int status); + /// Accumulate into the gradient associated with a given variable extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value); /// Clear the gradient of a given variable extern DRJIT_EXTRA_EXPORT void ad_clear_grad(uint64_t index); +/// Increase the reference count of the given AD variable +extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT; + /** - * \brief Increase the reference count of the given AD variable + * \brief Variant of 'ad_var_inc_ref' that conceptually creates a copy * - * This function is typically called when an AD variable is copied. It may - * return a detached variable when an active AD scope disables differentiation - * of the provided input variable. + * This function return a detached variable when an active AD scope disables + * differentiation of the provided input variable. */ -extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT; +extern DRJIT_EXTRA_EXPORT uint64_t ad_var_copy_ref_impl(uint64_t) JIT_NOEXCEPT; /// Decrease the reference count of the given AD variable extern DRJIT_EXTRA_EXPORT void ad_var_dec_ref_impl(uint64_t) JIT_NOEXCEPT; @@ -484,6 +492,15 @@ DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT { return ad_var_inc_ref_impl(index); } +DRJIT_INLINE uint64_t ad_var_copy_ref(uint64_t index) JIT_NOEXCEPT { + /* If 'index' is known at compile time, it can only be zero, in + which case we can skip the redundant call to ad_var_dec_ref */ + if (__builtin_constant_p(index)) + return 0; + else + return ad_var_copy_ref_impl(index); +} + DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT { if (!__builtin_constant_p(index)) ad_var_dec_ref_impl(index); @@ -491,6 +508,7 @@ DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT { #else #define ad_var_dec_ref ad_var_dec_ref_impl #define ad_var_inc_ref ad_var_inc_ref_impl +#define ad_var_copy_ref ad_var_copy_ref_impl #endif // Return the AD reference count of a variable (for debugging) diff --git a/include/drjit/jit.h b/include/drjit/jit.h index 12a4aca1..d89335fd 100644 --- a/include/drjit/jit.h +++ b/include/drjit/jit.h @@ -71,8 +71,8 @@ struct DRJIT_TRIVIAL_ABI JitArray ~JitArray() noexcept { jit_var_dec_ref(m_index); } - JitArray(const JitArray &a) : m_index(a.m_index) { - jit_var_inc_ref(m_index); + JitArray(const JitArray &a) { + m_index = jit_var_inc_ref(a.m_index); } JitArray(JitArray &&a) noexcept : m_index(a.m_index) { @@ -122,9 +122,9 @@ struct DRJIT_TRIVIAL_ABI JitArray } JitArray &operator=(const JitArray &a) { - jit_var_inc_ref(a.m_index); + uint32_t index = jit_var_inc_ref(a.m_index); jit_var_dec_ref(m_index); - m_index = a.m_index; + m_index = index; return *this; } @@ -677,8 +677,7 @@ struct DRJIT_TRIVIAL_ABI JitArray static DRJIT_INLINE JitArray borrow(Index index) { JitArray result; - jit_var_inc_ref(index); - result.m_index = index; + result.m_index = jit_var_inc_ref(index); return result; } @@ -733,8 +732,7 @@ struct index32_vector : drjit::vector<uint32_t> { void push_back_steal(uint32_t index) { push_back(index); } void push_back_borrow(uint32_t index) { - jit_var_inc_ref(index); - push_back(index); + push_back(jit_var_inc_ref(index)); } }; diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp index f993826d..17195ceb 100644 --- a/src/extra/autodiff.cpp +++ b/src/extra/autodiff.cpp @@ -506,6 +506,9 @@ struct Scope { */ bool isolate = false; + /// Flag to temporarily force gradient tracking in an ad-disabled scope + bool force_grad = false; + // Current ``state.counter`` value when entering this scope uint64_t counter = 0; @@ -537,7 +540,7 @@ struct Scope { /// Check if a variable has gradients enabled bool enabled(ADIndex index) const { - return (indices.find(index) != indices.end()) != complement; + return (indices.find(index) != indices.end()) != complement || force_grad; } /// Potentially zero out 'index' if the variable has gradients disabled @@ -549,7 +552,7 @@ struct Scope { /// Track gradients for the given variable void enable(ADIndex index) { - if (!index) + if (!index || force_grad) return; if (complement) @@ -695,11 +698,11 @@ static void ad_free(ADIndex index, Variable *v) { state.unused_variables.push(index); } -Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT { +Index ad_var_copy_ref_impl(Index index) JIT_NOEXCEPT { JitIndex jit_index = ::jit_index(index); ADIndex ad_index = ::ad_index(index); - jit_var_inc_ref(jit_index); + jit_index = jit_var_inc_ref(jit_index); if (unlikely(ad_index)) { const std::vector<Scope> &scopes = local_state.scopes; @@ -715,6 +718,20 @@ Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT { return combine(ad_index, jit_index); } +Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT { + JitIndex jit_index = ::jit_index(index); + ADIndex ad_index = ::ad_index(index); + + jit_var_inc_ref(jit_index); + + if (unlikely(ad_index)) { + std::lock_guard<std::mutex> guard(state.mutex); + ad_var_inc_ref_int(ad_index, state[ad_index]); + } + + return index; +} + uint32_t ad_var_ref(uint64_t index) { uint32_t ad_index = ::ad_index(index); @@ -953,7 +970,7 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result, active |= scope.maybe_disable(args[i].ad_index); } - if (!active) + if (!active && !scope.force_grad) return (Index) result.release(); } @@ -1770,6 +1787,25 @@ int ad_grad_enabled(Index index) { return ad_index != 0; } +int ad_grad_suspended() { + const std::vector<Scope> &scopes = local_state.scopes; + if (scopes.empty()) + return false; + else + return scopes.back().complement == false; +} + +/// Temporarily enforce gradient tracking without creating a new scope +int ad_set_force_grad(int status) { + std::vector<Scope> &scopes = local_state.scopes; + if (scopes.empty()) + return 0; + Scope &scope = scopes.back(); + bool old = scope.force_grad; + scope.force_grad = (bool) status; + return (int) old; +} + // ========================================================================== // AD traversal callbacks for special operations: masks, gathers, scatters // ========================================================================== @@ -2811,12 +2847,11 @@ Index ad_var_cast(Index i0, VarType vt) { void ad_var_map_put(Index source, Index target) { uint32_t ad_index_source = ad_index(source), ad_index_target = ad_index(target); - - if ((ad_index_source == 0) != (ad_index_target == 0)) - ad_raise("ad_var_map_put(): mixed attached/detached inputs!"); + if (ad_index_target == 0) + return; if (ad_index_source == 0) - return; + ad_raise("ad_var_map_put(): mixed attached/detached inputs!"); ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target); diff --git a/src/extra/call.cpp b/src/extra/call.cpp index 8db2a27e..b5d17d7a 100644 --- a/src/extra/call.cpp +++ b/src/extra/call.cpp @@ -90,7 +90,7 @@ static void ad_call_getter(JitBackend backend, const char *domain, "ad_call_getter(\"%s%s%s\", index=r%u, mask=r%u)", domain_or_empty, separator, name, index, mask.index()); - scoped_isolation_boundary guard; + scoped_isolation_guard guard; { scoped_record rec(backend, name, true); @@ -279,7 +279,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary guard_1(1); + scoped_isolation_guard guard_1(1); scoped_record guard_2(backend, name, true); // Recording may fail due to recursion depth @@ -631,7 +631,7 @@ struct CallOp : public dr::detail::CustomOpBase { /// Implements f(arg..., grad(rv)...) -> grad(arg) ... void backward() override { - scoped_isolation_boundary isolation_guard; + scoped_isolation_guard isolation_guard; std::string name = m_name + " [ad, bwd]"; index64_vector args, rv; diff --git a/src/extra/common.h b/src/extra/common.h index c6c8e503..6e1af403 100644 --- a/src/extra/common.h +++ b/src/extra/common.h @@ -38,12 +38,12 @@ template <typename T> class unlock_guard { }; /// RAII AD Isolation helper -struct scoped_isolation_boundary { - scoped_isolation_boundary(int symbolic = -1) : symbolic(symbolic) { +struct scoped_isolation_guard { + scoped_isolation_guard(int symbolic = -1) : symbolic(symbolic) { ad_scope_enter(drjit::ADScope::Isolate, 0, nullptr, symbolic); } - ~scoped_isolation_boundary() { + ~scoped_isolation_guard() { ad_scope_leave(success); } @@ -58,6 +58,13 @@ struct scoped_isolation_boundary { bool success = false; }; +struct scoped_force_grad_guard { + scoped_force_grad_guard() { value = ad_set_force_grad(1); } + ~scoped_force_grad_guard() { ad_set_force_grad(value); } + bool value; +}; + + /// RAII helper to temporarily push a mask onto the Dr.Jit mask stack struct scoped_push_mask { scoped_push_mask(JitBackend backend, uint32_t index) : backend(backend) { diff --git a/src/extra/cond.cpp b/src/extra/cond.cpp index 19517f20..b3e98657 100644 --- a/src/extra/cond.cpp +++ b/src/extra/cond.cpp @@ -35,7 +35,7 @@ static void ad_cond_evaluated(JitBackend backend, const char *label, index64_vector args_t, args_f; size_t cond_size = jit_var_size((uint32_t) cond_t); - /// For differentiable inputs, create masked AD variables + // For differentiable inputs, create masked AD variables for (size_t i = 0; i < args.size(); ++i) { uint64_t index = args[i]; uint32_t index_lo = (uint32_t) index; @@ -43,6 +43,11 @@ static void ad_cond_evaluated(JitBackend backend, const char *label, bool is_diff = index != index_lo; if (is_diff && (size == cond_size || size == 1 || cond_size == 1)) { + // Force the creation of AD variables even when in an AD-suspended + // scope. This is so that we can preserve the AD status of variables + // that aren't changed by the loop + scoped_force_grad_guard guard; + uint64_t idx_t = ad_var_select(cond_t, index, index_lo), idx_f = ad_var_select(cond_f, index, index_lo); uint32_t ad_idx_t = (uint32_t) (idx_t >> 32), @@ -145,18 +150,22 @@ static void ad_cond_symbolic(JitBackend backend, const char *label, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary isolation_guard(1); + scoped_isolation_guard isolation_guard(1); scoped_record record_guard(backend); index64_vector args_t, args_f, rv_t, rv_f, cleanup; dr::vector<uint32_t> tmp; - // For differentiable inputs, create new disconnected AD variables for (size_t i = 0; i < args.size(); ++i) { uint64_t index = args[i]; uint32_t index_lo = (uint32_t) index; if (ad && (args[i] >> 32)) { + // Force the creation of AD variables even when in an AD-suspended + // scope. This is so that we can preserve the AD status of variables + // that aren't changed by the loop + scoped_force_grad_guard guard; + uint64_t idx_t = ad_var_new(index_lo), idx_f = ad_var_new(index_lo); @@ -259,6 +268,8 @@ static void ad_cond_symbolic(JitBackend backend, const char *label, // Unchanged differentiable outputs can be piped through directly // without being outputs of the CustomOp if (ad && ((idx_f >> 32) || (idx_t >> 32))) { + // Force the creation of AD variables even when in an AD-suspended + scoped_force_grad_guard guard; idx_t = ad_var_map_get(idx_t); idx_f = ad_var_map_get(idx_f); @@ -627,7 +638,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload, output_offsets, implicit_in, implicit_out, ad); } - if (!input_offsets.empty() || !output_offsets.empty()) { + if ((!input_offsets.empty() || !output_offsets.empty()) && !ad_grad_suspended()) { nanobind::ref<CondOp> op = new CondOp( backend, label, payload, cond, body_cb, delete_cb, args, rv, input_offsets, output_offsets, implicit_in, implicit_out); @@ -641,7 +652,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload, op->disable(rv); } } else { - scoped_isolation_boundary guard; + scoped_isolation_guard guard; ad_cond_evaluated(backend, label, payload, true_mask.index(), false_mask.index(), args, rv, body_cb); guard.disarm(); diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp index fa133acd..ddcb66ab 100644 --- a/src/extra/loop.cpp +++ b/src/extra/loop.cpp @@ -45,7 +45,7 @@ static bool ad_loop_symbolic(JitBackend backend, const char *name, /* Postponed operations captured by the isolation scope should only * be executed once we've exited the symbolic scope. We therefore * need to declare the AD isolation guard before the recording guard. */ - scoped_isolation_boundary isolation_guard(1); + scoped_isolation_guard isolation_guard(1); scoped_record record_guard(backend); // Rewrite the loop state variables @@ -155,6 +155,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, index64_vector indices2; JitVar active_it; size_t it = 0; + bool grad_suspended = ad_grad_suspended(); while (true) { // Evaluate the loop state @@ -189,8 +190,15 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name, } for (size_t i = 0; i < indices2.size(); ++i) { - uint64_t i1 = indices2[i]; - uint64_t i2 = ad_var_copy(i1); + // Kernel caching: Must create an AD copy so that gradient + // computation steps involving this variable (even if unchangecd + // & only used as a read-only dependency) are correctly placed + // within their associated loop iterations. This does not create + // a copy of the underlying JIT variable. + + uint64_t i1 = indices2[i], + i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1); + ad_var_dec_ref(i1); ad_mark_loop_boundary(i2); int unused = 0; @@ -926,8 +934,26 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, write_cb, cond_cb, body_cb, indices_in, implicit_in, implicit_out); } + needs_ad &= ad; - if (needs_ad && ad) { + if (needs_ad && ad_grad_suspended()) { + // Maintain differentiability of unchanged variables + bool rewrite = false; + index64_vector indices_out; + + read_cb(payload, indices_out); + for (size_t i = 0; i < indices_out.size(); ++i) { + if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i] && + indices_in[i] != indices_out[i]) { + ad_var_inc_ref(indices_in[i]); + jit_var_dec_ref((uint32_t) indices_out[i]); + indices_out[i] = indices_in[i]; + rewrite = true; + } + } + if (rewrite) + write_cb(payload, indices_out, false); + } else if (needs_ad) { index64_vector indices_out; read_cb(payload, indices_out); @@ -980,7 +1006,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress, "drjit.while_loop() for general information on symbolic and " "evaluated loops, as well as their limitations."); - scoped_isolation_boundary guard; + scoped_isolation_guard guard; ad_loop_evaluated(backend, name, payload, read_cb, write_cb, cond_cb, body_cb, compress); guard.disarm(); diff --git a/tests/test_if_stmt.py b/tests/test_if_stmt.py index c488a554..64f43093 100644 --- a/tests/test_if_stmt.py +++ b/tests/test_if_stmt.py @@ -584,3 +584,32 @@ class Z: assert dr.all(x == (10 + (mutate and tt != 'nested'), 20)) assert dr.all(y == (30, 40)) + +@pytest.test_arrays('float32,is_diff,shape=(*)') +@pytest.mark.parametrize('mode', ['evaluated', 'symbolic']) +def test18_if_stmt_preserve_unused_ad(t, mode): + with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False): + x = t(0, 1) + y = t(1, 3) + dr.enable_grad(x, y) + y_id = y.index_ad + + with dr.suspend_grad(): + def true_fn(x, y): + return x + y, y + + def false_fn(x, y): + return x, y + + x, y = dr.if_stmt( + args=(x, y), + cond=x<.5, + arg_labels=('x', 'y'), + true_fn=true_fn, + false_fn=false_fn, + mode=mode + ) + + assert not dr.grad_enabled(x) + assert dr.grad_enabled(y) + assert y.index_ad == y_id diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py index 0a2f28c4..0ae0a337 100644 --- a/tests/test_while_loop.py +++ b/tests/test_while_loop.py @@ -664,3 +664,31 @@ def loop(t, x, y: t, n = 10): dr.make_opaque(x, y) y = loop(t, [x, x], y) + + +@pytest.test_arrays('float32,diff,shape=(*)') +@pytest.mark.parametrize('mode', ['symbolic', 'evaluated']) +def test29_preserve_differentiability_suspend(t, mode): + x = t(0, 0) + y = t(1, 2) + dr.enable_grad(x, y) + y_id = y.index_ad + + with dr.suspend_grad(): + def cond_fn(x, _): + return x < 10 + + def body_fn(x, y): + return x + y, y + + x, y = dr.while_loop( + state=(x, y), + cond=cond_fn, + labels=('x', 'y'), + body=body_fn, + mode=mode + ) + + assert not dr.grad_enabled(x) + assert dr.grad_enabled(y) + assert y.index_ad == y_id