diff --git a/include/drjit/autodiff.h b/include/drjit/autodiff.h
index 10d2c869..b4a2af80 100644
--- a/include/drjit/autodiff.h
+++ b/include/drjit/autodiff.h
@@ -123,12 +123,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
     }
 
     DiffArray(const DiffArray &a) {
-        if constexpr (IsFloat) {
-            m_index = ad_var_inc_ref(a.m_index);
-        } else {
-            m_index = a.m_index;
-            jit_var_inc_ref(m_index);
-        }
+        if constexpr (IsFloat)
+            m_index = ad_var_copy_ref(a.m_index);
+        else
+            m_index = jit_var_inc_ref(a.m_index);
     }
 
     DiffArray(DiffArray &&a) noexcept : m_index(a.m_index) {
@@ -148,8 +146,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray
         m_index = jit_var_cast((uint32_t) v.m_index, Type, 1);
     }
 
-    DiffArray(const Detached &v) : m_index(v.index()) {
-        jit_var_inc_ref((uint32_t) m_index);
+    DiffArray(const Detached &v) {
+        m_index = jit_var_inc_ref((uint32_t) v.index());
     }
 
     template <typename T, enable_if_scalar_t<T> = 0>
@@ -161,14 +159,15 @@ struct DRJIT_TRIVIAL_ABI DiffArray
 
     DiffArray &operator=(const DiffArray &a) {
         Index old_index = m_index;
+
         if constexpr (IsFloat) {
-            m_index = ad_var_inc_ref(a.m_index);
+            m_index = ad_var_copy_ref(a.m_index);
             ad_var_dec_ref(old_index);
         } else {
-            m_index = a.m_index;
-            jit_var_inc_ref(m_index);
+            m_index = jit_var_inc_ref(a.m_index);
             jit_var_dec_ref(old_index);
         }
+
         return *this;
     }
 
@@ -685,12 +684,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
     static DRJIT_INLINE DiffArray borrow(Index index) {
         DiffArray result;
 
-        if constexpr (IsFloat) {
+        if constexpr (IsFloat)
             result.m_index = ad_var_inc_ref(index);
-        } else {
-            jit_var_inc_ref(index);
-            result.m_index = index;
-        }
+        else
+            result.m_index = jit_var_inc_ref(index);
 
         return result;
     }
@@ -742,7 +739,7 @@ struct DRJIT_TRIVIAL_ABI DiffArray
                 m_index = ad_var_new(jit_index);
                 jit_var_dec_ref(jit_index);
             } else {
-                jit_var_inc_ref(jit_index);
+                jit_index = jit_var_inc_ref(jit_index);
                 ad_var_dec_ref(m_index);
                 m_index = jit_index;
             }
@@ -956,7 +953,7 @@ NAMESPACE_BEGIN(detail)
 template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
     vector<uint64_t> &indices = *(vector<uint64_t> *) p;
     if constexpr (IncRef)
-        ad_var_inc_ref(index);
+        index = ad_var_inc_ref(index);
     indices.push_back(index);
 }
 
@@ -1010,7 +1007,7 @@ struct ad_index32_vector : drjit::vector<uint32_t> {
 
     void push_back_steal(uint32_t index) { push_back(index); }
     void push_back_borrow(uint32_t index) {
-        push_back(uint32_t(ad_var_inc_ref(uint64_t(index) << 32) >> 32));
+        push_back((uint32_t) (ad_var_inc_ref(uint64_t(index) << 32) >> 32));
     }
 };
 
@@ -1033,9 +1030,7 @@ struct index64_vector : drjit::vector<uint64_t> {
     }
 
     void push_back_steal(uint64_t index) { push_back(index); }
-    void push_back_borrow(uint64_t index) {
-        push_back(ad_var_inc_ref(index));
-    }
+    void push_back_borrow(uint64_t index) { push_back(ad_var_inc_ref(index)); }
 };
 
 NAMESPACE_END(detail)
