Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Minor fixes to auto-diff documentation #5621

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/user-guide/07-autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ interface IDifferentiablePtrType
}
```

> #### Note ####
> Support for `IDifferentiablePtrType` is still experimental.

Types should not conform to both `IDifferentiablePtrType` and `IDifferentiable`. Such cases will result in a compiler error.


Expand Down
36 changes: 25 additions & 11 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,18 @@ void __requireGLSLExtension(constexpr String preludeText);
__intrinsic_op($(kIROp_StaticAssert))
void static_assert(constexpr bool condition, NativeString errorMessage);

/// Interface to denote types as differentiable.
/// Allows for user-specified differential types as
/// well as automatic generation, for when the associated type
/// hasn't been declared explicitly.
/// Note that the requirements must currently be defined in this exact order
/// since the auto-diff pass relies on the order to grab the struct keys.
/// Represents a type that is differentiable for the purposes of automatic differentiation.
///
/// Implemented by builtin floating-point scalar types (`float`, `half`, `double`)
///
/// vector<T, N>, matrix<T, N, M> and Array<T, N> automatically conform to
/// `IDifferentiable` if `T` conforms to `IDifferentiable`.
///
/// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation
/// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters.
/// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable`
/// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if
/// they are not already defined.
__magic_type(DifferentiableType)
interface IDifferentiable
{
Expand All @@ -446,9 +451,13 @@ interface IDifferentiable
static Differential dmul(T, Differential);
};

/// Represents a type that supports differentiation operations for pointer types.
/// This interface is used to define operations that are specific to pointer types
/// in the context of automatic differentiation.
/// @experimental
///
/// Represents a type that supports differentiation operations for pointers, buffers and
/// any other types
///
/// @remarks Support for this interface is still experimental and subject to change.
///
__magic_type(DifferentiablePtrType)
interface IDifferentiablePtrType
{
Expand All @@ -458,8 +467,9 @@ interface IDifferentiablePtrType


/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.

/// differential types of a differentiable value type
/// T that conforms to `IDifferentiable`.
///
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
Expand Down Expand Up @@ -528,6 +538,10 @@ struct DifferentialPair : IDifferentiable
}
};

/// Pair type that serves to wrap the primal and
/// differential types of a differentiable pointer type
/// T that conforms to `IDifferentiablePtrType`.
///
__generic<T : IDifferentiablePtrType>
__magic_type(DifferentialPtrPairType)
__intrinsic_type($(kIROp_DifferentialPtrPairType))
Expand Down
Loading