Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor high-level API #32

Merged
merged 8 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ julia> x = rand(3);

julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];

julia> pattern(f, JacobianTracer{BitSet}, x)
julia> jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
Expand All @@ -46,7 +46,7 @@ julia> x = rand(28, 28, 3, 1);

julia> layer = Conv((3, 3), 3 => 8);

julia> pattern(layer, JacobianTracer{BitSet}, x)
julia> jacobian_pattern(layer, x)
5408×2352 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 146016 stored entries:
⎡⠙⢦⡀⠀⠀⠘⢷⣄⠀⠀⠈⠻⣦⡀⠀⠀⠀⎤
⎢⠀⠀⠙⢷⣄⠀⠀⠙⠷⣄⠀⠀⠈⠻⣦⡀⠀⎥
Expand All @@ -69,6 +69,9 @@ julia> pattern(layer, JacobianTracer{BitSet}, x)
⎣⠀⠀⠀⠙⢷⣄⠀⠀⠈⠻⣦⠀⠀⠀⠙⢦⡀⎦
```

The type of index set `T<:AbstractSet{<:Integer}` that is internally used to keep track of connectivity can be specified via `jacobian_pattern(f, x, T)`, defaulting to `BitSet`.
For high-dimensional functions, `Set{UInt64}` can be more efficient .

### Hessian

For scalar functions `y = f(x)`, the sparsity pattern of the Hessian of $f$ can be obtained
Expand All @@ -79,7 +82,7 @@ julia> x = rand(5);

julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];

julia> pattern(f, HessianTracer{BitSet}, x)
julia> hessian_pattern(f, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -89,7 +92,7 @@ julia> pattern(f, HessianTracer{BitSet}, x)

julia> g(x) = f(x) + x[2]^x[5];

julia> pattern(g, HessianTracer{BitSet}, x)
julia> hessian_pattern(g, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ makedocs(;
assets = String[],
),
pages=["Home" => "index.md", "API Reference" => "api.md"],
warnonly=[:missing_docs],
)

deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main")
19 changes: 6 additions & 13 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,20 @@ CollapsedDocStrings = true

## Interface
```@docs
pattern
connectivity_pattern
jacobian_pattern
hessian_pattern
```
```@docs
TracerSparsityDetector
```

## Internals
SparseConnectivityTracer works by pushing `Number` types called tracers through generic functions.
Currently, two tracer types are provided:
Currently, three tracer types are provided:

```@docs
ConnectivityTracer
JacobianTracer
HessianTracer
```

Utilities to create tracers:
```@docs
tracer
trace_input
```

Utility to extract input indices from tracers:
```@docs
inputs
```
10 changes: 4 additions & 6 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using ADTypes: ADTypes
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

abstract type AbstractTracer <: Number end

include("tracers.jl")
include("conversion.jl")
include("operators.jl")
Expand All @@ -15,10 +13,10 @@ include("overload_hessian.jl")
include("pattern.jl")
include("adtypes.jl")

export JacobianTracer, ConnectivityTracer, HessianTracer
export tracer, trace_input
export inputs
export pattern
export ConnectivityTracer, connectivity_pattern
export JacobianTracer, jacobian_pattern
export HessianTracer, hessian_pattern

export TracerSparsityDetector

end # module
6 changes: 3 additions & 3 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ TracerSparsityDetector() = TracerSparsityDetector(BitSet)
function ADTypes.jacobian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, JacobianTracer{S}, x)
return jacobian_pattern(f, x, S)
end

function ADTypes.jacobian_sparsity(
f!, y, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f!, y, JacobianTracer{S}, x)
return jacobian_pattern(f!, y, x, S)
end

function ADTypes.hessian_sparsity(
f, x, ::TracerSparsityDetector{S}
) where {S<:AbstractIndexSet}
return pattern(f, HessianTracer{S}, x)
return hessian_pattern(f, x, S)
end
132 changes: 78 additions & 54 deletions src/pattern.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,13 @@
## Enumerate inputs
const DEFAULT_SET_TYPE = BitSet

## Enumerate inputs
"""
trace_input(T, x)
trace_input(T, x)


Enumerates input indices and constructs the specified type `T` of tracer.
Supports [`ConnectivityTracer`](@ref), [`JacobianTracer`](@ref) and [`HessianTracer`](@ref).

## Example
```jldoctest
julia> x = rand(3);

julia> trace_input(ConnectivityTracer{BitSet}, x)
3-element Vector{ConnectivityTracer{BitSet}}:
ConnectivityTracer{BitSet}(1,)
ConnectivityTracer{BitSet}(2,)
ConnectivityTracer{BitSet}(3,)