diff --git a/include/drjit/extra.h b/include/drjit/extra.h
index d2bea714..422d1206 100644
--- a/include/drjit/extra.h
+++ b/include/drjit/extra.h
@@ -119,20 +119,28 @@ extern DRJIT_EXTRA_EXPORT uint32_t ad_grad(uint64_t index, bool null_ok = false)
 /// Check if gradient tracking is enabled for the given variable
 extern DRJIT_EXTRA_EXPORT int ad_grad_enabled(uint64_t index);
 
+/// Check if gradient tracking is disabled (can't create new AD variables)
+extern DRJIT_EXTRA_EXPORT int ad_grad_suspended();
+
+/// Temporarily enforce gradient tracking without creating a new scope
+extern DRJIT_EXTRA_EXPORT int ad_set_force_grad(int status);
+
 /// Accumulate into the gradient associated with a given variable
 extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value);
 
 /// Clear the gradient of a given variable
 extern DRJIT_EXTRA_EXPORT void ad_clear_grad(uint64_t index);
 
+/// Increase the reference count of the given AD variable
+extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;
+
 /**
- * \brief Increase the reference count of the given AD variable
+ * \brief Variant of 'ad_var_inc_ref' that conceptually creates a copy
  *
- * This function is typically called when an AD variable is copied. It may
- * return a detached variable when an active AD scope disables differentiation
- * of the provided input variable.
+ * This function return a detached variable when an active AD scope disables
+ * differentiation of the provided input variable.
  */
-extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;
+extern DRJIT_EXTRA_EXPORT uint64_t ad_var_copy_ref_impl(uint64_t) JIT_NOEXCEPT;
 
 /// Decrease the reference count of the given AD variable
 extern DRJIT_EXTRA_EXPORT void ad_var_dec_ref_impl(uint64_t) JIT_NOEXCEPT;
@@ -484,6 +492,15 @@ DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT {
         return ad_var_inc_ref_impl(index);
 }
 
