Skip to content

Commit

Permalink
Use multiple forward diff trials to generate approximate Jacobian spa…
Browse files Browse the repository at this point in the history
…rsity pattern
  • Loading branch information
avik-pal committed Sep 10, 2023
1 parent 1eb7252 commit 3b3bf9b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseDiffTools"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
version = "2.6.0"
version = "2.7.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -14,6 +14,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand Down
4 changes: 2 additions & 2 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using ArrayInterface, SparseArrays
import ArrayInterface: matrix_colors
import StaticArrays
# Others
using SciMLOperators, LinearAlgebra
using SciMLOperators, LinearAlgebra, Random
import DataStructures: DisjointSets, find_root!, union!
import SciMLOperators: update_coefficients, update_coefficients!
import Setfield: @set!
Expand Down Expand Up @@ -89,7 +89,7 @@ export update_coefficients, update_coefficients!, value!
export AutoSparseEnzyme

export NoSparsityDetection, SymbolicsSparsityDetection, JacPrototypeSparsityDetection,
PrecomputedJacobianColorvec, AutoSparsityDetection
PrecomputedJacobianColorvec, ApproximateJacobianSparsity, AutoSparsityDetection
export sparse_jacobian, sparse_jacobian_cache, sparse_jacobian!
export init_jacobian

Expand Down
26 changes: 26 additions & 0 deletions src/highlevel/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,31 @@ function (alg::PrecomputedJacobianColorvec)(ad::AbstractSparseADType, args...; k
return MatrixColoringResult(colorvec, J, nz_rows, nz_cols)
end

# Approximate Jacobian Sparsity Detection
## Right now we hardcode it to use `ForwardDiff`
function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f, x; kwargs...)
@unpack ntrials, rng = alg
cfg = ForwardDiff.JacobianConfig(f, x)
J = sum(1:ntrials) do _
local x_ = similar(x)
rand!(rng, x_)
abs.(ForwardDiff.jacobian(f, x_, cfg))

Check warning on line 41 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L35-L41

Added lines #L35 - L41 were not covered by tests
end
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(ad, f, x;

Check warning on line 43 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L43

Added line #L43 was not covered by tests
kwargs...)
end

function (alg::ApproximateJacobianSparsity)(ad::AbstractSparseADType, f!, fx, x; kwargs...)
@unpack ntrials, rng = alg
cfg = ForwardDiff.JacobianConfig(f!, fx, x)
J = sum(1:ntrials) do _
local x_ = similar(x)
rand!(rng, x_)
abs.(ForwardDiff.jacobian(f!, fx, x_, cfg))

Check warning on line 53 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L47-L53

Added lines #L47 - L53 were not covered by tests
end
return (JacPrototypeSparsityDetection(; jac_prototype = sparse(J), alg.alg))(ad, f!, fx,

Check warning on line 55 in src/highlevel/coloring.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/coloring.jl#L55

Added line #L55 was not covered by tests
x; kwargs...)
end

# TODO: Heuristics to decide whether to use Sparse Differentiation or not
# Simple Idea: Check min(max(colorvec_cols), max(colorvec_rows))
25 changes: 25 additions & 0 deletions src/highlevel/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ function _get_colorvec(alg::PrecomputedJacobianColorvec, ::AbstractReverseMode)
return cvec
end

"""
ApproximateJacobianSparsity(; ntrials = 5, rng = Random.default_rng(),
alg = GreedyD1Color())
Use `ntrials` random vectors to compute the sparsity pattern of the Jacobian. This is an
approximate method and the sparsity pattern may not be exact.
## Keyword Arguments
- `ntrials`: The number of random vectors to use for computing the sparsity pattern
- `rng`: The random number generator used for generating the random vectors
- `alg`: The algorithm used for computing the matrix colors
"""
struct ApproximateJacobianSparsity{R <: AbstractRNG,
A <: ArrayInterface.ColoringAlgorithm} <: AbstractSparsityDetection
ntrials::Int
rng::R
alg::A
end

function ApproximateJacobianSparsity(; ntrials::Int = 3,

Check warning on line 135 in src/highlevel/common.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/common.jl#L135

Added line #L135 was not covered by tests
rng::AbstractRNG = Random.default_rng(), alg = GreedyD1Color())
return ApproximateJacobianSparsity(ntrials, rng, alg)

Check warning on line 137 in src/highlevel/common.jl

View check run for this annotation

Codecov / codecov/patch

src/highlevel/common.jl#L137

Added line #L137 was not covered by tests
end

# No one should be using this currently
Base.@kwdef struct AutoSparsityDetection{A <: ArrayInterface.ColoringAlgorithm} <:
AbstractSparsityDetection
Expand Down
2 changes: 1 addition & 1 deletion test/test_sparse_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ row_colorvec = SparseDiffTools.matrix_colors(J_sparsity; partition_by_rows = tru
col_colorvec = SparseDiffTools.matrix_colors(J_sparsity; partition_by_rows = false)

SPARSITY_DETECTION_ALGS = [JacPrototypeSparsityDetection(; jac_prototype = J_sparsity),
SymbolicsSparsityDetection(), NoSparsityDetection(),
SymbolicsSparsityDetection(), NoSparsityDetection(), ApproximateJacobianSparsity(),
PrecomputedJacobianColorvec(; jac_prototype = J_sparsity, row_colorvec, col_colorvec)]

@testset "High-Level API" begin
Expand Down

0 comments on commit 3b3bf9b

Please sign in to comment.