Skip to content

Commit

Permalink
feat: First Batch of Translators (#20)
Browse files Browse the repository at this point in the history
This PR introduces a series of primitive translators, most of them are
based on the prototype, with some improvements.
I just copied them over from the development branch, which is not so
nice, but was the simplest thing to to without also introducing the
other stuff.

It is important that the tests from the development branch were not
added, to keep the PR small.
Furthermore, we need something to test, so this PR must go first.

For organizational reasons, the development history of this PR happened
to be contained in [PR#21](#21).

---------

Co-authored-by: Enrique González Paredes <[email protected]>
  • Loading branch information
philip-paul-mueller and egparedes authored Sep 26, 2024
1 parent 19c89b0 commit 0a9f361
Show file tree
Hide file tree
Showing 21 changed files with 1,868 additions and 298 deletions.
64 changes: 57 additions & 7 deletions src/jace/translator/jaxpr_translator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import dace
from dace import data as dace_data, properties as dace_properties
from dace.sdfg import propagation as dace_propagation
from jax import core as jax_core

from jace import util
Expand All @@ -35,8 +36,11 @@ class JaxprTranslationBuilder:
- there are only transient variables inside the SDFG,
- it lacks the special `__return` variable,
- the `arg_names` parameter is not set,
- for all scalar values a ` Scalar` SDFG variable is used, thus they cannot
be used to return anything.
- for all scalar values a `Scalar` SDFG variable is used, thus they cannot
be used for returning values,
- for every transient there is exactly one access node that writes to it,
except if the name of the array starts with `__jace_mutable_`, in which case
it can be written to multiple times.
For these reasons the SDFG is not directly usable, and further manipulations
have to be performed. Especially, DaCe's validation function will fail and
Expand Down Expand Up @@ -498,7 +502,8 @@ def _allocate_translation_ctx(
@property
def _ctx(self) -> TranslationContext:
"""Returns the currently active translation context."""
assert len(self._ctx_stack) != 0, "No context is active."
if not self.is_allocated():
raise RuntimeError("The context is not allocated.")
return self._ctx_stack[-1]

def _clear_translation_ctx(self) -> TranslationContext | None:
Expand Down Expand Up @@ -550,6 +555,7 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None:
translator = self._primitive_translators[primitive_name]

# Create the state into which the equation should be translated
prev_terminal_state = self._ctx.terminal_state
eqn_state = self.append_new_state(
label=f"{primitive_name}_{'_'.join(out_var_names)}",
prev_state=None, # forces the creation of a new terminal state
Expand All @@ -569,11 +575,15 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None:
if eqn_state is not self._ctx.terminal_state:
raise RuntimeError("Inconsistent terminal state was detected.")
new_sdfg_term_state = eqn_state
if not self._ctx.validate():
raise RuntimeError("Detected an invalid SDFG under construction.")

# Propagate the Memlets through the newly created state machine
self._propagate_memlets_in_new_states(
prev_terminal_state,
new_sdfg_term_state,
)
# Modify terminal root state of 'self'
self._ctx.terminal_state = new_sdfg_term_state
self._ctx.validate()

def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext:
"""
Expand Down Expand Up @@ -680,6 +690,47 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]:

return out_var_names

def _propagate_memlets_in_new_states(
self,
prev_terminal_state: dace.SDFGState,
new_terminal_state: dace.SDFGState,
) -> None:
"""
Propagate the Memlets inside the newly added parts of the state machine.
This function performs BFS starting at `prev_terminal_state` that is bound
by `new_terminal_state`.
Args:
prev_terminal_state: Terminal state before the expansion of the
state machine.
new_terminal_state: Terminal state after the expansion.
"""
seen: set[dace.SDFGState] = {prev_terminal_state}
nodes_to_process: list[dace.SDFGState] = [
edge.dst for edge in self.sdfg.out_edges(prev_terminal_state)
]

while nodes_to_process:
currently_processing = nodes_to_process.pop()
if (
self.sdfg.out_degree(currently_processing) == 0
and currently_processing != new_terminal_state
):
raise dace.sdfg.InvalidSDFGError(
f"Found leaf node '{currently_processing}' that is not the terminal node.",
self.sdfg,
self.sdfg.node_id(currently_processing),
)

seen.add(currently_processing)
dace_propagation.propagate_memlets_state(self.sdfg, currently_processing)
nodes_to_process.extend(
edge.dst
for edge in self.sdfg.out_edges(currently_processing)
if edge.dst not in seen
)

@property
def _start_state(self) -> dace.SDFGState:
return cast(dace.SDFGState, self._ctx.start_state)
Expand Down Expand Up @@ -739,7 +790,7 @@ def __init__(self, name: str | None, jaxpr: jax_core.ClosedJaxpr) -> None:
self.terminal_state = self.start_state
self.jaxpr = jaxpr

def validate(self) -> bool:
def validate(self) -> None:
"""
Validate internal state of `self`.
Expand Down Expand Up @@ -778,4 +829,3 @@ def validate(self) -> bool:
self.sdfg,
None,
)
return True
214 changes: 214 additions & 0 deletions src/jace/translator/mapped_operation_base_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming)
#
# Copyright (c) 2024, ETH Zurich
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Module implementing the `MappedOperationTranslatorBase` helper class."""

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING

import dace
from typing_extensions import final, override

from jace import translator, util


if TYPE_CHECKING:
from collections.abc import Sequence

from jax import core as jax_core


class MappedOperationTranslatorBase(translator.PrimitiveTranslator):
"""
Implements the base for all "mapped base operations".
A mapped base operation `f` is an operation that has several inputs arrays
that are elementwise combined to a single output array. A prime example for
this would be the addition of two arrays. Essentially it assumes that the
Tasklet code can be written as:
```
__out = f(__in0, __in1, __in3, ...)
```
where `__in*` are the connector names of the Tasklet and `__out` is the
output connector. For problems such as this, the SDFG API provides the
`SDFGState.add_mapped_tasklet()` function. However, because the function
operates on a very low level and is very verbose to use, this class acts
as a convenience wrapper around it.
To use this class a user has to define the abstract `write_tasklet_code()` method.
This function generates the entire code that should be put into the Tasklet,
include the assignment to `__out`. If needed the translator will perform
literal substitution on the returned code and broadcast the inputs to match
the outputs.
If needed a subclass can also override the `make_input_memlets()` function
to generate custom input Memlets, such as adding an offset.
Args:
primitive_name: The name of the primitive `self` should bind to.
Note:
This class will always generate a mapped Tasklet, even if a scalar is handled.
"""

def __init__(self, primitive_name: str) -> None:
self._prim_name = primitive_name

@property
def primitive(self) -> str:
"""Returns the primitive that should be translated."""
return self._prim_name

@final
@override
def __call__(
self,
builder: translator.JaxprTranslationBuilder,
in_var_names: Sequence[str | None],
out_var_names: Sequence[str],
eqn: jax_core.JaxprEqn,
eqn_state: dace.SDFGState,
) -> None:
"""
Create the mapped Tasklet.
The function will create the map ranges based on the shape of the
output array. It will then call `make_input_memlets()` to get the input
Memlets. After that it calls `write_tasklet_code()` to get the Tasklet
code and perform literal substitution by forwarding it to
`self.literal_substitution()`. After that it will create the mapped Tasklet.
Note:
For a description of the arguments see `PrimitiveTranslatorCallable`.
"""
assert len(out_var_names) == 1
if util.get_jax_var_shape(eqn.outvars[0]):
tskl_ranges: list[tuple[str, str]] = [
(f"__i{dim}", f"0:{N}")
for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0]))
]
tskl_output: dict[str, dace.Memlet] = {
"__out": dace.Memlet.simple(
out_var_names[0], ", ".join(name for name, _ in tskl_ranges)
)
}

else:
# If we have a scalar we will generate a Map, but it will be trivial.
tskl_ranges = [("__jace_iterator_SCALAR", "0:1")]
tskl_output = {"__out": dace.Memlet.simple(out_var_names[0], "0")}

tskl_inputs: dict[str, dace.Memlet] = self.make_input_memlets(
tskl_ranges, in_var_names, eqn
)
tskl_name = f"{self.primitive}_{out_var_names[0]}"
tskl_code = self.write_tasklet_code(tskl_ranges, in_var_names, eqn)
tskl_code = self.literal_substitution(tskl_code, in_var_names, eqn)

eqn_state.add_mapped_tasklet(
name=tskl_name,
map_ranges=tskl_ranges,
inputs=tskl_inputs,
code=tskl_code,
outputs=tskl_output,
external_edges=True,
)

return eqn_state

@abstractmethod
def write_tasklet_code(
self,
tskl_ranges: Sequence[tuple[str, str]],
in_var_names: Sequence[str | None],
eqn: jax_core.JaxprEqn,
) -> str:
"""
Return the Python code that should be put inside the Tasklet.
This also includes the assignment statement, i.e. `__out`.
However, the base will do literal substitution on the returned object.
Args:
tskl_ranges: List of pairs used as map parameter, first element
is the name iteration index of the dimension, second is its range.
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation.
"""
...

def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need them.
self,
tskl_ranges: Sequence[tuple[str, str]],
in_var_names: Sequence[str | None],
eqn: jax_core.JaxprEqn,
) -> dict[str, dace.Memlet]:
"""
Generate the input Memlets for the non literal operators of the primitive.
The returned `dict` maps the input connector of the Tasklet to the Memlet
that is used to connect it to the Map entry node.
Args:
tskl_ranges: List of pairs used as map parameter, first element
is the name iteration index of the dimension, second is its range
in_var_names: The list of SDFG variables used as input, `None` if literal.
eqn: The equation object.
"""
out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0]))
out_rank = len(out_shape)
if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars):
raise NotImplementedError(
f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! "
f"Eqn: {eqn} || {tuple(util.get_jax_var_shape(eqn.outvars[0]))}"
)

# Now we will generate the input Memlets.
tskl_inputs: dict[str, dace.Memlet] = {}
for i, (in_var_name, in_shape) in enumerate(
zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars))
):
if in_var_name is None:
pass

elif in_shape == ():
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0")

else:
dims_to_bcast = [
dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1
]
tskl_inputs[f"__in{i}"] = dace.Memlet.simple(
in_var_name,
", ".join(
("0" if i in dims_to_bcast else it_var)
for i, (it_var, _) in enumerate(tskl_ranges)
),
)
return tskl_inputs

def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it.
self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn
) -> str:
"""
Perform literal substitution on the proto Tasklet code `tskl_code`.
Args:
tskl_code: The proto Tasklet code with literal.
in_var_names: The list of SDFG variables used as input.
eqn: The equation.
Note:
It is allowed but not recommended to override this function.
"""
for i, in_var_name in enumerate(in_var_names):
if in_var_name is None:
t_val = util.get_jax_literal_value(eqn.invars[i])
tskl_code = tskl_code.replace(f"__in{i}", str(t_val))
return tskl_code
Loading

0 comments on commit 0a9f361

Please sign in to comment.