Skip to content

Commit

Permalink
feat: more works
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Sep 3, 2024
1 parent abeb975 commit 474a954
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
3 changes: 2 additions & 1 deletion jaxadi/_declare.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Any
import jax.numpy as jnp


def declare(f: str) -> Callable[..., Any]:
Expand All @@ -7,6 +8,6 @@ def declare(f: str) -> Callable[..., Any]:
based on string definition
"""
local_vars = {}
exec(f, {}, local_vars)
exec(f, globals(), local_vars)
func_name = next(iter(local_vars))
return local_vars[func_name]
4 changes: 2 additions & 2 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@
OP_ATANH: "\n work = work.at[{0}].set(jnp.arctanh(work[{1}]))",
OP_ATAN2: "\n work = work.at[{0}].set(jnp.arctan2(work[{1}], work[{2}]))",
OP_CONST: "\n work = work.at[{0}].set({1:.16f})",
OP_INPUT: "\n work = work.at[{0}].set(inputs[{1}][{2}])",
OP_OUTPUT: "\n outputs[{0}] = outputs[{0}].at[{1}].set(work[{2}])",
OP_INPUT: "\n work = work.at[{0}].set(inputs[{1}][{2}, {3}])",
OP_OUTPUT: "\n outputs[{0}] = outputs[{0}].at[{1}, {2}].set(work[{3}][0])",
}
32 changes: 22 additions & 10 deletions jaxadi/_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,34 @@
from ._ops import OP_JAX_DICT


def translate(func: Function, add_jit=False, add_import=False) -> list[str]:
def translate(func: Function, add_jit=False, add_import=False) -> str:
# Get information about Casadi function
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
n_in = func.n_in() # number of outputs in the function

# get the shapes of input and output
# Get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
in_shapes = [func.size_in(i) for i in range(n_in)]

# number of work variables
# Number of work variables
n_w = func.sz_w()

input_idx = [func.instruction_input(i) for i in range(n_instr)]
output_idx = [func.instruction_output(i) for i in range(n_instr)]
operations = [func.instruction_id(i) for i in range(n_instr)]
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

# generate string with complete code
# Generate string with complete code
codegen = ""
if add_import:
codegen += "import jax\nimport jax.numpy as jnp\n\n"
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
codegen += " inputs = args\n" # combine all inputs into a single list
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n" # output variables
codegen += f" work = jnp.zeros(({n_w}, 1))\n" # work variables
codegen += " inputs = args\n" # Combine all inputs into a single list
# Output variables
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"
codegen += f" work = jnp.zeros(({n_w}, 1))\n" # Work variables

for k in range(n_instr):
op = operations[k]
Expand All @@ -36,9 +39,18 @@ def translate(func: Function, add_jit=False, add_import=False) -> list[str]:
if op == OP_CONST:
codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
elif op == OP_INPUT:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], i_idx[1])
this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
codegen += OP_JAX_DICT[op].format(o_idx[0], o_idx[1], i_idx[0])
# Fix for OP_OUTPUT
rows, cols = out_shapes[o_idx[0]] # Get the shape of the output matrix
# Adjust the calculation to switch from CasADi's column-major to JAX's row-major
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
codegen += OP_JAX_DICT[op].format(o_idx[0], row_number, column_number, i_idx[0])
elif op == OP_SQ:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], i_idx[0])
elif OP_JAX_DICT[op].count("{") == 3:
Expand All @@ -48,7 +60,7 @@ def translate(func: Function, add_jit=False, add_import=False) -> list[str]:
else:
raise Exception("Unknown CasADi operation: " + str(op))

# footer
# Footer
codegen += "\n return outputs\n"

return codegen

0 comments on commit 474a954

Please sign in to comment.