diff --git a/src/Sundials.jl b/src/Sundials.jl index cdcb51d..3cd1fe9 100644 --- a/src/Sundials.jl +++ b/src/Sundials.jl @@ -22,18 +22,30 @@ const Ctime_t = UInt const Cclock_t = UInt export Ctm, Ctime_t, Cclock_t -const warnkeywords = - (:save_idxs, :d_discontinuities, :isoutofdomain, :unstable_check, - :calck, :internalnorm, :gamma, :beta1, :beta2, :qmax, :qmin, :qoldinit) +const warnkeywords = ( + :save_idxs, + :d_discontinuities, + :isoutofdomain, + :unstable_check, + :calck, + :internalnorm, + :gamma, + :beta1, + :beta2, + :qmax, + :qmin, + :qoldinit, +) function __init__() global warnlist = Set(warnkeywords) - global warnida = union(warnlist, Set((:dtmin,))) + global warnida = union(warnlist, Set((:dtmin,))) end using Sundials_jll -export solve, SundialsODEAlgorithm, SundialsDAEAlgorithm, ARKODE, CVODE_BDF, CVODE_Adams, IDA +export solve, + SundialsODEAlgorithm, SundialsDAEAlgorithm, ARKODE, CVODE_BDF, CVODE_Adams, IDA # some definitions from the system C headers wrapped into the types_and_consts.jl const DBL_MAX = prevfloat(Inf) diff --git a/src/common_interface/algorithms.jl b/src/common_interface/algorithms.jl index 00c241c..7e18a9b 100644 --- a/src/common_interface/algorithms.jl +++ b/src/common_interface/algorithms.jl @@ -1,11 +1,11 @@ # Sundials.jl algorithms # Abstract Types -abstract type SundialsODEAlgorithm{Method,LinearSolver} <: DiffEqBase.AbstractODEAlgorithm end +abstract type SundialsODEAlgorithm{Method, LinearSolver} <: DiffEqBase.AbstractODEAlgorithm end abstract type SundialsDAEAlgorithm{LinearSolver} <: DiffEqBase.AbstractDAEAlgorithm end # ODE Algorithms -struct CVODE_BDF{Method,LinearSolver,P,PS} <: SundialsODEAlgorithm{Method,LinearSolver} +struct CVODE_BDF{Method, LinearSolver, P, PS} <: SundialsODEAlgorithm{Method, LinearSolver} jac_upper::Int jac_lower::Int krylov_dim::Int @@ -19,33 +19,62 @@ struct CVODE_BDF{Method,LinearSolver,P,PS} <: SundialsODEAlgorithm{Method,Linear psetup::PS prec_side::Int end -Base.@pure function CVODE_BDF(;method=:Newton,linear_solver=:Dense, - jac_upper=0,jac_lower=0,non_zero=0,krylov_dim=0, - stability_limit_detect=false, - max_hnil_warns = 10, - max_order = 5, - max_error_test_failures = 7, - max_nonlinear_iters = 3, - max_convergence_failures = 10, - prec = nothing, psetup = nothing, prec_side = 0) - if linear_solver == :Band && (jac_upper==0 || jac_lower==0) +Base.@pure function CVODE_BDF(; + method = :Newton, + linear_solver = :Dense, + jac_upper = 0, + jac_lower = 0, + non_zero = 0, + krylov_dim = 0, + stability_limit_detect = false, + max_hnil_warns = 10, + max_order = 5, + max_error_test_failures = 7, + max_nonlinear_iters = 3, + max_convergence_failures = 10, + prec = nothing, + psetup = nothing, + prec_side = 0, +) + if linear_solver == :Band && (jac_upper == 0 || jac_lower == 0) error("Banded solver must set the jac_upper and jac_lower") end - if !(linear_solver in (:None, :Diagonal, :Dense, :LapackDense, :Band, :LapackBand, :BCG, :GMRES, :FGMRES, :PCG, :TFQMR, :KLU)) + if !( + linear_solver in ( + :None, + :Diagonal, + :Dense, + :LapackDense, + :Band, + :LapackBand, + :BCG, + :GMRES, + :FGMRES, + :PCG, + :TFQMR, + :KLU, + ) + ) error("Linear solver choice not accepted.") end - CVODE_BDF{method,linear_solver, typeof(prec), typeof(psetup)}( - jac_upper,jac_lower, - krylov_dim, - stability_limit_detect, - max_hnil_warns, - max_order, - max_error_test_failures, - max_nonlinear_iters, - max_convergence_failures, prec, psetup, prec_side) + CVODE_BDF{method, linear_solver, typeof(prec), typeof(psetup)}( + jac_upper, + jac_lower, + krylov_dim, + stability_limit_detect, + max_hnil_warns, + max_order, + max_error_test_failures, + max_nonlinear_iters, + max_convergence_failures, + prec, + psetup, + prec_side, + ) end -struct CVODE_Adams{Method,LinearSolver,P,PS} <: SundialsODEAlgorithm{Method,LinearSolver} +struct CVODE_Adams{Method, LinearSolver, P, PS} <: + SundialsODEAlgorithm{Method, LinearSolver} jac_upper::Int jac_lower::Int krylov_dim::Int @@ -59,36 +88,61 @@ struct CVODE_Adams{Method,LinearSolver,P,PS} <: SundialsODEAlgorithm{Method,Line psetup::PS prec_side::Int end -Base.@pure function CVODE_Adams(;method=:Functional,linear_solver=:None, - jac_upper=0,jac_lower=0, - krylov_dim=0, - stability_limit_detect=false, - max_hnil_warns = 10, - max_order = 12, - max_error_test_failures = 7, - max_nonlinear_iters = 3, - max_convergence_failures = 10, - prec = nothing, psetup = nothing, prec_side = 0 - ) - if linear_solver == :Band && (jac_upper==0 || jac_lower==0) +Base.@pure function CVODE_Adams(; + method = :Functional, + linear_solver = :None, + jac_upper = 0, + jac_lower = 0, + krylov_dim = 0, + stability_limit_detect = false, + max_hnil_warns = 10, + max_order = 12, + max_error_test_failures = 7, + max_nonlinear_iters = 3, + max_convergence_failures = 10, + prec = nothing, + psetup = nothing, + prec_side = 0, +) + if linear_solver == :Band && (jac_upper == 0 || jac_lower == 0) error("Banded solver must set the jac_upper and jac_lower") end - if !(linear_solver in (:None, :Diagonal, :Dense, :LapackDense, :Band, :LapackBand, :BCG, :GMRES, :FGMRES, :PCG, :TFQMR, :KLU)) + if !( + linear_solver in ( + :None, + :Diagonal, + :Dense, + :LapackDense, + :Band, + :LapackBand, + :BCG, + :GMRES, + :FGMRES, + :PCG, + :TFQMR, + :KLU, + ) + ) error("Linear solver choice not accepted.") end - CVODE_Adams{method,linear_solver,typeof(prec),typeof(psetup)}( - jac_upper,jac_lower, - krylov_dim, - stability_limit_detect, - max_hnil_warns, - max_order, - max_error_test_failures, - max_nonlinear_iters, - max_convergence_failures,prec,psetup, - prec_side) + CVODE_Adams{method, linear_solver, typeof(prec), typeof(psetup)}( + jac_upper, + jac_lower, + krylov_dim, + stability_limit_detect, + max_hnil_warns, + max_order, + max_error_test_failures, + max_nonlinear_iters, + max_convergence_failures, + prec, + psetup, + prec_side, + ) end -struct ARKODE{Method,LinearSolver,MassLinearSolver,T,T1,T2,P,PS} <: SundialsODEAlgorithm{Method,LinearSolver} +struct ARKODE{Method, LinearSolver, MassLinearSolver, T, T1, T2, P, PS} <: + SundialsODEAlgorithm{Method, LinearSolver} stiffness::T jac_upper::Int jac_lower::Int @@ -116,123 +170,203 @@ struct ARKODE{Method,LinearSolver,MassLinearSolver,T,T1,T2,P,PS} <: SundialsODEA prec_side::Int end -Base.@pure function ARKODE(stiffness=Implicit();method=:Newton,linear_solver=:Dense, - mass_linear_solver=:Dense, - jac_upper=0,jac_lower=0, - mass_upper=0,mass_lower=0, - non_zero=0,krylov_dim=0,mass_krylov_dim=0, - max_hnil_warns = 10, - max_error_test_failures = 7, - max_nonlinear_iters = 3, - max_convergence_failures = 10, - predictor_method = 0, - nonlinear_convergence_coefficient = 0.1, - dense_order = 3, - order = 4, - set_optimal_params = false, - crdown = 0.3, - dgmax = 0.2, - rdiv = 2.3, - msbp = 20, - adaptivity_method = 0, - itable = nothing, - etable = nothing, - prec = nothing, psetup = nothing, prec_side = 0 - ) - if linear_solver == :Band && (jac_upper==0 || jac_lower==0) +Base.@pure function ARKODE( + stiffness = Implicit(); + method = :Newton, + linear_solver = :Dense, + mass_linear_solver = :Dense, + jac_upper = 0, + jac_lower = 0, + mass_upper = 0, + mass_lower = 0, + non_zero = 0, + krylov_dim = 0, + mass_krylov_dim = 0, + max_hnil_warns = 10, + max_error_test_failures = 7, + max_nonlinear_iters = 3, + max_convergence_failures = 10, + predictor_method = 0, + nonlinear_convergence_coefficient = 0.1, + dense_order = 3, + order = 4, + set_optimal_params = false, + crdown = 0.3, + dgmax = 0.2, + rdiv = 2.3, + msbp = 20, + adaptivity_method = 0, + itable = nothing, + etable = nothing, + prec = nothing, + psetup = nothing, + prec_side = 0, +) + if linear_solver == :Band && (jac_upper == 0 || jac_lower == 0) error("Banded solver must set the jac_upper and jac_lower") end - if !(linear_solver in (:None, :Diagonal, :Dense, :LapackDense, :Band, :LapackBand, :BCG, :GMRES, :FGMRES, :PCG, :TFQMR, :KLU)) + if !( + linear_solver in ( + :None, + :Diagonal, + :Dense, + :LapackDense, + :Band, + :LapackBand, + :BCG, + :GMRES, + :FGMRES, + :PCG, + :TFQMR, + :KLU, + ) + ) error("Linear solver choice not accepted.") end - if !(mass_linear_solver in (:None, :Diagonal, :Dense, :LapackDense, :Band, :LapackBand, :BCG, :GMRES, :FGMRES, :PCG, :TFQMR, :KLU)) + if !( + mass_linear_solver in ( + :None, + :Diagonal, + :Dense, + :LapackDense, + :Band, + :LapackBand, + :BCG, + :GMRES, + :FGMRES, + :PCG, + :TFQMR, + :KLU, + ) + ) error("Mass Matrix Linear solver choice not accepted.") end - ARKODE{method,linear_solver,mass_linear_solver, - typeof(stiffness), - typeof(itable),typeof(etable), - typeof(prec),typeof(psetup)}( - stiffness, - jac_upper,jac_lower, - mass_upper,mass_lower, - krylov_dim,mass_krylov_dim, - max_hnil_warns, - max_error_test_failures, - max_nonlinear_iters, - max_convergence_failures, - predictor_method, - nonlinear_convergence_coefficient, - dense_order, - order, - set_optimal_params, - crdown, - dgmax, - rdiv, - msbp, - itable, - etable, - prec, psetup, prec_side) + ARKODE{ + method, + linear_solver, + mass_linear_solver, + typeof(stiffness), + typeof(itable), + typeof(etable), + typeof(prec), + typeof(psetup), + }( + stiffness, + jac_upper, + jac_lower, + mass_upper, + mass_lower, + krylov_dim, + mass_krylov_dim, + max_hnil_warns, + max_error_test_failures, + max_nonlinear_iters, + max_convergence_failures, + predictor_method, + nonlinear_convergence_coefficient, + dense_order, + order, + set_optimal_params, + crdown, + dgmax, + rdiv, + msbp, + itable, + etable, + prec, + psetup, + prec_side, + ) end # DAE Algorithms -struct IDA{LinearSolver,P,PS} <: SundialsDAEAlgorithm{LinearSolver} - jac_upper::Int - jac_lower::Int - krylov_dim::Int - max_order::Int - max_error_test_failures::Int - nonlinear_convergence_coefficient::Float64 - max_nonlinear_iters::Int - max_convergence_failures::Int - nonlinear_convergence_coefficient_ic::Float64 - max_num_steps_ic::Int - max_num_jacs_ic::Int - max_num_iters_ic::Int - max_num_backs_ic::Int - use_linesearch_ic::Bool - init_all::Bool - prec::P - psetup::PS - prec_side::Int +struct IDA{LinearSolver, P, PS} <: SundialsDAEAlgorithm{LinearSolver} + jac_upper::Int + jac_lower::Int + krylov_dim::Int + max_order::Int + max_error_test_failures::Int + nonlinear_convergence_coefficient::Float64 + max_nonlinear_iters::Int + max_convergence_failures::Int + nonlinear_convergence_coefficient_ic::Float64 + max_num_steps_ic::Int + max_num_jacs_ic::Int + max_num_iters_ic::Int + max_num_backs_ic::Int + use_linesearch_ic::Bool + init_all::Bool + prec::P + psetup::PS + prec_side::Int end -Base.@pure function IDA(;linear_solver=:Dense,jac_upper=0,jac_lower=0, - krylov_dim=0, - max_order = 5, - max_error_test_failures = 7, - max_nonlinear_iters = 3, - nonlinear_convergence_coefficient = 0.33, - nonlinear_convergence_coefficient_ic = 0.0033, - max_num_steps_ic = 5, - max_num_jacs_ic = 4, - max_num_iters_ic = 10, - max_num_backs_ic = 100, - use_linesearch_ic = true, - init_all = false, - max_convergence_failures = 10, - prec = nothing, psetup = nothing, prec_side = 0) - if linear_solver == :Band && (jac_upper==0 || jac_lower==0) - error("Banded solver must set the jac_upper and jac_lower") - end - if !(linear_solver in (:None, :Diagonal, :Dense, :LapackDense, :Band, :LapackBand, :BCG, :GMRES, :FGMRES, :PCG, :TFQMR, :KLU)) - error("Linear solver choice not accepted.") - end - IDA{linear_solver,typeof(prec),typeof(psetup)}( - jac_upper,jac_lower,krylov_dim, - max_order, - max_error_test_failures, - nonlinear_convergence_coefficient, - max_nonlinear_iters, - max_convergence_failures, - nonlinear_convergence_coefficient_ic, - max_num_steps_ic, - max_num_jacs_ic, - max_num_iters_ic, - max_num_backs_ic, - use_linesearch_ic, - init_all,prec, psetup, prec_side) +Base.@pure function IDA(; + linear_solver = :Dense, + jac_upper = 0, + jac_lower = 0, + krylov_dim = 0, + max_order = 5, + max_error_test_failures = 7, + max_nonlinear_iters = 3, + nonlinear_convergence_coefficient = 0.33, + nonlinear_convergence_coefficient_ic = 0.0033, + max_num_steps_ic = 5, + max_num_jacs_ic = 4, + max_num_iters_ic = 10, + max_num_backs_ic = 100, + use_linesearch_ic = true, + init_all = false, + max_convergence_failures = 10, + prec = nothing, + psetup = nothing, + prec_side = 0, +) + if linear_solver == :Band && (jac_upper == 0 || jac_lower == 0) + error("Banded solver must set the jac_upper and jac_lower") + end + if !( + linear_solver in ( + :None, + :Diagonal, + :Dense, + :LapackDense, + :Band, + :LapackBand, + :BCG, + :GMRES, + :FGMRES, + :PCG, + :TFQMR, + :KLU, + ) + ) + error("Linear solver choice not accepted.") + end + IDA{linear_solver, typeof(prec), typeof(psetup)}( + jac_upper, + jac_lower, + krylov_dim, + max_order, + max_error_test_failures, + nonlinear_convergence_coefficient, + max_nonlinear_iters, + max_convergence_failures, + nonlinear_convergence_coefficient_ic, + max_num_steps_ic, + max_num_jacs_ic, + max_num_iters_ic, + max_num_backs_ic, + use_linesearch_ic, + init_all, + prec, + psetup, + prec_side, + ) end -method_choice(alg::SundialsODEAlgorithm{Method}) where Method = Method +method_choice(alg::SundialsODEAlgorithm{Method}) where {Method} = Method method_choice(alg::SundialsDAEAlgorithm) = :Newton -linear_solver(alg::SundialsODEAlgorithm{Method,LinearSolver}) where {Method,LinearSolver}= LinearSolver -linear_solver(alg::SundialsDAEAlgorithm{LinearSolver}) where LinearSolver = LinearSolver +linear_solver( + alg::SundialsODEAlgorithm{Method, LinearSolver}, +) where {Method, LinearSolver} = LinearSolver +linear_solver(alg::SundialsDAEAlgorithm{LinearSolver}) where {LinearSolver} = LinearSolver diff --git a/src/common_interface/function_types.jl b/src/common_interface/function_types.jl index 1ff8f9c..c626e5c 100644 --- a/src/common_interface/function_types.jl +++ b/src/common_interface/function_types.jl @@ -12,60 +12,62 @@ mutable struct FunJac{F, F2, J, P, M, J2, uType, uType2, Prec, PS} <: AbstractFu du::uType resid::uType2 end -FunJac(fun,jac,p,m,jac_prototype,prec,psetup,u,du) = FunJac(fun,nothing,jac,p,m,jac_prototype,prec,psetup,u,du,nothing) -FunJac(fun,jac,p,m,jac_prototype,prec,psetup,u,du,resid) = FunJac(fun,nothing,jac,p,m,jac_prototype,prec,psetup,u,du,resid) - -function cvodefunjac(t::Float64, - u::N_Vector, - du::N_Vector, - funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),length(funjac.du)) +FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du) = + FunJac(fun, nothing, jac, p, m, jac_prototype, prec, psetup, u, du, nothing) +FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du, resid) = + FunJac(fun, nothing, jac, p, m, jac_prototype, prec, psetup, u, du, resid) + +function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) + funjac.du = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), length(funjac.du)) _du = funjac.du _u = funjac.u funjac.fun(_du, _u, funjac.p, t) return CV_SUCCESS end -function cvodefunjac2(t::Float64, - u::N_Vector, - du::N_Vector, - funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),length(funjac.du)) +function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) + funjac.du = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), length(funjac.du)) _du = funjac.du _u = funjac.u funjac.fun2(_du, _u, funjac.p, t) return CV_SUCCESS end -function cvodejac(t::realtype, - u::N_Vector, - du::N_Vector, - J::SUNMatrix, - funjac::AbstractFunJac{Nothing}, - tmp1::N_Vector, - tmp2::N_Vector, - tmp3::N_Vector) +function cvodejac( + t::realtype, + u::N_Vector, + du::N_Vector, + J::SUNMatrix, + funjac::AbstractFunJac{Nothing}, + tmp1::N_Vector, + tmp2::N_Vector, + tmp3::N_Vector, +) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) _u = funjac.u funjac.jac(convert(Matrix, J), _u, funjac.p, t) return CV_SUCCESS end -function cvodejac(t::realtype, - u::N_Vector, - du::N_Vector, - _J::SUNMatrix, - funjac::AbstractFunJac{<:SparseArrays.SparseMatrixCSC}, - tmp1::N_Vector, - tmp2::N_Vector, - tmp3::N_Vector) +function cvodejac( + t::realtype, + u::N_Vector, + du::N_Vector, + _J::SUNMatrix, + funjac::AbstractFunJac{<:SparseArrays.SparseMatrixCSC}, + tmp1::N_Vector, + tmp2::N_Vector, + tmp3::N_Vector, +) jac_prototype = funjac.jac_prototype - J = convert(SparseArrays.SparseMatrixCSC,_J) + J = convert(SparseArrays.SparseMatrixCSC, _J) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) _u = funjac.u funjac.jac(jac_prototype, _u, funjac.p, t) @@ -77,154 +79,202 @@ function cvodejac(t::realtype, end function idasolfun(t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector, funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),length(funjac.du)) + funjac.du = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), length(funjac.du)) _du = funjac.du - funjac.resid = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(resid),length(funjac.resid)) + funjac.resid = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(resid), length(funjac.resid)) _resid = funjac.resid funjac.fun(_resid, _du, _u, funjac.p, t) return IDA_SUCCESS end -function idajac(t::realtype, - cj::realtype, - u::N_Vector, - du::N_Vector, - res::N_Vector, - J::SUNMatrix, - funjac::AbstractFunJac{Nothing}, - tmp1::N_Vector, - tmp2::N_Vector, - tmp3::N_Vector) - +function idajac( + t::realtype, + cj::realtype, + u::N_Vector, + du::N_Vector, + res::N_Vector, + J::SUNMatrix, + funjac::AbstractFunJac{Nothing}, + tmp1::N_Vector, + tmp2::N_Vector, + tmp3::N_Vector, +) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),length(funjac.du)) + funjac.du = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), length(funjac.du)) _du = funjac.du - funjac.jac(convert(Matrix, J), _du, - _u, funjac.p, cj, t) + funjac.jac(convert(Matrix, J), _du, _u, funjac.p, cj, t) return IDA_SUCCESS end -function idajac(t::realtype, - cj::realtype, - u::N_Vector, - du::N_Vector, - res::N_Vector, - _J::SUNMatrix, - funjac::AbstractFunJac{<:SparseArrays.SparseMatrixCSC}, - tmp1::N_Vector, - tmp2::N_Vector, - tmp3::N_Vector) - - jac_prototype = funjac.jac_prototype - J = convert(SparseArrays.SparseMatrixCSC,_J) - - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u),length(funjac.u)) - _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),length(funjac.du)) - _du = funjac.du - - funjac.jac(jac_prototype, _du, convert(Vector, _u), funjac.p, cj, t) - J.nzval .= jac_prototype.nzval - # Sundials resets the value pointers each time, so reset it too - @. J.rowval = jac_prototype.rowval - 1 - @. J.colptr = jac_prototype.colptr - 1 - - return IDA_SUCCESS +function idajac( + t::realtype, + cj::realtype, + u::N_Vector, + du::N_Vector, + res::N_Vector, + _J::SUNMatrix, + funjac::AbstractFunJac{<:SparseArrays.SparseMatrixCSC}, + tmp1::N_Vector, + tmp2::N_Vector, + tmp3::N_Vector, +) + + jac_prototype = funjac.jac_prototype + J = convert(SparseArrays.SparseMatrixCSC, _J) + + funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) + _u = funjac.u + funjac.du = + unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), length(funjac.du)) + _du = funjac.du + + funjac.jac(jac_prototype, _du, convert(Vector, _u), funjac.p, cj, t) + J.nzval .= jac_prototype.nzval + # Sundials resets the value pointers each time, so reset it too + @. J.rowval = jac_prototype.rowval - 1 + @. J.colptr = jac_prototype.colptr - 1 + + return IDA_SUCCESS end -function massmat(t::Float64, - _M::SUNMatrix, - mmf::AbstractFunJac, - tmp1::N_Vector, - tmp2::N_Vector, - tmp3::N_Vector) - if typeof(mmf.mass_matrix) <: Array - M = convert(Matrix, _M) - else - M = convert(SparseArrays.SparseMatrixCSC, _M) - end - M .= mmf.mass_matrix - - return IDA_SUCCESS +function massmat( + t::Float64, + _M::SUNMatrix, + mmf::AbstractFunJac, + tmp1::N_Vector, + tmp2::N_Vector, + tmp3::N_Vector, +) + if typeof(mmf.mass_matrix) <: Array + M = convert(Matrix, _M) + else + M = convert(SparseArrays.SparseMatrixCSC, _M) + end + M .= mmf.mass_matrix + + return IDA_SUCCESS end -function jactimes(v::N_Vector, - Jv::N_Vector, - t::Float64, - y::N_Vector, - fy::N_Vector, - fj::AbstractFunJac, - tmp::N_Vector) - DiffEqBase.update_coefficients!(fj.jac_prototype,y,fj.p,t) - LinearAlgebra.mul!(convert(Vector,Jv),fj.jac_prototype,convert(Vector,v)) +function jactimes( + v::N_Vector, + Jv::N_Vector, + t::Float64, + y::N_Vector, + fy::N_Vector, + fj::AbstractFunJac, + tmp::N_Vector, +) + DiffEqBase.update_coefficients!(fj.jac_prototype, y, fj.p, t) + LinearAlgebra.mul!(convert(Vector, Jv), fj.jac_prototype, convert(Vector, v)) return CV_SUCCESS end function idajactimes( - t::Float64, - y::N_Vector, - fy::N_Vector, - r::N_Vector, - v::N_Vector, - Jv::N_Vector, - cj::Float64, - fj::AbstractFunJac, - tmp1::N_Vector, - tmp2::N_Vector) - DiffEqBase.update_coefficients!(fj.jac_prototype,y,fj.p,t) - LinearAlgebra.mul!(convert(Vector,Jv),fj.jac_prototype,convert(Vector,v)) + t::Float64, + y::N_Vector, + fy::N_Vector, + r::N_Vector, + v::N_Vector, + Jv::N_Vector, + cj::Float64, + fj::AbstractFunJac, + tmp1::N_Vector, + tmp2::N_Vector, +) + DiffEqBase.update_coefficients!(fj.jac_prototype, y, fj.p, t) + LinearAlgebra.mul!(convert(Vector, Jv), fj.jac_prototype, convert(Vector, v)) return IDA_SUCCESS end -function precsolve(t::Float64, - y::N_Vector, - fy::N_Vector, - r::N_Vector, - z::N_Vector, - gamma::Float64, - delta::Float64, - lr::Int, - fj::AbstractFunJac) - fj.prec(convert(Vector,z),convert(Vector,r),fj.p,t,convert(Vector,y),convert(Vector,fy),gamma,delta,lr) +function precsolve( + t::Float64, + y::N_Vector, + fy::N_Vector, + r::N_Vector, + z::N_Vector, + gamma::Float64, + delta::Float64, + lr::Int, + fj::AbstractFunJac, +) + fj.prec( + convert(Vector, z), + convert(Vector, r), + fj.p, + t, + convert(Vector, y), + convert(Vector, fy), + gamma, + delta, + lr, + ) return CV_SUCCESS end -function precsetup(t::Float64, - y::N_Vector, - fy::N_Vector, - jok::Int, - jcurPtr::Ref{Int}, - gamma::Float64, - fj::AbstractFunJac) - fj.psetup(fj.p,t,convert(Vector,y),convert(Vector,fy),jok==1,Base.unsafe_wrap(Vector{Int}, jcurPtr, 1),gamma) +function precsetup( + t::Float64, + y::N_Vector, + fy::N_Vector, + jok::Int, + jcurPtr::Ref{Int}, + gamma::Float64, + fj::AbstractFunJac, +) + fj.psetup( + fj.p, + t, + convert(Vector, y), + convert(Vector, fy), + jok == 1, + Base.unsafe_wrap(Vector{Int}, jcurPtr, 1), + gamma, + ) return CV_SUCCESS end -function idaprecsolve(t::Float64, - y::N_Vector, - fy::N_Vector, - resid::N_Vector, - r::N_Vector, - z::N_Vector, - gamma::Float64, - delta::Float64, - lr::Int, - fj::AbstractFunJac) - fj.prec(convert(Vector,z),convert(Vector,r),fj.p,t,convert(Vector,y),convert(Vector,fy),convert(Vector,resid),gamma,delta,lr) +function idaprecsolve( + t::Float64, + y::N_Vector, + fy::N_Vector, + resid::N_Vector, + r::N_Vector, + z::N_Vector, + gamma::Float64, + delta::Float64, + lr::Int, + fj::AbstractFunJac, +) + fj.prec( + convert(Vector, z), + convert(Vector, r), + fj.p, + t, + convert(Vector, y), + convert(Vector, fy), + convert(Vector, resid), + gamma, + delta, + lr, + ) return IDA_SUCCESS end -function idaprecsetup(t::Float64, - y::N_Vector, - fy::N_Vector, - rr::N_Vector, - gamma::Float64, - fj::AbstractFunJac) - fj.psetup(fj.p,t,convert(Vector,rr),convert(Vector,y),convert(Vector,fy),gamma) +function idaprecsetup( + t::Float64, + y::N_Vector, + fy::N_Vector, + rr::N_Vector, + gamma::Float64, + fj::AbstractFunJac, +) + fj.psetup(fj.p, t, convert(Vector, rr), convert(Vector, y), convert(Vector, fy), gamma) return IDA_SUCCESS end diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index afe14c7..3db3667 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -1,4 +1,4 @@ -mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType,F5} +mutable struct DEOptions{SType, TstopType, CType, reltolType, abstolType, F5} saveat::SType tstops::TstopType save_everystep::Bool @@ -19,9 +19,26 @@ mutable struct DEOptions{SType,TstopType,CType,reltolType,abstolType,F5} maxiters::Int end -abstract type AbstractSundialsIntegrator{algType} <: DiffEqBase.AbstractODEIntegrator{algType,true,Vector{Float64},Float64} end +abstract type AbstractSundialsIntegrator{algType} <: + DiffEqBase.AbstractODEIntegrator{algType, true, Vector{Float64}, Float64} end -mutable struct CVODEIntegrator{uType,pType,memType,solType,algType,fType,UFType,JType,oType,toutType,sizeType,tmpType,LStype,Atype,CallbackCacheType} <: AbstractSundialsIntegrator{algType} +mutable struct CVODEIntegrator{ + uType, + pType, + memType, + solType, + algType, + fType, + UFType, + JType, + oType, + toutType, + sizeType, + tmpType, + LStype, + Atype, + CallbackCacheType, +} <: AbstractSundialsIntegrator{algType} u::uType p::pType t::Float64 @@ -49,18 +66,45 @@ mutable struct CVODEIntegrator{uType,pType,memType,solType,algType,fType,UFType, last_event_error::Float64 end -function (integrator::CVODEIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::CVODEIntegrator)( + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} out = similar(integrator.u) integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : out[idxs] end -function (integrator::CVODEIntegrator)(out,t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::CVODEIntegrator)( + out, + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : @view out[idxs] end -mutable struct ARKODEIntegrator{uType,pType,memType,solType,algType,fType,UFType,JType,oType,toutType,sizeType,tmpType,LStype,Atype,MLStype,Mtype,CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE} +mutable struct ARKODEIntegrator{ + uType, + pType, + memType, + solType, + algType, + fType, + UFType, + JType, + oType, + toutType, + sizeType, + tmpType, + LStype, + Atype, + MLStype, + Mtype, + CallbackCacheType, +} <: AbstractSundialsIntegrator{ARKODE} u::uType p::pType t::Float64 @@ -90,18 +134,45 @@ mutable struct ARKODEIntegrator{uType,pType,memType,solType,algType,fType,UFType last_event_error::Float64 end -function (integrator::ARKODEIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::ARKODEIntegrator)( + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} out = similar(integrator.u) integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : out[idxs] end -function (integrator::ARKODEIntegrator)(out,t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::ARKODEIntegrator)( + out, + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : @view out[idxs] end -mutable struct IDAIntegrator{uType,duType,pType,memType,solType,algType,fType,UFType,JType,oType,toutType,sizeType,sizeDType,tmpType,LStype,Atype,CallbackCacheType} <: AbstractSundialsIntegrator{IDA} +mutable struct IDAIntegrator{ + uType, + duType, + pType, + memType, + solType, + algType, + fType, + UFType, + JType, + oType, + toutType, + sizeType, + sizeDType, + tmpType, + LStype, + Atype, + CallbackCacheType, +} <: AbstractSundialsIntegrator{IDA} u::uType du::duType p::pType @@ -131,13 +202,22 @@ mutable struct IDAIntegrator{uType,duType,pType,memType,solType,algType,fType,UF last_event_error::Float64 end -function (integrator::IDAIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::IDAIntegrator)( + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} out = similar(integrator.u) integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : out[idxs] end -function (integrator::IDAIntegrator)(out,t::Number,deriv::Type{Val{T}}=Val{0};idxs=nothing) where T +function (integrator::IDAIntegrator)( + out, + t::Number, + deriv::Type{Val{T}} = Val{0}; + idxs = nothing, +) where {T} integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out) return idxs == nothing ? out : @view out[idxs] end @@ -154,38 +234,38 @@ DiffEqBase.postamble!(integrator::AbstractSundialsIntegrator) = nothing ### Iterator interface @inline function DiffEqBase.step!(integrator::AbstractSundialsIntegrator) - if integrator.opts.advance_to_tstop - # The call to first is an overload of Base.first implemented in DataStructures - while integrator.tdir*(integrator.t-first(integrator.opts.tstops)) < -1e6eps() - tstop = first(integrator.opts.tstops) - set_stop_time(integrator,tstop) + if integrator.opts.advance_to_tstop + # The call to first is an overload of Base.first implemented in DataStructures + while integrator.tdir * (integrator.t - first(integrator.opts.tstops)) < -1e6eps() + tstop = first(integrator.opts.tstops) + set_stop_time(integrator, tstop) + integrator.tprev = integrator.t + if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) + integrator.uprev .= integrator.u + end + solver_step(integrator, tstop) + integrator.t = first(integrator.tout) + DiffEqBase.check_error!(integrator) != :Success && return + handle_callbacks!(integrator) + DiffEqBase.check_error!(integrator) != :Success && return + end + else integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks)<:Tuple{}) + if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) integrator.uprev .= integrator.u end - solver_step(integrator,tstop) + if !isempty(integrator.opts.tstops) + tstop = first(integrator.opts.tstops) + set_stop_time(integrator, tstop) + solver_step(integrator, tstop) + else + solver_step(integrator, 1.0) # fake tstop + end integrator.t = first(integrator.tout) DiffEqBase.check_error!(integrator) != :Success && return handle_callbacks!(integrator) DiffEqBase.check_error!(integrator) != :Success && return end - else - integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks)<:Tuple{}) - integrator.uprev .= integrator.u - end - if !isempty(integrator.opts.tstops) - tstop = first(integrator.opts.tstops) - set_stop_time(integrator,tstop) - solver_step(integrator,tstop) - else - solver_step(integrator,1.0) # fake tstop - end - integrator.t = first(integrator.tout) - DiffEqBase.check_error!(integrator) != :Success && return - handle_callbacks!(integrator) - DiffEqBase.check_error!(integrator) != :Success && return - end - handle_tstop!(integrator) - nothing + handle_tstop!(integrator) + nothing end diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index c8561ea..52271c9 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -1,175 +1,212 @@ function handle_callbacks!(integrator) - discrete_callbacks = integrator.opts.callback.discrete_callbacks - continuous_callbacks = integrator.opts.callback.continuous_callbacks - atleast_one_callback = false - - continuous_modified = false - discrete_modified = false - saved_in_cb = false - if !(typeof(continuous_callbacks)<:Tuple{}) - time,upcrossing,event_occured,event_idx,idx,counter = - DiffEqBase.find_first_continuous_callback(integrator,continuous_callbacks...) - if event_occured - integrator.event_last_time = idx - integrator.vector_event_last_time = event_idx - continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing,event_idx) - else - integrator.event_last_time = 0 - integrator.vector_event_last_time = 1 + discrete_callbacks = integrator.opts.callback.discrete_callbacks + continuous_callbacks = integrator.opts.callback.continuous_callbacks + atleast_one_callback = false + + continuous_modified = false + discrete_modified = false + saved_in_cb = false + if !(typeof(continuous_callbacks) <: Tuple{}) + time, upcrossing, event_occured, event_idx, idx, counter = + DiffEqBase.find_first_continuous_callback(integrator, continuous_callbacks...) + if event_occured + integrator.event_last_time = idx + integrator.vector_event_last_time = event_idx + continuous_modified, saved_in_cb = DiffEqBase.apply_callback!( + integrator, + continuous_callbacks[idx], + time, + upcrossing, + event_idx, + ) + else + integrator.event_last_time = 0 + integrator.vector_event_last_time = 1 + end end - end - if !(typeof(discrete_callbacks)<:Tuple{}) - discrete_modified,saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,discrete_callbacks...) - end - - integrator.u_modified = continuous_modified || discrete_modified - if integrator.u_modified - handle_callback_modifiers!(integrator) - end - - if !saved_in_cb - savevalues!(integrator) - end - - integrator.u_modified = false -end - -function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator,force_save=false)::Tuple{Bool,Bool} - saved, savedexactly = false, false - !integrator.opts.save_on && return saved, savedexactly - uType = eltype(integrator.sol.u) - # The call to first is an overload of Base.first implemented in DataStructures - while !isempty(integrator.opts.saveat) && - integrator.tdir*first(integrator.opts.saveat) < integrator.tdir*integrator.t - saved = true - curt = pop!(integrator.opts.saveat) - - tmp = integrator(curt) - save_value!(integrator.sol.u,tmp,uType,integrator.sizeu,Val{false}) - push!(integrator.sol.t,curt) - if integrator.opts.dense - tmp = integrator(curt,Val{1}) - save_value!(integrator.sol.interp.du,tmp,uType,integrator.sizeu,Val{false}) + if !(typeof(discrete_callbacks) <: Tuple{}) + discrete_modified, saved_in_cb = + DiffEqBase.apply_discrete_callback!(integrator, discrete_callbacks...) end - end - - if integrator.opts.save_everystep || force_save - saved = true - save_value!(integrator.sol.u,integrator.u,uType,integrator.sizeu) - push!(integrator.sol.t, integrator.t) - if integrator.opts.dense - tmp = integrator(integrator.t,Val{1}) - save_value!(integrator.sol.interp.du,tmp,uType,integrator.sizeu) + + integrator.u_modified = continuous_modified || discrete_modified + if integrator.u_modified + handle_callback_modifiers!(integrator) end - end - savedexactly = !isempty(integrator.sol.t) && last(integrator.sol.t) == integrator.t - return saved, savedexactly -end -function save_value!(save_array,val,::Type{T},sizeu, - make_copy::Type{Val{bool}}=Val{true}) where {T <: Number,bool} - push!(save_array,first(val)) -end -function save_value!(save_array,val,::Type{T},sizeu, - make_copy::Type{Val{bool}}=Val{true}) where {T <: Vector,bool} + if !saved_in_cb + savevalues!(integrator) + end + + integrator.u_modified = false +end + +function DiffEqBase.savevalues!( + integrator::AbstractSundialsIntegrator, + force_save = false, +)::Tuple{Bool, Bool} + saved, savedexactly = false, false + !integrator.opts.save_on && return saved, savedexactly + uType = eltype(integrator.sol.u) + # The call to first is an overload of Base.first implemented in DataStructures + while !isempty(integrator.opts.saveat) && + integrator.tdir * first(integrator.opts.saveat) < integrator.tdir * integrator.t + saved = true + curt = pop!(integrator.opts.saveat) + + tmp = integrator(curt) + save_value!(integrator.sol.u, tmp, uType, integrator.sizeu, Val{false}) + push!(integrator.sol.t, curt) + if integrator.opts.dense + tmp = integrator(curt, Val{1}) + save_value!(integrator.sol.interp.du, tmp, uType, integrator.sizeu, Val{false}) + end + end + + if integrator.opts.save_everystep || force_save + saved = true + save_value!(integrator.sol.u, integrator.u, uType, integrator.sizeu) + push!(integrator.sol.t, integrator.t) + if integrator.opts.dense + tmp = integrator(integrator.t, Val{1}) + save_value!(integrator.sol.interp.du, tmp, uType, integrator.sizeu) + end + end + savedexactly = !isempty(integrator.sol.t) && last(integrator.sol.t) == integrator.t + return saved, savedexactly +end + +function save_value!( + save_array, + val, + ::Type{T}, + sizeu, + make_copy::Type{Val{bool}} = Val{true}, +) where {T <: Number, bool} + push!(save_array, first(val)) +end +function save_value!( + save_array, + val, + ::Type{T}, + sizeu, + make_copy::Type{Val{bool}} = Val{true}, +) where {T <: Vector, bool} bool ? save = copy(val) : save = val - push!(save_array,save) -end -function save_value!(save_array,val,::Type{T},sizeu, - make_copy::Type{Val{bool}}=Val{true}) where {T <: Array,bool} + push!(save_array, save) +end +function save_value!( + save_array, + val, + ::Type{T}, + sizeu, + make_copy::Type{Val{bool}} = Val{true}, +) where {T <: Array, bool} bool ? save = copy(val) : save = val - push!(save_array,reshape(save,sizeu)) -end -function save_value!(save_array,val,::Type{T},sizeu, - make_copy::Type{Val{bool}}=Val{true}) where {T <: AbstractArray,bool} + push!(save_array, reshape(save, sizeu)) +end +function save_value!( + save_array, + val, + ::Type{T}, + sizeu, + make_copy::Type{Val{bool}} = Val{true}, +) where {T <: AbstractArray, bool} bool ? save = copy(val) : save = val - push!(save_array,convert(T,reshape(save,sizeu))) + push!(save_array, convert(T, reshape(save, sizeu))) end function handle_callback_modifiers!(integrator::CVODEIntegrator) - CVodeReInit(integrator.mem,integrator.t,integrator.u) + CVodeReInit(integrator.mem, integrator.t, integrator.u) end function handle_callback_modifiers!(integrator::ARKODEIntegrator) - ARKStepReInit(integrator.mem,integrator.t,integrator.u) + ARKStepReInit(integrator.mem, integrator.t, integrator.u) end function handle_callback_modifiers!(integrator::IDAIntegrator) - IDAReInit(integrator.mem,integrator.t,integrator.u,integrator.du) - DiffEqBase.initialize_dae!(integrator) + IDAReInit(integrator.mem, integrator.t, integrator.u, integrator.du) + DiffEqBase.initialize_dae!(integrator) end -function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator,t) - t < integrator.t && error("Tried to add a tstop that is behind the current time. This is strictly forbidden") - push!(integrator.opts.tstops,t) +function DiffEqBase.add_tstop!(integrator::AbstractSundialsIntegrator, t) + t < integrator.t && + error("Tried to add a tstop that is behind the current time. This is strictly forbidden") + push!(integrator.opts.tstops, t) end -function DiffEqBase.add_saveat!(integrator::AbstractSundialsIntegrator,t) - integrator.tdir * (t - integrator.t) < 0 && error("Tried to add a saveat that is behind the current time. This is strictly forbidden") - push!(integrator.opts.saveat,t) +function DiffEqBase.add_saveat!(integrator::AbstractSundialsIntegrator, t) + integrator.tdir * (t - integrator.t) < 0 && + error("Tried to add a saveat that is behind the current time. This is strictly forbidden") + push!(integrator.opts.saveat, t) end DiffEqBase.get_tmp_cache(integrator::AbstractSundialsIntegrator) = (integrator.tmp,) -@inline function DiffEqBase.u_modified!(integrator::AbstractSundialsIntegrator,bool::Bool) - integrator.u_modified = bool +@inline function DiffEqBase.u_modified!(integrator::AbstractSundialsIntegrator, bool::Bool) + integrator.u_modified = bool end -function DiffEqBase.terminate!(integrator::AbstractSundialsIntegrator, - retcode = :Terminated) - integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, retcode) - integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)() +function DiffEqBase.terminate!( + integrator::AbstractSundialsIntegrator, + retcode = :Terminated, +) + integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, retcode) + integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)() end @inline function DiffEqBase.get_du(integrator::CVODEIntegrator) - integrator(integrator.t,Val{1}) + integrator(integrator.t, Val{1}) end -@inline function DiffEqBase.get_du!(out,integrator::CVODEIntegrator) - integrator(out,integrator.t,Val{1}) +@inline function DiffEqBase.get_du!(out, integrator::CVODEIntegrator) + integrator(out, integrator.t, Val{1}) end @inline function DiffEqBase.get_du(integrator::IDAIntegrator) - reshape(integrator.du,integrator.sizedu) + reshape(integrator.du, integrator.sizedu) end -@inline function DiffEqBase.get_du!(out,integrator::IDAIntegrator) - out .= reshape(integrator.du,integrator.sizedu) +@inline function DiffEqBase.get_du!(out, integrator::IDAIntegrator) + out .= reshape(integrator.du, integrator.sizedu) end -function DiffEqBase.change_t_via_interpolation!(integrator::AbstractSundialsIntegrator,t) +function DiffEqBase.change_t_via_interpolation!(integrator::AbstractSundialsIntegrator, t) integrator.t = t - integrator(integrator.u,integrator.t) + integrator(integrator.u, integrator.t) return nothing end @inline function Base.getproperty(integrator::AbstractSundialsIntegrator, sym::Symbol) - if sym == :dt - return integrator.t-integrator.tprev - else - return getfield(integrator, sym) - end + if sym == :dt + return integrator.t - integrator.tprev + else + return getfield(integrator, sym) + end end -DiffEqBase.reeval_internals_due_to_modification!(integrator::AbstractSundialsIntegrator) = nothing -DiffEqBase.reeval_internals_due_to_modification!(integrator::IDAIntegrator) = handle_callback_modifiers!(integrator::IDAIntegrator) +DiffEqBase.reeval_internals_due_to_modification!(integrator::AbstractSundialsIntegrator) = + nothing +DiffEqBase.reeval_internals_due_to_modification!(integrator::IDAIntegrator) = + handle_callback_modifiers!(integrator::IDAIntegrator) # Required for callbacks -DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator,dt) = nothing +DiffEqBase.set_proposed_dt!(i::AbstractSundialsIntegrator, dt) = nothing DiffEqBase.initialize_dae!(integrator::AbstractSundialsIntegrator) = nothing function DiffEqBase.initialize_dae!(integrator::IDAIntegrator) - integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) - if any(abs.(integrator.tmp) .>= integrator.opts.reltol) - if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all - error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") - end - if integrator.alg.init_all - init_type = IDA_Y_INIT - else - init_type = IDA_YA_YDP_INIT - integrator.flag = IDASetId(integrator.mem, integrator.sol.prob.differential_vars) - end - integrator.flag = IDACalcIC(integrator.mem, init_type, integrator.dt) - end + integrator.f(integrator.tmp, integrator.du, integrator.u, integrator.p, integrator.t) + if any(abs.(integrator.tmp) .>= integrator.opts.reltol) + if integrator.sol.prob.differential_vars === nothing && !integrator.alg.init_all + error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") + end + if integrator.alg.init_all + init_type = IDA_Y_INIT + else + init_type = IDA_YA_YDP_INIT + integrator.flag = + IDASetId(integrator.mem, integrator.sol.prob.differential_vars) + end + integrator.flag = IDACalcIC(integrator.mem, init_type, integrator.dt) + end end diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 69ed5c1..18ce55d 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -1,43 +1,57 @@ ## Common Interface Solve Functions function DiffEqBase.__solve( - prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.AbstractDAEProblem}, - alg::algType,timeseries=[],ts=[],ks=[], - recompile::Type{Val{recompile_flag}}=Val{true}; - kwargs...) where {algType<:Union{SundialsODEAlgorithm,SundialsDAEAlgorithm}, - recompile_flag} - - integrator = DiffEqBase.__init(prob,alg,timeseries,ts,ks;kwargs...) - if integrator.sol.retcode == :Default - solve!(integrator) - end - integrator.sol + prob::Union{DiffEqBase.AbstractODEProblem, DiffEqBase.AbstractDAEProblem}, + alg::algType, + timeseries = [], + ts = [], + ks = [], + recompile::Type{Val{recompile_flag}} = Val{true}; + kwargs..., +) where {algType <: Union{SundialsODEAlgorithm, SundialsDAEAlgorithm}, recompile_flag} + + integrator = DiffEqBase.__init(prob, alg, timeseries, ts, ks; kwargs...) + if integrator.sol.retcode == :Default + solve!(integrator) + end + integrator.sol end function DiffEqBase.__init( prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace}, - alg::SundialsODEAlgorithm{Method,LinearSolver}, - timeseries=[], ts=[], ks=[]; - - verbose=true, - callback=nothing, abstol=1/10^6, reltol=1/10^3, - saveat=Float64[], tstops=Float64[], - maxiters=Int(1e5), - dt = nothing, dtmin = 0.0, dtmax = 0.0, - timeseries_errors=true, + alg::SundialsODEAlgorithm{Method, LinearSolver}, + timeseries = [], + ts = [], + ks = []; + verbose = true, + callback = nothing, + abstol = 1 / 10^6, + reltol = 1 / 10^3, + saveat = Float64[], + tstops = Float64[], + maxiters = Int(1e5), + dt = nothing, + dtmin = 0.0, + dtmax = 0.0, + timeseries_errors = true, dense_errors = false, - save_everystep=isempty(saveat), + save_everystep = isempty(saveat), save_on = true, - save_start = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[1] in saveat, - save_end = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : prob.tspan[2] in saveat, + save_start = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : + prob.tspan[1] in saveat, + save_end = save_everystep || isempty(saveat) || typeof(saveat) <: Number ? true : + prob.tspan[2] in saveat, dense = save_everystep && isempty(saveat), - progress=false,progress_name="ODE", + progress = false, + progress_name = "ODE", progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, save_timeseries = nothing, - advance_to_tstop = false,stop_at_next_tstop=false, - userdata=nothing, - alias_u0=false, - kwargs...) where {uType, tupType, isinplace, Method, LinearSolver} + advance_to_tstop = false, + stop_at_next_tstop = false, + userdata = nothing, + alias_u0 = false, + kwargs..., +) where {uType, tupType, isinplace, Method, LinearSolver} tType = eltype(tupType) @@ -54,24 +68,24 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end - progress && Logging.@logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + progress && Logging.@logmsg(-1, progress_name, _id = _id = :Sundials, progress = 0) callbacks_internal = DiffEqBase.CallbackSet(callback) max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) if max_len_cb isa VectorContinuousCallback - callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,Float64,Float64) + callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) else - callback_cache = nothing + callback_cache = nothing end tspan = prob.tspan t0 = tspan[1] - tdir = sign(tspan[2]-tspan[1]) + tdir = sign(tspan[2] - tspan[1]) tstops_internal, saveat_internal = - tstop_saveat_disc_handling(tstops,saveat,tdir,tspan,tType) + tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) if typeof(prob.u0) <: Number u0 = [prob.u0] @@ -86,22 +100,23 @@ function DiffEqBase.__init( sizeu = size(prob.u0) ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0)<:Number + if !isinplace && typeof(prob.u0) <: Number f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:Vector{Float64} + elseif !isinplace && typeof(prob.u0) <: Vector{Float64} f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:AbstractArray + elseif !isinplace && typeof(prob.u0) <: AbstractArray f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0)<:Vector{Float64} + elseif typeof(prob.u0) <: Vector{Float64} f! = prob.f else # Then it's an in-place function on an abstract array f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t); - du=vec(du); 0) + du = vec(du); + 0) end if typeof(alg) <: CVODE_BDF alg_code = CV_BDF - elseif typeof(alg) <: CVODE_Adams + elseif typeof(alg) <: CVODE_Adams alg_code = CV_ADAMS end @@ -115,11 +130,13 @@ function DiffEqBase.__init( (mem_ptr == C_NULL) && error("Failed to allocate CVODE solver object") mem = Handle(mem_ptr) - !verbose && CVodeSetErrHandlerFn(mem,@cfunction(null_error_handler, Nothing, - (Cint, Char, - Char, Ptr{Cvoid})),C_NULL) + !verbose && CVodeSetErrHandlerFn( + mem, + @cfunction(null_error_handler, Nothing, (Cint, Char, Char, Ptr{Cvoid})), + C_NULL, + ) - ures = Vector{uType}() + ures = Vector{uType}() dures = Vector{uType}() save_start ? ts = [t0] : ts = Float64[] @@ -127,14 +144,23 @@ function DiffEqBase.__init( _u0 = copy(u0) utmp = NVector(_u0) - userfun = FunJac(f!,prob.f.jac,prob.p,nothing,prob.f.jac_prototype,alg.prec,alg.psetup,u0,_u0) - - function getcfunf(::T) where T + userfun = FunJac( + f!, + prob.f.jac, + prob.p, + nothing, + prob.f.jac_prototype, + alg.prec, + alg.psetup, + u0, + _u0, + ) + + function getcfunf(::T) where {T} @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) end - flag = CVodeInit(mem,getcfunf(userfun), - t0, convert(N_Vector, utmp)) + flag = CVodeInit(mem, getcfunf(userfun), t0, convert(N_Vector, utmp)) dt != nothing && (flag = CVodeSetInitStep(mem, dt)) flag = CVodeSetMinStep(mem, dtmin) @@ -158,25 +184,25 @@ function DiffEqBase.__init( if Method == :Newton # Only use a linear solver if it's a Newton-based method if LinearSolver in (:Dense, :LapackDense) nojacobian = false - A = SUNDenseMatrix(length(u0),length(u0)) - _A = MatrixHandle(A,DenseMatrix()) + A = SUNDenseMatrix(length(u0), length(u0)) + _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0,A) - _LS = LinSolHandle(LS,Dense()) + LS = SUNLinSol_Dense(u0, A) + _LS = LinSolHandle(LS, Dense()) else - LS = SUNLinSol_LapackDense(u0,A) - _LS = LinSolHandle(LS,LapackDense()) + LS = SUNLinSol_LapackDense(u0, A) + _LS = LinSolHandle(LS, LapackDense()) end elseif LinearSolver in (:Band, :LapackBand) nojacobian = false A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) - _A = MatrixHandle(A,BandMatrix()) + _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0,A) - _LS = LinSolHandle(LS,Band()) + LS = SUNLinSol_Band(u0, A) + _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0,A) - _LS = LinSolHandle(LS,LapackBand()) + LS = SUNLinSol_LapackBand(u0, A) + _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :Diagonal nojacobian = false @@ -186,30 +212,30 @@ function DiffEqBase.__init( elseif LinearSolver == :GMRES LS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = Sundials.LinSolHandle(LS,Sundials.SPGMR()) + _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) elseif LinearSolver == :FGMRES LS = SUNLinSol_SPFGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPFGMR()) + _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG LS = SUNLinSol_SPBCGS(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPBCGS()) + _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG LS = SUNLinSol_PCG(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PCG()) + _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR LS = SUNLinSol_SPTFQMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PTFQMR()) + _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nojacobian = false nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) - A = SUNSparseMatrix(length(u0),length(u0), nnz, CSC_MAT) + A = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) LS = SUNLinSol_KLU(u0, A) - _A = MatrixHandle(A,SparseMatrix()) - _LS = LinSolHandle(LS,KLU()) + _A = MatrixHandle(A, SparseMatrix()) + _LS = LinSolHandle(LS, KLU()) end if LinearSolver !== :Diagonal flag = CVodeSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) @@ -225,62 +251,67 @@ function DiffEqBase.__init( CVodeSetNonlinearSolver(mem, NLS) if DiffEqBase.has_jac(prob.f) && Method == :Newton - function getcfunjac(::T) where T - @cfunction(cvodejac, - Cint, - (realtype, - N_Vector, - N_Vector, - SUNMatrix, - Ref{T}, - N_Vector, - N_Vector, - N_Vector)) - end - jac = getcfunjac(userfun) - flag = CVodeSetUserData(mem, userfun) - nojacobian || (flag = CVodeSetJacFn(mem, jac)) + function getcfunjac(::T) where {T} + @cfunction( + cvodejac, + Cint, + ( + realtype, + N_Vector, + N_Vector, + SUNMatrix, + Ref{T}, + N_Vector, + N_Vector, + N_Vector, + ) + ) + end + jac = getcfunjac(userfun) + flag = CVodeSetUserData(mem, userfun) + nojacobian || (flag = CVodeSetJacFn(mem, jac)) else jac = nothing end if typeof(prob.f.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator - function getcfunjtimes(::T) where T - @cfunction(jactimes, - Cint, - (N_Vector, - N_Vector, - realtype, - N_Vector, - N_Vector, - Ref{T}, - N_Vector)) + function getcfunjtimes(::T) where {T} + @cfunction( + jactimes, + Cint, + (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector) + ) end jtimes = getcfunjtimes(userfun) CVodeSetJacTimes(mem, C_NULL, jtimes) end if alg.prec !== nothing - function getpercfun(::T) where T - @cfunction(precsolve, - Cint, - (Float64, - N_Vector, - N_Vector, - N_Vector, - N_Vector,Float64,Float64,Int, - Ref{T})) + function getpercfun(::T) where {T} + @cfunction( + precsolve, + Cint, + ( + Float64, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + Float64, + Float64, + Int, + Ref{T}, + ) + ) end precfun = getpercfun(userfun) - function getpsetupfun(::T) where T - @cfunction(precsetup, - Cint, - (Float64, - N_Vector, - N_Vector, - Int, - Ptr{Int},Float64,Ref{T})) + function getpsetupfun(::T) where {T} + @cfunction( + precsetup, + Cint, + (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T}) + ) end psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) @@ -292,26 +323,72 @@ function DiffEqBase.__init( tout = [tspan[1]] if save_start - save_value!(ures,u0,uType,sizeu) - if dense - f!(_u0,u0,prob.p,tspan[1]) - save_value!(dures,utmp,uType,sizeu) - end + save_value!(ures, u0, uType, sizeu) + if dense + f!(_u0, u0, prob.p, tspan[1]) + save_value!(dures, utmp, uType, sizeu) + end end - sol = DiffEqBase.build_solution(prob, alg, ts, ures, - dense = dense, - interp = dense ? DiffEqBase.HermiteInterpolation(ts,ures,dures) : - DiffEqBase.LinearInterpolation(ts,ures), - timeseries_errors = timeseries_errors, - destats = DiffEqBase.DEStats(0), - calculate_error = false) - opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, - timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, - progress,progress_name,progress_message,maxiters) - integrator = CVODEIntegrator(u0,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts, - tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,1,callback_cache,0.) + sol = DiffEqBase.build_solution( + prob, + alg, + ts, + ures, + dense = dense, + interp = dense ? DiffEqBase.HermiteInterpolation(ts, ures, dures) : + DiffEqBase.LinearInterpolation(ts, ures), + timeseries_errors = timeseries_errors, + destats = DiffEqBase.DEStats(0), + calculate_error = false, + ) + opts = DEOptions( + saveat_internal, + tstops_internal, + save_everystep, + dense, + timeseries_errors, + dense_errors, + save_on, + save_end, + callbacks_internal, + abstol, + reltol, + verbose, + advance_to_tstop, + stop_at_next_tstop, + progress, + progress_name, + progress_message, + maxiters, + ) + integrator = CVODEIntegrator( + u0, + prob.p, + t0, + t0, + mem, + _LS, + _A, + sol, + alg, + f!, + userfun, + jac, + opts, + tout, + tdir, + sizeu, + false, + tmp, + uprev, + Cint(flag), + false, + 0, + 1, + callback_cache, + 0.0, + ) initialize_callbacks!(integrator) integrator @@ -319,25 +396,37 @@ end # function solve function DiffEqBase.__init( prob::DiffEqBase.AbstractODEProblem{uType, tupType, isinplace}, - alg::ARKODE{Method,LinearSolver,MassLinearSolver}, - timeseries=[], ts=[], ks=[]; - - verbose=true, - callback=nothing, abstol=1/10^6, reltol=1/10^3, - saveat=Float64[], tstops=Float64[], - maxiters=Int(1e5), - dt = nothing, dtmin = 0.0, dtmax = 0.0, - timeseries_errors=true, + alg::ARKODE{Method, LinearSolver, MassLinearSolver}, + timeseries = [], + ts = [], + ks = []; + verbose = true, + callback = nothing, + abstol = 1 / 10^6, + reltol = 1 / 10^3, + saveat = Float64[], + tstops = Float64[], + maxiters = Int(1e5), + dt = nothing, + dtmin = 0.0, + dtmax = 0.0, + timeseries_errors = true, dense_errors = false, - save_everystep=isempty(saveat), dense = save_everystep, - save_on = true, save_start = true, save_end = true, + save_everystep = isempty(saveat), + dense = save_everystep, + save_on = true, + save_start = true, + save_end = true, save_timeseries = nothing, - progress=false,progress_name="ODE", + progress = false, + progress_name = "ODE", progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, - advance_to_tstop = false,stop_at_next_tstop=false, - userdata=nothing, - alias_u0=false, - kwargs...) where {uType, tupType, isinplace, Method, LinearSolver, MassLinearSolver} + advance_to_tstop = false, + stop_at_next_tstop = false, + userdata = nothing, + alias_u0 = false, + kwargs..., +) where {uType, tupType, isinplace, Method, LinearSolver, MassLinearSolver} tType = eltype(tupType) @@ -350,24 +439,24 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end - progress && Logging.@logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + progress && Logging.@logmsg(-1, progress_name, _id = _id = :Sundials, progress = 0) callbacks_internal = DiffEqBase.CallbackSet(callback) max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) if max_len_cb isa VectorContinuousCallback - callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,Float64,Float64) + callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) else - callback_cache = nothing + callback_cache = nothing end tspan = prob.tspan t0 = tspan[1] - tdir = sign(tspan[2]-tspan[1]) + tdir = sign(tspan[2] - tspan[1]) tstops_internal, saveat_internal = - tstop_saveat_disc_handling(tstops,saveat,tdir,tspan,tType) + tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) if typeof(prob.u0) <: Number u0 = [prob.u0] @@ -381,97 +470,115 @@ function DiffEqBase.__init( sizeu = size(prob.u0) - - - ures = Vector{uType}() + ures = Vector{uType}() dures = Vector{uType}() save_start ? ts = [t0] : ts = Float64[] u0nv = NVector(u0) _u0 = copy(u0) utmp = NVector(_u0) - function arkodemem(;fe=C_NULL, fi=C_NULL, t0=t0, u0=convert(N_Vector, u0nv)) + function arkodemem(; fe = C_NULL, fi = C_NULL, t0 = t0, u0 = convert(N_Vector, u0nv)) mem_ptr = ARKStepCreate(fe, fi, t0, u0) (mem_ptr == C_NULL) && error("Failed to allocate ARKODE solver object") mem = Handle(mem_ptr) - !verbose && ARKStepSetErrHandlerFn(mem,@cfunction(null_error_handler, Nothing, - (Cint, Char, - Char, Ptr{Cvoid})),C_NULL) + !verbose && ARKStepSetErrHandlerFn( + mem, + @cfunction(null_error_handler, Nothing, (Cint, Char, Char, Ptr{Cvoid})), + C_NULL, + ) return mem end ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0)<:Number + if !isinplace && typeof(prob.u0) <: Number f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:Vector{Float64} + elseif !isinplace && typeof(prob.u0) <: Vector{Float64} f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:AbstractArray + elseif !isinplace && typeof(prob.u0) <: AbstractArray f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0)<:Vector{Float64} + elseif typeof(prob.u0) <: Vector{Float64} f! = prob.f else # Then it's an in-place function on an abstract array f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t); - du=vec(du); Cint(0)) + du = vec(du); + Cint(0)) end if typeof(prob.problem_type) <: SplitODEProblem ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0)<:Number + if !isinplace && typeof(prob.u0) <: Number f1! = (du, u, p, t) -> (du .= prob.f.f1(first(u), p, t); Cint(0)) f2! = (du, u, p, t) -> (du .= prob.f.f2(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:Vector{Float64} + elseif !isinplace && typeof(prob.u0) <: Vector{Float64} f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0)) f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:AbstractArray + elseif !isinplace && typeof(prob.u0) <: AbstractArray f1! = (du, u, p, t) -> (du .= vec(prob.f.f1(reshape(u, sizeu), p, t)); Cint(0)) f2! = (du, u, p, t) -> (du .= vec(prob.f.f2(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0)<:Vector{Float64} + elseif typeof(prob.u0) <: Vector{Float64} f1! = prob.f.f1 f2! = prob.f.f2 else # Then it's an in-place function on an abstract array - f1! = (du, u, p, t) -> (prob.f.f1(reshape(du, sizeu), reshape(u, sizeu), p, t); - du=vec(du); Cint(0)) - f2! = (du, u, p, t) -> (prob.f.f2(reshape(du, sizeu), reshape(u, sizeu), p, t); - du=vec(du); Cint(0)) + f1! = + (du, u, p, t) -> (prob.f.f1(reshape(du, sizeu), reshape(u, sizeu), p, t); + du = vec(du); + Cint(0)) + f2! = + (du, u, p, t) -> (prob.f.f2(reshape(du, sizeu), reshape(u, sizeu), p, t); + du = vec(du); + Cint(0)) end - userfun = FunJac(f1!,f2!,prob.f.f1.jac,prob.p,prob.f.mass_matrix, - prob.f.f1.jac_prototype,alg.prec,alg.psetup,u0,_u0,nothing) - - function getcfunjac(::T) where T - @cfunction(cvodefunjac, Cint, - (realtype, N_Vector, - N_Vector, Ref{T})) + userfun = FunJac( + f1!, + f2!, + prob.f.f1.jac, + prob.p, + prob.f.mass_matrix, + prob.f.f1.jac_prototype, + alg.prec, + alg.psetup, + u0, + _u0, + nothing, + ) + + function getcfunjac(::T) where {T} + @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) end - function getcfunjac2(::T) where T - @cfunction(cvodefunjac2, Cint, - (realtype, N_Vector, - N_Vector, Ref{T})) + function getcfunjac2(::T) where {T} + @cfunction(cvodefunjac2, Cint, (realtype, N_Vector, N_Vector, Ref{T})) end cfj1 = getcfunjac(userfun) cfj2 = getcfunjac2(userfun) - mem = arkodemem(fi=cfj1, fe=cfj2) + mem = arkodemem(fi = cfj1, fe = cfj2) else - userfun = FunJac(f!,prob.f.jac,prob.p,prob.f.mass_matrix,prob.f.jac_prototype,alg.prec,alg.psetup,u0,_u0) + userfun = FunJac( + f!, + prob.f.jac, + prob.p, + prob.f.mass_matrix, + prob.f.jac_prototype, + alg.prec, + alg.psetup, + u0, + _u0, + ) if alg.stiffness == Explicit() - function getcfun1(::T) where T - @cfunction(cvodefunjac, Cint, - (realtype, N_Vector, - N_Vector, Ref{T})) + function getcfun1(::T) where {T} + @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) end cfj1 = getcfun1(userfun) - mem = arkodemem(fe=cfj1) + mem = arkodemem(fe = cfj1) elseif alg.stiffness == Implicit() - function getcfun2(::T) where T - @cfunction(cvodefunjac, Cint, - (realtype, N_Vector, - N_Vector, Ref{T})) + function getcfun2(::T) where {T} + @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) end cfj2 = getcfun2(userfun) - mem = arkodemem(fi=cfj2) + mem = arkodemem(fi = cfj2) end end @@ -491,7 +598,7 @@ function DiffEqBase.__init( flag = ARKStepSetMaxConvFails(mem, alg.max_convergence_failures) flag = ARKStepSetPredictorMethod(mem, alg.predictor_method) flag = ARKStepSetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient) - flag = ARKStepSetDenseOrder(mem,alg.dense_order) + flag = ARKStepSetDenseOrder(mem, alg.dense_order) #= Reference from Manual on ARKODE @@ -499,7 +606,7 @@ function DiffEqBase.__init( To select an implicit table, set etable to a negative value. This automatically calls ARKStepSetImplicit(). If both itable and etable are non-negative, then these should match an existing implicit/explicit pair, listed in the section Additive Butcher tables. This automatically calls ARKStepSetImEx(). =# if alg.itable == nothing && alg.etable == nothing - flag = ARKStepSetOrder(mem,alg.order) + flag = ARKStepSetOrder(mem, alg.order) elseif alg.itable == nothing && alg.etable != nothing flag = ARKStepSetTableNum(mem, -1, alg.etable) elseif alg.itable != nothing && alg.etable == nothing @@ -508,65 +615,64 @@ function DiffEqBase.__init( flag = ARKStepSetTableNum(mem, alg.itable, alg.etable) end - flag = ARKStepSetNonlinCRDown(mem,alg.crdown) + flag = ARKStepSetNonlinCRDown(mem, alg.crdown) flag = ARKStepSetNonlinRDiv(mem, alg.rdiv) flag = ARKStepSetDeltaGammaMax(mem, alg.dgmax) flag = ARKStepSetMaxStepsBetweenLSet(mem, alg.msbp) #flag = ARKStepSetAdaptivityMethod(mem,alg.adaptivity_method,1,0) - #flag = ARKStepSetFixedStep(mem,) alg.set_optimal_params && (flag = ARKStepSetOptimalParams(mem)) if Method == :Newton # Only use a linear solver if it's a Newton-based method if LinearSolver in (:Dense, :LapackDense) nojacobian = false - A = SUNDenseMatrix(length(u0),length(u0)) - _A = MatrixHandle(A,DenseMatrix()) + A = SUNDenseMatrix(length(u0), length(u0)) + _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0,A) - _LS = LinSolHandle(LS,Dense()) + LS = SUNLinSol_Dense(u0, A) + _LS = LinSolHandle(LS, Dense()) else - LS = SUNLinSol_LapackDense(u0,A) - _LS = LinSolHandle(LS,LapackDense()) + LS = SUNLinSol_LapackDense(u0, A) + _LS = LinSolHandle(LS, LapackDense()) end elseif LinearSolver in (:Band, :LapackBand) nojacobian = false A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) - _A = MatrixHandle(A,BandMatrix()) + _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0,A) - _LS = LinSolHandle(LS,Band()) + LS = SUNLinSol_Band(u0, A) + _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0,A) - _LS = LinSolHandle(LS,LapackBand()) + LS = SUNLinSol_LapackBand(u0, A) + _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :GMRES LS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = Sundials.LinSolHandle(LS,Sundials.SPGMR()) + _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) elseif LinearSolver == :FGMRES LS = SUNLinSol_SPFGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPFGMR()) + _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG LS = SUNLinSol_SPBCGS(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPBCGS()) + _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG LS = SUNLinSol_PCG(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PCG()) + _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR LS = SUNLinSol_SPTFQMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PTFQMR()) + _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) - A = SUNSparseMatrix(length(u0),length(u0), nnz, CSC_MAT) + A = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) LS = SUNLinSol_KLU(u0, A) - _A = MatrixHandle(A,SparseMatrix()) - _LS = LinSolHandle(LS,KLU()) + _A = MatrixHandle(A, SparseMatrix()) + _LS = LinSolHandle(LS, KLU()) end flag = ARKStepSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) elseif Method == :Functional @@ -576,20 +682,19 @@ function DiffEqBase.__init( _LS = nothing end - if (typeof(prob.problem_type) <: SplitODEProblem && - typeof(prob.f.f1.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator) || - (!(typeof(prob.problem_type) <: SplitODEProblem) && - typeof(prob.f.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator) - function getcfunjtimes(::T) where T - @cfunction(jactimes, - Cint, - (N_Vector, - N_Vector, - realtype, - N_Vector, - N_Vector, - Ref{T}, - N_Vector)) + if ( + typeof(prob.problem_type) <: SplitODEProblem && + typeof(prob.f.f1.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator + ) || ( + !(typeof(prob.problem_type) <: SplitODEProblem) && + typeof(prob.f.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator + ) + function getcfunjtimes(::T) where {T} + @cfunction( + jactimes, + Cint, + (N_Vector, N_Vector, realtype, N_Vector, N_Vector, Ref{T}, N_Vector) + ) end jtimes = getcfunjtimes(userfun) ARKStepSetJacTimes(mem, C_NULL, jtimes) @@ -598,112 +703,118 @@ function DiffEqBase.__init( if prob.f.mass_matrix != LinearAlgebra.I if MassLinearSolver in (:Dense, :LapackDense) nojacobian = false - M = SUNDenseMatrix(length(u0),length(u0)) - _M = MatrixHandle(M,DenseMatrix()) + M = SUNDenseMatrix(length(u0), length(u0)) + _M = MatrixHandle(M, DenseMatrix()) if MassLinearSolver === :Dense - MLS = SUNLinSol_Dense(u0,M) - _MLS = LinSolHandle(MLS,Dense()) + MLS = SUNLinSol_Dense(u0, M) + _MLS = LinSolHandle(MLS, Dense()) else - MLS = SUNLinSol_LapackDense(u0,M) - _MLS = LinSolHandle(MLS,LapackDense()) + MLS = SUNLinSol_LapackDense(u0, M) + _MLS = LinSolHandle(MLS, LapackDense()) end elseif MassLinearSolver in (:Band, :LapackBand) nojacobian = false M = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) - _M = MatrixHandle(M,BandMatrix()) + _M = MatrixHandle(M, BandMatrix()) if MassLinearSolver === :Band - MLS = SUNLinSol_Band(u0,M) - _MLS = LinSolHandle(MLS,Band()) + MLS = SUNLinSol_Band(u0, M) + _MLS = LinSolHandle(MLS, Band()) else - MLS = SUNLinSol_LapackBand(u0,M) - _MLS = LinSolHandle(MLS,LapackBand()) + MLS = SUNLinSol_LapackBand(u0, M) + _MLS = LinSolHandle(MLS, LapackBand()) end elseif MassLinearSolver == :GMRES MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) _M = nothing - _MLS = LinSolHandle(MLS,SPGMR()) + _MLS = LinSolHandle(MLS, SPGMR()) elseif MassLinearSolver == :FGMRES MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) _M = nothing - _MLS = LinSolHandle(MLS,SPFGMR()) + _MLS = LinSolHandle(MLS, SPFGMR()) elseif MassLinearSolver == :BCG MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) _M = nothing - _MLS = LinSolHandle(MLS,SPBCGS()) + _MLS = LinSolHandle(MLS, SPBCGS()) elseif MassLinearSolver == :PCG MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) _M = nothing - _MLS = LinSolHandle(MLS,PCG()) + _MLS = LinSolHandle(MLS, PCG()) elseif MassLinearSolver == :TFQMR MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) _M = nothing - _MLS = LinSolHandle(MLS,PTFQMR()) + _MLS = LinSolHandle(MLS, PTFQMR()) elseif MassLinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.mass_matrix)) - M = SUNSparseMatrix(length(u0),length(u0), nnz, CSC_MAT) + M = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) MLS = SUNLinSol_KLU(u0, M) - _M = MatrixHandle(M,SparseMatrix()) - _MLS = LinSolHandle(MLS,KLU()) + _M = MatrixHandle(M, SparseMatrix()) + _MLS = LinSolHandle(MLS, KLU()) end flag = ARKStepSetMassLinearSolver(mem, MLS, _M === nothing ? C_NULL : M, false) - function getmatfun(::T) where T - @cfunction(massmat, - Cint, - (realtype, - SUNMatrix, - Ref{T}, - N_Vector, - N_Vector, - N_Vector)) + function getmatfun(::T) where {T} + @cfunction( + massmat, + Cint, + (realtype, SUNMatrix, Ref{T}, N_Vector, N_Vector, N_Vector) + ) end matfun = getmatfun(userfun) - ARKStepSetMassFn(mem,matfun) + ARKStepSetMassFn(mem, matfun) else _M = nothing _MLS = nothing end if DiffEqBase.has_jac(prob.f) - function getfunjac(::T) where T - @cfunction(cvodejac, - Cint, - (realtype, - N_Vector, - N_Vector, - SUNMatrix, - Ref{T}, - N_Vector, - N_Vector, - N_Vector)) - end - jac = getfunjac(userfun) - flag = ARKStepSetUserData(mem, userfun) - flag = ARKStepSetJacFn(mem, jac) + function getfunjac(::T) where {T} + @cfunction( + cvodejac, + Cint, + ( + realtype, + N_Vector, + N_Vector, + SUNMatrix, + Ref{T}, + N_Vector, + N_Vector, + N_Vector, + ) + ) + end + jac = getfunjac(userfun) + flag = ARKStepSetUserData(mem, userfun) + flag = ARKStepSetJacFn(mem, jac) else jac = nothing end if alg.prec !== nothing - function getpercfun(::T) where T - @cfunction(precsolve, - Cint, - (Float64, - N_Vector, - N_Vector, - N_Vector, - N_Vector,Float64,Float64,Int, - Ref{T})) + function getpercfun(::T) where {T} + @cfunction( + precsolve, + Cint, + ( + Float64, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + Float64, + Float64, + Int, + Ref{T}, + ) + ) end precfun = getpercfun(userfun) - function getpsetupfun(::T) where T - @cfunction(precsetup, - Cint, - (Float64, - N_Vector, - N_Vector, - Int, - Ptr{Int},Float64,Ref{T})) + function getpsetupfun(::T) where {T} + @cfunction( + precsetup, + Cint, + (Float64, N_Vector, N_Vector, Int, Ptr{Int}, Float64, Ref{T}) + ) end psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) @@ -715,64 +826,127 @@ function DiffEqBase.__init( tout = [tspan[1]] if save_start - save_value!(ures,u0,uType,sizeu) - if dense - f!(_u0,u0,prob.p,tspan[1]) - save_value!(dures,utmp,uType,sizeu) - end + save_value!(ures, u0, uType, sizeu) + if dense + f!(_u0, u0, prob.p, tspan[1]) + save_value!(dures, utmp, uType, sizeu) + end end - sol = DiffEqBase.build_solution(prob, alg, ts, ures, - dense = dense, - interp = dense ? DiffEqBase.HermiteInterpolation(ts,ures,dures) : - DiffEqBase.LinearInterpolation(ts,ures), - timeseries_errors = timeseries_errors, - destats = DiffEqBase.DEStats(0), - calculate_error = false) - opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, - timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, - progress,progress_name,progress_message,maxiters) - integrator = ARKODEIntegrator(utmp,prob.p,t0,t0,mem,_LS,_A,_MLS,_M,sol,alg,f!,userfun,jac,opts, - tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0,1,callback_cache,0.) + sol = DiffEqBase.build_solution( + prob, + alg, + ts, + ures, + dense = dense, + interp = dense ? DiffEqBase.HermiteInterpolation(ts, ures, dures) : + DiffEqBase.LinearInterpolation(ts, ures), + timeseries_errors = timeseries_errors, + destats = DiffEqBase.DEStats(0), + calculate_error = false, + ) + opts = DEOptions( + saveat_internal, + tstops_internal, + save_everystep, + dense, + timeseries_errors, + dense_errors, + save_on, + save_end, + callbacks_internal, + abstol, + reltol, + verbose, + advance_to_tstop, + stop_at_next_tstop, + progress, + progress_name, + progress_message, + maxiters, + ) + integrator = ARKODEIntegrator( + utmp, + prob.p, + t0, + t0, + mem, + _LS, + _A, + _MLS, + _M, + sol, + alg, + f!, + userfun, + jac, + opts, + tout, + tdir, + sizeu, + false, + tmp, + uprev, + Cint(flag), + false, + 0, + 1, + callback_cache, + 0.0, + ) initialize_callbacks!(integrator) integrator end # function solve -function tstop_saveat_disc_handling(tstops,saveat,tdir,tspan,tType) +function tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) - if isempty(tstops) # TODO: Specialize more - tstops_vec = [tspan[2]] - else - tstops_vec = vec(collect(tType,Iterators.filter(x->tdir*tspan[1] tdir * tspan[1] < tdir * x ≤ tdir * tspan[end], + Iterators.flatten((tstops, tspan[end])), + ), + )) + end - if tdir>0 - tstops_internal = DataStructures.BinaryMinHeap(tstops_vec) - else - tstops_internal = DataStructures.BinaryMaxHeap(tstops_vec) - end + if tdir > 0 + tstops_internal = DataStructures.BinaryMinHeap(tstops_vec) + else + tstops_internal = DataStructures.BinaryMaxHeap(tstops_vec) + end - if typeof(saveat) <: Number - if (tspan[1]:saveat:tspan[end])[end] == tspan[end] - saveat_vec = convert(Vector{tType},collect(tType,tspan[1]+saveat:saveat:tspan[end])) + if typeof(saveat) <: Number + if (tspan[1]:saveat:tspan[end])[end] == tspan[end] + saveat_vec = convert( + Vector{tType}, + collect(tType, (tspan[1] + saveat):saveat:tspan[end]), + ) + else + saveat_vec = convert( + Vector{tType}, + collect(tType, (tspan[1] + saveat):saveat:(tspan[end] - saveat)), + ) + end + elseif isempty(saveat) + saveat_vec = saveat else - saveat_vec = convert(Vector{tType},collect(tType,tspan[1]+saveat:saveat:(tspan[end]-saveat))) + saveat_vec = vec(collect( + tType, + Iterators.filter(x -> tdir * tspan[1] < tdir * x < tdir * tspan[end], saveat), + )) end - elseif isempty(saveat) - saveat_vec = saveat - else - saveat_vec = vec(collect(tType,Iterators.filter(x->tdir*tspan[1]0 - saveat_internal = DataStructures.BinaryMinHeap(saveat_vec) - else - saveat_internal = DataStructures.BinaryMaxHeap(saveat_vec) - end - - tstops_internal,saveat_internal + + if tdir > 0 + saveat_internal = DataStructures.BinaryMinHeap(saveat_vec) + else + saveat_internal = DataStructures.BinaryMaxHeap(saveat_vec) + end + + tstops_internal, saveat_internal end ## Solve for DAEs uses IDA @@ -780,22 +954,34 @@ end function DiffEqBase.__init( prob::DiffEqBase.AbstractDAEProblem{uType, duType, tupType, isinplace}, alg::SundialsDAEAlgorithm{LinearSolver}, - timeseries=[], ts=[], ks=[]; - - verbose=true, - dt=nothing, dtmax=0.0, - save_on=true, save_start=true, - callback=nothing, abstol=1/10^6, reltol=1/10^3, - saveat=Float64[], tstops=Float64[], maxiters=Int(1e5), - timeseries_errors=true, + timeseries = [], + ts = [], + ks = []; + verbose = true, + dt = nothing, + dtmax = 0.0, + save_on = true, + save_start = true, + callback = nothing, + abstol = 1 / 10^6, + reltol = 1 / 10^3, + saveat = Float64[], + tstops = Float64[], + maxiters = Int(1e5), + timeseries_errors = true, dense_errors = false, - save_everystep=isempty(saveat), dense=save_everystep, - save_timeseries=nothing, save_end = true, - progress=false,progress_name="ODE", + save_everystep = isempty(saveat), + dense = save_everystep, + save_timeseries = nothing, + save_end = true, + progress = false, + progress_name = "ODE", progress_message = DiffEqBase.ODE_DEFAULT_PROG_MESSAGE, - advance_to_tstop = false, stop_at_next_tstop = false, - userdata=nothing, - kwargs...) where {uType, duType, tupType, isinplace, LinearSolver} + advance_to_tstop = false, + stop_at_next_tstop = false, + userdata = nothing, + kwargs..., +) where {uType, duType, tupType, isinplace, LinearSolver} tType = eltype(tupType) @@ -808,24 +994,24 @@ function DiffEqBase.__init( error("Sundials only allows scalar reltol.") end - progress && Logging.@logmsg(-1,progress_name,_id=_id = :Sundials,progress=0) + progress && Logging.@logmsg(-1, progress_name, _id = _id = :Sundials, progress = 0) callbacks_internal = DiffEqBase.CallbackSet(callback) max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) if max_len_cb isa VectorContinuousCallback - callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,Float64,Float64) + callback_cache = DiffEqBase.CallbackCache(max_len_cb.len, Float64, Float64) else - callback_cache = nothing + callback_cache = nothing end tspan = prob.tspan t0 = tspan[1] - tdir = sign(tspan[2]-tspan[1]) + tdir = sign(tspan[2] - tspan[1]) tstops_internal, saveat_internal = - tstop_saveat_disc_handling(tstops,saveat,tdir,tspan,tType) + tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) if typeof(prob.u0) <: Number u0 = [prob.u0] @@ -843,32 +1029,37 @@ function DiffEqBase.__init( sizedu = size(prob.du0) ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0)<:Number - f! = (out, du, u, p, t) -> (out .= prob.f(first(du),first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:Vector{Float64} + if !isinplace && typeof(prob.u0) <: Number + f! = (out, du, u, p, t) -> (out .= prob.f(first(du), first(u), p, t); Cint(0)) + elseif !isinplace && typeof(prob.u0) <: Vector{Float64} f! = (out, du, u, p, t) -> (out .= prob.f(du, u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0)<:AbstractArray - f! = (out, du, u, p, t) -> (out .= vec( - prob.f(reshape(du, sizedu), reshape(u, sizeu), p, t) - );Cint(0)) - elseif typeof(prob.u0)<:Vector{Float64} + elseif !isinplace && typeof(prob.u0) <: AbstractArray + f! = + (out, du, u, p, t) -> + (out .= vec(prob.f(reshape(du, sizedu), reshape(u, sizeu), p, t)); + Cint(0)) + elseif typeof(prob.u0) <: Vector{Float64} f! = prob.f else # Then it's an in-place function on an abstract array - f! = (out, du, u, p, t) -> (prob.f(reshape(out, sizeu), reshape(du, sizedu), - reshape(u, sizeu), p, t); Cint(0)) + f! = + (out, du, u, p, t) -> + (prob.f(reshape(out, sizeu), reshape(du, sizedu), reshape(u, sizeu), p, t); + Cint(0)) end mem_ptr = IDACreate() (mem_ptr == C_NULL) && error("Failed to allocate IDA solver object") mem = Handle(mem_ptr) - !verbose && IDASetErrHandlerFn(mem,@cfunction(null_error_handler, Nothing, - (Cint, Char, - Char, Ptr{Cvoid})),C_NULL) + !verbose && IDASetErrHandlerFn( + mem, + @cfunction(null_error_handler, Nothing, (Cint, Char, Char, Ptr{Cvoid})), + C_NULL, + ) ures = Vector{uType}() dures = Vector{uType}() - ts = [t0] + ts = [t0] _u0 = copy(u0) utmp = NVector(_u0) @@ -876,19 +1067,26 @@ function DiffEqBase.__init( dutmp = NVector(_du0) rtest = zeros(length(u0)) - userfun = FunJac(f!,prob.f.jac,prob.p,nothing,prob.f.jac_prototype,alg.prec,alg.psetup,_u0,_du0,rtest) + userfun = FunJac( + f!, + prob.f.jac, + prob.p, + nothing, + prob.f.jac_prototype, + alg.prec, + alg.psetup, + _u0, + _du0, + rtest, + ) u0nv = NVector(u0) - function getcfun(::T) where T - @cfunction(idasolfun, - Cint, (realtype, N_Vector, N_Vector, - N_Vector, Ref{T})) + function getcfun(::T) where {T} + @cfunction(idasolfun, Cint, (realtype, N_Vector, N_Vector, N_Vector, Ref{T})) end cfun = getcfun(userfun) - flag = IDAInit(mem, cfun, - t0, convert(N_Vector, utmp), - convert(N_Vector, dutmp)) + flag = IDAInit(mem, cfun, t0, convert(N_Vector, utmp), convert(N_Vector, dutmp)) dt != nothing && (flag = IDASetInitStep(mem, dt)) flag = IDASetUserData(mem, userfun) flag = IDASetMaxStep(mem, dtmax) @@ -898,104 +1096,119 @@ function DiffEqBase.__init( flag = IDASStolerances(mem, reltol, abstol) end flag = IDASetMaxNumSteps(mem, maxiters) - flag = IDASetMaxOrd(mem,alg.max_order) - flag = IDASetMaxErrTestFails(mem,alg.max_error_test_failures) - flag = IDASetNonlinConvCoef(mem,alg.nonlinear_convergence_coefficient) - flag = IDASetMaxNonlinIters(mem,alg.max_nonlinear_iters) - flag = IDASetMaxConvFails(mem,alg.max_convergence_failures) - flag = IDASetNonlinConvCoefIC(mem,alg.nonlinear_convergence_coefficient_ic) - flag = IDASetMaxNumStepsIC(mem,alg.max_num_steps_ic) - flag = IDASetMaxNumJacsIC(mem,alg.max_num_jacs_ic) - flag = IDASetMaxNumItersIC(mem,alg.max_num_iters_ic) + flag = IDASetMaxOrd(mem, alg.max_order) + flag = IDASetMaxErrTestFails(mem, alg.max_error_test_failures) + flag = IDASetNonlinConvCoef(mem, alg.nonlinear_convergence_coefficient) + flag = IDASetMaxNonlinIters(mem, alg.max_nonlinear_iters) + flag = IDASetMaxConvFails(mem, alg.max_convergence_failures) + flag = IDASetNonlinConvCoefIC(mem, alg.nonlinear_convergence_coefficient_ic) + flag = IDASetMaxNumStepsIC(mem, alg.max_num_steps_ic) + flag = IDASetMaxNumJacsIC(mem, alg.max_num_jacs_ic) + flag = IDASetMaxNumItersIC(mem, alg.max_num_iters_ic) #flag = IDASetMaxBacksIC(mem,alg.max_num_backs_ic) # Needs newer version? - flag = IDASetLineSearchOffIC(mem,alg.use_linesearch_ic) + flag = IDASetLineSearchOffIC(mem, alg.use_linesearch_ic) if LinearSolver in (:Dense, :LapackDense) nojacobian = false - A = SUNDenseMatrix(length(u0),length(u0)) - _A = MatrixHandle(A,DenseMatrix()) + A = SUNDenseMatrix(length(u0), length(u0)) + _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0,A) - _LS = LinSolHandle(LS,Dense()) + LS = SUNLinSol_Dense(u0, A) + _LS = LinSolHandle(LS, Dense()) else - LS = SUNLinSol_LapackDense(u0,A) - _LS = LinSolHandle(LS,LapackDense()) + LS = SUNLinSol_LapackDense(u0, A) + _LS = LinSolHandle(LS, LapackDense()) end elseif LinearSolver in (:Band, :LapackBand) nojacobian = false A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) - _A = MatrixHandle(A,BandMatrix()) + _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0,A) - _LS = LinSolHandle(LS,Band()) + LS = SUNLinSol_Band(u0, A) + _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0,A) - _LS = LinSolHandle(LS,LapackBand()) + LS = SUNLinSol_LapackBand(u0, A) + _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :GMRES LS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPGMR()) + _LS = LinSolHandle(LS, SPGMR()) elseif LinearSolver == :FGMRES LS = SUNLinSol_SPFGMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPFGMR()) + _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG LS = SUNLinSol_SPBCGS(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,SPBCGS()) + _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG LS = SUNLinSol_PCG(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PCG()) + _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR LS = SUNLinSol_SPTFQMR(u0, alg.prec_side, alg.krylov_dim) _A = nothing - _LS = LinSolHandle(LS,PTFQMR()) + _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) - A = SUNSparseMatrix(length(u0),length(u0), nnz, Sundials.CSC_MAT) + A = SUNSparseMatrix(length(u0), length(u0), nnz, Sundials.CSC_MAT) LS = SUNLinSol_KLU(u0, A) - _A = MatrixHandle(A,SparseMatrix()) - _LS = LinSolHandle(LS,KLU()) + _A = MatrixHandle(A, SparseMatrix()) + _LS = LinSolHandle(LS, KLU()) end flag = IDASetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) if typeof(prob.f.jac_prototype) <: DiffEqBase.AbstractDiffEqLinearOperator - function getcfunjtimes(::T) where T - @cfunction(idajactimes, - Cint, - (realtype, - N_Vector,N_Vector,N_Vector,N_Vector,N_Vector, - realtype, - Ref{T}, - N_Vector,N_Vector)) + function getcfunjtimes(::T) where {T} + @cfunction( + idajactimes, + Cint, + ( + realtype, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + realtype, + Ref{T}, + N_Vector, + N_Vector, + ) + ) end jtimes = getcfunjtimes(userfun) IDASetJacTimes(mem, C_NULL, jtimes) end if alg.prec !== nothing - function getpercfun(::T) where T - @cfunction(idaprecsolve, - Cint, - (Float64, - N_Vector, - N_Vector, - N_Vector, - N_Vector,N_Vector,Float64,Float64,Int, - Ref{T})) + function getpercfun(::T) where {T} + @cfunction( + idaprecsolve, + Cint, + ( + Float64, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + N_Vector, + Float64, + Float64, + Int, + Ref{T}, + ) + ) end precfun = getpercfun(userfun) - function getpsetupfun(::T) where T - @cfunction(idaprecsetup, - Cint, - (Float64, - N_Vector, - N_Vector, - N_Vector, - Float64,Ref{T})) + function getpsetupfun(::T) where {T} + @cfunction( + idaprecsetup, + Cint, + (Float64, N_Vector, N_Vector, N_Vector, Float64, Ref{T}) + ) end psetupfun = alg.psetup === nothing ? C_NULL : getpsetupfun(userfun) @@ -1003,25 +1216,29 @@ function DiffEqBase.__init( end if DiffEqBase.has_jac(prob.f) - function getcfunjacc(::T) where T - @cfunction(idajac, - Cint, - (realtype, - realtype, - N_Vector, - N_Vector, - N_Vector, - SUNMatrix, - Ref{T}, - N_Vector, - N_Vector, - N_Vector)) - end - jac = getcfunjacc(userfun) - flag = IDASetUserData(mem, userfun) - flag = IDASetJacFn(mem, jac) + function getcfunjacc(::T) where {T} + @cfunction( + idajac, + Cint, + ( + realtype, + realtype, + N_Vector, + N_Vector, + N_Vector, + SUNMatrix, + Ref{T}, + N_Vector, + N_Vector, + N_Vector, + ) + ) + end + jac = getcfunjacc(userfun) + flag = IDASetUserData(mem, userfun) + flag = IDASetJacFn(mem, jac) else - jac = nothing + jac = nothing end tout = [tspan[1]] @@ -1031,8 +1248,8 @@ function DiffEqBase.__init( if prob.differential_vars === nothing && !alg.init_all error("Must supply differential_vars argument to DAEProblem constructor to use IDA initial value solver.") end - prob.differential_vars != nothing && (flag = IDASetId(mem, collect(Float64, prob.differential_vars))) - + prob.differential_vars != nothing && + (flag = IDASetId(mem, collect(Float64, prob.differential_vars))) if dt != nothing _t = float(dt) @@ -1048,10 +1265,10 @@ function DiffEqBase.__init( end if save_start - save_value!(ures,u0,uType,sizeu) - if dense - save_value!(dures,du0,uType,sizedu) # Does this need to update for IDACalcIC? - end + save_value!(ures, u0, uType, sizeu) + if dense + save_value!(dures, du0, uType, sizedu) # Does this need to update for IDACalcIC? + end end callbacks_internal == nothing ? tmp = nothing : tmp = similar(u0) @@ -1063,23 +1280,71 @@ function DiffEqBase.__init( retcode = :InitialFailure end - sol = DiffEqBase.build_solution(prob, alg, ts, ures, - dense = dense, - interp = dense ? DiffEqBase.HermiteInterpolation(ts,ures,dures) : - DiffEqBase.LinearInterpolation(ts,ures), - calculate_error = false, - timeseries_errors = timeseries_errors, - retcode = retcode, - destats = DiffEqBase.DEStats(0), - dense_errors = dense_errors) - - opts = DEOptions(saveat_internal,tstops_internal,save_everystep,dense, - timeseries_errors,dense_errors,save_on,save_end, - callbacks_internal,abstol,reltol,verbose,advance_to_tstop,stop_at_next_tstop, - progress,progress_name,progress_message,maxiters) - - integrator = IDAIntegrator(utmp,dutmp,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts, - tout,tdir,sizeu,sizedu,false,tmp,uprev,Cint(flag),false,0,1,callback_cache,0.) + sol = DiffEqBase.build_solution( + prob, + alg, + ts, + ures, + dense = dense, + interp = dense ? DiffEqBase.HermiteInterpolation(ts, ures, dures) : + DiffEqBase.LinearInterpolation(ts, ures), + calculate_error = false, + timeseries_errors = timeseries_errors, + retcode = retcode, + destats = DiffEqBase.DEStats(0), + dense_errors = dense_errors, + ) + + opts = DEOptions( + saveat_internal, + tstops_internal, + save_everystep, + dense, + timeseries_errors, + dense_errors, + save_on, + save_end, + callbacks_internal, + abstol, + reltol, + verbose, + advance_to_tstop, + stop_at_next_tstop, + progress, + progress_name, + progress_message, + maxiters, + ) + + integrator = IDAIntegrator( + utmp, + dutmp, + prob.p, + t0, + t0, + mem, + _LS, + _A, + sol, + alg, + f!, + userfun, + jac, + opts, + tout, + tdir, + sizeu, + sizedu, + false, + tmp, + uprev, + Cint(flag), + false, + 0, + 1, + callback_cache, + 0.0, + ) initialize_callbacks!(integrator) integrator @@ -1088,63 +1353,92 @@ end # function solve ## Common calls function interpret_sundials_retcode(flag) - flag >= 0 && return :Success - flag == -1 && return :MaxIters - (flag == -2 || flag == -3) && return :Unstable - flag == -4 && return :ConvergenceFailure - return :Failure + flag >= 0 && return :Success + flag == -1 && return :MaxIters + (flag == -2 || flag == -3) && return :Unstable + flag == -4 && return :ConvergenceFailure + return :Failure end -function solver_step(integrator::CVODEIntegrator,tstop) - integrator.flag = CVode(integrator.mem, tstop, integrator.u, integrator.tout, CV_ONE_STEP) +function solver_step(integrator::CVODEIntegrator, tstop) + integrator.flag = + CVode(integrator.mem, tstop, integrator.u, integrator.tout, CV_ONE_STEP) if integrator.opts.progress - Logging.@logmsg(-1, - integrator.opts.progress_name, - _id = :Sundials, - message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), - progress=integrator.t/integrator.sol.prob.tspan[2]) + Logging.@logmsg( + -1, + integrator.opts.progress_name, + _id = :Sundials, + message = integrator.opts.progress_message( + integrator.dt, + integrator.u, + integrator.p, + integrator.t, + ), + progress = integrator.t / integrator.sol.prob.tspan[2] + ) end end -function solver_step(integrator::ARKODEIntegrator,tstop) - integrator.flag = ARKStepEvolve(integrator.mem, tstop, integrator.u, integrator.tout, ARK_ONE_STEP) +function solver_step(integrator::ARKODEIntegrator, tstop) + integrator.flag = + ARKStepEvolve(integrator.mem, tstop, integrator.u, integrator.tout, ARK_ONE_STEP) if integrator.opts.progress - Logging.@logmsg(-1, - integrator.opts.progress_name, - _id = :Sundials, - message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), - progress=integrator.t/integrator.sol.prob.tspan[2]) + Logging.@logmsg( + -1, + integrator.opts.progress_name, + _id = :Sundials, + message = integrator.opts.progress_message( + integrator.dt, + integrator.u, + integrator.p, + integrator.t, + ), + progress = integrator.t / integrator.sol.prob.tspan[2] + ) end end -function solver_step(integrator::IDAIntegrator,tstop) - integrator.flag = IDASolve(integrator.mem, tstop, integrator.tout, - integrator.u, integrator.du, IDA_ONE_STEP) +function solver_step(integrator::IDAIntegrator, tstop) + integrator.flag = IDASolve( + integrator.mem, + tstop, + integrator.tout, + integrator.u, + integrator.du, + IDA_ONE_STEP, + ) if integrator.opts.progress - Logging.@logmsg(-1, - integrator.opts.progress_name, - _id = :Sundials, - message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), - progress=integrator.t/integrator.sol.prob.tspan[2]) + Logging.@logmsg( + -1, + integrator.opts.progress_name, + _id = :Sundials, + message = integrator.opts.progress_message( + integrator.dt, + integrator.u, + integrator.p, + integrator.t, + ), + progress = integrator.t / integrator.sol.prob.tspan[2] + ) end end -function set_stop_time(integrator::CVODEIntegrator,tstop) - CVodeSetStopTime(integrator.mem,tstop) +function set_stop_time(integrator::CVODEIntegrator, tstop) + CVodeSetStopTime(integrator.mem, tstop) end -function set_stop_time(integrator::ARKODEIntegrator,tstop) - ARKStepSetStopTime(integrator.mem,tstop) +function set_stop_time(integrator::ARKODEIntegrator, tstop) + ARKStepSetStopTime(integrator.mem, tstop) end -function set_stop_time(integrator::IDAIntegrator,tstop) - IDASetStopTime(integrator.mem,tstop) +function set_stop_time(integrator::IDAIntegrator, tstop) + IDASetStopTime(integrator.mem, tstop) end -function get_iters!(integrator::CVODEIntegrator,iters) - CVodeGetNumSteps(integrator.mem,iters) +function get_iters!(integrator::CVODEIntegrator, iters) + CVodeGetNumSteps(integrator.mem, iters) end -function get_iters!(integrator::ARKODEIntegrator,iters) - ARKStepGetNumSteps(integrator.mem,iters) +function get_iters!(integrator::ARKODEIntegrator, iters) + ARKStepGetNumSteps(integrator.mem, iters) end -function get_iters!(integrator::IDAIntegrator,iters) - IDAGetNumSteps(integrator.mem,iters) +function get_iters!(integrator::IDAIntegrator, iters) + IDAGetNumSteps(integrator.mem, iters) end function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator) @@ -1154,21 +1448,21 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator) # Sundials can have floating point issues approaching a tstop if # there is a modifying event each # The call to first is an overload of Base.first implemented in DataStructures - while integrator.tdir*(integrator.t-first(integrator.opts.tstops)) < -1e6eps() + while integrator.tdir * (integrator.t - first(integrator.opts.tstops)) < -1e6eps() tstop = first(integrator.opts.tstops) - set_stop_time(integrator,tstop) + set_stop_time(integrator, tstop) integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks)<:Tuple{}) + if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) integrator.uprev .= integrator.u end integrator.userfun.p = integrator.p - solver_step(integrator,tstop) + solver_step(integrator, tstop) integrator.t = first(integrator.tout) integrator.flag < 0 && break handle_callbacks!(integrator) integrator.flag < 0 && break if isempty(integrator.opts.tstops) - break + break end get_iters!(integrator, iters) if iters[] + 1 > integrator.opts.maxiters @@ -1180,21 +1474,29 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator) handle_tstop!(integrator) end - if integrator.opts.save_end && (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t) - save_value!(integrator.sol.u,integrator.u,uType,integrator.sizeu) + if integrator.opts.save_end && + (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t) + save_value!(integrator.sol.u, integrator.u, uType, integrator.sizeu) push!(integrator.sol.t, integrator.t) if integrator.opts.dense - integrator(integrator.u,integrator.t,Val{1}) - save_value!(integrator.sol.interp.du,integrator.u,uType,integrator.sizeu) + integrator(integrator.u, integrator.t, Val{1}) + save_value!(integrator.sol.interp.du, integrator.u, uType, integrator.sizeu) end end if integrator.opts.progress - Logging.@logmsg(-1, - integrator.opts.progress_name, - _id = :Sundials, - message=integrator.opts.progress_message(integrator.dt,integrator.u,integrator.p,integrator.t), - progress="done") + Logging.@logmsg( + -1, + integrator.opts.progress_name, + _id = :Sundials, + message = integrator.opts.progress_message( + integrator.dt, + integrator.u, + integrator.p, + integrator.t, + ), + progress = "done" + ) end fill_destats!(integrator) @@ -1203,13 +1505,18 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator) integrator.LS != nothing && empty!(integrator.LS) if DiffEqBase.has_analytic(integrator.sol.prob.f) - DiffEqBase.calculate_solution_errors!(integrator.sol; - timeseries_errors=integrator.opts.timeseries_errors, - dense_errors=integrator.opts.dense_errors) + DiffEqBase.calculate_solution_errors!( + integrator.sol; + timeseries_errors = integrator.opts.timeseries_errors, + dense_errors = integrator.opts.dense_errors, + ) end if integrator.sol.retcode === :Default - integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol,interpret_sundials_retcode(integrator.flag)) + integrator.sol = DiffEqBase.solution_new_retcode( + integrator.sol, + interpret_sundials_retcode(integrator.flag), + ) end return integrator.sol @@ -1218,35 +1525,34 @@ end function handle_tstop!(integrator::AbstractSundialsIntegrator) tstops = integrator.opts.tstops if !isempty(tstops) - if integrator.tdir*(integrator.t-first(integrator.opts.tstops)) > -1e6eps() - pop!(tstops) - t = integrator.t - integrator.just_hit_tstop = true - end + if integrator.tdir * (integrator.t - first(integrator.opts.tstops)) > -1e6eps() + pop!(tstops) + t = integrator.t + integrator.just_hit_tstop = true + end end end -function fill_destats!(integrator::AbstractSundialsIntegrator) -end +function fill_destats!(integrator::AbstractSundialsIntegrator) end function fill_destats!(integrator::CVODEIntegrator) destats = integrator.sol.destats mem = integrator.mem tmp = Ref(Clong(-1)) - CVodeGetNumRhsEvals(mem,tmp) + CVodeGetNumRhsEvals(mem, tmp) destats.nf = tmp[] - CVodeGetNumLinSolvSetups(mem,tmp) + CVodeGetNumLinSolvSetups(mem, tmp) destats.nw = tmp[] - CVodeGetNumErrTestFails(mem,tmp) + CVodeGetNumErrTestFails(mem, tmp) destats.nreject = tmp[] - CVodeGetNumSteps(mem,tmp) + CVodeGetNumSteps(mem, tmp) destats.naccept = tmp[] - destats.nreject - CVodeGetNumNonlinSolvIters(mem,tmp) + CVodeGetNumNonlinSolvIters(mem, tmp) destats.nnonliniter = tmp[] - CVodeGetNumNonlinSolvConvFails(mem,tmp) + CVodeGetNumNonlinSolvConvFails(mem, tmp) destats.nnonlinconvfail = tmp[] if method_choice(integrator.alg) == :Newton - CVodeGetNumJacEvals(mem,tmp) + CVodeGetNumJacEvals(mem, tmp) destats.njacs = tmp[] end end @@ -1256,21 +1562,21 @@ function fill_destats!(integrator::ARKODEIntegrator) mem = integrator.mem tmp = Ref(Clong(-1)) tmp2 = Ref(Clong(-1)) - ARKStepGetNumRhsEvals(mem,tmp,tmp2) + ARKStepGetNumRhsEvals(mem, tmp, tmp2) destats.nf = tmp[] destats.nf2 = tmp2[] - ARKStepGetNumLinSolvSetups(mem,tmp) + ARKStepGetNumLinSolvSetups(mem, tmp) destats.nw = tmp[] - ARKStepGetNumErrTestFails(mem,tmp) + ARKStepGetNumErrTestFails(mem, tmp) destats.nreject = tmp[] - ARKStepGetNumSteps(mem,tmp) + ARKStepGetNumSteps(mem, tmp) destats.naccept = tmp[] - destats.nreject - ARKStepGetNumNonlinSolvIters(mem,tmp) + ARKStepGetNumNonlinSolvIters(mem, tmp) destats.nnonliniter = tmp[] - ARKStepGetNumNonlinSolvConvFails(mem,tmp) + ARKStepGetNumNonlinSolvConvFails(mem, tmp) destats.nnonlinconvfail = tmp[] if method_choice(integrator.alg) == :Newton - ARKStepGetNumJacEvals(mem,tmp) + ARKStepGetNumJacEvals(mem, tmp) destats.njacs = tmp[] end end @@ -1279,44 +1585,45 @@ function fill_destats!(integrator::IDAIntegrator) destats = integrator.sol.destats mem = integrator.mem tmp = Ref(Clong(-1)) - IDAGetNumResEvals(mem,tmp) + IDAGetNumResEvals(mem, tmp) destats.nf = tmp[] - IDAGetNumLinSolvSetups(mem,tmp) + IDAGetNumLinSolvSetups(mem, tmp) destats.nw = tmp[] - IDAGetNumErrTestFails(mem,tmp) + IDAGetNumErrTestFails(mem, tmp) destats.nreject = tmp[] - IDAGetNumSteps(mem,tmp) + IDAGetNumSteps(mem, tmp) destats.naccept = tmp[] - destats.nreject - IDAGetNumNonlinSolvIters(mem,tmp) + IDAGetNumNonlinSolvIters(mem, tmp) destats.nnonliniter = tmp[] - IDAGetNumNonlinSolvConvFails(mem,tmp) + IDAGetNumNonlinSolvConvFails(mem, tmp) destats.nnonlinconvfail = tmp[] if method_choice(integrator.alg) == :Newton - IDAGetNumJacEvals(mem,tmp) + IDAGetNumJacEvals(mem, tmp) destats.njacs = tmp[] end end function initialize_callbacks!(integrator, initialize_save = true) - t = integrator.t - u = integrator.u - callbacks = integrator.opts.callback - integrator.u_modified = true + t = integrator.t + u = integrator.u + callbacks = integrator.opts.callback + integrator.u_modified = true - u_modified = initialize!(callbacks,u,t,integrator) + u_modified = initialize!(callbacks, u, t, integrator) - # if the user modifies u, we need to fix current values - if u_modified + # if the user modifies u, we need to fix current values + if u_modified - handle_callback_modifiers!(integrator) + handle_callback_modifiers!(integrator) - if initialize_save && - (any((c)->c.save_positions[2],callbacks.discrete_callbacks) || - any((c)->c.save_positions[2],callbacks.continuous_callbacks)) - savevalues!(integrator,true) + if initialize_save && ( + any((c) -> c.save_positions[2], callbacks.discrete_callbacks) || + any((c) -> c.save_positions[2], callbacks.continuous_callbacks) + ) + savevalues!(integrator, true) + end end - end - # reset this as it is now handled so the integrators should proceed as normal - integrator.u_modified = false + # reset this as it is now handled so the integrators should proceed as normal + integrator.u_modified = false end diff --git a/src/common_interface/verbosity.jl b/src/common_interface/verbosity.jl index b3a1d58..44aeda2 100644 --- a/src/common_interface/verbosity.jl +++ b/src/common_interface/verbosity.jl @@ -1,6 +1,3 @@ -function null_error_handler(error_code::Cint, - mod::Char, - func::Char, - eh_data::Ptr{Cvoid}) - nothing +function null_error_handler(error_code::Cint, mod::Char, func::Char, eh_data::Ptr{Cvoid}) + nothing end diff --git a/src/handle.jl b/src/handle.jl index 26eb4ae..e2347bf 100644 --- a/src/handle.jl +++ b/src/handle.jl @@ -43,55 +43,66 @@ abstract type SundialsHandle end struct Handle{T <: AbstractSundialsObject} <: SundialsHandle ptr_ref::Ref{Ptr{T}} # pointer to a pointer - function Handle(ptr::Ptr{T}) where T <: AbstractSundialsObject + function Handle(ptr::Ptr{T}) where {T <: AbstractSundialsObject} h = new{T}(Ref{Ptr{T}}(ptr)) finalizer(release_handle, h.ptr_ref) return h end end -mutable struct MatrixHandle{T<:SundialsMatrix} <: SundialsHandle +mutable struct MatrixHandle{T <: SundialsMatrix} <: SundialsHandle ptr::SUNMatrix destroyed::Bool - function MatrixHandle(ptr::SUNMatrix,M::T) where T<:SundialsMatrix - h = new{T}(ptr,false) + function MatrixHandle(ptr::SUNMatrix, M::T) where {T <: SundialsMatrix} + h = new{T}(ptr, false) finalizer(release_handle, h) return h end end -mutable struct LinSolHandle{T<:SundialsLinearSolver} <: SundialsHandle +mutable struct LinSolHandle{T <: SundialsLinearSolver} <: SundialsHandle ptr::SUNLinearSolver destroyed::Bool - function LinSolHandle(ptr::SUNLinearSolver,M::T) where T<:SundialsLinearSolver - h = new{T}(ptr,false) + function LinSolHandle(ptr::SUNLinearSolver, M::T) where {T <: SundialsLinearSolver} + h = new{T}(ptr, false) finalizer(release_handle, h) return h end end -mutable struct NonLinSolHandle{T<:SundialsNonLinearSolver} <: SundialsHandle +mutable struct NonLinSolHandle{T <: SundialsNonLinearSolver} <: SundialsHandle ptr::SUNNonlinearSolver destroyed::Bool - function NonLinSolHandle(ptr::SUNNonlinearSolver,M::T) where T<:SundialsNonLinearSolver - h = new{T}(ptr,false) + function NonLinSolHandle( + ptr::SUNNonlinearSolver, + M::T, + ) where {T <: SundialsNonLinearSolver} + h = new{T}(ptr, false) finalizer(release_handle, h) return h end end -Base.unsafe_convert(::Type{Ptr{T}}, h::Handle{T}) where T = h.ptr_ref[] -Base.unsafe_convert(::Type{Ptr{Cvoid}}, h::Handle{T}) where T = Ptr{Cvoid}(h.ptr_ref[]) -Base.convert(::Type{Ptr{T}}, h::Handle{T}) where T = h.ptr_ref[] -Base.convert(::Type{Ptr{Ptr{T}}}, h::Handle{T}) where {T} = convert(Ptr{Ptr{T}}, h.ptr_ref[]) - -release_handle(ptr_ref::Ref{Ptr{T}}) where {T} = throw(MethodError("Freeing objects of type $T not supported")) -release_handle(ptr_ref::Ref{Ptr{KINMem}}) = ((ptr_ref[] != C_NULL) && KINFree(ptr_ref); nothing) -release_handle(ptr_ref::Ref{Ptr{CVODEMem}}) = ((ptr_ref[] != C_NULL) && CVodeFree(ptr_ref); nothing) -release_handle(ptr_ref::Ref{Ptr{ARKStepMem}}) = ((ptr_ref[] != C_NULL) && ARKStepFree(ptr_ref); nothing) -release_handle(ptr_ref::Ref{Ptr{ERKStepMem}}) = ((ptr_ref[] != C_NULL) && ERKStepFree(ptr_ref); nothing) -release_handle(ptr_ref::Ref{Ptr{MRIStepMem}}) = ((ptr_ref[] != C_NULL) && MRIStepFree(ptr_ref); nothing) -release_handle(ptr_ref::Ref{Ptr{IDAMem}}) = ((ptr_ref[] != C_NULL) && IDAFree(ptr_ref); nothing) +Base.unsafe_convert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h.ptr_ref[] +Base.unsafe_convert(::Type{Ptr{Cvoid}}, h::Handle{T}) where {T} = Ptr{Cvoid}(h.ptr_ref[]) +Base.convert(::Type{Ptr{T}}, h::Handle{T}) where {T} = h.ptr_ref[] +Base.convert(::Type{Ptr{Ptr{T}}}, h::Handle{T}) where {T} = + convert(Ptr{Ptr{T}}, h.ptr_ref[]) + +release_handle(ptr_ref::Ref{Ptr{T}}) where {T} = + throw(MethodError("Freeing objects of type $T not supported")) +release_handle(ptr_ref::Ref{Ptr{KINMem}}) = + ((ptr_ref[] != C_NULL) && KINFree(ptr_ref); nothing) +release_handle(ptr_ref::Ref{Ptr{CVODEMem}}) = + ((ptr_ref[] != C_NULL) && CVodeFree(ptr_ref); nothing) +release_handle(ptr_ref::Ref{Ptr{ARKStepMem}}) = + ((ptr_ref[] != C_NULL) && ARKStepFree(ptr_ref); nothing) +release_handle(ptr_ref::Ref{Ptr{ERKStepMem}}) = + ((ptr_ref[] != C_NULL) && ERKStepFree(ptr_ref); nothing) +release_handle(ptr_ref::Ref{Ptr{MRIStepMem}}) = + ((ptr_ref[] != C_NULL) && MRIStepFree(ptr_ref); nothing) +release_handle(ptr_ref::Ref{Ptr{IDAMem}}) = + ((ptr_ref[] != C_NULL) && IDAFree(ptr_ref); nothing) function release_handle(h::MatrixHandle{DenseMatrix}) if !isempty(h) @@ -226,9 +237,9 @@ Base.isempty(h::NonLinSolHandle) = h.destroyed # ################################################################## -const CVODEh = Handle{CVODEMem} -const ARKSteph = Handle{ARKStepMem} -const ERKSteph = Handle{ERKStepMem} -const MRISteph = Handle{MRIStepMem} -const KINh = Handle{KINMem} -const IDAh = Handle{IDAMem} +const CVODEh = Handle{CVODEMem} +const ARKSteph = Handle{ARKStepMem} +const ERKSteph = Handle{ERKStepMem} +const MRISteph = Handle{MRIStepMem} +const KINh = Handle{KINMem} +const IDAh = Handle{IDAMem} diff --git a/src/nvector_wrapper.jl b/src/nvector_wrapper.jl index 4ea3c3c..bdb0671 100644 --- a/src/nvector_wrapper.jl +++ b/src/nvector_wrapper.jl @@ -25,8 +25,8 @@ struct NVector <: DenseVector{realtype} end end -NVector(v::AbstractArray) = convert(Vector,v) -N_Vector(x::NVector) = convert(N_Vector,x) +NVector(v::AbstractArray) = convert(Vector, v) +N_Vector(x::NVector) = convert(N_Vector, x) release_handle(ref_nv::Ref{N_Vector}) = N_VDestroy_Serial(ref_nv[]) @@ -50,13 +50,14 @@ Base.pointer(nv::NVector) = Sundials.N_VGetArrayPointer_Serial(nv.ref_nv[]) ################################################################## Base.convert(::Type{NVector}, v::Vector{realtype}) = NVector(v) -Base.convert(::Type{NVector}, v::Vector{T}) where {T<:Real} = NVector(copy!(similar(v, realtype), v)) -Base.convert(::Type{NVector}, v::AbstractVector) = NVector(convert(Array,v)) +Base.convert(::Type{NVector}, v::Vector{T}) where {T <: Real} = + NVector(copy!(similar(v, realtype), v)) +Base.convert(::Type{NVector}, v::AbstractVector) = NVector(convert(Array, v)) Base.convert(::Type{NVector}, nv::NVector) = nv Base.convert(::Type{NVector}, nv::N_Vector) = NVector(nv) Base.convert(::Type{N_Vector}, nv::NVector) = nv.ref_nv[] -Base.convert(::Type{Vector{realtype}}, nv::NVector)= nv.v -Base.convert(::Type{Vector}, nv::NVector)= nv.v +Base.convert(::Type{Vector{realtype}}, nv::NVector) = nv.v +Base.convert(::Type{Vector}, nv::NVector) = nv.v """ `N_Vector(v::Vector{T})` @@ -66,16 +67,18 @@ Base.convert(::Type{Vector}, nv::NVector)= nv.v destruction of `N_Vector` object when no longer in use. """ Base.convert(::Type{N_Vector}, v::Vector{realtype}) = N_Vector(NVector(v)) -Base.convert(::Type{N_Vector}, v::Vector{T}) where {T<:Real} = N_Vector(NVector(v)) +Base.convert(::Type{N_Vector}, v::Vector{T}) where {T <: Real} = N_Vector(NVector(v)) Base.similar(nv::NVector) = NVector(similar(nv.v)) nvlength(x::N_Vector) = unsafe_load(unsafe_load(convert(Ptr{Ptr{Clong}}, x))) # asarray() creates an array pointing to N_Vector data, but does not take the ownership -@inline asarray(x::N_Vector) = unsafe_wrap(Array, N_VGetArrayPointer_Serial(x), (nvlength(x),), own=false) -@inline asarray(x::N_Vector, dims::Tuple) = unsafe_wrap(Array, N_VGetArrayPointer_Serial(x), dims, own=false) +@inline asarray(x::N_Vector) = + unsafe_wrap(Array, N_VGetArrayPointer_Serial(x), (nvlength(x),), own = false) +@inline asarray(x::N_Vector, dims::Tuple) = + unsafe_wrap(Array, N_VGetArrayPointer_Serial(x), dims, own = false) asarray(x::Vector{realtype}) = x -asarray(x::Ptr{realtype}, dims::Tuple) = unsafe_wrap(Array, x, dims, own=false) +asarray(x::Ptr{realtype}, dims::Tuple) = unsafe_wrap(Array, x, dims, own = false) @inline Base.convert(::Type{Vector{realtype}}, x::N_Vector) = asarray(x) @inline Base.convert(::Type{Vector}, x::N_Vector) = asarray(x) diff --git a/src/simple.jl b/src/simple.jl index ef20579..75aec40 100644 --- a/src/simple.jl +++ b/src/simple.jl @@ -8,7 +8,7 @@ Insert a check that the given function call returns 0, throw an error otherwise. Only apply directly to function calls. """ -macro checkflag(ex,throw_error=false) +macro checkflag(ex, throw_error = false) @assert Base.Meta.isexpr(ex, :call) fname = ex.args[1] quote @@ -44,9 +44,14 @@ function kinsolfun(y::N_Vector, fy::N_Vector, userfun) return KIN_SUCCESS end -function kinsol(f, y0::Vector{Float64}; - userdata::Any = nothing, - linear_solver=:Dense, jac_upper=0, jac_lower=0) +function kinsol( + f, + y0::Vector{Float64}; + userdata::Any = nothing, + linear_solver = :Dense, + jac_upper = 0, + jac_lower = 0, +) # f, Function to be optimized of the form f(y::Vector{Float64}, fy::Vector{Float64}) # where `y` is the input vector, and `fy` is the result of the function # y0, Vector of initial values @@ -65,11 +70,11 @@ function kinsol(f, y0::Vector{Float64}; end flag = @checkflag KINInit(kmem, getcfun(userfun), NVector(y0)) true if linear_solver == :Dense - A = Sundials.SUNDenseMatrix(length(y0),length(y0)) - LS = Sundials.SUNLinSol_Dense(y0,A) + A = Sundials.SUNDenseMatrix(length(y0), length(y0)) + LS = Sundials.SUNLinSol_Dense(y0, A) elseif linear_solver == :Band A = Sundials.SUNBandMatrix(length(y0), jac_upper, jac_lower) - LS = Sundials.SUNLinSol_Band(y0,A) + LS = Sundials.SUNLinSol_Band(y0, A) end flag = @checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A) true flag = @checkflag KINSetUserData(kmem, userfun) true @@ -116,17 +121,32 @@ end return: a solution matrix with time steps in `t` along rows and state variable `y` along columns """ -function cvode(f::Function, y0::Vector{Float64}, t::AbstractVector, userdata::Any=nothing; kwargs...) +function cvode( + f::Function, + y0::Vector{Float64}, + t::AbstractVector, + userdata::Any = nothing; + kwargs..., +) y = zeros(length(t), length(y0)) n = cvode!(f, y, y0, t, userdata; kwargs...) - return y[1:n,:] + return y[1:n, :] end -function cvode!(f::Function, y::Matrix{Float64}, y0::Vector{Float64}, t::AbstractVector, userdata::Any=nothing; - integrator=:BDF, reltol::Float64=1e-3, abstol::Float64=1e-6, callback=(x,y,z)->true) - if integrator==:BDF +function cvode!( + f::Function, + y::Matrix{Float64}, + y0::Vector{Float64}, + t::AbstractVector, + userdata::Any = nothing; + integrator = :BDF, + reltol::Float64 = 1e-3, + abstol::Float64 = 1e-6, + callback = (x, y, z) -> true, +) + if integrator == :BDF mem_ptr = CVodeCreate(CV_BDF) - elseif integrator==:Adams + elseif integrator == :Adams mem_ptr = CVodeCreate(CV_ADAMS) end @@ -145,11 +165,11 @@ function cvode!(f::Function, y::Matrix{Float64}, y0::Vector{Float64}, t::Abstrac flag = @checkflag CVodeSetUserData(mem, userfun) true flag = @checkflag CVodeSStolerances(mem, reltol, abstol) true - A = Sundials.SUNDenseMatrix(length(y0),length(y0)) - LS = Sundials.SUNLinSol_Dense(y0nv,A) + A = Sundials.SUNDenseMatrix(length(y0), length(y0)) + LS = Sundials.SUNLinSol_Dense(y0nv, A) flag = Sundials.@checkflag Sundials.CVDlsSetLinearSolver(mem, LS, A) true - y[1,:] = y0 + y[1, :] = y0 ynv = NVector(copy(y0)) tout = [0.0] for k in 2:length(t) @@ -157,7 +177,7 @@ function cvode!(f::Function, y::Matrix{Float64}, y0::Vector{Float64}, t::Abstrac if !callback(mem, t[k], ynv) break end - y[k,:] = convert(Vector, ynv) + y[k, :] = convert(Vector, ynv) c = c + 1 end @@ -168,8 +188,20 @@ function cvode!(f::Function, y::Matrix{Float64}, y0::Vector{Float64}, t::Abstrac return c end -function idasolfun(t::Float64, y::N_Vector, yp::N_Vector, r::N_Vector, userfun::UserFunctionAndData) - userfun.func(t, convert(Vector, y), convert(Vector, yp), convert(Vector, r), userfun.data) +function idasolfun( + t::Float64, + y::N_Vector, + yp::N_Vector, + r::N_Vector, + userfun::UserFunctionAndData, +) + userfun.func( + t, + convert(Vector, y), + convert(Vector, yp), + convert(Vector, r), + userfun.data, + ) return IDA_SUCCESS end @@ -195,8 +227,16 @@ end return: (y,yp) two solution matrices representing the states and state derivatives with time steps in `t` along rows and state variable `y` or `yp` along columns """ -function idasol(f, y0::Vector{Float64}, yp0::Vector{Float64}, t::Vector{Float64}, userdata::Any=nothing; - reltol::Float64=1e-3, abstol::Float64=1e-6, diffstates::Union{Vector{Bool},Nothing}=nothing) +function idasol( + f, + y0::Vector{Float64}, + yp0::Vector{Float64}, + t::Vector{Float64}, + userdata::Any = nothing; + reltol::Float64 = 1e-3, + abstol::Float64 = 1e-6, + diffstates::Union{Vector{Bool}, Nothing} = nothing, +) mem_ptr = IDACreate() (mem_ptr == C_NULL) && error("Failed to allocate IDA solver object") mem = Handle(mem_ptr) @@ -209,13 +249,12 @@ function idasol(f, y0::Vector{Float64}, yp0::Vector{Float64}, t::Vector{Float64} function getcfun(userfun::T) where {T} @cfunction(idasolfun, Cint, (realtype, N_Vector, N_Vector, N_Vector, Ref{T})) end - flag = @checkflag IDAInit(mem, getcfun(userfun), - t[1], y0, yp0) true + flag = @checkflag IDAInit(mem, getcfun(userfun), t[1], y0, yp0) true flag = @checkflag IDASetUserData(mem, userfun) true flag = @checkflag IDASStolerances(mem, reltol, abstol) true - A = Sundials.SUNDenseMatrix(length(y0),length(y0)) - LS = Sundials.SUNLinSol_Dense(y0,A) + A = Sundials.SUNDenseMatrix(length(y0), length(y0)) + LS = Sundials.SUNLinSol_Dense(y0, A) flag = Sundials.@checkflag Sundials.IDADlsSetLinearSolver(mem, LS, A) true rtest = zeros(length(y0)) @@ -227,15 +266,15 @@ function idasol(f, y0::Vector{Float64}, yp0::Vector{Float64}, t::Vector{Float64} flag = @checkflag IDASetId(mem, collect(Float64, diffstates)) true flag = @checkflag IDACalcIC(mem, IDA_YA_YDP_INIT, t[2]) true end - yres[1,:] = y0 - ypres[1,:] = yp0 + yres[1, :] = y0 + ypres[1, :] = yp0 y = copy(y0) yp = copy(yp0) tout = [0.0] for k in 2:length(t) retval = @checkflag IDASolve(mem, t[k], tout, y, yp, IDA_NORMAL) true - yres[k,:] = y - ypres[k,:] = yp + yres[k, :] = y + ypres[k, :] = yp end empty!(mem) diff --git a/src/types_and_consts_additions.jl b/src/types_and_consts_additions.jl index 4cf4b06..48dc488 100644 --- a/src/types_and_consts_additions.jl +++ b/src/types_and_consts_additions.jl @@ -3,36 +3,40 @@ function Base.convert(::Type{Matrix}, J::DlsMat) _dlsmat = unsafe_load(J) # own is false as memory is allocated by sundials - unsafe_wrap(Array, _dlsmat.data, (_dlsmat.M, _dlsmat.N), own=false) + unsafe_wrap(Array, _dlsmat.data, (_dlsmat.M, _dlsmat.N), own = false) end CVRhsFn_wrapper(fp::CVRhsFn) = fp -CVRhsFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,N_Vector,Ptr{Cvoid})).ptr +CVRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr ARKRhsFn_wrapper(fp::ARKRhsFn) = fp -ARKRhsFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,N_Vector,Ptr{Cvoid})).ptr +ARKRhsFn_wrapper(f) = @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr CVRootFn_wrapper(fp::CVRootFn) = fp -CVRootFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,Ptr{realtype},Ptr{Cvoid})).ptr +CVRootFn_wrapper(f) = + @cfunction($f, Cint, (realtype, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr CVQuadRhsFn_wrapper(fp::CVQuadRhsFn) = fp -CVQuadRhsFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,N_Vector,Ptr{Cvoid})).ptr +CVQuadRhsFn_wrapper(f) = + @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{Cvoid})).ptr IDAResFn_wrapper(fp::IDAResFn) = fp -IDAResFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,N_Vector,N_Vector,Ptr{Cvoid})).ptr +IDAResFn_wrapper(f) = + @cfunction($f, Cint, (realtype, N_Vector, N_Vector, N_Vector, Ptr{Cvoid})).ptr IDARootFn_wrapper(fp::IDARootFn) = fp -IDARootFn_wrapper(f) = @cfunction($f,Cint,(realtype,N_Vector,N_Vector,Ptr{realtype},Ptr{Cvoid})).ptr +IDARootFn_wrapper(f) = + @cfunction($f, Cint, (realtype, N_Vector, N_Vector, Ptr{realtype}, Ptr{Cvoid})).ptr KINSysFn_wrapper(fp::KINSysFn) = fp -KINSysFn_wrapper(f) = @cfunction($f,Cint,(N_Vector,N_Vector,Ptr{Cvoid})).ptr +KINSysFn_wrapper(f) = @cfunction($f, Cint, (N_Vector, N_Vector, Ptr{Cvoid})).ptr function Base.convert(::Type{Matrix}, J::SUNMatrix) _sunmat = unsafe_load(J) _mat = convert(SUNMatrixContent_Dense, _sunmat.content) mat = unsafe_load(_mat) # own is false as memory is allocated by sundials - unsafe_wrap(Array, mat.data, (mat.M, mat.N), own=false) + unsafe_wrap(Array, mat.data, (mat.M, mat.N), own = false) end function Base.convert(::Type{SparseArrays.SparseMatrixCSC}, J::SUNMatrix) @@ -41,13 +45,13 @@ function Base.convert(::Type{SparseArrays.SparseMatrixCSC}, J::SUNMatrix) mat = unsafe_load(_mat) # own is false as memory is allocated by sundials # TODO: Get rid of allocation for 1-based index change - rowval = unsafe_wrap(Array, mat.indexvals, (mat.NNZ), own=false) - colptr = unsafe_wrap(Array, mat.indexptrs, (mat.NP+1), own=false) + rowval = unsafe_wrap(Array, mat.indexvals, (mat.NNZ), own = false) + colptr = unsafe_wrap(Array, mat.indexptrs, (mat.NP + 1), own = false) colptr .+= 1 m = mat.M n = mat.N - nzval = unsafe_wrap(Array,mat.data, (mat.NNZ), own=false) - SparseArrays.SparseMatrixCSC(m,n,colptr,rowval,nzval) + nzval = unsafe_wrap(Array, mat.data, (mat.NNZ), own = false) + SparseArrays.SparseMatrixCSC(m, n, colptr, rowval, nzval) end abstract type SundialsMatrix end diff --git a/test/arkstep_Roberts_dns.jl b/test/arkstep_Roberts_dns.jl index 6779b30..923c0cf 100644 --- a/test/arkstep_Roberts_dns.jl +++ b/test/arkstep_Roberts_dns.jl @@ -5,8 +5,8 @@ using Sundials, Test function f(t, y_nv, ydot_nv, user_data) y = convert(Vector, y_nv) ydot = convert(Vector, ydot_nv) - ydot[1] = -0.04*y[1] + 1.0e4*y[2]*y[3] - ydot[3] = 3.0e7*y[2]*y[2] + ydot[1] = -0.04 * y[1] + 1.0e4 * y[2] * y[3] + ydot[3] = 3.0e7 * y[2] * y[2] ydot[2] = -ydot[1] - ydot[3] return Sundials.ARK_SUCCESS end @@ -33,8 +33,8 @@ Sundials.@checkflag Sundials.ARKStepSetMaxNumSteps(arkStep_mem, 100000) Sundials.@checkflag Sundials.ARKStepSetPredictorMethod(arkStep_mem, 1) Sundials.@checkflag Sundials.ARKStepSStolerances(arkStep_mem, reltol, abstol) -A = Sundials.SUNDenseMatrix(neq,neq) -LS = Sundials.SUNLinSol_Dense(y0,A) +A = Sundials.SUNDenseMatrix(neq, neq) +LS = Sundials.SUNLinSol_Dense(y0, A) Sundials.@checkflag Sundials.ARKStepSetLinearSolver(arkStep_mem, LS, A) iout = 0 diff --git a/test/common_interface/arkode.jl b/test/common_interface/arkode.jl index 902a8b4..626c64f 100644 --- a/test/common_interface/arkode.jl +++ b/test/common_interface/arkode.jl @@ -1,28 +1,37 @@ using Sundials, Test using DiffEqProblemLibrary -using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; importodeproblems() +using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; +importodeproblems(); import DiffEqProblemLibrary.ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear prob = prob_ode_linear -dt = 1//2^(4) -sol = solve(prob,ARKODE()) +dt = 1 // 2^(4) +sol = solve(prob, ARKODE()) @test sol.errors[:l2] < 5e-2 -sol = solve(prob,ARKODE(),abstol=[1e-7]) +sol = solve(prob, ARKODE(), abstol = [1e-7]) -f1 = (du,u,p,t) -> du .= u -f2 = (du,u,p,t) -> du .= u +f1 = (du, u, p, t) -> du .= u +f2 = (du, u, p, t) -> du .= u prob = prob_ode_2Dlinear -dt = 1//2^(4) -sol = solve(prob,ARKODE(linear_solver=:LapackDense)) +dt = 1 // 2^(4) +sol = solve(prob, ARKODE(linear_solver = :LapackDense)) -prob = SplitODEProblem(SplitFunction(f1,f2,analytic=(u0,p,t)->exp(2t)*u0), - rand(4,2),(0.0,1.0)) +prob = SplitODEProblem( + SplitFunction(f1, f2, analytic = (u0, p, t) -> exp(2t) * u0), + rand(4, 2), + (0.0, 1.0), +) -sol = solve(prob,ARKODE(linear_solver=:Dense)) +sol = solve(prob, ARKODE(linear_solver = :Dense)) @test sol.errors[:l2] < 1e-2 -sol = solve(prob,ARKODE(linear_solver=:LapackBand,jac_upper=3,jac_lower=3),reltol=1e-12,abstol=1e-12) +sol = solve( + prob, + ARKODE(linear_solver = :LapackBand, jac_upper = 3, jac_lower = 3), + reltol = 1e-12, + abstol = 1e-12, +) @test sol.errors[:l2] < 1e-6 # @@ -31,26 +40,28 @@ sol = solve(prob,ARKODE(linear_solver=:LapackBand,jac_upper=3,jac_lower=3),relto # ARKStepSetERKTableNum not defined # # Function -function Eq_Dif(dq,q,t) - dq .= 10*q +function Eq_Dif(dq, q, t) + dq .= 10 * q end # Alias -fn(dq,q,p,t) = Eq_Dif(dq,q,t) +fn(dq, q, p, t) = Eq_Dif(dq, q, t) # Time span tspan = (0.0, 1.0) # Initial values q = zeros(10) # Define problem -prob = ODEProblem(fn,q,tspan) +prob = ODEProblem(fn, q, tspan) # Define solution method -method = ARKODE(Sundials.Explicit(), - etable = Sundials.VERNER_8_5_6, - order = 8, - set_optimal_params = false, - max_hnil_warns = 10, - max_error_test_failures = 7, - max_nonlinear_iters = 4, - max_convergence_failures = 10) +method = ARKODE( + Sundials.Explicit(), + etable = Sundials.VERNER_8_5_6, + order = 8, + set_optimal_params = false, + max_hnil_warns = 10, + max_error_test_failures = 7, + max_nonlinear_iters = 4, + max_convergence_failures = 10, +) # Solve -sol = solve(prob,method) +sol = solve(prob, method) @test sol.retcode == :Success diff --git a/test/common_interface/callbacks.jl b/test/common_interface/callbacks.jl index bc95c1b..d56664d 100644 --- a/test/common_interface/callbacks.jl +++ b/test/common_interface/callbacks.jl @@ -1,65 +1,65 @@ using Sundials, Test callback_f = function (du, u, p, t) - du[1] = u[2] - du[2] = -9.81 + du[1] = u[2] + du[2] = -9.81 end -condtion= function (u,t,integrator) # Event when event_f(u,t,k) == 0 - u[1] +condtion = function (u, t, integrator) # Event when event_f(u,t,k) == 0 + u[1] end affect! = nothing affect_neg! = function (integrator) - integrator.u[2] = -integrator.u[2] + integrator.u[2] = -integrator.u[2] end -callback = ContinuousCallback(condtion,affect!,affect_neg!) +callback = ContinuousCallback(condtion, affect!, affect_neg!) -u0 = [50.0,0.0] -tspan = (0.0,15.0) -prob = ODEProblem(callback_f,u0,tspan) +u0 = [50.0, 0.0] +tspan = (0.0, 15.0) +prob = ODEProblem(callback_f, u0, tspan) -sol = solve(prob,CVODE_Adams(),callback=callback) +sol = solve(prob, CVODE_Adams(), callback = callback) @test sol(4.0)[1] > 0 -sol = solve(prob,CVODE_BDF(),callback=callback) +sol = solve(prob, CVODE_BDF(), callback = callback) @test sol(4.0)[1] > 0 -condition = function (out,u,t,integrator) - out[1] = u[1] +condition = function (out, u, t, integrator) + out[1] = u[1] end affect! = nothing affect_neg! = function (integrator, idx) - if idx == 1 - integrator.u[2] = -integrator.u[2] - end + if idx == 1 + integrator.u[2] = -integrator.u[2] + end end callback = VectorContinuousCallback(condition, affect!, affect_neg!, 1) -sol = solve(prob,CVODE_Adams(),callback=callback) +sol = solve(prob, CVODE_Adams(), callback = callback) @test sol(4.0)[1] > 0 -sol = solve(prob,CVODE_BDF(),callback=callback) +sol = solve(prob, CVODE_BDF(), callback = callback) @test sol(4.0)[1] > 0 -u0 = [1.,0.] +u0 = [1.0, 0.0] function fun2(du, u, p, t) - du[2] = -u[1] - du[1] = u[2] + du[2] = -u[1] + du[1] = u[2] end -tspan = (0.0,10.0) -prob = ODEProblem(fun2,u0,tspan) +tspan = (0.0, 10.0) +prob = ODEProblem(fun2, u0, tspan) -function condition2(u,t,integrator) - get_du(integrator)[1]>0 +function condition2(u, t, integrator) + get_du(integrator)[1] > 0 end affect2!(integrator) = terminate!(integrator) -cb = DiscreteCallback(condition2,affect2!) -sol = solve(prob,CVODE_BDF(),callback=cb) +cb = DiscreteCallback(condition2, affect2!) +sol = solve(prob, CVODE_BDF(), callback = cb) @test sol.t[end] < 3.5 -condition3(u,t,integrator) = u[2] +condition3(u, t, integrator) = u[2] affect3!(integrator) = terminate!(integrator) -cb = ContinuousCallback(condition3,affect3!) -sol = solve(prob,CVODE_Adams(),callback=cb,abstol=1e-12,reltol=1e-12) +cb = ContinuousCallback(condition3, affect3!) +sol = solve(prob, CVODE_Adams(), callback = cb, abstol = 1e-12, reltol = 1e-12) @test sol.t[end] ≈ pi diff --git a/test/common_interface/cvode.jl b/test/common_interface/cvode.jl index 029658f..31c7e2c 100644 --- a/test/common_interface/cvode.jl +++ b/test/common_interface/cvode.jl @@ -1,112 +1,119 @@ using Sundials, Test using DiffEqProblemLibrary -using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; importodeproblems() +using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; +importodeproblems(); import DiffEqProblemLibrary.ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear prob = prob_ode_linear -dt = 1//2^(4) +dt = 1 // 2^(4) saveat = float(collect(0:dt:1)) -sol = solve(prob,CVODE_BDF()) -sol = solve(prob,CVODE_Adams()) +sol = solve(prob, CVODE_BDF()) +sol = solve(prob, CVODE_Adams()) @test sol.errors[:l2] < 1e-3 -sol = solve(prob,CVODE_Adams(),reltol=1e-5) +sol = solve(prob, CVODE_Adams(), reltol = 1e-5) @test sol.errors[:l2] < 1e-5 -sol = solve(prob,CVODE_Adams(),reltol=1e-5,abstol=[1e-7]) -sol = solve(prob,CVODE_Adams(),saveat=saveat) +sol = solve(prob, CVODE_Adams(), reltol = 1e-5, abstol = [1e-7]) +sol = solve(prob, CVODE_Adams(), saveat = saveat) @test sol.t == saveat -sol = solve(prob,CVODE_Adams(),saveat=dt) +sol = solve(prob, CVODE_Adams(), saveat = dt) @test sol.t == saveat -sol = solve(prob,CVODE_Adams(),saveat=saveat,save_everystep=true) +sol = solve(prob, CVODE_Adams(), saveat = saveat, save_everystep = true) @test sol.t != saveat -@test intersect(sol.t,saveat) == saveat +@test intersect(sol.t, saveat) == saveat -sol = solve(prob,CVODE_Adams(),saveat=saveat,save_everystep=true,save_start=false) +sol = solve(prob, CVODE_Adams(), saveat = saveat, save_everystep = true, save_start = false) @test sol.t[1] != 0 -sol = solve(prob,CVODE_Adams(),tstops=[0.2,0.5,0.7]) -@test all(t ∈ sol.t for t in [0.2,0.5,0.7]) +sol = solve(prob, CVODE_Adams(), tstops = [0.2, 0.5, 0.7]) +@test all(t ∈ sol.t for t in [0.2, 0.5, 0.7]) prob = prob_ode_2Dlinear -sol = solve(prob,CVODE_BDF()) -sol = solve(prob,CVODE_Adams()) -sol = solve(prob,CVODE_Adams(),saveat=saveat) +sol = solve(prob, CVODE_BDF()) +sol = solve(prob, CVODE_Adams()) +sol = solve(prob, CVODE_Adams(), saveat = saveat) @test sol.t == saveat -sol = solve(prob,CVODE_Adams(),saveat=[prob.tspan[2]]) +sol = solve(prob, CVODE_Adams(), saveat = [prob.tspan[2]]) @test sol.t == [prob.tspan[2]] -sol = solve(prob,CVODE_Adams(),saveat=saveat,save_everystep=false) +sol = solve(prob, CVODE_Adams(), saveat = saveat, save_everystep = false) @test sol.t == saveat -sol = solve(prob,CVODE_Adams(),tstops=[0.9]) +sol = solve(prob, CVODE_Adams(), tstops = [0.9]) @test 0.9 ∈ sol.t # Test the other function conversions k = (du, u, p, t) -> du[1] = u[1] -prob = ODEProblem(k,[1.0],(0.0,1.0)) -sol = solve(prob,CVODE_BDF()) +prob = ODEProblem(k, [1.0], (0.0, 1.0)) +sol = solve(prob, CVODE_BDF()) h = (u, p, t) -> u -u0 = [1.0 2.0 - 3.0 2.0] -prob = ODEProblem(h,u0,(0.0,1.0)) -sol = solve(prob,CVODE_BDF()) +u0 = [ + 1.0 2.0 + 3.0 2.0 +] +prob = ODEProblem(h, u0, (0.0, 1.0)) +sol = solve(prob, CVODE_BDF()) # Test Algorithm Choices -sol1 = solve(prob,CVODE_BDF(method=:Functional)) -sol2 = solve(prob,CVODE_BDF(linear_solver=:Band,jac_upper=3,jac_lower=3)) -sol3 = solve(prob,CVODE_BDF(linear_solver=:Diagonal)) -sol4 = solve(prob,CVODE_BDF(linear_solver=:GMRES)) -sol5 = solve(prob,CVODE_BDF(linear_solver=:FGMRES)) -sol6 = solve(prob,CVODE_BDF(linear_solver=:PCG)) -sol7 = solve(prob,CVODE_BDF(linear_solver=:BCG)) -sol8 = solve(prob,CVODE_BDF(linear_solver=:TFQMR)) -sol9 = solve(prob,CVODE_BDF(linear_solver=:Dense)) +sol1 = solve(prob, CVODE_BDF(method = :Functional)) +sol2 = solve(prob, CVODE_BDF(linear_solver = :Band, jac_upper = 3, jac_lower = 3)) +sol3 = solve(prob, CVODE_BDF(linear_solver = :Diagonal)) +sol4 = solve(prob, CVODE_BDF(linear_solver = :GMRES)) +sol5 = solve(prob, CVODE_BDF(linear_solver = :FGMRES)) +sol6 = solve(prob, CVODE_BDF(linear_solver = :PCG)) +sol7 = solve(prob, CVODE_BDF(linear_solver = :BCG)) +sol8 = solve(prob, CVODE_BDF(linear_solver = :TFQMR)) +sol9 = solve(prob, CVODE_BDF(linear_solver = :Dense)) #sol9 = solve(prob,CVODE_BDF(linear_solver=:KLU)) # Requires Jacobian -sol10 = solve(prob,CVODE_BDF(linear_solver=:LapackDense)) -sol11 = solve(prob,CVODE_BDF(linear_solver=:LapackBand,jac_upper=3,jac_lower=3)) - -@test isapprox(sol1[end],sol2[end],rtol=1e-3) -@test isapprox(sol1[end],sol3[end],rtol=1e-3) -@test isapprox(sol1[end],sol4[end],rtol=1e-3) -@test isapprox(sol1[end],sol5[end],rtol=1e-3) -@test isapprox(sol1[end],sol6[end],rtol=1e-3) -@test isapprox(sol1[end],sol7[end],rtol=1e-3) -@test isapprox(sol1[end],sol8[end],rtol=1e-3) +sol10 = solve(prob, CVODE_BDF(linear_solver = :LapackDense)) +sol11 = solve(prob, CVODE_BDF(linear_solver = :LapackBand, jac_upper = 3, jac_lower = 3)) + +@test isapprox(sol1[end], sol2[end], rtol = 1e-3) +@test isapprox(sol1[end], sol3[end], rtol = 1e-3) +@test isapprox(sol1[end], sol4[end], rtol = 1e-3) +@test isapprox(sol1[end], sol5[end], rtol = 1e-3) +@test isapprox(sol1[end], sol6[end], rtol = 1e-3) +@test isapprox(sol1[end], sol7[end], rtol = 1e-3) +@test isapprox(sol1[end], sol8[end], rtol = 1e-3) #@test isapprox(sol1[end],sol9[end],rtol=1e-3) # Test identity preconditioner global prec_used = false global psetup_used = false -prec = (z,r,p,t,y,fy,gamma,delta,lr)->(global prec_used=true;z.=r) -psetup = (p,t,u,du,jok,jcurPtr,gamma) -> (global psetup_used = true; jcurPtr[]=false) -sol4 = solve(prob,CVODE_BDF(linear_solver=:GMRES, prec_side = 3, prec=prec)) -@test isapprox(sol1[end],sol4[end],rtol=1e-3) +prec = (z, r, p, t, y, fy, gamma, delta, lr) -> (global prec_used = true; z .= r) +psetup = + (p, t, u, du, jok, jcurPtr, gamma) -> (global psetup_used = true; jcurPtr[] = false) +sol4 = solve(prob, CVODE_BDF(linear_solver = :GMRES, prec_side = 3, prec = prec)) +@test isapprox(sol1[end], sol4[end], rtol = 1e-3) @test prec_used -sol4 = solve(prob,CVODE_BDF(linear_solver=:GMRES, prec_side = 3, prec=prec, psetup=psetup)) -@test isapprox(sol1[end],sol4[end],rtol=1e-3) +sol4 = solve( + prob, + CVODE_BDF(linear_solver = :GMRES, prec_side = 3, prec = prec, psetup = psetup), +) +@test isapprox(sol1[end], sol4[end], rtol = 1e-3) @test psetup_used # Backwards prob = deepcopy(prob_ode_2Dlinear) -prob2 = ODEProblem(prob.f,prob.u0,(1.0,0.0),1.01) -sol = solve(prob2,CVODE_BDF()) +prob2 = ODEProblem(prob.f, prob.u0, (1.0, 0.0), 1.01) +sol = solve(prob2, CVODE_BDF()) @test maximum(diff(sol.t)) < 0 # Make sure all go negative -number_test(u,p,t) = -u^2 + (p[1] + t + p[2])*u + p[2] +number_test(u, p, t) = -u^2 + (p[1] + t + p[2]) * u + p[2] u0 = 0.0; tspan = (0.0, 10) -prob = ODEProblem(number_test,u0,tspan,(2.0,0.01)) -sol = solve(prob,CVODE_BDF()) +prob = ODEProblem(number_test, u0, tspan, (2.0, 0.01)) +sol = solve(prob, CVODE_BDF()) -sol = solve(prob2,CVODE_BDF(), maxiters=1) +sol = solve(prob2, CVODE_BDF(), maxiters = 1) @test length(sol.t) == 2 diff --git a/test/common_interface/errors.jl b/test/common_interface/errors.jl index 0a775e9..abedabd 100644 --- a/test/common_interface/errors.jl +++ b/test/common_interface/errors.jl @@ -2,16 +2,17 @@ using Sundials, Test println("Test error handling") -f_error(u,p,t) = u/t +f_error(u, p, t) = u / t u0 = 1.0 -prob = ODEProblem(f_error,u0,(0.0,1.0)) -sol = solve(prob,CVODE_BDF()) -sol = solve(prob,CVODE_BDF(),verbose=false) +prob = ODEProblem(f_error, u0, (0.0, 1.0)) +sol = solve(prob, CVODE_BDF()) +sol = solve(prob, CVODE_BDF(), verbose = false) -f_error2(du,u,p,t) = u/t-1 -u0 = 1.0; du0 = 1.0 -prob = DAEProblem(f_error2,u0,du0,(0.0,1.0),differential_vars=[1]) -sol = solve(prob,IDA()) -sol = solve(prob,IDA(),verbose=false) +f_error2(du, u, p, t) = u / t - 1 +u0 = 1.0; +du0 = 1.0; +prob = DAEProblem(f_error2, u0, du0, (0.0, 1.0), differential_vars = [1]) +sol = solve(prob, IDA()) +sol = solve(prob, IDA(), verbose = false) @test sol.retcode == :InitialFailure diff --git a/test/common_interface/ida.jl b/test/common_interface/ida.jl index 0a38acd..69d2564 100644 --- a/test/common_interface/ida.jl +++ b/test/common_interface/ida.jl @@ -1,85 +1,81 @@ using DiffEqProblemLibrary, Sundials, Test -using DiffEqProblemLibrary.DAEProblemLibrary: importdaeproblems; importdaeproblems() +using DiffEqProblemLibrary.DAEProblemLibrary: importdaeproblems; +importdaeproblems(); using DiffEqProblemLibrary.DAEProblemLibrary: prob_dae_resrob # Test DAE prob = prob_dae_resrob dt = 1000 saveat = float(collect(0:dt:100000)) -sol = solve(prob,IDA()) +sol = solve(prob, IDA()) @info "Multiple abstol" -sol = solve(prob,IDA(),abstol=[1e-9,1e-8,1e-7]) +sol = solve(prob, IDA(), abstol = [1e-9, 1e-8, 1e-7]) @info "Band solver" -sol2 = solve(prob,IDA(linear_solver=:Band,jac_upper=2,jac_lower=2)) +sol2 = solve(prob, IDA(linear_solver = :Band, jac_upper = 2, jac_lower = 2)) @info "GMRES solver" -sol3 = solve(prob,IDA(linear_solver=:GMRES)) +sol3 = solve(prob, IDA(linear_solver = :GMRES)) #sol4 = solve(prob,IDA(linear_solver=:BCG)) # Fails but doesn't throw an error? @info "TFQMR solver" -sol5 = solve(prob,IDA(linear_solver=:TFQMR)) +sol5 = solve(prob, IDA(linear_solver = :TFQMR)) @info "FGMRES solver" -sol6 = solve(prob,IDA(linear_solver=:FGMRES)) +sol6 = solve(prob, IDA(linear_solver = :FGMRES)) @info "PCG solver" -sol7 = solve(prob,IDA(linear_solver=:PCG)) # Requires symmetric linear +sol7 = solve(prob, IDA(linear_solver = :PCG)) # Requires symmetric linear #@info "KLU solver" #sol8 = solve(prob,IDA(linear_solver=:KLU)) # Requires Jacobian -sol9 = solve(prob,IDA(linear_solver=:LapackBand,jac_upper=2,jac_lower=2)) -sol10 = solve(prob,IDA(linear_solver=:LapackDense)) -sol11 = solve(prob,IDA(linear_solver=:Dense)) +sol9 = solve(prob, IDA(linear_solver = :LapackBand, jac_upper = 2, jac_lower = 2)) +sol10 = solve(prob, IDA(linear_solver = :LapackDense)) +sol11 = solve(prob, IDA(linear_solver = :Dense)) # Test identity preconditioner global prec_used = false global psetup_used = false -prec = (z,r,p,t,y,fy,resid,gamma,delta,lr)->(global prec_used=true;z.=r) -psetup = (p,t,resid,u,du,gamma) -> (global psetup_used = true) +prec = (z, r, p, t, y, fy, resid, gamma, delta, lr) -> (global prec_used = true; z .= r) +psetup = (p, t, resid, u, du, gamma) -> (global psetup_used = true) @info "GMRES for identity preconditioner" -sol4 = solve(prob,IDA(linear_solver=:GMRES, - prec_side = 3, - prec=prec)) +sol4 = solve(prob, IDA(linear_solver = :GMRES, prec_side = 3, prec = prec)) @test prec_used @info "GMRES with pset" -sol4 = solve(prob,IDA(linear_solver=:GMRES, - prec_side = 3, - prec=prec, - psetup=psetup)) +sol4 = solve(prob, IDA(linear_solver = :GMRES, prec_side = 3, prec = prec, psetup = psetup)) @test psetup_used @info "IDA with saveat" -sol = solve(prob,IDA(),saveat=saveat) +sol = solve(prob, IDA(), saveat = saveat) @test sol.t == saveat @info "IDA with saveat everystep" -sol = solve(prob,IDA(),saveat=saveat,save_everystep=true) +sol = solve(prob, IDA(), saveat = saveat, save_everystep = true) @test sol.t != saveat -@test intersect(sol.t,saveat) == saveat +@test intersect(sol.t, saveat) == saveat @info "IDA with tstops" -sol = solve(prob,IDA(),tstops=[0.9]) +sol = solve(prob, IDA(), tstops = [0.9]) @test 0.9 ∈ sol.t prob = deepcopy(prob_dae_resrob) -prob2 = DAEProblem(prob.f,prob.du0,prob.u0,(1.0,0.0)) -sol = solve(prob2,IDA()) +prob2 = DAEProblem(prob.f, prob.du0, prob.u0, (1.0, 0.0)) +sol = solve(prob2, IDA()) @test maximum(diff(sol.t)) < 0 # Make sure all go negative -function f!(res, du, u, p ,t) - res[1] = du[1]-1.01 +function f!(res, du, u, p, t) + res[1] = du[1] - 1.01 return end -u0 = [0.] +u0 = [0.0] du0 = [1.01] -tspan = (0.0, 10.) +tspan = (0.0, 10.0) println("With consistent initial conditions:") -dae_prob = DAEProblem(f!,du0,u0,tspan, differential_vars=[true]) -sol = solve(dae_prob,IDA()) +dae_prob = DAEProblem(f!, du0, u0, tspan, differential_vars = [true]) +sol = solve(dae_prob, IDA()) println("With inconsistent initial conditions:") -du0 = [0.] +du0 = [0.0] -dae_prob = DAEProblem(f!,du0, u0,tspan, differential_vars=[true]) -sol = solve(dae_prob,IDA()) +dae_prob = DAEProblem(f!, du0, u0, tspan, differential_vars = [true]) +sol = solve(dae_prob, IDA()) diff --git a/test/common_interface/iterators.jl b/test/common_interface/iterators.jl index 31d6c58..5cf6077 100644 --- a/test/common_interface/iterators.jl +++ b/test/common_interface/iterators.jl @@ -1,11 +1,12 @@ using DiffEqProblemLibrary, Sundials, Test -using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; importodeproblems() +using DiffEqProblemLibrary.ODEProblemLibrary: importodeproblems; +importodeproblems(); import DiffEqProblemLibrary.ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear prob = prob_ode_2Dlinear -integrator = init(prob,CVODE_BDF()) +integrator = init(prob, CVODE_BDF()) step!(integrator) -integrator(integrator.t,Val{1}) +integrator(integrator.t, Val{1}) for i in integrator @info i.t diff --git a/test/common_interface/jacobians.jl b/test/common_interface/jacobians.jl index c7f8673..fe7b085 100644 --- a/test/common_interface/jacobians.jl +++ b/test/common_interface/jacobians.jl @@ -2,41 +2,43 @@ using Sundials, Test, SparseArrays, DiffEqOperators # Test for Jacobian usage function Lotka(du, u, p, t) - du[1] = u[1] - u[1] * u[2] # REPL[7], line 3: - du[2] = -3 * u[2] + 1 * u[1] * u[2] - nothing + du[1] = u[1] - u[1] * u[2] # REPL[7], line 3: + du[2] = -3 * u[2] + 1 * u[1] * u[2] + nothing end jac_called = false -function Lotka_jac(J,u,p,t) - global jac_called - jac_called = true - J[1,1] = 1.0 - u[2] - J[1,2] = -u[1] - J[2,1] = 1 * u[2] - J[2,2] = -3 + u[1] - nothing +function Lotka_jac(J, u, p, t) + global jac_called + jac_called = true + J[1, 1] = 1.0 - u[2] + J[1, 2] = -u[1] + J[2, 1] = 1 * u[2] + J[2, 2] = -3 + u[1] + nothing end -Lotka_f = ODEFunction(Lotka,jac=Lotka_jac) -prob = ODEProblem(Lotka_f,ones(2),(0.0,10.0)) -good_sol = solve(prob,CVODE_BDF()) +Lotka_f = ODEFunction(Lotka, jac = Lotka_jac) +prob = ODEProblem(Lotka_f, ones(2), (0.0, 10.0)) +good_sol = solve(prob, CVODE_BDF()) @test jac_called == true -Lotka_f = ODEFunction(Lotka,jac=Lotka_jac, - jac_prototype = sparse([1,2,1,2],[1,1,2,2],zeros(4))) +Lotka_f = ODEFunction( + Lotka, + jac = Lotka_jac, + jac_prototype = sparse([1, 2, 1, 2], [1, 1, 2, 2], zeros(4)), +) -prob = ODEProblem(Lotka_f,ones(2),(0.0,10.0)) +prob = ODEProblem(Lotka_f, ones(2), (0.0, 10.0)) jac_called = false -sol9 = solve(prob,CVODE_BDF(linear_solver=:KLU)) +sol9 = solve(prob, CVODE_BDF(linear_solver = :KLU)) @test jac_called == true @test Array(sol9) ≈ Array(good_sol) -Lotka_fj = ODEFunction(Lotka, - jac_prototype = JacVecOperator{Float64}(Lotka,ones(2))) +Lotka_fj = ODEFunction(Lotka, jac_prototype = JacVecOperator{Float64}(Lotka, ones(2))) -prob = ODEProblem(Lotka_fj,ones(2),(0.0,10.0)) -sol9 = solve(prob,CVODE_BDF(linear_solver=:GMRES)) +prob = ODEProblem(Lotka_fj, ones(2), (0.0, 10.0)) +sol9 = solve(prob, CVODE_BDF(linear_solver = :GMRES)) function f2!(res, du, u, p, t) res[1] = 1.01du[1] @@ -50,43 +52,49 @@ function f2_jac!(out, du, u, p, gamma, t) out[1] = 1.01 end -f2_f = DAEFunction(f2!,jac=f2_jac!) +f2_f = DAEFunction(f2!, jac = f2_jac!) -u0 = [0.] -tspan = (0.0, 10.) -du0 = [0.] -dae_prob = DAEProblem(f2_f,du0, u0,tspan, differential_vars=[true]) -good_sol = solve(dae_prob,IDA()) +u0 = [0.0] +tspan = (0.0, 10.0) +du0 = [0.0] +dae_prob = DAEProblem(f2_f, du0, u0, tspan, differential_vars = [true]) +good_sol = solve(dae_prob, IDA()) @test jac_called == true function testjac(res, du, u, p, t) - res[1] = du[1] - 1.5 * u[1] + 1.0 * u[1]*u[2] - res[2] = du[2] +3 * u[2] - u[1]*u[2] + res[1] = du[1] - 1.5 * u[1] + 1.0 * u[1] * u[2] + res[2] = du[2] + 3 * u[2] - u[1] * u[2] end jac_called = false function testjac_jac(J, du, u, p, gamma, t) - global jac_called - jac_called = true - J[1,1] = gamma - 1.5 + 1.0 * u[2] - J[1,2] = 1.0 * u[1] - J[2,1] = - 1 * u[2] - J[2,2] = gamma + 3 - u[1] - nothing + global jac_called + jac_called = true + J[1, 1] = gamma - 1.5 + 1.0 * u[2] + J[1, 2] = 1.0 * u[1] + J[2, 1] = -1 * u[2] + J[2, 2] = gamma + 3 - u[1] + nothing end -testjac_f = DAEFunction(testjac,jac=testjac_jac) +testjac_f = DAEFunction(testjac, jac = testjac_jac) -prob3 = DAEProblem(testjac_f,[0.5,-2.0],ones(2),(0.0,10.0),differential_vars=[true,true]) +prob3 = DAEProblem( + testjac_f, + [0.5, -2.0], + ones(2), + (0.0, 10.0), + differential_vars = [true, true], +) sol3 = solve(prob3, IDA()) @test jac_called == true jac_called = false -prob4 = DAEProblem(testjac,[0.5,-2.0],ones(2),(0.0,10.0)) +prob4 = DAEProblem(testjac, [0.5, -2.0], ones(2), (0.0, 10.0)) sol4 = solve(prob4, IDA()) @test jac_called == false println("Jacobian vs no Jacobian difference:") -println(maximum(sol3-sol4)) -@test maximum(sol3-sol4) < 1e-6 +println(maximum(sol3 - sol4)) +@test maximum(sol3 - sol4) < 1e-6 diff --git a/test/common_interface/mass_matrix.jl b/test/common_interface/mass_matrix.jl index d47d44a..8132d92 100644 --- a/test/common_interface/mass_matrix.jl +++ b/test/common_interface/mass_matrix.jl @@ -1,35 +1,39 @@ -using Sundials, Test, LinearAlgebra - -# create mass matrix problems -function make_mm_probs(mm_A, ::Type{Val{iip}}) where iip - # iip - mm_b = vec(sum(mm_A; dims=2)) - function mm_f(du,u,p,t) - LinearAlgebra.mul!(du,mm_A,u) - du .+= t * mm_b - nothing - end - mm_g(du,u,p,t) = (@. du = u + t; nothing) - - # oop - mm_f(u,p,t) = mm_A * (u .+ t) - mm_g(u,p,t) = u .+ t - - mm_analytic(u0, p, t) = @. 2 * u0 * exp(t) - t - 1 - - u0 = ones(3) - tspan = (0.0, 1.0) - - prob = ODEProblem(ODEFunction{iip,true}(mm_f, analytic=mm_analytic, mass_matrix=mm_A), u0, tspan) - prob2 = ODEProblem(ODEFunction{iip,true}(mm_g, analytic=mm_analytic), u0, tspan) - - prob, prob2 -end - -mm_A = Float64[-2 1 4; 4 -2 1; 2 1 3] -prob, prob2 = make_mm_probs(mm_A, Val{true}) - -sol = solve(prob, ARKODE(), abstol=1e-8,reltol=1e-8) -sol2 = solve(prob2, ARKODE(), abstol=1e-8,reltol=1e-8) - -@test norm(sol .- sol2) ≈ 0 atol=1e-7 +using Sundials, Test, LinearAlgebra + +# create mass matrix problems +function make_mm_probs(mm_A, ::Type{Val{iip}}) where {iip} + # iip + mm_b = vec(sum(mm_A; dims = 2)) + function mm_f(du, u, p, t) + LinearAlgebra.mul!(du, mm_A, u) + du .+= t * mm_b + nothing + end + mm_g(du, u, p, t) = (@. du = u + t; nothing) + + # oop + mm_f(u, p, t) = mm_A * (u .+ t) + mm_g(u, p, t) = u .+ t + + mm_analytic(u0, p, t) = @. 2 * u0 * exp(t) - t - 1 + + u0 = ones(3) + tspan = (0.0, 1.0) + + prob = ODEProblem( + ODEFunction{iip, true}(mm_f, analytic = mm_analytic, mass_matrix = mm_A), + u0, + tspan, + ) + prob2 = ODEProblem(ODEFunction{iip, true}(mm_g, analytic = mm_analytic), u0, tspan) + + prob, prob2 +end + +mm_A = Float64[-2 1 4; 4 -2 1; 2 1 3] +prob, prob2 = make_mm_probs(mm_A, Val{true}) + +sol = solve(prob, ARKODE(), abstol = 1e-8, reltol = 1e-8) +sol2 = solve(prob2, ARKODE(), abstol = 1e-8, reltol = 1e-8) + +@test norm(sol .- sol2) ≈ 0 atol = 1e-7 diff --git a/test/cvode_Roberts_dns.jl b/test/cvode_Roberts_dns.jl index 954009e..de1c70a 100644 --- a/test/cvode_Roberts_dns.jl +++ b/test/cvode_Roberts_dns.jl @@ -5,13 +5,12 @@ using Sundials function f(t, y_nv, ydot_nv, user_data) y = convert(Vector, y_nv) ydot = convert(Vector, ydot_nv) - ydot[1] = -0.04*y[1] + 1.0e4*y[2]*y[3] - ydot[3] = 3.0e7*y[2]*y[2] + ydot[1] = -0.04 * y[1] + 1.0e4 * y[2] * y[3] + ydot[3] = 3.0e7 * y[2] * y[2] ydot[2] = -ydot[1] - ydot[3] return Sundials.CV_SUCCESS end - ## g routine. Compute functions g_i(t,y) for i = 0,1. function g(t, y_nv, gout_ptr, user_data) @@ -26,17 +25,22 @@ end # broken -- needs a wrapper from Sundials._DlsMat to Matrix and Jac user function wrapper function Jac(N, t, ny, fy, Jptr, user_data, tmp1, tmp2, tmp3) y = convert(Vector, ny) - dlsmat = unpack(IOString(unsafe_wrap(convert(Ptr{UInt8}, Jptr), - (sum(map(sizeof, Sundials._DlsMat))+10,), false)), - Sundials._DlsMat) + dlsmat = unpack( + IOString(unsafe_wrap( + convert(Ptr{UInt8}, Jptr), + (sum(map(sizeof, Sundials._DlsMat)) + 10,), + false, + )), + Sundials._DlsMat, + ) J = unsafe_wrap(unsafe_ref(dlsmat.cols), (Int(neq), Int(neq)), false) - J[1,1] = -0.04 - J[1,2] = 1.0e4*y[3] - J[1,3] = 1.0e4*y[2] - J[2,1] = 0.04 - J[2,2] = -1.0e4*y[3] - 6.0e7*y[2] - J[2,3] = -1.0e4*y[2] - J[3,2] = 6.0e7*y[2] + J[1, 1] = -0.04 + J[1, 2] = 1.0e4 * y[3] + J[1, 3] = 1.0e4 * y[2] + J[2, 1] = 0.04 + J[2, 2] = -1.0e4 * y[3] - 6.0e7 * y[2] + J[2, 3] = -1.0e4 * y[2] + J[3, 2] = 6.0e7 * y[2] return Sundials.CV_SUCCESS end @@ -55,21 +59,22 @@ cvode_mem = Sundials.Handle(mem_ptr) userfun = Sundials.UserFunctionAndData(f, userdata) Sundials.CVodeSetUserData(cvode_mem, userfun) -function getcfunrob(userfun::T) where T - @cfunction(Sundials.cvodefun, - Cint, (Sundials.realtype, Sundials.N_Vector, - Sundials.N_Vector, Ref{T})) +function getcfunrob(userfun::T) where {T} + @cfunction( + Sundials.cvodefun, + Cint, + (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Ref{T}) + ) end -Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, - convert(Sundials.N_Vector, y0)) +Sundials.CVodeInit(cvode_mem, getcfunrob(userfun), t1, convert(Sundials.N_Vector, y0)) Sundials.@checkflag Sundials.CVodeInit(cvode_mem, f, t0, y0) Sundials.@checkflag Sundials.CVodeSVtolerances(cvode_mem, reltol, abstol) Sundials.@checkflag Sundials.CVodeRootInit(cvode_mem, 2, g) -A = Sundials.SUNDenseMatrix(neq,neq) -mat_handle = Sundials.MatrixHandle(A,Sundials.DenseMatrix()) -LS = Sundials.SUNLinSol_Dense(convert(Sundials.N_Vector,y0),A) -LS_handle = Sundials.LinSolHandle(LS,Sundials.Dense()) +A = Sundials.SUNDenseMatrix(neq, neq) +mat_handle = Sundials.MatrixHandle(A, Sundials.DenseMatrix()) +LS = Sundials.SUNLinSol_Dense(convert(Sundials.N_Vector, y0), A) +LS_handle = Sundials.LinSolHandle(LS, Sundials.Dense()) Sundials.@checkflag Sundials.CVDlsSetLinearSolver(cvode_mem, LS, A) #Sundials.@checkflag Sundials.CVDlsSetDenseJacFn(cvode_mem, Jac) @@ -86,8 +91,8 @@ while iout < nout Sundials.@checkflag Sundials.CVodeGetRootInfo(cvode_mem, rootsfound) println("roots=", rootsfound) elseif flag == Sundials.CV_SUCCESS - global iout += 1 - global tout *= tmult + global iout += 1 + global tout *= tmult end end diff --git a/test/cvode_Roberts_simplified.jl b/test/cvode_Roberts_simplified.jl index a1feee9..5581100 100644 --- a/test/cvode_Roberts_simplified.jl +++ b/test/cvode_Roberts_simplified.jl @@ -3,12 +3,12 @@ using Sundials ## f routine. Compute function f(t,y). function f(t, y, ydot) - ydot[1] = -0.04*y[1] + 1.0e4*y[2]*y[3] - ydot[3] = 3.0e7*y[2]*y[2] + ydot[1] = -0.04 * y[1] + 1.0e4 * y[2] * y[3] + ydot[3] = 3.0e7 * y[2] * y[2] ydot[2] = -ydot[1] - ydot[3] return Sundials.CV_SUCCESS end -t = [0.0; 4 * exp10.(range(-1., stop=7., length=9))] +t = [0.0; 4 * exp10.(range(-1.0, stop = 7.0, length = 9))] y0 = [1.0, 0.0, 0.0] res = Sundials.cvode(f, y0, t) diff --git a/test/cvodes_dns.jl b/test/cvodes_dns.jl index b734bc1..962964a 100644 --- a/test/cvodes_dns.jl +++ b/test/cvodes_dns.jl @@ -3,38 +3,38 @@ using Sundials: N_Vector, N_Vector_S using LinearAlgebra function mycopy!(pp, arr::Matrix) - nj = size(arr,2) - ps = unsafe_wrap(Array, pp, nj) - for j = 1:nj - arr[:,j] = Sundials.asarray(ps[j]) - end - arr + nj = size(arr, 2) + ps = unsafe_wrap(Array, pp, nj) + for j in 1:nj + arr[:, j] = Sundials.asarray(ps[j]) + end + arr end function mycopy!(arr::Matrix, pp) - nj = size(arr,2) - ps = unsafe_wrap(Array, pp, nj) - for j = 1:nj - Sundials.asarray(ps[j])[:] = arr[:,j] - end + nj = size(arr, 2) + ps = unsafe_wrap(Array, pp, nj) + for j in 1:nj + Sundials.asarray(ps[j])[:] = arr[:, j] + end end -f!(dy,t,y,p) = (dy[:]=y.*p) - -function srhs(t,y,ydot,ys,ysdot) - n = length(y) - np = 2 - dyt = similar(y) - chunk = min(n, 8) - c1 = ForwardDiff.JacobianConfig(nothing, dyt, dyt, ForwardDiff.Chunk{chunk}()) - c2 = ForwardDiff.JacobianConfig(nothing, dyt, p, ForwardDiff.Chunk{chunk}()) - jac = ForwardDiff.jacobian((dy,y)->f!(dy,t,y,p), dyt, y, c1, Val{false}()) - #jac = ReverseDiff.jacobian!(t1, y) - ysdot[:] = jac * ys - - jac = ForwardDiff.jacobian((dy,p)->f!(dy,t,y,p), dyt, p, c2, Val{false}()) - #jac = ReverseDiff.jacobian!(t2, p) - ysdot[:, 1:np] += jac +f!(dy, t, y, p) = (dy[:] = y .* p) + +function srhs(t, y, ydot, ys, ysdot) + n = length(y) + np = 2 + dyt = similar(y) + chunk = min(n, 8) + c1 = ForwardDiff.JacobianConfig(nothing, dyt, dyt, ForwardDiff.Chunk{chunk}()) + c2 = ForwardDiff.JacobianConfig(nothing, dyt, p, ForwardDiff.Chunk{chunk}()) + jac = ForwardDiff.jacobian((dy, y) -> f!(dy, t, y, p), dyt, y, c1, Val{false}()) + #jac = ReverseDiff.jacobian!(t1, y) + ysdot[:] = jac * ys + + jac = ForwardDiff.jacobian((dy, p) -> f!(dy, t, y, p), dyt, p, c2, Val{false}()) + #jac = ReverseDiff.jacobian!(t2, p) + ysdot[:, 1:np] += jac end """ @@ -45,100 +45,134 @@ y[i, t] is the solutions component i at timestep t. ys[i, j, t] is the i-th component sensivity wrt the j-th parameter at timestep t. ys[i, np+j, t] the i-th component sensivity wrt the j-th initial condition value. """ -function sens(f!::Function, t0::Float64, y0::Vector{Float64}, p::Vector{Float64}, tout::Vector{Float64}; reltol::Float64 = 1e-5, abstol::Float64 = 1e-5) - n = length(y0) - np = length(p) - ys0 = zeros(n,np.+n) - ys0[:, np.+(1:n)] = Matrix(1.0I,n,n) - - #t1 = ReverseDiff.JacobianTape((dy,y)->f!(dy,0,y,p), dyt, y0) - #t2 = ReverseDiff.JacobianTape((dy,p)->f!(dy,0,y0,p), dyt, p) - - pbar = abs.(vcat(p, y0)) - y, ys = cvodes(f!, srhs, t0, y0, ys0, p, reltol, abstol, pbar, tout) +function sens( + f!::Function, + t0::Float64, + y0::Vector{Float64}, + p::Vector{Float64}, + tout::Vector{Float64}; + reltol::Float64 = 1e-5, + abstol::Float64 = 1e-5, +) + n = length(y0) + np = length(p) + ys0 = zeros(n, np .+ n) + ys0[:, np .+ (1:n)] = Matrix(1.0I, n, n) + + #t1 = ReverseDiff.JacobianTape((dy,y)->f!(dy,0,y,p), dyt, y0) + #t2 = ReverseDiff.JacobianTape((dy,p)->f!(dy,0,y0,p), dyt, p) + + pbar = abs.(vcat(p, y0)) + y, ys = cvodes(f!, srhs, t0, y0, ys0, p, reltol, abstol, pbar, tout) end - ### internals ## data structure dealing with the sundials callbacks struct CVSData - f # f(t,y,dy) - fs # fs() - p - jys - jdys + f::Any # f(t,y,dy) + fs::Any # fs() + p::Any + jys::Any + jdys::Any end -CVSData(f, fs, p, n::Int, nS::Int) = CVSData(f, fs, p, Array{Float64}(undef, n, nS), Array{Float64}(undef, n, nS)) +CVSData(f, fs, p, n::Int, nS::Int) = + CVSData(f, fs, p, Array{Float64}(undef, n, nS), Array{Float64}(undef, n, nS)) function cvrhsfn(t::Float64, y::N_Vector, dy::N_Vector, data::CVSData) - data.f(convert(Vector,dy), t, convert(Vector,y), data.p) + data.f(convert(Vector, dy), t, convert(Vector, y), data.p) return Sundials.CV_SUCCESS end -function cvsensrhsfn(ns::Cint, t::Float64, y::N_Vector, dy::N_Vector, ys::N_Vector_S, dys::N_Vector_S, data::CVSData, tmp1::N_Vector, tmp2::N_Vector) +function cvsensrhsfn( + ns::Cint, + t::Float64, + y::N_Vector, + dy::N_Vector, + ys::N_Vector_S, + dys::N_Vector_S, + data::CVSData, + tmp1::N_Vector, + tmp2::N_Vector, +) jys = data.jys jdys = data.jdys mycopy!(ys, data.jys) - data.fs(t, convert(Vector,y), convert(Vector,dy), jys, jdys) + data.fs(t, convert(Vector, y), convert(Vector, dy), jys, jdys) mycopy!(jdys, dys) return Sundials.CV_SUCCESS end - ## cvodes wrapper "Given the sensivity problem, return (y,ys) where y[i,t] is the solutions i-th componnent for timestep t and ys[i,j,t] is the sensivity of the i-th component wrt to the j-th paramater, where the last parameter indices correspond to the initial conditions components." -function cvodes(f,fS, t0, y0, yS0, p, reltol, abstol, pbar, t::AbstractVector) - N, Ns = size(yS0) - y = zeros(N, length(t)) - ys = zeros(N, Ns, length(t)) - tret = [t0] - yret = similar(y0) - ysret = similar(yS0) - yS0n = [Sundials.NVector(yS0[:,j]) for j=1:Ns] - yS0nv = [N_Vector(n) for n in yS0n] - pyS0 = pointer(yS0nv) - crhs = Sundials.@cfunction(cvrhsfn, Cint, (Sundials.realtype, N_Vector, N_Vector, Ref{CVSData})) - csensrhs = Sundials.@cfunction(cvsensrhsfn, Cint, (Cint, Sundials.realtype, N_Vector, N_Vector, N_Vector_S, N_Vector_S, Ref{CVSData}, N_Vector, N_Vector)) - - ## - - mem_ptr = Sundials.CVodeCreate(Sundials.CV_ADAMS) - #mem_ptr = Sundials.CVodeCreate(Sundials.CV_BDF) - cvode_mem = Sundials.Handle(mem_ptr) - Sundials.CVodeSetUserData(cvode_mem, CVSData(f, fS, p, size(yS0)...)) - - Sundials.CVodeInit(cvode_mem, crhs, t0, convert(N_Vector, y0)) - Sundials.CVodeSStolerances(cvode_mem, reltol, abstol) - - Sundials.CVodeSensInit(cvode_mem, Ns, Sundials.CV_STAGGERED, csensrhs, pyS0) - Sundials.CVodeSetSensParams(cvode_mem, C_NULL, pbar, C_NULL) - Sundials.CVodeSensEEtolerances(cvode_mem) - for i in 1:length(t) - @info "here1" - Sundials.CVode(cvode_mem, t[i], yret, tret, Sundials.CV_NORMAL) - @info "here2" - Sundials.CVodeGetSens(cvode_mem, tret, pyS0) - @info "here3" - mycopy!(pyS0, ysret) - y[:,i] = yret - ys[:,:,i] = ysret - end - empty!(cvode_mem) - y, ys +function cvodes(f, fS, t0, y0, yS0, p, reltol, abstol, pbar, t::AbstractVector) + N, Ns = size(yS0) + y = zeros(N, length(t)) + ys = zeros(N, Ns, length(t)) + tret = [t0] + yret = similar(y0) + ysret = similar(yS0) + yS0n = [Sundials.NVector(yS0[:, j]) for j in 1:Ns] + yS0nv = [N_Vector(n) for n in yS0n] + pyS0 = pointer(yS0nv) + crhs = Sundials.@cfunction( + cvrhsfn, + Cint, + (Sundials.realtype, N_Vector, N_Vector, Ref{CVSData}) + ) + csensrhs = Sundials.@cfunction( + cvsensrhsfn, + Cint, + ( + Cint, + Sundials.realtype, + N_Vector, + N_Vector, + N_Vector_S, + N_Vector_S, + Ref{CVSData}, + N_Vector, + N_Vector, + ) + ) + + ## + + mem_ptr = Sundials.CVodeCreate(Sundials.CV_ADAMS) + #mem_ptr = Sundials.CVodeCreate(Sundials.CV_BDF) + cvode_mem = Sundials.Handle(mem_ptr) + Sundials.CVodeSetUserData(cvode_mem, CVSData(f, fS, p, size(yS0)...)) + + Sundials.CVodeInit(cvode_mem, crhs, t0, convert(N_Vector, y0)) + Sundials.CVodeSStolerances(cvode_mem, reltol, abstol) + + Sundials.CVodeSensInit(cvode_mem, Ns, Sundials.CV_STAGGERED, csensrhs, pyS0) + Sundials.CVodeSetSensParams(cvode_mem, C_NULL, pbar, C_NULL) + Sundials.CVodeSensEEtolerances(cvode_mem) + for i in 1:length(t) + @info "here1" + Sundials.CVode(cvode_mem, t[i], yret, tret, Sundials.CV_NORMAL) + @info "here2" + Sundials.CVodeGetSens(cvode_mem, tret, pyS0) + @info "here3" + mycopy!(pyS0, ysret) + y[:, i] = yret + ys[:, :, i] = ysret + end + empty!(cvode_mem) + y, ys end - -t0 = 0. -t = [1., 2.] -y0 = [1., 2.] -p = [3., 4.] +t0 = 0.0 +t = [1.0, 2.0] +y0 = [1.0, 2.0] +p = [3.0, 4.0] y, ys = sens(f!, t0, y0, p, t) -@test isapprox(y[1,1], 20.0856, rtol=1e-3) -@test isapprox(ys[2,2,2], 11924.3, rtol=1e-3) # todo: check if these are indeed the right results +@test isapprox(y[1, 1], 20.0856, rtol = 1e-3) +@test isapprox(ys[2, 2, 2], 11924.3, rtol = 1e-3) # todo: check if these are indeed the right results diff --git a/test/erkstep_nonlin.jl b/test/erkstep_nonlin.jl index 07d4370..475b174 100644 --- a/test/erkstep_nonlin.jl +++ b/test/erkstep_nonlin.jl @@ -41,7 +41,7 @@ y0 = [0.0] function f(t, y, ydot, user_data) y = convert(Vector, y) ydot = convert(Vector, ydot) - ydot[1] = (t+1.0)*exp(-1*y[1]) + ydot[1] = (t + 1.0) * exp(-1 * y[1]) return Sundials.ARK_SUCCESS end @@ -61,9 +61,9 @@ while (tf - t[1] > 1e-15) end t = 0.0:1:10 -y_analytic = log.((0.5*t.^2 .+ t .+ 1)) +y_analytic = log.((0.5 * t .^ 2 .+ t .+ 1)) for i in 1:length(t) - @test isapprox(y_analytic[1], res[1]; atol= 1e-3) + @test isapprox(y_analytic[1], res[1]; atol = 1e-3) end y = nothing temp = Ref(Clong(-1)) diff --git a/test/handle_tests.jl b/test/handle_tests.jl index 623602e..a4bf137 100644 --- a/test/handle_tests.jl +++ b/test/handle_tests.jl @@ -17,32 +17,38 @@ h = Sundials.Handle(h1.ptr_ref[]) # Check construction with null pointers @test isempty(h) neq = 3 -h3 = Sundials.MatrixHandle(Sundials.SUNDenseMatrix(neq,neq),Sundials.DenseMatrix()) -h3 = Sundials.MatrixHandle(Sundials.SUNDenseMatrix(neq,neq),Sundials.DenseMatrix()) +h3 = Sundials.MatrixHandle(Sundials.SUNDenseMatrix(neq, neq), Sundials.DenseMatrix()) +h3 = Sundials.MatrixHandle(Sundials.SUNDenseMatrix(neq, neq), Sundials.DenseMatrix()) empty!(h3) @test isempty(h3) empty!(h3) @test isempty(h3) -h3 = Sundials.MatrixHandle(Sundials.SUNBandMatrix(100,3,3),Sundials.BandMatrix()) -h3 = Sundials.MatrixHandle(Sundials.SUNBandMatrix(100,3,3),Sundials.BandMatrix()) +h3 = Sundials.MatrixHandle(Sundials.SUNBandMatrix(100, 3, 3), Sundials.BandMatrix()) +h3 = Sundials.MatrixHandle(Sundials.SUNBandMatrix(100, 3, 3), Sundials.BandMatrix()) empty!(h3) @test isempty(h3) empty!(h3) @test isempty(h3) -h3 = Sundials.MatrixHandle(Sundials.SUNSparseMatrix(neq,neq,neq,Sundials.CSC_MAT),Sundials.SparseMatrix()) -h3 = Sundials.MatrixHandle(Sundials.SUNSparseMatrix(neq,neq,neq,Sundials.CSC_MAT),Sundials.SparseMatrix()) +h3 = Sundials.MatrixHandle( + Sundials.SUNSparseMatrix(neq, neq, neq, Sundials.CSC_MAT), + Sundials.SparseMatrix(), +) +h3 = Sundials.MatrixHandle( + Sundials.SUNSparseMatrix(neq, neq, neq, Sundials.CSC_MAT), + Sundials.SparseMatrix(), +) empty!(h3) @test isempty(h3) empty!(h3) @test isempty(h3) -A = Sundials.SUNDenseMatrix(neq,neq) +A = Sundials.SUNDenseMatrix(neq, neq) u0 = rand(neq) -Sundials.SUNLinSol_Dense(u0,A) -h3 = Sundials.LinSolHandle(Sundials.SUNLinSol_Dense(u0,A),Sundials.Dense()) -h3 = Sundials.LinSolHandle(Sundials.SUNLinSol_Dense(u0,A),Sundials.Dense()) +Sundials.SUNLinSol_Dense(u0, A) +h3 = Sundials.LinSolHandle(Sundials.SUNLinSol_Dense(u0, A), Sundials.Dense()) +h3 = Sundials.LinSolHandle(Sundials.SUNLinSol_Dense(u0, A), Sundials.Dense()) empty!(h3) @test isempty(h3) empty!(h3) diff --git a/test/ida_Heat2D.jl b/test/ida_Heat2D.jl index 4766d90..4157e7f 100644 --- a/test/ida_Heat2D.jl +++ b/test/ida_Heat2D.jl @@ -26,9 +26,9 @@ using Sundials ## MGRID = 10 -NEQ = MGRID*MGRID +NEQ = MGRID * MGRID -dx = 1.0/(MGRID - 1.0) +dx = 1.0 / (MGRID - 1.0) coeff = 1.0 / (dx * dx) bval = 0.1 ## @@ -45,32 +45,32 @@ function heatres(t, u, up, r) r[:] = u ## Initialize r to u, to take care of boundary equations. ## Loop over interior points; set res = up - (central difference). - for j = 2:(MGRID-2) + for j in 2:(MGRID - 2) offset = MGRID * j - for i = 2:(MGRID-2) + for i in 2:(MGRID - 2) loc = offset + i - r[loc] = up[loc] - coeff * (u[loc-1] + u[loc+1] + - u[loc-MGRID] + u[loc+MGRID] - - 4.0 * u[loc]) + r[loc] = + up[loc] - + coeff * + (u[loc - 1] + u[loc + 1] + u[loc - MGRID] + u[loc + MGRID] - 4.0 * u[loc]) end end return Sundials.CV_SUCCESS end - function initial() mm = MGRID mm1 = mm - 1 - u = zeros(NEQ) + u = zeros(NEQ) id = ones(NEQ) ## initialize u on all grid points - for j = 1:mm-1 + for j in 1:(mm - 1) yfact = dx * j - offset = mm*j - for i = 1:mm-1 + offset = mm * j + for i in 1:(mm - 1) xfact = dx * i loc = offset + i u[loc] = 48.0 * xfact * (1.0 - xfact) * yfact * (1.0 - yfact) @@ -78,7 +78,7 @@ function initial() end up = zeros(NEQ) - r = zeros(NEQ) + r = zeros(NEQ) heatres(0.0, u, up, r) @@ -86,10 +86,10 @@ function initial() up[:] = -1.0 * r ## Finally, set values of u, up, and id at boundary points. - for j = 1:mm-1 + for j in 1:(mm - 1) offset = mm * j - for i = 1:mm-1 - loc = offset + i; + for i in 1:(mm - 1) + loc = offset + i if j == 1 || j == mm1 || i == 1 || i == mm1 u[loc] = bval up[loc] = 0 @@ -99,26 +99,45 @@ function initial() end constraints = ones(NEQ) - return (u,up,id,constraints) + return (u, up, id, constraints) end -function idabandsol(f::Function, y0::Vector{Float64}, yp0::Vector{Float64}, - id::Vector{Float64}, constraints::Vector{Float64}, - t::Vector{Float64}; - reltol::Float64=1e-4, abstol::Float64=1e-6) +function idabandsol( + f::Function, + y0::Vector{Float64}, + yp0::Vector{Float64}, + id::Vector{Float64}, + constraints::Vector{Float64}, + t::Vector{Float64}; + reltol::Float64 = 1e-4, + abstol::Float64 = 1e-6, +) neq = length(y0) mem = Sundials.IDACreate() - Sundials.@checkflag Sundials.IDAInit(mem, @cfunction( - Sundials.idasolfun, Cint, - (Sundials.realtype, Sundials.N_Vector, Sundials.N_Vector, Sundials.N_Vector, Ref{Function})), - t[1], y0, yp0) + Sundials.@checkflag Sundials.IDAInit( + mem, + @cfunction( + Sundials.idasolfun, + Cint, + ( + Sundials.realtype, + Sundials.N_Vector, + Sundials.N_Vector, + Sundials.N_Vector, + Ref{Function}, + ) + ), + t[1], + y0, + yp0, + ) Sundials.@checkflag Sundials.IDASetId(mem, id) Sundials.@checkflag Sundials.IDASetConstraints(mem, constraints) Sundials.@checkflag Sundials.IDASetUserData(mem, f) Sundials.@checkflag Sundials.IDASStolerances(mem, reltol, abstol) A = Sundials.SUNBandMatrix(neq, MGRID, MGRID)#,2MGRID) - LS = Sundials.SUNLinSol_Band(y0,A) + LS = Sundials.SUNLinSol_Band(y0, A) Sundials.@checkflag Sundials.IDADlsSetLinearSolver(mem, LS, A) rtest = zeros(neq) @@ -144,8 +163,7 @@ end nsteps = 10 tstep = 0.005 -t = collect(0.0:tstep:(tstep*nsteps)) +t = collect(0.0:tstep:(tstep * nsteps)) u0, up0, id, constraints = initial() -idabandsol(heatres, u0, up0, id, constraints, map(x -> x, t), - reltol = 0.0, abstol = 1e-3) +idabandsol(heatres, u0, up0, id, constraints, map(x -> x, t), reltol = 0.0, abstol = 1e-3) diff --git a/test/ida_Roberts_dns.jl b/test/ida_Roberts_dns.jl index 9c80b11..391aba5 100644 --- a/test/ida_Roberts_dns.jl +++ b/test/ida_Roberts_dns.jl @@ -1,7 +1,6 @@ ## Adapted from doc/libsundials-serial-dev/examples/ida/serial/idaRoberts_dns.c and ## sundialsTB/ida/examples_ser/midasRoberts_dns.m - ## /* ## * ----------------------------------------------------------------- ## * $Revision: 1.2 $ @@ -33,16 +32,15 @@ using Sundials - ## Define the system residual function. function resrob(tres, yy_nv, yp_nv, rr_nv, user_data) yy = convert(Vector, yy_nv) yp = convert(Vector, yp_nv) rr = convert(Vector, rr_nv) - rr[1] = -0.04*yy[1] + 1.0e4*yy[2]*yy[3] - rr[2] = -rr[1] - 3.0e7*yy[2]*yy[2] - yp[2] - rr[1] -= yp[1] - rr[3] = yy[1] + yy[2] + yy[3] - 1.0 + rr[1] = -0.04 * yy[1] + 1.0e4 * yy[2] * yy[3] + rr[2] = -rr[1] - 3.0e7 * yy[2] * yy[2] - yp[2] + rr[1] -= yp[1] + rr[3] = yy[1] + yy[2] + yy[3] - 1.0 return Sundials.IDA_SUCCESS end @@ -56,26 +54,25 @@ function grob(t, yy_nv, yp_nv, gout_ptr, user_data) end ## Define the Jacobian function. BROKEN - JJ is wrong -function jacrob(Neq, tt, cj, yy, yp, resvec, - JJ, user_data, tempv1, tempv2, tempv3) - JJ = pointer_to_array(convert(Ptr{Float64}, JJ), (3,3)) - JJ[1,1] = -0.04 - cj - JJ[2,1] = 0.04 - JJ[3,1] = 1.0 - JJ[1,2] = 1.0e4*yy[3] - JJ[2,2] = -1.0e4*yy[3] - 6.0e7*yy[2] - cj - JJ[3,2] = 1.0 - JJ[1,3] = 1.0e4*yy[2] - JJ[2,3] = -1.0e4*yy[2] - JJ[3,3] = 1.0 +function jacrob(Neq, tt, cj, yy, yp, resvec, JJ, user_data, tempv1, tempv2, tempv3) + JJ = pointer_to_array(convert(Ptr{Float64}, JJ), (3, 3)) + JJ[1, 1] = -0.04 - cj + JJ[2, 1] = 0.04 + JJ[3, 1] = 1.0 + JJ[1, 2] = 1.0e4 * yy[3] + JJ[2, 2] = -1.0e4 * yy[3] - 6.0e7 * yy[2] - cj + JJ[3, 2] = 1.0 + JJ[1, 3] = 1.0e4 * yy[2] + JJ[2, 3] = -1.0e4 * yy[2] + JJ[3, 3] = 1.0 return Sundials.IDA_SUCCESS end neq = 3 nout = 12 t0 = 0.0 -yy0 = [1.0,0.0,0.0] -yp0 = [-0.04,0.04,0.0] +yy0 = [1.0, 0.0, 0.0] +yp0 = [-0.04, 0.04, 0.0] rtol = 1e-4 avtol = [1e-8, 1e-14, 1e-6] tout1 = 0.4 @@ -88,8 +85,8 @@ Sundials.@checkflag Sundials.IDASVtolerances(mem, rtol, avtol) Sundials.@checkflag Sundials.IDARootInit(mem, 2, grob) ## Call IDADense and set up the linear solver. -A = Sundials.SUNDenseMatrix(length(y0),length(y0)) -LS = Sundials.SUNLinSol_Dense(y0,A) +A = Sundials.SUNDenseMatrix(length(y0), length(y0)) +LS = Sundials.SUNLinSol_Dense(y0, A) Sundials.@checkflag Sundials.IDADlsSetLinearSolver(mem, LS, A) iout = 0 diff --git a/test/ida_Roberts_simplified.jl b/test/ida_Roberts_simplified.jl index 222ec95..22003d1 100644 --- a/test/ida_Roberts_simplified.jl +++ b/test/ida_Roberts_simplified.jl @@ -2,11 +2,11 @@ using Sundials ## Define the system residual function. function resrob(tres, y, yp, r) - r[1] = -0.04*y[1] + 1.0e4*y[2]*y[3] - r[2] = -r[1] - 3.0e7*y[2]*y[2] - yp[2] - r[1] -= yp[1] - r[3] = y[1] + y[2] + y[3] - 1.0 + r[1] = -0.04 * y[1] + 1.0e4 * y[2] * y[3] + r[2] = -r[1] - 3.0e7 * y[2] * y[2] - yp[2] + r[1] -= yp[1] + r[3] = y[1] + y[2] + y[3] - 1.0 end -t = [0.0; 4 * exp10.(range(-1., stop=5., length=7))] +t = [0.0; 4 * exp10.(range(-1.0, stop = 5.0, length = 7))] yout, ypout = Sundials.idasol(resrob, [1.0, 0, 0], [-0.04, 0.04, 0.0], t) diff --git a/test/kinsol_banded.jl b/test/kinsol_banded.jl index 820d23f..88993fd 100644 --- a/test/kinsol_banded.jl +++ b/test/kinsol_banded.jl @@ -1,7 +1,8 @@ function f!(resid, x) for i in eachindex(x) - resid[i] = sin(x[i]) + x[i]^3 + resid[i] = sin(x[i]) + x[i]^3 end end x = ones(5) -@test Sundials.kinsol(f!, x, linear_solver=:Band, jac_upper=0, jac_lower=0) == Sundials.kinsol(f!, x) +@test Sundials.kinsol(f!, x, linear_solver = :Band, jac_upper = 0, jac_lower = 0) == + Sundials.kinsol(f!, x) diff --git a/test/kinsol_mkinTest.jl b/test/kinsol_mkinTest.jl index 62dd57a..43b6fa1 100644 --- a/test/kinsol_mkinTest.jl +++ b/test/kinsol_mkinTest.jl @@ -28,13 +28,12 @@ Sundials.@checkflag Sundials.KINSetScaledStepTol(kmem, 1.0e-4) Sundials.@checkflag Sundials.KINSetMaxSetupCalls(kmem, 1) y = ones(neq) Sundials.@checkflag Sundials.KINInit(kmem, sysfn, y) -A = Sundials.SUNDenseMatrix(length(y),length(y)) -LS = Sundials.SUNLinSol_Dense(y,A) +A = Sundials.SUNDenseMatrix(length(y), length(y)) +LS = Sundials.SUNLinSol_Dense(y, A) Sundials.@checkflag Sundials.KINDlsSetLinearSolver(kmem, LS, A) ## Solve problem scale = ones(neq) -Sundials.@checkflag Sundials.KINSol(kmem, y, Sundials.KIN_LINESEARCH, - scale, scale) +Sundials.@checkflag Sundials.KINSol(kmem, y, Sundials.KIN_LINESEARCH, scale, scale) println("Solution: ", y) residual = ones(2) diff --git a/test/mri_twowaycouple.jl b/test/mri_twowaycouple.jl index 99e8143..db9eae6 100644 --- a/test/mri_twowaycouple.jl +++ b/test/mri_twowaycouple.jl @@ -39,8 +39,8 @@ using Sundials# Test function ff(t, y_nv, ydot_nv, user_data) y = convert(Vector, y_nv) ydot = convert(Vector, ydot_nv) - ydot[1] = 100.0*y[2] - ydot[2] = -100.0*y[1] + ydot[1] = 100.0 * y[2] + ydot[2] = -100.0 * y[1] ydot[3] = y[1] return Sundials.ARK_SUCCESS end @@ -58,7 +58,7 @@ T0 = 0.0 Tf = 2.0 dTout = 0.1 Neq = 3 -Nt = ceil(Tf/dTout) +Nt = ceil(Tf / dTout) hs = 0.001 hf = 0.00002 y0 = [0.90001, -9.999, 1000.0] @@ -66,7 +66,11 @@ y0 = [0.90001, -9.999, 1000.0] # Fast Integration portion _mem_ptr = Sundials.ARKStepCreate(ff, C_NULL, T0, y0); inner_arkode_mem = Sundials.Handle(_mem_ptr) -Sundials.@checkflag Sundials.ARKStepSetTableNum(inner_arkode_mem, -1, Sundials.KNOTH_WOLKE_3_3) +Sundials.@checkflag Sundials.ARKStepSetTableNum( + inner_arkode_mem, + -1, + Sundials.KNOTH_WOLKE_3_3, +) Sundials.@checkflag Sundials.ARKStepSetFixedStep(inner_arkode_mem, hf) # Slow integrator portion @@ -75,19 +79,19 @@ arkode_mem = Sundials.Handle(_arkode_mem_ptr) Sundials.@checkflag Sundials.MRIStepSetFixedStep(arkode_mem, hs) t = [T0] -tout = T0+dTout +tout = T0 + dTout res = Dict(0 => y0) for i in 1:Nt y = similar(y0) global retval = Sundials.MRIStepEvolve(arkode_mem, tout, y, t, Sundials.ARK_NORMAL) - global tout += dTout; - global tout = (tout > Tf) ? Tf : tout; + global tout += dTout + global tout = (tout > Tf) ? Tf : tout res[i] = y end for i in 1:3 - sol_1 = [ -0.927671 -8.500060 904.786828] - sol_end = [0.547358 -0.523577 135.169441] + sol_1 = [-0.927671 -8.500060 904.786828] + sol_end = [0.547358 -0.523577 135.169441] @test isapprox(res[1][i], sol_1[i], atol = 1e-3) @test isapprox(res[Nt][i], sol_end[i], atol = 1e-3) end diff --git a/test/runtests.jl b/test/runtests.jl index 44ee5ce..67d1b1b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,39 +2,77 @@ using Sundials using Test @testset "CVODE" begin - @testset "Roberts CVODE Simplified" begin include("cvode_Roberts_simplified.jl") end - @testset "Roberts CVODE Direct" begin include("cvode_Roberts_dns.jl") end + @testset "Roberts CVODE Simplified" begin + include("cvode_Roberts_simplified.jl") + end + @testset "Roberts CVODE Direct" begin + include("cvode_Roberts_dns.jl") + end #@testset "CVODES Direct" begin include("cvodes_dns.jl") end end @testset "IDA" begin - @testset "Roberts IDA Simplified" begin include("ida_Roberts_simplified.jl") end - @testset "Roberts IDA Direct" begin include("ida_Roberts_dns.jl") end - @testset "Heat IDA Direct" begin include("ida_Heat2D.jl") end + @testset "Roberts IDA Simplified" begin + include("ida_Roberts_simplified.jl") + end + @testset "Roberts IDA Direct" begin + include("ida_Roberts_dns.jl") + end + @testset "Heat IDA Direct" begin + include("ida_Heat2D.jl") + end # Commented out because still uses the syntax from Grid which is a deprecated package #@testset "Cable IDA Direct" begin include("ida_Cable.jl") end end @testset "ARK" begin - @testset "Roberts ARKStep Direct" begin include("arkstep_Roberts_dns.jl") end - @testset "NonLinear ERKStep Direct" begin include("erkstep_nonlin.jl") end - #@testset "MRI two way couple" begin include("mri_twowaycouple.jl") end + @testset "Roberts ARKStep Direct" begin + include("arkstep_Roberts_dns.jl") + end + @testset "NonLinear ERKStep Direct" begin + include("erkstep_nonlin.jl") + end + #@testset "MRI two way couple" begin include("mri_twowaycouple.jl") end end @testset "Kinsol" begin - @testset "Kinsol Simplified" begin include("kinsol_mkin_simplified.jl") end - @testset "Kinsol MKin" begin include("kinsol_mkinTest.jl") end - @testset "Kinsol Banded" begin include("kinsol_banded.jl") end + @testset "Kinsol Simplified" begin + include("kinsol_mkin_simplified.jl") + end + @testset "Kinsol MKin" begin + include("kinsol_mkinTest.jl") + end + @testset "Kinsol Banded" begin + include("kinsol_banded.jl") + end +end +@testset "Handle Tests" begin + include("handle_tests.jl") end -@testset "Handle Tests" begin include("handle_tests.jl") end @testset "Common Interface" begin - @testset "CVODE" begin include("common_interface/cvode.jl") end - @testset "ARKODE" begin include("common_interface/arkode.jl") end - @testset "IDA" begin include("common_interface/ida.jl") end - @testset "Jacobians" begin include("common_interface/jacobians.jl") end - @testset "Callbacks" begin include("common_interface/callbacks.jl") end - @testset "Iterator" begin include("common_interface/iterators.jl") end - @testset "Errors" begin include("common_interface/errors.jl") end - @testset "Mass Matrix" begin include("common_interface/mass_matrix.jl") end + @testset "CVODE" begin + include("common_interface/cvode.jl") + end + @testset "ARKODE" begin + include("common_interface/arkode.jl") + end + @testset "IDA" begin + include("common_interface/ida.jl") + end + @testset "Jacobians" begin + include("common_interface/jacobians.jl") + end + @testset "Callbacks" begin + include("common_interface/callbacks.jl") + end + @testset "Iterator" begin + include("common_interface/iterators.jl") + end + @testset "Errors" begin + include("common_interface/errors.jl") + end + @testset "Mass Matrix" begin + include("common_interface/mass_matrix.jl") + end end