Skip to content

Commit

Permalink
Merge pull request #194 from SciML/update
Browse files Browse the repository at this point in the history
Update for removal of pins
  • Loading branch information
ChrisRackauckas authored Feb 17, 2021
2 parents db927f8 + 8c13066 commit bb000f4
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DataDrivenDiffEq"
uuid = "2445eb08-9709-466a-b3fc-47e12bd697a2"
authors = ["Julius Martensen <[email protected]>"]
version = "0.5.4"
version = "0.5.5"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions src/DataDrivenDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using FiniteDifferences, DataInterpolations
using Compat
using DocStringExtensions

abstract type AbstractKoopmanOperator end;
abstract type AbstractKoopmanOperator <: Function end;

include("./optimizers/Optimize.jl")
using .Optimize
Expand Down Expand Up @@ -85,6 +85,6 @@ export burst_sampling, subsample

include("./basis_generators.jl")
export chebyshev_basis, monomial_basis, polynomial_basis
export sin_basis, cos_basis, fourier_basis
export sin_basis, cos_basis, fourier_basis

end # module
10 changes: 4 additions & 6 deletions src/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ It can be called with the typical DiffEq signature, meaning out of place with `f
or in place with `f(du, u, p, t)`.
If `linear_independent` is set to `true`, a linear independent basis is created from all atom function in `f`.
If `simplify_eqs` is set to `true`, `simplify` is called on `f`.
Additional keyworded arguments include `name`, which can be used to name the basis, `pins` used for connections and
Additional keyworded arguments include `name`, which can be used to name the basis, and
`observed` for defining observeables.
# Fields
Expand Down Expand Up @@ -48,7 +48,6 @@ mutable struct Basis <: ModelingToolkit.AbstractSystem
states::Vector
"""Parameters"""
ps::Vector
pins::Vector
observed::Vector
"""Independent variable"""
iv::Num
Expand All @@ -63,7 +62,7 @@ end
function Basis(eqs::AbstractVector, states::AbstractVector; parameters::AbstractArray = [], iv = nothing,

simplify = false, linear_independent = false, name = gensym(:Basis), eval_expression = false,
pins = [], observed = [],
observed = [],
kwargs...)

eqs = simplify ? ModelingToolkit.simplify.(eqs) : eqs
Expand All @@ -80,7 +79,7 @@ function Basis(eqs::AbstractVector, states::AbstractVector; parameters::Abstract
f_(u,p,t) = f_oop(u,p,t)
f_(du, u, p, t) = f_iip(du, u, p, t)

return Basis(eqs, value.(states), value.(parameters), pins, observed, value(iv), f_, name, Basis[])
return Basis(eqs, value.(states), value.(parameters), observed, value(iv), f_, name, Basis[])
end

function Basis(f::Function, states::AbstractVector; parameters::AbstractArray = [], iv = nothing, kwargs...)
Expand Down Expand Up @@ -239,9 +238,8 @@ function Base.merge(x::Basis, y::Basis; eval_expression = false)
b = unique(vcat([xi.rhs for xi equations(x)], [xi.rhs for xi equations(y)]))
vs = unique(vcat(x.states, y.states))
ps = unique(vcat(x.ps, y.ps))
pins = unique(vcat(x.pins, y.pins))
observed = unique(vcat(x.observed, y.observed))
return Basis(Num.(b), vs, parameters = ps, pins = pins, observed = observed, eval_expression = eval_expression)
return Basis(Num.(b), vs, parameters = ps, observed = observed, eval_expression = eval_expression)
end

"""
Expand Down
18 changes: 8 additions & 10 deletions src/system_conversions.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import ModelingToolkit.SymbolicUtils.FnType

function _generate_deqs(x::Basis, states, iv, p)
@assert length(x) == length(states)
@assert length(x) == length(states)
# Create new variables with time dependency
∂t = Differential(iv)
dvs = [Num(Sym{FnType{Tuple{Any}, Real}}(_get_name(xi))(value(iv))) for xi in states]
dvs = [Num(Sym{FnType{Tuple{Any}, Real}}(_get_name(xi))(value(iv))) for xi in states]
dvsdt = ∂t.(dvs)
# Adapt equations
eqs = dvsdt .~ x(dvs,p,iv)
Expand All @@ -16,17 +16,16 @@ $(SIGNATURES)
Convert a given `Basis` or `SparseIdentificationResult` into an `ODESystem`. For details, see ModelingToolkit.jl.
"""
function ModelingToolkit.ODESystem(x::Basis, iv = nothing, dvs = Num[], ps = Num[]; pins = Num[], observed = Num[], systems = ODESystem[],kwargs...)
function ModelingToolkit.ODESystem(x::Basis, iv = nothing, dvs = Num[], ps = Num[]; observed = Num[], systems = ODESystem[],kwargs...)
iv = isnothing(iv) ? independent_variable(x) : iv
dvs = isempty(dvs) ? variables(x) : dvs
ps = isempty(ps) ? parameters(x) : ps
eqs, dvs = _generate_deqs(x, dvs, iv, ps)
pins = isempty(pins) ? x.pins : pins
observed = isempty(observed) ? x.observed : observed
systems = isempty(systems) ? x.systems : systems
return ODESystem(
eqs, iv, dvs, ps,
pins = pins, observed = observed, systems = systems, kwargs...)
observed = observed, systems = systems, kwargs...)
end

function ModelingToolkit.ODESystem(b::SparseIdentificationResult, iv = nothing, dvs = Num[], ps = Num[]; kwargs...)
Expand All @@ -46,7 +45,7 @@ function _generate_deqs(x::Basis, states, iv, p, controls)
states_ = _remove_controls(states, controls)
@assert length(x) == length(states_)
∂t = Differential(iv)
dvs = [Num(Sym{FnType{Tuple{Any}, Real}}(_get_name(xi))(value(iv))) for xi in states_]
dvs = [Num(Sym{FnType{Tuple{Any}, Real}}(_get_name(xi))(value(iv))) for xi in states_]
input_states = _create_input_vec(states, dvs, controls)
dvsdt = ∂t.(dvs)
# Adapt equations
Expand Down Expand Up @@ -77,20 +76,19 @@ $(SIGNATURES)
Convert a given `Basis` or `SparseIdentificationResult` into a `ControlSystem`. For details, see ModelingToolkit.jl.
"""
function ModelingToolkit.ControlSystem(loss, x::Basis, controls, iv = nothing, dvs = nothing, ps = nothing;
pins = Num[], observed = Num[], systems = ODESystem[], kwargs...)
function ModelingToolkit.ControlSystem(loss, x::Basis, controls, iv = nothing, dvs = nothing, ps = nothing;
observed = Num[], systems = ODESystem[], kwargs...)
iv = isnothing(iv) ? independent_variable(x) : iv
dvs = isnothing(dvs) ? variables(x) : dvs
ps = isnothing(ps) ? parameters(x) : ps
eqs, dvs, input_states = _generate_deqs(x, dvs, iv, ps, controls)
#return input_states
subs = [(xi => is) for (xi, is) in zip(variables(x), input_states)]
loss = substitute.(loss, (subs,))[1]
pins = isempty(pins) ? x.pins : pins
observed = isempty(observed) ? x.observed : observed
systems = isempty(systems) ? x.systems : systems
return ControlSystem(loss, eqs, iv, dvs, controls, ps,
pins = pins, observed = observed, systems = systems, kwargs...)
observed = observed, systems = systems, kwargs...)
end

function ModelingToolkit.ControlSystem(loss, b::SparseIdentificationResult, controls, iv = nothing, dvs = Num[], ps = Num[]; kwargs...)
Expand Down

0 comments on commit bb000f4

Please sign in to comment.