diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index 3d7d04c..3e48964 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -14,6 +14,7 @@ import dace from dace import data as dace_data, properties as dace_properties +from dace.sdfg import propagation as dace_propagation from jax import core as jax_core from jace import util @@ -35,8 +36,11 @@ class JaxprTranslationBuilder: - there are only transient variables inside the SDFG, - it lacks the special `__return` variable, - the `arg_names` parameter is not set, - - for all scalar values a ` Scalar` SDFG variable is used, thus they cannot - be used to return anything. + - for all scalar values a `Scalar` SDFG variable is used, thus they cannot + be used for returning values, + - for every transient there is exactly one access node that writes to it, + except if the name of the array starts with `__jace_mutable_`, in which case + it can be written to multiple times. For these reasons the SDFG is not directly usable, and further manipulations have to be performed. Especially, DaCe's validation function will fail and @@ -498,7 +502,8 @@ def _allocate_translation_ctx( @property def _ctx(self) -> TranslationContext: """Returns the currently active translation context.""" - assert len(self._ctx_stack) != 0, "No context is active." + if not self.is_allocated(): + raise RuntimeError("The context is not allocated.") return self._ctx_stack[-1] def _clear_translation_ctx(self) -> TranslationContext | None: @@ -550,6 +555,7 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: translator = self._primitive_translators[primitive_name] # Create the state into which the equation should be translated + prev_terminal_state = self._ctx.terminal_state eqn_state = self.append_new_state( label=f"{primitive_name}_{'_'.join(out_var_names)}", prev_state=None, # forces the creation of a new terminal state @@ -569,11 +575,15 @@ def _translate_single_eqn(self, eqn: jax_core.JaxprEqn) -> None: if eqn_state is not self._ctx.terminal_state: raise RuntimeError("Inconsistent terminal state was detected.") new_sdfg_term_state = eqn_state - if not self._ctx.validate(): - raise RuntimeError("Detected an invalid SDFG under construction.") + # Propagate the Memlets through the newly created state machine + self._propagate_memlets_in_new_states( + prev_terminal_state, + new_sdfg_term_state, + ) # Modify terminal root state of 'self' self._ctx.terminal_state = new_sdfg_term_state + self._ctx.validate() def _translate_jaxpr_internal(self, jaxpr: jax_core.ClosedJaxpr) -> TranslationContext: """ @@ -680,6 +690,47 @@ def _handle_null_jaxpr(self, jaxpr: jax_core.ClosedJaxpr) -> list[str]: return out_var_names + def _propagate_memlets_in_new_states( + self, + prev_terminal_state: dace.SDFGState, + new_terminal_state: dace.SDFGState, + ) -> None: + """ + Propagate the Memlets inside the newly added parts of the state machine. + + This function performs BFS starting at `prev_terminal_state` that is bound + by `new_terminal_state`. + + Args: + prev_terminal_state: Terminal state before the expansion of the + state machine. + new_terminal_state: Terminal state after the expansion. + """ + seen: set[dace.SDFGState] = {prev_terminal_state} + nodes_to_process: list[dace.SDFGState] = [ + edge.dst for edge in self.sdfg.out_edges(prev_terminal_state) + ] + + while nodes_to_process: + currently_processing = nodes_to_process.pop() + if ( + self.sdfg.out_degree(currently_processing) == 0 + and currently_processing != new_terminal_state + ): + raise dace.sdfg.InvalidSDFGError( + f"Found leaf node '{currently_processing}' that is not the terminal node.", + self.sdfg, + self.sdfg.node_id(currently_processing), + ) + + seen.add(currently_processing) + dace_propagation.propagate_memlets_state(self.sdfg, currently_processing) + nodes_to_process.extend( + edge.dst + for edge in self.sdfg.out_edges(currently_processing) + if edge.dst not in seen + ) + @property def _start_state(self) -> dace.SDFGState: return cast(dace.SDFGState, self._ctx.start_state) @@ -739,7 +790,7 @@ def __init__(self, name: str | None, jaxpr: jax_core.ClosedJaxpr) -> None: self.terminal_state = self.start_state self.jaxpr = jaxpr - def validate(self) -> bool: + def validate(self) -> None: """ Validate internal state of `self`. @@ -778,4 +829,3 @@ def validate(self) -> bool: self.sdfg, None, ) - return True diff --git a/src/jace/translator/mapped_operation_base_translator.py b/src/jace/translator/mapped_operation_base_translator.py new file mode 100644 index 0000000..508ad13 --- /dev/null +++ b/src/jace/translator/mapped_operation_base_translator.py @@ -0,0 +1,214 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Module implementing the `MappedOperationTranslatorBase` helper class.""" + +from __future__ import annotations + +from abc import abstractmethod +from typing import TYPE_CHECKING + +import dace +from typing_extensions import final, override + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class MappedOperationTranslatorBase(translator.PrimitiveTranslator): + """ + Implements the base for all "mapped base operations". + + A mapped base operation `f` is an operation that has several inputs arrays + that are elementwise combined to a single output array. A prime example for + this would be the addition of two arrays. Essentially it assumes that the + Tasklet code can be written as: + ``` + __out = f(__in0, __in1, __in3, ...) + ``` + where `__in*` are the connector names of the Tasklet and `__out` is the + output connector. For problems such as this, the SDFG API provides the + `SDFGState.add_mapped_tasklet()` function. However, because the function + operates on a very low level and is very verbose to use, this class acts + as a convenience wrapper around it. + + To use this class a user has to define the abstract `write_tasklet_code()` method. + This function generates the entire code that should be put into the Tasklet, + include the assignment to `__out`. If needed the translator will perform + literal substitution on the returned code and broadcast the inputs to match + the outputs. + + If needed a subclass can also override the `make_input_memlets()` function + to generate custom input Memlets, such as adding an offset. + + Args: + primitive_name: The name of the primitive `self` should bind to. + + Note: + This class will always generate a mapped Tasklet, even if a scalar is handled. + """ + + def __init__(self, primitive_name: str) -> None: + self._prim_name = primitive_name + + @property + def primitive(self) -> str: + """Returns the primitive that should be translated.""" + return self._prim_name + + @final + @override + def __call__( + self, + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, + ) -> None: + """ + Create the mapped Tasklet. + + The function will create the map ranges based on the shape of the + output array. It will then call `make_input_memlets()` to get the input + Memlets. After that it calls `write_tasklet_code()` to get the Tasklet + code and perform literal substitution by forwarding it to + `self.literal_substitution()`. After that it will create the mapped Tasklet. + + Note: + For a description of the arguments see `PrimitiveTranslatorCallable`. + """ + assert len(out_var_names) == 1 + if util.get_jax_var_shape(eqn.outvars[0]): + tskl_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") + for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + tskl_output: dict[str, dace.Memlet] = { + "__out": dace.Memlet.simple( + out_var_names[0], ", ".join(name for name, _ in tskl_ranges) + ) + } + + else: + # If we have a scalar we will generate a Map, but it will be trivial. + tskl_ranges = [("__jace_iterator_SCALAR", "0:1")] + tskl_output = {"__out": dace.Memlet.simple(out_var_names[0], "0")} + + tskl_inputs: dict[str, dace.Memlet] = self.make_input_memlets( + tskl_ranges, in_var_names, eqn + ) + tskl_name = f"{self.primitive}_{out_var_names[0]}" + tskl_code = self.write_tasklet_code(tskl_ranges, in_var_names, eqn) + tskl_code = self.literal_substitution(tskl_code, in_var_names, eqn) + + eqn_state.add_mapped_tasklet( + name=tskl_name, + map_ranges=tskl_ranges, + inputs=tskl_inputs, + code=tskl_code, + outputs=tskl_output, + external_edges=True, + ) + + return eqn_state + + @abstractmethod + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + """ + Return the Python code that should be put inside the Tasklet. + + This also includes the assignment statement, i.e. `__out`. + However, the base will do literal substitution on the returned object. + + Args: + tskl_ranges: List of pairs used as map parameter, first element + is the name iteration index of the dimension, second is its range. + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation. + """ + ... + + def make_input_memlets( # noqa: PLR6301 [no-self-use] # Subclasses might need them. + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + """ + Generate the input Memlets for the non literal operators of the primitive. + + The returned `dict` maps the input connector of the Tasklet to the Memlet + that is used to connect it to the Map entry node. + + Args: + tskl_ranges: List of pairs used as map parameter, first element + is the name iteration index of the dimension, second is its range + in_var_names: The list of SDFG variables used as input, `None` if literal. + eqn: The equation object. + """ + out_shape = tuple(util.get_jax_var_shape(eqn.outvars[0])) + out_rank = len(out_shape) + if any(len(util.get_jax_var_shape(invar)) not in {0, out_rank} for invar in eqn.invars): + raise NotImplementedError( + f"'MappedOperationTranslatorBase' Inputs must have the same rank as the output! " + f"Eqn: {eqn} || {tuple(util.get_jax_var_shape(eqn.outvars[0]))}" + ) + + # Now we will generate the input Memlets. + tskl_inputs: dict[str, dace.Memlet] = {} + for i, (in_var_name, in_shape) in enumerate( + zip(in_var_names, (util.get_jax_var_shape(invar) for invar in eqn.invars)) + ): + if in_var_name is None: + pass + + elif in_shape == (): + tskl_inputs[f"__in{i}"] = dace.Memlet.simple(in_var_name, "0") + + else: + dims_to_bcast = [ + dim for dim in range(out_rank) if in_shape[dim] == 1 and out_shape[dim] != 1 + ] + tskl_inputs[f"__in{i}"] = dace.Memlet.simple( + in_var_name, + ", ".join( + ("0" if i in dims_to_bcast else it_var) + for i, (it_var, _) in enumerate(tskl_ranges) + ), + ) + return tskl_inputs + + def literal_substitution( # noqa: PLR6301 [no-self-use] # Subclasses might need it. + self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn + ) -> str: + """ + Perform literal substitution on the proto Tasklet code `tskl_code`. + + Args: + tskl_code: The proto Tasklet code with literal. + in_var_names: The list of SDFG variables used as input. + eqn: The equation. + + Note: + It is allowed but not recommended to override this function. + """ + for i, in_var_name in enumerate(in_var_names): + if in_var_name is None: + t_val = util.get_jax_literal_value(eqn.invars[i]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + return tskl_code diff --git a/src/jace/translator/post_translation.py b/src/jace/translator/post_translation.py index a00b651..6b27b4f 100644 --- a/src/jace/translator/post_translation.py +++ b/src/jace/translator/post_translation.py @@ -19,6 +19,9 @@ if TYPE_CHECKING: + from dace.sdfg import nodes as dace_nodes + from jax import core as jax_core + from jace import translator @@ -234,3 +237,139 @@ def finalize_translation_context( if validate: tsdfg.validate() return tsdfg + + +def add_nested_sdfg( + state: dace.SDFGState, + child_ctx: translator.TranslationContext, + parent_ctx: translator.TranslationContext, + in_var_names: Sequence[str], + out_var_names: Sequence[str], +) -> dace_nodes.NestedSDFG: + """ + Adds the SDFG in `child_ctx` as nested SDFG at state `state` in `parent_ctx`. + + The function is a convenience wrapper that operates directly on translation + contexts instead of SDFGs. The function will also create the necessary Memlet + connections. + + Args: + state: The state at which the nested SDFG should be inserted. + Must be part of `parent_ctx`. + child_ctx: The translation context representing the SDFG that should be added. + parent_ctx: The parent SDFG to which `child_ctx` should be added as nested + SDFG in state `state`. + in_var_names: Names of the variables in `parent_ctx` that are used as inputs for + the nested SDFG, must have the same order as `child_ctx.input_names`. + out_var_names: Names of the variables in `parent_ctx` that are used as outputs + for the nested SDFG, must have the same order as `child_ctx.output_names`. + + Returns: + The nested SDFG object. + + Note: + The function will not add `child_ctx` directly as nested SDFG. Instead it + will first pass it to `finalize_translation_context()` and operates on the + return values. This means that `child_ctx` will be modified in place, and + a copy will be added to `parent_ctx`. + It is highly recommended that `state` is empty, this makes subsequent + inlining of the nested SDFG simpler. + """ + if child_ctx.sdfg.free_symbols: + raise NotImplementedError("Symbol Mapping is not implemented.") + assert not (child_ctx.input_names is None or child_ctx.output_names is None) # Silence mypy + assert len(child_ctx.input_names) == len(in_var_names) + assert len(child_ctx.output_names) == len(out_var_names) + assert state in parent_ctx.sdfg.nodes() + assert not set(in_var_names).intersection(out_var_names) + + if any(input_name.startswith("__jace_mutable_") for input_name in in_var_names): + raise NotImplementedError( + "'__jace_mutable_' variables are not yet handled in 'add_nested_sdfg()'." + ) + if len(set(in_var_names)) != len(in_var_names): + raise ValueError( + f"An input can only be passed once, but { {in_var_name for in_var_name in in_var_names if in_var_names.count(in_var_name) > 1} } were passed multiple times." + ) + if len(set(out_var_names)) != len(out_var_names): + raise NotImplementedError( + f"Tried to write multiple times to variables: { {out_var_name for out_var_name in out_var_names if out_var_names.count(out_var_name) > 1} }." + ) + + final_child_ctx = finalize_translation_context(child_ctx) + nested_sdfg: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=final_child_ctx.sdfg, + parent=parent_ctx.sdfg, + inputs=set(final_child_ctx.input_names), + outputs=set(final_child_ctx.output_names), + ) + + # Now create the connections for the input. + for outer_name, inner_name in zip(in_var_names, final_child_ctx.input_names): + outer_array = parent_ctx.sdfg.arrays[outer_name] + state.add_edge( + state.add_read(outer_name), + None, + nested_sdfg, + inner_name, + dace.Memlet.from_array(outer_name, outer_array), + ) + + # Now we create the output connections. + for outer_name, inner_name in zip(out_var_names, final_child_ctx.output_names): + outer_array = parent_ctx.sdfg.arrays[outer_name] + state.add_edge( + nested_sdfg, + inner_name, + state.add_write(outer_name), + None, + dace.Memlet.from_array(outer_name, outer_array), + ) + + return nested_sdfg + + +def promote_literals_to_constants( + builder: translator.JaxprTranslationBuilder, + var_names: Sequence[str | None], + jax_vars: Sequence[jax_core.Atom], + name_pattern: str, +) -> list[str]: + """ + Promotes all literals in `var_names` to DaCe constants and add them to the SDFG. + + The function assumes that `var_names` are the SDFG variables equivalents of + `jax_vars`, as by convention `None` indicates a literal. The function will create + a constant for each literal and return `var_names` cleared of all literals. + For naming the variables the function will use `name_pattern`. + + Args: + builder: The builder that is used for translation. + var_names: Names of the SDFG variables, `None` indicates a literal. + jax_vars: The JAX variables, in the same order than `var_names`. + name_pattern: A pattern to generate a unique name for the variables. + + Todo: + Is a constant the right idea or should we generate a symbol? + """ + promoted_var_names: list[str] = [] + for i, var_name in enumerate(var_names): + if var_name is None: + promoted_var_name = f"__const_{name_pattern}_literal_promotion_{i}" + jax_var = jax_vars[i] + promoted_jace_var = util.JaCeVar.from_atom( + jax_var=jax_var, + name=promoted_var_name, + ) + builder.add_array(promoted_jace_var) + builder.sdfg.add_constant( + promoted_var_name, + util.get_jax_literal_value(jax_var), + builder.arrays[promoted_var_name], + ) + + else: + # Already an SDFG variable, so nothing to do. + promoted_var_name = var_name + promoted_var_names.append(promoted_var_name) + return promoted_var_names diff --git a/src/jace/translator/primitive_translator.py b/src/jace/translator/primitive_translator.py index ab84c5d..71aa067 100644 --- a/src/jace/translator/primitive_translator.py +++ b/src/jace/translator/primitive_translator.py @@ -64,6 +64,9 @@ def __call__( primitive translator was able to fully construct the dataflow graph within `eqn_state`. + After the primitive translator returns, the builder will propagate the + Memlets in all states that were newly created. + A primitive translator has to use the passed input variables, `in_var_names` and must write its output into the variables indicated by `out_var_names`. But it is allowed that a primitive translator @@ -74,7 +77,7 @@ def __call__( Args: builder: The builder object of the translation. in_var_names: List of the names of the arrays created inside the - SDFG for the inpts or `None` in case of a literal. + SDFG for the inputs or `None` in case of a literal. out_var_names: List of the names of the arrays created inside the SDFG for the outputs. eqn: The JAX primitive that should be translated. diff --git a/src/jace/translator/primitive_translators/__init__.py b/src/jace/translator/primitive_translators/__init__.py index 65f9153..f019964 100644 --- a/src/jace/translator/primitive_translators/__init__.py +++ b/src/jace/translator/primitive_translators/__init__.py @@ -8,7 +8,39 @@ from __future__ import annotations -from .alu_translator import ALUTranslator +from .arithmetic_logical_translators import ( + ArithmeticOperationTranslator, + LogicalOperationTranslator, +) +from .broadcast_in_dim_translator import BroadcastInDimTranslator +from .concatenate_translator import concatenate_translator +from .conditions import condition_translator +from .convert_element_type_translator import ConvertElementTypeTranslator +from .copy_translator import copy_translator, device_put_translator +from .gather_translator import gather_translator +from .iota_translator import IotaTranslator +from .pjit_translator import pjit_translator +from .reshape_translator import reshape_translator +from .select_n_translator import SelectNTranslator +from .slicing import SlicingTranslator, dynamic_slicing_translator +from .squeeze_translator import SqueezeTranslator -__all__ = ["ALUTranslator"] +__all__ = [ + "ArithmeticOperationTranslator", + "BroadcastInDimTranslator", + "ConvertElementTypeTranslator", + "IotaTranslator", + "LogicalOperationTranslator", + "SelectNTranslator", + "SlicingTranslator", + "SqueezeTranslator", + "concatenate_translator", + "condition_translator", + "copy_translator", + "device_put_translator", + "dynamic_slicing_translator", + "gather_translator", + "pjit_translator", + "reshape_translator", +] diff --git a/src/jace/translator/primitive_translators/alu_translator.py b/src/jace/translator/primitive_translators/alu_translator.py deleted file mode 100644 index f217924..0000000 --- a/src/jace/translator/primitive_translators/alu_translator.py +++ /dev/null @@ -1,287 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""This module contains the `ALUTranslator` which translates all arithmetic and logic primitives.""" -# ruff: noqa: W505 PLR0912 C901 PLR0914 PLR0915 D417 - -from __future__ import annotations - -from collections.abc import Sequence -from typing import Any, Final, cast - -import dace -import numpy as np -from jax import core as jax_core -from typing_extensions import override - -from jace import translator, util - - -class ALUTranslator(translator.PrimitiveTranslator): - """ - This translator handles all arithmetic and logical operations. - - This translator will be reworked soon, it just exists that the initial PR can do anything at all!! - """ - - def __init__(self, prim_name: str, prim_tmpl: str) -> None: - """Initialize the `ALUTranslator`.""" - self._prim_name = prim_name - self._prim_tmpl = prim_tmpl - - @property - @override - def primitive(self) -> str: - return self._prim_name - - @override - def __call__( - self, - builder: translator.JaxprTranslationBuilder, - in_var_names: Sequence[str | None], - out_var_names: Sequence[str], - eqn: jax_core.JaxprEqn, - eqn_state: dace.SDFGState, - ) -> None: - """ - Perform the translation. - - Deepening on the shapes of the input the function will either create a Tasklet or a mapped Tasklet. - The translator is able to handle broadcasting with NumPy rules. - The function will always perform the translation inside the provided state. - - Args: - builder: The builder object of the translation. - in_var_names: List of the names of the arrays created inside the SDFG for the inpts or 'None' in case of a literal. - out_var_names: List of the names of the arrays created inside the SDFG for the outputs. - eqn: The JAX equation that is translated. - eqn_state: State into which the primitive's SDFG representation is constructed. - """ - assert self._prim_name == eqn.primitive.name - - # Determine what kind of input we got and how we should proceed. - is_scalar = len(util.get_jax_var_shape(eqn.outvars[0])) == 0 - input_scalars = [len(util.get_jax_var_shape(Inp)) == 0 for i, Inp in enumerate(eqn.invars)] - has_scalars_as_inputs = any(input_scalars) - has_some_literals = any(x is None for x in in_var_names) - inps_same_shape = all( - util.get_jax_var_shape(eqn.invars[0]) == util.get_jax_var_shape(eqn.invars[i]) - for i in range(1, len(eqn.invars)) - ) - - # We will now look which dimensions have to be broadcasted on which operator. - # I.e. in the dimensions in the lists below there will be no map iteration index. - dims_to_bcastl: list[int] = [] - dims_to_bcastr: list[int] = [] - - # Determine if and how we have to broadcast. - if inps_same_shape or is_scalar: - pass - - elif has_some_literals or has_scalars_as_inputs: - # This is essentially an array plus a scalar, that is eitehr a literal or a variable. - assert (not has_some_literals) or all( - util.get_jax_var_shape(invar) == util.get_jax_var_shape(eqn.outvars[0]) - for (invar, x) in zip(eqn.invars, in_var_names, strict=False) - if x is not None - ) - assert (not has_scalars_as_inputs) or all( - util.get_jax_var_shape(invar) in {util.get_jax_var_shape(eqn.outvars[0]), ()} - for (invar, x) in zip(eqn.invars, in_var_names, strict=False) - if x is not None - ) - - else: - # This is the general broadcasting case - # We assume that both inputs and the output have the same rank but different sizes in each dimension. - # It seems that JAX ensures this. - # We further assume that if the size in a dimension differs then one must have size 1. - # This is the size we broadcast over, i.e. conceptually replicated. - out_shps = tuple(util.get_jax_var_shape(eqn.outvars[0])) # Shape of the output - input_shpl = tuple( - util.get_jax_var_shape(eqn.invars[0]) - ) # Shape of the left/first input - input_shpr = tuple( - util.get_jax_var_shape(eqn.invars[1]) - ) # Shape of the right/second input - - if not ((len(input_shpl) == len(input_shpr)) and (len(out_shps) == len(input_shpr))): - raise NotImplementedError("Can not broadcast over different ranks.") - - for dim, (shp_lft, shp_rgt, out_shp) in enumerate( - zip(input_shpl, input_shpr, out_shps) - ): - if shp_lft == shp_rgt: - assert out_shp == shp_lft - elif shp_lft == 1: - assert shp_rgt == out_shp - dims_to_bcastl.append(dim) - elif shp_rgt == 1: - assert shp_lft == out_shp - dims_to_bcastr.append(dim) - else: - raise ValueError(f"Invalid shapes in dimension {dim} for broadcasting.") - - # Now we create the Tasklet in which the calculation is performed. - tskl_code: str = self._write_tasklet_code(in_var_names, eqn) - tskl_name: str = eqn.primitive.name - tskl_map_ranges: list[tuple[str, str]] = [ - (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) - ] - tskl_output: tuple[str, dace.Memlet] = None # type: ignore[assignment] - tskl_inputs: list[tuple[str, dace.Memlet] | tuple[None, None]] = [] - - # Generate the Memlets for the input. - for i, dims_to_bcast in zip(range(len(in_var_names)), [dims_to_bcastl, dims_to_bcastr]): - if in_var_names[i] is None: # Literal: No input needed. - tskl_inputs.append((None, None)) - continue - if input_scalars[i]: # Scalar - assert len(dims_to_bcast) == 0 - i_memlet = dace.Memlet.simple(in_var_names[i], "0") - else: # Array: We may have to broadcast - inputs_: list[str] = [] - for dim, (map_var, _) in enumerate(tskl_map_ranges): - if dim in dims_to_bcast: - inputs_.append("0") - else: - inputs_.append(map_var) - i_memlet = dace.Memlet.simple(in_var_names[i], ", ".join(inputs_)) - del inputs_ - tskl_inputs.append((f"__in{i}", i_memlet)) - - # Now generate the Memlets for the output - if is_scalar: - tskl_output = ("__out0", dace.Memlet.simple(out_var_names[0], "0")) - else: - tskl_output = ( - "__out0", - dace.Memlet.simple(out_var_names[0], ", ".join([X[0] for X in tskl_map_ranges])), - ) - - if is_scalar: - tskl_tasklet = eqn_state.add_tasklet( - tskl_name, - _list_to_dict(tskl_inputs).keys(), - _list_to_dict([tskl_output]).keys(), - tskl_code, - ) - for in_var, (in_connector, in_memlet) in zip(in_var_names, tskl_inputs, strict=False): - if in_var is None: # So access node for literal - continue - eqn_state.add_edge( - eqn_state.add_read(in_var), None, tskl_tasklet, in_connector, in_memlet - ) - eqn_state.add_edge( - tskl_tasklet, - tskl_output[0], - eqn_state.add_write(out_var_names[0]), - None, - tskl_output[1], - ) - else: - eqn_state.add_mapped_tasklet( - name=tskl_name, - map_ranges=_list_to_dict(tskl_map_ranges), - inputs=_list_to_dict(tskl_inputs), - code=tskl_code, - outputs=_list_to_dict([tskl_output]), - external_edges=True, - ) - - return eqn_state - - def _write_tasklet_code( - self, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn - ) -> str: - """ - This function generates the Tasklet code based on a primitive. - - The function will also perform literal substitution and parameter handling. - - Args: - in_var_names: The list of SDFG variables used as input. - """ - t_code = self._prim_tmpl - - # Now we handle Literal substitution - for i, in_var_name in enumerate(in_var_names): - if in_var_name is not None: - continue - - jax_in_var: jax_core.Literal = cast(jax_core.Literal, eqn.invars[i]) - if util.get_jax_var_shape(jax_in_var) == (): - t_val = jax_in_var.val - if isinstance(t_val, np.ndarray): - t_val = jax_in_var.val.max() # I do not know a better way in that case - t_code = t_code.replace(f"__in{i}", str(t_val)) - else: - raise ValueError( - f"Can not handle the literal case of shape: {util.get_jax_var_shape(jax_in_var)}" - ) - - # Now replace the parameters - if len(eqn.params) != 0: - t_code = t_code.format(**eqn.params) - - return t_code - - -def _list_to_dict(inp: Sequence[tuple[None | Any, Any]]) -> dict[Any, Any]: - """ - This method turns a `list` of pairs into a `dict` and applies a `None` filter. - - The function will only include pairs whose key, i.e. first element is not `None`. - """ - return {k: v for k, v in inp if k is not None} - - -# Contains all the templates for ALU operations. -_ALU_OPS_TASKLET_TEMPLATES: Final[dict[str, str]] = { - # Unary operations - "pos": "__out0 = +(__in0)", - "neg": "__out0 = -(__in0)", - "not": "__out0 = not (__in0)", - "floor": "__out0 = floor(__in0)", - "ceil": "__out0 = ceil(__in0)", - "round": "__out0 = round(__in0)", - "abs": "__out0 = abs(__in0)", - "sign": "__out0 = sign(__in0)", - "sqrt": "__out0 = sqrt(__in0)", - "log": "__out0 = log(__in0)", - "exp": "__out0 = exp(__in0)", - "integer_pow": "__out0 = (__in0)**({y})", # 'y' is a parameter of the primitive - "sin": "__out0 = sin(__in0)", - "asin": "__out0 = asin(__in0)", - "cos": "__out0 = cos(__in0)", - "acos": "__out0 = acos(__in0)", - "tan": "__out0 = tan(__in0)", - "atan": "__out0 = atan(__in0)", - "tanh": "__out0 = tanh(__in0)", - # Binary operations - "add": "__out0 = (__in0)+(__in1)", - "add_any": "__out0 = (__in0)+(__in1)", # No idea what makes `add_any` differ from `add` - "sub": "__out0 = (__in0)-(__in1)", - "mul": "__out0 = (__in0)*(__in1)", - "div": "__out0 = (__in0)/(__in1)", - "rem": "__out0 = (__in0)%(__in1)", - "and": "__out0 = (__in0) and (__in1)", - "or": "__out0 = (__in0) or (__in1)", - "pow": "__out0 = (__in0)**(__in1)", - "ipow": "__out0 = (__in0)**(int(__in1))", - "min": "__out0 = min(__in0, __in1)", - "max": "__out0 = max(__in0, __in1)", - "eq": "__out0 = __in0 == __in1", - "ne": "__out0 = __in0 != __in1", - "ge": "__out0 = __in0 >= __in1", - "gt": "__out0 = __in0 > __in1", - "le": "__out0 = __in0 <= __in1", - "lt": "__out0 = __in0 < __in1", -} - -for prim_name, prim_tmpl in _ALU_OPS_TASKLET_TEMPLATES.items(): - translator.register_primitive_translator(ALUTranslator(prim_name, prim_tmpl)) diff --git a/src/jace/translator/primitive_translators/arithmetic_logical_translators.py b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py new file mode 100644 index 0000000..28f9a3a --- /dev/null +++ b/src/jace/translator/primitive_translators/arithmetic_logical_translators.py @@ -0,0 +1,204 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Primitive translators related to all arithmetic, logical and comparison operations. + +Todo: + - Hijack Jax to inject a proper modulo operation. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ArithmeticOperationTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Translator for all arithmetic operations and comparisons. + + Args: + prim_name: The name of the primitive that should be handled. + tskl_tmpl: Template used for generating the Tasklet code. + + Note: + Logical and bitwise operations are implemented by `LogicalOperationTranslator`. + """ + + def __init__(self, prim_name: str, tskl_tmpl: str) -> None: + super().__init__(primitive_name=prim_name) + self._tskl_tmpl = tskl_tmpl + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + """Returns the code for the Tasklet, with all parameters replaced.""" + return self._tskl_tmpl.format(**eqn.params) + + +class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Translator for all logical operations. + + The reason why the logical operations are separated from the arithmetic + operations is quite complicated and in fact the whole thing is harder than + it should be. NumPy has two kinds of these operations, i.e. + `logical_{and, or, xor, not}()` and `bitwise_{and, or, xor, not}()`, but Jax + has only a single kind of logical operation, that operate in bitwise mode. + The first idea would be to use `ArithmeticOperationTranslator` with a template + such as `__out = __in0 & __in1` or `__out = ~__in0`. Since DaCe eventually + generates C++ code and C++ has a native bool type, and `true` is guaranteed + to be `1` and `false` equals `0`, this works for all operations except `not`, + as `~true` in C++ is essentially `~1`, which is again `true`! + Thus the `not` primitive must be handled separately. + + The solution to the problem is to introduce two templates, one used for the + bool context and one used in the integer context. This works because depending + if the `logical_*()` or `bitwise_*()` functions are used the input is either + of type bool or an integer. + + Args: + prim_name: The name of the primitive that should be handled. + bitwise_tmpl: The template used for the bitwise case. + logical_tmpl: The template used for the logical case. + + Note: + Since it does not make sense to single out `not` and keep the other + logical operations in `ArithmeticOperationTranslator` all of them are + handled by this class. + """ + + def __init__(self, prim_name: str, bitwise_tmpl: str, logical_tmpl: str) -> None: + super().__init__(primitive_name=prim_name) + self._bitwise_tmpl = bitwise_tmpl + self._logical_tmpl = logical_tmpl + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars): + return self._logical_tmpl + return self._bitwise_tmpl + + +# Maps the name of an arithmetic JAX primitive to the code template that is used to +# generate the body of the mapped tasklet. These are used to instantiate the +# `ArithmeticOperationTranslator` objects. +# fmt: off +_ARITMETIC_OPERATION_TEMPLATES: Final[dict[str, str]] = { + # Unary operations + "pos": "__out = +(__in0)", + "neg": "__out = -(__in0)", + + "floor": "__out = floor(__in0)", + "ceil": "__out = ceil(__in0)", + "round": "__out = round(__in0)", + + "abs": "__out = abs(__in0)", + "sign": "__out = sign(__in0)", + "exp": "__out = exp(__in0)", + "exp2": "__out = exp2(__in0)", + "expm1": "__out = expm1(__in0)", + "log": "__out = log(__in0)", + "log1p": "__out = log1p(__in0)", + "conj": "__out = conj(__in0)", + "sqrt": "__out = sqrt(__in0)", + "cbrt": "__out = cbrt(__in0)", + + "integer_pow": "__out = (__in0)**({y})", # 'y' is a parameter of the primitive + "is_finite": "__out = isfinite(__in0)", + + "sin": "__out = sin(__in0)", + "asin": "__out = asin(__in0)", + "cos": "__out = cos(__in0)", + "acos": "__out = acos(__in0)", + "tan": "__out = tan(__in0)", + "atan": "__out = atan(__in0)", + + "sinh": "__out = sinh(__in0)", + "asinh": "__out = asinh(__in0)", + "cosh": "__out = cosh(__in0)", + "acosh": "__out = acosh(__in0)", + "tanh": "__out = tanh(__in0)", + "atanh": "__out = atanh(__in0)", + + # Binary operations + "add": "__out = (__in0)+(__in1)", + "add_any": "__out = (__in0)+(__in1)", # No idea what makes `add_any` differ from `add` + "sub": "__out = (__in0)-(__in1)", + "mul": "__out = (__in0)*(__in1)", + "div": "__out = (__in0)/(__in1)", + "rem": "__out = (__in0)%(__in1)", + "pow": "__out = (__in0)**(__in1)", + "min": "__out = min((__in0), (__in1))", + "max": "__out = max((__in0), (__in1))", + + "eq": "__out = (__in0) == (__in1)", + "ne": "__out = (__in0) != (__in1)", + "ge": "__out = (__in0) >= (__in1)", + "gt": "__out = (__in0) > (__in1)", + "le": "__out = (__in0) <= (__in1)", + "lt": "__out = (__in0) < (__in1)", + + "atan2": "__out = atan2((__in0), (__in1))", + + "nextafter": "__out = nextafter((__in0), (__in1))", + + # Ternary operations + "clamp": "__out = ((__in0) if (__in1) < (__in0) else ((__in1) if (__in1) < (__in2) else (__in2)))" +} + + +# Maps the name of a logical primitive to the two code templates, first the integer +# case and second the boolean case, that are used to create the body of the mapped +# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators. +_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, dict[str, str]]] = { + "or": { + "bitwise_tmpl": "__out = (__in0) | (__in1)", + "logical_tmpl": "__out = (__in0) or (__in1)", + }, + "not": { + "bitwise_tmpl": "__out = ~(__in0)", + "logical_tmpl": "__out = not (__in0)", + }, + "and": { + "bitwise_tmpl": "__out = (__in0) & (__in1)", + "logical_tmpl": "__out = (__in0) and (__in1)", + }, + "xor": { + "bitwise_tmpl": "__out = (__in0) ^ (__in1)", + "logical_tmpl": "__out = (__in0) != (__in1)", + }, +} +# fmt: on + + +# Instantiate the arithmetic and logical translators from the templates. +for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items(): + translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl)) +for pname, ptmpl in _LOGICAL_OPERATION_TEMPLATES.items(): # type: ignore[assignment] # Type confusion + translator.register_primitive_translator(LogicalOperationTranslator(pname, **ptmpl)) # type: ignore[arg-type] # Type confusion diff --git a/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py new file mode 100644 index 0000000..d8bd388 --- /dev/null +++ b/src/jace/translator/primitive_translators/broadcast_in_dim_translator.py @@ -0,0 +1,64 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for broadcasting operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class BroadcastInDimTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `broadcast_in_dim` primitive. + + Essentially creates a copy tasklet, however, the memlets are made in such a + way that some dimensions are replicated. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="broadcast_in_dim") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + if in_var_names[0] is None: # Broadcast a literal (scalar) to a matrix. + return {} + subset_str = ( + ", ".join(tskl_ranges[bdim][0] for bdim in eqn.params["broadcast_dimensions"]) + if eqn.params["broadcast_dimensions"] + else "0" + ) + return {"__in0": dace.Memlet.simple(in_var_names[0], subset_str)} + + +translator.register_primitive_translator(BroadcastInDimTranslator()) diff --git a/src/jace/translator/primitive_translators/concatenate_translator.py b/src/jace/translator/primitive_translators/concatenate_translator.py new file mode 100644 index 0000000..b327bde --- /dev/null +++ b/src/jace/translator/primitive_translators/concatenate_translator.py @@ -0,0 +1,75 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for concatenation operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("concatenate") +def concatenate_translator( + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `concatenate` primitive. + + Each source array is copied by its own map, but all maps write to the same + access node. + + Args: + builder: The builder object of the translation; unused. + in_var_names: The SDFG variables used an input arguments in order as they + should be concatenated. + out_var_names: Names of SDFG variables that should be used as outputs. + eqn: The equation that should be translated, the concatenation dimensions + is read from the `dimension` parameter. + eqn_state: State into which the nested SDFG should be constructed. + """ + if any(in_var_name is None for in_var_name in in_var_names): + raise NotImplementedError("Concatenate: No literal inputs supported.") + + # Access node that is used by all maps. + output_nodes = {out_var_names[0]: eqn_state.add_write(out_var_names[0])} + + cat_dim = eqn.params["dimension"] + cat_offset = 0 + for i, in_var_name in enumerate(in_var_names): + input_shape = util.get_jax_var_shape(eqn.invars[i]) + + tskl_range = [(f"__dim{d}", f"0:{dim_size}") for d, dim_size in enumerate(input_shape)] + tskl_input_access = [it_var for it_var, _ in tskl_range] + + tskl_output_access = tskl_input_access.copy() + tskl_output_access[cat_dim] = f"{tskl_output_access[cat_dim]} + {cat_offset}" + + eqn_state.add_mapped_tasklet( + f"_concatenate_{out_var_names[0]}_{in_var_name}", + map_ranges=tskl_range, + inputs={"__in": dace.Memlet.simple(in_var_name, ", ".join(tskl_input_access))}, + code="__out = __in", + outputs={"__out": dace.Memlet.simple(out_var_names[0], ",".join(tskl_output_access))}, + output_nodes=output_nodes, + external_edges=True, + ) + cat_offset += input_shape[cat_dim] diff --git a/src/jace/translator/primitive_translators/conditions.py b/src/jace/translator/primitive_translators/conditions.py new file mode 100644 index 0000000..e13920b --- /dev/null +++ b/src/jace/translator/primitive_translators/conditions.py @@ -0,0 +1,127 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for condition operations, i.e. scalar `if` and `switch`.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import dace + +from jace import translator, util +from jace.translator import post_translation as ptranslation + + +if TYPE_CHECKING: + from jax._src import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("cond") +def condition_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> dace.SDFGState: + """ + Implements the translation of scalar conditional branches. + + This translator handles both `jax.lax.cond()` and `jax.lax.switch()` cases. + The sub expression of the branches are each translated into a separate nested + SDFG, each located in their own state. These state are then connected to the + joint state which is returned. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variables used an input arguments. First is the + selection variable. The remaining ones are passed to the branches as + inputs. + out_var_names: Names of SDFG variables that should be used as outputs. + eqn: The equation that should be translated. + eqn_state: State into which the nested SDFG should be constructed. + + Notes: + - According to the JAX documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html) + the selector is clamped. But according to XLA (https://openxla.org/xla/operation_semantics#conditional) + an out of range selector uses the last branch. JaCe conforms to JAX + semantic. + - After this function the terminal state of the `builder` is unspecific. + """ + if util.get_jax_var_dtype(eqn.invars[0]) is dace.bool_: + # XLA explicitly provides a binary form of the primitive + # (https://openxla.org/xla/operation_semantics#conditional) JAX however, + # does not seem to use it at the moment and instead forwards it to the + # integer implementation. + raise NotImplementedError("The boolean conditional primitive is not implemented.") + + # Used as prefix to give all additional states/variables a unique name. + name_pattern = eqn_state.name + + # To avoid special cases promote all symbols to constants. + branch_input_variable_names: list[str] = ptranslation.promote_literals_to_constants( + builder=builder, + var_names=in_var_names[1:], + jax_vars=eqn.invars[1:], + name_pattern=name_pattern, + ) + + # expressions of the branches. + branches: list[jax_core.ClosedJaxpr] = eqn.params["branches"] + + # Make sure that the selection variable is a DaCe symbol. + if in_var_names[0] is None: + literal_selection_value = str(util.get_jax_literal_value(eqn.invars[0])) + selection_symbol = f"min({len(branches)}, max(0, {literal_selection_value}))" + selection_state = eqn_state + else: + selection_variable_name = in_var_names[0] + selection_symbol = f"{selection_variable_name}_symb" + selection_state = builder.append_new_state( + label=f"{name_pattern}_fork", + assignments={ + selection_symbol: f"min({len(branches)}, max(0, {selection_variable_name}))" + }, + prev_state=eqn_state, + ) + + # Translate the subbranches, the branches are all connected from `selection_state`. + branch_states: list[dace.SDFGState] = [] + for i, branch_jaxpr in enumerate(branches): + branch_pattern = f"{name_pattern}_{{}}_branch_{i}" + branch_ctx = builder.translate_jaxpr(jaxpr=branch_jaxpr, name=branch_pattern.format("sdfg")) + + # The first time it is called it will update the builder's terminal state + # but since we will return the join state it will be updated later. But + # until then the terminal state of the builder is invalid. + branch_state = builder.append_new_state( + label=branch_pattern.format("state"), + condition=f"{selection_symbol} == {i}", + prev_state=selection_state, + ) + ptranslation.add_nested_sdfg( + state=branch_state, + child_ctx=branch_ctx, + parent_ctx=builder._ctx, + in_var_names=branch_input_variable_names, + out_var_names=out_var_names, + ) + branch_states.append(branch_state) + + # Connect all branch states to the join state + join_state = builder._ctx.sdfg.add_state(label=f"{name_pattern}__join_state") + for branch_state in branch_states: + builder.sdfg.add_edge( + branch_state, + join_state, + dace.sdfg.InterstateEdge(), + ) + + return join_state diff --git a/src/jace/translator/primitive_translators/convert_element_type_translator.py b/src/jace/translator/primitive_translators/convert_element_type_translator.py new file mode 100644 index 0000000..e1fb8e5 --- /dev/null +++ b/src/jace/translator/primitive_translators/convert_element_type_translator.py @@ -0,0 +1,82 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for type casting operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class ConvertElementTypeTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `convert_element_type` primitive. + + The primitive is implemented as a copy operation. However, the tasklet body + will perform the type conversion operation. + + Note: + The type to cast to is inferred from the output variable and the `new_dtype` + parameter of the equation is ignored. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="convert_element_type") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if in_var_names[0] is None: + raise NotImplementedError("'convert_element_type' is not supported for literals.") + + in_dtype = util.get_jax_var_dtype(eqn.invars[0]).type + in_dtype_s: str = in_dtype.__name__ + out_dtype = util.get_jax_var_dtype(eqn.outvars[0]).type + out_dtype_s: str = out_dtype.__name__ + + if in_dtype == out_dtype: + # JAX sometimes adds conversions which are not needed. In these cases + # make a copy out of it. + # TODO(phimuell): Create a Memlet instead. + return "__out = __in0" + + # A simple copy tasklet `__out = __in0` and rely on the implicit type + # conversion of the C++ compiler, is not enough. Due to a bug in DaCe + # (see https://github.com/spcl/dace/issues/1665) this conversion might be + # lost, thus we have to perform the conversion explicitly in the tasklet. + conv_code = "__in0" + + if in_dtype_s.startswith("bool"): + conv_code = f"(1 if {conv_code} else 0)" + if out_dtype_s.startswith("bool"): + conv_code = f"dace.bool_({conv_code})" + elif hasattr(dace.dtypes, out_dtype_s): + conv_code = f"dace.{out_dtype_s}({conv_code})" + else: + raise NotImplementedError( + f"Cannot convert '{in_dtype}' to '{out_dtype}' as this type is not known to DaCe." + ) + return f"__out = {conv_code}" + + +_ = translator.register_primitive_translator(ConvertElementTypeTranslator()) diff --git a/src/jace/translator/primitive_translators/copy_translator.py b/src/jace/translator/primitive_translators/copy_translator.py new file mode 100644 index 0000000..9e0d2d1 --- /dev/null +++ b/src/jace/translator/primitive_translators/copy_translator.py @@ -0,0 +1,94 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translators related to data movement operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace + +from jace import translator + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("copy") +def copy_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, # noqa: ARG001 [unused-function-argument] # Required by the interface. + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `copy` primitive. + + The copy is implemented by creating a memlet between the source and destination. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated; unused. + eqn_state: State into which the nested SDFG should be constructed. + + Todo: + Investigate if operation should expand to a map. + """ + assert in_var_names[0] is not None + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet.from_array( + in_var_names[0], + builder.arrays[in_var_names[0]], + ), + ) + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("device_put") +def device_put_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `device_put` primitive. + + In JAX this primitive is used to copy data between the host and the device, + in DaCe only memlets can do this. However, because of the way JaCe (currently) + operates (a computation is either fully on the host or on GPU), the `device_put` + primitive essentially decays to a copy. + + Args: + builder: The builder object of the translation. + in_var_names: The SDFG variable that acts as source. + out_var_names: The SDFG variable that acts as destination of the copy. + eqn: The equation that should be translated. + eqn_state: State into which the nested SDFG should be constructed. + """ + if not (eqn.params["device"] is None and eqn.params["src"] is None): + raise NotImplementedError( + f"Can only copy on the host, but not from {eqn.params['src']} to {eqn.params['device']}." + ) + copy_translator( + builder=builder, + in_var_names=in_var_names, + out_var_names=out_var_names, + eqn=eqn, + eqn_state=eqn_state, + ) diff --git a/src/jace/translator/primitive_translators/gather_translator.py b/src/jace/translator/primitive_translators/gather_translator.py new file mode 100644 index 0000000..51f5730 --- /dev/null +++ b/src/jace/translator/primitive_translators/gather_translator.py @@ -0,0 +1,195 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for indexing operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from jax import lax as jax_lax + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("gather") +def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any further. + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `gather` primitive. + + These primitive is used to implement the `array.at[...].get()` access. In the + end the primitive extracts patches/windows of a certain size, known as + `slice_size`, from an array, which is called source or input array. The start + points of these windows are given by another array, the so called index array. + + Args: + builder: The builder object that is active. + in_var_names: The names of the input variables, the first array is + assumed as source array and the second is the index array. + out_var_names: The names of the output variables. + eqn: The equation to translate. + eqn_state: The state in which we put the extraction. + + See Also: + https://www.tensorflow.org/xla/operation_semantics#gather + https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html + """ + out_name = out_var_names[0] + out_shape = util.get_jax_var_shape(eqn.outvars[0]) + src_name = in_var_names[0] + src_shape = util.get_jax_var_shape(eqn.invars[0]) + idx_name = in_var_names[1] + idx_shape = util.get_jax_var_shape(eqn.invars[1]) + dimension_numbers = eqn.params["dimension_numbers"] + + if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS: + raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.") + + # This is the size of the slice window that is copied. Its length is the rank + # of the source array, dimensions that are excluded from copying are listed + # in `collapsed_slice_dims`. + slice_sizes: Sequence[int] = eqn.params["slice_sizes"] + collapsed_slice_dims: Sequence[int] = dimension_numbers.collapsed_slice_dims + not_collapsed_slice_dims = tuple( + dim for dim in range(len(slice_sizes)) if dim not in collapsed_slice_dims + ) + assert len(slice_sizes) == len(src_shape) + + # The batch dimensions are used to iterate through the different slice windows + # (not inside them) thus they access the index array, with the exception of the + # last dimension, see below. + # NOTE: In pure XLA this last dimension is in certain cases optional, however, + # JAX adds it and our implementation relies on it. + batch_dims = tuple(d for d in range(len(out_shape)) if d not in dimension_numbers.offset_dims) + if (len(batch_dims) + 1) != len(idx_shape): + raise ValueError( + f"Expected that the index array has {len(batch_dims) + 1} dimensions, but it had {len(idx_shape)}." + ) + + # The last dimension of the index array is special, as it contains the actual + # start point for the slice windows when the dimension is only partially copied. + # Thus the last dimension must be seen as a list of start indexes and the other + # dimensions are used to enumerate the slice windows. The `start_index_map` + # associates each position in the last dimension with the corresponding + # dimension of the source array. + start_index_map: Sequence[int] = dimension_numbers.start_index_map + assert len(start_index_map) == idx_shape[-1] + + # The iteration variable of the final map can be divided into two parts or + # categories. The first part iterates through all the slice windows that are + # given through the index array. If a dimension is not fully copied then the + # start index of the window is given through the elements of the last dimensions + # of the index array. Map variables that are used for this use the pattern + # `__i{out_name}_gather{bd}`. The second kind of variables are used to copy the + # content of the slice windows themselves, these map variables follow the + # pattern `__i{i}`. + + # Because the offsets of the slice window (which are given by the elements of + # the last dimension of the index array) are variables and not symbols, they + # can not be included in the memlets. Instead we generate a tasklet that + # performs an indirect access and get all elements of the last dimension of the + # index array (with the names `__gather_{dim}`), together with the full source + # array as input. + + # Access pattern of the source array _inside_ the tasklet. + src_access_pattern: list[str] = [] + + # The map variables and their ranges of the second part implicit loop; the one + # that copy the content inside the window. + inside_window_map_ranges: list[tuple[str, str]] = [] + + for dim, slice_size in enumerate(slice_sizes): + # Order is important! + if dim not in start_index_map: + # This dimension is fully copied + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(inside_window_map_ranges[-1][0]) + assert dim in not_collapsed_slice_dims + + elif dim in collapsed_slice_dims: + # This dimension is only partially copied, but because it is collapsed, + # only a single element is copied. Thus the offset is only given by the + # what we read from the index array. + src_access_pattern.append(f"__gather_{dim}") + + else: + # This dimension is partially copied, but _not colapsed_. This the element + # that is read depends on the (static) offset of this window and the + # current position within the slicing window. + inside_window_map_ranges.append((f"__i{dim}", f"0:{slice_size}")) + src_access_pattern.append(f"__gather_{dim} + {inside_window_map_ranges[-1][0]}") + assert dim in not_collapsed_slice_dims + + # These are the map variables that are associated to the first implicit loop (the + # iteration over the index array, excluding the last dimension). + batch_map_ranges = [ + (f"__i{out_name}_gather{batch_dim}", f"0:{batch_loop_bound}") + for batch_dim, batch_loop_bound in zip(batch_dims, idx_shape[:-1]) + ] + assert len(batch_map_ranges) + len(inside_window_map_ranges) == len(out_shape) + + tasklet_inputs: dict[str, dace.Memlet] = {} + + # We need to pass the full array into the tasklet, however, we know that we + # will read only one element. + tasklet_inputs["__arr"] = dace.Memlet.simple( + data=src_name, + subset_str=", ".join(f"0:{size}" for size in src_shape), + num_accesses=1, + ) + + # The static offsets of the slice window, are given through the elements of the + # last dimensions of the index array. + for i, dim in enumerate(start_index_map): + tasklet_inputs[f"__gather_{dim}"] = dace.Memlet.simple( + data=idx_name, + subset_str=( + ", ".join(batch_loop_var for batch_loop_var, _ in batch_map_ranges) + f", {i}" + ), + ) + + # The output shape is given by the combination of the not collapsed slice sizes + # and the index array (without the last dimension) with some permutation. + # While the relative order of slice window does not change, `start_index_map` + # already applied a permutation, it might be interleaved with batch dimensions. + output_memlet_pattern: list[str] = [] + dim_counter = 0 + for dim in range(len(out_shape)): + if dim in batch_dims: + batch_loop_var = batch_map_ranges[batch_dims.index(dim)][0] + output_memlet_pattern.append(str(batch_loop_var)) + + else: + associated_map_idx = not_collapsed_slice_dims[dim_counter] + dim_counter += 1 + output_memlet_pattern.append(f"__i{associated_map_idx}") + assert dim_counter == len(not_collapsed_slice_dims) + + eqn_state.add_mapped_tasklet( + name=f"_gather_map_{out_name}", + map_ranges=batch_map_ranges + inside_window_map_ranges, + inputs=tasklet_inputs, + code="__out = __arr[" + ", ".join(src_access_pattern) + "]", + outputs={ + "__out": dace.Memlet.simple(data=out_name, subset_str=", ".join(output_memlet_pattern)) + }, + external_edges=True, + ) diff --git a/src/jace/translator/primitive_translators/iota_translator.py b/src/jace/translator/primitive_translators/iota_translator.py new file mode 100644 index 0000000..035caf7 --- /dev/null +++ b/src/jace/translator/primitive_translators/iota_translator.py @@ -0,0 +1,56 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for the `iota` primitive.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from typing_extensions import override + +from jace import translator +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + import dace + from jax import core as jax_core + + +class IotaTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `iota` primitive. + + Essentially, a very general `jnp.arange()` function. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="iota") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return f"__out = {tskl_ranges[eqn.params['dimension']][0]}" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + return {} + + +translator.register_primitive_translator(IotaTranslator()) diff --git a/src/jace/translator/primitive_translators/pjit_translator.py b/src/jace/translator/primitive_translators/pjit_translator.py new file mode 100644 index 0000000..95cb3d4 --- /dev/null +++ b/src/jace/translator/primitive_translators/pjit_translator.py @@ -0,0 +1,94 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator related handling nested Jaxpr operations.""" + +from __future__ import annotations + +import re +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from jax._src import sharding_impls as jax_sharding # noqa: PLC2701 [import-private-name] + +from jace import translator +from jace.translator import post_translation as ptranslation + + +if TYPE_CHECKING: + import dace + from jax._src import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("pjit") +def pjit_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `pjit` translator that handles nested Jaxpr. + + `pjit` primitives in JAX represents nested calls, for example the branches of a + conditional are nested Jaxpr. However, in JAX `pjit` is also used to indicate that + a computation should be done on the device or on sharded memory. + In case an input is a literal the translator will create a constant for it. + + Args: + builder: The builder object of the translation. + in_var_names: Names of the SDFG variables that should be used as inputs + inside the parent SDFG. + out_var_names: Names of SDFG variables that should be used as outputs + inside the parent SDFG. + eqn: The equation that contains the `pjit` primitive. + eqn_state: State into which the nested SDFG should be constructed. + + Note: + The translator ignores the `donated_invars`, the `keep_unused` and the + `inline` parameter and let's DaCe handle it. + """ + nested_jaxpr: jax_core.ClosedJaxpr = eqn.params["jaxpr"] + in_shardings = eqn.params["in_shardings"] + out_shardings = eqn.params["out_shardings"] + # "donated_invars", "keep_unused", "inline" parameters are just ignored + + if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings): + raise NotImplementedError("Currently 'pjit' does not support sharding in its input.") + if not all(out_sharding is jax_sharding.UNSPECIFIED for out_sharding in out_shardings): + raise NotImplementedError("Currently 'pjit' does not support sharding in its output.") + + # TODO(phimuell): Controlflow region and name + pjit_name = eqn.params["name"] + + # Name in SDFG must be unique, thus we mangle it, furthermore, we have to clean it. + sdfg_name = f"pjit_{re.subn('[^a-zA-Z0-9_]', '_', pjit_name)[0]}__{'_'.join(out_var_names)}" + + # Ensure that all inputs are SDFG variables + final_input_names = ptranslation.promote_literals_to_constants( + builder=builder, + var_names=in_var_names, + jax_vars=eqn.invars, + name_pattern=sdfg_name, + ) + + # Translate the nested expression + nested_context: translator.TranslationContext = builder.translate_jaxpr( + jaxpr=nested_jaxpr, + name=sdfg_name, + ) + + # Now lets add the nested SDFG + ptranslation.add_nested_sdfg( + state=eqn_state, + child_ctx=nested_context, + parent_ctx=builder._ctx, + in_var_names=final_input_names, + out_var_names=out_var_names, + ) diff --git a/src/jace/translator/primitive_translators/reshape_translator.py b/src/jace/translator/primitive_translators/reshape_translator.py new file mode 100644 index 0000000..79b9bb0 --- /dev/null +++ b/src/jace/translator/primitive_translators/reshape_translator.py @@ -0,0 +1,63 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for reshaping operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace + +from jace import translator, util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("reshape") +def reshape_translator( + builder: translator.JaxprTranslationBuilder, # noqa: ARG001 [unused-function-argument] # Required by the interface. + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `reshape` primitive. + + The function creates a memlet between the input (old shape) and output (final + shape). Because of this, it is best if both arrays do not have any paddings. + + Args: + builder: The builder object of the translation. + in_var_names: Name of the SDFG variable of the source array, + with the old shape. + out_var_names: Name of SDFG variable that acts as destination, + with the new shape. + eqn: The equation that contains the `pjit` primitive. + eqn_state: State into which the nested SDFG should be constructed. + + Note: + The optional `dimensions` parameters, which allows to permute the input, + is not supported. + """ + if eqn.params["dimensions"] is not None: + raise NotImplementedError("Currently 'dimensions' must be 'None'.") + eqn_state.add_nedge( + eqn_state.add_read(in_var_names[0]), + eqn_state.add_write(out_var_names[0]), + dace.Memlet( + data=in_var_names[0], + subset=", ".join(f"0:{size}" for size in util.get_jax_var_shape(eqn.invars[0])), + other_subset=", ".join(f"0:{size}" for size in eqn.params["new_sizes"]), + ), + ) diff --git a/src/jace/translator/primitive_translators/select_n_translator.py b/src/jace/translator/primitive_translators/select_n_translator.py new file mode 100644 index 0000000..aa96922 --- /dev/null +++ b/src/jace/translator/primitive_translators/select_n_translator.py @@ -0,0 +1,88 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for select operations, i.e. generalized `np.where()`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SelectNTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `select_n` primitive. + + The `select_n` primitive is a generalization of `np.where`, that can take an + arbitrary number of cases, which are selected by an integer predicate. + The behaviour is undefined if the predicate is out of bound. + + Note: + For a better understanding this function renames its input connectors. + The first one, which is the predicate, is renamed to `__cond` and the + others are renamed again to `__in{i}`, starting with zero. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="select_n") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + if len(in_var_names) == 3: # noqa: PLR2004 [magic-value-comparison] # Ternary conditional expression. + # The order is correct, since `False` is interpreted as `0`, + # which means "the first case". + return "__out = __in1 if __cond else __in0" + + return "\n".join( + ["if __cond == 0: __out = __in0"] + + [f"elif __cond == {i}: __out = __in{i}" for i in range(1, len(in_var_names) - 1)] + ) + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + return { + f"__in{i - 1}" if i else "__cond": dace.Memlet.simple( + in_var_name, ", ".join(f"{it_idx}" for it_idx, _ in tskl_ranges) + ) + for i, in_var_name in enumerate(in_var_names) + if in_var_name + } + + @override + def literal_substitution( + self, tskl_code: str, in_var_names: Sequence[str | None], eqn: jax_core.JaxprEqn + ) -> str: + assert in_var_names[0] # Condition can never be a literal. + for i, in_var_name in enumerate(in_var_names[1:]): + if in_var_name is None: + t_val = util.get_jax_literal_value(eqn.invars[i + 1]) + tskl_code = tskl_code.replace(f"__in{i}", str(t_val)) + return tskl_code + + +translator.register_primitive_translator(SelectNTranslator()) diff --git a/src/jace/translator/primitive_translators/slicing.py b/src/jace/translator/primitive_translators/slicing.py new file mode 100644 index 0000000..6d9ae26 --- /dev/null +++ b/src/jace/translator/primitive_translators/slicing.py @@ -0,0 +1,187 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translators for slicing operations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SlicingTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `slice` primitive. + + The `slice` primitive represents the static case of slicing, i.e. a fixed + window starting from a fixed starting point. + The slicing is implemented by performing a partial copy. + + Note: + Slices are essentially optimization barriers as they can not be fused + with Maps _before_ them. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="slice") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + strides: Sequence[int] = ( + eqn.params["strides"] if eqn.params["strides"] else ((1,) * len(tskl_ranges)) + ) + start_indices: Sequence[int] = eqn.params["start_indices"] # Fist index to slice + return { + "__in0": dace.Memlet.simple( + in_var_names[0], + ", ".join( + f"{start_index} + ({it_idx} * {stride})" + for (it_idx, _), start_index, stride in zip(tskl_ranges, start_indices, strides) + ), + ) + } + + +@translator.register_primitive_translator() +@translator.make_primitive_translator("dynamic_slice") +def dynamic_slicing_translator( + builder: translator.JaxprTranslationBuilder, + in_var_names: Sequence[str | None], + out_var_names: Sequence[str], + eqn: jax_core.JaxprEqn, + eqn_state: dace.SDFGState, +) -> None: + """ + Implements the `dynamic_slice` primitive. + + Dynamic slicing (see: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dynamic_slice.html) + performs a slicing of a _fixed_ window, but the start of the window is defined + through some input variables. Furthermore, if the window would overrun the + start indexes are adjusted. + + Todo: + - Prevent that the modified start indexes are promoted to symbols, + to ensure mergability. + """ + assert in_var_names[0] + assert len(in_var_names) == len(util.get_jax_var_shape(eqn.invars[0])) + 1 + + window_sizes: Sequence[int] = eqn.params["slice_sizes"] + + # Maps the variable name, that stores the _adjusted_ start index of the window + # of a dimension to the access node that holds the value. Needed to ensure the + # correct order of computation. + in_access: dict[str, dace.nodes.AccessNode] = {} + + # Name of the variables (DaCe arrays) from where we get the start index of the + # window or the value itself if it is a literal (`None` means not yet processed). + # The first input argument is always the array we slice from. + window_start_indices: list[str | None] = list(in_var_names[1:]) + + for dim, (window_start_index, dim_size, window_size) in enumerate( + zip(window_start_indices, util.get_jax_var_shape(eqn.invars[0]), window_sizes) + ): + if window_start_index is None: + # The start is a literal value. + # Jax does not adjust the literals on its own so we have to do it. + raw_window_start = int(util.get_jax_literal_value(eqn.invars[dim + 1])) # type: ignore[arg-type] # type confusion + adjusted_window_start = min(dim_size, raw_window_start + window_size) - window_size + window_start_indices[dim] = str(adjusted_window_start) + + else: + tasklet = dace.nodes.Tasklet( + label=f"adjustment_of_slice_start_{window_start_index}_for_{out_var_names[0]}", + inputs={"unadjusted_start_idx": None}, + outputs={"adjusted_start_idx": None}, + code=f"adjusted_start_idx = min(unadjusted_start_idx + {window_size}, {dim_size}) - {window_size}", + ) + # Name of the variable holding the (adjusted) start of the window. + # It is important that this name is also used for the dynamic map range + # symbols created below. This prevents some errors if DaCe promotes them + # to symbols and does not handle the DMR correctly. + # (see https://github.com/spcl/dace/issues/1665) + new_start_idx_var_name = builder.add_array( + eqn.invars[dim + 1], name_prefix="__jace_adapted_start_idx_" + ) + new_start_idx_acc = eqn_state.add_access(new_start_idx_var_name) + + eqn_state.add_edge( + eqn_state.add_read(window_start_index), + None, + tasklet, + "unadjusted_start_idx", + dace.Memlet.simple(window_start_index, "0"), + ) + eqn_state.add_edge( + tasklet, + "adjusted_start_idx", + new_start_idx_acc, + None, + dace.Memlet.simple(new_start_idx_var_name, "0"), + ) + window_start_indices[dim] = new_start_idx_var_name + in_access[new_start_idx_var_name] = new_start_idx_acc + + tskl_ranges: list[tuple[str, str]] = [ + (f"__i{dim}", f"0:{N}") for dim, N in enumerate(util.get_jax_var_shape(eqn.outvars[0])) + ] + tskl_input = dace.Memlet.simple( + in_var_names[0], + ", ".join( + f"{it_var} + {offset_symbol_name}" + for (it_var, _), offset_symbol_name in zip(tskl_ranges, window_start_indices) + ), + ) + tskl_output = dace.Memlet.simple(out_var_names[0], ", ".join(name for name, _ in tskl_ranges)) + _, map_entry, _ = eqn_state.add_mapped_tasklet( + name=f"dynamic_slice_{out_var_names[0]}", + map_ranges=tskl_ranges, + inputs={"__in": tskl_input}, + code="__out = __in", + outputs={"__out": tskl_output}, + external_edges=True, + ) + + # Create the dynamic ranges, i.e. read the start indexes for the window + # from variable and create symbols out of it, without an interstate edge. + for window_start_index_name, windows_start_access_node in in_access.items(): + eqn_state.add_edge( + windows_start_access_node, + None, + map_entry, + window_start_index_name, + dace.Memlet.simple(window_start_index_name, "0"), + ) + map_entry.add_in_connector(window_start_index_name) + + +translator.register_primitive_translator(SlicingTranslator()) diff --git a/src/jace/translator/primitive_translators/squeeze_translator.py b/src/jace/translator/primitive_translators/squeeze_translator.py new file mode 100644 index 0000000..dbaa548 --- /dev/null +++ b/src/jace/translator/primitive_translators/squeeze_translator.py @@ -0,0 +1,69 @@ +# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) +# +# Copyright (c) 2024, ETH Zurich +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Primitive translator for squeezing (the removal of size 1 dimensions) operations.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING + +import dace +from typing_extensions import override + +from jace import translator, util +from jace.translator import mapped_operation_base_translator as mapped_base + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from jax import core as jax_core + + +class SqueezeTranslator(mapped_base.MappedOperationTranslatorBase): + """ + Implements the `squeeze` primitive. + + The primitives allows to remove dimensions of size one. Essentially + equivalent to `np.squeeze` and the inverse to `np.expand_dims()`. + """ + + def __init__(self) -> None: + super().__init__(primitive_name="squeeze") + + @override + def write_tasklet_code( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> str: + return "__out = __in0" + + @override + def make_input_memlets( + self, + tskl_ranges: Sequence[tuple[str, str]], + in_var_names: Sequence[str | None], + eqn: jax_core.JaxprEqn, + ) -> dict[str, dace.Memlet]: + dims_to_delete: Sequence[str] = eqn.params["dimensions"] + in_rank: int = len(util.get_jax_var_shape(eqn.invars[0])) + cnt = itertools.count(0) + return { + "__in0": dace.Memlet.simple( + in_var_names[0], + ", ".join( + "0" if dim in dims_to_delete else tskl_ranges[next(cnt)][0] + for dim in range(in_rank) + ), + ) + } + + +translator.register_primitive_translator(SqueezeTranslator()) diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index bc2de21..7c9f2f0 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -81,6 +81,27 @@ def __eq__(self, other: Any) -> bool: return NotImplemented return id(self) == id(other) + @classmethod + def from_atom( + cls, + jax_var: jax_core.Atom, + name: str | None, + ) -> JaCeVar: + """ + Generates a `JaCeVar` from the JAX variable `jax_var`. + + If `jax_var` is a literal its value is ignored. + + Args: + jax_var: The variable to process. + name: The optional name of the variable. + """ + return cls( + shape=get_jax_var_shape(jax_var), + dtype=get_jax_var_dtype(jax_var), + name=name, + ) + def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: """Returns the name of `jax_var` as a string.""" diff --git a/tests/test_subtranslator_helper.py b/tests/test_subtranslator_helper.py index a4c4ad9..52672b0 100644 --- a/tests/test_subtranslator_helper.py +++ b/tests/test_subtranslator_helper.py @@ -75,7 +75,7 @@ def fake_add_translator(*args: Any, **kwargs: Any) -> None: # noqa: ARG001 def test_are_subtranslators_imported(): """Tests if something is inside the list of subtranslators.""" # Must be adapted if new primitives are implemented. - assert len(get_registered_primitive_translators()) == 37 + assert len(get_registered_primitive_translators()) > 0 @pytest.mark.usefixtures("no_builtin_translators")