Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the behavior of control flow operations (dr.if_stmt(), drjit.while_loop()) in AD-suspended scopes #299

Merged
merged 3 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions include/drjit/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
}

DiffArray(const DiffArray &a) {
if constexpr (IsFloat) {
m_index = ad_var_inc_ref(a.m_index);
} else {
m_index = a.m_index;
jit_var_inc_ref(m_index);
}
if constexpr (IsFloat)
m_index = ad_var_copy_ref(a.m_index);
else
m_index = jit_var_inc_ref(a.m_index);
}

DiffArray(DiffArray &&a) noexcept : m_index(a.m_index) {
Expand All @@ -148,8 +146,8 @@ struct DRJIT_TRIVIAL_ABI DiffArray
m_index = jit_var_cast((uint32_t) v.m_index, Type, 1);
}

DiffArray(const Detached &v) : m_index(v.index()) {
jit_var_inc_ref((uint32_t) m_index);
DiffArray(const Detached &v) {
m_index = jit_var_inc_ref((uint32_t) v.index());
}

template <typename T, enable_if_scalar_t<T> = 0>
Expand All @@ -161,14 +159,15 @@ struct DRJIT_TRIVIAL_ABI DiffArray

DiffArray &operator=(const DiffArray &a) {
Index old_index = m_index;

if constexpr (IsFloat) {
m_index = ad_var_inc_ref(a.m_index);
m_index = ad_var_copy_ref(a.m_index);
ad_var_dec_ref(old_index);
} else {
m_index = a.m_index;
jit_var_inc_ref(m_index);
m_index = jit_var_inc_ref(a.m_index);
jit_var_dec_ref(old_index);
}

return *this;
}

Expand Down Expand Up @@ -685,12 +684,10 @@ struct DRJIT_TRIVIAL_ABI DiffArray
static DRJIT_INLINE DiffArray borrow(Index index) {
DiffArray result;

if constexpr (IsFloat) {
if constexpr (IsFloat)
result.m_index = ad_var_inc_ref(index);
} else {
jit_var_inc_ref(index);
result.m_index = index;
}
else
result.m_index = jit_var_inc_ref(index);

return result;
}
Expand Down Expand Up @@ -742,7 +739,7 @@ struct DRJIT_TRIVIAL_ABI DiffArray
m_index = ad_var_new(jit_index);
jit_var_dec_ref(jit_index);
} else {
jit_var_inc_ref(jit_index);
jit_index = jit_var_inc_ref(jit_index);
ad_var_dec_ref(m_index);
m_index = jit_index;
}
Expand Down Expand Up @@ -956,7 +953,7 @@ NAMESPACE_BEGIN(detail)
template <bool IncRef> void collect_indices_fn(void *p, uint64_t index) {
vector<uint64_t> &indices = *(vector<uint64_t> *) p;
if constexpr (IncRef)
ad_var_inc_ref(index);
index = ad_var_inc_ref(index);
indices.push_back(index);
}

Expand Down Expand Up @@ -1010,7 +1007,7 @@ struct ad_index32_vector : drjit::vector<uint32_t> {

void push_back_steal(uint32_t index) { push_back(index); }
void push_back_borrow(uint32_t index) {
push_back(uint32_t(ad_var_inc_ref(uint64_t(index) << 32) >> 32));
push_back((uint32_t) (ad_var_inc_ref(uint64_t(index) << 32) >> 32));
}
};

Expand All @@ -1033,9 +1030,7 @@ struct index64_vector : drjit::vector<uint64_t> {
}

void push_back_steal(uint64_t index) { push_back(index); }
void push_back_borrow(uint64_t index) {
push_back(ad_var_inc_ref(index));
}
void push_back_borrow(uint64_t index) { push_back(ad_var_inc_ref(index)); }
};

NAMESPACE_END(detail)
Expand Down
28 changes: 23 additions & 5 deletions include/drjit/extra.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,28 @@ extern DRJIT_EXTRA_EXPORT uint32_t ad_grad(uint64_t index, bool null_ok = false)
/// Check if gradient tracking is enabled for the given variable
extern DRJIT_EXTRA_EXPORT int ad_grad_enabled(uint64_t index);

/// Check if gradient tracking is disabled (can't create new AD variables)
extern DRJIT_EXTRA_EXPORT int ad_grad_suspended();

/// Temporarily enforce gradient tracking without creating a new scope
extern DRJIT_EXTRA_EXPORT int ad_set_force_grad(int status);

/// Accumulate into the gradient associated with a given variable
extern DRJIT_EXTRA_EXPORT void ad_accum_grad(uint64_t index, uint32_t value);

/// Clear the gradient of a given variable
extern DRJIT_EXTRA_EXPORT void ad_clear_grad(uint64_t index);

