Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to run neural_dae_tests.jl #939

Closed
awecefil opened this issue Sep 6, 2024 · 1 comment
Closed

Unable to run neural_dae_tests.jl #939

awecefil opened this issue Sep 6, 2024 · 1 comment
Labels
question Further information is requested

Comments

@awecefil
Copy link

awecefil commented Sep 6, 2024

Hi, currently I am trying to use NeuraDAE() and I try to run the neural_dae_tests.jl as a reference.
However, I got some errors when I run the test code

using ComponentArrays, Zygote, Optimization, OptimizationOptimJL, OrdinaryDiffEq, Random, Lux, DiffEqFlux

# A desired MWE for now, not a test yet.

function rober(du, u, p, t)
    y₁, y₂, y₃ = u
    k₁, k₂, k₃ = p
    du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
    du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
    du[3] = y₁ + y₂ + y₃ - 1
    nothing
end
M = [1.0 0 0
     0 1.0 0
     0 0 0]
prob_mm = ODEProblem(ODEFunction(rober; mass_matrix = M), [1.0, 0.0, 0.0], (0.0, 10.0), (0.04, 3e7, 1e4))
sol = solve(prob_mm, Rodas5(); reltol = 1e-8, abstol = 1e-8)

dudt2 = Chain(x -> x .^ 3, Dense(6, 50, tanh), Dense(50, 3))

u₀ = [1.0, 0, 0]
du₀ = [-0.04, 0.04, 0.0]
tspan = (0.0, 10.0)

ndae = NeuralDAE(dudt2, (u, p, t) -> [u[1] + u[2] + u[3] - 1], tspan,
    DFBDF(); differential_vars = [true, true, false])
ps, st = Lux.setup(Xoshiro(0), ndae)
ps = ComponentArray(ps)

predict_n_dae(p) = first(ndae((u₀, du₀), p, st))

function loss(p)
    pred = predict_n_dae(p)
    loss = sum(abs2, sol .- pred)
    return loss, pred
end

begin
    optfunc = Optimization.OptimizationFunction((x, p) -> loss(x), Optimization.AutoZygote())
    optprob = Optimization.OptimizationProblem(optfunc, ps)
    res = Optimization.solve(optprob, BFGS(; initial_stepnorm = 0.0001))
end

Running this code, I got:

ERROR: MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})

Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat
   @ Base rounding.jl:207
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:792
  Float64(::IrrationalConstants.Inv4π)
   @ IrrationalConstants C:\.julia\packages\IrrationalConstants\vp5v4\src\macro.jl:112
  ...

Stacktrace:
   [1] convert(::Type{Float64}, x::Tracker.TrackedReal{Float64})
     @ Base .\number.jl:7
   [2] DiffEqBase.NonlinearTerminationModeCache{…}(u::Vector{…}, retcode::SciMLBase.ReturnCode.T, abstol::Float64, reltol::Float64, best_objective_value::Tracker.TrackedReal{…}, mode::AbsSafeBestTerminationMode{…}, initial_objective::Tracker.TrackedReal{…}, objectives_trace::Vector{…}, nsteps::Int64, saved_values::Nothing, u0_norm::Nothing, step_norm_trace::Vector{…}, max_stalled_steps::Int64, u_diff_cache::Vector{…})
     @ DiffEqBase C:\.julia\packages\DiffEqBase\52czI\src\termination_conditions.jl:201
   [3] init(::Tracker.TrackedVector{…}, ::Vector{…}, ::AbsSafeBestTerminationMode{…}; use_deprecated_retcodes::Val{…}, abstol::Nothing, reltol::Nothing, kwargs::@Kwargs{})   
     @ DiffEqBase C:\.julia\packages\DiffEqBase\52czI\src\termination_conditions.jl:282
   [4] init_termination_cache(abstol::Nothing, reltol::Nothing, du::Tracker.TrackedVector{…}, u::Vector{…}, tc::AbsSafeBestTerminationMode{…})
     @ NonlinearSolve C:\.julia\packages\NonlinearSolve\5yLII\src\internal\termination.jl:6
   [5] init_termination_cache(abstol::Nothing, reltol::Nothing, du::Tracker.TrackedVector{…}, u::Vector{…}, ::Nothing)
     @ NonlinearSolve C:\.julia\packages\NonlinearSolve\5yLII\src\internal\termination.jl:2
   [6] __init(::NonlinearProblem{…}, ::NonlinearSolve.GeneralizedFirstOrderAlgorithm{…}; alias_u0::Bool, maxiters::Int64, abstol::Nothing, reltol::Nothing, maxtime::Nothing, 
termination_condition::Nothing, internalnorm::Function, linsolve_kwargs::@NamedTuple{}, kwargs::@Kwargs{…})
     @ NonlinearSolve C:\.julia\packages\NonlinearSolve\5yLII\src\core\generalized_first_order.jl:158

The full stack in this link

By the way, I have tried to change BFGS to DFBDF but got different error:

ERROR: Optimization algorithm not found. Either the chosen algorithm is not a valid solver
choice for the `OptimizationProblem`, or the Optimization solver library is not loaded.
Make sure that you have loaded an appropriate Optimization.jl solver library, for example,
`solve(prob,Optim.BFGS())` requires `using OptimizationOptimJL` and
`solve(prob,Adam())` requires `using OptimizationOptimisers`.

For more information, see the Optimization.jl documentation: https://docs.sciml.ai/Optimization/stable/.

Chosen Optimizer: DFBDF{5, 0, true, Nothing, NLNewton{Rational{Int64}, Rational{Int64}, Rational{Int64}, Rational{Int64}}, typeof(OrdinaryDiffEq.DEFAULT_PRECS), Val{:forward}, true, nothing, Nothing, Nothing}(Val{5}(), nothing, NLNewton{Rational{Int64}, Rational{Int64}, Rational{Int64}, Rational{Int64}}(1//100, 10, 1//5, 1//5, false, true, 0//1),
 OrdinaryDiffEq.DEFAULT_PRECS, nothing, nothing, :linear, :Standard)
Stacktrace:
 [1] __solve(::OptimizationProblem{…}, ::DFBDF{…}; kwargs::@Kwargs{})
   @ SciMLBase C:\.julia\packages\SciMLBase\NjslX\src\solve.jl:193
 [2] __solve(::OptimizationProblem{…}, ::DFBDF{…})
   @ SciMLBase C:\.julia\packages\SciMLBase\NjslX\src\solve.jl:192
 [3] solve(::OptimizationProblem{…}, ::DFBDF{…}; kwargs::@Kwargs{})
   @ SciMLBase C:\.julia\packages\SciMLBase\NjslX\src\solve.jl:99 
 [4] solve(::OptimizationProblem{…}, ::DFBDF{…})
   @ SciMLBase C:\.julia\packages\SciMLBase\NjslX\src\solve.jl:93
 [5] top-level scope
   @ d:\julia_neuralDAE.jl:102
@awecefil awecefil added the question Further information is requested label Sep 6, 2024
@avik-pal
Copy link
Member

Can you try the most recent versions of the packages. This particular breakage was from an upstream issue that has since been fixed https://github.com/SciML/DiffEqFlux.jl/actions/runs/10849209541

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants