Skip to content

Commit

Permalink
proposing fix by changing variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
quffaro committed Oct 18, 2024
1 parent 586915f commit 9032276
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
12 changes: 6 additions & 6 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ function get_vars_code(d::SummationDecapode, vars::Vector{Symbol}, ::Type{statee
# TODO: we should fix that upstream so that we don't need this.
line = @match s_type begin
:Literal => :($s = $(parse(stateeltype, String(s))))
:Constant => :($s = p.$s)
:Parameter => :($s = (p.$s)(t))
:Constant => :($s = __p__.$s)
:Parameter => :($s = (__p__.$s)(t))
_ => hook_GVC_get_form(s, s_type, code_target) # ! WARNING: This assumes a form
# _ => throw(InvalidDecaTypeException(s, s_type)) # TODO: Use this for invalid types
end
Expand All @@ -326,7 +326,7 @@ end

# TODO: Expand on this to be able to handle vector and ComponentArrays inputs
function hook_GVC_get_form(var_name::Symbol, var_type::Symbol, ::Union{CPUBackend, CUDABackend})
return :($var_name = u.$var_name)
return :($var_name = __u__.$var_name)
end

"""
Expand Down Expand Up @@ -354,7 +354,7 @@ is the name of the variable whose data will be stored and a code target.
"""
function hook_STC_settvar(state_name::Symbol, tgt_name::Symbol, ::Union{CPUBackend, CUDABackend})
ssymb = QuoteNode(state_name)
return :(setproperty!(du, $ssymb, $tgt_name))
return :(setproperty!(__du__, $ssymb, $tgt_name))
end

const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
Expand Down Expand Up @@ -546,7 +546,7 @@ This hook is passed in `cache_exprs` which is the collection of exprs to be past
`AllocVecCall` that stores information about the allocated vector and a code target.
"""
function hook_PPVA_data_handle!(cache_exprs::Vector{Expr}, alloc_vec::AllocVecCall, ::CPUBackend)
line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), u)))
line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), __u__)))
push!(cache_exprs, line)
end

Expand Down Expand Up @@ -759,7 +759,7 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
$func_defs
$cont_defs
$vect_defs
f(du, u, p, t) = begin
f(__du__, __u__, __p__, t) = begin
$vars
$data
$(equations...)
Expand Down
21 changes: 15 additions & 6 deletions test/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ DiffusionWithLiteral = @decapode begin
end

# Verify the variable accessors.
@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = p.k)
@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k)
@test infer_state_names(DiffusionWithConstant) == [:C, :k]

@test infer_state_names(DiffusionWithParameter) == [:C, :k]
@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = p.k(t))
@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k(t))

@test infer_state_names(DiffusionWithLiteral) == [:C]
# TODO: Fix proper Expr equality, the Float64 does not equate here
Expand Down Expand Up @@ -682,6 +682,15 @@ end
return op
end

# tests that there is no variable shadowing for u and p
NoShadow = @decapode begin
u::Form0
v::Form0
end
symsim = gensim(NoShadow)
sim_NS = eval(symsim)
@test sim_NS(d_rect, generate, DiagonalHodge()) isa Any

HeatTransfer = @decapode begin
(HT, Tₛ)::Form0
(D, cosϕᵖ, cosϕᵈ)::Constant
Expand Down Expand Up @@ -770,11 +779,11 @@ for prealloc in [false, true]
bytes = @allocated f(du, u₀, p, (0,1.0))

if prealloc
@test nallocs == 3
@test bytes == 80
@test nallocs == 2
@test bytes == 32
elseif !prealloc
@test nallocs == 5
@test bytes == 400
@test nallocs == 6
@test bytes == 352
end
end

Expand Down
6 changes: 3 additions & 3 deletions test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ import Decapodes: get_vars_code, AmbiguousNameException
let d = @decapode begin end
inputs = [:C]
add_parts!(d, :Var, 1, name=inputs, type=[:Constant])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = p.C)
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = __p__.C)
end

# Test that parameters parse correctly
let d = @decapode begin end
inputs = [:P]
add_parts!(d, :Var, 1, name=inputs, type=[:Parameter])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = p.P(t))
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = __p__.P(t))
end

# TODO: Remove when Literals are not parsed as symbols anymore
Expand All @@ -272,7 +272,7 @@ import Decapodes: get_vars_code, AmbiguousNameException
let d = @decapode begin end
inputs = [:F]
add_parts!(d, :Var, 1, name=inputs, type=[form])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = u.F)
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = __u__.F)
end
end

Expand Down

0 comments on commit 9032276

Please sign in to comment.