julia> trace_input(JacobianTracer{BitSet}, x)
3-element Vector{JacobianTracer{BitSet}}:
JacobianTracer{BitSet}(1,)
JacobianTracer{BitSet}(2,)
JacobianTracer{BitSet}(3,)

julia> trace_input(HessianTracer{BitSet}, x)
3-element Vector{HessianTracer{BitSet}}:
HessianTracer{BitSet}(
1 => (),
)
HessianTracer{BitSet}(
2 => (),
)
HessianTracer{BitSet}(
3 => (),
)
```
"""
trace_input(::Type{T}, x) where {T<:AbstractTracer} = trace_input(T, x, 1)
trace_input(::Type{T}, ::Number, i) where {T<:AbstractTracer} = tracer(T, i)
Expand All @@ -46,42 +18,100 @@ end

## Construct sparsity pattern matrix
"""
pattern(f, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}}
connectivity_pattern(f, x)
connectivity_pattern(f, x, T)

Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.

pattern(f, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}}
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.

Computes the sparsity pattern of the Jacobian of `y = f(x)`.
## Example

pattern(f, HessianTracer{S}, x) where {S<:AbstractSet{<:Integer}}
```jldoctest
julia> x = rand(3);

Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`.
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])];

julia> connectivity_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
```
"""
connectivity_pattern(f, x, settype::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet} =
pattern(f, ConnectivityTracer{S}, x)

"""
connectivity_pattern(f!, y, x)
connectivity_pattern(f!, y, x, T)

Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.

The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
"""
function connectivity_pattern(
f!, y, x, ::Type{S}=DEFAULT_SET_TYPE
) where {S<:AbstractIndexSet}
return pattern(f!, y, ConnectivityTracer{S}, x)
end

"""
jacobian_pattern(f, x)
jacobian_pattern(f, x, T)

Compute the sparsity pattern of the Jacobian of `y = f(x)`.

## Examples
### First order
The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.

## Example

```jldoctest
julia> x = rand(3);

julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])];
julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])];

julia> pattern(f, ConnectivityTracer{BitSet}, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 4 stored entries:
julia> jacobian_pattern(f, x)
3×3 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
1 ⋅ ⋅
1 1 ⋅
⋅ ⋅ 1
⋅ ⋅
```
"""
function jacobian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f, JacobianTracer{S}, x)
end

### Second order
"""
jacobian_pattern(f!, y, x)
jacobian_pattern(f!, y, x, T)

Compute the sparsity pattern of the Jacobian of `f!(y, x)`.

The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.
"""
function jacobian_pattern(f!, y, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f!, y, JacobianTracer{S}, x)
end

"""
hessian_pattern(f, x)
hessian_pattern(f, x, T)

Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`.

The type of index set `T<:AbstractSet{<:Integer}` can be specified as an optional argument and defaults to `BitSet`.

## Example

```jldoctest
julia> x = rand(5);

julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5];

julia> pattern(f, HessianTracer{BitSet}, x)
julia> hessian_pattern(f, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 3 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
Expand All @@ -91,7 +121,7 @@ julia> pattern(f, HessianTracer{BitSet}, x)

julia> g(x) = f(x) + x[2]^x[5];

julia> pattern(g, HessianTracer{BitSet}, x)
julia> hessian_pattern(g, x)
5×5 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 7 stored entries:
⋅ ⋅ ⋅ ⋅ ⋅
⋅ 1 1 ⋅ 1
Expand All @@ -100,22 +130,16 @@ julia> pattern(g, HessianTracer{BitSet}, x)
⋅ 1 ⋅ ⋅ 1
```
"""
function hessian_pattern(f, x, ::Type{S}=DEFAULT_SET_TYPE) where {S<:AbstractIndexSet}
return pattern(f, HessianTracer{S}, x)
end

function pattern(f, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = f(xt)
return _pattern(xt, yt)
end

"""
pattern(f!, y, JacobianTracer{S}, x) where {S<:AbstractSet{<:Integer}}

Computes the sparsity pattern of the Jacobian of `f!(y, x)`.

pattern(f!, y, ConnectivityTracer{S}, x) where {S<:AbstractSet{<:Integer}}

Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)`
where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`.
"""
function pattern(f!, y, ::Type{T}, x) where {T<:AbstractTracer}
xt = trace_input(T, x)
yt = similar(y, T)
Expand Down
Loading
Loading