Skip to content

Commit

Permalink
Remove struct for connectivity data (ensure usage of offset_provider …
Browse files Browse the repository at this point in the history
…in all cases)
  • Loading branch information
lorenzovarese committed Jul 11, 2024
1 parent b1cda4d commit 26f7686
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
10 changes: 10 additions & 0 deletions src/GridTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,16 @@ struct Connectivity
dims::Integer
end

@generated function Base.getindex(conn::Connectivity, row::Union{Integer, Colon}, col::Integer)
if row <: Integer
return :(conn.data[row, col])
elseif row <: Colon
return :(conn.data[:, col])
else
throw(ArgumentError("Unsupported index type"))
end
end

# Field operator ----------------------------------------------------------------------

struct FieldOp
Expand Down
36 changes: 14 additions & 22 deletions test/gt2py_fo_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ end

# Utility ----------------------------------------------------------------------------------------------------

struct ConnectivityData
edge_to_cell_table::Matrix{Integer}
cell_to_edge_table::Matrix{Integer}
E2C_offset_provider::Connectivity
C2E_offset_provider::Connectivity
offset_provider::Dict{String,Connectivity}
end

"""
testwrapper(setupfunc::Union{Function,Nothing}, testfunc::Function, args...)
Expand Down Expand Up @@ -86,7 +78,7 @@ end

# Setup ------------------------------------------------------------------------------------------------------

function setup_simple_connectivity()::ConnectivityData
function setup_simple_connectivity()::Dict{String,Connectivity}
edge_to_cell_table = [
[1 -1];
[3 -1];
Expand Down Expand Up @@ -120,7 +112,7 @@ function setup_simple_connectivity()::ConnectivityData
"E2CDim" => E2C_offset_provider #TODO(lorenzovarese) this is required for the embedded backend (note: python already uses E2C)
)

return ConnectivityData(edge_to_cell_table, cell_to_edge_table, E2C_offset_provider, C2E_offset_provider, offset_provider)
return offset_provider
end

# Test Definitions -------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -169,67 +161,67 @@ function test_fo_nested_if_else(backend::String)
@test all(out.data .== collect(22:36))
end

function test_fo_remapping(data::ConnectivityData, backend::String)
function test_fo_remapping(offset_provider::Dict{String,Connectivity}, backend::String)
a = Field(Cell, collect(1.0:15.0))
out = Field(Edge, zeros(Float64, 12))
expected_output = a[data.edge_to_cell_table[:, 1]] # First column of the edge to cell connectivity table
expected_output = a[offset_provider["E2C"][:, 1]] # First column of the edge to cell connectivity table

@field_operator function fo_remapping(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64}
return a(E2C[1])
end

fo_remapping(a, offset_provider=data.offset_provider, backend=backend, out=out)
fo_remapping(a, offset_provider=offset_provider, backend=backend, out=out)
@test all(out.data .== expected_output)
end

function test_fo_neighbor_sum(data::ConnectivityData, backend::String)
function test_fo_neighbor_sum(offset_provider::Dict{String,Connectivity}, backend::String)
a = Field(Cell, collect(1.0:15.0))
out = Field(Edge, zeros(Float64, 12))

# Function to sum only the positive elements of each inner vector (to exclude the -1 in the connectivity)
sum_positive_elements(v) = sum(x -> x > 0 ? x : 0, v)

# Compute the ground truth manually computing the sum on that dimension
expected_output = a[Integer.(map(sum_positive_elements, eachrow(data.edge_to_cell_table)))]
expected_output = a[Integer.(map(sum_positive_elements, eachrow(offset_provider["E2C"].data)))]

@field_operator function fo_neighbor_sum(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64}
return neighbor_sum(a(E2C), axis=E2CDim)
end

fo_neighbor_sum(a, offset_provider=data.offset_provider, backend=backend, out=out)
fo_neighbor_sum(a, offset_provider=offset_provider, backend=backend, out=out)
@test out == expected_output
end

function test_fo_max_over(data::ConnectivityData, backend::String)
function test_fo_max_over(offset_provider::Dict{String,Connectivity}, backend::String)
a = Field(Cell, collect(1.0:15.0))
out = Field(Edge, zeros(Float64, 12))

# Compute the ground truth manually computing max on that dimension
expected_output = a[Integer.(map(maximum, eachrow(data.edge_to_cell_table)))]
expected_output = a[Integer.(map(maximum, eachrow(offset_provider["E2C"].data)))]

@field_operator function fo_max_over(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64}
return max_over(a(E2C), axis=E2CDim)
end

fo_max_over(a, offset_provider=data.offset_provider, backend=backend, out=out)
fo_max_over(a, offset_provider=offset_provider, backend=backend, out=out)
@test out == expected_output
end

function test_fo_min_over(data::ConnectivityData, backend::String)
function test_fo_min_over(offset_provider::Dict{String,Connectivity}, backend::String)
a = Field(Cell, collect(1.0:15.0))
out = Field(Edge, zeros(Float64, 12))

# Function to return the minimum positive element of each inner vector
mim_positive_element(v) = minimum(filter(x -> x > 0, v))

# Compute the ground truth manually computing min on that dimension
expected_output = a[Integer.(map(mim_positive_element, eachrow(data.edge_to_cell_table)))] # We exclude the -1
expected_output = a[Integer.(map(mim_positive_element, eachrow(offset_provider["E2C"].data)))] # We exclude the -1

@field_operator function fo_min_over(a::Field{Tuple{Cell_},Float64})::Field{Tuple{Edge_},Float64}
return min_over(a(E2C), axis=E2CDim)
end

fo_min_over(a, offset_provider=data.offset_provider, backend=backend, out=out)
fo_min_over(a, offset_provider=offset_provider, backend=backend, out=out)
@test out == expected_output
end

Expand Down

0 comments on commit 26f7686

Please sign in to comment.