/// Increase the reference count of the given AD variable
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;

/**
* \brief Increase the reference count of the given AD variable
* \brief Variant of 'ad_var_inc_ref' that conceptually creates a copy
*
* This function is typically called when an AD variable is copied. It may
* return a detached variable when an active AD scope disables differentiation
* of the provided input variable.
* This function return a detached variable when an active AD scope disables
* differentiation of the provided input variable.
*/
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_inc_ref_impl(uint64_t) JIT_NOEXCEPT;
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_copy_ref_impl(uint64_t) JIT_NOEXCEPT;

/// Decrease the reference count of the given AD variable
extern DRJIT_EXTRA_EXPORT void ad_var_dec_ref_impl(uint64_t) JIT_NOEXCEPT;
Expand Down Expand Up @@ -484,13 +492,23 @@ DRJIT_INLINE uint64_t ad_var_inc_ref(uint64_t index) JIT_NOEXCEPT {
return ad_var_inc_ref_impl(index);
}

DRJIT_INLINE uint64_t ad_var_copy_ref(uint64_t index) JIT_NOEXCEPT {
/* If 'index' is known at compile time, it can only be zero, in
which case we can skip the redundant call to ad_var_dec_ref */
if (__builtin_constant_p(index))
return 0;
else
return ad_var_copy_ref_impl(index);
}

DRJIT_INLINE void ad_var_dec_ref(uint64_t index) JIT_NOEXCEPT {
if (!__builtin_constant_p(index))
ad_var_dec_ref_impl(index);
}
#else
#define ad_var_dec_ref ad_var_dec_ref_impl
#define ad_var_inc_ref ad_var_inc_ref_impl
#define ad_var_copy_ref ad_var_copy_ref_impl
#endif

// Return the AD reference count of a variable (for debugging)
Expand Down
14 changes: 6 additions & 8 deletions include/drjit/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ struct DRJIT_TRIVIAL_ABI JitArray

~JitArray() noexcept { jit_var_dec_ref(m_index); }

