Skip to content

Commit

Permalink
feat: optimisiing comp graph
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Oct 17, 2024
1 parent 0f5c01a commit 6c15dac
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
39 changes: 39 additions & 0 deletions examples/optimize-ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import ast
import numpy as np


class MatrixOptimizer(ast.NodeTransformer):
def visit_BinOp(self, node):
# Detect if we are working with multiplications and additions
if isinstance(node.op, ast.Add):
left = self.visit(node.left)
right = self.visit(node.right)
# If both sides are multiplications, return a matrix operation
if isinstance(left, ast.BinOp) and isinstance(right, ast.BinOp):
return self.merge_to_matrix(node)
return self.generic_visit(node)

def merge_to_matrix(self, node):
# Replace detected subtree with a matrix expression
q = np.array([1, 2]) # Example q vector
A = np.array([[1, 2], [3, 4]]) # Example A matrix
result = q.T @ A @ q
print(f"Compressed to: {result}")
return ast.Constant(value=result)


# Example code AST (simulated distributed computation)
code = """
w0 = q[0] * A[0][0] + q[1] * A[0][1]
w1 = q[0] * A[1][0] + q[1] * A[1][1]
result = w0 * q[0] + w1 * q[1]
"""
tree = ast.parse(code)

# Optimize AST
optimizer = MatrixOptimizer()
optimized_tree = optimizer.visit(tree)

# Display the optimized code
optimized_code = ast.unparse(optimized_tree)
print(optimized_code)
25 changes: 25 additions & 0 deletions examples/print_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import casadi as cs
from visast import generate, visualise

from jaxadi import translate, legacy_translate
import jax

A = cs.SX.sym("A", 2, 2)
q = cs.SX.sym("q", 2)

fn = cs.Function("large_mult", [A, q], [q.T @ A @ q])

jax_fn_equiv = """
@jax.jit
def jax_fn_equiv(A, q):
return q.T @ A @ q.T
"""

jax_fn_text = translate(fn)
jax_fn_text_legacy = legacy_translate(fn)
jax_fn_ast = generate.fromString(jax_fn_text)
jax_fn_ast_ideal = generate.fromString(jax_fn_equiv)
jax_fn_ast_legacy = generate.fromString(jax_fn_text_legacy)
visualise.graph(jax_fn_ast)
visualise.graph(jax_fn_ast_legacy)
visualise.graph(jax_fn_ast_ideal)
1 change: 1 addition & 0 deletions jaxadi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._compile import lower
from ._convert import convert
from ._translate import translate
from ._legacy_translate import translate as legacy_translate
from ._declare import declare
163 changes: 163 additions & 0 deletions jaxadi/_legacy_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from casadi import (
OP_ACOS,
OP_ACOSH,
OP_ADD,
OP_AND,
OP_ASIN,
OP_ASINH,
OP_ASSIGN,
OP_ATAN,
OP_ATAN2,
OP_ATANH,
OP_CEIL,
OP_CONST,
OP_CONSTPOW,
OP_COPYSIGN,
OP_COS,
OP_COSH,
OP_DIV,
OP_EQ,
OP_ERF,
OP_EXP,
OP_FABS,
OP_FLOOR,
OP_FMAX,
OP_FMIN,
OP_FMOD,
OP_IF_ELSE_ZERO,
OP_INPUT,
OP_INV,
OP_LE,
OP_LOG,
OP_LT,
OP_MUL,
OP_NE,
OP_NEG,
OP_NOT,
OP_OR,
OP_OUTPUT,
OP_POW,
OP_SIGN,
OP_SIN,
OP_SINH,
OP_SQ,
OP_SQRT,
OP_SUB,
OP_TAN,
OP_TANH,
OP_TWICE,
Function,
)

