Skip to content

Commit

Permalink
Formatted Code
Browse files Browse the repository at this point in the history
  • Loading branch information
hippyhippohops committed May 6, 2024
1 parent 1ed7683 commit 2f9db68
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
14 changes: 7 additions & 7 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,20 @@ function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p,
return loss
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
function generate_loss(
strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
differential_vars::AbstractVector)
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
minT = tspan[1]
maxT = tspan[2]

weights = strategy.weights ./ sum(strategy.weights)

N = length(weights)
points = strategy.points

difference = (maxT-minT)/N
difference = (maxT - minT) / N

data = Float64[]
for (index, item) in enumerate(weights)
temp_data = rand(1, trunc(Int, points * item)) .* difference .+ minT .+
Expand All @@ -102,12 +103,11 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo
ts = data

function loss(θ, _)
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
sum(inner_loss(phi, f, autodiff, ts, θ, p, differential_vars))
end
return loss
end


function SciMLBase.__solve(prob::SciMLBase.AbstractDAEProblem,
alg::NNDAE,
args...;
Expand Down
6 changes: 2 additions & 4 deletions test/NNDAE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,12 @@ end
opt = OptimizationOptimisers.Adam(0.1)
weights = [0.7, 0.2, 0.1]
points = 200
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1), strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false)
alg = NeuralPDE.NNDAE(chain, OptimizationOptimisers.Adam(0.1),
strategy = NeuralPDE.WeightedIntervalTraining(weights, points); autodiff = false)

sol = solve(prob,
alg, verbose = false, dt = 1 / 100.0f0,
maxiters = 3000, abstol = 1.0f-10)

@test ground_sol(0:(1 / 100):(pi / 2))sol atol=0.4
end



0 comments on commit 2f9db68

Please sign in to comment.