Skip to content

Commit

Permalink
refactor[next]: workflowify step3 (GridTools#1516)
Browse files Browse the repository at this point in the history
## Description

### New:

- `ffront.stages.FieldOperatorDefinition`
  - all the data to start the toolchain from a field operator dsl
definition

- `ffront.stages.FoastOperatorDefinition`
  - data after lowering from field operator dsl code

- `ffront.stages.FoastWithTypes`
  - program argument types in addition to the foast definition for
creating a program AST

- `ffront.stages.FoastClosure`
  - program arguments in addition to the foast definition, ready to run
the whole toolchain

### Changed:

- `decorator.Program.__post_init__`
  - implementation moved to `past_passes.linters` workflow steps
  - linting stage added to program transforms

- `decorator.FieldOperator.from_function`
  - implementation moved to workflow step in `ffront.func_to_foast`

- `decorator.FieldOperator.as_program`
  - implementation moved to workflow steps in `ffront.foast_to_past`

- `decorator.FieldOperator` data attributes
  - added: `definition_stage`
  - removed:
    - `.foast_node`: replaced with `.foast_stage.foast_node`
    - `.definition`: replaced with `.definition_stage.definition`

- `next.backend.Backend`
  - renamed: `.transformer` -> `.transforms_prog`
  - added: `.transforms_fop`, toolchain for starting from field operator

- `otf.recipes.FieldOpTransformWorkflow`
  - now has all the steps from DSL field operator to `ProgramCall` via
`foast_to_past`, with additional steps to go to the field operator
IteratorIR expression directly instead (not run by default). The latter
`foast_to_itir` step is required during lowering of programs that call a
field operator.
  • Loading branch information
DropD authored Apr 29, 2024
1 parent 2cd0c91 commit 9cc7548
Show file tree
Hide file tree
Showing 21 changed files with 1,411 additions and 248 deletions.
514 changes: 514 additions & 0 deletions docs/user/next/Advanced_ToolchainWalkthrough.md

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions src/gt4py/eve/extended_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,20 @@ def __delete__(self, _instance: _C) -> None: ...
class HashlibAlgorithm(Protocol):
"""Used in the hashlib module of the standard library."""

digest_size: int
block_size: int
name: str
@property
def block_size(self) -> int: ...

@property
def digest_size(self) -> int: ...

@property
def name(self) -> str: ...

def __init__(self, data: ReadableBuffer = ...) -> None: ...

def copy(self) -> HashlibAlgorithm: ...
def copy(self) -> Self: ...

def update(self, data: ReadableBuffer) -> None: ...
def update(self, data: Buffer, /) -> None: ...

def digest(self) -> bytes: ...

Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,12 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No
"""
if hash_algorithm is None:
hash_algorithm = xxhash.xxh64() # type: ignore[assignment]
hash_algorithm = xxhash.xxh64()
elif isinstance(hash_algorithm, str):
hash_algorithm = hashlib.new(hash_algorithm) # type: ignore[assignment]
hash_algorithm = hashlib.new(hash_algorithm)

hash_algorithm.update(pickle.dumps(args)) # type: ignore[union-attr]
result = hash_algorithm.hexdigest() # type: ignore[union-attr]
hash_algorithm.update(pickle.dumps(args))
result = hash_algorithm.hexdigest()
assert isinstance(result, str)

return result
Expand Down
154 changes: 143 additions & 11 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,161 @@

from gt4py._core import definitions as core_defs
from gt4py.next import allocators as next_allocators
from gt4py.next.ffront import func_to_past, past_process_args, past_to_itir, stages as ffront_stages
from gt4py.next.otf import recipes
from gt4py.next.ffront import (
foast_to_itir,
foast_to_past,
func_to_foast,
func_to_past,
past_process_args,
past_to_itir,
stages as ffront_stages,
)
from gt4py.next.ffront.past_passes import linters as past_linters
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import stages, workflow
from gt4py.next.program_processors import processor_interface as ppi


DEFAULT_TRANSFORMS = recipes.ProgramTransformWorkflow(
func_to_past=func_to_past.OptionalFuncToPastFactory(cached=True),
past_transform_args=past_process_args.past_process_args,
past_to_itir=past_to_itir.PastToItirFactory(),
)
@dataclasses.dataclass(frozen=True)
class FopArgsInjector(workflow.Workflow):
args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
from_fieldop: Any = None

def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure:
return ffront_stages.FoastClosure(
foast_op_def=inp,
args=self.args,
kwargs=self.kwargs,
closure_vars={inp.foast_node.id: self.from_fieldop},
)


@dataclasses.dataclass(frozen=True)
class FieldopTransformWorkflow(workflow.NamedStepSequence):
"""Modular workflow for transformations with access to intermediates."""

func_to_foast: workflow.SkippableStep[
ffront_stages.FieldOperatorDefinition | ffront_stages.FoastOperatorDefinition,
ffront_stages.FoastOperatorDefinition,
] = dataclasses.field(
default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True)
)
foast_inject_args: workflow.Workflow[
ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure
] = dataclasses.field(default_factory=FopArgsInjector)
foast_to_past_closure: workflow.Workflow[
ffront_stages.FoastClosure, ffront_stages.PastClosure
] = dataclasses.field(
default_factory=lambda: foast_to_past.FoastToPastClosure(
foast_to_past=workflow.CachedStep(
foast_to_past.foast_to_past, hash_function=ffront_stages.fingerprint_stage
)
)
)
past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
dataclasses.field(default=past_process_args.past_process_args)
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
)

foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = (
dataclasses.field(
default_factory=lambda: workflow.CachedStep(
step=foast_to_itir.foast_to_itir, hash_function=ffront_stages.fingerprint_stage
)
)
)

@property
def step_order(self) -> list[str]:
return [
"func_to_foast",
"foast_inject_args",
"foast_to_past_closure",
"past_transform_args",
"past_to_itir",
]


DEFAULT_FIELDOP_TRANSFORMS = FieldopTransformWorkflow()


@dataclasses.dataclass(frozen=True)
class ProgArgsInjector(workflow.Workflow):
args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)

def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure:
return ffront_stages.PastClosure(
past_node=inp.past_node,
closure_vars=inp.closure_vars,
grid_type=inp.grid_type,
args=self.args,
kwargs=self.kwargs,
)


@dataclasses.dataclass(frozen=True)
class ProgramTransformWorkflow(workflow.NamedStepSequence):
"""Modular workflow for transformations with access to intermediates."""

func_to_past: workflow.SkippableStep[
ffront_stages.ProgramDefinition | ffront_stages.PastProgramDefinition,
ffront_stages.PastProgramDefinition,
] = dataclasses.field(
default_factory=lambda: func_to_past.OptionalFuncToPastFactory(cached=True)
)
past_lint: workflow.Workflow[
ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition
] = dataclasses.field(default_factory=past_linters.LinterFactory)
past_inject_args: workflow.Workflow[
ffront_stages.PastProgramDefinition, ffront_stages.PastClosure
] = dataclasses.field(default_factory=ProgArgsInjector)
past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
dataclasses.field(default=past_process_args.past_process_args)
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
)


DEFAULT_PROG_TRANSFORMS = ProgramTransformWorkflow()


@dataclasses.dataclass(frozen=True)
class Backend(Generic[core_defs.DeviceTypeT]):
executor: ppi.ProgramExecutor
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
transformer: recipes.ProgramTransformWorkflow = DEFAULT_TRANSFORMS
transforms_fop: FieldopTransformWorkflow = DEFAULT_FIELDOP_TRANSFORMS
transforms_prog: ProgramTransformWorkflow = DEFAULT_PROG_TRANSFORMS

def __call__(
self, program: ffront_stages.ProgramDefinition, *args: tuple[Any], **kwargs: dict[str, Any]
self,
program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition,
*args: tuple[Any],
**kwargs: dict[str, Any],
) -> None:
transformer = self.transformer.replace(args=args, kwargs=kwargs)
program_call = transformer(program)
if isinstance(
program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition)
):
offset_provider = kwargs.pop("offset_provider")
from_fieldop = kwargs.pop("from_fieldop")
transforms_fop = self.transforms_fop.replace(
foast_inject_args=FopArgsInjector(
args=args, kwargs=kwargs, from_fieldop=from_fieldop
)
)
program_call = transforms_fop(program)
program_call = dataclasses.replace(
program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider}
)
else:
transforms_prog = self.transforms_prog.replace(
past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs)
)
program_call = transforms_prog(program)
self.executor(program_call.program, *program_call.args, **program_call.kwargs)

@property
Expand Down
Loading

0 comments on commit 9cc7548

Please sign in to comment.