Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for f-strings in debug.print #26296

Open
kcdodd opened this issue Feb 4, 2025 · 4 comments
Open

Support for f-strings in debug.print #26296

kcdodd opened this issue Feb 4, 2025 · 4 comments
Assignees
Labels
enhancement New feature or request

Comments

@kcdodd
Copy link

kcdodd commented Feb 4, 2025

An idea for supporting f-strings in jax.debug.print as an alternative to passing a template+args. This might not be a fully well-formed idea, it occurred to me, but wanted to see if there was any desire for this or considered before. Just to demonstrate some possible behavior (not a robust implementation), currently only for DynamicJaxprTracer, here is an override for the __format__ special method which is called during formatting but not for normal repr or str.

The idea is that __format__ would return an id for the tracer, instead of the normal repr, but is actually interpreted itself as a string template placeholder (ie the f-string actually generates a template for the tracer). Then in debug_print the tracer-id is extracted by evaluating the template with a namespace mapping that, when key'ed, looks up the tracer in the current trace and adds it to the kwargs to the callback so that when it is evaluated on the backed the template key pulls in the array instead.

There is a bit of example output below, but "traced-D" is the only one where the traced array is actually mapped. Clearly this is much more limited than passing by argument, since the values are not flattened/un-flattened before formatting.

Demonstration code:

class DynamicJaxprTracer(core.Tracer):
  ...
  def __format__(self, spec):
    key = f"{type(self).__name__}-{id(self)}"
    return f"{{{key}}}"
class _namespace:
  def __init__(self, trace):
    self.trace = trace
    self.tracer_lookup = {}

  def __getitem__(self, key):
    name, _id = key.split('-')

    _id = int(_id)
    for v in self.trace.frame.tracers:
      if id(v) == _id:
        self.tracer_lookup[key] = v
        return key

    raise KeyError(f"Tracer key not found in current trace: {key}")


def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
  trace = core.trace_ctx.trace

  if args or kwargs or type(trace) is not pe.DynamicJaxprTrace:
    formatter.format(fmt, *args, **kwargs)
  else:
    # extract tracer lookup
    ns = _namespace(trace)
    # run formatter to force a namespace lookup
    fmt.format_map(ns)
    # extract tracers that were mapped
    kwargs = ns.tracer_lookup

  ...

Example:

def f(case, x):
  print(f"{case}-A: {x}")
  print(f"{case}-B: {x!r}")
  print(f"{case}-C: {[x, x]}")
  jax.debug.print(f"{case}-D: {x}")
  jax.debug.print(f"{case}-E: {x!r}")
  return x**2


x = np.zeros((3,4), jnp.float32)
jax.jit(f, static_argnums=[0])('traced', x)
f('direct', x)

Example output:

traced-A: {DynamicJaxprTracer-140107375262464}
traced-B: Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
traced-C: [Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>, Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>]
traced-D: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
traced-E: Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace>
direct-A: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
direct-B: array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)
direct-C: [array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)]
direct-D: [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
direct-E: array([[0., 0., 0., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)

@kcdodd kcdodd added the enhancement New feature or request label Feb 4, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 5, 2025

Thanks for the suggestion! That's a clever idea, though one downside is that if a tracer ends up in a regular format statement, the output would be pretty inscrutable. I suspect we'd probably not add this sort of approach by default in JAX, but I'll ping @sharadmv who first implemented debug.print to get his thoughts.

@sharadmv
Copy link
Collaborator

sharadmv commented Feb 5, 2025

I actually tried this exact idea in early prototypes of debug print. It does work with one major downside: if the variable you're tracking leaves scope and gets freed, you get crazy errors.

A common failure mode would be doing something like jax.debug.print(f"{x + 1}") where the x + 1 is materialized just for the print but then not re-used, so it is freed.

@kcdodd
Copy link
Author

kcdodd commented Feb 6, 2025

Thank you for taking the time to explain this. Yeah, I see f-strings are not really compatible as-is with deferred evaluation, and it would probably just introduce more non-intuitive stuff. Effectively a closure is needed to defer evaluation, e.g. passing separate args. Except, also the values in the closure have to be swapped out, tuple[Tracer, ...] -> tuple[Array,...].

I think my only motivation for f-strings, if it were possible that is, is to avoid the extra work and mental effort of prepare/maintain a format string and associated list of arguments on top of the actual debugging, which also happens with debug.callback when more analysis is needed of the runtime values. This might be just an exercise to see if it's possible to get any closer to that ideal (yeah, probably still misguided), but as another way to meet the above requirements an actual function closure could be used as a way to "automatically" maintain the mapping without separate args. I worked up a demo below named debug_closure() of conceptually how it might work. Avoiding for the moment jax internals, strings represent "tracers", as placeholders in the enclosed scope that have to be re-mapped to actual values. Example use:

# ugly part, see below
...

_debug_closure = partial(
  debug_closure,
  namespace={'zero': 0, 'one': 1, 'two': 2, 'three': 3})

def g():
  x = 'one'
  y = 'two'
  z = 123456
  _debug_closure(lambda: f"inline: {x+y}, {2*x-y}, {[z, z-x]}")

  @_debug_closure
  def f():
    print("callback:")
    print(f" {x+y=}")
    print(f" {2*x-y=}")
    print(f" {[z, z-x]=}")

g()
inline: 3, 0, [123456, 123455]
callback:
 x+y=3
 2*x-y=0
 [z, z-x]=[123456, 123455]

Now, the ugly part to get that to work is something like...

from functools import partial
from types import FunctionType, CellType

def _replace_closure(func, values: tuple):
  return FunctionType(
    func.__code__,
    func.__globals__,
    name = func.__name__,
    argdefs = func.__defaults__,
    closure = tuple(CellType(v) for v in values))


def debug_closure(func, *, namespace):
  values = [c.cell_contents for c in func.__closure__]

  empty_func = _replace_closure(func, [None]*len(func.__closure__))

  # ...

  renew_func = _replace_closure(empty_func, [
    namespace.get(k) if type(k) is str else k
    for k in values])

  out = renew_func()

  if out is not None:
    print(out)

...

@kcdodd
Copy link
Author

kcdodd commented Feb 8, 2025

I went ahead and attempted a more formal proof of concept for the "closure" approach (basically lambda capture). I think this basically boils down to "flattening" a function that has a closure, so more appropriate to store the decomposed form of the function instead of reconstituting the function with a dummy closure. Below is an excerpt of a working example with jax, the full code for StaticClosure is a bit more involved, so I have put it in a gist debug_closure.py.

def _debug_closure_callback(func: StaticClosure, closure):
  out = func(closure)

  if out is not None:
    if type(out) is not str:
      out = repr(out)

    sys.stdout.write(out + "\n")

def debug_closure(func: FunctionType = None, /, *, ordered: bool = False):
  if func is None:
    return partial(debug_closure, ordered=ordered)

  closure, func = StaticClosure.from_function(func)

  jax.debug.callback(
    _debug_closure_callback,
    func,
    closure,
    ordered=ordered)
def g(x, y, z):
  out = x+y+z

  debug_closure(lambda: "no closure")
  debug_closure(lambda: f"inline: {x+y}, {2*x-y}, {[z, z-x]}")

  @debug_closure
  def f1():
    print("callback:")
    print(f" {x+y=}")
    print(f" {2*x-y+np.pi=}")
    print(f" {[z, z-x]=}")

  @debug_closure(ordered=True)
  def f2():
    print(f"callback(ordered): {out=}")

  return out

x = jnp.array(1.0)
y = jnp.array(2.0)
z = jnp.array(1234.0)

jax.jit(g)(x, y, z)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants