Skip to content

Commit

Permalink
Address review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 5, 2024
1 parent 7c271ca commit e3306fc
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DomainAccessDescriptor(eve.StrEnum):
Descriptor for domains that could not be inferred.
"""

#: The access if unknown because of a dynamic shift.whose extent is not known.
#: The access is unknown because of a dynamic shift.whose extent is not known.
#: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)`
UNKNOWN = "unknown"
#: The domain is never accessed.
Expand Down Expand Up @@ -83,7 +83,7 @@ def _domain_union(
return DomainAccessDescriptor.UNKNOWN

filtered_domains: list[domain_utils.SymbolicDomain] = [
d # type: ignore[misc] # domain can never be unknown because as these cases are filtered above
d # type: ignore[misc] # domain can never be unknown as these cases are filtered above
for d in domains
if d != DomainAccessDescriptor.NEVER
]
Expand Down Expand Up @@ -153,7 +153,6 @@ def _extract_accessed_domains(
target_domain: domain_utils.SymbolicDomain | DomainAccessDescriptor,
offset_provider: common.OffsetProvider,
symbolic_domain_sizes: Optional[dict[str, str]],
allow_uninferred: bool,
) -> ACCESSED_DOMAINS:
accessed_domains: dict[str, domain_utils.SymbolicDomain | DomainAccessDescriptor] = {}

Expand All @@ -178,6 +177,8 @@ def _extract_accessed_domains(
accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains
)

# Widen type to allow callee to all other types that can be in ACCESSED_DOMAINS, i.e. tuple.
# Fine since we transfer ownership of return value to callee.
return typing.cast(ACCESSED_DOMAINS, accessed_domains)


Expand All @@ -196,10 +197,7 @@ def _infer_as_fieldop(
# FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`.
if isinstance(target_domain, tuple):
target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough
if not isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)):
raise ValueError(
"'target_domain' needs to be a 'domain_utils.SymbolicDomain' or a 'DomainAccessDescriptor'."
)
assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor))

# `as_fieldop(stencil)(inputs...)`
stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args
Expand All @@ -222,7 +220,7 @@ def _infer_as_fieldop(
input_ids.append(id_)

inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains(
stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes, allow_uninferred
stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes
)

# Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s
Expand Down Expand Up @@ -406,7 +404,7 @@ def infer_expr(
- symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol
name that evaluates to the length of that axis.
- allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g.
because of a dynamic shift) or empty.
because of a dynamic shift) or never accessed.
Returns:
A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed)
Expand Down

0 comments on commit e3306fc

Please sign in to comment.