From ae6296546d91f41e40451403c3560b1744d467cc Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Dec 2024 21:15:55 +0100 Subject: [PATCH 1/6] feat[next]: Inline dynamic shifts (#1738) Dynamic shifts are not supported in the domain inference. In order to make them work nonetheless this PR aggressively inlines all arguments to `as_fieldop` until they contain only references to `itir.Program` params. Additionally the domain inference is extended to tolerate such `as_fieldop` by introducing a special domain marker that signifies a domain is unknown. --------- Co-authored-by: Hannes Vogt Co-authored-by: Edoardo Paone --- .../iterator/transforms/fuse_as_fieldop.py | 209 ++++++++------ .../next/iterator/transforms/global_tmps.py | 4 +- .../next/iterator/transforms/infer_domain.py | 272 +++++++++++------- .../transforms/inline_dynamic_shifts.py | 73 +++++ .../next/iterator/transforms/pass_manager.py | 7 + tests/next_tests/definitions.py | 1 - .../test_inline_dynamic_shifts.py | 48 ++++ .../transforms_tests/test_domain_inference.py | 115 +++++--- 8 files changed, 492 insertions(+), 237 deletions(-) create mode 100644 src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py create mode 100644 tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 9076bf2d3f..e8a221b814 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr): return isinstance(expr, itir.Literal) +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + + return new_node + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator - def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - @classmethod def apply( cls, @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] - domain = node.fun.args[1] if len(node.fun.args) > 1 else None - - shifts = trace_shifts.trace_stencil(stencil) - args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil) - new_args: dict[str, itir.Expr] = {} - new_stencil_body: itir.Expr = stencil.expr - - for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + eligible_args = [] + for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = _is_tuple_expr_of_literals(arg) or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") + eligible_args.append( + _is_tuple_expr_of_literals(arg) + or ( + isinstance(arg, itir.FunCall) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - if should_inline: - if cpm.is_applied_as_fieldop(arg): - pass - elif cpm.is_call_to(arg, "if_"): - # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type - arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ - elif _is_tuple_expr_of_literals(arg): - arg = im.op_as_fieldop(im.lambda_()(arg))() - else: - raise NotImplementedError() - - inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) - - new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - - new_args = _merge_arguments(new_args, extracted_args) - else: - assert not isinstance(dtype, it_ts.ListType) - new_param: str - if isinstance( - arg, itir.SymRef - ): # use name from outer scope (optional, just to get a nice IR) - new_param = arg.id - new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) - else: - new_param = stencil_param.id - new_args = _merge_arguments(new_args, {new_param: arg}) - - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) - - # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node - ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True - ) - new_node = inline_lifts.InlineLifts().visit(new_node) - - type_inference.copy_type(from_=node, to=new_node) - return new_node + return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain = tmp_expr.annex.domain + domain: infer_domain.DomainAccess = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are @@ -186,7 +186,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider) + program = infer_domain.infer_program(program, offset_provider=offset_provider) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..f26d3f9ec2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,10 +10,10 @@ import itertools import typing -from typing import Callable, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -25,8 +25,35 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. + + #: The access is unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertex and an edge domain. +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] + + +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -57,43 +84,58 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -101,16 +143,16 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -118,44 +160,52 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN + continue + new_domains = [ domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + allow_uninferred: bool, +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): - target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -177,22 +227,29 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes + in_field, + inputs_accessed_domains[in_field_id], + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { @@ -206,17 +263,15 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + input_domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes - ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} + + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) + accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -227,10 +282,9 @@ def _infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, - symbolic_domain_sizes, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -247,13 +301,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -261,13 +314,12 @@ def _infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -276,19 +328,18 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) - tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -297,18 +348,15 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -317,24 +365,23 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -347,10 +394,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + allow_uninferred: bool = False, +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -362,30 +411,35 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or never accessed. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) + expr, accessed_domains = _infer_expr( + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) expr.annex.domain = domain + return expr, accessed_domains def _infer_stmt( stmt: itir.Stmt, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], + **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): - transformed_call, _unused_domain = infer_expr( - stmt.expr, - domain_utils.SymbolicDomain.from_expr(stmt.domain), - offset_provider, - symbolic_domain_sizes, + transformed_call, _ = infer_expr( + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs ) + return itir.SetAt( expr=transformed_call, domain=stmt.domain, @@ -394,20 +448,18 @@ def _infer_stmt( elif isinstance(stmt, itir.IfStmt): return itir.IfStmt( cond=stmt.cond, - true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch - ], - false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch - ], + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], ) raise ValueError(f"Unsupported stmt: {stmt}") def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -423,5 +475,13 @@ def infer_program( function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + body=[ + _infer_stmt( + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..0af9d9dab9 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [ + any(trace_shifts.Sentinel.VALUE in shifts for shifts in param_shifts) + for param_shifts in params_shifts + ] + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index ec4207d726..d967c8fbb8 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -15,6 +15,7 @@ fuse_as_fieldop, global_tmps, infer_domain, + inline_dynamic_shifts, inline_fundefs, inline_lifts, ) @@ -73,6 +74,9 @@ def apply_common_transforms( ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) ir = CollapseTuple.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( ir, offset_provider=offset_provider, @@ -158,5 +162,8 @@ def apply_fieldview_transforms( ir = CollapseTuple.apply( ir, offset_provider_type=common.offset_provider_to_type(offset_provider) ) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index d7413f32d7..bed6e89a52 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 2492fc446d..779ab738cb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -76,7 +76,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -89,12 +89,14 @@ def run_test_expr( expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, - symbolic_domain_sizes, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -104,10 +106,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -128,10 +128,12 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: domain_utils.SymbolicDomain | None): - if domain is None: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -154,7 +156,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( - shift_list, offset_provider + shift_list, offset_provider=offset_provider ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -340,7 +342,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -384,7 +386,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -397,7 +399,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -409,7 +414,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -522,7 +527,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -579,7 +584,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -803,7 +808,7 @@ def test_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -815,13 +820,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -833,7 +838,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, @@ -841,7 +846,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -852,14 +857,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -870,12 +879,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -903,7 +916,7 @@ def test_nested_make_tuple(offset_provider): ), domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -914,10 +927,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -937,7 +950,7 @@ def test_domain_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -953,7 +966,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -973,7 +986,7 @@ def test_make_tuple_2tuple_get(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -990,7 +1003,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -1004,7 +1017,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) @@ -1048,3 +1061,35 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): unstructured_offset_provider, symbolic_domain_sizes, ) + + +def test_unknown_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) From 29b6af23c15955910f413ed12e5d1a463e7b5b4b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 9 Dec 2024 16:44:28 +0100 Subject: [PATCH 2/6] build: fix min version of filelock (#1777) ... and fix linting after ruff update. --- .pre-commit-config.yaml | 10 ++-- constraints.txt | 48 +++++++++---------- min-extra-requirements-test.txt | 2 +- min-requirements-test.txt | 2 +- pyproject.toml | 2 +- requirements-dev.txt | 48 +++++++++---------- src/gt4py/__init__.py | 2 +- src/gt4py/cartesian/__init__.py | 4 +- src/gt4py/cartesian/backend/__init__.py | 2 +- src/gt4py/cartesian/cli.py | 2 +- src/gt4py/cartesian/frontend/__init__.py | 2 +- src/gt4py/cartesian/gtscript.py | 6 +-- src/gt4py/cartesian/testing/__init__.py | 2 +- src/gt4py/cartesian/utils/__init__.py | 2 +- src/gt4py/cartesian/utils/base.py | 6 +-- src/gt4py/eve/__init__.py | 2 +- src/gt4py/eve/datamodels/validators.py | 2 +- src/gt4py/next/errors/__init__.py | 2 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- .../next/iterator/transforms/__init__.py | 2 +- .../iterator/transforms/fuse_as_fieldop.py | 6 ++- .../transformations/__init__.py | 14 +++--- src/gt4py/storage/__init__.py | 6 +-- 24 files changed, 88 insertions(+), 90 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e1870c67f..e383112310 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: v{version}") ##]]] - rev: v0.7.4 + rev: v0.8.2 ##[[[end]]] hooks: # Run the linter. @@ -96,7 +96,7 @@ repos: - boltons==24.1.0 - cached-property==2.0.1 - click==8.1.7 - - cmake==3.31.0.1 + - cmake==3.31.1 - cytoolz==1.0.0 - deepdiff==8.0.1 - devtools==0.12.2 @@ -108,9 +108,9 @@ repos: - importlib-resources==6.4.5 - jinja2==3.1.4 - lark==1.2.2 - - mako==1.3.6 - - nanobind==2.2.0 - - ninja==1.11.1.1 + - mako==1.3.8 + - nanobind==2.4.0 + - ninja==1.11.1.2 - numpy==1.24.4 - packaging==24.2 - pybind11==2.13.6 diff --git a/constraints.txt b/constraints.txt index f039fa2125..fbdfb6e267 100644 --- a/constraints.txt +++ b/constraints.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via requests cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox charset-normalizer==3.4.0 # via requests -clang-format==19.1.3 # via -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via gt4py (pyproject.toml) +cmake==3.31.1 # via gt4py (pyproject.toml) cogapp==3.4.1 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.2 # via ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via matplotlib cytoolz==1.0.0 # via gt4py (pyproject.toml) dace==1.0.0 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.8 # via ipykernel +debugpy==1.8.9 # via ipykernel decorator==5.1.1 # via ipython deepdiff==8.0.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via hypothesis, pytest execnet==2.1.1 # via pytest-cache, pytest-xdist executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via factory-boy -fastjsonschema==2.20.0 # via nbformat +faker==33.1.0 # via factory-boy +fastjsonschema==2.21.1 # via nbformat filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via matplotlib -fparser==0.1.4 # via dace +fonttools==4.55.2 # via matplotlib +fparser==0.2.0 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) gitdb==4.0.11 # via gitpython gitpython==3.1.43 # via tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat jupytext==1.16.4 # via -r requirements-dev.in kiwisolver==1.4.7 # via matplotlib lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.6 # via gt4py (pyproject.toml) +mako==1.3.8 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via jinja2, mako matplotlib==3.7.5 # via -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy mypy==1.13.0 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==2.2.0 # via gt4py (pyproject.toml) +nanobind==2.4.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.10.4 # via jupytext, nbclient, nbmake nbmake==1.5.4 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace, tach -ninja==1.11.1.1 # via gt4py (pyproject.toml) +ninja==1.11.1.2 # via gt4py (pyproject.toml) nodeenv==1.9.1 # via pre-commit numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy orderly-set==5.2.2 # via deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via ipython pickleshare==0.7.5 # via ipython pillow==10.4.0 # via matplotlib pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.23.4 # via -r requirements-dev.in +pipdeptree==2.24.0 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema platformdirs==4.3.6 # via black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.10.0 # via bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via pydantic +pydantic==2.10.3 # via bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via pydantic pydantic-settings==2.6.1 # via bump-my-version -pydot==3.0.2 # via tach +pydot==3.0.3 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via matplotlib, pydot pyproject-api==1.8.0 # via tox pyproject-hooks==1.2.0 # via build, pip-tools -pytest==8.3.3 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==5.0.0 # via -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in @@ -137,12 +137,12 @@ questionary==2.0.1 # via bump-my-version referencing==0.35.1 # via jsonschema, jsonschema-specifications requests==2.32.3 # via sphinx rich==13.9.4 # via bump-my-version, rich-click, tach -rich-click==1.8.4 # via bump-my-version +rich-click==1.8.5 # via bump-my-version rpds-py==0.20.1 # via jsonschema, referencing -ruff==0.7.4 # via -r requirements-dev.in +ruff==0.8.2 # via -r requirements-dev.in scipy==1.10.1 # via gt4py (pyproject.toml) setuptools-scm==8.1.0 # via fparser -six==1.16.0 # via asttokens, astunparse, python-dateutil +six==1.17.0 # via asttokens, astunparse, python-dateutil smmap==5.0.1 # via gitdb snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via hypothesis @@ -159,21 +159,21 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.4 # via -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version toolz==1.0.0 # via cytoolz -tornado==6.4.1 # via ipykernel, jupyter-client +tornado==6.4.2 # via ipykernel, jupyter-client tox==4.23.2 # via -r requirements-dev.in traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -r requirements-dev.in typing-extensions==4.12.2 # via annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via requests -virtualenv==20.27.1 # via pre-commit, tox +virtualenv==20.28.0 # via pre-commit, tox wcmatch==10.0 # via bump-my-version wcwidth==0.2.13 # via prompt-toolkit -wheel==0.45.0 # via astunparse, pip-tools +wheel==0.45.1 # via astunparse, pip-tools xxhash==3.0.0 # via gt4py (pyproject.toml) zipp==3.20.2 # via importlib-metadata, importlib-resources diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index d7679a1f0f..6d75415181 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,7 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index cf505e88d6..991b7a6941 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,7 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 -filelock==3.0.0 +filelock==3.16.1 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index e859c9b4f7..d086363ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', - 'filelock>=3.0.0', + 'filelock>=3.16.1', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6542be36f1..40554cef13 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -23,9 +23,9 @@ certifi==2024.8.30 # via -c constraints.txt, requests cfgv==3.4.0 # via -c constraints.txt, pre-commit chardet==5.2.0 # via -c constraints.txt, tox charset-normalizer==3.4.0 # via -c constraints.txt, requests -clang-format==19.1.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) +clang-format==19.1.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.31.0.1 # via -c constraints.txt, gt4py (pyproject.toml) +cmake==3.31.1 # via -c constraints.txt, gt4py (pyproject.toml) cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in colorama==0.4.6 # via -c constraints.txt, tox comm==0.2.2 # via -c constraints.txt, ipykernel @@ -35,7 +35,7 @@ cycler==0.12.1 # via -c constraints.txt, matplotlib cytoolz==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) dace==1.0.0 # via -c constraints.txt, gt4py (pyproject.toml) darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.8 # via -c constraints.txt, ipykernel +debugpy==1.8.9 # via -c constraints.txt, ipykernel decorator==5.1.1 # via -c constraints.txt, ipython deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) @@ -47,11 +47,11 @@ exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==33.0.0 # via -c constraints.txt, factory-boy -fastjsonschema==2.20.0 # via -c constraints.txt, nbformat +faker==33.1.0 # via -c constraints.txt, factory-boy +fastjsonschema==2.21.1 # via -c constraints.txt, nbformat filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv -fonttools==4.55.0 # via -c constraints.txt, matplotlib -fparser==0.1.4 # via -c constraints.txt, dace +fonttools==4.55.2 # via -c constraints.txt, matplotlib +fparser==0.2.0 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) gitdb==4.0.11 # via -c constraints.txt, gitpython gitpython==3.1.43 # via -c constraints.txt, tach @@ -75,7 +75,7 @@ jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, n jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in kiwisolver==1.4.7 # via -c constraints.txt, matplotlib lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.6 # via -c constraints.txt, gt4py (pyproject.toml) +mako==1.3.8 # via -c constraints.txt, gt4py (pyproject.toml) markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in @@ -85,13 +85,13 @@ mdurl==0.1.2 # via -c constraints.txt, markdown-it-py mpmath==1.3.0 # via -c constraints.txt, sympy mypy==1.13.0 # via -c constraints.txt, -r requirements-dev.in mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.2.0 # via -c constraints.txt, gt4py (pyproject.toml) +nanobind==2.4.0 # via -c constraints.txt, gt4py (pyproject.toml) nbclient==0.6.8 # via -c constraints.txt, nbmake nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient networkx==3.1 # via -c constraints.txt, dace, tach -ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) +ninja==1.11.1.2 # via -c constraints.txt, gt4py (pyproject.toml) nodeenv==1.9.1 # via -c constraints.txt, pre-commit numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib orderly-set==5.2.2 # via -c constraints.txt, deepdiff @@ -102,7 +102,7 @@ pexpect==4.9.0 # via -c constraints.txt, ipython pickleshare==0.7.5 # via -c constraints.txt, ipython pillow==10.4.0 # via -c constraints.txt, matplotlib pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.23.4 # via -c constraints.txt, -r requirements-dev.in +pipdeptree==2.24.0 # via -c constraints.txt, -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema platformdirs==4.3.6 # via -c constraints.txt, black, jupyter-core, tox, virtualenv pluggy==1.5.0 # via -c constraints.txt, pytest, tox @@ -113,15 +113,15 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.27.0 # via -c constraints.txt, pydantic +pydantic==2.10.3 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.1 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version -pydot==3.0.2 # via -c constraints.txt, tach +pydot==3.0.3 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot pyproject-api==1.8.0 # via -c constraints.txt, tox pyproject-hooks==1.2.0 # via -c constraints.txt, build, pip-tools -pytest==8.3.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist +pytest==8.3.4 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in @@ -137,11 +137,11 @@ questionary==2.0.1 # via -c constraints.txt, bump-my-version referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications requests==2.32.3 # via -c constraints.txt, sphinx rich==13.9.4 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.4 # via -c constraints.txt, bump-my-version +rich-click==1.8.5 # via -c constraints.txt, bump-my-version rpds-py==0.20.1 # via -c constraints.txt, jsonschema, referencing -ruff==0.7.4 # via -c constraints.txt, -r requirements-dev.in +ruff==0.8.2 # via -c constraints.txt, -r requirements-dev.in setuptools-scm==8.1.0 # via -c constraints.txt, fparser -six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil +six==1.17.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil smmap==5.0.1 # via -c constraints.txt, gitdb snowballstemmer==2.2.0 # via -c constraints.txt, sphinx sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis @@ -158,21 +158,21 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in -tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox +tach==0.16.5 # via -c constraints.txt, -r requirements-dev.in +tomli==2.2.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version toolz==1.0.0 # via -c constraints.txt, cytoolz -tornado==6.4.1 # via -c constraints.txt, ipykernel, jupyter-client +tornado==6.4.2 # via -c constraints.txt, ipykernel, jupyter-client tox==4.23.2 # via -c constraints.txt, -r requirements-dev.in traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -c constraints.txt, -r requirements-dev.in +types-tabulate==0.9.0.20241207 # via -c constraints.txt, -r requirements-dev.in typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, faker, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm, tox urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.27.1 # via -c constraints.txt, pre-commit, tox +virtualenv==20.28.0 # via -c constraints.txt, pre-commit, tox wcmatch==10.0 # via -c constraints.txt, bump-my-version wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -wheel==0.45.0 # via -c constraints.txt, astunparse, pip-tools +wheel==0.45.1 # via -c constraints.txt, astunparse, pip-tools xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) zipp==3.20.2 # via -c constraints.txt, importlib-metadata, importlib-resources diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 1b88285475..c0bf9580b3 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -27,6 +27,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 shadowing a Python builtin __all__ += ["next"] diff --git a/src/gt4py/cartesian/__init__.py b/src/gt4py/cartesian/__init__.py index c03ef15105..90df315d5c 100644 --- a/src/gt4py/cartesian/__init__.py +++ b/src/gt4py/cartesian/__init__.py @@ -27,7 +27,7 @@ __all__ = [ - "typing", + "StencilObject", "caching", "cli", "config", @@ -39,5 +39,5 @@ "stencil_builder", "stencil_object", "type_hints", - "StencilObject", + "typing", ] diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index e58c7a01a7..4296e3b389 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -32,9 +32,9 @@ "BasePyExtBackend", "CLIBackendMixin", "CudaBackend", - "GTGpuBackend", "GTCpuIfirstBackend", "GTCpuKfirstBackend", + "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 91daed9e98..4ea5e44074 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -90,7 +90,7 @@ def backend_table(cls) -> str: ", ".join(backend.languages["bindings"]) if backend and backend.languages else "?" for backend in backends ] - enabled = [backend is not None and "Yes" or "No" for backend in backends] + enabled = [(backend is not None and "Yes") or "No" for backend in backends] data = zip(names, comp_langs, binding_langs, enabled) return tabulate.tabulate(data, headers=headers) diff --git a/src/gt4py/cartesian/frontend/__init__.py b/src/gt4py/cartesian/frontend/__init__.py index 6988fb6aab..f1e0f9a775 100644 --- a/src/gt4py/cartesian/frontend/__init__.py +++ b/src/gt4py/cartesian/frontend/__init__.py @@ -10,4 +10,4 @@ from .base import REGISTRY, Frontend, from_name, register -__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"] +__all__ = ["REGISTRY", "Frontend", "from_name", "gtscript_frontend", "register"] diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 643ecba010..59f3ef37c2 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -657,10 +657,8 @@ def __str__(self) -> str: class _FieldDescriptorMaker: @staticmethod def _is_axes_spec(spec) -> bool: - return ( - isinstance(spec, Axis) - or isinstance(spec, collections.abc.Collection) - and all(isinstance(i, Axis) for i in spec) + return isinstance(spec, Axis) or ( + isinstance(spec, collections.abc.Collection) and all(isinstance(i, Axis) for i in spec) ) def __getitem__(self, field_spec): diff --git a/src/gt4py/cartesian/testing/__init__.py b/src/gt4py/cartesian/testing/__init__.py index 288d7b1d2d..0753b4175e 100644 --- a/src/gt4py/cartesian/testing/__init__.py +++ b/src/gt4py/cartesian/testing/__init__.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["field", "global_name", "none", "parameter", "StencilTestSuite"] +__all__ = ["StencilTestSuite", "field", "global_name", "none", "parameter"] try: from .input_strategies import field, global_name, none, parameter from .suites import StencilTestSuite diff --git a/src/gt4py/cartesian/utils/__init__.py b/src/gt4py/cartesian/utils/__init__.py index 3c0bdb3fc3..626d29b167 100644 --- a/src/gt4py/cartesian/utils/__init__.py +++ b/src/gt4py/cartesian/utils/__init__.py @@ -37,7 +37,7 @@ ) -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # Modules "attrib", "meta", diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index d5d43a4103..35184a3f7b 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -63,10 +63,8 @@ def flatten_iter(nested_iterables, filter_none=False, *, skip_types=(str, bytes) def get_member(instance, item_name): try: - if ( - isinstance(instance, collections.abc.Mapping) - or isinstance(instance, collections.abc.Sequence) - and isinstance(item_name, int) + if isinstance(instance, collections.abc.Mapping) or ( + isinstance(instance, collections.abc.Sequence) and isinstance(item_name, int) ): return instance[item_name] else: diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 5adac47da3..e6044f15ef 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -71,7 +71,7 @@ from .visitors import NodeTranslator, NodeVisitor -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # version "__version__", "__version_info__", diff --git a/src/gt4py/eve/datamodels/validators.py b/src/gt4py/eve/datamodels/validators.py index 119410460c..4ce6f94c5e 100644 --- a/src/gt4py/eve/datamodels/validators.py +++ b/src/gt4py/eve/datamodels/validators.py @@ -42,7 +42,7 @@ from .core import DataModelTP, FieldValidator -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # reexported from attrs "and_", "deep_iterable", diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 89f78a45e4..9febe098a4 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -23,9 +23,9 @@ __all__ = [ "DSLError", "InvalidParameterAnnotationError", + "MissingArgumentError", "MissingAttributeError", "MissingParameterAnnotationError", - "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", ] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b60fa63f95..1210e96efc 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,7 +10,7 @@ import functools import inspect import math -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index e47a6886ad..c9a5b15de7 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] +__all__ = ["fendef", "fundef", "if_stmt", "offset", "set_at"] @dataclass(frozen=True) diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index d0afc610e7..1d91254ee8 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -13,4 +13,4 @@ ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] +__all__ = ["GTIRTransform", "apply_common_transforms", "apply_fieldview_transforms"] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index e8a221b814..661b456608 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -240,8 +240,10 @@ def visit_FunCall(self, node: itir.FunCall): or ( isinstance(arg, itir.FunCall) and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) + ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + ) or cpm.is_call_to(arg, "if_") ) and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 2232bcef01..4f3efb19b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -43,25 +43,25 @@ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", "GT4PyGlobalSelfCopyElimination", - "GT4PyMoveTaskletIntoMap", "GT4PyMapBufferElimination", + "GT4PyMoveTaskletIntoMap", "LoopBlocking", - "MapIterationOrder", "MapFusionParallel", "MapFusionSerial", + "MapIterationOrder", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", "gt_change_transient_strides", "gt_create_local_double_buffering", + "gt_find_constant_arguments", + "gt_gpu_transform_non_standard_memlet", "gt_gpu_transformation", "gt_inline_nested_sdfg", - "gt_set_iteration_order", - "gt_set_gpu_blocksize", - "gt_simplify", "gt_make_transients_persistent", "gt_reduce_distributed_buffering", - "gt_find_constant_arguments", + "gt_set_gpu_blocksize", + "gt_set_iteration_order", + "gt_simplify", "gt_substitute_compiletime_symbols", - "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index 4866cd480c..5986baa65e 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -16,12 +16,12 @@ __all__ = [ "cartesian", - "layout", "empty", "from_array", + "from_name", "full", + "layout", "ones", - "zeros", - "from_name", "register", + "zeros", ] From 98889056c914886912d9131793deb67b5f947602 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Dec 2024 10:02:22 +0100 Subject: [PATCH 3/6] feat[next]: Change interval syntax in ITIR pretty printer (#1766) We currently use `)` in the pretty printer to express an open interval. This is quite cumbersome when debugging the IR because it breaks matching parenthesis in the editor of functions and calls, e.g. when does a function start and end. This PR simply uses `[` instead. --- src/gt4py/next/iterator/ir_utils/ir_makers.py | 6 +++--- src/gt4py/next/iterator/pretty_parser.py | 2 +- src/gt4py/next/iterator/pretty_printer.py | 4 +++- src/gt4py/next/iterator/transforms/fuse_as_fieldop.py | 6 +++--- src/gt4py/next/iterator/transforms/inline_fundefs.py | 2 +- .../unit_tests/iterator_tests/test_pretty_parser.py | 4 ++-- .../unit_tests/iterator_tests/test_pretty_printer.py | 2 +- 7 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index a4e111e785..0839e95b5b 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -423,11 +423,11 @@ def domain( ... }, ... ) ... ) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)})) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' >>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)})) - 'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 29b30beae1..a077b39911 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 "[" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index a25f99356c..7acbf5d23d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -190,7 +190,9 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]: if fun_name == "named_range" and len(node.args) == 3: # named_range(dim, start, stop) → dim: [star, stop) dim, start, end = self.visit(node.args, prec=0) - res = self._hmerge(dim, [": ["], start, [", "], end, [")"]) + res = self._hmerge( + dim, [": ["], start, [", "], end, ["["] + ) # to get matching parenthesis of functions return self._prec_parens(res, prec, PRECEDENCE["__call__"]) if fun_name == "cartesian_domain" and len(node.args) >= 1: # cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003 [ambiguous-unicode-character-comment] diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 661b456608..b7087472e0 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -186,15 +186,15 @@ class FuseAsFieldOp(eve.NodeTranslator): ... im.ref("inp3", field_type), ... ) >>> print(nested_as_fieldop) - as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)( - as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2), inp3 + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2), inp3 ) >>> print( ... FuseAsFieldOp.apply( ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True ... ) ... ) - as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1) ⟩)(inp1, inp2, inp3) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character uids: eve_utils.UIDGenerator diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index a2188030a1..e4cae978da 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -59,7 +59,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> print(prune_unreferenced_fundefs(program)) testee(inp, out) { fun1 = λ(a) → ·a; - out @ c⟨ IDimₕ: [0, 10) ⟩ ← fun1(inp); + out @ c⟨ IDimₕ: [0, 10[ ⟩ ← fun1(inp); } """ fun_names = [fun.id for fun in program.function_definitions] diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index bf47f997d6..af9084f407 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -127,7 +127,7 @@ def test_make_tuple(): def test_named_range_horizontal(): - testee = "IDimₕ: [x, y)" + testee = "IDimₕ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], @@ -137,7 +137,7 @@ def test_named_range_horizontal(): def test_named_range_vertical(): - testee = "IDimᵥ: [x, y)" + testee = "IDimᵥ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 11f50dbf6d..6b45f470b7 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -233,7 +233,7 @@ def test_named_range_horizontal(): fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], ) - expected = "IDimₕ: [x, y)" + expected = "IDimₕ: [x, y[" actual = pformat(testee) assert actual == expected From 06b398af7c5a4235d2c595bbbac93ec70f31a5a6 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 16 Dec 2024 15:32:16 +0100 Subject: [PATCH 4/6] refact[next][dace]: split handling of let-statement lambdas from stencil body (#1781) This is a refactoring of the code to lower lambda nodes: it splits the lowering of let-statements from the lowering of stencil expressions. --- .../gtir_builtin_translators.py | 43 ++--- .../runners/dace_fieldview/gtir_dataflow.py | 165 +++++++++++++----- .../runners/dace_fieldview/gtir_sdfg.py | 5 +- 3 files changed, 143 insertions(+), 70 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index ff011c4193..cffbd74c90 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins @@ -30,7 +30,7 @@ gtir_python_codegen, utility as dace_gtir_utils, ) -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info as ti, type_specifications as ts if TYPE_CHECKING: @@ -39,7 +39,7 @@ def _get_domain_indices( dims: Sequence[gtx_common.Dimension], offsets: Optional[Sequence[dace.symbolic.SymExpr]] = None -) -> sbs.Indices: +) -> dace_subsets.Indices: """ Helper function to construct the list of indices for a field domain, applying an optional offset in each dimension as start index. @@ -55,9 +55,9 @@ def _get_domain_indices( """ index_variables = [dace.symbolic.SymExpr(dace_gtir_utils.get_map_variable(dim)) for dim in dims] if offsets is None: - return sbs.Indices(index_variables) + return dace_subsets.Indices(index_variables) else: - return sbs.Indices( + return dace_subsets.Indices( [ index - offset if offset != 0 else index for index, offset in zip(index_variables, offsets, strict=True) @@ -96,7 +96,7 @@ def get_local_view( """Helper method to access a field in local view, given the compute domain of a field operator.""" if isinstance(self.gt_type, ts.ScalarType): return gtir_dataflow.MemletExpr( - dc_node=self.dc_node, gt_dtype=self.gt_type, subset=sbs.Indices([0]) + dc_node=self.dc_node, gt_dtype=self.gt_type, subset=dace_subsets.Indices([0]) ) if isinstance(self.gt_type, ts.FieldType): @@ -263,7 +263,7 @@ def _create_field_operator( dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - field_subset = sbs.Range.from_indices(field_indices) + field_subset = dace_subsets.Range.from_indices(field_indices) if isinstance(output_edge.result.gt_dtype, ts.ScalarType): assert output_edge.result.gt_dtype == node_type.dtype assert isinstance(dataflow_output_desc, dace.data.Scalar) @@ -280,7 +280,7 @@ def _create_field_operator( field_dims.append(output_edge.result.gt_dtype.offset_type) field_shape.extend(dataflow_output_desc.shape) field_offset.extend(dataflow_output_desc.offset) - field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) # allocate local temporary storage field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) @@ -366,36 +366,37 @@ def translate_as_fieldop( """ assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) fun_node = node.fun assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args + fieldop_expr, domain_expr = fun_node.args - if isinstance(stencil_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - pass - elif cpm.is_ref_to(stencil_expr, "deref"): + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) - stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( - f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." ) # parse the domain of the field operator domain = extract_domain(domain_expr) # visit the list of arguments to be passed to the lambda expression - stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) + input_edges, output_edge = gtir_dataflow.visit_lambda( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge @@ -654,7 +655,7 @@ def translate_tuple_get( if not isinstance(node.args[0], gtir.Literal): raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_itir_type(INDEX_DTYPE) + assert ti.is_integral(node.args[0].type) index = int(node.args[0].value) data_nodes = sdfg_builder.visit( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index cfba4d61e5..a3653fb519 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,10 +10,22 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + Final, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) import dace -import dace.subsets as sbs +from dace import subsets as dace_subsets from gt4py import eve from gt4py.next import common as gtx_common @@ -68,7 +80,7 @@ class MemletExpr: dc_node: dace.nodes.AccessNode gt_dtype: itir_ts.ListType | ts.ScalarType - subset: sbs.Indices | sbs.Range + subset: dace_subsets.Range @dataclasses.dataclass(frozen=True) @@ -104,7 +116,7 @@ class IteratorExpr: field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymExpr]] indices: dict[gtx_common.Dimension, DataExpr] - def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): raise ValueError(f"Cannot deref iterator {self}.") @@ -117,7 +129,7 @@ def get_memlet_subset(self, sdfg: dace.SDFG) -> sbs.Range: assert len(field_desc.shape) == len(self.field_domain) field_domain = self.field_domain - return sbs.Range.from_string( + return dace_subsets.Range.from_string( ",".join( str(self.indices[dim].value - offset) # type: ignore[union-attr] if dim in self.indices @@ -152,7 +164,7 @@ class MemletInputEdge(DataflowInputEdge): state: dace.SDFGState source: dace.nodes.AccessNode - subset: sbs.Range + subset: dace_subsets.Range dest: dace.nodes.AccessNode | dace.nodes.Tasklet dest_conn: Optional[str] @@ -202,7 +214,7 @@ def connect( self, mx: dace.nodes.MapExit, dest: dace.nodes.AccessNode, - subset: sbs.Range, + subset: dace_subsets.Range, ) -> None: # retrieve the node which writes the result last_node = self.state.in_edges(self.result.dc_node)[0].src @@ -256,10 +268,12 @@ def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: return op_name, reduce_init, reduce_identity +@dataclasses.dataclass(frozen=True) class LambdaToDataflow(eve.NodeVisitor): """ - Translates an `ir.Lambda` expression to a dataflow graph. + Visitor class to translate a `Lambda` expression to a dataflow graph. + This visitor should be applied by calling `apply()` method on a `Lambda` IR. The dataflow graph generated here typically represents the stencil function of a field operator. It only computes single elements or pure local fields, in case of neighbor values. In case of local fields, the dataflow contains @@ -275,25 +289,15 @@ class LambdaToDataflow(eve.NodeVisitor): sdfg: dace.SDFG state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.input_edges = [] - self.symbol_map = {} + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] = dataclasses.field( + default_factory=lambda: {} + ) def _add_input_data_edge( self, src: dace.nodes.AccessNode, - src_subset: sbs.Range, + src_subset: dace_subsets.Range, dst_node: dace.nodes.Node, dst_conn: Optional[str] = None, src_offset: Optional[list[dace.symbolic.SymExpr]] = None, @@ -301,7 +305,7 @@ def _add_input_data_edge( input_subset = ( src_subset if src_offset is None - else sbs.Range( + else dace_subsets.Range( (start - off, stop - off, step) for (start, stop, step), off in zip(src_subset, src_offset, strict=True) ) @@ -512,7 +516,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # add new termination point for the field parameter self._add_input_data_edge( arg_expr.field, - sbs.Range.from_array(field_desc), + dace_subsets.Range.from_array(field_desc), deref_node, "field", src_offset=[offset for (_, offset) in arg_expr.field_domain], @@ -580,7 +584,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=it.field, gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( ",".join( str(it.indices[dim].value - offset) # type: ignore[union-attr] if dim != offset_provider.codomain @@ -596,7 +600,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: MemletExpr( dc_node=self.state.add_access(connectivity), gt_dtype=node.type, - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_index.value}, 0:{offset_provider.max_neighbors}" ), ) @@ -758,7 +762,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=itir_ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), - subset=sbs.Range.from_string( + subset=dace_subsets.Range.from_string( f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" ), ) @@ -908,7 +912,9 @@ def _make_reduce_with_skip_values( ) self._add_input_data_edge( connectivity_node, - sbs.Range.from_string(f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}"), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), nsdfg_node, "neighbor_indices", ) @@ -1081,7 +1087,7 @@ def _make_dynamic_neighbor_offset( ) self._add_input_data_edge( offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), tasklet_node, "table", ) @@ -1127,7 +1133,7 @@ def _make_unstructured_shift( shifted_indices[neighbor_dim] = MemletExpr( dc_node=offset_table_node, gt_dtype=it.gt_dtype, - subset=sbs.Indices([origin_index.value, offset_expr.value]), + subset=dace_subsets.Indices([origin_index.value, offset_expr.value]), ) else: # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node @@ -1264,39 +1270,39 @@ def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | DataExpr: elif cpm.is_applied_shift(node): return self._visit_shift(node) + elif isinstance(node.fun, gtir.Lambda): + # Lambda node should be visited with 'visit_let()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") + elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) else: raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: DataExpr = self.visit(node.expr) - if isinstance(output_expr, ValueExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + def visit_Lambda(self, node: gtir.Lambda) -> DataflowOutputEdge: + result: DataExpr = self.visit(node.expr) + + if isinstance(result, ValueExpr): + return DataflowOutputEdge(self.state, result) - if isinstance(output_expr, MemletExpr): + if isinstance(result, MemletExpr): # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + output_dtype = result.dc_node.desc(self.sdfg).dtype tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") self._add_input_data_edge( - output_expr.dc_node, - output_expr.subset, + result.dc_node, + result.subset, tasklet_node, "__inp", ) else: - assert isinstance(output_expr, SymbolExpr) # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dc_dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") + output_dtype = result.dc_dtype + tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {result.value}") output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) + return DataflowOutputEdge(self.state, output_expr) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) @@ -1309,3 +1315,68 @@ def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolE # if not in the lambda symbol map, this must be a symref to a builtin function assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING return SymbolExpr(param, dace.string) + + def visit_let( + self, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], + ) -> DataflowOutputEdge: + """ + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `visit_let()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + return self.visit_let(let_node.fun, args=let_args) + else: + # this lambda node is not a let-statement, but a stencil expression + return self.visit(node) + + +def visit_lambda( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + node: gtir.Lambda, + args: Sequence[IteratorExpr | MemletExpr | SymbolExpr], +) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. + + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. + + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Output data connection. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + output_edge = taskgen.visit_let(node, args) + return taskgen.input_edges, output_edge diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 6b5e164458..9bd40f75f8 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -602,7 +602,7 @@ def visit_Lambda( node: gtir.Lambda, sdfg: dace.SDFG, head_state: dace.SDFGState, - args: list[gtir_builtin_translators.FieldopResult], + args: Sequence[gtir_builtin_translators.FieldopResult], ) -> gtir_builtin_translators.FieldopResult: """ Translates a `Lambda` node to a nested SDFG in the current state. @@ -679,7 +679,7 @@ def get_field_domain_offset( self.offset_provider_type, lambda_symbols, lambda_field_offsets ) nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) - nstate = nsdfg.add_state("lambda") + nsdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) # add sdfg storage for the symbols that need to be passed as input parameters lambda_params = [ @@ -690,6 +690,7 @@ def get_field_domain_offset( nsdfg, node_params=lambda_params, symbolic_arguments=lambda_domain_symbols ) + nstate = nsdfg.add_state("lambda") lambda_result = lambda_translator.visit( node.expr, sdfg=nsdfg, From 77cad7c8862c6164dff5f9e192ffef8fc9a2b1af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:53:40 +0100 Subject: [PATCH 5/6] feat[dace][next]: Fixing strides in optimization (#1782) Added functionality to properly handle changes of strides. During the implementation of the scan we found that the strides were not handled properly. Most importantly a change on one level was not propagated into the next levels, i.e. they were still using the old strides. This PR Solves most of the problems, but there are still some issues that are unsolved: - Views are not adjusted yet (Fixed in [PR@1784](https://github.com/GridTools/gt4py/pull/1784)). - It is not properly checked if the symbols of the propagated strides are safe to introduce into the nested SDFG. The initial functionality of this PR was done by Edoardo Paone (@edopao). --------- Co-authored-by: edopao --- .../transformations/__init__.py | 12 +- .../transformations/gpu_utils.py | 2 +- .../transformations/simplify.py | 5 +- .../dace_fieldview/transformations/strides.py | 611 +++++++++++++++++- .../test_map_buffer_elimination.py | 93 ++- .../transformation_tests/test_strides.py | 541 ++++++++++++++++ 6 files changed, 1238 insertions(+), 26 deletions(-) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..0902bd665a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,13 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +65,10 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..4339a761fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -971,6 +971,9 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..980b2a8fdf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,14 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import Optional, TypeAlias + import dace from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, ) +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the strides, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the inner array name because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + def gt_change_transient_strides( sdfg: dace.SDFG, gpu: bool, @@ -24,6 +40,11 @@ def gt_change_transient_strides( transients in the optimal way. The function should run after all maps have been created. + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + Args: sdfg: The SDFG to process. gpu: If the SDFG is supposed to run on the GPU. @@ -35,8 +56,6 @@ def gt_change_transient_strides( Todo: - Implement the estimation correctly. - - Handle the case of nested SDFGs correctly; on the outside a transient, - but on the inside a non transient. """ # TODO(phimeull): Implement this function correctly. @@ -46,54 +65,608 @@ def gt_change_transient_strides( return sdfg for nsdfg in sdfg.all_sdfgs_recursive(): - # TODO(phimuell): Handle the case when transient goes into nested SDFG - # on the inside it is a non transient, so it is ignored. _gt_change_transient_strides_non_recursive_impl(nsdfg) def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order.""" - for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + # NOTE: Processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than one dimensions ndim = len(desc.shape) if ndim <= 1: continue + # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. + # TODO(phimuell): Improve this. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above applies here, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only one scan is needed. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, + ) + + +def gt_propagate_strides_of( + sdfg: dace.SDFG, + data_name: str, + ignore_symbol_mapping: bool = True, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. + + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + """ + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[PropagatedStrideRecord] = set() + + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode.data != data_name: + continue + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_propagate_strides_from_access_node( + sdfg: dace.SDFG, + state: dace.SDFGState, + outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. + + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + """ + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. + + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + assert edge.src is outer_node + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow + + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], + propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() + + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_src_subset(edge, state) -def _find_toplevel_transients( + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_dst_subset(edge, state) + + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) + + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(process_record) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, + nsdfg_node=nsdfg_node, + inner_data=inner_data, + outer_subset=get_subset(state, edge), + outer_desc=outer_node.desc(sdfg), + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_into_nested_sdfg( sdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, + ignore_symbol_mapping: bool, +) -> None: + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. + + Args: + sdfg: The SDFG containing the NestedSDFG. + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. + + Todo: + - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. + """ + # We need to compute the new strides. In the following we assume that the + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + new_strides: list = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + # This is the case of implicit slicing along one dimension. + pass + else: + # There is inflow into the SDFG, so we need the stride. + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + + # If we have a scalar on the inside, then there is nothing to adjust. + # We could have performed the test above, but doing it here, gives us + # the chance of validating it. + if isinstance(inner_desc, dace_data.Scalar): + if len(new_strides) != 0: + raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") + return + + if not isinstance(inner_desc, dace_data.Array): + raise TypeError( + f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." + ) + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. + # The second way would be to replace `strides` attribute of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init + ): + # Use the symbol + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + # We have to replace the `strides` attribute of the inner descriptor. + inner_desc.set_shape(inner_desc.shape, new_strides) + + # Now find the free symbols that the new strides need. + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + else: + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { + sym + for sym in new_strides_symbols + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) + } + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. + for sym in missing_symbol_mappings: + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) + nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, only_arrays: bool = False, -) -> set[str]: - """Find all top level transients in the SDFG. +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. The function will scan the SDFG, ignoring nested one, and return the - name of all transients that have an access node at the top level. - However, it will ignore access nodes that refers to registers. + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` only include transients. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. """ - top_level_transients: set[str] = set() + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + for state in sdfg.states(): scope_dict = state.scope_dict() for dnode in state.data_nodes(): data: str = dnode.data if scope_dict[dnode] is not None: - if data in top_level_transients: - top_level_transients.remove(data) + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) continue - elif data in top_level_transients: + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) continue + elif gtx_transformations.util.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway. continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." desc: dace_data.Data = dnode.desc(sdfg) - if not desc.transient: + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): continue - elif only_arrays and not isinstance(desc, dace_data.Array): + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is dace.StorageType.Register: continue - top_level_transients.add(data) - return top_level_transients + + # We are only interested in transients + if only_transients and (not desc.transient): + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..5b16e41bc3 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,541 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + # The symbol mapping must should not be updated. + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"]) From e8743dd357656f25c2b73884858bddc56f72d0a0 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 6 Jan 2025 10:38:26 +0100 Subject: [PATCH 6/6] ci: fix boost install in cartesian and daily ci plan (#1787) Boost download link expired, but actually no custom boost (header) installation is required. --- .github/workflows/daily-ci.yml | 7 ------- .github/workflows/test-cartesian.yml | 10 ++-------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 30ad0a6ff9..7ece5a4d5e 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -34,13 +34,6 @@ jobs: shell: bash run: | sudo apt install libboost-dev - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index aa59660a68..f7e78ee6c1 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -29,16 +29,10 @@ jobs: tox-factor: [internal, dace] steps: - uses: actions/checkout@v4 - - name: Install boost + - name: Install C++ libraries shell: bash run: | - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV + sudo apt install libboost-dev - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: