Skip to content

Commit

Permalink
fix: pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Oct 20, 2024
1 parent d56d48c commit d2ca1ad
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion jaxadi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[...,
"""
if translate is None:
translate = graph_translate

jax_str = translate(casadi_fn)
jax_fn = declare(jax_str)

Expand Down
1 change: 1 addition & 0 deletions jaxadi/_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]:
cmd = combine_outputs(new_stages)
return cmd


def translate(func: Function, add_jit=False, add_import=False) -> str:
stages = stage_generator(func)
stages = squeeze(stages)
Expand Down
20 changes: 9 additions & 11 deletions jaxadi/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,22 @@
creation, traversion, code-generation and
compression/fusion if necessary/possible
"""

from casadi import Function
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
from collections import deque

from ._ops import OP_JAX_VALUE_DICT


def sort_by_height(graph, antigraph, heights):
nodes = [[] for i in range(max(heights) + 1)]
for i, h in enumerate(heights):
nodes[h].append(i)

return nodes


def codegen(graph, antigraph, heights, output_map, values):
sorted_nodes = sort_by_height(graph, antigraph, heights)
code = ""
Expand All @@ -27,14 +30,10 @@ def codegen(graph, antigraph, heights, output_map, values):
if node in output_map:
oo = output_map[node]
if outputs.get(oo[0], None) is None:
outputs[oo[0]] = {
'rows': [],
'cols': [],
'values': []
}
outputs[oo[0]]['rows'].append(oo[1])
outputs[oo[0]]['cols'].append(oo[2])
outputs[oo[0]]['values'].append(values[node])
outputs[oo[0]] = {"rows": [], "cols": [], "values": []}
outputs[oo[0]]["rows"].append(oo[1])
outputs[oo[0]]["cols"].append(oo[2])
outputs[oo[0]]["values"].append(values[node])
else:
if len(assignment) > 1:
assignment += ", "
Expand All @@ -47,6 +46,7 @@ def codegen(graph, antigraph, heights, output_map, values):

return code


def compute_heights(func, graph, antigraph):
heights = [0 for _ in range(len(graph))]
current_layer = set()
Expand All @@ -70,6 +70,7 @@ def compute_heights(func, graph, antigraph):

return heights


def create_graph(func: Function):
N = func.n_instructions()
graph = [[] for _ in range(N)]
Expand Down Expand Up @@ -121,7 +122,6 @@ def create_graph(func: Function):
else:
raise Exception("Unknown CasADi operation: " + str(op))


return graph, antigraph, output_map, values


Expand All @@ -142,5 +142,3 @@ def translate(func: Function, add_jit=False, add_import=False):
code += " return outputs"

return code


1 change: 1 addition & 0 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
OP_TANH,
OP_TWICE,
)

OP_JAX_VALUE_DICT = {
OP_ASSIGN: "work[{0}]",
OP_ADD: "work[{0}] + work[{1}]",
Expand Down

0 comments on commit d2ca1ad

Please sign in to comment.