diff --git a/jaxadi/_convert.py b/jaxadi/_convert.py index 2ce714f..e0ab5e7 100644 --- a/jaxadi/_convert.py +++ b/jaxadi/_convert.py @@ -20,7 +20,7 @@ def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., """ if translate is None: translate = graph_translate - + jax_str = translate(casadi_fn) jax_fn = declare(jax_str) diff --git a/jaxadi/_expand.py b/jaxadi/_expand.py index 84c285f..3c7b887 100644 --- a/jaxadi/_expand.py +++ b/jaxadi/_expand.py @@ -201,6 +201,7 @@ def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]: cmd = combine_outputs(new_stages) return cmd + def translate(func: Function, add_jit=False, add_import=False) -> str: stages = stage_generator(func) stages = squeeze(stages) diff --git a/jaxadi/_graph.py b/jaxadi/_graph.py index 607dd4c..84584e2 100644 --- a/jaxadi/_graph.py +++ b/jaxadi/_graph.py @@ -3,12 +3,14 @@ creation, traversion, code-generation and compression/fusion if necessary/possible """ + from casadi import Function from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function from collections import deque from ._ops import OP_JAX_VALUE_DICT + def sort_by_height(graph, antigraph, heights): nodes = [[] for i in range(max(heights) + 1)] for i, h in enumerate(heights): @@ -16,6 +18,7 @@ def sort_by_height(graph, antigraph, heights): return nodes + def codegen(graph, antigraph, heights, output_map, values): sorted_nodes = sort_by_height(graph, antigraph, heights) code = "" @@ -27,14 +30,10 @@ def codegen(graph, antigraph, heights, output_map, values): if node in output_map: oo = output_map[node] if outputs.get(oo[0], None) is None: - outputs[oo[0]] = { - 'rows': [], - 'cols': [], - 'values': [] - } - outputs[oo[0]]['rows'].append(oo[1]) - outputs[oo[0]]['cols'].append(oo[2]) - outputs[oo[0]]['values'].append(values[node]) + outputs[oo[0]] = {"rows": [], "cols": [], "values": []} + outputs[oo[0]]["rows"].append(oo[1]) + outputs[oo[0]]["cols"].append(oo[2]) + outputs[oo[0]]["values"].append(values[node]) else: if len(assignment) > 1: assignment += ", " @@ -47,6 +46,7 @@ def codegen(graph, antigraph, heights, output_map, values): return code + def compute_heights(func, graph, antigraph): heights = [0 for _ in range(len(graph))] current_layer = set() @@ -70,6 +70,7 @@ def compute_heights(func, graph, antigraph): return heights + def create_graph(func: Function): N = func.n_instructions() graph = [[] for _ in range(N)] @@ -121,7 +122,6 @@ def create_graph(func: Function): else: raise Exception("Unknown CasADi operation: " + str(op)) - return graph, antigraph, output_map, values @@ -142,5 +142,3 @@ def translate(func: Function, add_jit=False, add_import=False): code += " return outputs" return code - - diff --git a/jaxadi/_ops.py b/jaxadi/_ops.py index 762572b..86cc970 100644 --- a/jaxadi/_ops.py +++ b/jaxadi/_ops.py @@ -47,6 +47,7 @@ OP_TANH, OP_TWICE, ) + OP_JAX_VALUE_DICT = { OP_ASSIGN: "work[{0}]", OP_ADD: "work[{0}] + work[{1}]",