From 9f476d5d1e7e5ff2df6820c96d16d6fbbac2e629 Mon Sep 17 00:00:00 2001 From: mattephi Date: Tue, 17 Sep 2024 14:56:17 +0900 Subject: [PATCH] feat: sequential translation --- jaxadi/_convert.py | 4 +- jaxadi/_ops.py | 92 ++++++++++++++++++++++---------------------- jaxadi/_stages.py | 75 +++++++++++++++++++++--------------- jaxadi/_translate.py | 15 ++++---- 4 files changed, 101 insertions(+), 85 deletions(-) diff --git a/jaxadi/_convert.py b/jaxadi/_convert.py index 72cdf30..a6f8cd9 100644 --- a/jaxadi/_convert.py +++ b/jaxadi/_convert.py @@ -7,7 +7,7 @@ from ._compile import compile as compile_fn -def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]: +def convert(casadi_fn: Function, compile=False, num_threads=1) -> Callable[..., Any]: """ Convert given casadi function into python callable based on JAX backend, optionally @@ -17,7 +17,7 @@ def convert(casadi_fn: Function, compile=False) -> Callable[..., Any]: :param compile (bool): Whether to AOT compile function :return (Callable[..., Any]): Resulting python function """ - jax_str = translate(casadi_fn) + jax_str = translate(casadi_fn, num_threads=num_threads) jax_fn = declare(jax_str) if compile: diff --git a/jaxadi/_ops.py b/jaxadi/_ops.py index fcdc8d5..6986ef3 100644 --- a/jaxadi/_ops.py +++ b/jaxadi/_ops.py @@ -49,51 +49,51 @@ ) OP_JAX_VALUE_DICT = { - OP_ASSIGN: "work[{0}]", - OP_ADD: "work[{0}] + work[{1}]", - OP_SUB: "work[{0}] - work[{1}]", - OP_MUL: "work[{0}] * work[{1}]", - OP_DIV: "work[{0}] / work[{1}]", - OP_NEG: "-work[{0}]", - OP_EXP: "jnp.exp(work[{0}])", - OP_LOG: "jnp.log(work[{0}])", - OP_POW: "jnp.power(work[{0}], work[{1}])", - OP_CONSTPOW: "jnp.power(work[{0}], work[{1}])", - OP_SQRT: "jnp.sqrt(work[{0}])", - OP_SQ: "work[{0}] * work[{0}]", - OP_TWICE: "2 * work[{0}]", - OP_SIN: "jnp.sin(work[{0}])", - OP_COS: "jnp.cos(work[{0}])", - OP_TAN: "jnp.tan(work[{0}])", - OP_ASIN: "jnp.arcsin(work[{0}])", - OP_ACOS: "jnp.arccos(work[{0}])", - OP_ATAN: "jnp.arctan(work[{0}])", - OP_LT: "work[{0}] < work[{1}]", - OP_LE: "work[{0}] <= work[{1}]", - OP_EQ: "work[{0}] == work[{1}]", - OP_NE: "work[{0}] != work[{1}]", - OP_NOT: "jnp.logical_not(work[{0}])", - OP_AND: "jnp.logical_and(work[{0}], work[{1}])", - OP_OR: "jnp.logical_or(work[{0}], work[{1}])", - OP_FLOOR: "jnp.floor(work[{0}])", - OP_CEIL: "jnp.ceil(work[{0}])", - OP_FMOD: "jnp.fmod(work[{0}], work[{1}])", - OP_FABS: "jnp.abs(work[{0}])", - OP_SIGN: "jnp.sign(work[{0}])", - OP_COPYSIGN: "jnp.copysign(work[{0}], work[{1}])", - OP_IF_ELSE_ZERO: "jnp.where(work[{0}] == 0, 0, work[{1}])", - OP_ERF: "jax.scipy.special.erf(work[{0}])", - OP_FMIN: "jnp.minimum(work[{0}], work[{1}])", - OP_FMAX: "jnp.maximum(work[{0}], work[{1}])", - OP_INV: "1.0 / work[{0}]", - OP_SINH: "jnp.sinh(work[{0}])", - OP_COSH: "jnp.cosh(work[{0}])", - OP_TANH: "jnp.tanh(work[{0}])", - OP_ASINH: "jnp.arcsinh(work[{0}])", - OP_ACOSH: "jnp.arccosh(work[{0}])", - OP_ATANH: "jnp.arctanh(work[{0}])", - OP_ATAN2: "jnp.arctan2(work[{0}], work[{1}])", + OP_ASSIGN: "{0}", + OP_ADD: "{0}+{1}", + OP_SUB: "{0}-{1}", + OP_MUL: "{0}*{1}", + OP_DIV: "{0}/{1}", + OP_NEG: "-{0}", + OP_EXP: "jnp.exp({0})", + OP_LOG: "jnp.log({0})", + OP_POW: "jnp.power({0}, {1})", + OP_CONSTPOW: "jnp.power({0}, {1})", + OP_SQRT: "jnp.sqrt({0})", + OP_SQ: "{0} * {0}", + OP_TWICE: "2 * {0}", + OP_SIN: "jnp.sin({0})", + OP_COS: "jnp.cos({0})", + OP_TAN: "jnp.tan({0})", + OP_ASIN: "jnp.arcsin({0})", + OP_ACOS: "jnp.arccos({0})", + OP_ATAN: "jnp.arctan({0})", + OP_LT: "{0} < {1}", + OP_LE: "{0} <= {1}", + OP_EQ: "{0} == {1}", + OP_NE: "{0} != {1}", + OP_NOT: "jnp.logical_not({0})", + OP_AND: "jnp.logical_and({0}, {1})", + OP_OR: "jnp.logical_or({0}, {1})", + OP_FLOOR: "jnp.floor({0})", + OP_CEIL: "jnp.ceil({0})", + OP_FMOD: "jnp.fmod({0}, {1})", + OP_FABS: "jnp.abs({0})", + OP_SIGN: "jnp.sign({0})", + OP_COPYSIGN: "jnp.copysign({0}, {1})", + OP_IF_ELSE_ZERO: "jnp.where({0} == 0, 0, {1})", + OP_ERF: "jax.scipy.special.erf({0})", + OP_FMIN: "jnp.minimum({0}, {1})", + OP_FMAX: "jnp.maximum({0}, {1})", + OP_INV: "1.0/{0}", + OP_SINH: "jnp.sinh({0})", + OP_COSH: "jnp.cosh({0})", + OP_TANH: "jnp.tanh({0})", + OP_ASINH: "jnp.arcsinh({0})", + OP_ACOSH: "jnp.arccosh({0})", + OP_ATANH: "jnp.arctanh({0})", + OP_ATAN2: "jnp.arctan2({0}, {1})", OP_CONST: "{0:.16f}", - OP_INPUT: "inputs[{0}, {1}, {2}]", - OP_OUTPUT: "work[{0}][0]", + OP_INPUT: "i[{0},{1},{2}]", + OP_OUTPUT: "{0}[0]", } diff --git a/jaxadi/_stages.py b/jaxadi/_stages.py index 3663d9d..135e33b 100644 --- a/jaxadi/_stages.py +++ b/jaxadi/_stages.py @@ -2,6 +2,8 @@ from ._ops import OP_JAX_VALUE_DICT from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function import re +from tqdm import tqdm +from multiprocessing import Pool, cpu_count class Stage: @@ -56,6 +58,9 @@ def stage_generator(func: Function) -> str: n_instr = func.n_instructions() n_out = func.n_out() # number of outputs in the function n_in = func.n_in() # number of outputs in the function + n_w = func.sz_w() + + workers = [""] * n_w # Get the shapes of input and output out_shapes = [func.size_out(i) for i in range(n_out)] @@ -67,23 +72,23 @@ def stage_generator(func: Function) -> str: const_instr = [func.instruction_constant(i) for i in range(n_instr)] stages = [] - for k in range(n_instr): + for k in tqdm(range(n_instr)): op = operations[k] o_idx = output_idx[k] i_idx = input_idx[k] operation = Operation() operation.op = op if op == OP_CONST: - operation.output_idx = o_idx[0] - operation.value = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])" - # codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k]) + workers[o_idx[0] + ] = "jnp.array([" + OP_JAX_VALUE_DICT[op].format(const_instr[k]) + "])" + elif op == OP_INPUT: this_shape = in_shapes[i_idx[0]] rows, cols = this_shape # Get the shape of the output row_number = i_idx[1] % rows # Compute row index for JAX column_number = i_idx[1] // rows # Compute column index for JAX - operation.output_idx = o_idx[0] - operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number) + workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format( + i_idx[0], row_number, column_number) elif op == OP_OUTPUT: operation = OutputOperation() operation.op = op @@ -94,28 +99,25 @@ def stage_generator(func: Function) -> str: operation.exact_idx2 = column_number operation.output_idx = o_idx[0] operation.work_idx.append(i_idx[0]) - operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0]) + operation.value = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) + stage = Stage() + stage.output_idx.append(operation.output_idx) + stage.work_idx.extend(operation.work_idx) + stage.ops.append(operation) + stages.append(stage) elif op == OP_SQ: - operation.output_idx = o_idx[0] - operation.work_idx.append(i_idx[0]) - operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0]) + workers[o_idx[0]] = "(" + \ + OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) + ")" elif OP_JAX_VALUE_DICT[op].count("}") == 2: - operation.output_idx = o_idx[0] - operation.work_idx.extend([i_idx[0], i_idx[1]]) - operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0], i_idx[1]) + workers[o_idx[0]] = "(" + OP_JAX_VALUE_DICT[op].format( + workers[i_idx[0]], workers[i_idx[1]]) + ")" elif OP_JAX_VALUE_DICT[op].count("}") == 1: - operation.output_idx = o_idx[0] - operation.work_idx.append(i_idx[0]) - operation.value = OP_JAX_VALUE_DICT[op].format(i_idx[0]) + workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]]) else: raise Exception("Unknown CasADi operation: " + str(op)) + print(sum(len(s) for s in workers)) - stage = Stage() - stage.output_idx.append(operation.output_idx) - stage.work_idx.extend(operation.work_idx) - stage.ops.append(operation) - stages.append(stage) - + print("finished stages") return stages @@ -146,7 +148,7 @@ def combine_outputs(stages: List[Stage]) -> str: rows = "[" + ", ".join(row_indices) + "]" columns = "[" + ", ".join(column_indices) + "]" values_str = ", ".join(values) - command = f" outputs[{output_idx}] = outputs[{output_idx}].at[({rows}, {columns})].set([{values_str}])" + command = f" o[{output_idx}] = o[{output_idx}].at[({rows}, {columns})].set([{values_str}])" commands.append(command) # Combine all the commands into a single string @@ -177,19 +179,32 @@ def recursive_subs(stages: List[Stage], idx: int) -> str: for i in range(idx - 1, -1, -1): if stages[i].ops[0].output_idx == number and stages[i].ops[0].op != OP_OUTPUT: # Recursively replace the found work[] with expanded value - expanded_value = recursive_subs(stages, i) - result = result.replace(f"work[{number}]", expanded_value) + stages[i].ops[0].value = recursive_subs(stages, i) + result = result.replace( + f"work[{number}]", stages[i].ops[0].value) break return f"({result})" -def squeeze(stages: List[Stage]) -> List[Stage]: +def process_stage(args): + stages, i = args + if len(stages[i].ops) != 0: + stages[i].ops[0].value = recursive_subs(stages, i) + return stages[i] + return None + + +def squeeze(stages: List[Stage], num_threads=1) -> List[Stage]: new_stages = [] - for i in range(len(stages)): - if len(stages[i].ops) != 0 and stages[i].ops[0].op == OP_OUTPUT: - stages[i].ops[0].value = recursive_subs(stages, i) - new_stages.append(stages[i]) + working_stages = [] + for i, stage in enumerate(stages): + if len(stage.ops) != 0 and stage.ops[0].op == OP_OUTPUT: + working_stages.append((i, stage)) + for i in tqdm(range(len(working_stages))): + i, stage = working_stages[i] + stage.value = recursive_subs(stages, i) + new_stages.append(stage) cmd = combine_outputs(new_stages) return cmd diff --git a/jaxadi/_translate.py b/jaxadi/_translate.py index b617cf9..ff8c0e2 100644 --- a/jaxadi/_translate.py +++ b/jaxadi/_translate.py @@ -2,14 +2,15 @@ from ._stages import stage_generator, squeeze -def translate(func: Function, add_jit=False, add_import=False) -> str: - stages = stage_generator(func) - stages = squeeze(stages) - # get information about casadi function +def translate(func: Function, add_jit=False, add_import=False, num_threads=1) -> str: n_out = func.n_out() # number of outputs in the function # get the shapes of input and output out_shapes = [func.size_out(i) for i in range(n_out)] + print(out_shapes) + stages = stage_generator(func) + stages = squeeze(stages, num_threads=num_threads) + # get information about casadi function # generate string with complete code codegen = "" @@ -18,15 +19,15 @@ def translate(func: Function, add_jit=False, add_import=False) -> str: codegen += "@jax.jit\n" if add_jit else "" codegen += f"def evaluate_{func.name()}(*args):\n" # combine all inputs into a single list - codegen += " inputs = jnp.expand_dims(jnp.array(args), axis=-1)\n" + codegen += " i = jnp.expand_dims(jnp.array(args), axis=-1)\n" # output variables - codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n" + codegen += f" o = [jnp.zeros(out) for out in {out_shapes}]\n" # for stage in stages: # codegen += stage.codegen() codegen += stages # footer - codegen += "\n return outputs\n" + codegen += "\n return o\n" return codegen