Skip to content

Commit

Permalink
feat: sequential translation
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 17, 2024
1 parent 4a0ddb3 commit 9f476d5
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 85 deletions.
4 changes: 2 additions & 2 deletions jaxadi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._compile import compile as compile_fn


def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
def convert(casadi_fn: Function, compile=False, num_threads=1) -> Callable[..., Any]:
"""
Convert given casadi function into python
callable based on JAX backend, optionally
Expand All @@ -17,7 +17,7 @@ def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]:
:param compile (bool): Whether to AOT compile function
:return (Callable[..., Any]): Resulting python function
"""
jax_str = translate(casadi_fn)
jax_str = translate(casadi_fn, num_threads=num_threads)
jax_fn = declare(jax_str)

if compile:
Expand Down
92 changes: 46 additions & 46 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,51 +49,51 @@
)

OP_JAX_VALUE_DICT = {
OP_ASSIGN: "work[{0}]",
OP_ADD: "work[{0}] + work[{1}]",
OP_SUB: "work[{0}] - work[{1}]",
OP_MUL: "work[{0}] * work[{1}]",
OP_DIV: "work[{0}] / work[{1}]",
OP_NEG: "-work[{0}]",
OP_EXP: "jnp.exp(work[{0}])",
OP_LOG: "jnp.log(work[{0}])",
OP_POW: "jnp.power(work[{0}], work[{1}])",
OP_CONSTPOW: "jnp.power(work[{0}], work[{1}])",
OP_SQRT: "jnp.sqrt(work[{0}])",
OP_SQ: "work[{0}] * work[{0}]",
OP_TWICE: "2 * work[{0}]",
OP_SIN: "jnp.sin(work[{0}])",
OP_COS: "jnp.cos(work[{0}])",
OP_TAN: "jnp.tan(work[{0}])",
OP_ASIN: "jnp.arcsin(work[{0}])",
OP_ACOS: "jnp.arccos(work[{0}])",
OP_ATAN: "jnp.arctan(work[{0}])",
OP_LT: "work[{0}] < work[{1}]",
OP_LE: "work[{0}] <= work[{1}]",
OP_EQ: "work[{0}] == work[{1}]",
OP_NE: "work[{0}] != work[{1}]",
OP_NOT: "jnp.logical_not(work[{0}])",
OP_AND: "jnp.logical_and(work[{0}], work[{1}])",
OP_OR: "jnp.logical_or(work[{0}], work[{1}])",
OP_FLOOR: "jnp.floor(work[{0}])",
OP_CEIL: "jnp.ceil(work[{0}])",
OP_FMOD: "jnp.fmod(work[{0}], work[{1}])",
OP_FABS: "jnp.abs(work[{0}])",
OP_SIGN: "jnp.sign(work[{0}])",
OP_COPYSIGN: "jnp.copysign(work[{0}], work[{1}])",
OP_IF_ELSE_ZERO: "jnp.where(work[{0}] == 0, 0, work[{1}])",
OP_ERF: "jax.scipy.special.erf(work[{0}])",
OP_FMIN: "jnp.minimum(work[{0}], work[{1}])",
OP_FMAX: "jnp.maximum(work[{0}], work[{1}])",
OP_INV: "1.0 / work[{0}]",
OP_SINH: "jnp.sinh(work[{0}])",
OP_COSH: "jnp.cosh(work[{0}])",
OP_TANH: "jnp.tanh(work[{0}])",
OP_ASINH: "jnp.arcsinh(work[{0}])",
OP_ACOSH: "jnp.arccosh(work[{0}])",
OP_ATANH: "jnp.arctanh(work[{0}])",
OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])",
OP_ASSIGN: "{0}",
OP_ADD: "{0}+{1}",
OP_SUB: "{0}-{1}",
OP_MUL: "{0}*{1}",
OP_DIV: "{0}/{1}",
OP_NEG: "-{0}",
OP_EXP: "jnp.exp({0})",
OP_LOG: "jnp.log({0})",
OP_POW: "jnp.power({0}, {1})",
OP_CONSTPOW: "jnp.power({0}, {1})",
OP_SQRT: "jnp.sqrt({0})",
OP_SQ: "{0} * {0}",
OP_TWICE: "2 * {0}",
OP_SIN: "jnp.sin({0})",
OP_COS: "jnp.cos({0})",
OP_TAN: "jnp.tan({0})",
OP_ASIN: "jnp.arcsin({0})",
OP_ACOS: "jnp.arccos({0})",
OP_ATAN: "jnp.arctan({0})",
OP_LT: "{0} < {1}",
OP_LE: "{0} <= {1}",
OP_EQ: "{0} == {1}",
OP_NE: "{0} != {1}",
OP_NOT: "jnp.logical_not({0})",
OP_AND: "jnp.logical_and({0}, {1})",
OP_OR: "jnp.logical_or({0}, {1})",
OP_FLOOR: "jnp.floor({0})",
OP_CEIL: "jnp.ceil({0})",
OP_FMOD: "jnp.fmod({0}, {1})",
OP_FABS: "jnp.abs({0})",
OP_SIGN: "jnp.sign({0})",
OP_COPYSIGN: "jnp.copysign({0}, {1})",
OP_IF_ELSE_ZERO: "jnp.where({0} == 0, 0, {1})",
OP_ERF: "jax.scipy.special.erf({0})",
OP_FMIN: "jnp.minimum({0}, {1})",
OP_FMAX: "jnp.maximum({0}, {1})",
OP_INV: "1.0/{0}",
OP_SINH: "jnp.sinh({0})",
OP_COSH: "jnp.cosh({0})",
OP_TANH: "jnp.tanh({0})",
OP_ASINH: "jnp.arcsinh({0})",
OP_ACOSH: "jnp.arccosh({0})",
OP_ATANH: "jnp.arctanh({0})",
OP_ATAN2: "jnp.arctan2({0}, {1})",
OP_CONST: "{0:.16f}",
OP_INPUT: "inputs[{0}, {1}, {2}]",
OP_OUTPUT: "work[{0}][0]",
OP_INPUT: "i[{0},{1},{2}]",
OP_OUTPUT: "{0}[0]",
}
75 changes: 45 additions & 30 deletions jaxadi/_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from ._ops import OP_JAX_VALUE_DICT
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
import re
from tqdm import tqdm
from multiprocessing import Pool, cpu_count


