Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proposing fix by changing variable names #274

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ function get_vars_code(d::SummationDecapode, vars::Vector{Symbol}, ::Type{statee
# TODO: we should fix that upstream so that we don't need this.
line = @match s_type begin
:Literal => :($s = $(parse(stateeltype, String(s))))
:Constant => :($s = p.$s)
:Parameter => :($s = (p.$s)(t))
:Constant => :($s = __p__.$s)
:Parameter => :($s = (__p__.$s)(t))
_ => hook_GVC_get_form(s, s_type, code_target) # ! WARNING: This assumes a form
# _ => throw(InvalidDecaTypeException(s, s_type)) # TODO: Use this for invalid types
end
Expand All @@ -326,7 +326,7 @@ end

# TODO: Expand on this to be able to handle vector and ComponentArrays inputs
function hook_GVC_get_form(var_name::Symbol, var_type::Symbol, ::Union{CPUBackend, CUDABackend})
return :($var_name = u.$var_name)
return :($var_name = __u__.$var_name)
end

"""
Expand Down Expand Up @@ -354,7 +354,7 @@ is the name of the variable whose data will be stored and a code target.
"""
function hook_STC_settvar(state_name::Symbol, tgt_name::Symbol, ::Union{CPUBackend, CUDABackend})
ssymb = QuoteNode(state_name)
return :(setproperty!(du, $ssymb, $tgt_name))
return :(setproperty!(__du__, $ssymb, $tgt_name))
end

const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
Expand Down Expand Up @@ -546,7 +546,7 @@ This hook is passed in `cache_exprs` which is the collection of exprs to be past
`AllocVecCall` that stores information about the allocated vector and a code target.
"""
function hook_PPVA_data_handle!(cache_exprs::Vector{Expr}, alloc_vec::AllocVecCall, ::CPUBackend)
line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), u)))
line = :($(alloc_vec.name) = (Decapodes.get_tmp($(Symbol(:__,alloc_vec.name)), __u__)))
push!(cache_exprs, line)
end

Expand Down Expand Up @@ -759,7 +759,7 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
$func_defs
$cont_defs
$vect_defs
f(du, u, p, t) = begin
f(__du__, __u__, __p__, t) = begin
$vars
$data
$(equations...)
Expand Down
21 changes: 15 additions & 6 deletions test/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ DiffusionWithLiteral = @decapode begin
end

# Verify the variable accessors.
@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = p.k)
@test Decapodes.get_vars_code(DiffusionWithConstant, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k)
@test infer_state_names(DiffusionWithConstant) == [:C, :k]

@test infer_state_names(DiffusionWithParameter) == [:C, :k]
@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = p.k(t))
@test Decapodes.get_vars_code(DiffusionWithParameter, [:k], Float64, CPUTarget()).args[2] == :(k = __p__.k(t))

@test infer_state_names(DiffusionWithLiteral) == [:C]
# TODO: Fix proper Expr equality, the Float64 does not equate here
Expand Down Expand Up @@ -682,6 +682,15 @@ end
return op
end

# tests that there is no variable shadowing for u and p
NoShadow = @decapode begin
u::Form0
v::Form0
quffaro marked this conversation as resolved.
Show resolved Hide resolved
end
symsim = gensim(NoShadow)
sim_NS = eval(symsim)
@test sim_NS(d_rect, generate, DiagonalHodge()) isa Any

HeatTransfer = @decapode begin
(HT, Tₛ)::Form0
(D, cosϕᵖ, cosϕᵈ)::Constant
Expand Down Expand Up @@ -770,11 +779,11 @@ for prealloc in [false, true]
bytes = @allocated f(du, u₀, p, (0,1.0))

if prealloc
@test nallocs == 3
@test bytes == 80
@test nallocs == 2
lukem12345 marked this conversation as resolved.
Show resolved Hide resolved
@test bytes == 32
elseif !prealloc
@test nallocs == 5
@test bytes == 400
@test nallocs == 6
@test bytes == 352
end
end

Expand Down
6 changes: 3 additions & 3 deletions test/simulation_core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,14 @@ import Decapodes: get_vars_code, AmbiguousNameException
let d = @decapode begin end
inputs = [:C]
add_parts!(d, :Var, 1, name=inputs, type=[:Constant])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = p.C)
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(C = __p__.C)
end

# Test that parameters parse correctly
let d = @decapode begin end
inputs = [:P]
add_parts!(d, :Var, 1, name=inputs, type=[:Parameter])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = p.P(t))
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(P = __p__.P(t))
end

# TODO: Remove when Literals are not parsed as symbols anymore
Expand All @@ -272,7 +272,7 @@ import Decapodes: get_vars_code, AmbiguousNameException
let d = @decapode begin end
inputs = [:F]
add_parts!(d, :Var, 1, name=inputs, type=[form])
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = u.F)
@test get_vars_code(d, inputs, Float64, CPUTarget()).args[begin+1] == :(F = __u__.F)
end
end

Expand Down
Loading