Skip to content

Commit

Permalink
feat: add broken example with lower
Browse files Browse the repository at this point in the history
  • Loading branch information
lvjonok committed Sep 3, 2024
1 parent cbd1da4 commit e9de0d7
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
12 changes: 10 additions & 2 deletions examples/00_translate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
This example demonstrates how to define a mathematical function using CasADi
and translate it into JAX-compatible code using the jaxadi library.
It shows the function's signature in CasADi and its equivalent representation
in JAX, enabling the use of JAX's automatic differentiation and optimization
capabilities for the same function.
"""

import casadi as cs

from jaxadi import translate
Expand All @@ -7,9 +15,9 @@
y = cs.SX.sym("y", 1)
casadi_function = cs.Function("myfunc", [x, y], [x**2 + y**2 - 1])

print("Signature of the casadi function:")
print("Signature of the CasADi function:")
print(casadi_function)

print("Transcribed code:")
print("Translated JAX function:")
for cg_str in translate(casadi_function):
print(cg_str)
Empty file removed examples/01_lho.py
Empty file.
17 changes: 17 additions & 0 deletions examples/01_lower.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import casadi as cs

from jaxadi import lower, translate, declare

# define input variables for the function
x = cs.SX.sym("x", 10, 10)
y = cs.SX.sym("y", 10, 10)
casadi_fn = cs.Function("myfunc", [x, y], [x @ y])

print("Signature of the CasADi function:")
print(casadi_fn)

# define jax function from casadi one
jax_fn = declare(translate(casadi_fn))

print("Lowered JAX function:")
print(lower(jax_fn, casadi_fn))
1 change: 1 addition & 0 deletions jaxadi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._compile import lower
from ._convert import convert
from ._translate import translate
from ._declare import declare
13 changes: 3 additions & 10 deletions jaxadi/_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,8 @@ def translate(func: Function) -> list[str]:
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

# generate string with complete code
codegen = textwrap.dedent(
"""
import jax
import jax.numpy as jnp
"""
)
codegen += "@jax.jit\n"
codegen = ""
# codegen += "@jax.jit\n"
codegen += f"def evaluate_{func.name()}(*args):\n"
codegen += " inputs = args\n" # combine all inputs into a single list
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n" # output variables
Expand Down Expand Up @@ -58,4 +51,4 @@ def translate(func: Function) -> list[str]:
# footer
codegen += "\n return outputs\n"

return codegen.split("\n")
return codegen

0 comments on commit e9de0d7

Please sign in to comment.