-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for f-strings in debug.print #26296
Comments
Thanks for the suggestion! That's a clever idea, though one downside is that if a tracer ends up in a regular format statement, the output would be pretty inscrutable. I suspect we'd probably not add this sort of approach by default in JAX, but I'll ping @sharadmv who first implemented |
I actually tried this exact idea in early prototypes of debug print. It does work with one major downside: if the variable you're tracking leaves scope and gets freed, you get crazy errors. A common failure mode would be doing something like |
Thank you for taking the time to explain this. Yeah, I see f-strings are not really compatible as-is with deferred evaluation, and it would probably just introduce more non-intuitive stuff. Effectively a closure is needed to defer evaluation, e.g. passing separate I think my only motivation for f-strings, if it were possible that is, is to avoid the extra work and mental effort of prepare/maintain a format string and associated list of arguments on top of the actual debugging, which also happens with # ugly part, see below
...
_debug_closure = partial(
debug_closure,
namespace={'zero': 0, 'one': 1, 'two': 2, 'three': 3})
def g():
x = 'one'
y = 'two'
z = 123456
_debug_closure(lambda: f"inline: {x+y}, {2*x-y}, {[z, z-x]}")
@_debug_closure
def f():
print("callback:")
print(f" {x+y=}")
print(f" {2*x-y=}")
print(f" {[z, z-x]=}")
g()
Now, the ugly part to get that to work is something like... from functools import partial
from types import FunctionType, CellType
def _replace_closure(func, values: tuple):
return FunctionType(
func.__code__,
func.__globals__,
name = func.__name__,
argdefs = func.__defaults__,
closure = tuple(CellType(v) for v in values))
def debug_closure(func, *, namespace):
values = [c.cell_contents for c in func.__closure__]
empty_func = _replace_closure(func, [None]*len(func.__closure__))
# ...
renew_func = _replace_closure(empty_func, [
namespace.get(k) if type(k) is str else k
for k in values])
out = renew_func()
if out is not None:
print(out)
... |
I went ahead and attempted a more formal proof of concept for the "closure" approach (basically lambda capture). I think this basically boils down to "flattening" a function that has a closure, so more appropriate to store the decomposed form of the function instead of reconstituting the function with a dummy closure. Below is an excerpt of a working example with jax, the full code for def _debug_closure_callback(func: StaticClosure, closure):
out = func(closure)
if out is not None:
if type(out) is not str:
out = repr(out)
sys.stdout.write(out + "\n")
def debug_closure(func: FunctionType = None, /, *, ordered: bool = False):
if func is None:
return partial(debug_closure, ordered=ordered)
closure, func = StaticClosure.from_function(func)
jax.debug.callback(
_debug_closure_callback,
func,
closure,
ordered=ordered) def g(x, y, z):
out = x+y+z
debug_closure(lambda: "no closure")
debug_closure(lambda: f"inline: {x+y}, {2*x-y}, {[z, z-x]}")
@debug_closure
def f1():
print("callback:")
print(f" {x+y=}")
print(f" {2*x-y+np.pi=}")
print(f" {[z, z-x]=}")
@debug_closure(ordered=True)
def f2():
print(f"callback(ordered): {out=}")
return out
x = jnp.array(1.0)
y = jnp.array(2.0)
z = jnp.array(1234.0)
jax.jit(g)(x, y, z) |
An idea for supporting f-strings in
jax.debug.print
as an alternative to passing a template+args. This might not be a fully well-formed idea, it occurred to me, but wanted to see if there was any desire for this or considered before. Just to demonstrate some possible behavior (not a robust implementation), currently only forDynamicJaxprTracer
, here is an override for the__format__
special method which is called during formatting but not for normalrepr
orstr
.The idea is that
__format__
would return an id for the tracer, instead of the normal repr, but is actually interpreted itself as a string template placeholder (ie the f-string actually generates a template for the tracer). Then indebug_print
the tracer-id is extracted by evaluating the template with a namespace mapping that, when key'ed, looks up the tracer in the current trace and adds it to the kwargs to the callback so that when it is evaluated on the backed the template key pulls in the array instead.There is a bit of example output below, but "traced-D" is the only one where the traced array is actually mapped. Clearly this is much more limited than passing by argument, since the values are not flattened/un-flattened before formatting.
Demonstration code:
Example:
Example output:
The text was updated successfully, but these errors were encountered: