diff --git a/src/dae_solve.jl b/src/dae_solve.jl index 5a5ee83be3..d131676391 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -47,6 +47,25 @@ function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = NNDAE(chain, opt, init_params, autodiff, strategy, kwargs) end + +function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ, + autodiff::Bool, differential_vars::AbstractVector) where {C, T, U <: Number} + if autodiff + ForwardDiff.derivative(t -> phi(t, θ), t) + else + (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t))) + end +end + +function dfdx(phi::ODEPhi{C, T, U}, t::Number, θ, + autodiff::Bool,differential_vars::AbstractVector) where {C, T, U <: AbstractVector} + if autodiff + ForwardDiff.jacobian(t -> phi(t, θ), t) + else + (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, θ)) / sqrt(eps(typeof(t))) + end +end + function dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool, differential_vars::AbstractVector) if autodiff @@ -69,6 +88,19 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, sum(abs2, loss) / length(t) end +#= +function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, + p, differential_vars::AbstractVector) where {C, T, U} + sum(abs2, dfdx(phi, t, θ, autodiff,differential_vars) .- f(phi(t, θ), t)) +end +=# + +function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, + p, differential_vars::AbstractVector) where {C, T, U} + dphi = dfdx(phi, t, θ, autodiff,differential_vars) + sum(abs2, f(dphi, phi(t, θ), p, t)) +end + function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, differential_vars::AbstractVector) ts = tspan[1]:(strategy.dx):tspan[2] @@ -79,6 +111,65 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, return loss end +function generate_loss( + strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, + differential_vars::AbstractVector) + autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) + minT = tspan[1] + maxT = tspan[2] + + weights = strategy.weights ./ sum(strategy.weights) + + N = length(weights) + points = strategy.points + + difference = (maxT - minT) / N + + data = Float64[] + for (index, item) in enumerate(weights) + temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ + ((index - 1) * difference) + data = append!(data, temp_data) + end + + ts = data + + function loss(θ, _) + sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) + end + return loss +end + + +function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, + differential_vars::AbstractVector) + integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, differential_vars)) + + function integrand(ts, θ) + [sum(abs2, inner_loss(phi, f, autodiff, t, θ, p, differential_vars)) for t in ts] + end + + function loss(θ, _) + intf = BatchIntegralFunction(integrand, max_batch = strategy.batch) + intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ) + sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, + reltol = strategy.reltol, maxiters = strategy.maxiters) + sol.u + end + return loss +end + +function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p, + differential_vars::AbstractVector) + autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining.")) + function loss(θ, _) + ts = adapt(parameterless_type(θ), + [(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)]) + sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) + end + return loss +end + function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, alg::NNDAE, args...; @@ -136,8 +227,13 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, if dt !== nothing GridTraining(dt) else - error("dt is not defined") + QuadratureTraining(; quadrature_alg = QuadGKJL(), + reltol = convert(eltype(u0), reltol), + abstol = convert(eltype(u0), abstol), maxiters = maxiters, + batch = 0) end + else + alg.strategy end inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, differential_vars) @@ -189,3 +285,4 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, dense_errors = false) sol end + diff --git a/src/ode_solve.jl b/src/ode_solve.jl index bcf9c68ebe..7a0b597317 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -225,6 +225,7 @@ end Representation of the loss function, parametric on the training strategy `strategy`. """ + function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool) integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) @@ -304,6 +305,8 @@ function generate_loss( return loss end + + function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool) function loss(θ, _) if batch @@ -319,6 +322,7 @@ function generate_loss(strategy::QuasiRandomTraining, phi, f, autodiff::Bool, ts error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional spaces only. Use StochasticTraining instead.") end + struct NNODEInterpolation{T <: ODEPhi, T2} phi::T θ::T2 @@ -490,3 +494,4 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, dense_errors = false) sol end #solve + diff --git a/test/NNDAE_tests.jl b/test/NNDAE_tests.jl index bbcf12dd6d..11d679c382 100644 --- a/test/NNDAE_tests.jl +++ b/test/NNDAE_tests.jl @@ -16,7 +16,7 @@ Random.seed!(100) M = [1.0 0 0 0] f = ODEFunction(example1, mass_matrix = M) - tspan = (0.0f0, 1.0f0) + tspan = (0.0, 1.0) prob_mm = ODEProblem(f, u₀, tspan) ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) @@ -25,13 +25,13 @@ Random.seed!(100) differential_vars = [true, false] prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) chain = Lux.Chain(Lux.Dense(1, 15, cos), Lux.Dense(15, 15, sin), Lux.Dense(15, 2)) - opt = OptimizationOptimisers.Adam(0.1) - alg = NeuralPDE.NNDAE(chain, opt; autodiff = false) + opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) + alg = NNDAE(chain, opt; autodiff = false) sol = solve(prob, - alg, verbose = false, dt = 1 / 100.0f0, - maxiters = 3000, abstol = 1.0f-10) - @test ground_sol(0:(1 / 100):1)≈sol atol=0.4 + alg, verbose = false, dt = 1 / 100.0, + maxiters = 3000, abstol = 1e-10) + @test reduce(hcat, ground_sol(0:(1 / 100):1).u)≈reduce(hcat, sol.u) rtol=1e-1 end @testset "Example 2" begin @@ -44,7 +44,7 @@ end 0 1] u₀ = [0.0, 0.0] du₀ = [0.0, 0.0] - tspan = (0.0f0, pi / 2.0f0) + tspan = (0.0, pi / 2.0) f = ODEFunction(example2, mass_matrix = M) prob_mm = ODEProblem(f, u₀, tspan) ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) @@ -57,8 +57,93 @@ end alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1); autodiff = false) sol = solve(prob, - alg, verbose = false, dt = 1 / 100.0f0, - maxiters = 3000, abstol = 1.0f-10) + alg, verbose = false, dt = 1 / 100.0, + maxiters = 3000, abstol = 1e-10) + + @test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2 +end + +@testset "WeightedIntervalTraining" begin + function example2(du, u, p, t) + du[1] = u[1] - t + du[2] = u[2] - t + nothing + end + M = [0.0 0.0 + 0.0 1.0] + u₀ = [0.0, 0.0] + du₀ = [0.0, 0.0] + tspan = (0.0, pi / 2.0) + f = ODEFunction(example2, mass_matrix = M) + prob_mm = ODEProblem(f, u₀, tspan) + ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) + + example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]] + differential_vars = [false, true] + prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) + chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) + opt = OptimizationOptimisers.Adam(0.1) + weights = [0.7, 0.2, 0.1] + points = 200 + alg = NNDAE(chain, OptimizationOptimisers.Adam(0.1), + strategy = WeightedIntervalTraining(weights, points); autodiff = false) + + sol = solve(prob, + alg, verbose = false, dt = 1 / 100.0, + maxiters = 3000, abstol = 1e-10) + + @test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2 +end + +@testset "StochasticTraining" begin + function example2(du, u, p, t) + du[1] = u[1] - t + du[2] = u[2] - t + nothing + end + M = [0.0 0.0 + 0.0 1.0] + u₀ = [0.0, 0.0] + du₀ = [0.0, 0.0] + tspan = (0.0, pi / 2.0) + f = ODEFunction(example2, mass_matrix = M) + prob_mm = ODEProblem(f, u₀, tspan) + ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) + + example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]] + differential_vars = [false, true] + prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) + chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) + opt = OptimizationOptimisers.Adam(0.1) + alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), + strategy = NeuralPDE.StochasticTraining(1000); autodiff = false) + sol = solve(prob, + alg, verbose = false, dt = 1 / 100.0, + maxiters = 3000, abstol = 1e-10) + @test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2 +end + +@testset "QuadratureTraining" begin + function example2(du, u, p, t) + du[1] = u[1] - t + du[2] = u[2] - t + nothing + end + M = [0.0 0.0 + 0.0 1.0] + u₀ = [0.0, 0.0] + du₀ = [0.0, 0.0] + tspan = (0.0, pi / 2.0) + f = ODEFunction(example2, mass_matrix = M) + prob_mm = ODEProblem(f, u₀, tspan) + ground_sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) - @test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4 + example = (du, u, p, t) -> [u[1] - t - du[1], u[2] - t - du[2]] + differential_vars = [false, true] + prob = DAEProblem(example, du₀, u₀, tspan; differential_vars = differential_vars) + chain = Lux.Chain(Lux.Dense(1, 15, Lux.σ), Lux.Dense(15, 2)) + opt = OptimizationOptimJL.BFGS(linesearch = BackTracking()) + alg = NeuralPDE.NNDAE(chain, opt; autodiff = false) + sol = solve(prob, alg, verbose = true, maxiters = 6000, abstol = 1e-10, dt = 1/100.0) + @test reduce(hcat, ground_sol(0:(1 / 100):(pi / 2.0)).u)≈reduce(hcat, sol.u) rtol=1e-2 end