Skip to content

Commit

Permalink
Applied Enriques primarly fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Sep 26, 2024
1 parent d88752a commit 71f9f86
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 35 deletions.
21 changes: 2 additions & 19 deletions src/jace/translator/jaxpr_translator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,24 +179,6 @@ def append_new_state(
self._ctx.terminal_state = new_state
return new_state

def add_orphan_state(
self,
label: str,
) -> dace.SDFGState:
"""
Add a new orphan state to the SDFG.
The state is not connected to any other state, nor it is the new start state.
Except you know what you are doing you should not use this function and
instead use `self.append_new_state()`.
Args:
label: The name of the state.
"""
if not self.is_allocated():
raise RuntimeError("Builder is not allocated.")
return self._ctx.sdfg.add_state(label=label, is_start_block=False)

@property
def arrays(self) -> Mapping[str, dace_data.Data]:
"""
Expand Down Expand Up @@ -520,7 +502,8 @@ def _allocate_translation_ctx(
@property
def _ctx(self) -> TranslationContext:
"""Returns the currently active translation context."""
assert len(self._ctx_stack) != 0, "No context is active."
if not self.is_allocated():
raise RuntimeError("The context is not allocated.")
return self._ctx_stack[-1]

def _clear_translation_ctx(self) -> TranslationContext | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ class LogicalOperationTranslator(mapped_base.MappedOperationTranslatorBase):
Args:
prim_name: The name of the primitive that should be handled.
int_tmpl: The template used for the integer case.
bool_tmpl: The template used for the bool case.
bitwise_tmpl: The template used for the bitwise case.
logical_tmpl: The template used for the logical case.
Note:
Since it does not make sense to single out `not` and keep the other
logical operations in `ArithmeticOperationTranslator` all of them are
handled by this class.
"""

def __init__(self, prim_name: str, int_tmpl: str, bool_tmpl: str) -> None:
def __init__(self, prim_name: str, bitwise_tmpl: str, logical_tmpl: str) -> None:
super().__init__(primitive_name=prim_name)
self._int_tmpl = int_tmpl
self._bool_tmpl = bool_tmpl
self._bitwise_tmpl = bitwise_tmpl
self._logical_tmpl = logical_tmpl

@override
def write_tasklet_code(
Expand All @@ -101,8 +101,8 @@ def write_tasklet_code(
eqn: jax_core.JaxprEqn,
) -> str:
if all(util.get_jax_var_dtype(invar) is dace.bool_ for invar in eqn.invars):
return self._bool_tmpl
return self._int_tmpl
return self._logical_tmpl
return self._bitwise_tmpl


# Maps the name of an arithmetic JAX primitive to the code template that is used to
Expand Down Expand Up @@ -176,17 +176,29 @@ def write_tasklet_code(
# Maps the name of a logical primitive to the two code templates, first the integer
# case and second the boolean case, that are used to create the body of the mapped
# tasklet. They are used to instantiate the `LogicalOperationTranslator` translators.
_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, tuple[str, str]]] = {
"or": ("__out = (__in0) | (__in1)", "__out = (__in0) or (__in1)"),
"not": ("__out = ~(__in0)", "__out = not (__in0)"),
"and": ("__out = (__in0) & (__in1)", "__out = (__in0) and (__in1)"),
"xor": ("__out = (__in0) ^ (__in1)", "__out = (__in0) != (__in1)"),
_LOGICAL_OPERATION_TEMPLATES: Final[dict[str, dict[str, str]]] = {
"or": {
"bitwise_tmpl": "__out = (__in0) | (__in1)",
"logical_tmpl": "__out = (__in0) or (__in1)",
},
"not": {
"bitwise_tmpl": "__out = ~(__in0)",
"logical_tmpl": "__out = not (__in0)",
},
"and": {
"bitwise_tmpl": "__out = (__in0) & (__in1)",
"logical_tmpl": "__out = (__in0) and (__in1)",
},
"xor": {
"bitwise_tmpl": "__out = (__in0) ^ (__in1)",
"logical_tmpl": "__out = (__in0) != (__in1)",
},
}
# fmt: on


# Instantiate the arithmetic and logical translators from the templates.
for pname, ptmpl in _ARITMETIC_OPERATION_TEMPLATES.items():
translator.register_primitive_translator(ArithmeticOperationTranslator(pname, ptmpl))
for pname, (itmpl, btmpl) in _LOGICAL_OPERATION_TEMPLATES.items():
translator.register_primitive_translator(LogicalOperationTranslator(pname, itmpl, btmpl))
for pname, ptmpl in _LOGICAL_OPERATION_TEMPLATES.items(): # type: ignore[assignment] # Type confusion
translator.register_primitive_translator(LogicalOperationTranslator(pname, **ptmpl)) # type: ignore[arg-type] # Type confusion
2 changes: 1 addition & 1 deletion src/jace/translator/primitive_translators/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def condition_translator(
branch_states.append(branch_state)

# Connect all branch states to the join state
join_state = builder.add_orphan_state(f"{name_pattern}__join_state")
join_state = builder._ctx.sdfg.add_state(label=f"{name_pattern}__join_state")
for branch_state in branch_states:
builder.sdfg.add_edge(
branch_state,
Expand Down
2 changes: 1 addition & 1 deletion src/jace/translator/primitive_translators/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def make_input_memlets(
eqn: jax_core.JaxprEqn,
) -> dict[str, dace.Memlet]:
strides: Sequence[int] = (
((1,) * len(tskl_ranges)) if eqn.params["strides"] is None else eqn.params["strides"]
eqn.params["strides"] if eqn.params["strides"] else ((1,) * len(tskl_ranges))
)
start_indices: Sequence[int] = eqn.params["start_indices"] # Fist index to slice
return {
Expand Down

0 comments on commit 71f9f86

Please sign in to comment.