Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 5, 2024
1 parent 2928503 commit 447b467
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class DomainAccessDescriptor(eve.StrEnum):
NEVER = "never"


NON_TUPLE_DOMAIN = domain_utils.SymbolicDomain | DomainAccessDescriptor
NonTupleDomain: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor
#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning
#: a tuple since other occurrences for tuples are removed before domain inference. This is
#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just
#: fine to a tuple of a vertexn and an edge domain.
DOMAIN: TypeAlias = NON_TUPLE_DOMAIN | tuple["DOMAIN", ...]
ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN]
Domain: TypeAlias = NonTupleDomain | tuple["Domain", ...]
AccessedDomains: TypeAlias = dict[str, Domain]


class InferenceOptions(typing.TypedDict):
Expand Down Expand Up @@ -97,7 +97,7 @@ def _domain_union(
return domain_utils.domain_union(*filtered_domains)


def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]:
def _canonicalize_domain_structure(d1: Domain, d2: Domain) -> tuple[Domain, Domain]:
"""
Given two domains or composites thereof, canonicalize their structure.
Expand Down Expand Up @@ -138,9 +138,9 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA


def _merge_domains(
original_domains: ACCESSED_DOMAINS,
additional_domains: ACCESSED_DOMAINS,
) -> ACCESSED_DOMAINS:
original_domains: AccessedDomains,
additional_domains: AccessedDomains,
) -> AccessedDomains:
new_domains = {**original_domains}

for key, domain in additional_domains.items():
Expand All @@ -155,11 +155,11 @@ def _merge_domains(
def _extract_accessed_domains(
stencil: itir.Expr,
input_ids: list[str],
target_domain: NON_TUPLE_DOMAIN,
target_domain: NonTupleDomain,
offset_provider: common.OffsetProvider,
symbolic_domain_sizes: Optional[dict[str, str]],
) -> dict[str, NON_TUPLE_DOMAIN]:
accessed_domains: dict[str, NON_TUPLE_DOMAIN] = {}
) -> dict[str, NonTupleDomain]:
accessed_domains: dict[str, NonTupleDomain] = {}

shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids))

Expand Down Expand Up @@ -187,12 +187,12 @@ def _extract_accessed_domains(

def _infer_as_fieldop(
applied_fieldop: itir.FunCall,
target_domain: DOMAIN,
target_domain: Domain,
*,
offset_provider: common.OffsetProvider,
symbolic_domain_sizes: Optional[dict[str, str]],
allow_uninferred: bool,
) -> tuple[itir.FunCall, ACCESSED_DOMAINS]:
) -> tuple[itir.FunCall, AccessedDomains]:
assert isinstance(applied_fieldop, itir.FunCall)
assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop")
if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER:
Expand Down Expand Up @@ -222,12 +222,12 @@ def _infer_as_fieldop(
raise ValueError(f"Unsupported expression of type '{type(in_field)}'.")
input_ids.append(id_)

inputs_accessed_domains: dict[str, NON_TUPLE_DOMAIN] = _extract_accessed_domains(
inputs_accessed_domains: dict[str, NonTupleDomain] = _extract_accessed_domains(
stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes
)

# Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s
accessed_domains: ACCESSED_DOMAINS = {}
accessed_domains: AccessedDomains = {}
transformed_inputs: list[itir.Expr] = []
for in_field_id, in_field in zip(input_ids, inputs):
transformed_input, accessed_domains_tmp = infer_expr(
Expand Down Expand Up @@ -258,9 +258,9 @@ def _infer_as_fieldop(

def _infer_let(
let_expr: itir.FunCall,
input_domain: DOMAIN,
input_domain: Domain,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.FunCall, ACCESSED_DOMAINS]:
) -> tuple[itir.FunCall, AccessedDomains]:
assert cpm.is_let(let_expr)
assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy
let_params = {param_sym.id for param_sym in let_expr.fun.params}
Expand Down Expand Up @@ -296,12 +296,12 @@ def _infer_let(

def _infer_make_tuple(
expr: itir.Expr,
domain: DOMAIN,
domain: Domain,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
) -> tuple[itir.Expr, AccessedDomains]:
assert cpm.is_call_to(expr, "make_tuple")
infered_args_expr = []
actual_domains: ACCESSED_DOMAINS = {}
actual_domains: AccessedDomains = {}
if not isinstance(domain, tuple):
# promote domain to a tuple of domains such that it has the same structure as
# the expression
Expand All @@ -323,11 +323,11 @@ def _infer_make_tuple(

def _infer_tuple_get(
expr: itir.Expr,
domain: DOMAIN,
domain: Domain,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
) -> tuple[itir.Expr, AccessedDomains]:
assert cpm.is_call_to(expr, "tuple_get")
actual_domains: ACCESSED_DOMAINS = {}
actual_domains: AccessedDomains = {}
idx_expr, tuple_arg = expr.args
assert isinstance(idx_expr, itir.Literal)
idx = int(idx_expr.value)
Expand All @@ -343,12 +343,12 @@ def _infer_tuple_get(

def _infer_if(
expr: itir.Expr,
domain: DOMAIN,
domain: Domain,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
) -> tuple[itir.Expr, AccessedDomains]:
assert cpm.is_call_to(expr, "if_")
infered_args_expr = []
actual_domains: ACCESSED_DOMAINS = {}
actual_domains: AccessedDomains = {}
cond, true_val, false_val = expr.args
for arg in [true_val, false_val]:
infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs)
Expand All @@ -360,9 +360,9 @@ def _infer_if(

def _infer_expr(
expr: itir.Expr,
domain: DOMAIN,
domain: Domain,
**kwargs: Unpack[InferenceOptions],
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
) -> tuple[itir.Expr, AccessedDomains]:
if isinstance(expr, itir.SymRef):
return expr, {str(expr.id): domain}
elif isinstance(expr, itir.Literal):
Expand All @@ -389,12 +389,12 @@ def _infer_expr(

def infer_expr(
expr: itir.Expr,
domain: DOMAIN,
domain: Domain,
*,
offset_provider: common.OffsetProvider,
symbolic_domain_sizes: Optional[dict[str, str]] = None,
allow_uninferred: bool = False,
) -> tuple[itir.Expr, ACCESSED_DOMAINS]:
) -> tuple[itir.Expr, AccessedDomains]:
"""
Infer the domain of all field subexpressions of `expr`.
Expand Down

0 comments on commit 447b467

Please sign in to comment.