+DRJIT_INLINE uint64_t ad_var_copy_ref(uint64_t index) JIT_NOEXCEPT {
+    /* If 'index' is known at compile time, it can only be zero, in
+       which case we can skip the redundant call to ad_var_dec_ref */
+    if (__builtin_constant_p(index))
+        return 0;
+    else
+        return ad_var_copy_ref_impl(index);
+}
+
 DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT {
     if (!__builtin_constant_p(index))
         ad_var_dec_ref_impl(index);
@@ -491,6 +508,7 @@ DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT {
 #else
 #define ad_var_dec_ref ad_var_dec_ref_impl
 #define ad_var_inc_ref ad_var_inc_ref_impl
+#define ad_var_copy_ref ad_var_copy_ref_impl
 #endif
 
 // Return the AD reference count of a variable (for debugging)
diff --git a/include/drjit/jit.h b/include/drjit/jit.h
index 12a4aca1..d89335fd 100644
--- a/include/drjit/jit.h
+++ b/include/drjit/jit.h
@@ -71,8 +71,8 @@ struct DRJIT_TRIVIAL_ABI JitArray
 
     ~JitArray() noexcept { jit_var_dec_ref(m_index); }
 
-    JitArray(const JitArray &a) : m_index(a.m_index) {
-        jit_var_inc_ref(m_index);
+    JitArray(const JitArray &a) {
+        m_index = jit_var_inc_ref(a.m_index);
     }
 
     JitArray(JitArray &&a) noexcept : m_index(a.m_index) {
@@ -122,9 +122,9 @@ struct DRJIT_TRIVIAL_ABI JitArray
     }
 
     JitArray &operator=(const JitArray &a) {
-        jit_var_inc_ref(a.m_index);
+        uint32_t index = jit_var_inc_ref(a.m_index);
         jit_var_dec_ref(m_index);
-        m_index = a.m_index;
+        m_index = index;
         return *this;
     }
 
@@ -677,8 +677,7 @@ struct DRJIT_TRIVIAL_ABI JitArray
 
     static DRJIT_INLINE JitArray borrow(Index index) {
         JitArray result;
-        jit_var_inc_ref(index);
-        result.m_index = index;
+        result.m_index = jit_var_inc_ref(index);
         return result;
     }
 
@@ -733,8 +732,7 @@ struct index32_vector : drjit::vector<uint32_t> {
 
     void push_back_steal(uint32_t index) { push_back(index); }
     void push_back_borrow(uint32_t index) {
-        jit_var_inc_ref(index);
-        push_back(index);
+        push_back(jit_var_inc_ref(index));
     }
 };
 
diff --git a/src/extra/autodiff.cpp b/src/extra/autodiff.cpp
index f993826d..17195ceb 100644
--- a/src/extra/autodiff.cpp
+++ b/src/extra/autodiff.cpp
@@ -506,6 +506,9 @@ struct Scope {
      */
     bool isolate = false;
 
+    /// Flag to temporarily force gradient tracking in an ad-disabled scope
+    bool force_grad = false;
+
     // Current ``state.counter`` value when entering this scope
     uint64_t counter = 0;
 
@@ -537,7 +540,7 @@ struct Scope {
 
     /// Check if a variable has gradients enabled
     bool enabled(ADIndex index) const {
-        return (indices.find(index) != indices.end()) != complement;
+        return (indices.find(index) != indices.end()) != complement || force_grad;
     }
 
     /// Potentially zero out 'index' if the variable has gradients disabled
@@ -549,7 +552,7 @@ struct Scope {
 
     /// Track gradients for the given variable
     void enable(ADIndex index) {
-        if (!index)
+        if (!index || force_grad)
             return;
 
         if (complement)
@@ -695,11 +698,11 @@ static void ad_free(ADIndex index, Variable *v) {
     state.unused_variables.push(index);
 }
 
-Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
+Index ad_var_copy_ref_impl(Index index) JIT_NOEXCEPT {
     JitIndex jit_index = ::jit_index(index);
     ADIndex ad_index = ::ad_index(index);
 
-    jit_var_inc_ref(jit_index);
+    jit_index = jit_var_inc_ref(jit_index);
 
     if (unlikely(ad_index)) {
         const std::vector<Scope> &scopes = local_state.scopes;
@@ -715,6 +718,20 @@ Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
     return combine(ad_index, jit_index);
 }
 
+Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
+    JitIndex jit_index = ::jit_index(index);
+    ADIndex ad_index = ::ad_index(index);
+
+    jit_var_inc_ref(jit_index);
+
+    if (unlikely(ad_index)) {
+        std::lock_guard<std::mutex> guard(state.mutex);
+        ad_var_inc_ref_int(ad_index, state[ad_index]);
+    }
+
+    return index;
+}
+
 
 uint32_t ad_var_ref(uint64_t index) {
     uint32_t ad_index = ::ad_index(index);
@@ -953,7 +970,7 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result,
                 active |= scope.maybe_disable(args[i].ad_index);
         }
 
-        if (!active)
+        if (!active && !scope.force_grad)
             return (Index) result.release();
     }
 
@@ -1770,6 +1787,25 @@ int ad_grad_enabled(Index index) {
     return ad_index != 0;
 }
 
+int ad_grad_suspended() {
+    const std::vector<Scope> &scopes = local_state.scopes;
+    if (scopes.empty())
+        return false;
+    else
+        return scopes.back().complement == false;
+}
+
+/// Temporarily enforce gradient tracking without creating a new scope
+int ad_set_force_grad(int status) {
+    std::vector<Scope> &scopes = local_state.scopes;
+    if (scopes.empty())
+        return 0;
+    Scope &scope = scopes.back();
+    bool old = scope.force_grad;
+    scope.force_grad = (bool) status;
+    return (int) old;
+}
+
 // ==========================================================================
 // AD traversal callbacks for special operations: masks, gathers, scatters
 // ==========================================================================
@@ -2811,12 +2847,11 @@ Index ad_var_cast(Index i0, VarType vt) {
 void ad_var_map_put(Index source, Index target) {
     uint32_t ad_index_source = ad_index(source),
              ad_index_target = ad_index(target);
-
-    if ((ad_index_source == 0) != (ad_index_target == 0))
-        ad_raise("ad_var_map_put(): mixed attached/detached inputs!");
+    if (ad_index_target == 0)
+        return;
 
     if (ad_index_source == 0)
-        return;
+        ad_raise("ad_var_map_put(): mixed attached/detached inputs!");
 
     ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target);
 
diff --git a/src/extra/call.cpp b/src/extra/call.cpp
index 8db2a27e..b5d17d7a 100644
--- a/src/extra/call.cpp
+++ b/src/extra/call.cpp
@@ -90,7 +90,7 @@ static void ad_call_getter(JitBackend backend, const char *domain,
             "ad_call_getter(\"%s%s%s\", index=r%u, mask=r%u)", domain_or_empty,
             separator, name, index, mask.index());
 
-    scoped_isolation_boundary guard;
+    scoped_isolation_guard guard;
     {
         scoped_record rec(backend, name, true);
 
@@ -279,7 +279,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain,
         /* Postponed operations captured by the isolation scope should only
          * be executed once we've exited the symbolic scope. We therefore
          * need to declare the AD isolation guard before the recording guard. */
-        scoped_isolation_boundary guard_1(1);
+        scoped_isolation_guard guard_1(1);
 
         scoped_record guard_2(backend, name, true);
         // Recording may fail due to recursion depth
@@ -631,7 +631,7 @@ struct CallOp : public dr::detail::CustomOpBase {
 
     /// Implements f(arg..., grad(rv)...) -> grad(arg) ...
     void backward() override {
-        scoped_isolation_boundary isolation_guard;
+        scoped_isolation_guard isolation_guard;
         std::string name = m_name + " [ad, bwd]";
 
         index64_vector args, rv;
diff --git a/src/extra/common.h b/src/extra/common.h
index c6c8e503..6e1af403 100644
--- a/src/extra/common.h
+++ b/src/extra/common.h
@@ -38,12 +38,12 @@ template <typename T> class unlock_guard {
 };
 
 /// RAII AD Isolation helper
-struct scoped_isolation_boundary {
-    scoped_isolation_boundary(int symbolic = -1) : symbolic(symbolic) {
+struct scoped_isolation_guard {
+    scoped_isolation_guard(int symbolic = -1) : symbolic(symbolic) {
         ad_scope_enter(drjit::ADScope::Isolate, 0, nullptr, symbolic);
     }
 
-    ~scoped_isolation_boundary() {
+    ~scoped_isolation_guard() {
         ad_scope_leave(success);
     }
 
@@ -58,6 +58,13 @@ struct scoped_isolation_boundary {
     bool success = false;
 };
 
+struct scoped_force_grad_guard {
+    scoped_force_grad_guard() { value = ad_set_force_grad(1); }
+    ~scoped_force_grad_guard() { ad_set_force_grad(value); }
+    bool value;
+};
+
+
 /// RAII helper to temporarily push a mask onto the Dr.Jit mask stack
 struct scoped_push_mask {
     scoped_push_mask(JitBackend backend, uint32_t index) : backend(backend) {
diff --git a/src/extra/cond.cpp b/src/extra/cond.cpp
index 19517f20..b3e98657 100644
--- a/src/extra/cond.cpp
+++ b/src/extra/cond.cpp
@@ -35,7 +35,7 @@ static void ad_cond_evaluated(JitBackend backend, const char *label,
     index64_vector args_t, args_f;
     size_t cond_size = jit_var_size((uint32_t) cond_t);
 
-    /// For differentiable inputs, create masked AD variables
+    // For differentiable inputs, create masked AD variables
     for (size_t i = 0; i < args.size(); ++i) {
         uint64_t index = args[i];
         uint32_t index_lo = (uint32_t) index;
@@ -43,6 +43,11 @@ static void ad_cond_evaluated(JitBackend backend, const char *label,
         bool is_diff = index != index_lo;
 
         if (is_diff && (size == cond_size || size == 1 || cond_size == 1)) {
+            // Force the creation of AD variables even when in an AD-suspended
+            // scope. This is so that we can preserve the AD status of variables
+            // that aren't changed by the loop
+            scoped_force_grad_guard guard;
+
             uint64_t idx_t = ad_var_select(cond_t, index, index_lo),
                      idx_f = ad_var_select(cond_f, index, index_lo);
             uint32_t ad_idx_t = (uint32_t) (idx_t >> 32),
@@ -145,18 +150,22 @@ static void ad_cond_symbolic(JitBackend backend, const char *label,
     /* Postponed operations captured by the isolation scope should only
      * be executed once we've exited the symbolic scope. We therefore
      * need to declare the AD isolation guard before the recording guard. */
-    scoped_isolation_boundary isolation_guard(1);
+    scoped_isolation_guard isolation_guard(1);
     scoped_record record_guard(backend);
 
     index64_vector args_t, args_f, rv_t, rv_f, cleanup;
     dr::vector<uint32_t> tmp;
 
-
     // For differentiable inputs, create new disconnected AD variables
     for (size_t i = 0; i < args.size(); ++i) {
         uint64_t index = args[i];
         uint32_t index_lo = (uint32_t) index;
         if (ad && (args[i] >> 32)) {
+            // Force the creation of AD variables even when in an AD-suspended
+            // scope. This is so that we can preserve the AD status of variables
+            // that aren't changed by the loop
+            scoped_force_grad_guard guard;
+
             uint64_t idx_t = ad_var_new(index_lo),
                      idx_f = ad_var_new(index_lo);
 
@@ -259,6 +268,8 @@ static void ad_cond_symbolic(JitBackend backend, const char *label,
         // Unchanged differentiable outputs can be piped through directly
         // without being outputs of the CustomOp
         if (ad && ((idx_f >> 32) || (idx_t >> 32))) {
+            // Force the creation of AD variables even when in an AD-suspended
+            scoped_force_grad_guard guard;
             idx_t = ad_var_map_get(idx_t);
             idx_f = ad_var_map_get(idx_f);
 
@@ -627,7 +638,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload,
                              output_offsets, implicit_in, implicit_out, ad);
         }
 
-        if (!input_offsets.empty() || !output_offsets.empty()) {
+        if ((!input_offsets.empty() || !output_offsets.empty()) && !ad_grad_suspended()) {
             nanobind::ref<CondOp> op = new CondOp(
                 backend, label, payload, cond, body_cb, delete_cb, args, rv,
                 input_offsets, output_offsets, implicit_in, implicit_out);
@@ -641,7 +652,7 @@ bool ad_cond(JitBackend backend, int symbolic, const char *label, void *payload,
             op->disable(rv);
         }
     } else {
-        scoped_isolation_boundary guard;
+        scoped_isolation_guard guard;
         ad_cond_evaluated(backend, label, payload, true_mask.index(),
                           false_mask.index(), args, rv, body_cb);
         guard.disarm();
diff --git a/src/extra/loop.cpp b/src/extra/loop.cpp
index fa133acd..ddcb66ab 100644
--- a/src/extra/loop.cpp
+++ b/src/extra/loop.cpp
@@ -45,7 +45,7 @@ static bool ad_loop_symbolic(JitBackend backend, const char *name,
         /* Postponed operations captured by the isolation scope should only
          * be executed once we've exited the symbolic scope. We therefore
          * need to declare the AD isolation guard before the recording guard. */
-        scoped_isolation_boundary isolation_guard(1);
+        scoped_isolation_guard isolation_guard(1);
         scoped_record record_guard(backend);
 
         // Rewrite the loop state variables
@@ -155,6 +155,7 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name,
     index64_vector indices2;
     JitVar active_it;
     size_t it = 0;
+    bool grad_suspended = ad_grad_suspended();
 
     while (true) {
         // Evaluate the loop state
@@ -189,8 +190,15 @@ static size_t ad_loop_evaluated_mask(JitBackend backend, const char *name,
         }
 
         for (size_t i = 0; i < indices2.size(); ++i) {
-            uint64_t i1 = indices2[i];
-            uint64_t i2 = ad_var_copy(i1);
+            // Kernel caching: Must create an AD copy so that gradient
+            // computation steps involving this variable (even if unchangecd
+            // & only used as a read-only dependency) are correctly placed
+            // within their associated loop iterations. This does not create
+            // a copy of the underlying JIT variable.
+
+            uint64_t i1 = indices2[i],
+                     i2 = grad_suspended ? ad_var_inc_ref(i1) : ad_var_copy(i1);
+
             ad_var_dec_ref(i1);
             ad_mark_loop_boundary(i2);
             int unused = 0;
@@ -926,8 +934,26 @@ bool ad_loop(JitBackend backend, int symbolic, int compress,
                                         write_cb, cond_cb, body_cb, indices_in,
                                         implicit_in, implicit_out);
         }
+        needs_ad &= ad;
 
-        if (needs_ad && ad) {
+        if (needs_ad && ad_grad_suspended()) {
+            // Maintain differentiability of unchanged variables
+            bool rewrite = false;
+            index64_vector indices_out;
+
+            read_cb(payload, indices_out);
+            for (size_t i = 0; i < indices_out.size(); ++i) {
+                if ((uint32_t) indices_in[i] == (uint32_t) indices_out[i] &&
+                    indices_in[i] != indices_out[i]) {
+                    ad_var_inc_ref(indices_in[i]);
+                    jit_var_dec_ref((uint32_t) indices_out[i]);
+                    indices_out[i] = indices_in[i];
+                    rewrite = true;
+                }
+            }
+            if (rewrite)
+                write_cb(payload, indices_out, false);
+        } else if (needs_ad) {
             index64_vector indices_out;
             read_cb(payload, indices_out);
 
@@ -980,7 +1006,7 @@ bool ad_loop(JitBackend backend, int symbolic, int compress,
                       "drjit.while_loop() for general information on symbolic and "
                       "evaluated loops, as well as their limitations.");
 
-        scoped_isolation_boundary guard;
+        scoped_isolation_guard guard;
         ad_loop_evaluated(backend, name, payload, read_cb, write_cb,
                           cond_cb, body_cb, compress);
         guard.disarm();
diff --git a/tests/test_if_stmt.py b/tests/test_if_stmt.py
index c488a554..64f43093 100644
--- a/tests/test_if_stmt.py
+++ b/tests/test_if_stmt.py
@@ -584,3 +584,32 @@ class Z:
     assert dr.all(x == (10 + (mutate and tt != 'nested'), 20))
     assert dr.all(y == (30, 40))
 
+
+@pytest.test_arrays('float32,is_diff,shape=(*)')
+@pytest.mark.parametrize('mode', ['evaluated', 'symbolic'])
+def test18_if_stmt_preserve_unused_ad(t, mode):
+    with dr.scoped_set_flag(dr.JitFlag.SymbolicConditionals, False):
+        x = t(0, 1)
+        y = t(1, 3)
+        dr.enable_grad(x, y)
+        y_id = y.index_ad
+
+        with dr.suspend_grad():
+            def true_fn(x, y):
+                return x + y, y
+
+            def false_fn(x, y):
+                return x, y
+
+            x, y = dr.if_stmt(
+                args=(x, y),
+                cond=x<.5,
+                arg_labels=('x', 'y'),
+                true_fn=true_fn,
+                false_fn=false_fn,
+                mode=mode
+            )
+
+        assert not dr.grad_enabled(x)
+        assert dr.grad_enabled(y)
+        assert y.index_ad == y_id
diff --git a/tests/test_while_loop.py b/tests/test_while_loop.py
index 0a2f28c4..0ae0a337 100644
--- a/tests/test_while_loop.py
+++ b/tests/test_while_loop.py
@@ -664,3 +664,31 @@ def loop(t, x, y: t, n = 10):
     dr.make_opaque(x, y)
 
     y = loop(t, [x, x], y)
+
+
+@pytest.test_arrays('float32,diff,shape=(*)')
+@pytest.mark.parametrize('mode', ['symbolic', 'evaluated'])
+def test29_preserve_differentiability_suspend(t, mode):
+    x = t(0, 0)
+    y = t(1, 2)
+    dr.enable_grad(x, y)
+    y_id = y.index_ad
+
+    with dr.suspend_grad():
+        def cond_fn(x, _):
+            return x < 10
+
+        def body_fn(x, y):
+            return x + y, y
+
+        x, y = dr.while_loop(
+            state=(x, y),
+            cond=cond_fn,
+            labels=('x', 'y'),
+            body=body_fn,
+            mode=mode
+        )
+
+    assert not dr.grad_enabled(x)
+    assert dr.grad_enabled(y)
+    assert y.index_ad == y_id