class Stage:
Expand Down Expand Up @@ -56,6 +58,9 @@ def stage_generator(func: Function) -> str:
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
n_in = func.n_in() # number of outputs in the function
n_w = func.sz_w()

workers = [""] * n_w

# Get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
Expand All @@ -67,23 +72,23 @@ def stage_generator(func: Function) -> str:
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

stages = []
for k in range(n_instr):
for k in tqdm(range(n_instr)):
op = operations[k]
o_idx = output_idx[k]
i_idx = input_idx[k]
operation = Operation()
operation.op = op
if op == OP_CONST:
operation.output_idx = o_idx[0]
operation.value = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"
# codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
workers[o_idx[0]
] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])"

elif op == OP_INPUT:
this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX
operation.output_idx = o_idx[0]
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)
workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(
i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
operation = OutputOperation()
operation.op = op
Expand All @@ -94,28 +99,25 @@ def stage_generator(func: Function) -> str:
operation.exact_idx2 = column_number
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
stage = Stage()
stage.output_idx.append(operation.output_idx)
stage.work_idx.extend(operation.work_idx)
stage.ops.append(operation)
stages.append(stage)
elif op == OP_SQ:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
workers[o_idx[0]] = "(" + \
OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) + ")"
elif OP_JAX_VALUE_DICT[op].count("}") == 2:
operation.output_idx = o_idx[0]
operation.work_idx.extend([i_idx[0], i_idx[1]])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1])
workers[o_idx[0]] = "(" + OP_JAX_VALUE_DICT[op].format(
workers[i_idx[0]], workers[i_idx[1]]) + ")"
elif OP_JAX_VALUE_DICT[op].count("}") == 1:
operation.output_idx = o_idx[0]
operation.work_idx.append(i_idx[0])
operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0])
workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
else:
raise Exception("Unknown CasADi operation: " + str(op))
print(sum(len(s) for s in workers))

