Skip to content

Commit

Permalink
feat: graph tranlsation
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Oct 20, 2024
1 parent 6c15dac commit 05c60f9
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 200 deletions.
4 changes: 2 additions & 2 deletions jaxadi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._compile import lower
from ._convert import convert
from ._translate import translate
from ._legacy_translate import translate as legacy_translate
from ._graph import translate as graph_translate
from ._expand import translate as expand_translate
from ._declare import declare
9 changes: 7 additions & 2 deletions jaxadi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections.abc import Callable

from ._declare import declare
from ._translate import translate
from ._graph import translate as graph_translate
from ._expand import translate as expand_translate
from ._compile import compile as compile_fn


def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
def convert(casadi_fn: Function, translate=None, compile=False) -> Callable[..., Any]:
"""
Convert given casadi function into python
callable based on JAX backend, optionally
Expand All @@ -17,6 +18,10 @@ def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
:param compile (bool): Whether to AOT compile function
:return (Callable[..., Any]): Resulting python function
"""

if translate is None:
translate = graph_translate

jax_str = translate(casadi_fn)
jax_fn = declare(jax_str)

Expand Down
29 changes: 29 additions & 0 deletions jaxadi/_stages.py → jaxadi/_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,32 @@ def squeeze(stages: List[Stage]) -> List[Stage]:

cmd = combine_outputs(new_stages)
return cmd

def translate(func: Function, add_jit=False, add_import=False) -> str:
stages = stage_generator(func)
stages = squeeze(stages)
# get information about casadi function
n_out = func.n_out() # number of outputs in the function

# get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]

# generate string with complete code
codegen = ""
if add_import:
codegen += "import jax\nimport jax.numpy as jnp\n\n"
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
# combine all inputs into a single list
codegen += " inputs = jnp.expand_dims(jnp.array(args), axis=-1)\n"
# output variables
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"

# for stage in stages:
# codegen += stage.codegen()
codegen += stages

# footer
codegen += "\n return outputs\n"

return codegen
146 changes: 146 additions & 0 deletions jaxadi/_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
This module is supposed to implement graph
creation, traversion, code-generation and
compression/fusion if necessary/possible
"""
from casadi import Function
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
from collections import deque

from ._ops import OP_JAX_VALUE_DICT

def sort_by_height(graph, antigraph, heights):
nodes = [[] for i in range(max(heights) + 1)]
for i, h in enumerate(heights):
nodes[h].append(i)

return nodes

def codegen(graph, antigraph, heights, output_map, values):
sorted_nodes = sort_by_height(graph, antigraph, heights)
code = ""
outputs = {}
for layer in sorted_nodes:
indices = []
assignment = "["
for node in layer:
if node in output_map:
oo = output_map[node]
if outputs.get(oo[0], None) is None:
outputs[oo[0]] = {
'rows': [],
'cols': [],
'values': []
}
outputs[oo[0]]['rows'].append(oo[1])
outputs[oo[0]]['cols'].append(oo[2])
outputs[oo[0]]['values'].append(values[node])
else:
if len(assignment) > 1:
assignment += ", "
assignment += values[node]
indices += [node]
code += f" work = work.at[jnp.array({indices})].set({assignment}])\n"

for k, v in outputs.items():
code += f" outputs[{k}] = outputs[{k}].at[({v['rows']}, {v['cols']})].set([{', '.join(v['values'])}])\n"

return code

def compute_heights(func, graph, antigraph):
heights = [0 for _ in range(len(graph))]
current_layer = set()
next_layer = set()
# queue = deque()

for i in range(func.n_instructions()):
op = func.instruction_id(i)
if op == OP_INPUT:
current_layer.add(i)

while current_layer:
instr = current_layer.pop()
for parent in antigraph[instr]:
heights[instr] = max(heights[instr], heights[parent] + 1)
for child in graph[instr]:
next_layer.add(child)

if not current_layer:
current_layer, next_layer = next_layer, current_layer

return heights

def create_graph(func: Function):
N = func.n_instructions()
graph = [[] for _ in range(N)]
values = [None for _ in range(N)]
antigraph = [[] for _ in range(N)]
output_map = {}
workers = [None for _ in range(func.sz_w())]

for i in range(N):
op = func.instruction_id(i)
o_idx = func.instruction_output(i)
i_idx = func.instruction_input(i)

if op == OP_CONST:
values[i] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(func.instruction_constant(i)) + "])"
workers[o_idx[0]] = i
elif op == OP_INPUT:
this_shape = func.size_in(i_idx[0])
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX

values[i] = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
workers[o_idx[0]] = i
elif op == OP_OUTPUT:
rows, cols = func.size_out(o_idx[0])
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
output_map[i] = (o_idx[0], row_number, column_number)
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
elif op == OP_SQ:
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
graph[workers[i_idx[0]]].append(i)
antigraph[i].append(workers[i_idx[0]])
workers[o_idx[0]] = i
elif OP_JAX_VALUE_DICT[op].count("}") == 2:
w_id0 = workers[i_idx[0]]
w_id1 = workers[i_idx[1]]
graph[w_id0].append(i)
graph[w_id1].append(i)
antigraph[i].extend([w_id0, w_id1])
values[i] = OP_JAX_VALUE_DICT[op].format(w_id0, w_id1)
workers[o_idx[0]] = i
elif OP_JAX_VALUE_DICT[op].count("}") == 1:
graph[workers[i_idx[0]]].append(i)
antigraph[i].append(workers[i_idx[0]])
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
workers[o_idx[0]] = i
else:
raise Exception("Unknown CasADi operation: " + str(op))


return graph, antigraph, output_map, values


def translate(func: Function, add_jit=False, add_import=False):
graph, antigraph, output_map, values = create_graph(func)
heights = compute_heights(func, graph, antigraph)

code = ""
if add_import:
code += "import jax\nimport jax.numpy as jnp\n\n"
if add_jit:
code += "@jax.jit\n"
code += f"def evaluate_{func.name()}(*args):\n"
code += " inputs = [jnp.expand_dims(jnp.array(arg), axis=-1) for arg in args]\n"
code += f" outputs = [jnp.zeros(out) for out in {[func.size_out(i) for i in range(func.n_out())]}]\n"
code += f" work = jnp.zeros(({func.n_instructions()}, 1))\n"
code += codegen(graph, antigraph, heights, output_map, values)
code += " return outputs"

return code


163 changes: 0 additions & 163 deletions jaxadi/_legacy_translate.py

This file was deleted.

2 changes: 1 addition & 1 deletion jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@
OP_ATANH: "jnp.arctanh(work[{0}])",
OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])",
OP_CONST: "{0:.16f}",
OP_INPUT: "inputs[{0}, {1}, {2}]",
OP_INPUT: "inputs[{0}][{1}, {2}]",
OP_OUTPUT: "work[{0}][0]",
}
Loading

0 comments on commit 05c60f9

Please sign in to comment.