diff --git a/src/GridTools.jl b/src/GridTools.jl index 31b80ba..71e911a 100644 --- a/src/GridTools.jl +++ b/src/GridTools.jl @@ -455,6 +455,16 @@ struct Connectivity dims::Integer end +@generated function Base.getindex(conn::Connectivity, row::Union{Integer, Colon}, col::Integer) + if row <: Integer + return :(conn.data[row, col]) + elseif row <: Colon + return :(conn.data[:, col]) + else + throw(ArgumentError("Unsupported index type")) + end +end + # Field operator ---------------------------------------------------------------------- struct FieldOp diff --git a/test/gt2py_fo_exec.jl b/test/gt2py_fo_exec.jl index 2e61c30..3df2814 100644 --- a/test/gt2py_fo_exec.jl +++ b/test/gt2py_fo_exec.jl @@ -22,14 +22,6 @@ end # Utility ---------------------------------------------------------------------------------------------------- -struct ConnectivityData - edge_to_cell_table::Matrix{Integer} - cell_to_edge_table::Matrix{Integer} - E2C_offset_provider::Connectivity - C2E_offset_provider::Connectivity - offset_provider::Dict{String,Connectivity} -end - """ testwrapper(setupfunc::Union{Function,Nothing}, testfunc::Function, args...) @@ -86,7 +78,7 @@ end # Setup ------------------------------------------------------------------------------------------------------ -function setup_simple_connectivity()::ConnectivityData +function setup_simple_connectivity()::Dict{String,Connectivity} edge_to_cell_table = [ [1 -1]; [3 -1]; @@ -120,7 +112,7 @@ function setup_simple_connectivity()::ConnectivityData "E2CDim" => E2C_offset_provider #TODO(lorenzovarese) this is required for the embedded backend (note: python already uses E2C) ) - return ConnectivityData(edge_to_cell_table, cell_to_edge_table, E2C_offset_provider, C2E_offset_provider, offset_provider) + return offset_provider end # Test Definitions ------------------------------------------------------------------------------------------- @@ -169,20 +161,20 @@ function test_fo_nested_if_else(backend::String) @test all(out.data .== collect(22:36)) end -function test_fo_remapping(data::ConnectivityData, backend::String) +function test_fo_remapping(offset_provider::Dict{String,Connectivity}, backend::String) a = Field(Cell, collect(1.0:15.0)) out = Field(Edge, zeros(Float64, 12)) - expected_output = a[data.edge_to_cell_table[:, 1]] # First column of the edge to cell connectivity table + expected_output = a[offset_provider["E2C"][:, 1]] # First column of the edge to cell connectivity table @field_operator function fo_remapping(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64} return a(E2C[1]) end - fo_remapping(a, offset_provider=data.offset_provider, backend=backend, out=out) + fo_remapping(a, offset_provider=offset_provider, backend=backend, out=out) @test all(out.data .== expected_output) end -function test_fo_neighbor_sum(data::ConnectivityData, backend::String) +function test_fo_neighbor_sum(offset_provider::Dict{String,Connectivity}, backend::String) a = Field(Cell, collect(1.0:15.0)) out = Field(Edge, zeros(Float64, 12)) @@ -190,32 +182,32 @@ function test_fo_neighbor_sum(data::ConnectivityData, backend::String) sum_positive_elements(v) = sum(x -> x > 0 ? x : 0, v) # Compute the ground truth manually computing the sum on that dimension - expected_output = a[Integer.(map(sum_positive_elements, eachrow(data.edge_to_cell_table)))] + expected_output = a[Integer.(map(sum_positive_elements, eachrow(offset_provider["E2C"].data)))] @field_operator function fo_neighbor_sum(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64} return neighbor_sum(a(E2C), axis=E2CDim) end - fo_neighbor_sum(a, offset_provider=data.offset_provider, backend=backend, out=out) + fo_neighbor_sum(a, offset_provider=offset_provider, backend=backend, out=out) @test out == expected_output end -function test_fo_max_over(data::ConnectivityData, backend::String) +function test_fo_max_over(offset_provider::Dict{String,Connectivity}, backend::String) a = Field(Cell, collect(1.0:15.0)) out = Field(Edge, zeros(Float64, 12)) # Compute the ground truth manually computing max on that dimension - expected_output = a[Integer.(map(maximum, eachrow(data.edge_to_cell_table)))] + expected_output = a[Integer.(map(maximum, eachrow(offset_provider["E2C"].data)))] @field_operator function fo_max_over(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64} return max_over(a(E2C), axis=E2CDim) end - fo_max_over(a, offset_provider=data.offset_provider, backend=backend, out=out) + fo_max_over(a, offset_provider=offset_provider, backend=backend, out=out) @test out == expected_output end -function test_fo_min_over(data::ConnectivityData, backend::String) +function test_fo_min_over(offset_provider::Dict{String,Connectivity}, backend::String) a = Field(Cell, collect(1.0:15.0)) out = Field(Edge, zeros(Float64, 12)) @@ -223,13 +215,13 @@ function test_fo_min_over(data::ConnectivityData, backend::String) mim_positive_element(v) = minimum(filter(x -> x > 0, v)) # Compute the ground truth manually computing min on that dimension - expected_output = a[Integer.(map(mim_positive_element, eachrow(data.edge_to_cell_table)))] # We exclude the -1 + expected_output = a[Integer.(map(mim_positive_element, eachrow(offset_provider["E2C"].data)))] # We exclude the -1 @field_operator function fo_min_over(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64} return min_over(a(E2C), axis=E2CDim) end - fo_min_over(a, offset_provider=data.offset_provider, backend=backend, out=out) + fo_min_over(a, offset_provider=offset_provider, backend=backend, out=out) @test out == expected_output end