Skip to content

Commit

Permalink
Fixed making operator sets global consts
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgeR227 committed Nov 1, 2024
1 parent 1a170b6 commit c709926
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function default_dec_generate(sd::HasDeltaSet, my_symbol::Symbol, hodge::Discret

op = @match my_symbol begin

:plus => (+)
# :plus => (+)
:(-) || :neg => x -> -1 .* x
:ln => (x -> log.(x))
# Musical Isomorphisms
Expand Down
63 changes: 31 additions & 32 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ Base.showerror(io::IO, e::InvalidCodeTargetException) = print(io, "Provided code
This creates the symbol to function linking for the simulation output. Those run through the `default_dec` backend
expect both an in-place and an out-of-place variant in that order. User defined operations only support out-of-place.
"""
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, nonoptimizable_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec_operators::Set{Symbol}, code_target::AbstractGenerationTarget)
defined_ops = deepcopy(con_dec_operators)

defs = quote end
Expand All @@ -255,7 +255,7 @@ function compile_env(d::SummationDecapode, dec_matrices::Vector{Symbol}, con_dec
end

# These are nonoptimizable default DEC functions.
for op in nonoptimizable_operators
for op in non_optimizable(code_target)
op in defined_ops && continue

quote_op = QuoteNode(op)
Expand Down Expand Up @@ -391,15 +391,15 @@ const PROMOTE_ARITHMETIC_MAP = Dict(:(+) => :.+,
:.= => :.=)

"""
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, matrix_optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
Function that compiles the computation body. `d` is the input Decapode, `inputs` is a vector of state variables and literals,
`alloc_vec` should be empty when passed in, `matrix_optimizable_dec_operators` is a collection of all DEC operator symbols that can use special
`alloc_vec` should be empty when passed in, `optimizable_dec_operators` is a collection of all DEC operator symbols that can use special
in-place methods, `dimension` is the dimension of the problem (usually 1 or 2), `stateeltype` is the type of the state elements
(usually Float32 or Float64), `code_target` determines what architecture the code is compiled for (either CPU or CUDA), and `preallocate`
which is set to `true` by default and determines if intermediate results can be preallocated..
"""
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, matrix_optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, preallocate::Bool)
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

Expand Down Expand Up @@ -436,7 +436,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve

# TODO: Check to see if this is a DEC operator
if preallocate && is_form(d, t)
if operator in matrix_optimizable_dec_operators
if operator in optimizable_dec_operators
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = add_stub(GENSIM_INPLACE_STUB, operator)
push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -482,7 +482,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
elseif operator in matrix_optimizable_dec_operators
elseif operator in optimizable_dec_operators
operator = add_stub(GENSIM_INPLACE_STUB, operator)
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -617,13 +617,13 @@ function infer_overload_compiler!(d::SummationDecapode, dimension::Int)
end

"""
init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, matrix_optimizable_dec_operators::Set{Symbol})
init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, matrix_optimizable_dec_operators::Set{Symbol})
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
for op_name in vcat(d[:op1], d[:op2])
if op_name in matrix_optimizable_dec_operators
if op_name in optimizable_dec_operators
push!(dec_matrices, op_name)
end
end
Expand Down Expand Up @@ -698,15 +698,28 @@ end

Base.showerror(io::IO, e::UnsupportedStateeltypeException) = print(io, "Decapodes does not support state element types as $(e.type), only Float32 or Float64")

const MATRIX_OPTIMIZABLE_DEC_OPERATORS = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
:d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁,
:avg₀₁])

const NONMATRIX_OPTIMIZABLE_DEC_OPERATORS = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])

const NON_OPTIMIZABLE_CPU_OPERATORS = Set([:♯ᵖᵖ, :♯ᵈᵈ, :♭ᵈᵖ])
const NON_OPTIMIZABLE_CUDA_OPERATORS = Set{Symbol}()

non_optimizable(::AbstractGenerationTarget) = NON_OPTIMIZABLE_CPU_OPERATORS
non_optimizable(::CPUBackend) = NON_OPTIMIZABLE_CPU_OPERATORS
non_optimizable(::CUDABackend) = NON_OPTIMIZABLE_CUDA_OPERATORS

"""
gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget(), preallocate::Bool = true)
Generates the entire code body for the simulation function. The returned simulation function can then be combined with a mesh, provided by `CombinatorialSpaces`, and a function describing symbol
Generates the entire code body for the simulation function. The returned simulation function can then be combined with a mesh, provided by `CombinatorialSpaces`, and a function describing symbol
to operator mappings to return a simulator that can be used to solve the represented equations given initial conditions.
**Arguments:**
`user_d`: The user passed Decapode for which simulation code will be generated. (This is not modified)
`user_d`: The user passed Decapode for which simulation code will be generated. (This is not modified)
`input_vars` is the collection of variables whose values are known at the beginning of the simulation. (Defaults to all state variables and literals in the Decapode)
Expand Down Expand Up @@ -752,34 +765,20 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
open_operators!(gen_d, dimension = dimension)
infer_overload_compiler!(gen_d, dimension)

# This will generate all of the fundemental DEC operators present
matrix_optimizable_dec_operators = Set([:₀, :₁, :₂, :₀⁻¹, :₂⁻¹,
:d₀, :d₁, :dual_d₀, :d̃₀, :dual_d₁, :d̃₁,
:avg₀₁])
nonmatrix_optimizable_dec_operators = Set([:₁⁻¹, :₀₁, :₁₀, :₁₁, :₀₂, :₂₀])

nonoptimizable_cpu_operators = Set([:♯ᵖᵖ, :♯ᵈᵈ, :♭ᵈᵖ])
nonoptimizable_cuda_operators = Set{Symbol}()
nonoptimizable_operators = @match code_target begin
::CPUBackend => nonoptimizable_cpu_operators
::CUDABackend => nonoptimizable_cuda_operators
_ => throw(InvalidCodeTargetException(code_target))
end

init_dec_matrices!(gen_d, dec_matrices, union(matrix_optimizable_dec_operators, nonmatrix_optimizable_dec_operators))
init_dec_matrices!(gen_d, dec_matrices, union(MATRIX_OPTIMIZABLE_DEC_OPERATORS, NONMATRIX_OPTIMIZABLE_DEC_OPERATORS))

# This contracts matrices together into a single matrix
contracted_dec_operators = Set{Symbol}()
contract_operators!(gen_d, white_list = matrix_optimizable_dec_operators)
contract_operators!(gen_d, white_list = MATRIX_OPTIMIZABLE_DEC_OPERATORS)
cont_defs = link_contract_operators(gen_d, contracted_dec_operators, stateeltype, code_target)

optimizable_dec_operators = union(matrix_optimizable_dec_operators, contracted_dec_operators, nonmatrix_optimizable_dec_operators)
optimizable_dec_operators = union(MATRIX_OPTIMIZABLE_DEC_OPERATORS, contracted_dec_operators, NONMATRIX_OPTIMIZABLE_DEC_OPERATORS)

# Compilation of the simulation
equations = compile(gen_d, input_vars, alloc_vectors, optimizable_dec_operators, dimension, stateeltype, code_target, preallocate)
data = post_process_vector_allocs(alloc_vectors, code_target)

func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, nonoptimizable_operators, code_target)
func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, code_target)
vect_defs = compile_var(alloc_vectors)

quote
Expand Down
23 changes: 11 additions & 12 deletions test/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,19 +655,19 @@ end
add_edges!(primal_line, [1,2], [2,3])
line = generate_dual_mesh(primal_line)

# Testing Diagonal inverse hodge 1
DiagonalInvHodge1 = @decapode begin
A::DualForm1
# Testing Diagonal inverse hodge 1
DiagonalInvHodge1 = @decapode begin
A::DualForm1

B == ∂ₜ(A)
B == (A)
end
g = gensim(DiagonalInvHodge1)
@test gensim(DiagonalInvHodge1).args[2].args[2].args[3].args[2].args[2].args[3].value == :₁⁻¹
sim = eval(g)
B == ∂ₜ(A)
B == (A)
end
g = gensim(DiagonalInvHodge1)
@test g.args[2].args[2].args[3].args[2].args[2].args[3].value == :₁⁻¹
sim = eval(g)

# Test that no error is thrown here
f = sim(line, default_dec_generate, DiagonalHodge())
# TODO: Error is being thrown here
# @test f = sim(line, default_dec_generate, DiagonalHodge()) isa Any
end

@testset "GenSim Compilation" begin
Expand Down Expand Up @@ -829,4 +829,3 @@ haystack = string(gensim(LargeSum))
@test occursin(needle, haystack)

end

0 comments on commit c709926

Please sign in to comment.