Skip to content

Commit

Permalink
Enable tag checking
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2023
1 parent 5d53d9d commit 90bba63
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
18 changes: 9 additions & 9 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
git-tree-sha1 = "94384b09e50ea01819b6db01ac08403ebe09bf65"
repo-rev = "ap/tstable_termination"
git-tree-sha1 = "53ad089996089756cae5a098b1a0542aeaab466f"
repo-rev = "master"
repo-url = "https://github.com/SciML/DiffEqBase.jl"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.136.0"
Expand Down Expand Up @@ -425,9 +425,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[deps.LinearSolve]]
deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "EnzymeCore", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "SuiteSparse", "UnPack"]
git-tree-sha1 = "27732d23d88534a7b735dcf8f411daf34293a39e"
git-tree-sha1 = "9f807ca41005f9a8f092716e48022ee5b36cf5b1"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
version = "2.14.0"
version = "2.14.1"

[deps.LinearSolve.extensions]
LinearSolveBandedMatricesExt = "BandedMatrices"
Expand Down Expand Up @@ -639,9 +639,9 @@ version = "1.3.4"

[[deps.RecursiveArrayTools]]
deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"]
git-tree-sha1 = "d7087c013e8a496ff396bae843b1e16d9a30ede8"
git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "2.38.10"
version = "2.39.0"

[deps.RecursiveArrayTools.extensions]
RecursiveArrayToolsMeasurementsExt = "Measurements"
Expand Down Expand Up @@ -689,9 +689,9 @@ version = "0.1.0"

[[deps.SLEEFPirates]]
deps = ["IfElse", "Static", "VectorizationBase"]
git-tree-sha1 = "f5c896d781486f1d67c8492f0e0ead2c3517208c"
git-tree-sha1 = "3aac6d68c5e57449f5b9b865c9ba50ac2970c4cf"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
version = "0.6.41"
version = "0.6.42"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"]
Expand Down Expand Up @@ -766,7 +766,7 @@ version = "1.10.0"

[[deps.SparseDiffTools]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
git-tree-sha1 = "5188e5e415908a19a41cd90d8ab74a23affacba6"
git-tree-sha1 = "888937b8348e1e9ffae1c31efa61e693bc5463ba"
repo-rev = "ap/tagging"
repo-url = "https://github.com/avik-pal/SparseDiffTools.jl"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Expand Down
20 changes: 7 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ end

struct NonlinearSolveTag end

function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:NonlinearSolveTag, <:T}}, f::F,
x::AbstractArray{T}) where {T, F}
return true
end

"""
default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), diff_type = Val{:forward})
Expand Down Expand Up @@ -43,7 +48,7 @@ function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing,

ad = _unwrap_val(autodiff)
# We don't really know the typeof the input yet, so we can't use the correct tag!
ad && return AutoForwardDiff{_unwrap_val(chunk_size), Nothing}(nothing)
ad && return AutoForwardDiff{_unwrap_val(chunk_size)}(; tag = NonlinearSolveTag())
return AutoFiniteDiff(; fdtype = diff_type)
end

Expand Down Expand Up @@ -117,17 +122,6 @@ function wrapprecs(_Pl, _Pr, weight)
return Pl, Pr
end

function _nfcount(N, ::Type{diff_type}) where {diff_type}
if diff_type === Val{:complex}
tmp = N
elseif diff_type === Val{:forward}
tmp = N + 1
else
tmp = 2N
end
return tmp
end

get_loss(fu) = norm(fu)^2 / 2

function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method
Expand Down Expand Up @@ -203,7 +197,7 @@ function __get_concrete_algorithm(alg, prob)
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
else
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
tag = NonlinearSolveTag())
end
return set_ad(alg, ad)
end
Expand Down
7 changes: 7 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ end
u0 = init_brusselator_2d(xyd_brusselator)
prob_brusselator_2d = NonlinearProblem(brusselator_2d_loop, u0, p)
sol = solve(prob_brusselator_2d, NewtonRaphson())
@test norm(sol.resid) < 1e-8

sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseForwardDiff()))
@test norm(sol.resid) < 1e-8

sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseFiniteDiff()))
@test norm(sol.resid) < 1e-8

du0 = copy(u0)
jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> brusselator_2d_loop(du, u, p), du0,
Expand Down

0 comments on commit 90bba63

Please sign in to comment.