From 447b4673788c920d4e7295d27c635f9a98581ab7 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 5 Dec 2024 15:13:18 +0100 Subject: [PATCH] Address review comments. --- .../next/iterator/transforms/infer_domain.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 60346e069c..0332d60168 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -38,13 +38,13 @@ class DomainAccessDescriptor(eve.StrEnum): NEVER = "never" -NON_TUPLE_DOMAIN = domain_utils.SymbolicDomain | DomainAccessDescriptor +NonTupleDomain: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor #: The domain can also be a tuple of domains, usually this only occurs for scan operators returning #: a tuple since other occurrences for tuples are removed before domain inference. This is #: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just #: fine to a tuple of a vertexn and an edge domain. -DOMAIN: TypeAlias = NON_TUPLE_DOMAIN | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +Domain: TypeAlias = NonTupleDomain | tuple["Domain", ...] +AccessedDomains: TypeAlias = dict[str, Domain] class InferenceOptions(typing.TypedDict): @@ -97,7 +97,7 @@ def _domain_union( return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure(d1: Domain, d2: Domain) -> tuple[Domain, Domain]: """ Given two domains or composites thereof, canonicalize their structure. @@ -138,9 +138,9 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): @@ -155,11 +155,11 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: NON_TUPLE_DOMAIN, + target_domain: NonTupleDomain, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> dict[str, NON_TUPLE_DOMAIN]: - accessed_domains: dict[str, NON_TUPLE_DOMAIN] = {} +) -> dict[str, NonTupleDomain]: + accessed_domains: dict[str, NonTupleDomain] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) @@ -187,12 +187,12 @@ def _extract_accessed_domains( def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: Domain, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], allow_uninferred: bool, -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: @@ -222,12 +222,12 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: dict[str, NON_TUPLE_DOMAIN] = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomain] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( @@ -258,9 +258,9 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, + input_domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy let_params = {param_sym.id for param_sym in let_expr.fun.params} @@ -296,12 +296,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -323,11 +323,11 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) @@ -343,12 +343,12 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) @@ -360,9 +360,9 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, **kwargs: Unpack[InferenceOptions], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): @@ -389,12 +389,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: Domain, *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, allow_uninferred: bool = False, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`.