Skip to content

Commit

Permalink
feat: allow specifying nothing as default value to skip it
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 22, 2024
1 parent 9b7e139 commit 9ab8ecf
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
:ODESystem, force = true)
end
defaults = todict(defaults)
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)
var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
process_variables!(var_to_name, defaults, ps′)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv
:SDESystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps;
:DiscreteSystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

var_to_name = Dict()
process_variables!(var_to_name, defaults, dvs′)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ function JumpSystem(eqs, iv, unknowns, ps;
:JumpSystem, force = true)
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

unknowns, ps = value.(unknowns), value.(ps)
var_to_name = Dict()
Expand Down
3 changes: 2 additions & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ function NonlinearSystem(eqs, unknowns, ps;
end
jac = RefValue{Any}(EMPTY_JAC)
defaults = todict(defaults)
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict{Any, Any}(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

unknowns, ps = value.(unknowns), value.(ps)
var_to_name = Dict()
Expand Down
3 changes: 2 additions & 1 deletion src/systems/optimization/constraints_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ function ConstraintsSystem(constraints, unknowns, ps;

jac = RefValue{Any}(EMPTY_JAC)
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

var_to_name = Dict()
process_variables!(var_to_name, defaults, unknowns′)
Expand Down
3 changes: 2 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ function OptimizationSystem(op, unknowns, ps;
throw(ArgumentError("System names must be unique."))
end
defaults = todict(defaults)
defaults = Dict(value(k) => value(v) for (k, v) in pairs(defaults))
defaults = Dict(value(k) => value(v)
for (k, v) in pairs(defaults) if value(v) !== nothing)

var_to_name = Dict()
process_variables!(var_to_name, defaults, unknowns′)
Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ end

function collect_defaults!(defs, vars)
for v in vars
(haskey(defs, v) || !hasdefault(unwrap(v))) && continue
if haskey(defs, v) || !hasdefault(unwrap(v)) || (def = getdefault(v)) === nothing
continue
end
defs[v] = getdefault(v)
end
return defs
Expand Down
17 changes: 17 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,20 @@ eqs = [D(D(z)) ~ ones(2, 2)]
prob = ODEProblem(sys, [], (0.0, 1.0), [A1 => 0.3])
@test prob.ps[B1] == 0.3
@test prob.ps[B2] == 0.7

@testset "default=nothing is skipped" begin
@parameters p = nothing
@variables x(t)=nothing y(t)
for sys in [
ODESystem(Equation[], t, [x, y], [p]; defaults = [y => nothing], name = :osys),
SDESystem(Equation[], [], t, [x, y], [p]; defaults = [y => nothing], name = :ssys),
JumpSystem(Equation[], t, [x, y], [p]; defaults = [y => nothing], name = :jsys),
NonlinearSystem(Equation[], [x, y], [p]; defaults = [y => nothing], name = :nsys),
OptimizationSystem(
Equation[], [x, y], [p]; defaults = [y => nothing], name = :optsys),
ConstraintsSystem(
Equation[], [x, y], [p]; defaults = [y => nothing], name = :conssys)
]
@test isempty(defaults(sys))
end
end

0 comments on commit 9ab8ecf

Please sign in to comment.