diff --git a/Manifest.toml b/Manifest.toml index e62925a59..b4503615c 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -164,8 +164,8 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DiffEqBase]] deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"] -git-tree-sha1 = "94384b09e50ea01819b6db01ac08403ebe09bf65" -repo-rev = "ap/tstable_termination" +git-tree-sha1 = "53ad089996089756cae5a098b1a0542aeaab466f" +repo-rev = "master" repo-url = "https://github.com/SciML/DiffEqBase.jl" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" version = "6.136.0" @@ -425,9 +425,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LinearSolve]] deps = ["ArrayInterface", "ConcreteStructs", "DocStringExtensions", "EnumX", "EnzymeCore", "FastLapackInterface", "GPUArraysCore", "InteractiveUtils", "KLU", "Krylov", "Libdl", "LinearAlgebra", "MKL_jll", "PrecompileTools", "Preferences", "RecursiveFactorization", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Sparspak", "SuiteSparse", "UnPack"] -git-tree-sha1 = "27732d23d88534a7b735dcf8f411daf34293a39e" +git-tree-sha1 = "9f807ca41005f9a8f092716e48022ee5b36cf5b1" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" -version = "2.14.0" +version = "2.14.1" [deps.LinearSolve.extensions] LinearSolveBandedMatricesExt = "BandedMatrices" @@ -639,9 +639,9 @@ version = "1.3.4" [[deps.RecursiveArrayTools]] deps = ["Adapt", "ArrayInterface", "DocStringExtensions", "GPUArraysCore", "IteratorInterfaceExtensions", "LinearAlgebra", "RecipesBase", "Requires", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables"] -git-tree-sha1 = "d7087c013e8a496ff396bae843b1e16d9a30ede8" +git-tree-sha1 = "fa453b42ba1623bd2e70260bf44dac850a3430a7" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" -version = "2.38.10" +version = "2.39.0" [deps.RecursiveArrayTools.extensions] RecursiveArrayToolsMeasurementsExt = "Measurements" @@ -689,9 +689,9 @@ version = "0.1.0" [[deps.SLEEFPirates]] deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "f5c896d781486f1d67c8492f0e0ead2c3517208c" +git-tree-sha1 = "3aac6d68c5e57449f5b9b865c9ba50ac2970c4cf" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.41" +version = "0.6.42" [[deps.SciMLBase]] deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"] @@ -766,7 +766,7 @@ version = "1.10.0" [[deps.SparseDiffTools]] deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"] -git-tree-sha1 = "5188e5e415908a19a41cd90d8ab74a23affacba6" +git-tree-sha1 = "888937b8348e1e9ffae1c31efa61e693bc5463ba" repo-rev = "ap/tagging" repo-url = "https://github.com/avik-pal/SparseDiffTools.jl" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" diff --git a/src/utils.jl b/src/utils.jl index 2777c93e4..ba1230258 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -10,6 +10,11 @@ end struct NonlinearSolveTag end +function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:NonlinearSolveTag, <:T}}, f::F, + x::AbstractArray{T}) where {T, F} + return true +end + """ default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), diff_type = Val{:forward}) @@ -43,7 +48,8 @@ function default_adargs_to_adtype(; chunk_size = missing, autodiff = nothing, ad = _unwrap_val(autodiff) # We don't really know the typeof the input yet, so we can't use the correct tag! - ad && return AutoForwardDiff{_unwrap_val(chunk_size), Nothing}(nothing) + ad && return AutoForwardDiff{_unwrap_val(chunk_size), NonlinearSolveTag}(; + tag = NonlinearSolveTag()) return AutoFiniteDiff(; fdtype = diff_type) end @@ -117,17 +123,6 @@ function wrapprecs(_Pl, _Pr, weight) return Pl, Pr end -function _nfcount(N, ::Type{diff_type}) where {diff_type} - if diff_type === Val{:complex} - tmp = N - elseif diff_type === Val{:forward} - tmp = N + 1 - else - tmp = 2N - end - return tmp -end - get_loss(fu) = norm(fu)^2 / 2 function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-function for adaptive trust region method @@ -203,7 +198,7 @@ function __get_concrete_algorithm(alg, prob) use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff() else (use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; - tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0))) + tag = NonlinearSolveTag()) end return set_ad(alg, ad) end diff --git a/test/sparse.jl b/test/sparse.jl index ce338489e..729629c38 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -41,6 +41,13 @@ end u0 = init_brusselator_2d(xyd_brusselator) prob_brusselator_2d = NonlinearProblem(brusselator_2d_loop, u0, p) sol = solve(prob_brusselator_2d, NewtonRaphson()) +@test norm(sol.resid) < 1e-8 + +sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseForwardDiff())) +@test norm(sol.resid) < 1e-8 + +sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseFiniteDiff())) +@test norm(sol.resid) < 1e-8 du0 = copy(u0) jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> brusselator_2d_loop(du, u, p), du0, @@ -57,7 +64,8 @@ sol = solve(prob_brusselator_2d, NewtonRaphson()) @test !all(iszero, jac_prototype) sol = solve(prob_brusselator_2d, NewtonRaphson(autodiff = AutoSparseFiniteDiff())) -@test norm(sol.resid) < 1e-6 +@test norm(sol.resid) < 1e-8 cache = init(prob_brusselator_2d, NewtonRaphson(; autodiff = AutoSparseForwardDiff())); @test maximum(cache.jac_cache.coloring.colorvec) == 12 +@test cache.alg.ad isa AutoSparseForwardDiff