Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dace-next
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 4, 2024
2 parents 77c35aa + 7a9489f commit 694bed0
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 17 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ markers = [
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_returns: tests that require backend support for tuple results',
Expand Down
4 changes: 4 additions & 0 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def scan_operator(
forward: bool,
init: core_defs.Scalar,
backend: Optional[str],
grid_type: GridType,
) -> FieldOperator[foast.ScanOperator]:
...

Expand All @@ -786,6 +787,7 @@ def scan_operator(
forward: bool,
init: core_defs.Scalar,
backend: Optional[str],
grid_type: GridType,
) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]:
...

Expand All @@ -797,6 +799,7 @@ def scan_operator(
forward: bool = True,
init: core_defs.Scalar = 0.0,
backend=None,
grid_type: GridType = None,
) -> (
FieldOperator[foast.ScanOperator]
| Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]
Expand Down Expand Up @@ -834,6 +837,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator:
return FieldOperator.from_function(
definition,
backend,
grid_type,
operator_node_cls=foast.ScanOperator,
operator_attributes={"axis": axis, "forward": forward, "init": init},
)
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t
):
return UnknownLength

if not type_.dtype.has_known_length:
return UnknownLength

return len(type_.dtype)


Expand Down
6 changes: 6 additions & 0 deletions src/gt4py/next/iterator/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def __iter__(self) -> abc.Iterator[Type]:
raise ValueError(f"Can not iterate over partially defined tuple '{self}'.")
yield from self.others

@property
def has_known_length(self):
return isinstance(self.others, EmptyTuple) or (
isinstance(self.others, Tuple) and self.others.has_known_length
)

def __len__(self) -> int:
return sum(1 for _ in self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
neighbor_tables = filter_neighbor_tables(offset_provider)
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

sdfg_sig = sdfg.signature_arglist(with_types=False)
dace_args = get_args(sdfg, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
Expand All @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
expected_args = {key: all_args[key] for key in sdfg_sig}

return expected_args


Expand Down Expand Up @@ -258,9 +256,7 @@ def build_sdfg_from_itir(
# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = itir_transforms.LiftMode.FORCE_INLINE

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
Expand All @@ -274,9 +270,10 @@ def build_sdfg_from_itir(

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
# TODO: Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)

if on_gpu:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

# Create the call signature for the SDFG.
# All arguments required by the SDFG, regardless if explicit and implicit, are added
# as positional arguments. In the front are all arguments to the Fencil, in that
# order, they are followed by the arguments created by the translation process,
arg_list = [str(a) for a in node.params]
sig_list = program_sdfg.signature_arglist(with_types=False)
implicit_args = set(sig_list) - set(arg_list)
call_params = arg_list + [ia for ia in sig_list if ia in implicit_args]
program_sdfg.arg_names = call_params
# Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments.
# The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments.
program_sdfg.arg_names = [str(a) for a in node.params]

program_sdfg.validate()
return program_sdfg
Expand Down
3 changes: 3 additions & 0 deletions tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions"
USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator"
USES_SPARSE_FIELDS = "uses_sparse_fields"
USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output"
USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields"
USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset"
USES_TUPLE_ARGS = "uses_tuple_args"
Expand All @@ -119,6 +120,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
(USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
]
DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [
(USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down Expand Up @@ -159,4 +161,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
ProgramFormatterId.GTFN_CPP_FORMATTER: [
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
],
ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)],
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,22 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32
out=cases.allocate(unstructured_case, testee, cases.RETURN)(),
ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1),
)


@pytest.mark.uses_sparse_fields_as_output
def test_write_local_field(unstructured_case):
@gtx.field_operator
def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]:
return inp(V2E)

out = unstructured_case.as_field(
[Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table)
)
inp = cases.allocate(unstructured_case, testee, "inp")()
cases.verify(
unstructured_case,
testee,
inp,
out=out,
ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table],
)

0 comments on commit 694bed0

Please sign in to comment.