OP_JAX_DICT = {
OP_ASSIGN: "\n work = work.at[{0}].set(work[{1}])",
OP_ADD: "\n work = work.at[{0}].set(work[{1}] + work[{2}])",
OP_SUB: "\n work = work.at[{0}].set(work[{1}] - work[{2}])",
OP_MUL: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_DIV: "\n work = work.at[{0}].set(work[{1}] / work[{2}])",
OP_NEG: "\n work = work.at[{0}].set(-work[{1}])",
OP_EXP: "\n work = work.at[{0}].set(jnp.exp(work[{1}]))",
OP_LOG: "\n work = work.at[{0}].set(jnp.log(work[{1}]))",
OP_POW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_CONSTPOW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_SQRT: "\n work = work.at[{0}].set(jnp.sqrt(work[{1}]))",
OP_SQ: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_TWICE: "\n work = work.at[{0}].set(2 * work[{1}])",
OP_SIN: "\n work = work.at[{0}].set(jnp.sin(work[{1}]))",
OP_COS: "\n work = work.at[{0}].set(jnp.cos(work[{1}]))",
OP_TAN: "\n work = work.at[{0}].set(jnp.tan(work[{1}]))",
OP_ASIN: "\n work = work.at[{0}].set(jnp.arcsin(work[{1}]))",
OP_ACOS: "\n work = work.at[{0}].set(jnp.arccos(work[{1}]))",
OP_ATAN: "\n work = work.at[{0}].set(jnp.arctan(work[{1}]))",
OP_LT: "\n work = work.at[{0}].set(work[{1}] < work[{2}])",
OP_LE: "\n work = work.at[{0}].set(work[{1}] <= work[{2}])",
OP_EQ: "\n work = work.at[{0}].set(work[{1}] == work[{2}])",
OP_NE: "\n work = work.at[{0}].set(work[{1}] != work[{2}])",
OP_NOT: "\n work = work.at[{0}].set(jnp.logical_not(work[{1}]))",
OP_AND: "\n work = work.at[{0}].set(jnp.logical_and(work[{1}], work[{2}]))",
OP_OR: "\n work = work.at[{0}].set(jnp.logical_or(work[{1}], work[{2}]))",
OP_FLOOR: "\n work = work.at[{0}].set(jnp.floor(work[{1}]))",
OP_CEIL: "\n work = work.at[{0}].set(jnp.ceil(work[{1}]))",
OP_FMOD: "\n work = work.at[{0}].set(jnp.fmod(work[{1}], work[{2}]))",
OP_FABS: "\n work = work.at[{0}].set(jnp.abs(work[{1}]))",
OP_SIGN: "\n work = work.at[{0}].set(jnp.sign(work[{1}]))",
OP_COPYSIGN: "\n work = work.at[{0}].set(jnp.copysign(work[{1}], work[{2}]))",
OP_IF_ELSE_ZERO: "\n work = work.at[{0}].set(jnp.where(work[{1}] == 0, 0, work[{2}]))",
OP_ERF: "\n work = work.at[{0}].set(jax.scipy.special.erf(work[{1}]))",
OP_FMIN: "\n work = work.at[{0}].set(jnp.minimum(work[{1}], work[{2}]))",
OP_FMAX: "\n work = work.at[{0}].set(jnp.maximum(work[{1}], work[{2}]))",
OP_INV: "\n work = work.at[{0}].set(1.0 / work[{1}])",
OP_SINH: "\n work = work.at[{0}].set(jnp.sinh(work[{1}]))",
OP_COSH: "\n work = work.at[{0}].set(jnp.cosh(work[{1}]))",
OP_TANH: "\n work = work.at[{0}].set(jnp.tanh(work[{1}]))",
OP_ASINH: "\n work = work.at[{0}].set(jnp.arcsinh(work[{1}]))",
OP_ACOSH: "\n work = work.at[{0}].set(jnp.arccosh(work[{1}]))",
OP_ATANH: "\n work = work.at[{0}].set(jnp.arctanh(work[{1}]))",
OP_ATAN2: "\n work = work.at[{0}].set(jnp.arctan2(work[{1}], work[{2}]))",
OP_CONST: "\n work = work.at[{0}].set({1:.16f})",
OP_INPUT: "\n work = work.at[{0}].set(inputs[{1}][{2}, {3}])",
OP_OUTPUT: "\n outputs[{0}] = outputs[{0}].at[{1}, {2}].set(work[{3}][0])",
}


def translate(func: Function, add_jit=False, add_import=False) -> str:
# Get information about Casadi function
n_instr = func.n_instructions()
n_out = func.n_out() # number of outputs in the function
n_in = func.n_in() # number of outputs in the function

# Get the shapes of input and output
out_shapes = [func.size_out(i) for i in range(n_out)]
in_shapes = [func.size_in(i) for i in range(n_in)]

# Number of work variables
n_w = func.sz_w()

input_idx = [func.instruction_input(i) for i in range(n_instr)]
output_idx = [func.instruction_output(i) for i in range(n_instr)]
operations = [func.instruction_id(i) for i in range(n_instr)]
const_instr = [func.instruction_constant(i) for i in range(n_instr)]

# Generate string with complete code
codegen = ""
if add_import:
codegen += "import jax\nimport jax.numpy as jnp\n\n"
codegen += "@jax.jit\n" if add_jit else ""
codegen += f"def evaluate_{func.name()}(*args):\n"
codegen += " inputs = args\n" # Combine all inputs into a single list
# Output variables
codegen += f" outputs = [jnp.zeros(out) for out in {out_shapes}]\n"
codegen += f" work = jnp.zeros(({n_w}, 1))\n" # Work variables

for k in range(n_instr):
op = operations[k]
o_idx = output_idx[k]
i_idx = input_idx[k]
if op == OP_CONST:
codegen += OP_JAX_DICT[op].format(o_idx[0], const_instr[k])
elif op == OP_INPUT:
this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape # Get the shape of the output
row_number = i_idx[1] % rows # Compute row index for JAX
column_number = i_idx[1] // rows # Compute column index for JAX
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], row_number, column_number)
elif op == OP_OUTPUT:
# Fix for OP_OUTPUT
rows, cols = out_shapes[o_idx[0]] # Get the shape of the output matrix
# Adjust the calculation to switch from CasADi's column-major to JAX's row-major
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
codegen += OP_JAX_DICT[op].format(o_idx[0], row_number, column_number, i_idx[0])
elif op == OP_SQ:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], i_idx[0])
elif OP_JAX_DICT[op].count("{") == 3:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0], i_idx[1])
elif OP_JAX_DICT[op].count("{") == 2:
codegen += OP_JAX_DICT[op].format(o_idx[0], i_idx[0])
else:
raise Exception("Unknown CasADi operation: " + str(op))

# Footer
codegen += "\n return outputs\n"

return codegen

0 comments on commit 6c15dac

Please sign in to comment.