From 9cc754877270121043ddbc711903dbc36171e657 Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Mon, 29 Apr 2024 15:46:55 +0200 Subject: [PATCH] refactor[next]: workflowify step3 (#1516) ## Description ### New: - `ffront.stages.FieldOperatorDefinition` - all the data to start the toolchain from a field operator dsl definition - `ffront.stages.FoastOperatorDefinition` - data after lowering from field operator dsl code - `ffront.stages.FoastWithTypes` - program argument types in addition to the foast definition for creating a program AST - `ffront.stages.FoastClosure` - program arguments in addition to the foast definition, ready to run the whole toolchain ### Changed: - `decorator.Program.__post_init__` - implementation moved to `past_passes.linters` workflow steps - linting stage added to program transforms - `decorator.FieldOperator.from_function` - implementation moved to workflow step in `ffront.func_to_foast` - `decorator.FieldOperator.as_program` - implementation moved to workflow steps in `ffront.foast_to_past` - `decorator.FieldOperator` data attributes - added: `definition_stage` - removed: - `.foast_node`: replaced with `.foast_stage.foast_node` - `.definition`: replaced with `.definition_stage.definition` - `next.backend.Backend` - renamed: `.transformer` -> `.transforms_prog` - added: `.transforms_fop`, toolchain for starting from field operator - `otf.recipes.FieldOpTransformWorkflow` - now has all the steps from DSL field operator to `ProgramCall` via `foast_to_past`, with additional steps to go to the field operator IteratorIR expression directly instead (not run by default). The latter `foast_to_itir` step is required during lowering of programs that call a field operator. --- .../next/Advanced_ToolchainWalkthrough.md | 514 ++++++++++++++++++ src/gt4py/eve/extended_typing.py | 15 +- src/gt4py/eve/utils.py | 8 +- src/gt4py/next/backend.py | 154 +++++- src/gt4py/next/ffront/decorator.py | 315 +++++------ .../ffront/foast_passes/type_deduction.py | 7 +- src/gt4py/next/ffront/foast_pretty_printer.py | 13 +- src/gt4py/next/ffront/foast_to_itir.py | 5 + src/gt4py/next/ffront/foast_to_past.py | 129 +++++ src/gt4py/next/ffront/func_to_foast.py | 81 ++- src/gt4py/next/ffront/func_to_past.py | 5 +- src/gt4py/next/ffront/past_passes/linters.py | 59 ++ src/gt4py/next/ffront/stages.py | 121 ++++- src/gt4py/next/otf/recipes.py | 33 -- src/gt4py/next/otf/workflow.py | 6 +- .../ffront_tests/ffront_test_utils.py | 2 +- .../ffront_tests/test_foast_pretty_printer.py | 4 +- .../test_math_builtin_execution.py | 22 +- .../unit_tests/ffront_tests/test_stages.py | 162 ++++++ .../unit_tests/otf_tests/test_workflow.py | 2 +- tox.ini | 2 + 21 files changed, 1411 insertions(+), 248 deletions(-) create mode 100644 docs/user/next/Advanced_ToolchainWalkthrough.md create mode 100644 src/gt4py/next/ffront/foast_to_past.py create mode 100644 src/gt4py/next/ffront/past_passes/linters.py create mode 100644 tests/next_tests/unit_tests/ffront_tests/test_stages.py diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/Advanced_ToolchainWalkthrough.md new file mode 100644 index 0000000000..94a7bfa7e2 --- /dev/null +++ b/docs/user/next/Advanced_ToolchainWalkthrough.md @@ -0,0 +1,514 @@ +```python +import dataclasses +import inspect + +import gt4py.next as gtx +from gt4py.next import backend + +import devtools +``` + + + + +```python +I = gtx.Dimension("I") +Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) +OFFSET_PROVIDER = {"Ioff": I} +``` + +# Toolchain Overview + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) +``` + +# Walkthrough from Field Operator + +## Starting Out + +```python +@gtx.field_operator +def example_fo(a: gtx.Field[[I], gtx.float64]) -> gtx.Field[[I], gtx.float64]: + return a + 1.0 +``` + +```python +start = example_fo.definition_stage +``` + +```python +gtx.ffront.stages.FieldOperatorDefinition? +``` + + Init signature: + gtx.ffront.stages.FieldOperatorDefinition( +  definition: 'types.FunctionType', +  grid_type: 'Optional[common.GridType]' = None, +  node_class: 'type[OperatorNodeT]' = <class 'gt4py.next.ffront.field_operator_ast.FieldOperator'>, +  attributes: 'dict[str, Any]' = <factory>, + ) -> None + Docstring: FieldOperatorDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None, node_class: 'type[OperatorNodeT]' = , attributes: 'dict[str, Any]' = ) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## DSL -> FOAST + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fdef fill:red +style foast fill:red +linkStyle 0 stroke:red,stroke-width:4px,color:pink +``` + +```python +foast = backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(start) +``` + +```python +gtx.ffront.stages.FoastOperatorDefinition? +``` + + Init signature: + gtx.ffront.stages.FoastOperatorDefinition( +  foast_node: 'OperatorNodeT', +  closure_vars: 'dict[str, Any]', +  grid_type: 'Optional[common.GridType]' = None, +  attributes: 'dict[str, Any]' = <factory>, + ) -> None + Docstring: FoastOperatorDefinition(foast_node: 'OperatorNodeT', closure_vars: 'dict[str, Any]', grid_type: 'Optional[common.GridType]' = None, attributes: 'dict[str, Any]' = ) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## FOAST -> ITIR + +This also happens inside the `decorator.FieldOperator.__gt_itir__` method during the lowering from calling Programs to ITIR + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style foast fill:red +style itir_expr fill:red +linkStyle 1 stroke:red,stroke-width:4px,color:pink +``` + +```python +fitir = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(foast) +``` + +```python +fitir.__class__ +``` + + gt4py.next.iterator.ir.FunctionDefinition + +## FOAST -> FOAST closure + +This is preparation for "directly calling" a field operator. + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style foast fill:red +style fclos fill:red +linkStyle 2 stroke:red,stroke-width:4px,color:pink +``` + +Here we have to dynamically generate a workflow step, because the arguments were not known before. + +```python +fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_inject_args.__class__( + args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), + kwargs={ + "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64) + }, + from_fieldop=example_fo +)(foast) +``` + +```python +gtx.ffront.stages.FoastClosure? +``` + + Init signature: + gtx.ffront.stages.FoastClosure( +  foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', +  closure_vars: 'dict[str, Any]', + ) -> None + Docstring: FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]') + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## FOAST with args -> PAST closure + +This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fclos fill:red +style pclos fill:red +linkStyle 3 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure(fclos) +``` + +```python +gtx.ffront.stages.PastClosure? +``` + + Init signature: + gtx.ffront.stages.PastClosure( +  closure_vars: 'dict[str, Any]', +  past_node: 'past.Program', +  grid_type: 'Optional[common.GridType]', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', + ) -> None + Docstring: PastClosure(closure_vars: 'dict[str, Any]', past_node: 'past.Program', grid_type: 'Optional[common.GridType]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]') + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## Transform PAST closure arguments + +Don't ask me, seems to be necessary though + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pclos fill:red +%%style pclos fill:red +linkStyle 4 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos) +``` + +## Lower PAST -> ITIR + +still forwarding the call arguments + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pclos fill:red +style pcall fill:red +linkStyle 5 stroke:red,stroke-width:4px,color:pink +``` + +```python +pitir = backend.DEFAULT_PROG_TRANSFORMS.past_to_itir(pclost) +``` + +```python +gtx.otf.stages.ProgramCall? +``` + + Init signature: + gtx.otf.stages.ProgramCall( +  program: 'itir.FencilDefinition', +  args: 'tuple[Any, ...]', +  kwargs: 'dict[str, Any]', + ) -> None + Docstring: Iterator IR representaion of a program together with arguments to be passed to it. + File: ~/Code/gt4py/src/gt4py/next/otf/stages.py + Type: type + Subclasses: + +## Executing The Result + +```python +gtx.program_processors.runners.roundtrip.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs) +``` + +```python +pitir.args +``` + + (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), + NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])), + 10, + 10) + +## Full Field Operator Toolchain + +using the default step order + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style fdef fill:red +style foast fill:red +style fclos fill:red +style pclos fill:red +style pcall fill:red +linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink +``` + +### Starting from DSL + +```python +foast_toolchain = backend.DEFAULT_FIELDOP_TRANSFORMS.replace( + foast_inject_args=backend.FopArgsInjector(args=fclos.args, kwargs=fclos.kwargs, from_fieldop=example_fo) +) +pitir2 = foast_toolchain(start) +assert pitir2 == pitir +``` + +#### Pass The result to the compile workflow and execute + +```python +example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflow( + dataclasses.replace(pitir2, kwargs=pitir2.kwargs | {"offset_provider": OFFSET_PROVIDER}) +) +``` + +```python +example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER) +``` + +```python +example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER) +``` + +```python +pitir2.args[1].asnumpy() +``` + + array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]) + +### Starting from FOAST + +Note that it is the exact same call but with a different input stage + +```python +pitir3 = foast_toolchain(foast) +assert pitir3 == pitir +``` + +# Walkthrough starting from Program + +## Starting Out + +```python +@gtx.program +def example_prog(a: gtx.Field[[I], gtx.float64], out: gtx.Field[[I], gtx.float64]) -> None: + example_fo(a, out=out) +``` + +```python +p_start = example_prog.definition_stage +``` + +```python +gtx.ffront.stages.ProgramDefinition? +``` + + Init signature: + gtx.ffront.stages.ProgramDefinition( +  definition: 'types.FunctionType', +  grid_type: 'Optional[common.GridType]' = None, + ) -> None + Docstring: ProgramDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None) + File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py + Type: type + Subclasses: + +## DSL -> PAST + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pdef fill:red +style past fill:red +linkStyle 6 stroke:red,stroke-width:4px,color:pink +``` + +```python +p_past = backend.DEFAULT_PROG_TRANSFORMS.func_to_past(p_start) +``` + +## PAST -> Closure + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style past fill:red +style pclos fill:red +linkStyle 7 stroke:red,stroke-width:4px,color:pink +``` + +```python +pclos = backend.DEFAULT_PROG_TRANSFORMS.replace( + past_inject_args=backend.ProgArgsInjector( + args=fclos.args, + kwargs=fclos.kwargs + ) +)(p_past) +``` + +## Full Program Toolchain + +```mermaid +graph LR + +fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) +foast -->|foast_to_itir| itir_expr(itir.Expr) +foast -->|foast_inject_args| fclos(FoastClosure) +fclos -->|foast_to_past_closure| pclos(PastClosure) +pclos -->|past_process_args| pclos +pclos -->|past_to_itir| pcall(ProgramCall) + +pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) +past -->|past_lint| past +past -->|past_inject_args| pclos(ProgramClosure) + +style pdef fill:red +style past fill:red +style pclos fill:red +style pcall fill:red +linkStyle 4,5,6,7 stroke:red,stroke-width:4px,color:pink +``` + +### Starting from DSL + +```python +toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( + past_inject_args=backend.ProgArgsInjector( + args=fclos.args, + kwargs=fclos.kwargs + ) +) +``` + +```python +p_itir1 = toolchain(p_start) +``` + +```python +p_itir2 = toolchain(p_past) +``` + +```python +assert p_itir1 == p_itir2 +``` diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 42473bea63..e406a5f097 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -207,15 +207,20 @@ def __delete__(self, _instance: _C) -> None: ... class HashlibAlgorithm(Protocol): """Used in the hashlib module of the standard library.""" - digest_size: int - block_size: int - name: str + @property + def block_size(self) -> int: ... + + @property + def digest_size(self) -> int: ... + + @property + def name(self) -> str: ... def __init__(self, data: ReadableBuffer = ...) -> None: ... - def copy(self) -> HashlibAlgorithm: ... + def copy(self) -> Self: ... - def update(self, data: ReadableBuffer) -> None: ... + def update(self, data: Buffer, /) -> None: ... def digest(self) -> bytes: ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 01c066ca91..d1f9d0f7d5 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -401,12 +401,12 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No """ if hash_algorithm is None: - hash_algorithm = xxhash.xxh64() # type: ignore[assignment] + hash_algorithm = xxhash.xxh64() elif isinstance(hash_algorithm, str): - hash_algorithm = hashlib.new(hash_algorithm) # type: ignore[assignment] + hash_algorithm = hashlib.new(hash_algorithm) - hash_algorithm.update(pickle.dumps(args)) # type: ignore[union-attr] - result = hash_algorithm.hexdigest() # type: ignore[union-attr] + hash_algorithm.update(pickle.dumps(args)) + result = hash_algorithm.hexdigest() assert isinstance(result, str) return result diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index cfa4911b57..3d3c7a27e1 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -19,29 +19,161 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators -from gt4py.next.ffront import func_to_past, past_process_args, past_to_itir, stages as ffront_stages -from gt4py.next.otf import recipes +from gt4py.next.ffront import ( + foast_to_itir, + foast_to_past, + func_to_foast, + func_to_past, + past_process_args, + past_to_itir, + stages as ffront_stages, +) +from gt4py.next.ffront.past_passes import linters as past_linters +from gt4py.next.iterator import ir as itir +from gt4py.next.otf import stages, workflow from gt4py.next.program_processors import processor_interface as ppi -DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow( - func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True), - past_transform_args=past_process_args.past_process_args, - past_to_itir=past_to_itir.PastToItirFactory(), -) +@dataclasses.dataclass(frozen=True) +class FopArgsInjector(workflow.Workflow): + args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + from_fieldop: Any = None + + def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: + return ffront_stages.FoastClosure( + foast_op_def=inp, + args=self.args, + kwargs=self.kwargs, + closure_vars={inp.foast_node.id: self.from_fieldop}, + ) + + +@dataclasses.dataclass(frozen=True) +class FieldopTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" + + func_to_foast: workflow.SkippableStep[ + ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition, + ffront_stages.FoastOperatorDefinition, + ] = dataclasses.field( + default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True) + ) + foast_inject_args: workflow.Workflow[ + ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure + ] = dataclasses.field(default_factory=FopArgsInjector) + foast_to_past_closure: workflow.Workflow[ + ffront_stages.FoastClosure, ffront_stages.PastClosure + ] = dataclasses.field( + default_factory=lambda: foast_to_past.FoastToPastClosure( + foast_to_past=workflow.CachedStep( + foast_to_past.foast_to_past, hash_function=ffront_stages.fingerprint_stage + ) + ) + ) + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( + dataclasses.field(default=past_process_args.past_process_args) + ) + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( + dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + ) + + foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( + dataclasses.field( + default_factory=lambda: workflow.CachedStep( + step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.fingerprint_stage + ) + ) + ) + + @property + def step_order(self) -> list[str]: + return [ + "func_to_foast", + "foast_inject_args", + "foast_to_past_closure", + "past_transform_args", + "past_to_itir", + ] + + +DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow() + + +@dataclasses.dataclass(frozen=True) +class ProgArgsInjector(workflow.Workflow): + args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) + kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) + + def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure: + return ffront_stages.PastClosure( + past_node=inp.past_node, + closure_vars=inp.closure_vars, + grid_type=inp.grid_type, + args=self.args, + kwargs=self.kwargs, + ) + + +@dataclasses.dataclass(frozen=True) +class ProgramTransformWorkflow(workflow.NamedStepSequence): + """Modular workflow for transformations with access to intermediates.""" + + func_to_past: workflow.SkippableStep[ + ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, + ffront_stages.PastProgramDefinition, + ] = dataclasses.field( + default_factory=lambda: func_to_past.OptionalFuncToPastFactory(cached=True) + ) + past_lint: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition + ] = dataclasses.field(default_factory=past_linters.LinterFactory) + past_inject_args: workflow.Workflow[ + ffront_stages.PastProgramDefinition, ffront_stages.PastClosure + ] = dataclasses.field(default_factory=ProgArgsInjector) + past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( + dataclasses.field(default=past_process_args.past_process_args) + ) + past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( + dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + ) + + +DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow() @dataclasses.dataclass(frozen=True) class Backend(Generic[core_defs.DeviceTypeT]): executor: ppi.ProgramExecutor allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT] - transformer: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS + transforms_fop: FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS + transforms_prog: ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS def __call__( - self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any] + self, + program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition, + *args: tuple[Any], + **kwargs: dict[str, Any], ) -> None: - transformer = self.transformer.replace(args=args, kwargs=kwargs) - program_call = transformer(program) + if isinstance( + program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition) + ): + offset_provider = kwargs.pop("offset_provider") + from_fieldop = kwargs.pop("from_fieldop") + transforms_fop = self.transforms_fop.replace( + foast_inject_args=FopArgsInjector( + args=args, kwargs=kwargs, from_fieldop=from_fieldop + ) + ) + program_call = transforms_fop(program) + program_call = dataclasses.replace( + program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider} + ) + else: + transforms_prog = self.transforms_prog.replace( + past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs) + ) + program_call = transforms_prog(program) self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a01cf0959a..be1b3c1fa8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -24,12 +24,11 @@ import typing import warnings from collections.abc import Callable -from typing import Generic, TypeVar +from typing import Any, Generic, Optional, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.eve import utils as eve_utils -from gt4py.eve.extended_typing import Any, Optional +from gt4py.eve import extended_typing as xtyping from gt4py.next import ( allocators as next_allocators, backend as next_backend, @@ -39,24 +38,14 @@ from gt4py.next.common import Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( - dialect_ast_enums, field_operator_ast as foast, past_process_args, past_to_itir, - program_ast as past, stages as ffront_stages, transform_utils, type_specifications as ts_ffront, ) -from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.ffront.gtcallable import GTCallable -from gt4py.next.ffront.past_passes.closure_var_type_deduction import ( - ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction, -) -from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction -from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils.ir_makers import ( literal_from_value, @@ -111,33 +100,16 @@ def definition(self): @functools.cached_property def past_stage(self): - if self.backend is not None and self.backend.transformer is not None: - return self.backend.transformer.func_to_past(self.definition_stage) - return next_backend.DEFAULT_TRANSFORMS.func_to_past(self.definition_stage) + # backwards compatibility for backends that do not support the full toolchain + if self.backend is not None and self.backend.transforms_prog is not None: + return self.backend.transforms_prog.func_to_past(self.definition_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage) + # TODO(ricoh): linting should become optional, up to the backend. def __post_init__(self): - function_closure_vars = transform_utils._filter_closure_vars_by_type( - self.past_stage.closure_vars, GTCallable - ) - misnamed_functions = [ - f"{name} vs. {func.id}" - for name, func in function_closure_vars.items() - if name != func.__gt_itir__().id - ] - if misnamed_functions: - raise RuntimeError( - f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." - ) - - undefined_symbols = [ - symbol.id - for symbol in self.past_stage.past_node.closure_vars - if symbol.id not in self.past_stage.closure_vars - ] - if undefined_symbols: - raise RuntimeError( - f"The following closure variables are undefined: {', '.join(undefined_symbols)}." - ) + if self.backend is not None and self.backend.transforms_prog is not None: + self.backend.transforms_prog.past_lint(self.past_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) @property def __name__(self) -> str: @@ -207,8 +179,8 @@ def itir(self) -> itir.FencilDefinition: args=[], kwargs={}, ) - if self.backend is not None and self.backend.transformer is not None: - return self.backend.transformer.past_to_itir(no_args_past).program + if self.backend is not None and self.backend.transforms_prog is not None: + return self.backend.transforms_prog.past_to_itir(no_args_past).program return past_to_itir.PastToItirFactory()(no_args_past).program def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: @@ -236,6 +208,13 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) @dataclasses.dataclass(frozen=True) class ProgramFromPast(Program): + """ + This version of program has no DSL definition associated with it. + + PAST nodes can be built programmatically from field operators or from scratch. + This wrapper provides the appropriate toolchain entry points. + """ + past_stage: ffront_stages.PastProgramDefinition def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): @@ -247,6 +226,12 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor) self.backend(self.past_stage, *args, **(kwargs | {"offset_provider": offset_provider})) + # TODO(ricoh): linting should become optional, up to the backend. + def __post_init__(self): + if self.backend is not None and self.backend.transforms_prog is not None: + self.backend.transforms_prog.past_lint(self.past_stage) + return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage) + @dataclasses.dataclass(frozen=True) class ProgramWithBoundArgs(Program): @@ -391,12 +376,8 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): it will be deduced from actually occurring dimensions. """ - foast_node: OperatorNodeT - closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] + definition_stage: ffront_stages.FieldOperatorDefinition backend: Optional[ppi.ProgramExecutor] - grid_type: Optional[GridType] - operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field( init=False, default_factory=dict ) # init=False ensure the cache is not copied in calls to replace @@ -411,39 +392,37 @@ def from_function( operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, ) -> FieldOperator[OperatorNodeT]: - operator_attributes = operator_attributes or {} - - source_def = SourceDefinition.from_function(definition) - closure_vars = get_closure_vars_from_function(definition) - annotations = typing.get_type_hints(definition) - foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) - loc = foast_definition_node.location - operator_attribute_nodes = { - key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) - for key, value in operator_attributes.items() - } - untyped_foast_node = operator_node_cls( - id=foast_definition_node.id, - definition=foast_definition_node, - location=loc, - **operator_attribute_nodes, - ) - foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node) return cls( - foast_node=foast_node, - closure_vars=closure_vars, - definition=definition, + definition_stage=ffront_stages.FieldOperatorDefinition( + definition=definition, + grid_type=grid_type, + node_class=operator_node_cls, + attributes=operator_attributes or {}, + ), backend=backend, - grid_type=grid_type, - operator_attributes=operator_attributes, ) + # TODO(ricoh): linting should become optional, up to the backend. + def __post_init__(self): + """This ensures that DSL linting occurs at decoration time.""" + _ = self.foast_stage + + @functools.cached_property + def foast_stage(self) -> ffront_stages.FoastOperatorDefinition: + if self.backend is not None and self.backend.transforms_fop is not None: + return self.backend.transforms_fop.func_to_foast(self.definition_stage) + return next_backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(self.definition_stage) + @property def __name__(self) -> str: - return self.definition.__name__ + return self.definition_stage.definition.__name__ + + @property + def definition(self) -> str: + return self.definition_stage.definition def __gt_type__(self) -> ts.CallableType: - type_ = self.foast_node.type + type_ = self.foast_stage.foast_node.type assert isinstance(type_, ts.CallableType) return type_ @@ -451,98 +430,46 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> FieldOperator: return dataclasses.replace(self, backend=backend) def with_grid_type(self, grid_type: GridType) -> FieldOperator: - return dataclasses.replace(self, grid_type=grid_type) + return dataclasses.replace( + self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) + ) def __gt_itir__(self) -> itir.FunctionDefinition: - if hasattr(self, "__cached_itir"): - return getattr(self, "__cached_itir") - - itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply(self.foast_node) - - object.__setattr__(self, "__cached_itir", itir_node) - - return itir_node + if self.backend is not None and self.backend.transforms_fop is not None: + return self.backend.transforms_fop.foast_to_itir(self.foast_stage) + return next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: - return self.closure_vars + return self.foast_stage.closure_vars def as_program( self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec] ) -> Program: - # TODO(tehrengruber): implement mechanism to deduce default values - # of arg and kwarg types - # TODO(tehrengruber): check foast operator has no out argument that clashes - # with the out argument of the program we generate here. - hash_ = eve_utils.content_hash( - (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items())) - ) - try: - return self._program_cache[hash_] - except KeyError: - pass - - loc = self.foast_node.location - # use a new UID generator to allow caching - param_sym_uids = eve_utils.UIDGenerator() - - type_ = self.__gt_type__() - params_decl: list[past.Symbol] = [ - past.DataSymbol( - id=param_sym_uids.sequential_id(prefix="__sym"), - type=arg_type, - namespace=dialect_ast_enums.Namespace.LOCAL, - location=loc, - ) - for arg_type in arg_types - ] - params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] - out_sym: past.Symbol = past.DataSymbol( - id="out", - type=type_info.return_type(type_, with_args=arg_types, with_kwargs=kwarg_types), - namespace=dialect_ast_enums.Namespace.LOCAL, - location=loc, + foast_with_types = ( + ffront_stages.FoastWithTypes( + foast_op_def=self.foast_stage, + arg_types=tuple(arg_types), + kwarg_types=kwarg_types, + closure_vars={self.foast_stage.foast_node.id: self}, + ), ) - out_ref = past.Name(id="out", location=loc) - - if self.foast_node.id in self.closure_vars: - raise RuntimeError("A closure variable has the same name as the field operator itself.") - closure_vars = {self.foast_node.id: self} - closure_symbols = [ - past.Symbol( - id=self.foast_node.id, - type=ts.DeferredType(constraint=None), - namespace=dialect_ast_enums.Namespace.CLOSURE, - location=loc, + past_stage = None + if self.backend is not None and self.backend.transforms_fop is not None: + past_stage = self.backend.transforms_fop.foast_to_past_closure.foast_to_past( + foast_with_types ) - ] - - untyped_past_node = past.Program( - id=f"__field_operator_{self.foast_node.id}", - type=ts.DeferredType(constraint=ts_ffront.ProgramType), - params=[*params_decl, out_sym], - body=[ - past.Call( - func=past.Name(id=self.foast_node.id, location=loc), - args=params_ref, - kwargs={"out": out_ref}, - location=loc, + else: + past_stage = ( + next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure.foast_to_past( + ffront_stages.FoastWithTypes( + foast_op_def=self.foast_stage, + arg_types=tuple(arg_types), + kwarg_types=kwarg_types, + closure_vars={self.foast_stage.foast_node.id: self}, + ), ) - ], - closure_vars=closure_symbols, - location=loc, - ) - untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars) - past_node = ProgramTypeDeduction.apply(untyped_past_node) - - self._program_cache[hash_] = ProgramFromPast( - definition_stage=None, - past_stage=ffront_stages.PastProgramDefinition( - past_node=past_node, closure_vars=closure_vars, grid_type=self.grid_type - ), - backend=self.backend, - ) - - return self._program_cache[hash_] + ) + return ProgramFromPast(definition_stage=None, past_stage=past_stage, backend=self.backend) def __call__(self, *args, **kwargs) -> None: if not next_embedded.context.within_valid_context() and self.backend is not None: @@ -554,36 +481,56 @@ def __call__(self, *args, **kwargs) -> None: if "out" not in kwargs: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs + args, kwargs = type_info.canonicalize_arguments( + self.foast_stage.foast_node.type, args, kwargs + ) + return self.backend( + self.definition_stage, + *args, + out=out, + offset_provider=offset_provider, + from_fieldop=self, + **kwargs, ) else: - if self.operator_attributes is not None and any( + attributes = ( + self.definition_stage.attributes + if self.definition_stage + else self.foast_stage.attributes + ) + if attributes is not None and any( has_scan_op_attribute := [ - attribute in self.operator_attributes - for attribute in ["init", "axis", "forward"] + attribute in attributes for attribute in ["init", "axis", "forward"] ] ): assert all(has_scan_op_attribute) - forward = self.operator_attributes["forward"] - init = self.operator_attributes["init"] - axis = self.operator_attributes["axis"] - op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + forward = attributes["forward"] + init = attributes["init"] + axis = attributes["axis"] + op = embedded_operators.ScanOperator( + self.definition_stage.definition, forward, init, axis + ) else: - op = embedded_operators.EmbeddedOperator(self.definition) + op = embedded_operators.EmbeddedOperator(self.definition_stage.definition) return embedded_operators.field_operator_call(op, args, kwargs) +@dataclasses.dataclass(frozen=True) +class FieldOperatorFromFoast(FieldOperator): + """ + This version of the field operator does not have a DSL definition. + + FieldOperator AST nodes can be programmatically built, which may be + particularly useful in testing and debugging. + This class provides the appropriate toolchain entry points. + """ + + foast_stage: ffront_stages.FoastOperatorDefinition + + def __call__(self, *args, **kwargs) -> None: + return self.backend(self.foast_stage, *args, from_fieldop=self, **kwargs) + + @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] @@ -695,3 +642,29 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: ) return scan_operator_inner if definition is None else scan_operator_inner(definition) + + +@ffront_stages.add_content_to_fingerprint.register +def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None: + ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_foast_fieldop_to_fingerprint( + obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm +) -> None: + ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None: + ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) + + +@ffront_stages.add_content_to_fingerprint.register +def add_past_program_to_fingerprint(obj: ProgramFromPast, hasher: xtyping.HashlibAlgorithm) -> None: + ffront_stages.add_content_to_fingerprint(obj.past_stage, hasher) + ffront_stages.add_content_to_fingerprint(obj.backend, hasher) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 471840ff1b..fad4df8c84 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, cast +from typing import Any, Optional, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -29,6 +29,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) + + def with_altered_scalar_kind( type_spec: ts.TypeSpec, new_scalar_kind: ts.ScalarKind ) -> ts.ScalarType | ts.FieldType: @@ -247,7 +250,7 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla """ @classmethod - def apply(cls, node: foast.FunctionDefinition) -> foast.FunctionDefinition: + def apply(cls, node: OperatorNodeT) -> OperatorNodeT: typed_foast_node = cls().visit(node) FieldOperatorTypeDeductionCompletnessValidator.apply(typed_foast_node) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index e589ecb601..6194647e1f 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -126,6 +126,17 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o UnaryOp = as_fmt("{op}{operand}") + IfStmt = as_fmt( + textwrap.dedent( + """ + if {condition}: + {true_branch} + else: + {false_branch} + """ + ).strip() + ) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: if node.op is dialect_ast_enums.UnaryOperator.NOT: op = "not " @@ -234,7 +245,7 @@ def pretty_format(node: foast.LocatedNode) -> str: >>> @field_operator ... def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]: ... return a + 1.0 - >>> print(pretty_format(field_op.foast_node)) + >>> print(pretty_format(field_op.foast_stage.foast_node)) @field_operator def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]: return a + 1.0 diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 80c0f1fea3..4a2a043fcc 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -23,6 +23,7 @@ fbuiltins, field_operator_ast as foast, lowering_utils, + stages as ffront_stages, type_specifications as ts_ffront, ) from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES @@ -33,6 +34,10 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr: + return FieldOperatorLowering.apply(inp.foast_node) + + def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: if not type_info.contains_local_field(node.type): return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py new file mode 100644 index 0000000000..b2e6324860 --- /dev/null +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -0,0 +1,129 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses + +from gt4py.eve import utils as eve_utils +from gt4py.next.ffront import ( + dialect_ast_enums, + program_ast as past, + stages as ffront_stages, + type_specifications as ts_ffront, +) +from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction +from gt4py.next.otf import workflow +from gt4py.next.type_system import type_info, type_specifications as ts, type_translation + + +def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition: + # TODO(tehrengruber): implement mechanism to deduce default values + # of arg and kwarg types + # TODO(tehrengruber): check foast operator has no out argument that clashes + # with the out argument of the program we generate here. + + loc = inp.foast_op_def.foast_node.location + # use a new UID generator to allow caching + param_sym_uids = eve_utils.UIDGenerator() + + type_ = inp.foast_op_def.foast_node.type + params_decl: list[past.Symbol] = [ + past.DataSymbol( + id=param_sym_uids.sequential_id(prefix="__sym"), + type=arg_type, + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + for arg_type in inp.arg_types + ] + params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl] + out_sym: past.Symbol = past.DataSymbol( + id="out", + type=type_info.return_type( + type_, with_args=list(inp.arg_types), with_kwargs=inp.kwarg_types + ), + namespace=dialect_ast_enums.Namespace.LOCAL, + location=loc, + ) + out_ref = past.Name(id="out", location=loc) + + if inp.foast_op_def.foast_node.id in inp.foast_op_def.closure_vars: + raise RuntimeError("A closure variable has the same name as the field operator itself.") + closure_symbols: list[past.Symbol] = [ + past.Symbol( + id=inp.foast_op_def.foast_node.id, + type=ts.DeferredType(constraint=None), + namespace=dialect_ast_enums.Namespace.CLOSURE, + location=loc, + ), + ] + + untyped_past_node = past.Program( + id=f"__field_operator_{inp.foast_op_def.foast_node.id}", + type=ts.DeferredType(constraint=ts_ffront.ProgramType), + params=[*params_decl, out_sym], + body=[ + past.Call( + func=past.Name(id=inp.foast_op_def.foast_node.id, location=loc), + args=params_ref, + kwargs={"out": out_ref}, + location=loc, + ) + ], + closure_vars=closure_symbols, + location=loc, + ) + untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply( + untyped_past_node, inp.closure_vars + ) + past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node) + + return ffront_stages.PastProgramDefinition( + past_node=past_node, + closure_vars=inp.closure_vars, + grid_type=inp.foast_op_def.grid_type, + ) + + +@dataclasses.dataclass(frozen=True) +class FoastToPastClosure(workflow.NamedStepSequence): + foast_to_past: workflow.Workflow[ + ffront_stages.FoastWithTypes, ffront_stages.PastProgramDefinition + ] + + def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure: + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in inp.args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in inp.kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + past_def = super().__call__( + ffront_stages.FoastWithTypes( + foast_op_def=inp.foast_op_def, + arg_types=tuple(arg_types), + kwarg_types=kwarg_types, + closure_vars=inp.closure_vars, + ) + ) + + return ffront_stages.PastClosure( + past_node=past_def.past_node, + closure_vars=past_def.closure_vars, + grid_type=past_def.grid_type, + args=inp.args, + kwargs=inp.kwargs, + ) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index 9f24dbf6db..6127cbdef5 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -16,11 +16,21 @@ import ast import builtins -from typing import Any, Callable, Iterable, Mapping, Type, cast +import dataclasses +import typing +from typing import Any, Callable, Iterable, Mapping, Type + +import factory import gt4py.eve as eve from gt4py.next import errors -from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast +from gt4py.next.ffront import ( + dialect_ast_enums, + fbuiltins, + field_operator_ast as foast, + source_utils, + stages as ffront_stages, +) from gt4py.next.ffront.ast_passes import ( SingleAssignTargetPass, SingleStaticAssignPass, @@ -35,9 +45,74 @@ from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction +from gt4py.next.otf import workflow from gt4py.next.type_system import type_info, type_specifications as ts, type_translation +@workflow.make_step +def func_to_foast( + inp: ffront_stages.FieldOperatorDefinition[ffront_stages.OperatorNodeT], +) -> ffront_stages.FoastOperatorDefinition[ffront_stages.OperatorNodeT]: + source_def = source_utils.SourceDefinition.from_function(inp.definition) + closure_vars = source_utils.get_closure_vars_from_function(inp.definition) + annotations = typing.get_type_hints(inp.definition) + foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations) + loc = foast_definition_node.location + operator_attribute_nodes = { + key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc) + for key, value in inp.attributes.items() + } + untyped_foast_node = inp.node_class( + id=foast_definition_node.id, + definition=foast_definition_node, + location=loc, + **operator_attribute_nodes, + ) + foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node) + return ffront_stages.FoastOperatorDefinition( + foast_node=foast_node, + closure_vars=closure_vars, + grid_type=inp.grid_type, + attributes=inp.attributes, + ) + + +@dataclasses.dataclass(frozen=True) +class OptionalFuncToFoast(workflow.SkippableStep): + step: workflow.Workflow[ + ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition + ] = func_to_foast + + def skip_condition( + self, inp: ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition + ) -> bool: + match inp: + case ffront_stages.FieldOperatorDefinition(): + return False + case ffront_stages.FoastOperatorDefinition(): + return True + + +@dataclasses.dataclass(frozen=True) +class OptionalFuncToFoastFactory(factory.Factory): + class Meta: + model = OptionalFuncToFoast + + class Params: + workflow: workflow.Workflow[ + ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition + ] = func_to_foast + cached = factory.Trait( + step=factory.LazyAttribute( + lambda o: workflow.CachedStep( + step=o.workflow, hash_function=ffront_stages.fingerprint_stage + ) + ) + ) + + step = factory.LazyAttribute(lambda o: o.workflow) + + class FieldOperatorParser(DialectParser[foast.FunctionDefinition]): """ Parse field operator function definition from source code into FOAST. @@ -141,7 +216,7 @@ def _builtin_type_constructor_symbols( ], # this is a constraint type that will not be inferred (as the function is polymorphic) pos_or_kw_args={}, kw_only_args={}, - returns=cast(ts.DataType, type_translation.from_type_hint(value)), + returns=typing.cast(ts.DataType, type_translation.from_type_hint(value)), ), namespace=dialect_ast_enums.Namespace.CLOSURE, location=location, diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 6864993f4c..372386aaf4 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -21,7 +21,6 @@ import factory -from gt4py.eve import utils as eve_utils from gt4py.next import errors from gt4py.next.ffront import ( dialect_ast_enums, @@ -73,7 +72,9 @@ class Params: workflow = func_to_past cached = factory.Trait( step=factory.LazyAttribute( - lambda o: workflow.CachedStep(step=o.workflow, hash_function=eve_utils.content_hash) + lambda o: workflow.CachedStep( + step=o.workflow, hash_function=ffront_stages.fingerprint_stage + ) ) ) diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py new file mode 100644 index 0000000000..6e77262fd1 --- /dev/null +++ b/src/gt4py/next/ffront/past_passes/linters.py @@ -0,0 +1,59 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import factory + +from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils +from gt4py.next.otf import workflow + + +@workflow.make_step +def lint_misnamed_functions( + inp: ffront_stages.PastProgramDefinition, +) -> ffront_stages.PastProgramDefinition: + function_closure_vars = transform_utils._filter_closure_vars_by_type( + inp.closure_vars, gtcallable.GTCallable + ) + misnamed_functions = [ + f"{name} vs. {func.__gt_itir__().id}" + for name, func in function_closure_vars.items() + if name != func.__gt_itir__().id + ] + if misnamed_functions: + raise RuntimeError( + f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." + ) + return inp + + +@workflow.make_step +def lint_undefined_symbols( + inp: ffront_stages.PastProgramDefinition, +) -> ffront_stages.PastProgramDefinition: + undefined_symbols = [ + symbol.id for symbol in inp.past_node.closure_vars if symbol.id not in inp.closure_vars + ] + if undefined_symbols: + raise RuntimeError( + f"The following closure variables are undefined: {', '.join(undefined_symbols)}." + ) + return inp + + +class LinterFactory(factory.Factory): + class Meta: + model = workflow.CachedStep + + step = lint_misnamed_functions.chain(lint_undefined_symbols) + hash_function = ffront_stages.fingerprint_stage diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index ed7c65c0af..1da6c85981 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -14,12 +14,59 @@ from __future__ import annotations +import collections import dataclasses +import functools +import hashlib import types -from typing import Any, Optional +import typing +from typing import Any, Generic, Optional, TypeVar +import xxhash + +from gt4py.eve import extended_typing as xtyping from gt4py.next import common -from gt4py.next.ffront import program_ast as past +from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils +from gt4py.next.type_system import type_specifications as ts + + +if typing.TYPE_CHECKING: + pass + + +OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode) + + +@dataclasses.dataclass(frozen=True) +class FieldOperatorDefinition(Generic[OperatorNodeT]): + definition: types.FunctionType + grid_type: Optional[common.GridType] = None + node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass(frozen=True) +class FoastOperatorDefinition(Generic[OperatorNodeT]): + foast_node: OperatorNodeT + closure_vars: dict[str, Any] + grid_type: Optional[common.GridType] = None + attributes: dict[str, Any] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass(frozen=True) +class FoastWithTypes(Generic[OperatorNodeT]): + foast_op_def: FoastOperatorDefinition[OperatorNodeT] + arg_types: tuple[ts.TypeSpec, ...] + kwarg_types: dict[str, ts.TypeSpec] + closure_vars: dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class FoastClosure(Generic[OperatorNodeT]): + foast_op_def: FoastOperatorDefinition[OperatorNodeT] + args: tuple[Any, ...] + kwargs: dict[str, Any] + closure_vars: dict[str, Any] @dataclasses.dataclass(frozen=True) @@ -42,3 +89,73 @@ class PastClosure: grid_type: Optional[common.GridType] args: tuple[Any, ...] kwargs: dict[str, Any] + + +def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: + hasher: xtyping.HashlibAlgorithm + if not algorithm: + hasher = xxhash.xxh64() + elif isinstance(algorithm, str): + hasher = hashlib.new(algorithm) + else: + hasher = algorithm + + add_content_to_fingerprint(obj, hasher) + return hasher.hexdigest() + + +@functools.singledispatch +def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: + hasher.update(str(obj).encode()) + + +for t in (str, int): + add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object]) + + +@add_content_to_fingerprint.register(FieldOperatorDefinition) +@add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(FoastWithTypes) +@add_content_to_fingerprint.register(FoastClosure) +@add_content_to_fingerprint.register(ProgramDefinition) +@add_content_to_fingerprint.register(PastProgramDefinition) +@add_content_to_fingerprint.register(PastClosure) +def add_content_to_fingerprint_stages(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None: + add_content_to_fingerprint(obj.__class__, hasher) + for field in dataclasses.fields(obj): + add_content_to_fingerprint(getattr(obj, field.name), hasher) + + +@add_content_to_fingerprint.register +def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None: + sourcedef = source_utils.SourceDefinition.from_function(obj) + for item in sourcedef: + add_content_to_fingerprint(item, hasher) + + +@add_content_to_fingerprint.register +def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: + for key, value in obj.items(): + add_content_to_fingerprint(key, hasher) + add_content_to_fingerprint(value, hasher) + + +@add_content_to_fingerprint.register +def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None: + hasher.update(obj.__name__.encode()) + + +@add_content_to_fingerprint.register +def add_sequence_to_fingerprint( + obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm +) -> None: + for item in obj: + add_content_to_fingerprint(item, hasher) + + +@add_content_to_fingerprint.register +def add_foast_located_node_to_fingerprint( + obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm +) -> None: + add_content_to_fingerprint(obj.location, hasher) + add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py index 41d0c8947f..982e2e9b7b 100644 --- a/src/gt4py/next/otf/recipes.py +++ b/src/gt4py/next/otf/recipes.py @@ -14,43 +14,10 @@ from __future__ import annotations import dataclasses -from typing import Any -from gt4py.next.ffront import stages as ffront_stages from gt4py.next.otf import stages, step_types, workflow -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequence): - """Modular workflow for transformations with access to intermediates.""" - - func_to_past: workflow.SkippableStep[ - ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition, - ffront_stages.PastProgramDefinition, - ] - past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] - past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] - - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __call__( - self, inp: ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition - ) -> stages.ProgramCall: - past_stage = self.func_to_past(inp) - return self.past_to_itir( - self.past_transform_args( - ffront_stages.PastClosure( - past_node=past_stage.past_node, - closure_vars=past_stage.closure_vars, - grid_type=past_stage.grid_type, - args=self.args, - kwargs=self.kwargs, - ) - ) - ) - - @dataclasses.dataclass(frozen=True) class OTFCompileWorkflow(workflow.NamedStepSequence): """The typical compiled backend steps composed into a workflow.""" diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 8ae741195f..c83748dece 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -84,8 +84,10 @@ def replace(self, **kwargs: Any) -> Self: return dataclasses.replace(self, **kwargs) -class ChainableWorkflowMixin(Workflow[StartT, EndT]): - def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[StartT, NewEndT]: +class ChainableWorkflowMixin(Workflow[StartT, EndT_co], Protocol[StartT, EndT_co]): + def chain( + self, next_step: Workflow[EndT_co, NewEndT] + ) -> ChainableWorkflowMixin[StartT, NewEndT]: return make_step(self).chain(next_step) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 388849bf09..840f6d6143 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -49,7 +49,7 @@ def __call__(self, program, *args, **kwargs) -> None: raise ValueError("No backend selected! Backend selection is mandatory in tests.") -no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None) +no_backend = NoBackend(executor=no_exec, transforms_prog=None, allocator=None) OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index c1bee4fa2f..77ae302efa 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -69,7 +69,7 @@ def bar(inp1: Field[[I], int64], inp2: Field[[I], int64]) -> Field[[I], int64]: """ ).strip() - assert pretty_format(bar.foast_node) == expected + assert pretty_format(bar.foast_stage.foast_node) == expected def test_scanop(): @@ -89,4 +89,4 @@ def scan(inp: int32) -> int32: """ ).strip() - assert pretty_format(scan.foast_node) == expected + assert pretty_format(scan.foast_stage.foast_node) == expected diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index e076ec4227..3419930588 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -18,9 +18,13 @@ import numpy as np import pytest -import gt4py.next as gtx -from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast -from gt4py.next.ffront.decorator import FieldOperator +from gt4py.next.ffront import ( + decorator, + dialect_ast_enums, + fbuiltins, + field_operator_ast as foast, + stages as ffront_stages, +) from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_translation @@ -107,12 +111,14 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.Program ) typed_foast_node = FieldOperatorTypeDeduction.apply(foast_node) - return FieldOperator( - foast_node=typed_foast_node, - closure_vars=closure_vars, - definition=None, + return decorator.FieldOperatorFromFoast( + definition_stage=None, + foast_stage=ffront_stages.FoastOperatorDefinition( + foast_node=typed_foast_node, + closure_vars=closure_vars, + grid_type=None, + ), backend=backend, - grid_type=None, ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py new file mode 100644 index 0000000000..67ac96d653 --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -0,0 +1,162 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from gt4py import next as gtx +from gt4py.next.ffront import stages + + +@pytest.fixture +def idim(): + yield gtx.Dimension("I") + + +@pytest.fixture +def jdim(): + yield gtx.Dimension("J") + + +@pytest.fixture +def fieldop(idim): + @gtx.field_operator + def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def samecode_fieldop(idim): + @gtx.field_operator + def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def different_fieldop(jdim): + @gtx.field_operator + def copy(a: gtx.Field[[jdim], gtx.int32]) -> gtx.Field[[jdim], gtx.int32]: + return a + + yield copy + + +@pytest.fixture +def program(fieldop, idim): + copy = fieldop + + @gtx.program + def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +@pytest.fixture +def samecode_program(samecode_fieldop, idim): + copy = samecode_fieldop + + @gtx.program + def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +@pytest.fixture +def different_program(different_fieldop, jdim): + copy = different_fieldop + + @gtx.program + def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int32]): + copy(a, out=out) + + yield copy_program + + +def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): + assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( + fieldop.definition_stage + ) + assert stages.fingerprint_stage(different_fieldop.definition_stage) != stages.fingerprint_stage( + fieldop.definition_stage + ) + + +def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): + foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage) + samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( + samecode_fieldop.definition_stage + ) + different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( + different_fieldop.definition_stage + ) + + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) + + +def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): + foast_closure = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, + from_fieldop=fieldop, + ), + )(fieldop.definition_stage) + samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, + from_fieldop=samecode_fieldop, + ) + )(samecode_fieldop.definition_stage) + different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({jdim: 10}, gtx.int32),), + kwargs={"out": gtx.zeros({jdim: 10}, gtx.int32)}, + from_fieldop=different_fieldop, + ) + )(different_fieldop.definition_stage) + different_args = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( + gtx.backend.FopArgsInjector( + args=(gtx.zeros({idim: 11}, gtx.int32),), + kwargs={"out": gtx.zeros({idim: 11}, gtx.int32)}, + from_fieldop=fieldop, + ) + )(fieldop.definition_stage) + + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure) + assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure) + + +def test_cache_key_program_def(program, samecode_program, different_program): + assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( + program.definition_stage + ) + assert stages.fingerprint_stage(different_program.definition_stage) != stages.fingerprint_stage( + program.definition_stage + ) + + +def test_cache_key_past_def(program, samecode_program, different_program): + past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage) + samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) + different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage) + + assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past) + assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past) diff --git a/tests/next_tests/unit_tests/otf_tests/test_workflow.py b/tests/next_tests/unit_tests/otf_tests/test_workflow.py index 9e1e14edaf..2a274c0110 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_workflow.py +++ b/tests/next_tests/unit_tests/otf_tests/test_workflow.py @@ -77,7 +77,7 @@ def test_cached_with_hashing(): def hashing(inp: list[int]) -> int: return hash(sum(inp)) - wf = workflow.CachedStep(step=lambda inp: inp + [1], hash_function=hashing) + wf = workflow.CachedStep(step=lambda inp: [*inp, 1], hash_function=hashing) assert wf([1, 2, 3]) == [1, 2, 3, 1] assert wf([3, 2, 1]) == [1, 2, 3, 1] diff --git a/tox.ini b/tox.ini index d4418c4ebc..8479e4c52c 100644 --- a/tox.ini +++ b/tox.ini @@ -109,10 +109,12 @@ commands = description = Run notebooks commands_pre = jupytext docs/user/next/QuickstartGuide.md --to .ipynb + jupytext docs/user/next/Advanced_ToolchainWalkthrough.md --to .ipynb commands = python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} + python -m pytest --nbmake docs/user/next/Advanced_ToolchainWalkthrough.ipynb -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} # -- Other artefacts --