stage = Stage()
stage.output_idx.append(operation.output_idx)
stage.work_idx.extend(operation.work_idx)
stage.ops.append(operation)
stages.append(stage)

print("finished stages")
return stages


Expand Down Expand Up @@ -146,7 +148,7 @@ def combine_outputs(stages: List[Stage]) -> str:
rows = "[" + ", ".join(row_indices) + "]"
columns = "[" + ", ".join(column_indices) + "]"
values_str = ", ".join(values)
command = f" outputs[{output_idx}] = outputs[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
command = f" o[{output_idx}] = o[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
commands.append(command)

# Combine all the commands into a single string
Expand Down Expand Up @@ -177,19 +179,32 @@ def recursive_subs(stages: List[Stage], idx: int) -> str:
for i in range(idx - 1, -1, -1):
if stages[i].ops[0].output_idx == number and stages[i].ops[0].op != OP_OUTPUT:
# Recursively replace the found work[<number>] with expanded value
expanded_value = recursive_subs(stages, i)
result = result.replace(f"work[{number}]", expanded_value)
stages[i].ops[0].value = recursive_subs(stages, i)
result = result.replace(
f"work[{number}]", stages[i].ops[0].value)
break

return f"({result})"


def squeeze(stages: List[Stage]) -> List[Stage]:
def process_stage(args):
stages, i = args
if len(stages[i].ops) != 0:
stages[i].ops[0].value = recursive_subs(stages, i)
return stages[i]
return None


def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]:
new_stages = []
for i in range(len(stages)):
if len(stages[i].ops) != 0 and stages[i].ops[0].op == OP_OUTPUT:
stages[i].ops[0].value = recursive_subs(stages, i)
new_stages.append(stages[i])
working_stages = []
for i, stage in enumerate(stages):
if len(stage.ops) != 0 and stage.ops[0].op == OP_OUTPUT:
working_stages.append((i, stage))
for i in tqdm(range(len(working_stages))):
i, stage = working_stages[i]
stage.value = recursive_subs(stages, i)
new_stages.append(stage)

cmd = combine_outputs(new_stages)
return cmd
15 changes: 8 additions & 7 deletions jaxadi/_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from ._stages import stage_generator, squeeze


def translate(func: Function, add_jit=False, add_import=False) -> str:
stages = stage_generator(func)
stages = squeeze(stages)
# get information about casadi function
def translate(func: Function, add_jit=False, add_import=False, num_threads=1) -> str:
n_out = func.n_out() # number of outputs in the function

# get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
print(out_shapes)
stages = stage_generator(func)
stages = squeeze(stages, num_threads=num_threads)
# get information about casadi function

# generate string with complete code
codegen = ""
Expand All @@ -18,15 +19,15 @@ def translate(func: Function, add_jit=False, add_import=False) -> str:
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
# combine all inputs into a single list
codegen += " inputs = jnp.expand_dims(jnp.array(args), axis=-1)\n"
codegen += " i = jnp.expand_dims(jnp.array(args), axis=-1)\n"
# output variables
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"
codegen += f" o = [jnp.zeros(out) for out in {out_shapes}]\n"

# for stage in stages:
# codegen += stage.codegen()
codegen += stages

# footer
codegen += "\n return outputs\n"
codegen += "\n return o\n"

return codegen

0 comments on commit 9f476d5

Please sign in to comment.