From 21946f750f577772065370e3e9ad452452999377 Mon Sep 17 00:00:00 2001 From: Priya Nagda Date: Sat, 2 Sep 2023 22:58:39 +0530 Subject: [PATCH] Changing `iteration` type from size 1 `Vector{Int}` to `Ref{Int}` --- .../test/adaptive_loss_log_tests.jl | 18 +++++++------- src/adaptive_losses.jl | 16 ++++++------- src/discretize.jl | 24 +++++++++---------- src/pinn_types.jl | 8 +++---- test/adaptive_loss_tests.jl | 8 +++---- 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/lib/NeuralPDELogging/test/adaptive_loss_log_tests.jl b/lib/NeuralPDELogging/test/adaptive_loss_log_tests.jl index 81990377ae..1feb1b84ac 100644 --- a/lib/NeuralPDELogging/test/adaptive_loss_log_tests.jl +++ b/lib/NeuralPDELogging/test/adaptive_loss_log_tests.jl @@ -45,7 +45,7 @@ function test_2d_poisson_equation_adaptive_loss(adaptive_loss, run, outdir, hasl domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] - iteration = [0] + iteration = Ref(0) discretization = NeuralPDE.PhysicsInformedNN(chain_, strategy_; adaptive_loss = adaptive_loss, @@ -63,25 +63,25 @@ function test_2d_poisson_equation_adaptive_loss(adaptive_loss, run, outdir, hasl (length(xs), length(ys))) callback = function (p, l) - iteration[1] += 1 - if iteration[1] % 100 == 0 - @info "Current loss is: $l, iteration is $(iteration[1])" + iteration += 1 + if iteration[] % 100 == 0 + @info "Current loss is: $l, iteration is $(iteration)" end if haslogger - log_value(logger, "outer_error/loss", l, step = iteration[1]) - if iteration[1] % 30 == 0 + log_value(logger, "outer_error/loss", l, step = iteration) + if iteration[] % 30 == 0 u_predict = reshape([first(phi([x, y], p)) for x in xs for y in ys], (length(xs), length(ys))) diff_u = abs.(u_predict .- u_real) total_diff = sum(diff_u) - log_value(logger, "outer_error/total_diff", total_diff, step = iteration[1]) + log_value(logger, "outer_error/total_diff", total_diff, step = iteration) total_u = sum(abs.(u_real)) total_diff_rel = total_diff / total_u log_value(logger, "outer_error/total_diff_rel", total_diff_rel, - step = iteration[1]) + step = iteration) total_diff_sq = sum(diff_u .^ 2) log_value(logger, "outer_error/total_diff_sq", total_diff_sq, - step = iteration[1]) + step = iteration) end end return false diff --git a/src/adaptive_losses.jl b/src/adaptive_losses.jl index b37023da7c..4da3ee6dfe 100644 --- a/src/adaptive_losses.jl +++ b/src/adaptive_losses.jl @@ -126,7 +126,7 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, adaloss_T = eltype(adaloss.pde_loss_weights) function run_loss_gradients_adaptive_loss(θ, pde_losses, bc_losses) - if iteration[1] % adaloss.reweight_every == 0 + if iteration[] % adaloss.reweight_every == 0 # the paper assumes a single pde loss function, so here we grab the maximum of the maximums of each pde loss function pde_grads_maxes = [maximum(abs.(Zygote.gradient(pde_loss_function, θ)[1])) for pde_loss_function in pde_loss_functions] @@ -143,14 +143,14 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, (1 .- weight_change_inertia) .* bc_loss_weights_proposed logscalar(pinnrep.logger, pde_grads_max, "adaptive_loss/pde_grad_max", - iteration[1]) + iteration[]) logvector(pinnrep.logger, pde_grads_maxes, "adaptive_loss/pde_grad_maxes", - iteration[1]) + iteration[]) logvector(pinnrep.logger, bc_grads_mean, "adaptive_loss/bc_grad_mean", - iteration[1]) + iteration[]) logvector(pinnrep.logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", - iteration[1]) + iteration[]) end nothing end @@ -244,15 +244,15 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation, iteration = pinnrep.iteration function run_minimax_adaptive_loss(θ, pde_losses, bc_losses) - if iteration[1] % adaloss.reweight_every == 0 + if iteration[] % adaloss.reweight_every == 0 Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses) Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses) logvector(pinnrep.logger, adaloss.pde_loss_weights, - "adaptive_loss/pde_loss_weights", iteration[1]) + "adaptive_loss/pde_loss_weights", iteration[]) logvector(pinnrep.logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", - iteration[1]) + iteration[]) end nothing end diff --git a/src/discretize.jl b/src/discretize.jl index 4308a79b4e..09114e8abd 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -595,7 +595,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized # that's why we prefer the user to maintain the increment in the outer loop callback during optimization ChainRulesCore.@ignore_derivatives if self_increment - iteration[1] += 1 + iteration[] += 1 end ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, @@ -630,33 +630,33 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, weighted_loss_before_additional + weighted_additional_loss_val end - ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0 + ChainRulesCore.@ignore_derivatives begin if iteration[] % log_frequency == 0 logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses", - iteration[1]) - logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1]) + iteration[]) + logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[]) logvector(pinnrep.logger, weighted_pde_losses, "weighted_loss/weighted_pde_losses", - iteration[1]) + iteration[]) logvector(pinnrep.logger, weighted_bc_losses, "weighted_loss/weighted_bc_losses", - iteration[1]) + iteration[]) if !(additional_loss isa Nothing) logscalar(pinnrep.logger, weighted_additional_loss_val, - "weighted_loss/weighted_additional_loss", iteration[1]) + "weighted_loss/weighted_additional_loss", iteration[]) end logscalar(pinnrep.logger, sum_weighted_pde_losses, - "weighted_loss/sum_weighted_pde_losses", iteration[1]) + "weighted_loss/sum_weighted_pde_losses", iteration[]) logscalar(pinnrep.logger, sum_weighted_bc_losses, - "weighted_loss/sum_weighted_bc_losses", iteration[1]) + "weighted_loss/sum_weighted_bc_losses", iteration[]) logscalar(pinnrep.logger, full_weighted_loss, "weighted_loss/full_weighted_loss", - iteration[1]) + iteration[]) logvector(pinnrep.logger, adaloss.pde_loss_weights, "adaptive_loss/pde_loss_weights", - iteration[1]) + iteration[]) logvector(pinnrep.logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", - iteration[1]) + iteration[]) end end return full_weighted_loss diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 3c74022b38..a8733903d5 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -89,7 +89,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN adaptive_loss::ADA logger::LOG log_options::LogOptions - iteration::Vector{Int64} + iteration::Ref{Int64} self_increment::Bool multioutput::Bool kwargs::K @@ -124,10 +124,10 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN _derivative = derivative end - if iteration isa Vector{Int64} + if iteration isa Ref{Int64} self_increment = false else - iteration = [1] + iteration = Ref(1) self_increment = true end @@ -228,7 +228,7 @@ mutable struct PINNRepresentation """ The iteration counter used inside the cost function """ - iteration::Vector{Int} + iteration::Ref{Int} """ The initial parameters as provided by the user. If the PDE is a system of PDEs, this will be an array of arrays. If Lux.jl is used, then this is an array of ComponentArrays. diff --git a/test/adaptive_loss_tests.jl b/test/adaptive_loss_tests.jl index bd49783ec4..d3e4a049b7 100644 --- a/test/adaptive_loss_tests.jl +++ b/test/adaptive_loss_tests.jl @@ -37,7 +37,7 @@ function test_2d_poisson_equation_adaptive_loss(adaptive_loss; seed = 60, maxite domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] - iteration = [0] + iteration = Ref(0) discretization = NeuralPDE.PhysicsInformedNN(chain_, strategy_; adaptive_loss = adaptive_loss, @@ -55,9 +55,9 @@ function test_2d_poisson_equation_adaptive_loss(adaptive_loss; seed = 60, maxite (length(xs), length(ys))) callback = function (p, l) - iteration[1] += 1 - if iteration[1] % 100 == 0 - @info "Current loss is: $l, iteration is $(iteration[1])" + iteration[] += 1 + if iteration[] % 100 == 0 + @info "Current loss is: $l, iteration is $(iteration[])" end return false end