Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MTK adjustments (draft) #62

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

strategy:
matrix:
julia-version: ["1.7", "1.8", "1.9"]
julia-version: ["!.9","1.10"]
julia-arch: [x64]
os: [ubuntu-latest, windows-latest, macOS-latest]

Expand Down
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TimeSeriesEcon = "8b6756d2-c55c-11ea-2998-5f67ea17da60"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
StateSpaceEconMTKExt = "ModelingToolkit"

[compat]
DelimitedFiles = "1.7"
DiffResults = "1.0"
Expand All @@ -36,8 +42,11 @@ julia = "1.7"
[extras]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "ModelingToolkit", "SymbolicIndexingInterface", "NonlinearSolve"]
268 changes: 268 additions & 0 deletions ext/StateSpaceEconMTKExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
module StateSpaceEconMTKExt

using ModelBaseEcon, StateSpaceEcon, ModelingToolkit
using ModelingToolkit.SciMLBase: solve, AbstractSolution
using StateSpaceEcon.StackedTimeSolver: StackedTimeSolverData
using StateSpaceEcon.SteadyStateSolver: SolverData as SteadyStateSolverData
using TimeSeriesEcon: rangeof, MVTSeries
import StateSpaceEcon.MTKExt:
stacked_time_system,
compute_residuals_stacked_time,
_create_system,
rename_variables,
get_var_names,
solve_steady_state!,
steady_state_system,
compute_residuals_steady_state,
sol2data

# Note: Unfortunately, Documenter.jl seems not to include docstrings from package extensions.
# (At least it's not obvious how to do so.)
# As a result, the docstrings for the following methods live in `../src/MTKExt.jl`.

##############################
# Stacked time system
##############################

function stacked_time_system(m::Model, exog_data::AbstractArray{Float64,2}; fctype = fcgiven, plan::Plan=Plan(m, 1+m.maxlag:size(exog_data, 1)-m.maxlead))

data = zeroarray(m, plan)
data[plan.exogenous .== true] = exog_data[plan.exogenous .== true]
data[1:m.maxlag, :] = exog_data[1:m.maxlag, :] # Initial conditions
if fctype === fcgiven
data[end-m.maxlead+1:end, :] = exog_data[end-m.maxlead+1:end, :] # Final conditions
end

# TODO: Should other variants be accepted, or is `:default` sufficient?
sd = StackedTimeSolverData(m, plan, fctype, :default)

return stacked_time_system(sd, data)

end

function stacked_time_system(sd::StackedTimeSolverData, data::AbstractArray{Float64,2})

f = let sd = sd, data = data
(u, p) -> compute_residuals_stacked_time(u, sd, data)
end

# Since we are just creating a MTK system out of the `NonlinearProblem`,
# the `u0` we specify here is not actually used for solving the system.
u0 = zeros(count(sd.solve_mask))
prob = NonlinearProblem(f, u0)
s = _create_system(prob, sd)
s.defaults[:data_in] = data # hack to pass along the data to be used when constructing an MVTSeries

return s

end

function compute_residuals_stacked_time(u, sd::StackedTimeSolverData, data::AbstractArray{Float64,2})

# We need to use `zeros` instead of `similar` because `assign_exog_data!`
# doesn't assign final conditions for the shocks.
point = zeros(eltype(u), size(data))
point[sd.solve_mask] = u
# Note: `assign_exog_data!` assigns both exogenous data (including initial conditions) *and* final conditions.
StateSpaceEcon.StackedTimeSolver.assign_exog_data!(point, data, sd)

# Emulate `StackedTimeSolver.stackedtime_R!`, computing the residual
# for each equation at each simulation period.
resid = map(sd.TT[1:length(sd.II)]) do tt
# `tt` grabs all the data from `point` for the current simulation period, lags, and leads (e.g., `tt = [1, 2, 3, 4, 5, 6]`).
# The last value of `tt` contains the indices for just the leads (e.g., `tt = [100, 101, 102]`)
# (but we don't use it in this loop).
p = view(point, tt, :)
med = sd.evaldata
map(med.alleqns, med.allinds) do eqn, inds
eqn.eval_resid(p[inds])
end
end

return vcat(resid...)

end

function _create_system(prob::NonlinearProblem, sd::StackedTimeSolverData)

# Convert `prob` into a ModelingToolkit system.
@named old_sys = modelingtoolkitize(prob)

# `modelingtoolkitize` creates a system with automatically generated variable names
# `x₁`, `x₂`, etc.
# Rename the variables to make it easier to query the solution to problems
# created from this system.
sys = rename_variables(old_sys, sd)

# We need `conservative = true` to prevent errors in `structural_simplify`.
s = structural_simplify(complete(sys); conservative = true)

return s

end

function rename_variables(old_sys::NonlinearSystem, sd::StackedTimeSolverData)

# reshape the solve_mask
solve_mask = reshape(sd.solve_mask, :, maximum(last, sd.evaldata.var_to_idx))

# Find all column indices of columns that have at least one `true`.
var_cols = vec(mapslices(any, solve_mask; dims = 1)) |> findall

# Convert from column indices to variable/shock names.
var_names = map(var_cols) do i
findfirst(p -> last(p) == i, sd.evaldata.var_to_idx)
end

var_lengths = Dict{Symbol,Int64}()
col_count = 1
for col in var_cols
var_lengths[var_names[col_count]] = count(solve_mask[:, col])
col_count += 1
end

# For each model variable, create a ModelingToolkit variable vector
# with the correct number of time steps.
# `vars` will end up being, e.g., `[pinf[1:97], rate[1:97], ygap[1:97]]`.
vars = map(var_names) do var_name
NE = var_lengths[var_name]
ModelingToolkit.@variables $(var_name)[1:NE]
end |> x -> reduce(vcat, x)

# Collect all the individual variables to enable substituting the variable names
# that were automatically generated by `modelingtoolkitize`.
# `vars_enumerated` will end up being, e.g., `[pinf[1], pinf[2], ..., ygap[97]]`.
vars_enumerated = reduce(vcat, map(collect, vars))
old_vars = unknowns(old_sys)
subs = Dict(old_vars .=> vars_enumerated)

# Take all the equations from `old_sys`, rename the variables,
# and create a new system with the new variable names.
eqs = substitute.(ModelingToolkit.equations(old_sys), Ref(subs))
@named sys = NonlinearSystem(eqs, vars_enumerated, [])

return sys

end

function get_var_names(sd)

# Reshape the solve mask to have a column for each variable/shock.
solve_mask = reshape(sd.solve_mask, :, maximum(last, sd.evaldata.var_to_idx))
# Find all column indices of columns that have at least one `true`.
var_cols = vec(mapslices(any, solve_mask; dims = 1)) |> findall
# Convert from column indices to variable/shock names.
var_names = map(var_cols) do i
findfirst(p -> last(p) == i, sd.evaldata.var_to_idx)
end

return var_names

end

##############################
# Steady state system
##############################

function solve_steady_state!(m::Model, sys::NonlinearSystem; u0 = zeros(length(unknowns(sys))), solver = nothing, solve_kwargs...)

# Creating a `NonlinearProblem` from a `NonlinearFunction` seems to be faster
# than creating a `NonlinearProblem` directly from a `NonlinearSystem`.
nf = NonlinearFunction(sys)
prob = NonlinearProblem(nf, u0)
sol = isnothing(solver) ? solve(prob; solve_kwargs...) : solve(prob, solver; solve_kwargs...)

# Copy the steady state solution to the model.
m.sstate.values[.!m.sstate.mask] = sol.u
# Mark the steady state as solved.
m.sstate.mask .= true

return m

end

steady_state_system(m::Model) = steady_state_system(SteadyStateSolverData(m))

function steady_state_system(sd::SteadyStateSolverData)

f = let sd = sd
(u, p) -> compute_residuals_steady_state(u, sd)
end

# Since we are just creating a MTK system out of the `NonlinearProblem`,
# the `u0` we specify here is not actually used for solving the system.
u0 = zeros(count(sd.solve_var))
prob = NonlinearProblem(f, u0)
@named sys = modelingtoolkitize(prob)
s = complete(sys)

return s

end

# `ModelBaseEcon.__to_dyn_pt` doesn't work with automatic differentiation,
# so copy it here to allow passing in an apropriately typed `buffer`.
# TODO: Should this method be added directly to ModelBaseEcon.jl?
function __to_dyn_pt!(buffer, pt, s)
# This function applies the transformation from steady
# state equation unknowns to dynamic equation unknowns
for (i, jt) in enumerate(s.JT)
if length(jt.ssinds) == 1
pti = pt[jt.ssinds[1]]
else
pti = pt[jt.ssinds[1]] + ModelBaseEcon.__lag(jt, s) * pt[jt.ssinds[2]]
end
buffer[i] += pti
end
return buffer
end

function compute_residuals_steady_state(u, sd)

# `point` needs to contain the levels and slopes for all variables and shocks,
# but these values are all zero except for those included in `u`.
point = zeros(eltype(u), length(sd.solve_var))
point[sd.solve_var] = u

# Emulate `SteadyStateSolver.global_SS_R!`, computing the residual for each equation.
resid = map(sd.alleqns) do eqn
if hasproperty(eqn.eval_resid, :s)
# `eqn.eval_resid` closes over `s::SSEqData`.
# There's no function to create just `s`, so just grab it from the closure.
s = eqn.eval_resid.s
buffer = zeros(eltype(u), length(s.JT))
s.eqn.eval_resid(__to_dyn_pt!(buffer, point[eqn.vinds], s))
else
# `eqn.eval_resid` is an `EquationEvaluator` (not a closure)
# that computes the residual for an equation defined by a steady state constraint.
eqn.eval_resid(point[eqn.vinds])
end
end

return resid

end

function sol2data(sol::AbstractSolution, s::ModelingToolkit.AbstractSystem, m::Model, p::Plan)
rng = rangeof(p)

# get endogenous points
_result = ModelingToolkit.SymbolicIndexingInterface.getu(sol,
reduce(vcat, map(collect, map(var -> hasproperty(s, var.name) ? getproperty(s, var.name) : [], m.varshks)))
)(sol) # is this a fair assumption, that we will get back the variables?

# get mask for endogenous points
mask = p.exogenous .== false
mask[begin:begin+m.maxlag-1, :] .= false # initial conditions
mask[end-m.maxlead+1:end, :] .= false # final conditions

# construct results from exogenous data
result = MVTSeries(rng, m.varshks, s.defaults[:data_in])

# update endogenous points
result[mask] = _result

return result
end

end
Loading
Loading