From 2f9db684af12abb78512d6a746396492115f6197 Mon Sep 17 00:00:00 2001 From: hippyhippohops Date: Sun, 5 May 2024 23:39:26 -0500 Subject: [PATCH] Formatted Code --- src/dae_solve.jl | 14 +++++++------- test/NNDAE_tests.jl | 6 ++---- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/dae_solve.jl b/src/dae_solve.jl index df18ad08d9..5c2936c4d3 100644 --- a/src/dae_solve.jl +++ b/src/dae_solve.jl @@ -79,19 +79,20 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, return loss end -function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, - differential_vars::AbstractVector) +function generate_loss( + strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p, + differential_vars::AbstractVector) autodiff && throw(ArgumentError("autodiff not supported for GridTraining.")) minT = tspan[1] maxT = tspan[2] - + weights = strategy.weights ./ sum(strategy.weights) N = length(weights) points = strategy.points - difference = (maxT-minT)/N - + difference = (maxT - minT) / N + data = Float64[] for (index, item) in enumerate(weights) temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+ @@ -102,12 +103,11 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo ts = data function loss(θ, _) - sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) + sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars)) end return loss end - function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem, alg::NNDAE, args...; diff --git a/test/NNDAE_tests.jl b/test/NNDAE_tests.jl index e5930ec063..1f9a31438d 100644 --- a/test/NNDAE_tests.jl +++ b/test/NNDAE_tests.jl @@ -85,7 +85,8 @@ end opt = OptimizationOptimisers.Adam(0.1) weights = [0.7, 0.2, 0.1] points = 200 - alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false) + alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), + strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false) sol = solve(prob, alg, verbose = false, dt = 1 / 100.0f0, @@ -93,6 +94,3 @@ end @test ground_sol(0:(1 / 100):(pi / 2))≈sol atol=0.4 end - - -