diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/Advanced_ToolchainWalkthrough.md
new file mode 100644
index 0000000000..94a7bfa7e2
--- /dev/null
+++ b/docs/user/next/Advanced_ToolchainWalkthrough.md
@@ -0,0 +1,514 @@
+```python
+import dataclasses
+import inspect
+
+import gt4py.next as gtx
+from gt4py.next import backend
+
+import devtools
+```
+
+
+
+
+```python
+I = gtx.Dimension("I")
+Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,))
+OFFSET_PROVIDER = {"Ioff": I}
+```
+
+# Toolchain Overview
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+```
+
+# Walkthrough from Field Operator
+
+## Starting Out
+
+```python
+@gtx.field_operator
+def example_fo(a: gtx.Field[[I], gtx.float64]) -> gtx.Field[[I], gtx.float64]:
+ return a + 1.0
+```
+
+```python
+start = example_fo.definition_stage
+```
+
+```python
+gtx.ffront.stages.FieldOperatorDefinition?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0mffront[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mFieldOperatorDefinition[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mdefinition[0m[0;34m:[0m [0;34m'types.FunctionType'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mgrid_type[0m[0;34m:[0m [0;34m'Optional[common.GridType]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mnode_class[0m[0;34m:[0m [0;34m'type[OperatorNodeT]'[0m [0;34m=[0m [0;34m<[0m[0;32mclass[0m [0;34m'gt4py.next.ffront.field_operator_ast.FieldOperator'[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mattributes[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m [0;34m=[0m [0;34m<[0m[0mfactory[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m FieldOperatorDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None, node_class: 'type[OperatorNodeT]' = , attributes: 'dict[str, Any]' = )
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/ffront/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## DSL -> FOAST
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style fdef fill:red
+style foast fill:red
+linkStyle 0 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+foast = backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(start)
+```
+
+```python
+gtx.ffront.stages.FoastOperatorDefinition?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0mffront[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mFoastOperatorDefinition[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mfoast_node[0m[0;34m:[0m [0;34m'OperatorNodeT'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mclosure_vars[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mgrid_type[0m[0;34m:[0m [0;34m'Optional[common.GridType]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mattributes[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m [0;34m=[0m [0;34m<[0m[0mfactory[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m FoastOperatorDefinition(foast_node: 'OperatorNodeT', closure_vars: 'dict[str, Any]', grid_type: 'Optional[common.GridType]' = None, attributes: 'dict[str, Any]' = )
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/ffront/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## FOAST -> ITIR
+
+This also happens inside the `decorator.FieldOperator.__gt_itir__` method during the lowering from calling Programs to ITIR
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style foast fill:red
+style itir_expr fill:red
+linkStyle 1 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+fitir = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(foast)
+```
+
+```python
+fitir.__class__
+```
+
+ gt4py.next.iterator.ir.FunctionDefinition
+
+## FOAST -> FOAST closure
+
+This is preparation for "directly calling" a field operator.
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style foast fill:red
+style fclos fill:red
+linkStyle 2 stroke:red,stroke-width:4px,color:pink
+```
+
+Here we have to dynamically generate a workflow step, because the arguments were not known before.
+
+```python
+fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_inject_args.__class__(
+ args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),),
+ kwargs={
+ "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64)
+ },
+ from_fieldop=example_fo
+)(foast)
+```
+
+```python
+gtx.ffront.stages.FoastClosure?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0mffront[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mFoastClosure[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mfoast_op_def[0m[0;34m:[0m [0;34m'FoastOperatorDefinition[OperatorNodeT]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0margs[0m[0;34m:[0m [0;34m'tuple[Any, ...]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mkwargs[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mclosure_vars[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]')
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/ffront/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## FOAST with args -> PAST closure
+
+This auto-generates a program for us, directly in PAST representation and forwards the call arguments to it
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style fclos fill:red
+style pclos fill:red
+linkStyle 3 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+pclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure(fclos)
+```
+
+```python
+gtx.ffront.stages.PastClosure?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0mffront[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mPastClosure[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mclosure_vars[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mpast_node[0m[0;34m:[0m [0;34m'past.Program'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mgrid_type[0m[0;34m:[0m [0;34m'Optional[common.GridType]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0margs[0m[0;34m:[0m [0;34m'tuple[Any, ...]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mkwargs[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m PastClosure(closure_vars: 'dict[str, Any]', past_node: 'past.Program', grid_type: 'Optional[common.GridType]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]')
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/ffront/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## Transform PAST closure arguments
+
+Don't ask me, seems to be necessary though
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style pclos fill:red
+%%style pclos fill:red
+linkStyle 4 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos)
+```
+
+## Lower PAST -> ITIR
+
+still forwarding the call arguments
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style pclos fill:red
+style pcall fill:red
+linkStyle 5 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+pitir = backend.DEFAULT_PROG_TRANSFORMS.past_to_itir(pclost)
+```
+
+```python
+gtx.otf.stages.ProgramCall?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0motf[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mProgramCall[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mprogram[0m[0;34m:[0m [0;34m'itir.FencilDefinition'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0margs[0m[0;34m:[0m [0;34m'tuple[Any, ...]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mkwargs[0m[0;34m:[0m [0;34m'dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m Iterator IR representaion of a program together with arguments to be passed to it.
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/otf/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## Executing The Result
+
+```python
+gtx.program_processors.runners.roundtrip.executor(pitir.program, *pitir.args, offset_provider=OFFSET_PROVIDER, **pitir.kwargs)
+```
+
+```python
+pitir.args
+```
+
+ (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
+ NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])),
+ 10,
+ 10)
+
+## Full Field Operator Toolchain
+
+using the default step order
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style fdef fill:red
+style foast fill:red
+style fclos fill:red
+style pclos fill:red
+style pcall fill:red
+linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink
+```
+
+### Starting from DSL
+
+```python
+foast_toolchain = backend.DEFAULT_FIELDOP_TRANSFORMS.replace(
+ foast_inject_args=backend.FopArgsInjector(args=fclos.args, kwargs=fclos.kwargs, from_fieldop=example_fo)
+)
+pitir2 = foast_toolchain(start)
+assert pitir2 == pitir
+```
+
+#### Pass The result to the compile workflow and execute
+
+```python
+example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflow(
+ dataclasses.replace(pitir2, kwargs=pitir2.kwargs | {"offset_provider": OFFSET_PROVIDER})
+)
+```
+
+```python
+example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER)
+```
+
+```python
+example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER)
+```
+
+```python
+pitir2.args[1].asnumpy()
+```
+
+ array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])
+
+### Starting from FOAST
+
+Note that it is the exact same call but with a different input stage
+
+```python
+pitir3 = foast_toolchain(foast)
+assert pitir3 == pitir
+```
+
+# Walkthrough starting from Program
+
+## Starting Out
+
+```python
+@gtx.program
+def example_prog(a: gtx.Field[[I], gtx.float64], out: gtx.Field[[I], gtx.float64]) -> None:
+ example_fo(a, out=out)
+```
+
+```python
+p_start = example_prog.definition_stage
+```
+
+```python
+gtx.ffront.stages.ProgramDefinition?
+```
+
+ [0;31mInit signature:[0m
+ [0mgtx[0m[0;34m.[0m[0mffront[0m[0;34m.[0m[0mstages[0m[0;34m.[0m[0mProgramDefinition[0m[0;34m([0m[0;34m[0m
+ [0;34m[0m [0mdefinition[0m[0;34m:[0m [0;34m'types.FunctionType'[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m [0mgrid_type[0m[0;34m:[0m [0;34m'Optional[common.GridType]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
+ [0;34m[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
+ [0;31mDocstring:[0m ProgramDefinition(definition: 'types.FunctionType', grid_type: 'Optional[common.GridType]' = None)
+ [0;31mFile:[0m ~/Code/gt4py/src/gt4py/next/ffront/stages.py
+ [0;31mType:[0m type
+ [0;31mSubclasses:[0m
+
+## DSL -> PAST
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style pdef fill:red
+style past fill:red
+linkStyle 6 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+p_past = backend.DEFAULT_PROG_TRANSFORMS.func_to_past(p_start)
+```
+
+## PAST -> Closure
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style past fill:red
+style pclos fill:red
+linkStyle 7 stroke:red,stroke-width:4px,color:pink
+```
+
+```python
+pclos = backend.DEFAULT_PROG_TRANSFORMS.replace(
+ past_inject_args=backend.ProgArgsInjector(
+ args=fclos.args,
+ kwargs=fclos.kwargs
+ )
+)(p_past)
+```
+
+## Full Program Toolchain
+
+```mermaid
+graph LR
+
+fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition)
+foast -->|foast_to_itir| itir_expr(itir.Expr)
+foast -->|foast_inject_args| fclos(FoastClosure)
+fclos -->|foast_to_past_closure| pclos(PastClosure)
+pclos -->|past_process_args| pclos
+pclos -->|past_to_itir| pcall(ProgramCall)
+
+pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition)
+past -->|past_lint| past
+past -->|past_inject_args| pclos(ProgramClosure)
+
+style pdef fill:red
+style past fill:red
+style pclos fill:red
+style pcall fill:red
+linkStyle 4,5,6,7 stroke:red,stroke-width:4px,color:pink
+```
+
+### Starting from DSL
+
+```python
+toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace(
+ past_inject_args=backend.ProgArgsInjector(
+ args=fclos.args,
+ kwargs=fclos.kwargs
+ )
+)
+```
+
+```python
+p_itir1 = toolchain(p_start)
+```
+
+```python
+p_itir2 = toolchain(p_past)
+```
+
+```python
+assert p_itir1 == p_itir2
+```
diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py
index 42473bea63..e406a5f097 100644
--- a/src/gt4py/eve/extended_typing.py
+++ b/src/gt4py/eve/extended_typing.py
@@ -207,15 +207,20 @@ def __delete__(self, _instance: _C) -> None: ...
class HashlibAlgorithm(Protocol):
"""Used in the hashlib module of the standard library."""
- digest_size: int
- block_size: int
- name: str
+ @property
+ def block_size(self) -> int: ...
+
+ @property
+ def digest_size(self) -> int: ...
+
+ @property
+ def name(self) -> str: ...
def __init__(self, data: ReadableBuffer = ...) -> None: ...
- def copy(self) -> HashlibAlgorithm: ...
+ def copy(self) -> Self: ...
- def update(self, data: ReadableBuffer) -> None: ...
+ def update(self, data: Buffer, /) -> None: ...
def digest(self) -> bytes: ...
diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py
index 01c066ca91..d1f9d0f7d5 100644
--- a/src/gt4py/eve/utils.py
+++ b/src/gt4py/eve/utils.py
@@ -401,12 +401,12 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No
"""
if hash_algorithm is None:
- hash_algorithm = xxhash.xxh64() # type: ignore[assignment]
+ hash_algorithm = xxhash.xxh64()
elif isinstance(hash_algorithm, str):
- hash_algorithm = hashlib.new(hash_algorithm) # type: ignore[assignment]
+ hash_algorithm = hashlib.new(hash_algorithm)
- hash_algorithm.update(pickle.dumps(args)) # type: ignore[union-attr]
- result = hash_algorithm.hexdigest() # type: ignore[union-attr]
+ hash_algorithm.update(pickle.dumps(args))
+ result = hash_algorithm.hexdigest()
assert isinstance(result, str)
return result
diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py
index cfa4911b57..3d3c7a27e1 100644
--- a/src/gt4py/next/backend.py
+++ b/src/gt4py/next/backend.py
@@ -19,29 +19,161 @@
from gt4py._core import definitions as core_defs
from gt4py.next import allocators as next_allocators
-from gt4py.next.ffront import func_to_past, past_process_args, past_to_itir, stages as ffront_stages
-from gt4py.next.otf import recipes
+from gt4py.next.ffront import (
+ foast_to_itir,
+ foast_to_past,
+ func_to_foast,
+ func_to_past,
+ past_process_args,
+ past_to_itir,
+ stages as ffront_stages,
+)
+from gt4py.next.ffront.past_passes import linters as past_linters
+from gt4py.next.iterator import ir as itir
+from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors import processor_interface as ppi
-DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow(
- func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True),
- past_transform_args=past_process_args.past_process_args,
- past_to_itir=past_to_itir.PastToItirFactory(),
-)
+@dataclasses.dataclass(frozen=True)
+class FopArgsInjector(workflow.Workflow):
+ args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
+ kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
+ from_fieldop: Any = None
+
+ def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure:
+ return ffront_stages.FoastClosure(
+ foast_op_def=inp,
+ args=self.args,
+ kwargs=self.kwargs,
+ closure_vars={inp.foast_node.id: self.from_fieldop},
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class FieldopTransformWorkflow(workflow.NamedStepSequence):
+ """Modular workflow for transformations with access to intermediates."""
+
+ func_to_foast: workflow.SkippableStep[
+ ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition,
+ ffront_stages.FoastOperatorDefinition,
+ ] = dataclasses.field(
+ default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True)
+ )
+ foast_inject_args: workflow.Workflow[
+ ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure
+ ] = dataclasses.field(default_factory=FopArgsInjector)
+ foast_to_past_closure: workflow.Workflow[
+ ffront_stages.FoastClosure, ffront_stages.PastClosure
+ ] = dataclasses.field(
+ default_factory=lambda: foast_to_past.FoastToPastClosure(
+ foast_to_past=workflow.CachedStep(
+ foast_to_past.foast_to_past, hash_function=ffront_stages.fingerprint_stage
+ )
+ )
+ )
+ past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
+ dataclasses.field(default=past_process_args.past_process_args)
+ )
+ past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
+ dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
+ )
+
+ foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = (
+ dataclasses.field(
+ default_factory=lambda: workflow.CachedStep(
+ step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.fingerprint_stage
+ )
+ )
+ )
+
+ @property
+ def step_order(self) -> list[str]:
+ return [
+ "func_to_foast",
+ "foast_inject_args",
+ "foast_to_past_closure",
+ "past_transform_args",
+ "past_to_itir",
+ ]
+
+
+DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow()
+
+
+@dataclasses.dataclass(frozen=True)
+class ProgArgsInjector(workflow.Workflow):
+ args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
+ kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure:
+ return ffront_stages.PastClosure(
+ past_node=inp.past_node,
+ closure_vars=inp.closure_vars,
+ grid_type=inp.grid_type,
+ args=self.args,
+ kwargs=self.kwargs,
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class ProgramTransformWorkflow(workflow.NamedStepSequence):
+ """Modular workflow for transformations with access to intermediates."""
+
+ func_to_past: workflow.SkippableStep[
+ ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition,
+ ffront_stages.PastProgramDefinition,
+ ] = dataclasses.field(
+ default_factory=lambda: func_to_past.OptionalFuncToPastFactory(cached=True)
+ )
+ past_lint: workflow.Workflow[
+ ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition
+ ] = dataclasses.field(default_factory=past_linters.LinterFactory)
+ past_inject_args: workflow.Workflow[
+ ffront_stages.PastProgramDefinition, ffront_stages.PastClosure
+ ] = dataclasses.field(default_factory=ProgArgsInjector)
+ past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
+ dataclasses.field(default=past_process_args.past_process_args)
+ )
+ past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
+ dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
+ )
+
+
+DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow()
@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
executor: ppi.ProgramExecutor
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
- transformer: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS
+ transforms_fop: FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS
+ transforms_prog: ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS
def __call__(
- self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any]
+ self,
+ program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition,
+ *args: tuple[Any],
+ **kwargs: dict[str, Any],
) -> None:
- transformer = self.transformer.replace(args=args, kwargs=kwargs)
- program_call = transformer(program)
+ if isinstance(
+ program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition)
+ ):
+ offset_provider = kwargs.pop("offset_provider")
+ from_fieldop = kwargs.pop("from_fieldop")
+ transforms_fop = self.transforms_fop.replace(
+ foast_inject_args=FopArgsInjector(
+ args=args, kwargs=kwargs, from_fieldop=from_fieldop
+ )
+ )
+ program_call = transforms_fop(program)
+ program_call = dataclasses.replace(
+ program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider}
+ )
+ else:
+ transforms_prog = self.transforms_prog.replace(
+ past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs)
+ )
+ program_call = transforms_prog(program)
self.executor(program_call.program, *program_call.args, **program_call.kwargs)
@property
diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py
index a01cf0959a..be1b3c1fa8 100644
--- a/src/gt4py/next/ffront/decorator.py
+++ b/src/gt4py/next/ffront/decorator.py
@@ -24,12 +24,11 @@
import typing
import warnings
from collections.abc import Callable
-from typing import Generic, TypeVar
+from typing import Any, Generic, Optional, TypeVar
from gt4py import eve
from gt4py._core import definitions as core_defs
-from gt4py.eve import utils as eve_utils
-from gt4py.eve.extended_typing import Any, Optional
+from gt4py.eve import extended_typing as xtyping
from gt4py.next import (
allocators as next_allocators,
backend as next_backend,
@@ -39,24 +38,14 @@
from gt4py.next.common import Dimension, GridType
from gt4py.next.embedded import operators as embedded_operators
from gt4py.next.ffront import (
- dialect_ast_enums,
field_operator_ast as foast,
past_process_args,
past_to_itir,
- program_ast as past,
stages as ffront_stages,
transform_utils,
type_specifications as ts_ffront,
)
-from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
-from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering
-from gt4py.next.ffront.func_to_foast import FieldOperatorParser
from gt4py.next.ffront.gtcallable import GTCallable
-from gt4py.next.ffront.past_passes.closure_var_type_deduction import (
- ClosureVarTypeDeduction as ProgramClosureVarTypeDeduction,
-)
-from gt4py.next.ffront.past_passes.type_deduction import ProgramTypeDeduction
-from gt4py.next.ffront.source_utils import SourceDefinition, get_closure_vars_from_function
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils.ir_makers import (
literal_from_value,
@@ -111,33 +100,16 @@ def definition(self):
@functools.cached_property
def past_stage(self):
- if self.backend is not None and self.backend.transformer is not None:
- return self.backend.transformer.func_to_past(self.definition_stage)
- return next_backend.DEFAULT_TRANSFORMS.func_to_past(self.definition_stage)
+ # backwards compatibility for backends that do not support the full toolchain
+ if self.backend is not None and self.backend.transforms_prog is not None:
+ return self.backend.transforms_prog.func_to_past(self.definition_stage)
+ return next_backend.DEFAULT_PROG_TRANSFORMS.func_to_past(self.definition_stage)
+ # TODO(ricoh): linting should become optional, up to the backend.
def __post_init__(self):
- function_closure_vars = transform_utils._filter_closure_vars_by_type(
- self.past_stage.closure_vars, GTCallable
- )
- misnamed_functions = [
- f"{name} vs. {func.id}"
- for name, func in function_closure_vars.items()
- if name != func.__gt_itir__().id
- ]
- if misnamed_functions:
- raise RuntimeError(
- f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}."
- )
-
- undefined_symbols = [
- symbol.id
- for symbol in self.past_stage.past_node.closure_vars
- if symbol.id not in self.past_stage.closure_vars
- ]
- if undefined_symbols:
- raise RuntimeError(
- f"The following closure variables are undefined: {', '.join(undefined_symbols)}."
- )
+ if self.backend is not None and self.backend.transforms_prog is not None:
+ self.backend.transforms_prog.past_lint(self.past_stage)
+ return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage)
@property
def __name__(self) -> str:
@@ -207,8 +179,8 @@ def itir(self) -> itir.FencilDefinition:
args=[],
kwargs={},
)
- if self.backend is not None and self.backend.transformer is not None:
- return self.backend.transformer.past_to_itir(no_args_past).program
+ if self.backend is not None and self.backend.transforms_prog is not None:
+ return self.backend.transforms_prog.past_to_itir(no_args_past).program
return past_to_itir.PastToItirFactory()(no_args_past).program
def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None:
@@ -236,6 +208,13 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any)
@dataclasses.dataclass(frozen=True)
class ProgramFromPast(Program):
+ """
+ This version of program has no DSL definition associated with it.
+
+ PAST nodes can be built programmatically from field operators or from scratch.
+ This wrapper provides the appropriate toolchain entry points.
+ """
+
past_stage: ffront_stages.PastProgramDefinition
def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs):
@@ -247,6 +226,12 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs):
ppi.ensure_processor_kind(self.backend.executor, ppi.ProgramExecutor)
self.backend(self.past_stage, *args, **(kwargs | {"offset_provider": offset_provider}))
+ # TODO(ricoh): linting should become optional, up to the backend.
+ def __post_init__(self):
+ if self.backend is not None and self.backend.transforms_prog is not None:
+ self.backend.transforms_prog.past_lint(self.past_stage)
+ return next_backend.DEFAULT_PROG_TRANSFORMS.past_lint(self.past_stage)
+
@dataclasses.dataclass(frozen=True)
class ProgramWithBoundArgs(Program):
@@ -391,12 +376,8 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
it will be deduced from actually occurring dimensions.
"""
- foast_node: OperatorNodeT
- closure_vars: dict[str, Any]
- definition: Optional[types.FunctionType]
+ definition_stage: ffront_stages.FieldOperatorDefinition
backend: Optional[ppi.ProgramExecutor]
- grid_type: Optional[GridType]
- operator_attributes: Optional[dict[str, Any]] = None
_program_cache: dict = dataclasses.field(
init=False, default_factory=dict
) # init=False ensure the cache is not copied in calls to replace
@@ -411,39 +392,37 @@ def from_function(
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
operator_attributes: Optional[dict[str, Any]] = None,
) -> FieldOperator[OperatorNodeT]:
- operator_attributes = operator_attributes or {}
-
- source_def = SourceDefinition.from_function(definition)
- closure_vars = get_closure_vars_from_function(definition)
- annotations = typing.get_type_hints(definition)
- foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations)
- loc = foast_definition_node.location
- operator_attribute_nodes = {
- key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc)
- for key, value in operator_attributes.items()
- }
- untyped_foast_node = operator_node_cls(
- id=foast_definition_node.id,
- definition=foast_definition_node,
- location=loc,
- **operator_attribute_nodes,
- )
- foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node)
return cls(
- foast_node=foast_node,
- closure_vars=closure_vars,
- definition=definition,
+ definition_stage=ffront_stages.FieldOperatorDefinition(
+ definition=definition,
+ grid_type=grid_type,
+ node_class=operator_node_cls,
+ attributes=operator_attributes or {},
+ ),
backend=backend,
- grid_type=grid_type,
- operator_attributes=operator_attributes,
)
+ # TODO(ricoh): linting should become optional, up to the backend.
+ def __post_init__(self):
+ """This ensures that DSL linting occurs at decoration time."""
+ _ = self.foast_stage
+
+ @functools.cached_property
+ def foast_stage(self) -> ffront_stages.FoastOperatorDefinition:
+ if self.backend is not None and self.backend.transforms_fop is not None:
+ return self.backend.transforms_fop.func_to_foast(self.definition_stage)
+ return next_backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(self.definition_stage)
+
@property
def __name__(self) -> str:
- return self.definition.__name__
+ return self.definition_stage.definition.__name__
+
+ @property
+ def definition(self) -> str:
+ return self.definition_stage.definition
def __gt_type__(self) -> ts.CallableType:
- type_ = self.foast_node.type
+ type_ = self.foast_stage.foast_node.type
assert isinstance(type_, ts.CallableType)
return type_
@@ -451,98 +430,46 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> FieldOperator:
return dataclasses.replace(self, backend=backend)
def with_grid_type(self, grid_type: GridType) -> FieldOperator:
- return dataclasses.replace(self, grid_type=grid_type)
+ return dataclasses.replace(
+ self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type)
+ )
def __gt_itir__(self) -> itir.FunctionDefinition:
- if hasattr(self, "__cached_itir"):
- return getattr(self, "__cached_itir")
-
- itir_node: itir.FunctionDefinition = FieldOperatorLowering.apply(self.foast_node)
-
- object.__setattr__(self, "__cached_itir", itir_node)
-
- return itir_node
+ if self.backend is not None and self.backend.transforms_fop is not None:
+ return self.backend.transforms_fop.foast_to_itir(self.foast_stage)
+ return next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_itir(self.foast_stage)
def __gt_closure_vars__(self) -> dict[str, Any]:
- return self.closure_vars
+ return self.foast_stage.closure_vars
def as_program(
self, arg_types: list[ts.TypeSpec], kwarg_types: dict[str, ts.TypeSpec]
) -> Program:
- # TODO(tehrengruber): implement mechanism to deduce default values
- # of arg and kwarg types
- # TODO(tehrengruber): check foast operator has no out argument that clashes
- # with the out argument of the program we generate here.
- hash_ = eve_utils.content_hash(
- (tuple(arg_types), tuple((name, arg) for name, arg in kwarg_types.items()))
- )
- try:
- return self._program_cache[hash_]
- except KeyError:
- pass
-
- loc = self.foast_node.location
- # use a new UID generator to allow caching
- param_sym_uids = eve_utils.UIDGenerator()
-
- type_ = self.__gt_type__()
- params_decl: list[past.Symbol] = [
- past.DataSymbol(
- id=param_sym_uids.sequential_id(prefix="__sym"),
- type=arg_type,
- namespace=dialect_ast_enums.Namespace.LOCAL,
- location=loc,
- )
- for arg_type in arg_types
- ]
- params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl]
- out_sym: past.Symbol = past.DataSymbol(
- id="out",
- type=type_info.return_type(type_, with_args=arg_types, with_kwargs=kwarg_types),
- namespace=dialect_ast_enums.Namespace.LOCAL,
- location=loc,
+ foast_with_types = (
+ ffront_stages.FoastWithTypes(
+ foast_op_def=self.foast_stage,
+ arg_types=tuple(arg_types),
+ kwarg_types=kwarg_types,
+ closure_vars={self.foast_stage.foast_node.id: self},
+ ),
)
- out_ref = past.Name(id="out", location=loc)
-
- if self.foast_node.id in self.closure_vars:
- raise RuntimeError("A closure variable has the same name as the field operator itself.")
- closure_vars = {self.foast_node.id: self}
- closure_symbols = [
- past.Symbol(
- id=self.foast_node.id,
- type=ts.DeferredType(constraint=None),
- namespace=dialect_ast_enums.Namespace.CLOSURE,
- location=loc,
+ past_stage = None
+ if self.backend is not None and self.backend.transforms_fop is not None:
+ past_stage = self.backend.transforms_fop.foast_to_past_closure.foast_to_past(
+ foast_with_types
)
- ]
-
- untyped_past_node = past.Program(
- id=f"__field_operator_{self.foast_node.id}",
- type=ts.DeferredType(constraint=ts_ffront.ProgramType),
- params=[*params_decl, out_sym],
- body=[
- past.Call(
- func=past.Name(id=self.foast_node.id, location=loc),
- args=params_ref,
- kwargs={"out": out_ref},
- location=loc,
+ else:
+ past_stage = (
+ next_backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_past_closure.foast_to_past(
+ ffront_stages.FoastWithTypes(
+ foast_op_def=self.foast_stage,
+ arg_types=tuple(arg_types),
+ kwarg_types=kwarg_types,
+ closure_vars={self.foast_stage.foast_node.id: self},
+ ),
)
- ],
- closure_vars=closure_symbols,
- location=loc,
- )
- untyped_past_node = ProgramClosureVarTypeDeduction.apply(untyped_past_node, closure_vars)
- past_node = ProgramTypeDeduction.apply(untyped_past_node)
-
- self._program_cache[hash_] = ProgramFromPast(
- definition_stage=None,
- past_stage=ffront_stages.PastProgramDefinition(
- past_node=past_node, closure_vars=closure_vars, grid_type=self.grid_type
- ),
- backend=self.backend,
- )
-
- return self._program_cache[hash_]
+ )
+ return ProgramFromPast(definition_stage=None, past_stage=past_stage, backend=self.backend)
def __call__(self, *args, **kwargs) -> None:
if not next_embedded.context.within_valid_context() and self.backend is not None:
@@ -554,36 +481,56 @@ def __call__(self, *args, **kwargs) -> None:
if "out" not in kwargs:
raise errors.MissingArgumentError(None, "out", True)
out = kwargs.pop("out")
- args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
- # TODO(tehrengruber): check all offset providers are given
- # deduce argument types
- arg_types = []
- for arg in args:
- arg_types.append(type_translation.from_value(arg))
- kwarg_types = {}
- for name, arg in kwargs.items():
- kwarg_types[name] = type_translation.from_value(arg)
-
- return self.as_program(arg_types, kwarg_types)(
- *args, out, offset_provider=offset_provider, **kwargs
+ args, kwargs = type_info.canonicalize_arguments(
+ self.foast_stage.foast_node.type, args, kwargs
+ )
+ return self.backend(
+ self.definition_stage,
+ *args,
+ out=out,
+ offset_provider=offset_provider,
+ from_fieldop=self,
+ **kwargs,
)
else:
- if self.operator_attributes is not None and any(
+ attributes = (
+ self.definition_stage.attributes
+ if self.definition_stage
+ else self.foast_stage.attributes
+ )
+ if attributes is not None and any(
has_scan_op_attribute := [
- attribute in self.operator_attributes
- for attribute in ["init", "axis", "forward"]
+ attribute in attributes for attribute in ["init", "axis", "forward"]
]
):
assert all(has_scan_op_attribute)
- forward = self.operator_attributes["forward"]
- init = self.operator_attributes["init"]
- axis = self.operator_attributes["axis"]
- op = embedded_operators.ScanOperator(self.definition, forward, init, axis)
+ forward = attributes["forward"]
+ init = attributes["init"]
+ axis = attributes["axis"]
+ op = embedded_operators.ScanOperator(
+ self.definition_stage.definition, forward, init, axis
+ )
else:
- op = embedded_operators.EmbeddedOperator(self.definition)
+ op = embedded_operators.EmbeddedOperator(self.definition_stage.definition)
return embedded_operators.field_operator_call(op, args, kwargs)
+@dataclasses.dataclass(frozen=True)
+class FieldOperatorFromFoast(FieldOperator):
+ """
+ This version of the field operator does not have a DSL definition.
+
+ FieldOperator AST nodes can be programmatically built, which may be
+ particularly useful in testing and debugging.
+ This class provides the appropriate toolchain entry points.
+ """
+
+ foast_stage: ffront_stages.FoastOperatorDefinition
+
+ def __call__(self, *args, **kwargs) -> None:
+ return self.backend(self.foast_stage, *args, from_fieldop=self, **kwargs)
+
+
@typing.overload
def field_operator(
definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor]
@@ -695,3 +642,29 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator:
)
return scan_operator_inner if definition is None else scan_operator_inner(definition)
+
+
+@ffront_stages.add_content_to_fingerprint.register
+def add_fieldop_to_fingerprint(obj: FieldOperator, hasher: xtyping.HashlibAlgorithm) -> None:
+ ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher)
+ ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
+
+
+@ffront_stages.add_content_to_fingerprint.register
+def add_foast_fieldop_to_fingerprint(
+ obj: FieldOperatorFromFoast, hasher: xtyping.HashlibAlgorithm
+) -> None:
+ ffront_stages.add_content_to_fingerprint(obj.foast_stage, hasher)
+ ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
+
+
+@ffront_stages.add_content_to_fingerprint.register
+def add_program_to_fingerprint(obj: Program, hasher: xtyping.HashlibAlgorithm) -> None:
+ ffront_stages.add_content_to_fingerprint(obj.definition_stage, hasher)
+ ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
+
+
+@ffront_stages.add_content_to_fingerprint.register
+def add_past_program_to_fingerprint(obj: ProgramFromPast, hasher: xtyping.HashlibAlgorithm) -> None:
+ ffront_stages.add_content_to_fingerprint(obj.past_stage, hasher)
+ ffront_stages.add_content_to_fingerprint(obj.backend, hasher)
diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py
index 471840ff1b..fad4df8c84 100644
--- a/src/gt4py/next/ffront/foast_passes/type_deduction.py
+++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py
@@ -12,7 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
-from typing import Any, Optional, cast
+from typing import Any, Optional, TypeVar, cast
import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
@@ -29,6 +29,9 @@
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
+OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode)
+
+
def with_altered_scalar_kind(
type_spec: ts.TypeSpec, new_scalar_kind: ts.ScalarKind
) -> ts.ScalarType | ts.FieldType:
@@ -247,7 +250,7 @@ class FieldOperatorTypeDeduction(traits.VisitorWithSymbolTableTrait, NodeTransla
"""
@classmethod
- def apply(cls, node: foast.FunctionDefinition) -> foast.FunctionDefinition:
+ def apply(cls, node: OperatorNodeT) -> OperatorNodeT:
typed_foast_node = cls().visit(node)
FieldOperatorTypeDeductionCompletnessValidator.apply(typed_foast_node)
diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py
index e589ecb601..6194647e1f 100644
--- a/src/gt4py/next/ffront/foast_pretty_printer.py
+++ b/src/gt4py/next/ffront/foast_pretty_printer.py
@@ -126,6 +126,17 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o
UnaryOp = as_fmt("{op}{operand}")
+ IfStmt = as_fmt(
+ textwrap.dedent(
+ """
+ if {condition}:
+ {true_branch}
+ else:
+ {false_branch}
+ """
+ ).strip()
+ )
+
def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str:
if node.op is dialect_ast_enums.UnaryOperator.NOT:
op = "not "
@@ -234,7 +245,7 @@ def pretty_format(node: foast.LocatedNode) -> str:
>>> @field_operator
... def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]:
... return a + 1.0
- >>> print(pretty_format(field_op.foast_node))
+ >>> print(pretty_format(field_op.foast_stage.foast_node))
@field_operator
def field_op(a: Field[[IDim], float64]) -> Field[[IDim], float64]:
return a + 1.0
diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py
index 80c0f1fea3..4a2a043fcc 100644
--- a/src/gt4py/next/ffront/foast_to_itir.py
+++ b/src/gt4py/next/ffront/foast_to_itir.py
@@ -23,6 +23,7 @@
fbuiltins,
field_operator_ast as foast,
lowering_utils,
+ stages as ffront_stages,
type_specifications as ts_ffront,
)
from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES
@@ -33,6 +34,10 @@
from gt4py.next.type_system import type_info, type_specifications as ts
+def foast_to_itir(inp: ffront_stages.FoastOperatorDefinition) -> itir.Expr:
+ return FieldOperatorLowering.apply(inp.foast_node)
+
+
def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]:
if not type_info.contains_local_field(node.type):
return lambda x: im.promote_to_lifted_stencil("make_const_list")(x)
diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py
new file mode 100644
index 0000000000..b2e6324860
--- /dev/null
+++ b/src/gt4py/next/ffront/foast_to_past.py
@@ -0,0 +1,129 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import dataclasses
+
+from gt4py.eve import utils as eve_utils
+from gt4py.next.ffront import (
+ dialect_ast_enums,
+ program_ast as past,
+ stages as ffront_stages,
+ type_specifications as ts_ffront,
+)
+from gt4py.next.ffront.past_passes import closure_var_type_deduction, type_deduction
+from gt4py.next.otf import workflow
+from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
+
+
+def foast_to_past(inp: ffront_stages.FoastWithTypes) -> ffront_stages.PastProgramDefinition:
+ # TODO(tehrengruber): implement mechanism to deduce default values
+ # of arg and kwarg types
+ # TODO(tehrengruber): check foast operator has no out argument that clashes
+ # with the out argument of the program we generate here.
+
+ loc = inp.foast_op_def.foast_node.location
+ # use a new UID generator to allow caching
+ param_sym_uids = eve_utils.UIDGenerator()
+
+ type_ = inp.foast_op_def.foast_node.type
+ params_decl: list[past.Symbol] = [
+ past.DataSymbol(
+ id=param_sym_uids.sequential_id(prefix="__sym"),
+ type=arg_type,
+ namespace=dialect_ast_enums.Namespace.LOCAL,
+ location=loc,
+ )
+ for arg_type in inp.arg_types
+ ]
+ params_ref = [past.Name(id=pdecl.id, location=loc) for pdecl in params_decl]
+ out_sym: past.Symbol = past.DataSymbol(
+ id="out",
+ type=type_info.return_type(
+ type_, with_args=list(inp.arg_types), with_kwargs=inp.kwarg_types
+ ),
+ namespace=dialect_ast_enums.Namespace.LOCAL,
+ location=loc,
+ )
+ out_ref = past.Name(id="out", location=loc)
+
+ if inp.foast_op_def.foast_node.id in inp.foast_op_def.closure_vars:
+ raise RuntimeError("A closure variable has the same name as the field operator itself.")
+ closure_symbols: list[past.Symbol] = [
+ past.Symbol(
+ id=inp.foast_op_def.foast_node.id,
+ type=ts.DeferredType(constraint=None),
+ namespace=dialect_ast_enums.Namespace.CLOSURE,
+ location=loc,
+ ),
+ ]
+
+ untyped_past_node = past.Program(
+ id=f"__field_operator_{inp.foast_op_def.foast_node.id}",
+ type=ts.DeferredType(constraint=ts_ffront.ProgramType),
+ params=[*params_decl, out_sym],
+ body=[
+ past.Call(
+ func=past.Name(id=inp.foast_op_def.foast_node.id, location=loc),
+ args=params_ref,
+ kwargs={"out": out_ref},
+ location=loc,
+ )
+ ],
+ closure_vars=closure_symbols,
+ location=loc,
+ )
+ untyped_past_node = closure_var_type_deduction.ClosureVarTypeDeduction.apply(
+ untyped_past_node, inp.closure_vars
+ )
+ past_node = type_deduction.ProgramTypeDeduction.apply(untyped_past_node)
+
+ return ffront_stages.PastProgramDefinition(
+ past_node=past_node,
+ closure_vars=inp.closure_vars,
+ grid_type=inp.foast_op_def.grid_type,
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class FoastToPastClosure(workflow.NamedStepSequence):
+ foast_to_past: workflow.Workflow[
+ ffront_stages.FoastWithTypes, ffront_stages.PastProgramDefinition
+ ]
+
+ def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure:
+ # TODO(tehrengruber): check all offset providers are given
+ # deduce argument types
+ arg_types = []
+ for arg in inp.args:
+ arg_types.append(type_translation.from_value(arg))
+ kwarg_types = {}
+ for name, arg in inp.kwargs.items():
+ kwarg_types[name] = type_translation.from_value(arg)
+
+ past_def = super().__call__(
+ ffront_stages.FoastWithTypes(
+ foast_op_def=inp.foast_op_def,
+ arg_types=tuple(arg_types),
+ kwarg_types=kwarg_types,
+ closure_vars=inp.closure_vars,
+ )
+ )
+
+ return ffront_stages.PastClosure(
+ past_node=past_def.past_node,
+ closure_vars=past_def.closure_vars,
+ grid_type=past_def.grid_type,
+ args=inp.args,
+ kwargs=inp.kwargs,
+ )
diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py
index 9f24dbf6db..6127cbdef5 100644
--- a/src/gt4py/next/ffront/func_to_foast.py
+++ b/src/gt4py/next/ffront/func_to_foast.py
@@ -16,11 +16,21 @@
import ast
import builtins
-from typing import Any, Callable, Iterable, Mapping, Type, cast
+import dataclasses
+import typing
+from typing import Any, Callable, Iterable, Mapping, Type
+
+import factory
import gt4py.eve as eve
from gt4py.next import errors
-from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast
+from gt4py.next.ffront import (
+ dialect_ast_enums,
+ fbuiltins,
+ field_operator_ast as foast,
+ source_utils,
+ stages as ffront_stages,
+)
from gt4py.next.ffront.ast_passes import (
SingleAssignTargetPass,
SingleStaticAssignPass,
@@ -35,9 +45,74 @@
from gt4py.next.ffront.foast_passes.iterable_unpack import UnpackedAssignPass
from gt4py.next.ffront.foast_passes.type_alias_replacement import TypeAliasReplacement
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
+from gt4py.next.otf import workflow
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation
+@workflow.make_step
+def func_to_foast(
+ inp: ffront_stages.FieldOperatorDefinition[ffront_stages.OperatorNodeT],
+) -> ffront_stages.FoastOperatorDefinition[ffront_stages.OperatorNodeT]:
+ source_def = source_utils.SourceDefinition.from_function(inp.definition)
+ closure_vars = source_utils.get_closure_vars_from_function(inp.definition)
+ annotations = typing.get_type_hints(inp.definition)
+ foast_definition_node = FieldOperatorParser.apply(source_def, closure_vars, annotations)
+ loc = foast_definition_node.location
+ operator_attribute_nodes = {
+ key: foast.Constant(value=value, type=type_translation.from_value(value), location=loc)
+ for key, value in inp.attributes.items()
+ }
+ untyped_foast_node = inp.node_class(
+ id=foast_definition_node.id,
+ definition=foast_definition_node,
+ location=loc,
+ **operator_attribute_nodes,
+ )
+ foast_node = FieldOperatorTypeDeduction.apply(untyped_foast_node)
+ return ffront_stages.FoastOperatorDefinition(
+ foast_node=foast_node,
+ closure_vars=closure_vars,
+ grid_type=inp.grid_type,
+ attributes=inp.attributes,
+ )
+
+
+@dataclasses.dataclass(frozen=True)
+class OptionalFuncToFoast(workflow.SkippableStep):
+ step: workflow.Workflow[
+ ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition
+ ] = func_to_foast
+
+ def skip_condition(
+ self, inp: ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition
+ ) -> bool:
+ match inp:
+ case ffront_stages.FieldOperatorDefinition():
+ return False
+ case ffront_stages.FoastOperatorDefinition():
+ return True
+
+
+@dataclasses.dataclass(frozen=True)
+class OptionalFuncToFoastFactory(factory.Factory):
+ class Meta:
+ model = OptionalFuncToFoast
+
+ class Params:
+ workflow: workflow.Workflow[
+ ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition
+ ] = func_to_foast
+ cached = factory.Trait(
+ step=factory.LazyAttribute(
+ lambda o: workflow.CachedStep(
+ step=o.workflow, hash_function=ffront_stages.fingerprint_stage
+ )
+ )
+ )
+
+ step = factory.LazyAttribute(lambda o: o.workflow)
+
+
class FieldOperatorParser(DialectParser[foast.FunctionDefinition]):
"""
Parse field operator function definition from source code into FOAST.
@@ -141,7 +216,7 @@ def _builtin_type_constructor_symbols(
], # this is a constraint type that will not be inferred (as the function is polymorphic)
pos_or_kw_args={},
kw_only_args={},
- returns=cast(ts.DataType, type_translation.from_type_hint(value)),
+ returns=typing.cast(ts.DataType, type_translation.from_type_hint(value)),
),
namespace=dialect_ast_enums.Namespace.CLOSURE,
location=location,
diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py
index 6864993f4c..372386aaf4 100644
--- a/src/gt4py/next/ffront/func_to_past.py
+++ b/src/gt4py/next/ffront/func_to_past.py
@@ -21,7 +21,6 @@
import factory
-from gt4py.eve import utils as eve_utils
from gt4py.next import errors
from gt4py.next.ffront import (
dialect_ast_enums,
@@ -73,7 +72,9 @@ class Params:
workflow = func_to_past
cached = factory.Trait(
step=factory.LazyAttribute(
- lambda o: workflow.CachedStep(step=o.workflow, hash_function=eve_utils.content_hash)
+ lambda o: workflow.CachedStep(
+ step=o.workflow, hash_function=ffront_stages.fingerprint_stage
+ )
)
)
diff --git a/src/gt4py/next/ffront/past_passes/linters.py b/src/gt4py/next/ffront/past_passes/linters.py
new file mode 100644
index 0000000000..6e77262fd1
--- /dev/null
+++ b/src/gt4py/next/ffront/past_passes/linters.py
@@ -0,0 +1,59 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import factory
+
+from gt4py.next.ffront import gtcallable, stages as ffront_stages, transform_utils
+from gt4py.next.otf import workflow
+
+
+@workflow.make_step
+def lint_misnamed_functions(
+ inp: ffront_stages.PastProgramDefinition,
+) -> ffront_stages.PastProgramDefinition:
+ function_closure_vars = transform_utils._filter_closure_vars_by_type(
+ inp.closure_vars, gtcallable.GTCallable
+ )
+ misnamed_functions = [
+ f"{name} vs. {func.__gt_itir__().id}"
+ for name, func in function_closure_vars.items()
+ if name != func.__gt_itir__().id
+ ]
+ if misnamed_functions:
+ raise RuntimeError(
+ f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}."
+ )
+ return inp
+
+
+@workflow.make_step
+def lint_undefined_symbols(
+ inp: ffront_stages.PastProgramDefinition,
+) -> ffront_stages.PastProgramDefinition:
+ undefined_symbols = [
+ symbol.id for symbol in inp.past_node.closure_vars if symbol.id not in inp.closure_vars
+ ]
+ if undefined_symbols:
+ raise RuntimeError(
+ f"The following closure variables are undefined: {', '.join(undefined_symbols)}."
+ )
+ return inp
+
+
+class LinterFactory(factory.Factory):
+ class Meta:
+ model = workflow.CachedStep
+
+ step = lint_misnamed_functions.chain(lint_undefined_symbols)
+ hash_function = ffront_stages.fingerprint_stage
diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py
index ed7c65c0af..1da6c85981 100644
--- a/src/gt4py/next/ffront/stages.py
+++ b/src/gt4py/next/ffront/stages.py
@@ -14,12 +14,59 @@
from __future__ import annotations
+import collections
import dataclasses
+import functools
+import hashlib
import types
-from typing import Any, Optional
+import typing
+from typing import Any, Generic, Optional, TypeVar
+import xxhash
+
+from gt4py.eve import extended_typing as xtyping
from gt4py.next import common
-from gt4py.next.ffront import program_ast as past
+from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils
+from gt4py.next.type_system import type_specifications as ts
+
+
+if typing.TYPE_CHECKING:
+ pass
+
+
+OperatorNodeT = TypeVar("OperatorNodeT", bound=foast.LocatedNode)
+
+
+@dataclasses.dataclass(frozen=True)
+class FieldOperatorDefinition(Generic[OperatorNodeT]):
+ definition: types.FunctionType
+ grid_type: Optional[common.GridType] = None
+ node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains
+ attributes: dict[str, Any] = dataclasses.field(default_factory=dict)
+
+
+@dataclasses.dataclass(frozen=True)
+class FoastOperatorDefinition(Generic[OperatorNodeT]):
+ foast_node: OperatorNodeT
+ closure_vars: dict[str, Any]
+ grid_type: Optional[common.GridType] = None
+ attributes: dict[str, Any] = dataclasses.field(default_factory=dict)
+
+
+@dataclasses.dataclass(frozen=True)
+class FoastWithTypes(Generic[OperatorNodeT]):
+ foast_op_def: FoastOperatorDefinition[OperatorNodeT]
+ arg_types: tuple[ts.TypeSpec, ...]
+ kwarg_types: dict[str, ts.TypeSpec]
+ closure_vars: dict[str, Any]
+
+
+@dataclasses.dataclass(frozen=True)
+class FoastClosure(Generic[OperatorNodeT]):
+ foast_op_def: FoastOperatorDefinition[OperatorNodeT]
+ args: tuple[Any, ...]
+ kwargs: dict[str, Any]
+ closure_vars: dict[str, Any]
@dataclasses.dataclass(frozen=True)
@@ -42,3 +89,73 @@ class PastClosure:
grid_type: Optional[common.GridType]
args: tuple[Any, ...]
kwargs: dict[str, Any]
+
+
+def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str:
+ hasher: xtyping.HashlibAlgorithm
+ if not algorithm:
+ hasher = xxhash.xxh64()
+ elif isinstance(algorithm, str):
+ hasher = hashlib.new(algorithm)
+ else:
+ hasher = algorithm
+
+ add_content_to_fingerprint(obj, hasher)
+ return hasher.hexdigest()
+
+
+@functools.singledispatch
+def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None:
+ hasher.update(str(obj).encode())
+
+
+for t in (str, int):
+ add_content_to_fingerprint.register(t, add_content_to_fingerprint.registry[object])
+
+
+@add_content_to_fingerprint.register(FieldOperatorDefinition)
+@add_content_to_fingerprint.register(FoastOperatorDefinition)
+@add_content_to_fingerprint.register(FoastWithTypes)
+@add_content_to_fingerprint.register(FoastClosure)
+@add_content_to_fingerprint.register(ProgramDefinition)
+@add_content_to_fingerprint.register(PastProgramDefinition)
+@add_content_to_fingerprint.register(PastClosure)
+def add_content_to_fingerprint_stages(obj: Any, hasher: xtyping.HashlibAlgorithm) -> None:
+ add_content_to_fingerprint(obj.__class__, hasher)
+ for field in dataclasses.fields(obj):
+ add_content_to_fingerprint(getattr(obj, field.name), hasher)
+
+
+@add_content_to_fingerprint.register
+def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgorithm) -> None:
+ sourcedef = source_utils.SourceDefinition.from_function(obj)
+ for item in sourcedef:
+ add_content_to_fingerprint(item, hasher)
+
+
+@add_content_to_fingerprint.register
+def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
+ for key, value in obj.items():
+ add_content_to_fingerprint(key, hasher)
+ add_content_to_fingerprint(value, hasher)
+
+
+@add_content_to_fingerprint.register
+def add_type_to_fingerprint(obj: type, hasher: xtyping.HashlibAlgorithm) -> None:
+ hasher.update(obj.__name__.encode())
+
+
+@add_content_to_fingerprint.register
+def add_sequence_to_fingerprint(
+ obj: collections.abc.Iterable, hasher: xtyping.HashlibAlgorithm
+) -> None:
+ for item in obj:
+ add_content_to_fingerprint(item, hasher)
+
+
+@add_content_to_fingerprint.register
+def add_foast_located_node_to_fingerprint(
+ obj: foast.LocatedNode, hasher: xtyping.HashlibAlgorithm
+) -> None:
+ add_content_to_fingerprint(obj.location, hasher)
+ add_content_to_fingerprint(str(obj), hasher)
diff --git a/src/gt4py/next/otf/recipes.py b/src/gt4py/next/otf/recipes.py
index 41d0c8947f..982e2e9b7b 100644
--- a/src/gt4py/next/otf/recipes.py
+++ b/src/gt4py/next/otf/recipes.py
@@ -14,43 +14,10 @@
from __future__ import annotations
import dataclasses
-from typing import Any
-from gt4py.next.ffront import stages as ffront_stages
from gt4py.next.otf import stages, step_types, workflow
-@dataclasses.dataclass(frozen=True)
-class ProgramTransformWorkflow(workflow.NamedStepSequence):
- """Modular workflow for transformations with access to intermediates."""
-
- func_to_past: workflow.SkippableStep[
- ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition,
- ffront_stages.PastProgramDefinition,
- ]
- past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure]
- past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall]
-
- args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
- kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
-
- def __call__(
- self, inp: ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition
- ) -> stages.ProgramCall:
- past_stage = self.func_to_past(inp)
- return self.past_to_itir(
- self.past_transform_args(
- ffront_stages.PastClosure(
- past_node=past_stage.past_node,
- closure_vars=past_stage.closure_vars,
- grid_type=past_stage.grid_type,
- args=self.args,
- kwargs=self.kwargs,
- )
- )
- )
-
-
@dataclasses.dataclass(frozen=True)
class OTFCompileWorkflow(workflow.NamedStepSequence):
"""The typical compiled backend steps composed into a workflow."""
diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py
index 8ae741195f..c83748dece 100644
--- a/src/gt4py/next/otf/workflow.py
+++ b/src/gt4py/next/otf/workflow.py
@@ -84,8 +84,10 @@ def replace(self, **kwargs: Any) -> Self:
return dataclasses.replace(self, **kwargs)
-class ChainableWorkflowMixin(Workflow[StartT, EndT]):
- def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[StartT, NewEndT]:
+class ChainableWorkflowMixin(Workflow[StartT, EndT_co], Protocol[StartT, EndT_co]):
+ def chain(
+ self, next_step: Workflow[EndT_co, NewEndT]
+ ) -> ChainableWorkflowMixin[StartT, NewEndT]:
return make_step(self).chain(next_step)
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
index 388849bf09..840f6d6143 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py
@@ -49,7 +49,7 @@ def __call__(self, program, *args, **kwargs) -> None:
raise ValueError("No backend selected! Backend selection is mandatory in tests.")
-no_backend = NoBackend(executor=no_exec, transformer=None, allocator=None)
+no_backend = NoBackend(executor=no_exec, transforms_prog=None, allocator=None)
OPTIONAL_PROCESSORS = []
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py
index c1bee4fa2f..77ae302efa 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py
@@ -69,7 +69,7 @@ def bar(inp1: Field[[I], int64], inp2: Field[[I], int64]) -> Field[[I], int64]:
"""
).strip()
- assert pretty_format(bar.foast_node) == expected
+ assert pretty_format(bar.foast_stage.foast_node) == expected
def test_scanop():
@@ -89,4 +89,4 @@ def scan(inp: int32) -> int32:
"""
).strip()
- assert pretty_format(scan.foast_node) == expected
+ assert pretty_format(scan.foast_stage.foast_node) == expected
diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
index e076ec4227..3419930588 100644
--- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
+++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py
@@ -18,9 +18,13 @@
import numpy as np
import pytest
-import gt4py.next as gtx
-from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast
-from gt4py.next.ffront.decorator import FieldOperator
+from gt4py.next.ffront import (
+ decorator,
+ dialect_ast_enums,
+ fbuiltins,
+ field_operator_ast as foast,
+ stages as ffront_stages,
+)
from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_translation
@@ -107,12 +111,14 @@ def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.Program
)
typed_foast_node = FieldOperatorTypeDeduction.apply(foast_node)
- return FieldOperator(
- foast_node=typed_foast_node,
- closure_vars=closure_vars,
- definition=None,
+ return decorator.FieldOperatorFromFoast(
+ definition_stage=None,
+ foast_stage=ffront_stages.FoastOperatorDefinition(
+ foast_node=typed_foast_node,
+ closure_vars=closure_vars,
+ grid_type=None,
+ ),
backend=backend,
- grid_type=None,
)
diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py
new file mode 100644
index 0000000000..67ac96d653
--- /dev/null
+++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py
@@ -0,0 +1,162 @@
+# GT4Py - GridTools Framework
+#
+# Copyright (c) 2014-2023, ETH Zurich
+# All rights reserved.
+#
+# This file is part of the GT4Py project and the GridTools framework.
+# GT4Py is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import pytest
+from gt4py import next as gtx
+from gt4py.next.ffront import stages
+
+
+@pytest.fixture
+def idim():
+ yield gtx.Dimension("I")
+
+
+@pytest.fixture
+def jdim():
+ yield gtx.Dimension("J")
+
+
+@pytest.fixture
+def fieldop(idim):
+ @gtx.field_operator
+ def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]:
+ return a
+
+ yield copy
+
+
+@pytest.fixture
+def samecode_fieldop(idim):
+ @gtx.field_operator
+ def copy(a: gtx.Field[[idim], gtx.int32]) -> gtx.Field[[idim], gtx.int32]:
+ return a
+
+ yield copy
+
+
+@pytest.fixture
+def different_fieldop(jdim):
+ @gtx.field_operator
+ def copy(a: gtx.Field[[jdim], gtx.int32]) -> gtx.Field[[jdim], gtx.int32]:
+ return a
+
+ yield copy
+
+
+@pytest.fixture
+def program(fieldop, idim):
+ copy = fieldop
+
+ @gtx.program
+ def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]):
+ copy(a, out=out)
+
+ yield copy_program
+
+
+@pytest.fixture
+def samecode_program(samecode_fieldop, idim):
+ copy = samecode_fieldop
+
+ @gtx.program
+ def copy_program(a: gtx.Field[[idim], gtx.int32], out: gtx.Field[[idim], gtx.int32]):
+ copy(a, out=out)
+
+ yield copy_program
+
+
+@pytest.fixture
+def different_program(different_fieldop, jdim):
+ copy = different_fieldop
+
+ @gtx.program
+ def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int32]):
+ copy(a, out=out)
+
+ yield copy_program
+
+
+def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop):
+ assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage(
+ fieldop.definition_stage
+ )
+ assert stages.fingerprint_stage(different_fieldop.definition_stage) != stages.fingerprint_stage(
+ fieldop.definition_stage
+ )
+
+
+def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop):
+ foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage)
+ samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(
+ samecode_fieldop.definition_stage
+ )
+ different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(
+ different_fieldop.definition_stage
+ )
+
+ assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast)
+ assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast)
+
+
+def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim):
+ foast_closure = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain(
+ gtx.backend.FopArgsInjector(
+ args=(gtx.zeros({idim: 10}, gtx.int32),),
+ kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)},
+ from_fieldop=fieldop,
+ ),
+ )(fieldop.definition_stage)
+ samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain(
+ gtx.backend.FopArgsInjector(
+ args=(gtx.zeros({idim: 10}, gtx.int32),),
+ kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)},
+ from_fieldop=samecode_fieldop,
+ )
+ )(samecode_fieldop.definition_stage)
+ different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain(
+ gtx.backend.FopArgsInjector(
+ args=(gtx.zeros({jdim: 10}, gtx.int32),),
+ kwargs={"out": gtx.zeros({jdim: 10}, gtx.int32)},
+ from_fieldop=different_fieldop,
+ )
+ )(different_fieldop.definition_stage)
+ different_args = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain(
+ gtx.backend.FopArgsInjector(
+ args=(gtx.zeros({idim: 11}, gtx.int32),),
+ kwargs={"out": gtx.zeros({idim: 11}, gtx.int32)},
+ from_fieldop=fieldop,
+ )
+ )(fieldop.definition_stage)
+
+ assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure)
+ assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure)
+ assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure)
+
+
+def test_cache_key_program_def(program, samecode_program, different_program):
+ assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage(
+ program.definition_stage
+ )
+ assert stages.fingerprint_stage(different_program.definition_stage) != stages.fingerprint_stage(
+ program.definition_stage
+ )
+
+
+def test_cache_key_past_def(program, samecode_program, different_program):
+ past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage)
+ samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage)
+ different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage)
+
+ assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(past)
+ assert stages.fingerprint_stage(different) != stages.fingerprint_stage(past)
diff --git a/tests/next_tests/unit_tests/otf_tests/test_workflow.py b/tests/next_tests/unit_tests/otf_tests/test_workflow.py
index 9e1e14edaf..2a274c0110 100644
--- a/tests/next_tests/unit_tests/otf_tests/test_workflow.py
+++ b/tests/next_tests/unit_tests/otf_tests/test_workflow.py
@@ -77,7 +77,7 @@ def test_cached_with_hashing():
def hashing(inp: list[int]) -> int:
return hash(sum(inp))
- wf = workflow.CachedStep(step=lambda inp: inp + [1], hash_function=hashing)
+ wf = workflow.CachedStep(step=lambda inp: [*inp, 1], hash_function=hashing)
assert wf([1, 2, 3]) == [1, 2, 3, 1]
assert wf([3, 2, 1]) == [1, 2, 3, 1]
diff --git a/tox.ini b/tox.ini
index d4418c4ebc..8479e4c52c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -109,10 +109,12 @@ commands =
description = Run notebooks
commands_pre =
jupytext docs/user/next/QuickstartGuide.md --to .ipynb
+ jupytext docs/user/next/Advanced_ToolchainWalkthrough.md --to .ipynb
commands =
python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1}
python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1}
python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1}
+ python -m pytest --nbmake docs/user/next/Advanced_ToolchainWalkthrough.ipynb -v -n {env:NUM_PROCESSES:1}
python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1}
# -- Other artefacts --