Skip to content

Commit

Permalink
feat: graph expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
mattephi committed Oct 21, 2024
1 parent d2ca1ad commit e9c8151
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 7 deletions.
4 changes: 2 additions & 2 deletions jaxadi/_expand.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Any, Dict
from ._ops import OP_JAX_VALUE_DICT
from ._ops import OP_JAX_EXPAND_VALUE_DICT as OP_JAX_VALUE_DICT
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
import re
from multiprocessing import Pool, cpu_count
Expand Down Expand Up @@ -141,7 +141,7 @@ def combine_outputs(stages: List[Stage]) -> str:
rows = "[" + ", ".join(row_indices) + "]"
columns = "[" + ", ".join(column_indices) + "]"
values_str = ", ".join(values)
command = f" o[{output_idx}] = o[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
command = f" outputs[{output_idx}] = outputs[{output_idx}].at[({rows}, {columns})].set([{values_str}])"
commands.append(command)

# Combine all the commands into a single string
Expand Down
97 changes: 92 additions & 5 deletions jaxadi/_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,32 @@
compression/fusion if necessary/possible
"""

from casadi import Function
from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function
import random
import re
from collections import deque

from casadi import OP_CONST, OP_INPUT, OP_OUTPUT, OP_SQ, Function

from ._ops import OP_JAX_VALUE_DICT


def test_and_compress(s):
# Step 1: Check if the string has the desired form using a regex
pattern = re.compile(r"\[\s*work\[(\d+)\]\s*\*\s*work\[(\d+)\](?:\s*,\s*work\[(\d+)\]\s*\*\s*work\[(\d+)\])*\s*\]")

if not pattern.fullmatch(s.strip()):
return s

# Step 2.1: Extract the indices from the matches
matches = re.findall(r"work\[(\d+)\]\s*\*\s*work\[(\d+)\]", s)
left_indices = [int(m[0]) for m in matches]
right_indices = [int(m[1]) for m in matches]

# Construct the compressed string
compressed_string = f"jnp.multiply(work[jnp.array({left_indices})], work[jnp.array({right_indices})])"
return compressed_string


def sort_by_height(graph, antigraph, heights):
nodes = [[] for i in range(max(heights) + 1)]
for i, h in enumerate(heights):
Expand All @@ -27,6 +46,8 @@ def codegen(graph, antigraph, heights, output_map, values):
indices = []
assignment = "["
for node in layer:
if len(graph[node]) == 0 and not node in output_map:
continue
if node in output_map:
oo = output_map[node]
if outputs.get(oo[0], None) is None:
Expand All @@ -39,7 +60,11 @@ def codegen(graph, antigraph, heights, output_map, values):
assignment += ", "
assignment += values[node]
indices += [node]
code += f" work = work.at[jnp.array({indices})].set({assignment}])\n"
if len(indices) == 0:
continue
assignment += "]"
# assignment = test_and_compress(assignment)
code += f" work = work.at[jnp.array({indices})].set({assignment})\n"

for k, v in outputs.items():
code += f" outputs[{k}] = outputs[{k}].at[({v['rows']}, {v['cols']})].set([{', '.join(v['values'])}])\n"
Expand Down Expand Up @@ -97,10 +122,20 @@ def create_graph(func: Function):
workers[o_idx[0]] = i
elif op == OP_OUTPUT:
rows, cols = func.size_out(o_idx[0])
row_number = o_idx[1] % rows # Compute row index for JAX
column_number = o_idx[1] // rows # Compute column index for JAX
row_number = o_idx[1] % rows
column_number = o_idx[1] // rows
output_map[i] = (o_idx[0], row_number, column_number)
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])

# Update the graph: add this output node as a child of its input (work node)
parent = workers[i_idx[0]]
graph[parent].append(i)
antigraph[i].append(parent)
# rows, cols = func.size_out(o_idx[0])
# row_number = o_idx[1] % rows # Compute row index for JAX
# column_number = o_idx[1] // rows # Compute column index for JAX
# output_map[i] = (o_idx[0], row_number, column_number)
# values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
elif op == OP_SQ:
values[i] = OP_JAX_VALUE_DICT[op].format(workers[i_idx[0]])
graph[workers[i_idx[0]]].append(i)
Expand All @@ -125,8 +160,60 @@ def create_graph(func: Function):
return graph, antigraph, output_map, values


def expand_graph(func, graph, antigraph, output_map, values):
heights = compute_heights(func, graph, antigraph)
sorted_nodes = sort_by_height(graph, antigraph, heights)

# Calculate the average number of vertices per layer
total_vertices = sum(len(layer) for layer in sorted_nodes)
avg_vertices = total_vertices / len(sorted_nodes)

new_graph = [[] for _ in range(len(graph))]
new_antigraph = [[] for _ in range(len(antigraph))]

# Iterate over layers
for layer in sorted_nodes:
expand_layer = len(layer) < avg_vertices and not any(node in output_map for node in layer)

if expand_layer:
# Expand nodes and update their values
for node in layer:
value_expr = values[node]
expanded_expr = re.sub(r"work\[(\d+)\]", lambda m: f"({values[int(m.group(1))]})", value_expr)
values[node] = expanded_expr

# Recalculate dependencies for expanded nodes
for node in layer:
new_parents = set()
for parent in antigraph[node]:
new_parents.update(new_antigraph[parent]) # Use updated parents

# Update new_antigraph and new_graph accordingly
new_antigraph[node] = list(new_parents)

for new_parent in new_parents:
new_graph[new_parent].append(node)

# Retain the original child relationships
for child in graph[node]:
new_graph[node].append(child)
new_antigraph[child].append(node)
else:
# Maintain existing connections for nodes without expansion
for node in layer:
for parent in antigraph[node]:
new_graph[parent].append(node)
new_antigraph[node].append(parent)
for child in graph[node]:
new_graph[node].append(child)
new_antigraph[child].append(node)

return new_graph, new_antigraph, output_map, values


def translate(func: Function, add_jit=False, add_import=False):
graph, antigraph, output_map, values = create_graph(func)
graph, antigraph, output_map, values = expand_graph(func, graph, antigraph, output_map, values)
heights = compute_heights(func, graph, antigraph)

code = ""
Expand Down
49 changes: 49 additions & 0 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,52 @@
OP_INPUT: "inputs[{0}][{1}, {2}]",
OP_OUTPUT: "work[{0}][0]",
}
OP_JAX_EXPAND_VALUE_DICT = {
OP_ASSIGN: "{0}",
OP_ADD: "{0} + {1}",
OP_SUB: "{0} - {1}",
OP_MUL: "{0} * {1}",
OP_DIV: "{0} / {1}",
OP_NEG: "-{0}",
OP_EXP: "jnp.exp({0})",
OP_LOG: "jnp.log({0})",
OP_POW: "jnp.power({0}, {1})",
OP_CONSTPOW: "jnp.power({0}, {1})",
OP_SQRT: "jnp.sqrt({0})",
OP_SQ: "{0} * {0}",
OP_TWICE: "2 * {0}",
OP_SIN: "jnp.sin({0})",
OP_COS: "jnp.cos({0})",
OP_TAN: "jnp.tan({0})",
OP_ASIN: "jnp.arcsin({0})",
OP_ACOS: "jnp.arccos({0})",
OP_ATAN: "jnp.arctan({0})",
OP_LT: "{0} < {1}",
OP_LE: "{0} <= {1}",
OP_EQ: "{0} == {1}",
OP_NE: "{0} != {1}",
OP_NOT: "jnp.logical_not({0})",
OP_AND: "jnp.logical_and({0}, {1})",
OP_OR: "jnp.logical_or({0}, {1})",
OP_FLOOR: "jnp.floor({0})",
OP_CEIL: "jnp.ceil({0})",
OP_FMOD: "jnp.fmod({0}, {1})",
OP_FABS: "jnp.abs({0})",
OP_SIGN: "jnp.sign({0})",
OP_COPYSIGN: "jnp.copysign({0}, {1})",
OP_IF_ELSE_ZERO: "jnp.where({0} == 0, 0, {1})",
OP_ERF: "jax.scipy.special.erf({0})",
OP_FMIN: "jnp.minimum({0}, {1})",
OP_FMAX: "jnp.maximum({0}, {1})",
OP_INV: "1.0 / {0}",
OP_SINH: "jnp.sinh({0})",
OP_COSH: "jnp.cosh({0})",
OP_TANH: "jnp.tanh({0})",
OP_ASINH: "jnp.arcsinh({0})",
OP_ACOSH: "jnp.arccosh({0})",
OP_ATANH: "jnp.arctanh({0})",
OP_ATAN2: "jnp.arctan2({0}, {1})",
OP_CONST: "{0:.16f}",
OP_INPUT: "inputs[{0}][{1}, {2}]",
OP_OUTPUT: "{0}[0]",
}

0 comments on commit e9c8151

Please sign in to comment.