Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix several bugs related to remake and symbolic indexing #583

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ EnumX = "1"
FillArrays = "1.9"
FunctionWrappersWrappers = "0.1.3"
IteratorInterfaceExtensions = "^1"
JumpProcesses = "9.10.1"
LinearAlgebra = "1.9"
Logging = "1.9"
Markdown = "1.9"
Expand All @@ -81,7 +82,7 @@ SciMLOperators = "0.3.7"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.9"
SymbolicIndexingInterface = "0.3"
SymbolicIndexingInterface = "0.3.2"
Tables = "1.11"
TruncatedStacktraces = "1.4"
Zygote = "0.6.67"
Expand All @@ -92,6 +93,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Expand All @@ -102,8 +104,9 @@ RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq"]
test = ["Pkg", "PyCall", "PythonCall", "SafeTestsets", "Test", "StaticArrays", "StochasticDiffEq", "Aqua", "Zygote", "PartialFunctions", "DataFrames", "ModelingToolkit", "OrdinaryDiffEq", "JumpProcesses", "SymbolicIndexingInterface"]
4 changes: 3 additions & 1 deletion src/problems/problem_interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
SymbolicIndexingInterface.symbolic_container(prob::AbstractSciMLProblem) = prob.f
SymbolicIndexingInterface.symbolic_container(prob::AbstractJumpProblem) = prob.prob

Check warning on line 2 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L2

Added line #L2 was not covered by tests
SymbolicIndexingInterface.parameter_values(prob::AbstractSciMLProblem) = prob.p
SymbolicIndexingInterface.parameter_values(prob::AbstractJumpProblem) = parameter_values(prob.prob)

Check warning on line 4 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L4

Added line #L4 was not covered by tests

Base.@propagate_inbounds function Base.getindex(prob::AbstractSciMLProblem, ::SymbolicIndexingInterface.SolvedVariables)
return getindex(prob, variable_symbols(prob))
Expand Down Expand Up @@ -30,7 +32,7 @@
elseif symbolic_type(sym) == ArraySymbolic()
return map(s -> prob[s], sym)
else
sym isa AbstractArray || error("Invalid indexing of problem")
sym isa Union{<:AbstractArray, <:Tuple} || error("Invalid indexing of problem")

Check warning on line 35 in src/problems/problem_interface.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/problem_interface.jl#L35

Added line #L35 was not covered by tests
return map(s -> prob[s], sym)
end
end
Expand Down
129 changes: 48 additions & 81 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,7 @@
@eval remaker_of(::$T) = $T
end

"""
remake(thing; <keyword arguments>)

Re-construct `thing` with new field values specified by the keyword
arguments.
"""
function remake(thing; kwargs...)
function _remake_internal(thing; kwargs...)

Check warning on line 22 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L22

Added line #L22 was not covered by tests
T = remaker_of(thing)
if :kwargs ∈ fieldnames(typeof(thing))
if :kwargs ∉ keys(kwargs)
Expand All @@ -38,6 +32,21 @@
end
end

"""
remake(thing; <keyword arguments>)

Re-construct `thing` with new field values specified by the keyword
arguments.
"""
function remake(thing; kwargs...)
_remake_internal(thing; kwargs...)

Check warning on line 42 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end

function remake(prob::DiscreteProblem; u0 = missing, p = missing, kwargs...)
p, u0 = _remake_get_p_u0(prob; p, u0)
_remake_internal(prob; p, u0, kwargs...)

Check warning on line 47 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L45-L47

Added lines #L45 - L47 were not covered by tests
end

function isrecompile(prob::ODEProblem{iip}) where {iip}
(prob.f isa ODEFunction) ? !isfunctionwrapper(prob.f.f) : true
end
Expand All @@ -59,25 +68,7 @@
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
if (eltype(p) <: Pair && !isempty(p)) || (eltype(u0) <: Pair && !isempty(u0)) # one is a non-empty symbolic map
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
p, u0 = process_p_u0_symbolic(prob, p, u0)
end
end
p, u0 = _remake_get_p_u0(prob; p, u0)

Check warning on line 71 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L71

Added line #L71 was not covered by tests

iip = isinplace(prob)

Expand Down Expand Up @@ -132,15 +123,11 @@
tspan = prob.tspan
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
if p === missing
p = prob.p

Check warning on line 127 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
end
if u0 === missing
u0 = prob.u0

Check warning on line 130 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L129-L130

Added lines #L129 - L130 were not covered by tests
end

iip = isinplace(prob)
Expand Down Expand Up @@ -202,13 +189,7 @@
tspan = prob.tspan
end

if p === missing
p = prob.p
end

