Skip to content

Commit

Permalink
Tweak to reference counting semantics
Browse files Browse the repository at this point in the history
The AD layer exposes a function named ``ad_var_inc_ref()`` that
increases the reference count of a variable analogous to
``jit_var_inc_ref()``. However, one difference between the two is that
the former detaches AD variables when the underlying index has
derivative tracking disabled.

For example, this ensures that code like

```python
x = Float(0)
dr.enable_grad(x)
with dr.suspend_grad():
    y = Float(y)
```
creates a non-differentiable copy.

However, since there are many other operations throughout the Dr.Jit
codebase that require reference counting, there were quite a few places
that exhibited this detaching behavior, which is not always wanted.
(see issue #253).

This commit provides two reference counting functions:

- ``ad_var_inc_ref()`` which increases the reference count *without*
  detaching, and

- ``ad_var_copy_ref()``, which detaches (i.e., reproducing the former
  behavior)

Following this split, only the constructor of AD arrays uses the
detaching ``ad_var_copy_ref()``, while all other operations use
the new ``ad_var_inc_ref()``.
  • Loading branch information
wjakob authored and njroussel committed Oct 21, 2024
1 parent f62bf3c commit 091d8ca
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 37 deletions.
39 changes: 17 additions & 22 deletions include/drjit/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
}

DiffArray(const DiffArray &a) {
if constexpr (IsFloat) {
m_index = ad_var_inc_ref(a.m_index);
} else {
m_index = a.m_index;
jit_var_inc_ref(m_index);
}
if constexpr (IsFloat)
m_index = ad_var_copy_ref(a.m_index);
else
m_index = jit_var_inc_ref(a.m_index);
}

DiffArray(DiffArray &&a) noexcept : m_index(a.m_index) {
Expand All @@ -148,8 +146,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray
m_index = jit_var_cast((uint32_t) v.m_index, Type, 1);
}

DiffArray(const Detached &v) : m_index(v.index()) {
jit_var_inc_ref((uint32_t) m_index);
DiffArray(const Detached &v) {
m_index = jit_var_inc_ref((uint32_t) v.index());
}

template <typename T, enable_if_scalar_t<T> = 0>
Expand All @@ -161,14 +159,15 @@ struct DRJIT_TRIVIAL_ABI DiffArray

DiffArray &operator=(const DiffArray &a) {
Index old_index = m_index;

if constexpr (IsFloat) {
m_index = ad_var_inc_ref(a.m_index);
m_index = ad_var_copy_ref(a.m_index);
ad_var_dec_ref(old_index);
} else {
m_index = a.m_index;
jit_var_inc_ref(m_index);
m_index = jit_var_inc_ref(a.m_index);
jit_var_dec_ref(old_index);
}

return *this;
}

Expand Down Expand Up @@ -685,12 +684,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
static DRJIT_INLINE DiffArray borrow(Index index) {
DiffArray result;

if constexpr (IsFloat) {
if constexpr (IsFloat)
result.m_index = ad_var_inc_ref(index);
} else {
jit_var_inc_ref(index);
result.m_index = index;
}
else
result.m_index = jit_var_inc_ref(index);

return result;
}
Expand Down Expand Up @@ -742,7 +739,7 @@ struct DRJIT_TRIVIAL_ABI DiffArray
m_index = ad_var_new(jit_index);
jit_var_dec_ref(jit_index);
} else {
jit_var_inc_ref(jit_index);
jit_index = jit_var_inc_ref(jit_index);
ad_var_dec_ref(m_index);
m_index = jit_index;
}
Expand Down Expand Up @@ -956,7 +953,7 @@ NAMESPACE_BEGIN(detail)
template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
vector<uint64_t> &indices = *(vector<uint64_t> *) p;
if constexpr (IncRef)
ad_var_inc_ref(index);
index = ad_var_inc_ref(index);
indices.push_back(index);
}

Expand Down Expand Up @@ -1010,7 +1007,7 @@ struct ad_index32_vector : drjit::vector<uint32_t> {

void push_back_steal(uint32_t index) { push_back(index); }
void push_back_borrow(uint32_t index) {
push_back(uint32_t(ad_var_inc_ref(uint64_t(index) << 32) >> 32));
push_back((uint32_t) (ad_var_inc_ref(uint64_t(index) << 32) >> 32));
}
};

Expand All @@ -1033,9 +1030,7 @@ struct index64_vector : drjit::vector<uint64_t> {
}

void push_back_steal(uint64_t index) { push_back(index); }
void push_back_borrow(uint64_t index) {
push_back(ad_var_inc_ref(index));
}
void push_back_borrow(uint64_t index) { push_back(ad_var_inc_ref(index)); }
};

NAMESPACE_END(detail)
Expand Down
22 changes: 17 additions & 5 deletions include/drjit/extra.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,16 @@ extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value);
/// Clear the gradient of a given variable
extern DRJIT_EXTRA_EXPORT void ad_clear_grad(uint64_t index);

/// Increase the reference count of the given AD variable
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;

/**
* \brief Increase the reference count of the given AD variable
* \brief Variant of 'ad_var_inc_ref' that conceptually creates a copy
*
* This function is typically called when an AD variable is copied. It may
* return a detached variable when an active AD scope disables differentiation
* of the provided input variable.
* This function return a detached variable when an active AD scope disables
* differentiation of the provided input variable.
*/
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_copy_ref_impl(uint64_t) JIT_NOEXCEPT;

/// Decrease the reference count of the given AD variable
extern DRJIT_EXTRA_EXPORT void ad_var_dec_ref_impl(uint64_t) JIT_NOEXCEPT;
Expand Down Expand Up @@ -484,13 +486,23 @@ DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT {
return ad_var_inc_ref_impl(index);
}

DRJIT_INLINE uint64_t ad_var_copy_ref(uint64_t index) JIT_NOEXCEPT {
/* If 'index' is known at compile time, it can only be zero, in
which case we can skip the redundant call to ad_var_dec_ref */
if (__builtin_constant_p(index))
return 0;
else
return ad_var_copy_ref_impl(index);
}

DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT {
if (!__builtin_constant_p(index))
ad_var_dec_ref_impl(index);
}
#else
#define ad_var_dec_ref ad_var_dec_ref_impl
#define ad_var_inc_ref ad_var_inc_ref_impl
#define ad_var_copy_ref ad_var_copy_ref_impl
#endif

// Return the AD reference count of a variable (for debugging)
Expand Down
14 changes: 6 additions & 8 deletions include/drjit/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ struct DRJIT_TRIVIAL_ABI JitArray

~JitArray() noexcept { jit_var_dec_ref(m_index); }

JitArray(const JitArray &a) : m_index(a.m_index) {
jit_var_inc_ref(m_index);
JitArray(const JitArray &a) {
m_index = jit_var_inc_ref(a.m_index);
}

JitArray(JitArray &&a) noexcept : m_index(a.m_index) {
Expand Down Expand Up @@ -122,9 +122,9 @@ struct DRJIT_TRIVIAL_ABI JitArray
}

JitArray &operator=(const JitArray &a) {
jit_var_inc_ref(a.m_index);
uint32_t index = jit_var_inc_ref(a.m_index);
jit_var_dec_ref(m_index);
m_index = a.m_index;
m_index = index;
return *this;
}

Expand Down Expand Up @@ -677,8 +677,7 @@ struct DRJIT_TRIVIAL_ABI JitArray

static DRJIT_INLINE JitArray borrow(Index index) {
JitArray result;
jit_var_inc_ref(index);
result.m_index = index;
result.m_index = jit_var_inc_ref(index);
return result;
}

Expand Down Expand Up @@ -733,8 +732,7 @@ struct index32_vector : drjit::vector<uint32_t> {

void push_back_steal(uint32_t index) { push_back(index); }
void push_back_borrow(uint32_t index) {
jit_var_inc_ref(index);
push_back(index);
push_back(jit_var_inc_ref(index));
}
};

Expand Down
18 changes: 16 additions & 2 deletions src/extra/autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,11 @@ static void ad_free(ADIndex index, Variable *v) {
state.unused_variables.push(index);
}

Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
Index ad_var_copy_ref_impl(Index index) JIT_NOEXCEPT {
JitIndex jit_index = ::jit_index(index);
ADIndex ad_index = ::ad_index(index);

jit_var_inc_ref(jit_index);
jit_index = jit_var_inc_ref(jit_index);

if (unlikely(ad_index)) {
const std::vector<Scope> &scopes = local_state.scopes;
Expand All @@ -715,6 +715,20 @@ Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
return combine(ad_index, jit_index);
}

Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
JitIndex jit_index = ::jit_index(index);
ADIndex ad_index = ::ad_index(index);

jit_var_inc_ref(jit_index);

if (unlikely(ad_index)) {
std::lock_guard<std::mutex> guard(state.mutex);
ad_var_inc_ref_int(ad_index, state[ad_index]);
}

return index;
}


uint32_t ad_var_ref(uint64_t index) {
uint32_t ad_index = ::ad_index(index);
Expand Down

0 comments on commit 091d8ca

Please sign in to comment.