JitArray(const JitArray &a) : m_index(a.m_index) {
jit_var_inc_ref(m_index);
JitArray(const JitArray &a) {
m_index = jit_var_inc_ref(a.m_index);
}

JitArray(JitArray &&a) noexcept : m_index(a.m_index) {
Expand Down Expand Up @@ -122,9 +122,9 @@ struct DRJIT_TRIVIAL_ABI JitArray
}

JitArray &operator=(const JitArray &a) {
jit_var_inc_ref(a.m_index);
uint32_t index = jit_var_inc_ref(a.m_index);
jit_var_dec_ref(m_index);
m_index = a.m_index;
m_index = index;
return *this;
}

Expand Down Expand Up @@ -677,8 +677,7 @@ struct DRJIT_TRIVIAL_ABI JitArray

static DRJIT_INLINE JitArray borrow(Index index) {
JitArray result;
jit_var_inc_ref(index);
result.m_index = index;
result.m_index = jit_var_inc_ref(index);
return result;
}

Expand Down Expand Up @@ -733,8 +732,7 @@ struct index32_vector : drjit::vector<uint32_t> {

void push_back_steal(uint32_t index) { push_back(index); }
void push_back_borrow(uint32_t index) {
jit_var_inc_ref(index);
push_back(index);
push_back(jit_var_inc_ref(index));
}
};

Expand Down
53 changes: 44 additions & 9 deletions src/extra/autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ struct Scope {
*/
bool isolate = false;

/// Flag to temporarily force gradient tracking in an ad-disabled scope
bool force_grad = false;

// Current ``state.counter`` value when entering this scope
uint64_t counter = 0;

Expand Down Expand Up @@ -537,7 +540,7 @@ struct Scope {

/// Check if a variable has gradients enabled
bool enabled(ADIndex index) const {
return (indices.find(index) != indices.end()) != complement;
return (indices.find(index) != indices.end()) != complement || force_grad;
}

/// Potentially zero out 'index' if the variable has gradients disabled
Expand All @@ -549,7 +552,7 @@ struct Scope {

/// Track gradients for the given variable
void enable(ADIndex index) {
if (!index)
if (!index || force_grad)
return;

if (complement)
Expand Down Expand Up @@ -695,11 +698,11 @@ static void ad_free(ADIndex index, Variable *v) {
state.unused_variables.push(index);
}

Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
Index ad_var_copy_ref_impl(Index index) JIT_NOEXCEPT {
JitIndex jit_index = ::jit_index(index);
ADIndex ad_index = ::ad_index(index);

jit_var_inc_ref(jit_index);
jit_index = jit_var_inc_ref(jit_index);

if (unlikely(ad_index)) {
const std::vector<Scope> &scopes = local_state.scopes;
Expand All @@ -715,6 +718,20 @@ Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
return combine(ad_index, jit_index);
}

Index ad_var_inc_ref_impl(Index index) JIT_NOEXCEPT {
JitIndex jit_index = ::jit_index(index);
ADIndex ad_index = ::ad_index(index);

jit_var_inc_ref(jit_index);

if (unlikely(ad_index)) {
std::lock_guard<std::mutex> guard(state.mutex);
ad_var_inc_ref_int(ad_index, state[ad_index]);
}

return index;
}


uint32_t ad_var_ref(uint64_t index) {
uint32_t ad_index = ::ad_index(index);
Expand Down Expand Up @@ -953,7 +970,7 @@ DRJIT_NOINLINE Index ad_var_new_impl(const char *label, JitVar &&result,
active |= scope.maybe_disable(args[i].ad_index);
}

if (!active)
if (!active && !scope.force_grad)
return (Index) result.release();
}

Expand Down Expand Up @@ -1770,6 +1787,25 @@ int ad_grad_enabled(Index index) {
return ad_index != 0;
}

int ad_grad_suspended() {
const std::vector<Scope> &scopes = local_state.scopes;
if (scopes.empty())
return false;
else
return scopes.back().complement == false;
}

/// Temporarily enforce gradient tracking without creating a new scope
int ad_set_force_grad(int status) {
std::vector<Scope> &scopes = local_state.scopes;
if (scopes.empty())
return 0;
Scope &scope = scopes.back();
bool old = scope.force_grad;
scope.force_grad = (bool) status;
return (int) old;
}

// ==========================================================================
// AD traversal callbacks for special operations: masks, gathers, scatters
// ==========================================================================
Expand Down Expand Up @@ -2811,12 +2847,11 @@ Index ad_var_cast(Index i0, VarType vt) {
void ad_var_map_put(Index source, Index target) {
uint32_t ad_index_source = ad_index(source),
ad_index_target = ad_index(target);

if ((ad_index_source == 0) != (ad_index_target == 0))
ad_raise("ad_var_map_put(): mixed attached/detached inputs!");
if (ad_index_target == 0)
return;

if (ad_index_source == 0)
return;
ad_raise("ad_var_map_put(): mixed attached/detached inputs!");

ad_log("ad_var_map_put(): a%u -> a%u", ad_index_source, ad_index_target);

Expand Down
6 changes: 3 additions & 3 deletions src/extra/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static void ad_call_getter(JitBackend backend, const char *domain,
"ad_call_getter(\"%s%s%s\", index=r%u, mask=r%u)", domain_or_empty,
separator, name, index, mask.index());

scoped_isolation_boundary guard;
scoped_isolation_guard guard;
{
scoped_record rec(backend, name, true);

Expand Down Expand Up @@ -279,7 +279,7 @@ static void ad_call_symbolic(JitBackend backend, const char *domain,
/* Postponed operations captured by the isolation scope should only
* be executed once we've exited the symbolic scope. We therefore
* need to declare the AD isolation guard before the recording guard. */
scoped_isolation_boundary guard_1(1);
scoped_isolation_guard guard_1(1);

scoped_record guard_2(backend, name, true);
// Recording may fail due to recursion depth
Expand Down Expand Up @@ -631,7 +631,7 @@ struct CallOp : public dr::detail::CustomOpBase {

/// Implements f(arg..., grad(rv)...) -> grad(arg) ...
void backward() override {
scoped_isolation_boundary isolation_guard;
scoped_isolation_guard isolation_guard;
std::string name = m_name + " [ad, bwd]";

index64_vector args, rv;
Expand Down
13 changes: 10 additions & 3 deletions src/extra/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ template <typename T> class unlock_guard {
};

/// RAII AD Isolation helper
struct scoped_isolation_boundary {
scoped_isolation_boundary(int symbolic = -1) : symbolic(symbolic) {
struct scoped_isolation_guard {
scoped_isolation_guard(int symbolic = -1) : symbolic(symbolic) {
ad_scope_enter(drjit::ADScope::Isolate, 0, nullptr, symbolic);
}

~scoped_isolation_boundary() {
~scoped_isolation_guard() {
ad_scope_leave(success);
}

Expand All @@ -58,6 +58,13 @@ struct scoped_isolation_boundary {
bool success = false;
};

struct scoped_force_grad_guard {
scoped_force_grad_guard() { value = ad_set_force_grad(1); }
~scoped_force_grad_guard() { ad_set_force_grad(value); }
bool value;
};


/// RAII helper to temporarily push a mask onto the Dr.Jit mask stack
struct scoped_push_mask {
scoped_push_mask(JitBackend backend, uint32_t index) : backend(backend) {
Expand Down
Loading