diff --git a/Project.toml b/Project.toml index c34cc46c..7054f706 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseConnectivityTracer" uuid = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" authors = ["Adrian Hill "] -version = "0.2.1" +version = "0.3.0-DEV" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/README.md b/README.md index 439d00d6..a2a56fc3 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,9 @@ Fast Jacobian and Hessian sparsity detection via operator-overloading. +> [!WARNING] +> This package is in early development. Expect frequent breaking changes and refer to the stable documentation. + ## Installation To install this package, open the Julia REPL and run @@ -28,7 +31,7 @@ julia> x = rand(3); julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; -julia> pattern(f, JacobianTracer, x) +julia> pattern(f, JacobianTracer{BitSet}, x) 3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: 1 ⋅ ⋅ 1 1 ⋅ @@ -43,7 +46,7 @@ julia> x = rand(28, 28, 3, 1); julia> layer = Conv((3, 3), 3 => 8); -julia> pattern(layer, JacobianTracer, x) +julia> pattern(layer, JacobianTracer{BitSet}, x) 5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries: ⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤ ⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥ @@ -76,7 +79,7 @@ julia> x = rand(5); julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; -julia> pattern(f, HessianTracer, x) +julia> pattern(f, HessianTracer{BitSet}, x) 5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ @@ -86,7 +89,7 @@ julia> pattern(f, HessianTracer, x) julia> g(x) = f(x) + x[2]^x[5]; -julia> pattern(g, HessianTracer, x) +julia> pattern(g, HessianTracer{BitSet}, x) 5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 diff --git a/src/adtypes.jl b/src/adtypes.jl index c08702b1..8706f4ff 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -28,16 +28,24 @@ julia> ADTypes.hessian_sparsity(f, rand(4), TracerSparsityDetector()) ⋅ ⋅ ⋅ 1 ``` """ -struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end - -function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector) - return pattern(f, JacobianTracer, x) +struct TracerSparsityDetector{S<:AbstractIndexSet} <: ADTypes.AbstractSparsityDetector end +TracerSparsityDetector(::Type{S}) where {S<:AbstractIndexSet} = TracerSparsityDetector{S}() +TracerSparsityDetector() = TracerSparsityDetector(BitSet) + +function ADTypes.jacobian_sparsity( + f, x, ::TracerSparsityDetector{S} +) where {S<:AbstractIndexSet} + return pattern(f, JacobianTracer{S}, x) end -function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector) - return pattern(f!, y, JacobianTracer, x) +function ADTypes.jacobian_sparsity( + f!, y, x, ::TracerSparsityDetector{S} +) where {S<:AbstractIndexSet} + return pattern(f!, y, JacobianTracer{S}, x) end -function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector) - return pattern(f, HessianTracer, x) +function ADTypes.hessian_sparsity( + f, x, ::TracerSparsityDetector{S} +) where {S<:AbstractIndexSet} + return pattern(f, HessianTracer{S}, x) end diff --git a/src/conversion.jl b/src/conversion.jl index 2856879d..7d3ddca6 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,28 +1,31 @@ ## Type conversions -for T in (:JacobianTracer, :ConnectivityTracer, :HessianTracer) - @eval Base.promote_rule(::Type{$T}, ::Type{N}) where {N<:Number} = $T - @eval Base.promote_rule(::Type{N}, ::Type{$T}) where {N<:Number} = $T +for TT in (:JacobianTracer, :ConnectivityTracer, :HessianTracer) + @eval Base.promote_rule(::Type{T}, ::Type{N}) where {T<:$TT,N<:Number} = T + @eval Base.promote_rule(::Type{N}, ::Type{T}) where {T<:$TT,N<:Number} = T - @eval Base.big(::Type{$T}) = $T - @eval Base.widen(::Type{$T}) = $T - @eval Base.widen(t::$T) = t + @eval Base.big(::Type{T}) where {T<:$TT} = T + @eval Base.widen(::Type{T}) where {T<:$TT} = T + @eval Base.widen(t::T) where {T<:$TT} = t - @eval Base.convert(::Type{$T}, x::Number) = empty($T) - @eval Base.convert(::Type{$T}, t::$T) = t - @eval Base.convert(::Type{<:Number}, t::$T) = t + @eval Base.convert(::Type{T}, x::Number) where {T<:$TT} = empty(T) + @eval Base.convert(::Type{T}, t::T) where {T<:$TT} = t + @eval Base.convert(::Type{<:Number}, t::T) where {T<:$TT} = t ## Constants - @eval Base.zero(::Type{$T}) = empty($T) - @eval Base.one(::Type{$T}) = empty($T) - @eval Base.typemin(::Type{$T}) = empty($T) - @eval Base.typemax(::Type{$T}) = empty($T) + @eval Base.zero(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.one(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.typemin(::Type{T}) where {T<:$TT} = empty(T) + @eval Base.typemax(::Type{T}) where {T<:$TT} = empty(T) ## Array constructors - @eval Base.similar(a::Array{$T,1}) = zeros($T, size(a, 1)) - @eval Base.similar(a::Array{$T,2}) = zeros($T, size(a, 1), size(a, 2)) - @eval Base.similar(a::Array{A,1}, ::Type{$T}) where {A} = zeros($T, size(a, 1)) - @eval Base.similar(a::Array{A,2}, ::Type{$T}) where {A} = zeros($T, size(a, 1), size(a, 2)) - @eval Base.similar(::Array{$T}, m::Int) = zeros($T, m) - @eval Base.similar(::Array, ::Type{$T}, dims::Dims{N}) where {N} = zeros($T, dims) - @eval Base.similar(::Array{$T}, dims::Dims{N}) where {N} = zeros($T, dims) + @eval Base.similar(a::Array{T,1}) where {T<:$TT} = zeros(T, size(a, 1)) + @eval Base.similar(a::Array{T,2}) where {T<:$TT} = zeros(T, size(a, 1), size(a, 2)) + @eval Base.similar(a::Array{A,1}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1)) + @eval Base.similar(a::Array{A,2}, ::Type{T}) where {A,T<:$TT} = zeros(T, size(a, 1), size(a, 2)) + @eval Base.similar(::Array{T}, m::Int) where {T<:$TT} = zeros(T, m) + @eval Base.similar(::Array{T}, dims::Dims{N}) where {N,T<:$TT} = zeros(T, dims) + + @eval Base.similar( + ::Array, ::Type{$TT{S}}, dims::Dims{N} + ) where {N,S<:AbstractIndexSet} = zeros($TT{S}, dims) end diff --git a/src/overload_connectivity.jl b/src/overload_connectivity.jl index 406a8c75..1641b541 100644 --- a/src/overload_connectivity.jl +++ b/src/overload_connectivity.jl @@ -3,7 +3,7 @@ for fn in union(ops_1_to_1_s, ops_1_to_1_f, ops_1_to_1_z) end for fn in ops_1_to_1_const - @eval Base.$fn(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER + @eval Base.$fn(::T) where {T<:ConnectivityTracer} = empty(T) end for fn in ops_1_to_2 @@ -28,4 +28,4 @@ Base.:^(::Irrational{:ℯ}, t::ConnectivityTracer) = t Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t ## Random numbers -rand(::AbstractRNG, ::SamplerType{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER +rand(::AbstractRNG, ::SamplerType{T}) where {T<:ConnectivityTracer} = empty(T) diff --git a/src/overload_hessian.jl b/src/overload_hessian.jl index 443f658a..03aa13a9 100644 --- a/src/overload_hessian.jl +++ b/src/overload_hessian.jl @@ -7,7 +7,7 @@ for fn in ops_1_to_1_f end for fn in union(ops_1_to_1_z, ops_1_to_1_const) - @eval Base.$fn(::HessianTracer) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::T) where {T<:HessianTracer} = empty(T) end ## 2-to-1 @@ -86,31 +86,31 @@ end for fn in ops_2_to_1_szz @eval Base.$fn(t::HessianTracer, ::HessianTracer) = promote_order(t) @eval Base.$fn(t::HessianTracer, ::Number) = promote_order(t) - @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T) end for fn in ops_2_to_1_zsz @eval Base.$fn(::HessianTracer, t::HessianTracer) = promote_order(t) - @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T) @eval Base.$fn(::Number, t::HessianTracer) = promote_order(t) end for fn in ops_2_to_1_fzz @eval Base.$fn(t::HessianTracer, ::HessianTracer) = t @eval Base.$fn(t::HessianTracer, ::Number) = t - @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T) end for fn in ops_2_to_1_zfz @eval Base.$fn(::HessianTracer, t::HessianTracer) = t - @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T) @eval Base.$fn(::Number, t::HessianTracer) = t end for fn in ops_2_to_1_zzz - @eval Base.$fn(::HessianTracer, t::HessianTracer) = EMPTY_HESSIAN_TRACER - @eval Base.$fn(::HessianTracer, ::Number) = EMPTY_HESSIAN_TRACER - @eval Base.$fn(::Number, t::HessianTracer) = EMPTY_HESSIAN_TRACER + @eval Base.$fn(::T, t::T) where {T<:HessianTracer} = empty(T) + @eval Base.$fn(::T, ::Number) where {T<:HessianTracer} = empty(T) + @eval Base.$fn(::Number, t::T) where {T<:HessianTracer} = empty(T) end # Extra types required for exponent @@ -122,7 +122,7 @@ Base.:^(t::HessianTracer, ::Irrational{:ℯ}) = promote_order(t) Base.:^(::Irrational{:ℯ}, t::HessianTracer) = promote_order(t) ## Rounding -Base.round(t::HessianTracer, ::RoundingMode; kwargs...) = EMPTY_HESSIAN_TRACER +Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = empty(T) ## Random numbers -rand(::AbstractRNG, ::SamplerType{HessianTracer}) = EMPTY_HESSIAN_TRACER +rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = empty(T) diff --git a/src/overload_jacobian.jl b/src/overload_jacobian.jl index 8fad8b31..c19698b8 100644 --- a/src/overload_jacobian.jl +++ b/src/overload_jacobian.jl @@ -3,7 +3,7 @@ for fn in union(ops_1_to_1_s, ops_1_to_1_f) end for fn in union(ops_1_to_1_z, ops_1_to_1_const) - @eval Base.$fn(::JacobianTracer) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::T) where {T<:JacobianTracer} = empty(T) end for fn in union( @@ -23,18 +23,18 @@ end for fn in union(ops_2_to_1_zsz, ops_2_to_1_zfz) @eval Base.$fn(::JacobianTracer, t::JacobianTracer) = t - @eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::T, ::Number) where {T<:JacobianTracer} = empty(T) @eval Base.$fn(::Number, t::JacobianTracer) = t end for fn in union(ops_2_to_1_szz, ops_2_to_1_fzz) @eval Base.$fn(t::JacobianTracer, ::JacobianTracer) = t @eval Base.$fn(t::JacobianTracer, ::Number) = t - @eval Base.$fn(::Number, t::JacobianTracer) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::Number, ::T) where {T<:JacobianTracer} = empty(T) end for fn in ops_2_to_1_zzz - @eval Base.$fn(::JacobianTracer, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER - @eval Base.$fn(::JacobianTracer, ::Number) = EMPTY_JACOBIAN_TRACER - @eval Base.$fn(::Number, ::JacobianTracer) = EMPTY_JACOBIAN_TRACER + @eval Base.$fn(::T, ::T) where {T<:JacobianTracer} = empty(T) + @eval Base.$fn(::T, ::Number) where {T<:JacobianTracer} = empty(T) + @eval Base.$fn(::Number, ::T) where {T<:JacobianTracer} = empty(T) end for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff) @@ -42,14 +42,14 @@ for fn in union(ops_1_to_2_ss, ops_1_to_2_sf, ops_1_to_2_fs, ops_1_to_2_ff) end for fn in union(ops_1_to_2_sz, ops_1_to_2_fz) - @eval Base.$fn(t::JacobianTracer) = (t, EMPTY_JACOBIAN_TRACER) + @eval Base.$fn(t::T) where {T<:JacobianTracer} = (t, empty(T)) end for fn in union(ops_1_to_2_zs, ops_1_to_2_zf) - @eval Base.$fn(t::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, t) + @eval Base.$fn(t::T) where {T<:JacobianTracer} = (empty(T), t) end for fn in ops_1_to_2_zz - @eval Base.$fn(::JacobianTracer) = (EMPTY_JACOBIAN_TRACER, EMPTY_JACOBIAN_TRACER) + @eval Base.$fn(::T) where {T<:JacobianTracer} = (empty(T), empty(T)) end # Extra types required for exponent @@ -61,7 +61,7 @@ Base.:^(t::JacobianTracer, ::Irrational{:ℯ}) = t Base.:^(::Irrational{:ℯ}, t::JacobianTracer) = t ## Rounding -Base.round(t::JacobianTracer, ::RoundingMode; kwargs...) = EMPTY_JACOBIAN_TRACER +Base.round(t::T, ::RoundingMode; kwargs...) where {T<:JacobianTracer} = empty(T) ## Random numbers -rand(::AbstractRNG, ::SamplerType{JacobianTracer}) = EMPTY_JACOBIAN_TRACER +rand(::AbstractRNG, ::SamplerType{T}) where {T<:JacobianTracer} = empty(T) diff --git a/src/pattern.jl b/src/pattern.jl index f1bfefbc..3f2af560 100644 --- a/src/pattern.jl +++ b/src/pattern.jl @@ -1,38 +1,38 @@ ## Enumerate inputs """ - trace_input(JacobianTracer, x) - trace_input(ConnectivityTracer, x) + trace_input(T, x) + trace_input(T, x) -Enumerates input indices and constructs the specified type of tracer. -Supports [`JacobianTracer`](@ref) and [`ConnectivityTracer`](@ref). +Enumerates input indices and constructs the specified type `T` of tracer. +Supports [`ConnectivityTracer`](@ref), [`JacobianTracer`](@ref) and [`HessianTracer`](@ref). ## Example ```jldoctest julia> x = rand(3); -julia> trace_input(ConnectivityTracer, x) -3-element Vector{ConnectivityTracer}: - ConnectivityTracer(1,) - ConnectivityTracer(2,) - ConnectivityTracer(3,) - -julia> trace_input(JacobianTracer, x) -3-element Vector{JacobianTracer}: - JacobianTracer(1,) - JacobianTracer(2,) - JacobianTracer(3,) - -julia> trace_input(HessianTracer, x) -3-element Vector{HessianTracer}: - HessianTracer( +julia> trace_input(ConnectivityTracer{BitSet}, x) +3-element Vector{ConnectivityTracer{BitSet}}: + ConnectivityTracer{BitSet}(1,) + ConnectivityTracer{BitSet}(2,) + ConnectivityTracer{BitSet}(3,) + +julia> trace_input(JacobianTracer{BitSet}, x) +3-element Vector{JacobianTracer{BitSet}}: + JacobianTracer{BitSet}(1,) + JacobianTracer{BitSet}(2,) + JacobianTracer{BitSet}(3,) + +julia> trace_input(HessianTracer{BitSet}, x) +3-element Vector{HessianTracer{BitSet}}: + HessianTracer{BitSet}( 1 => (), ) - HessianTracer( + HessianTracer{BitSet}( 2 => (), ) - HessianTracer( + HessianTracer{BitSet}( 3 => (), ) ``` @@ -46,16 +46,16 @@ end ## Construct sparsity pattern matrix """ - pattern(f, ConnectivityTracer, x) + pattern(f, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}} Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - pattern(f, JacobianTracer, x) + pattern(f, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}} Computes the sparsity pattern of the Jacobian of `y = f(x)`. - pattern(f, HessianTracer, x) + pattern(f, HessianTracer{S}, x) where {S<:AbstractSet{<:Integer}} Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`. @@ -67,7 +67,7 @@ julia> x = rand(3); julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; -julia> pattern(f, ConnectivityTracer, x) +julia> pattern(f, ConnectivityTracer{BitSet}, x) 3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries: 1 ⋅ ⋅ 1 1 ⋅ @@ -81,7 +81,7 @@ julia> x = rand(5); julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; -julia> pattern(f, HessianTracer, x) +julia> pattern(f, HessianTracer{BitSet}, x) 5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ @@ -91,7 +91,7 @@ julia> pattern(f, HessianTracer, x) julia> g(x) = f(x) + x[2]^x[5]; -julia> pattern(g, HessianTracer, x) +julia> pattern(g, HessianTracer{BitSet}, x) 5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 @@ -107,11 +107,11 @@ function pattern(f, ::Type{T}, x) where {T<:AbstractTracer} end """ - pattern(f!, y, JacobianTracer, x) + pattern(f!, y, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}} Computes the sparsity pattern of the Jacobian of `f!(y, x)`. - pattern(f!, y, ConnectivityTracer, x) + pattern(f!, y, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}} Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)` where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. @@ -151,8 +151,8 @@ function _pattern_to_sparsemat( end function _pattern_to_sparsemat( - xt::AbstractArray{HessianTracer}, yt::AbstractArray{HessianTracer} -) + xt::AbstractArray{HessianTracer{S}}, yt::AbstractArray{HessianTracer{S}} +) where {S<:AbstractIndexSet} length(yt) != 1 && error("pattern(f, HessianTracer, x) expects scalar output y=f(x).") y = only(yt) diff --git a/src/tracers.jl b/src/tracers.jl index c9ae226d..c81699ac 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -1,36 +1,56 @@ +const AbstractIndexSet = AbstractSet{<:Integer} +abstract type AbstractTracer <: Number end + +# Convenience constructor for empty tracers +empty(tracer::T) where {T<:AbstractTracer} = empty(T) + #==============# # Connectivity # #==============# """ - ConnectivityTracer(indexset) <: Number + ConnectivityTracer{S}(indexset) <: Number Number type keeping track of input indices of previous computations. +The provided set type `S` has to be an `AbstractSet{<:Integer}`, e.g. `BitSet` or `Set{UInt64}`. See also the convenience constructor [`tracer`](@ref). For a higher-level interface, refer to [`pattern`](@ref). """ -struct ConnectivityTracer <: AbstractTracer - inputs::BitSet # indices of connected, enumerated inputs +struct ConnectivityTracer{S<:AbstractIndexSet} <: AbstractTracer + inputs::S # indices of connected, enumerated inputs end -function Base.show(io::IO, t::ConnectivityTracer) - return Base.show_delim_array(io, inputs(t), "ConnectivityTracer(", ',', ')', true) +function Base.show(io::IO, t::ConnectivityTracer{S}) where {S<:AbstractIndexSet} + return Base.show_delim_array(io, inputs(t), "ConnectivityTracer{$S}(", ',', ')', true) end -const EMPTY_CONNECTIVITY_TRACER = ConnectivityTracer(BitSet()) -empty(::ConnectivityTracer) = EMPTY_CONNECTIVITY_TRACER -empty(::Type{ConnectivityTracer}) = EMPTY_CONNECTIVITY_TRACER +empty(::Type{ConnectivityTracer{S}}) where {S<:AbstractIndexSet} = ConnectivityTracer(S()) + +# Performance can be gained by not re-allocating empty tracers +const EMPTY_CONNECTIVITY_TRACER_BITSET = ConnectivityTracer(BitSet()) +const EMPTY_CONNECTIVITY_TRACER_SET_UINT8 = ConnectivityTracer(Set{UInt8}()) +const EMPTY_CONNECTIVITY_TRACER_SET_UINT16 = ConnectivityTracer(Set{UInt16}()) +const EMPTY_CONNECTIVITY_TRACER_SET_UINT32 = ConnectivityTracer(Set{UInt32}()) +const EMPTY_CONNECTIVITY_TRACER_SET_UINT64 = ConnectivityTracer(Set{UInt64}()) + +empty(::Type{ConnectivityTracer{BitSet}}) = EMPTY_CONNECTIVITY_TRACER_BITSET +empty(::Type{ConnectivityTracer{Set{UInt8}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT8 +empty(::Type{ConnectivityTracer{Set{UInt16}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT16 +empty(::Type{ConnectivityTracer{Set{UInt32}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT32 +empty(::Type{ConnectivityTracer{Set{UInt64}}}) = EMPTY_CONNECTIVITY_TRACER_SET_UINT64 # We have to be careful when defining constructors: # Generic code expecting "regular" numbers `x` will sometimes convert them # by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`. # When this happens, we create a new empty tracer with no input pattern. -ConnectivityTracer(::Number) = EMPTY_CONNECTIVITY_TRACER +ConnectivityTracer{S}(::Number) where {S<:AbstractIndexSet} = empty(ConnectivityTracer{S}) ConnectivityTracer(t::ConnectivityTracer) = t ## Unions of tracers -function uniontracer(a::ConnectivityTracer, b::ConnectivityTracer) +function uniontracer( + a::ConnectivityTracer{S}, b::ConnectivityTracer{S} +) where {S<:AbstractIndexSet} return ConnectivityTracer(union(a.inputs, b.inputs)) end @@ -39,50 +59,62 @@ end #==========# """ - JacobianTracer(indexset) <: Number + JacobianTracer{S}(indexset) <: Number Number type keeping track of input indices of previous computations with non-zero derivatives. +The provided set type `S` has to be an `AbstractSet{<:Integer}`, e.g. `BitSet` or `Set{UInt64}`. See also the convenience constructor [`tracer`](@ref). For a higher-level interface, refer to [`pattern`](@ref). """ -struct JacobianTracer <: AbstractTracer - inputs::BitSet +struct JacobianTracer{S<:AbstractIndexSet} <: AbstractTracer + inputs::S end -function Base.show(io::IO, t::JacobianTracer) - return Base.show_delim_array(io, inputs(t), "JacobianTracer(", ',', ')', true) +function Base.show(io::IO, t::JacobianTracer{S}) where {S<:AbstractIndexSet} + return Base.show_delim_array(io, inputs(t), "JacobianTracer{$S}(", ',', ')', true) end -const EMPTY_JACOBIAN_TRACER = JacobianTracer(BitSet()) -empty(::JacobianTracer) = EMPTY_JACOBIAN_TRACER -empty(::Type{JacobianTracer}) = EMPTY_JACOBIAN_TRACER +empty(::Type{JacobianTracer{S}}) where {S<:AbstractIndexSet} = JacobianTracer(S()) + +# Performance can be gained by not re-allocating empty tracers +const EMPTY_JACOBIAN_TRACER_BITSET = JacobianTracer(BitSet()) +const EMPTY_JACOBIAN_TRACER_SET_UINT8 = JacobianTracer(Set{UInt8}()) +const EMPTY_JACOBIAN_TRACER_SET_UINT16 = JacobianTracer(Set{UInt16}()) +const EMPTY_JACOBIAN_TRACER_SET_UINT32 = JacobianTracer(Set{UInt32}()) +const EMPTY_JACOBIAN_TRACER_SET_UINT64 = JacobianTracer(Set{UInt64}()) -JacobianTracer(::Number) = EMPTY_JACOBIAN_TRACER +empty(::Type{JacobianTracer{BitSet}}) = EMPTY_JACOBIAN_TRACER_BITSET +empty(::Type{JacobianTracer{Set{UInt8}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT8 +empty(::Type{JacobianTracer{Set{UInt16}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT16 +empty(::Type{JacobianTracer{Set{UInt32}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT32 +empty(::Type{JacobianTracer{Set{UInt64}}}) = EMPTY_JACOBIAN_TRACER_SET_UINT64 + +JacobianTracer{S}(::Number) where {S<:AbstractIndexSet} = empty(JacobianTracer{S}) JacobianTracer(t::JacobianTracer) = t ## Unions of tracers -function uniontracer(a::JacobianTracer, b::JacobianTracer) +function uniontracer(a::JacobianTracer{S}, b::JacobianTracer{S}) where {S<:AbstractIndexSet} return JacobianTracer(union(a.inputs, b.inputs)) end #=========# # Hessian # #=========# -const HessianDict = Dict{UInt64,BitSet} """ - HessianTracer(indexset) <: Number + HessianTracer{S}(indexset) <: Number Number type keeping track of input indices of previous computations with non-zero first and second derivatives. +The provided set type `S` has to be an `AbstractSet{<:Integer}`, e.g. `BitSet` or `Set{UInt64}`. See also the convenience constructor [`tracer`](@ref). For a higher-level interface, refer to [`pattern`](@ref). """ -struct HessianTracer <: AbstractTracer - inputs::HessianDict +struct HessianTracer{S<:AbstractIndexSet} <: AbstractTracer + inputs::Dict{UInt64,S} end -function Base.show(io::IO, t::HessianTracer) - println(io, "HessianTracer(") +function Base.show(io::IO, t::HessianTracer{S}) where {S<:AbstractIndexSet} + println(io, "HessianTracer{", S, "}(") for key in keys(t.inputs) print(io, " ", key, " => ") Base.show_delim_array(io, collect(t.inputs[key]), "(", ',', ')', true) @@ -91,11 +123,24 @@ function Base.show(io::IO, t::HessianTracer) return print(io, ")") end -const EMPTY_HESSIAN_TRACER = HessianTracer(HessianDict()) -empty(::HessianTracer) = EMPTY_HESSIAN_TRACER -empty(::Type{HessianTracer}) = EMPTY_HESSIAN_TRACER +function empty(::Type{HessianTracer{S}}) where {S<:AbstractIndexSet} + return HessianTracer(Dict{UInt64,S}()) +end + +# Performance can be gained by not re-allocating empty tracers +const EMPTY_HESSIAN_TRACER_BITSET = HessianTracer(Dict{UInt64,BitSet}()) +const EMPTY_HESSIAN_TRACER_SET_UINT8 = HessianTracer(Dict{UInt64,Set{UInt8}}()) +const EMPTY_HESSIAN_TRACER_SET_UINT16 = HessianTracer(Dict{UInt64,Set{UInt16}}()) +const EMPTY_HESSIAN_TRACER_SET_UINT32 = HessianTracer(Dict{UInt64,Set{UInt32}}()) +const EMPTY_HESSIAN_TRACER_SET_UINT64 = HessianTracer(Dict{UInt64,Set{UInt64}}()) + +empty(::Type{HessianTracer{BitSet}}) = EMPTY_HESSIAN_TRACER_BITSET +empty(::Type{HessianTracer{Set{UInt8}}}) = EMPTY_HESSIAN_TRACER_SET_UINT8 +empty(::Type{HessianTracer{Set{UInt16}}}) = EMPTY_HESSIAN_TRACER_SET_UINT16 +empty(::Type{HessianTracer{Set{UInt32}}}) = EMPTY_HESSIAN_TRACER_SET_UINT32 +empty(::Type{HessianTracer{Set{UInt64}}}) = EMPTY_HESSIAN_TRACER_SET_UINT64 -HessianTracer(::Number) = empty(HessianTracer) +HessianTracer{S}(::Number) where {S<:AbstractIndexSet} = empty(HessianTracer{S}) HessianTracer(t::HessianTracer) = t # Turn first-order interactions into second-order interactions @@ -145,16 +190,21 @@ end """ inputs(tracer) -Return raw `UInt64` input indices of a [`ConnectivityTracer`](@ref) or [`JacobianTracer`](@ref) +Return input indices of a [`ConnectivityTracer`](@ref) or [`JacobianTracer`](@ref) ## Example ```jldoctest -julia> t = tracer(ConnectivityTracer, 1, 2, 4) -ConnectivityTracer(1, 2, 4) +julia> a = tracer(ConnectivityTracer{BitSet}, 2) +ConnectivityTracer{BitSet}(2,) -julia> inputs(t) -3-element Vector{Int64}: - 1 +julia> b = tracer(ConnectivityTracer{BitSet}, 4) +ConnectivityTracer{BitSet}(4,) + +julia> c = a + b +ConnectivityTracer{BitSet}(2, 4) + +julia> inputs(c) +2-element Vector{Int64}: 2 4 ``` @@ -163,27 +213,48 @@ inputs(t::ConnectivityTracer) = collect(t.inputs) inputs(t::JacobianTracer) = collect(t.inputs) """ - tracer(JacobianTracer, index) - tracer(JacobianTracer, indices) - tracer(ConnectivityTracer, index) - tracer(ConnectivityTracer, indices) + tracer(ConnectivityTracer{S}, index) + tracer(ConnectivityTracer{S}, indices) + tracer(JacobianTracer{S}, index) + tracer(JacobianTracer{S}, indices) + tracer(HessianTracer{S}, index) + tracer(HessianTracer{S}, indices) + +Convenience constructor for [`ConnectivityTracer`](@ref), [`JacobianTracer`](@ref) and [`HessianTracer`](@ref) from input indices. +The provided set type `S` has to be an `AbstractSet{<:Integer}`, e.g. `BitSet` or `Set{UInt64}`. -Convenience constructor for [`JacobianTracer`](@ref) [`ConnectivityTracer`](@ref) from input indices. +## Example +```jldoctest +julia> tracer(JacobianTracer{BitSet}, 2) +JacobianTracer{BitSet}(2,) + +julia> tracer(HessianTracer{Set{UInt64}}, 2) +HessianTracer{Set{UInt64}}( + 2 => (), +) +``` """ -tracer(::Type{JacobianTracer}, index::Integer) = JacobianTracer(BitSet(index)) -tracer(::Type{ConnectivityTracer}, index::Integer) = ConnectivityTracer(BitSet(index)) -function tracer(::Type{HessianTracer}, index::Integer) - return HessianTracer(Dict{UInt64,BitSet}(index => BitSet())) +tracer(::Type{JacobianTracer{S}}, index::Integer) where {S<:AbstractIndexSet} = + JacobianTracer(S(index)) +function tracer(::Type{ConnectivityTracer{S}}, index::Integer) where {S<:AbstractIndexSet} + return ConnectivityTracer(S(index)) +end +function tracer(::Type{HessianTracer{S}}, index::Integer) where {S<:AbstractIndexSet} + return HessianTracer(Dict{UInt64,S}(index => S())) end -function tracer(::Type{JacobianTracer}, inds::NTuple{N,<:Integer}) where {N} - return JacobianTracer(BitSet(inds)) +function tracer( + ::Type{JacobianTracer{S}}, inds::NTuple{N,<:Integer} +) where {N,S<:AbstractIndexSet} + return JacobianTracer{S}(S(inds)) end -function tracer(::Type{ConnectivityTracer}, inds::NTuple{N,<:Integer}) where {N} - return ConnectivityTracer(BitSet(inds)) +function tracer( + ::Type{ConnectivityTracer{S}}, inds::NTuple{N,<:Integer} +) where {N,S<:AbstractIndexSet} + return ConnectivityTracer{S}(S(inds)) end -function tracer(::Type{HessianTracer}, inds::NTuple{N,<:Integer}) where {N} - return HessianTracer(Dict{UInt64,BitSet}(i => BitSet() for i in inds)) +function tracer( + ::Type{HessianTracer{S}}, inds::NTuple{N,<:Integer} +) where {N,S<:AbstractIndexSet} + return HessianTracer{S}(Dict{UInt64,S}(i => S() for i in inds)) end - -tracer(::Type{T}, inds...) where {T<:AbstractTracer} = tracer(T, inds) diff --git a/test/adtypes.jl b/test/adtypes.jl index 81628282..25b5272c 100644 --- a/test/adtypes.jl +++ b/test/adtypes.jl @@ -3,7 +3,7 @@ using SparseConnectivityTracer using SparseArrays using Test -sd = TracerSparsityDetector() +sd = TracerSparsityDetector(BitSet) x = rand(10) y = zeros(9) diff --git a/test/references/show/ConnectivityTracer.txt b/test/references/show/ConnectivityTracer.txt deleted file mode 100644 index 6f171124..00000000 --- a/test/references/show/ConnectivityTracer.txt +++ /dev/null @@ -1 +0,0 @@ -ConnectivityTracer(1, 2, 3) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_BitSet.txt b/test/references/show/ConnectivityTracer_BitSet.txt new file mode 100644 index 00000000..5bf162ba --- /dev/null +++ b/test/references/show/ConnectivityTracer_BitSet.txt @@ -0,0 +1 @@ +ConnectivityTracer{BitSet}(2,) \ No newline at end of file diff --git a/test/references/show/ConnectivityTracer_Set{UInt64}.txt b/test/references/show/ConnectivityTracer_Set{UInt64}.txt new file mode 100644 index 00000000..81843470 --- /dev/null +++ b/test/references/show/ConnectivityTracer_Set{UInt64}.txt @@ -0,0 +1 @@ +ConnectivityTracer{Set{UInt64}}(0x0000000000000002,) \ No newline at end of file diff --git a/test/references/show/HessianTracer.txt b/test/references/show/HessianTracer.txt deleted file mode 100644 index e220af71..00000000 --- a/test/references/show/HessianTracer.txt +++ /dev/null @@ -1,5 +0,0 @@ -HessianTracer( - 2 => (), - 3 => (), - 1 => (), -) \ No newline at end of file diff --git a/test/references/show/HessianTracer_BitSet.txt b/test/references/show/HessianTracer_BitSet.txt new file mode 100644 index 00000000..0aeee5ed --- /dev/null +++ b/test/references/show/HessianTracer_BitSet.txt @@ -0,0 +1,3 @@ +HessianTracer{BitSet}( + 2 => (), +) \ No newline at end of file diff --git a/test/references/show/HessianTracer_Set{UInt64}.txt b/test/references/show/HessianTracer_Set{UInt64}.txt new file mode 100644 index 00000000..54572703 --- /dev/null +++ b/test/references/show/HessianTracer_Set{UInt64}.txt @@ -0,0 +1,3 @@ +HessianTracer{Set{UInt64}}( + 2 => (), +) \ No newline at end of file diff --git a/test/references/show/JacobianTracer.txt b/test/references/show/JacobianTracer.txt deleted file mode 100644 index 27034783..00000000 --- a/test/references/show/JacobianTracer.txt +++ /dev/null @@ -1 +0,0 @@ -JacobianTracer(1, 2, 3) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_BitSet.txt b/test/references/show/JacobianTracer_BitSet.txt new file mode 100644 index 00000000..706e87bf --- /dev/null +++ b/test/references/show/JacobianTracer_BitSet.txt @@ -0,0 +1 @@ +JacobianTracer{BitSet}(2,) \ No newline at end of file diff --git a/test/references/show/JacobianTracer_Set{UInt64}.txt b/test/references/show/JacobianTracer_Set{UInt64}.txt new file mode 100644 index 00000000..47a9299a --- /dev/null +++ b/test/references/show/JacobianTracer_Set{UInt64}.txt @@ -0,0 +1 @@ +JacobianTracer{Set{UInt64}}(0x0000000000000002,) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index a086337a..4e2210e9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,255 +44,264 @@ DocMeta.setdocmeta!( include("test_differentiability.jl") end @testset "First order" begin - x = rand(3) - xt = trace_input(ConnectivityTracer, x) - - # Matrix multiplication - A = rand(1, 3) - yt = only(A * xt) - @test inputs(yt) == [1, 2, 3] - - @test pattern(x -> only(A * x), ConnectivityTracer, x) ≈ [1 1 1] - - # Custom functions - f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] - yt = f(xt) - @test inputs(yt[1]) == [1] - @test inputs(yt[2]) == [1, 2] - @test inputs(yt[3]) == [3] - - @test pattern(f, ConnectivityTracer, x) ≈ [1 0 0; 1 1 0; 0 0 1] - @test pattern(f, JacobianTracer, x) ≈ [1 0 0; 1 1 0; 0 0 1] - - @test pattern(identity, ConnectivityTracer, rand()) ≈ [1;;] - @test pattern(identity, JacobianTracer, rand()) ≈ [1;;] - @test pattern(Returns(1), ConnectivityTracer, 1) ≈ [0;;] - @test pattern(Returns(1), JacobianTracer, 1) ≈ [0;;] - - # Test JacobianTracer on functions with zero derivatives - x = rand(2) - g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] - @test pattern(g, ConnectivityTracer, x) ≈ [1 1; 1 1; 1 1] - @test pattern(g, JacobianTracer, x) ≈ [1 1; 0 0; 1 0] - - # Code coverage - @test pattern(x -> [sincos(x)...], ConnectivityTracer, 1) ≈ [1; 1] - @test pattern(x -> [sincos(x)...], JacobianTracer, 1) ≈ [1; 1] - @test pattern(typemax, ConnectivityTracer, 1) ≈ [0;;] - @test pattern(typemax, JacobianTracer, 1) ≈ [0;;] - @test pattern(x -> x^(2//3), ConnectivityTracer, 1) ≈ [1;;] - @test pattern(x -> x^(2//3), JacobianTracer, 1) ≈ [1;;] - @test pattern(x -> (2//3)^x, ConnectivityTracer, 1) ≈ [1;;] - @test pattern(x -> (2//3)^x, JacobianTracer, 1) ≈ [1;;] - @test pattern(x -> x^ℯ, ConnectivityTracer, 1) ≈ [1;;] - @test pattern(x -> x^ℯ, JacobianTracer, 1) ≈ [1;;] - @test pattern(x -> ℯ^x, ConnectivityTracer, 1) ≈ [1;;] - @test pattern(x -> ℯ^x, JacobianTracer, 1) ≈ [1;;] - @test pattern(x -> round(x, RoundNearestTiesUp), ConnectivityTracer, 1) ≈ [1;;] - @test pattern(x -> round(x, RoundNearestTiesUp), JacobianTracer, 1) ≈ [0;;] - - @test rand(ConnectivityTracer) == empty(ConnectivityTracer) - @test rand(JacobianTracer) == empty(JacobianTracer) - - t = tracer(ConnectivityTracer, 1, 2, 3) - @test ConnectivityTracer(t) == t - @test empty(t) == empty(ConnectivityTracer) - @test ConnectivityTracer(1) == empty(ConnectivityTracer) - - t = tracer(JacobianTracer, 1, 2, 3) - @test JacobianTracer(t) == t - @test empty(t) == empty(JacobianTracer) - @test JacobianTracer(1) == empty(JacobianTracer) - - # Base.show - @test_reference "references/show/ConnectivityTracer.txt" repr( - "text/plain", tracer(ConnectivityTracer, 1, 2, 3) - ) - @test_reference "references/show/JacobianTracer.txt" repr( - "text/plain", tracer(JacobianTracer, 1, 2, 3) - ) + for S in (BitSet, Set{UInt64}) + @testset "Set type $S" begin + CT = ConnectivityTracer{S} + JT = JacobianTracer{S} + + x = rand(3) + xt = trace_input(CT, x) + + # Matrix multiplication + A = rand(1, 3) + yt = only(A * xt) + @test pattern(x -> only(A * x), CT, x) ≈ [1 1 1] + + # Custom functions + f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] + yt = f(xt) + + @test pattern(f, CT, x) ≈ [1 0 0; 1 1 0; 0 0 1] + @test pattern(f, JT, x) ≈ [1 0 0; 1 1 0; 0 0 1] + + @test pattern(identity, CT, rand()) ≈ [1;;] + @test pattern(identity, JT, rand()) ≈ [1;;] + @test pattern(Returns(1), CT, 1) ≈ [0;;] + @test pattern(Returns(1), JT, 1) ≈ [0;;] + + # Test JacobianTracer on functions with zero derivatives + x = rand(2) + g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] + @test pattern(g, CT, x) ≈ [1 1; 1 1; 1 1] + @test pattern(g, JT, x) ≈ [1 1; 0 0; 1 0] + + # Code coverage + @test pattern(x -> [sincos(x)...], CT, 1) ≈ [1; 1] + @test pattern(x -> [sincos(x)...], JT, 1) ≈ [1; 1] + @test pattern(typemax, CT, 1) ≈ [0;;] + @test pattern(typemax, JT, 1) ≈ [0;;] + @test pattern(x -> x^(2//3), CT, 1) ≈ [1;;] + @test pattern(x -> x^(2//3), JT, 1) ≈ [1;;] + @test pattern(x -> (2//3)^x, CT, 1) ≈ [1;;] + @test pattern(x -> (2//3)^x, JT, 1) ≈ [1;;] + @test pattern(x -> x^ℯ, CT, 1) ≈ [1;;] + @test pattern(x -> x^ℯ, JT, 1) ≈ [1;;] + @test pattern(x -> ℯ^x, CT, 1) ≈ [1;;] + @test pattern(x -> ℯ^x, JT, 1) ≈ [1;;] + @test pattern(x -> round(x, RoundNearestTiesUp), CT, 1) ≈ [1;;] + @test pattern(x -> round(x, RoundNearestTiesUp), JT, 1) ≈ [0;;] + + @test rand(CT) == empty(CT) + @test rand(JT) == empty(JT) + + t = tracer(CT, 2) + @test ConnectivityTracer(t) == t + @test empty(t) == empty(CT) + @test CT(1) == empty(CT) + + t = tracer(JT, 2) + @test JacobianTracer(t) == t + @test empty(t) == empty(JT) + @test JT(1) == empty(JT) + + # Base.show + @test_reference "references/show/ConnectivityTracer_$S.txt" repr( + "text/plain", tracer(CT, 2) + ) + @test_reference "references/show/JacobianTracer_$S.txt" repr( + "text/plain", tracer(JT, 2) + ) + end + end end @testset "Second order" begin - @test pattern(identity, HessianTracer, rand()) ≈ [0;;] - @test pattern(sqrt, HessianTracer, rand()) ≈ [1;;] - - @test pattern(x -> 1 * x, HessianTracer, rand()) ≈ [0;;] - @test pattern(x -> x * 1, HessianTracer, rand()) ≈ [0;;] - - # Code coverage - @test pattern(typemax, HessianTracer, 1) ≈ [0;;] - @test pattern(x -> x^(2im), HessianTracer, 1) ≈ [1;;] - @test pattern(x -> (2im)^x, HessianTracer, 1) ≈ [1;;] - @test pattern(x -> x^(2//3), HessianTracer, 1) ≈ [1;;] - @test pattern(x -> (2//3)^x, HessianTracer, 1) ≈ [1;;] - @test pattern(x -> x^ℯ, HessianTracer, 1) ≈ [1;;] - @test pattern(x -> ℯ^x, HessianTracer, 1) ≈ [1;;] - @test pattern(x -> round(x, RoundNearestTiesUp), HessianTracer, 1) ≈ [0;;] - - @test rand(HessianTracer) == empty(HessianTracer) - - t = tracer(HessianTracer, 1, 2, 3) - @test HessianTracer(t) == t - @test empty(t) == empty(HessianTracer) - @test HessianTracer(1) == empty(HessianTracer) - - x = rand(4) - - f(x) = x[1] / x[2] + x[3] / 1 + 1 / x[4] - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 1 - ] - - f(x) = x[1] * x[2] + x[3] * 1 + 1 * x[4] - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - f(x) = (x[1] * x[2]) * (x[3] * x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 1 1 1 - 1 0 1 1 - 1 1 0 1 - 1 1 1 0 - ] - - f(x) = (x[1] + x[2]) * (x[3] + x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 0 1 1 - 0 0 1 1 - 1 1 0 0 - 1 1 0 0 - ] - - f(x) = (x[1] + x[2] + x[3] + x[4])^2 - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] - - f(x) = 1 / (x[1] + x[2] + x[3] + x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - 1 1 1 1 - ] - - f(x) = (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - f(x) = copysign(x[1] * x[2], x[3] * x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 1 0 0 - 1 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - f(x) = div(x[1] * x[2], x[3] * x[4]) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - 0 0 0 0 - ] - - x = rand() - f(x) = sum(sincosd(x)) - H = pattern(f, HessianTracer, x) - @test H ≈ [1;;] - - x = rand(4) - f(x) = sum(diff(x) .^ 3) - H = pattern(f, HessianTracer, x) - @test H ≈ [ - 1 1 0 0 - 1 1 1 0 - 0 1 1 1 - 0 0 1 1 - ] - - x = rand(5) - foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] - H = pattern(foo, HessianTracer, x) - @test H ≈ [ - 0 0 0 0 0 - 0 0 1 0 0 - 0 1 0 0 0 - 0 0 0 1 0 - 0 0 0 0 0 - ] - - bar(x) = foo(x) + x[2]^x[5] - H = pattern(bar, HessianTracer, x) - @test H ≈ [ - 0 0 0 0 0 - 0 1 1 0 1 - 0 1 0 0 0 - 0 0 0 1 0 - 0 1 0 0 1 - ] - - # Base.show - @test_reference "references/show/HessianTracer.txt" repr( - "text/plain", tracer(HessianTracer, 1, 2, 3) - ) + for S in (BitSet, Set{UInt64}) + @testset "Set type $S" begin + HT = HessianTracer{S} + @test pattern(identity, HT, rand()) ≈ [0;;] + @test pattern(sqrt, HT, rand()) ≈ [1;;] + + @test pattern(x -> 1 * x, HT, rand()) ≈ [0;;] + @test pattern(x -> x * 1, HT, rand()) ≈ [0;;] + + # Code coverage + @test pattern(typemax, HT, 1) ≈ [0;;] + @test pattern(x -> x^(2im), HT, 1) ≈ [1;;] + @test pattern(x -> (2im)^x, HT, 1) ≈ [1;;] + @test pattern(x -> x^(2//3), HT, 1) ≈ [1;;] + @test pattern(x -> (2//3)^x, HT, 1) ≈ [1;;] + @test pattern(x -> x^ℯ, HT, 1) ≈ [1;;] + @test pattern(x -> ℯ^x, HT, 1) ≈ [1;;] + @test pattern(x -> round(x, RoundNearestTiesUp), HT, 1) ≈ [0;;] + + @test rand(HT) == empty(HT) + + t = tracer(HT, 2) + @test HessianTracer(t) == t + @test empty(t) == empty(HT) + @test HT(1) == empty(HT) + + H = pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], HT, rand(4)) + @test H ≈ [ + 0 1 0 0 + 1 1 0 0 + 0 0 0 0 + 0 0 0 1 + ] + + H = pattern(x -> x[1] * x[2] + x[3] * 1 + 1 * x[4], HT, rand(4)) + @test H ≈ [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + H = pattern(x -> (x[1] * x[2]) * (x[3] * x[4]), HT, rand(4)) + @test H ≈ [ + 0 1 1 1 + 1 0 1 1 + 1 1 0 1 + 1 1 1 0 + ] + + H = pattern(x -> (x[1] + x[2]) * (x[3] + x[4]), HT, rand(4)) + @test H ≈ [ + 0 0 1 1 + 0 0 1 1 + 1 1 0 0 + 1 1 0 0 + ] + + H = pattern(x -> (x[1] + x[2] + x[3] + x[4])^2, HT, rand(4)) + @test H ≈ [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + H = pattern(x -> 1 / (x[1] + x[2] + x[3] + x[4]), HT, rand(4)) + @test H ≈ [ + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + ] + + H = pattern(x -> (x[1] - x[2]) + (x[3] - 1) + (1 - x[4]), HT, rand(4)) + @test H ≈ [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + H = pattern(x -> copysign(x[1] * x[2], x[3] * x[4]), HT, rand(4)) + @test H ≈ [ + 0 1 0 0 + 1 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + H = pattern(x -> div(x[1] * x[2], x[3] * x[4]), HT, rand(4)) + @test H ≈ [ + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + 0 0 0 0 + ] + + H = pattern(x -> sum(sincosd(x)), HT, 1.0) + @test H ≈ [1;;] + + H = pattern(x -> sum(diff(x) .^ 3), HT, rand(4)) + @test H ≈ [ + 1 1 0 0 + 1 1 1 0 + 0 1 1 1 + 0 0 1 1 + ] + + x = rand(5) + foo(x) = x[1] + x[2] * x[3] + 1 / x[4] + 1 * x[5] + H = pattern(foo, HT, x) + @test H ≈ [ + 0 0 0 0 0 + 0 0 1 0 0 + 0 1 0 0 0 + 0 0 0 1 0 + 0 0 0 0 0 + ] + + bar(x) = foo(x) + x[2]^x[5] + H = pattern(bar, HT, x) + @test H ≈ [ + 0 0 0 0 0 + 0 1 1 0 1 + 0 1 0 0 0 + 0 0 0 1 0 + 0 1 0 0 1 + ] + + # Base.show + @test_reference "references/show/HessianTracer_$S.txt" repr( + "text/plain", tracer(HT, 2) + ) + end + end end @testset "Real-world tests" begin - @testset "NNlib" begin - x = rand(3, 3, 2, 1) # WHCN - w = rand(2, 2, 2, 1) # Conv((2, 2), 2 => 1) - C = pattern(x -> NNlib.conv(x, w), ConnectivityTracer, x) - @test_reference "references/pattern/connectivity/NNlib/conv.txt" BitMatrix(C) - J = pattern(x -> NNlib.conv(x, w), JacobianTracer, x) - @test_reference "references/pattern/jacobian/NNlib/conv.txt" BitMatrix(J) - @test C == J - end - @testset "Brusselator" begin - include("brusselator.jl") - N = 6 - dims = (N, N, 2) - A = 1.0 - B = 1.0 - alpha = 1.0 - xyd = fill(1.0, N) - dx = 1.0 - p = (A, B, alpha, xyd, dx, N) - - u = rand(dims...) - du = similar(u) - f!(du, u) = brusselator_2d_loop(du, u, p, nothing) - - C = pattern(f!, du, ConnectivityTracer, u) - @test_reference "references/pattern/connectivity/Brusselator.txt" BitMatrix(C) - J = pattern(f!, du, JacobianTracer, u) - @test_reference "references/pattern/jacobian/Brusselator.txt" BitMatrix(J) - @test C == J - - C_ref = Symbolics.jacobian_sparsity(f!, du, u) - @test C == C_ref + include("brusselator.jl") + + for S in (BitSet, Set{UInt64}) + @testset "Set type $S" begin + CT = ConnectivityTracer{S} + JT = JacobianTracer{S} + HT = HessianTracer{S} + + @testset "Brusselator" begin + N = 6 + dims = (N, N, 2) + A = 1.0 + B = 1.0 + alpha = 1.0 + xyd = fill(1.0, N) + dx = 1.0 + p = (A, B, alpha, xyd, dx, N) + + u = rand(dims...) + du = similar(u) + f!(du, u) = brusselator_2d_loop(du, u, p, nothing) + + C = pattern(f!, du, CT, u) + @test_reference "references/pattern/connectivity/Brusselator.txt" BitMatrix( + C + ) + J = pattern(f!, du, JT, u) + @test_reference "references/pattern/jacobian/Brusselator.txt" BitMatrix( + J + ) + @test C == J + + C_ref = Symbolics.jacobian_sparsity(f!, du, u) + @test C == C_ref + end + @testset "NNlib" begin + x = rand(3, 3, 2, 1) # WHCN + w = rand(2, 2, 2, 1) # Conv((2, 2), 2 => 1) + C = pattern(x -> NNlib.conv(x, w), CT, x) + @test_reference "references/pattern/connectivity/NNlib/conv.txt" BitMatrix( + C + ) + J = pattern(x -> NNlib.conv(x, w), JT, x) + @test_reference "references/pattern/jacobian/NNlib/conv.txt" BitMatrix( + J + ) + @test C == J + end + end end end @testset "ADTypes integration" begin