Custom callable function from within the C++ API #1614
-
If we define a function in Python, it can then be passed to mlx.core.compile along with the inputs for the invokation. However, it is not clear to me if there is an API for, say, defining symbolic MLX arrays that correspond to each function argument that would then trace the graph, or how we would deal with tuple results in this workflow. Since MLX is mostly lazy-eval, not being able to achieve this isn't a dealbreaker, as we have the |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
I'm not familiar at all with Nx and how compilation works there but I can say a bit more about how it works in MLX: def fun(a, b, c):
return a + b + c
# Step 0: Nothing much has happened here yet other than wrapping `fun`
# in another function which knows to compile it
compiled_fun = mx.compile(fun)
# Step 1: The first time the compiled function is called it gets partially compiled. We trace the graph
# using the provided inputs and do some optimization passes on the graph
out = compiled_fun(a, b, c)
# Step 2: The rest of the compilation happens the first time you call eval. This is where
# kernel source is actually JIT compiled
eval(out)
# Calling it again on inputs with the same shape and type doesn't recompile
eval(compiled_fun(a, b, c)) I'm not sure if it makes sense to do step 0 (wrapping the For example, if you are changing the shape of the input from call to call, say for example increasing the shape of an input by one at each call, then compiling can slow you down. Compiling a large graph can take some time (milliseconds) and for latency sensitive applications that can add up. You typically want to amortize the cost of compiling with repeated applications of the compiled function. |
Beta Was this translation helpful? Give feedback.
I'm not familiar at all with Nx and how compilation works there but I can say a bit more about how it works in MLX: