From 903227605a7174700ab715d016ebfcbfc01260d8 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 18 Oct 2024 16:20:51 -0400 Subject: [PATCH] proposing fix by changing variable names --- src/simulation.jl | 12 ++++++------ test/simulation.jl | 21 +++++++++++++++------ test/simulation_core.jl | 6 +++--- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/simulation.jl b/src/simulation.jl index c0b65288..6ddf3d26 100644 --- a/src/simulation.jl +++ b/src/simulation.jl @@ -314,8 +314,8 @@ function get_vars_code(d::SummationDecapode, vars::Vector{Symbol}, ::Type{statee # TODO: we should fix that upstream so that we don't need this. line = @match s_type begin :Literal => :($s = $(parse(stateeltype, String(s)))) - :Constant => :($s = p.$s) - :Parameter => :($s = (p.$s)(t)) + :Constant => :($s = __p__.$s) + :Parameter => :($s = (__p__.$s)(t)) _ => hook_GVC_get_form(s, s_type, code_target) # ! WARNING: This assumes a form # _ => throw(InvalidDecaTypeException(s, s_type)) # TODO: Use this for invalid types end @@ -326,7 +326,7 @@ end # TODO: Expand on this to be able to handle vector and ComponentArrays inputs function hook_GVC_get_form(var_name::Symbol, var_type::Symbol, ::Union{CPUBackend, CUDABackend}) - return :($var_name = u.$var_name) + return :($var_name = __u__.$var_name) end """ @@ -354,7 +354,7 @@ is the name of the variable whose data will be stored and a code target. """ function hook_STC_settvar(state_name::Symbol, tgt_name::Symbol, ::Union{CPUBackend, CUDABackend}) ssymb = QuoteNode(state_name) - return :(setproperty!(du, $ssymb, $tgt_name)) + return :(setproperty!(__du__, $ssymb, $tgt_name)) end const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+, @@ -546,7 +546,7 @@ This hook is passed in `cache_exprs` which is the collection of exprs to be past `AllocVecCall` that stores information about the allocated vector and a code target. """ function hook_PPVA_data_handle!(cache_exprs::Vector{Expr}, alloc_vec::AllocVecCall, ::CPUBackend) - line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), u))) + line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), __u__))) push!(cache_exprs, line) end @@ -759,7 +759,7 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension $func_defs $cont_defs $vect_defs - f(du, u, p, t) = begin + f(__du__, __u__, __p__, t) = begin $vars $data $(equations...) diff --git a/test/simulation.jl b/test/simulation.jl index d3666fba..306b2bc8 100644 --- a/test/simulation.jl +++ b/test/simulation.jl @@ -102,11 +102,11 @@ DiffusionWithLiteral = @decapode begin end # Verify the variable accessors. -@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = p.k) +@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k) @test infer_state_names(DiffusionWithConstant) == [:C, :k] @test infer_state_names(DiffusionWithParameter) == [:C, :k] -@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = p.k(t)) +@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k(t)) @test infer_state_names(DiffusionWithLiteral) == [:C] # TODO: Fix proper Expr equality, the Float64 does not equate here @@ -682,6 +682,15 @@ end return op end + # tests that there is no variable shadowing for u and p + NoShadow = @decapode begin + u::Form0 + v::Form0 + end + symsim = gensim(NoShadow) + sim_NS = eval(symsim) + @test sim_NS(d_rect, generate, DiagonalHodge()) isa Any + HeatTransfer = @decapode begin (HT, Tₛ)::Form0 (D, cosϕᵖ, cosϕᵈ)::Constant @@ -770,11 +779,11 @@ for prealloc in [false, true] bytes = @allocated f(du, u₀, p, (0,1.0)) if prealloc - @test nallocs == 3 - @test bytes == 80 + @test nallocs == 2 + @test bytes == 32 elseif !prealloc - @test nallocs == 5 - @test bytes == 400 + @test nallocs == 6 + @test bytes == 352 end end diff --git a/test/simulation_core.jl b/test/simulation_core.jl index 6aa2a9d4..030b7a60 100644 --- a/test/simulation_core.jl +++ b/test/simulation_core.jl @@ -249,14 +249,14 @@ import Decapodes: get_vars_code, AmbiguousNameException let d = @decapode begin end inputs = [:C] add_parts!(d, :Var, 1, name=inputs, type=[:Constant]) - @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = p.C) + @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = __p__.C) end # Test that parameters parse correctly let d = @decapode begin end inputs = [:P] add_parts!(d, :Var, 1, name=inputs, type=[:Parameter]) - @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = p.P(t)) + @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = __p__.P(t)) end # TODO: Remove when Literals are not parsed as symbols anymore @@ -272,7 +272,7 @@ import Decapodes: get_vars_code, AmbiguousNameException let d = @decapode begin end inputs = [:F] add_parts!(d, :Var, 1, name=inputs, type=[form]) - @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = u.F) + @test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = __u__.F) end end