if u0 === missing
u0 = prob.u0
end
p, u0 = _remake_get_p_u0(prob; p, u0)

Check warning on line 192 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L192

Added line #L192 was not covered by tests

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -263,26 +244,7 @@
sense = missing,
kwargs = missing,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
if (eltype(p) <: Pair && !isempty(p)) || (eltype(u0) <: Pair && !isempty(u0)) # one is a non-empty symbolic map
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
p, u0 = process_p_u0_symbolic(prob, p, u0)
end
end

p, u0 = _remake_get_p_u0(prob; p, u0)

Check warning on line 247 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L247

Added line #L247 was not covered by tests
if f === missing
f = prob.f
end
Expand Down Expand Up @@ -332,25 +294,7 @@
problem_type = missing,
kwargs = missing,
_kwargs...)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0
else # at least one of them has a value
if p === missing
p = prob.p
end
if u0 === missing
u0 = prob.u0
end
if (eltype(p) <: Pair && !isempty(p)) || (eltype(u0) <: Pair && !isempty(u0)) # one is a non-empty symbolic map
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
p, u0 = process_p_u0_symbolic(prob, p, u0)
end
end
p, u0 = _remake_get_p_u0(prob; p, u0)

Check warning on line 297 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L297

Added line #L297 was not covered by tests

if f === missing
f = prob.f
Expand Down Expand Up @@ -419,3 +363,26 @@
en_kwargs = [k for k in kwargs if first(k) ∈ fieldnames(T)]
T(remake(thing.prob; setdiff(kwargs, en_kwargs)...); en_kwargs...)
end

function _remake_get_p_u0(prob; p = missing, u0 = missing)
if p === missing && u0 === missing
p, u0 = prob.p, prob.u0

Check warning on line 369 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L367-L369

Added lines #L367 - L369 were not covered by tests
else # at least one of them has a value
if p === missing
p = prob.p

Check warning on line 372 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L371-L372

Added lines #L371 - L372 were not covered by tests
end
if u0 === missing
u0 = prob.u0

Check warning on line 375 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L374-L375

Added lines #L374 - L375 were not covered by tests
end
if (eltype(p) <: Pair && !isempty(p)) || (eltype(u0) <: Pair && !isempty(u0)) # one is a non-empty symbolic map
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||

Check warning on line 378 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L377-L378

Added lines #L377 - L378 were not covered by tests
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||

Check warning on line 381 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L381

Added line #L381 was not covered by tests
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
p, u0 = process_p_u0_symbolic(prob, p, u0)

Check warning on line 384 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L384

Added line #L384 was not covered by tests
end
end
return p, u0

Check warning on line 387 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L387

Added line #L387 was not covered by tests
end
16 changes: 16 additions & 0 deletions src/solutions/rode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@

function (sol::RODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
if idxs !== nothing
if !(idxs isa Union{<:AbstractArray, <:Tuple})
idxs = [idxs]

Check warning on line 68 in src/solutions/rode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/rode_solutions.jl#L66-L68

Added lines #L66 - L68 were not covered by tests
end
idxs = map(idxs) do idx
if symbolic_type(idx) === NotSymbolic()
return idx
elseif symbolic_type(idx) === ScalarSymbolic()
return variable_index(sol, idx)

Check warning on line 74 in src/solutions/rode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/rode_solutions.jl#L70-L74

Added lines #L70 - L74 were not covered by tests
else
return variable_index.((sol,), collect(idx))

Check warning on line 76 in src/solutions/rode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/rode_solutions.jl#L76

Added line #L76 was not covered by tests
end
end
any(i === nothing for i in idxs) && error("All idxs must be variables")
idxs = reduce(vcat, idxs)

Check warning on line 80 in src/solutions/rode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/rode_solutions.jl#L79-L80

Added lines #L79 - L80 were not covered by tests
end
sol.interp(t, idxs, deriv, sol.prob.p, continuity)
end
function (sol::RODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing,
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ getβ3 = getp(sys, :β)
@test oprob[x] == oprob[sys.x] == oprob[:x] == 1.0
@test oprob[y] == oprob[sys.y] == oprob[:y] == 0.0
@test oprob[z] == oprob[sys.z] == oprob[:z] == 0.0
@test oprob[[x, y]] == oprob[[sys.x, sys.y]] == oprob[[:x, :y]] == [1.0, 0.0]
@test oprob[(x, y)] == oprob[(sys.x, sys.y)] == oprob[(:x, :y)] == (1.0, 0.0)
@test oprob[solvedvariables] == oprob[variable_symbols(sys)]
@test oprob[allvariables] == oprob[all_variable_symbols(sys)]

Expand Down
Loading
Loading