Skip to content

Commit

Permalink
Merge pull request #2503 from jClugstor/AliasingAPI
Browse files Browse the repository at this point in the history
Use Aliasing API
  • Loading branch information
ChrisRackauckas authored Dec 21, 2024
2 parents fdc341a + ac32012 commit 8353139
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 111 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ PrecompileTools = "1"
Preferences = "1.3"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SciMLBase = "2.53.2"
SciMLBase = "2.69"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleNonlinearSolve = "1, 2"
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Random = "<0.0.1, 1"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.62"
SciMLBase = "2.68"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleUnPack = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using DiffEqBase: check_error!, @def, _vec, _reshape

using FastBroadcast: @.., True, False

using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val, ODEAliasSpecifier

import SciMLBase: AbstractNonlinearProblem, alg_order

Expand Down
62 changes: 55 additions & 7 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ function DiffEqBase.__init(
userdata = nothing,
allow_extrapolation = alg_extrapolates(alg),
initialize_integrator = true,
alias_u0 = false,
alias_du0 = false,
alias = ODEAliasSpecifier(),
initializealg = DefaultInit(),
kwargs...) where {recompile_flag}
if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm
Expand Down Expand Up @@ -156,19 +155,62 @@ function DiffEqBase.__init(
else
_alg = alg
end
f = prob.f
p = prob.p

# Get the control variables
use_old_kwargs = haskey(kwargs,:alias_u0) || haskey(kwargs,:alias_du0)

if alias_u0
if use_old_kwargs
aliases = ODEAliasSpecifier()
if haskey(kwargs, :alias_u0)
message = "`alias_u0` keyword argument is deprecated, to set `alias_u0`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = ODEAliasSpecifier(alias_u0 = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = ODEAliasSpecifier(alias_u0 = values(kwargs).alias_u0)
else
aliases = ODEAliasSpecifier(alias_u0 = nothing)
end

if haskey(kwargs, :alias_du0)
message = "`alias_du0` keyword argument is deprecated, to set `alias_du0`,
please use an ODEAliasSpecifier, e.g. `solve(prob, alias = ODEAliasSpecifier(alias_du0 = true))"
Base.depwarn(message, :init)
Base.depwarn(message, :solve)
aliases = ODEAliasSpecifier(alias_u0 = aliases.alias_u0, alias_du0 = values(kwargs).alias_du0)
else
aliases = ODEAliasSpecifier(alias_u0 = aliases.alias_u0, alias_du0 = nothing)
end

aliases

else
# If alias isa Bool, all fields of ODEAliases set to alias
if alias isa Bool
aliases = ODEAliasSpecifier(alias = alias)
elseif alias isa ODEAliasSpecifier
aliases = alias
end
end

if isnothing(aliases.alias_f) || aliases.alias_f
f = prob.f
else
f = deepcopy(prob.f)
end

if isnothing(aliases.alias_p) || aliases.alias_p
p = prob.p
else
p = recursivecopy(prob.p)
end

if !isnothing(aliases.alias_u0) && aliases.alias_u0
u = prob.u0
else
u = recursivecopy(prob.u0)
end

if _alg isa DAEAlgorithm
if alias_du0
if !isnothing(aliases.alias_du0) && aliases.alias_du0
du = prob.du0
else
du = recursivecopy(prob.du0)
Expand Down Expand Up @@ -240,6 +282,12 @@ function DiffEqBase.__init(
resType = typeof(res_prototype)
end

if isnothing(aliases.alias_tstops) || aliases.alias_tstops
tstops = tstops
else
tstops = recursivecopy(tstops)
end

if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number
_tstops = nothing
else
Expand Down
Loading

0 comments on commit 8353139

Please sign in to comment.