Skip to content

Commit

Permalink
Merge remote-tracking branch 'gt4py/main' into dace-fieldview-transfo…
Browse files Browse the repository at this point in the history
…rmations
  • Loading branch information
philip-paul-mueller committed Jul 31, 2024
2 parents 5ed2a8f + 9d1e4e9 commit 368c8ad
Show file tree
Hide file tree
Showing 10 changed files with 1,085 additions and 187 deletions.
5 changes: 0 additions & 5 deletions ci/cscs-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ stages:
variables:
CUDA_VERSION: 12.4.1
CUPY_PACKAGE: cupy-cuda12x
# TODO: re-enable CI job when Todi is back in operational state
when: manual

build_py311_baseimage_x86_64:
extends: .build_baseimage_x86_64
Expand Down Expand Up @@ -176,9 +174,6 @@ build_py38_image_x86_64:
- SUBPACKAGE: next
VARIANT: [-nomesh, -atlas]
SUBVARIANT: [-cuda12x, -cpu]
before_script:
# TODO: remove start of CUDA MPS daemon once CI-CD can handle CRAY_CUDA_MPS
- CUDA_MPS_PIPE_DIRECTORY="/tmp/nvidia-mps" nvidia-cuda-mps-control -d
variables:
# Grace-Hopper gpu architecture is not enabled by default in CUDA build
CUDAARCHS: "90"
Expand Down
10 changes: 10 additions & 0 deletions src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
)


def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `reduce(λ(...) → ...)(...)`."""
return (
isinstance(arg, itir.FunCall)
and isinstance(arg.fun, itir.FunCall)
and isinstance(arg.fun.fun, itir.SymRef)
and arg.fun.fun.id == "reduce"
)


def is_applied_shift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
"""Match expressions of the form `shift(λ(...) → ...)(...)`."""
return (
Expand Down
29 changes: 29 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,32 @@ def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call:
)
)
)


def op_as_fieldop(
op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None
) -> Callable[..., itir.FunCall]:
"""
Promotes a function `op` to a field_operator.
Args:
op: a function from values to value.
domain: the domain of the returned field.
Returns:
A function from Fields to Field.
Examples:
>>> str(op_as_fieldop("op")("a", "b"))
'(⇑(λ(__arg0, __arg1) → op(·__arg0, ·__arg1)))(a, b)'
"""
if isinstance(op, (str, itir.SymRef, itir.Lambda)):
op = call(op)

def _impl(*its: itir.Expr) -> itir.FunCall:
args = [
f"__arg{i}" for i in range(len(its))
] # TODO: `op` must not contain `SymRef(id="__argX")`
return as_fieldop(lambda_(*args)(op(*[deref(arg) for arg in args])), domain)(*its)

return _impl
15 changes: 4 additions & 11 deletions src/gt4py/next/iterator/transforms/fuse_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from gt4py.eve import NodeTranslator, traits
from gt4py.eve.utils import UIDGenerator
from gt4py.next.iterator import ir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.transforms import inline_lambdas


Expand All @@ -29,14 +30,6 @@ def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]:
)


def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]:
return (
isinstance(node, ir.FunCall)
and isinstance(node.fun, ir.FunCall)
and node.fun.fun == ir.SymRef(id="reduce")
)


@dataclasses.dataclass(frozen=True)
class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator):
"""
Expand Down Expand Up @@ -71,7 +64,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda:

def visit_FunCall(self, node: ir.FunCall, **kwargs):
node = self.generic_visit(node)
if _is_map(node) or _is_reduce(node):
if _is_map(node) or cpm.is_applied_reduce(node):
if any(_is_map(arg) for arg in node.args):
first_param = (
0 if _is_map(node) else 1
Expand All @@ -83,7 +76,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
inlined_args = []
new_params = []
new_args = []
if _is_reduce(node):
if cpm.is_applied_reduce(node):
# param corresponding to reduce acc
inlined_args.append(ir.SymRef(id=outer_op.params[0].id))
new_params.append(outer_op.params[0])
Expand Down Expand Up @@ -119,7 +112,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs):
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args
)
else: # _is_reduce(node)
else: # is_applied_reduce(node)
return ir.FunCall(
fun=ir.FunCall(fun=ir.SymRef(id="reduce"), args=[new_op, node.fun.args[1]]),
args=new_args,
Expand Down
9 changes: 3 additions & 6 deletions src/gt4py/next/iterator/transforms/unroll_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.eve.utils import UIDGenerator
from gt4py.next import common
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift


Expand Down Expand Up @@ -60,16 +61,12 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]:
return [_get_partial_offset_tag(arg) for arg in _get_neighbors_args(reduce_args)]


def _is_reduce(node: itir.FunCall) -> TypeGuard[itir.FunCall]:
return isinstance(node.fun, itir.FunCall) and node.fun.fun == itir.SymRef(id="reduce")


def _get_connectivity(
applied_reduce_node: itir.FunCall,
offset_provider: dict[str, common.Dimension | common.Connectivity],
) -> common.Connectivity:
"""Return single connectivity that is compatible with the arguments of the reduce."""
if not _is_reduce(applied_reduce_node):
if not cpm.is_applied_reduce(applied_reduce_node):
raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.")

connectivities: list[common.Connectivity] = []
Expand Down Expand Up @@ -158,6 +155,6 @@ def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr:

def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Expr:
node = self.generic_visit(node, **kwargs)
if _is_reduce(node):
if cpm.is_applied_reduce(node):
return self._visit_reduce(node, **kwargs)
return node
Loading

0 comments on commit 368c8ad

Please sign in to comment.