Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

dev: Consider equinox.Module (or similar) inplace of jaxutils.PyTree #16

Open
patrick-kidger opened this issue Jan 21, 2023 · 4 comments
Labels
enhancement New feature or request

Comments

@patrick-kidger
Copy link

I just stumbled across jaxutils, and spotted jaxutils.PyTree. I can see that this is based off of Distrax's Jittable base class.

I wanted to give a heads-up that Distrax's approach has some performance issues, and some compatibility issues. So I'd really recommend against using it.

Equinox has an equinox.Module which accomplishes the same thing (registering a class as a pytree), and also automatically handles a lot of edge cases. (E.g. bound methods are pytrees too; multiple inheritance works smoothly; good performance; pretty-printing; etc.) I realise I am advertising my own libary here... but hopefully it's of interest!

@patrick-kidger patrick-kidger added the enhancement New feature or request label Jan 21, 2023
@daniel-dodd
Copy link
Member

Thanks @patrick-kidger! I was not aware - much appreciate you opening this!

The framework provided by Equinox looks neat! I have not yet digged into your source code, but if we were to adopt Equinox as a backend for the JaxGP ecosystem, I have some preliminary questions:

  • Parameters: How are these handled in Equinox? We have currently been keeping everything entirely stateless - we initialise a default set of parameter from each object and pass these through to each method to compute stuff (e.g. predictions, log likelihoods) each time. Am I right in thinking Equinox stores parameters via the Module class? Additionally is there functionality to register parameter attributes e.g., bijector transformations, default prior distirbutions. These are things that might be nice to improve on in our current JaxGP implimentation.

  • Compatibility with Distrax: on the same page with the Distrax Jittable class issues, but since we use Distrax distribution objects (for example when we return predictions for a Gaussian process, we return the corresponding predictive marginal distributon), I was wondering if you know of any conpatibility issues with methods provided by Equinox with using these?

Thanks, Dan :)

@patrick-kidger
Copy link
Author

Glad you're interested!

  • Parameters: an Equinox module is just a PyTree. The leaves of that PyTree are typically model parameters. Thus "initialising a default set of parameters" is just the same thing as "initialise their object".

    In particular everything is stateless: modules are immutable. (Out-of-place updates are possible using equinox.tree_at.)

    Registering parameter attributes: for any static metadata, then you can use dataclass fields, e.g.

    from dataclasses import field
    import equinox as eqx
    
    class MyModule(eqx.Module):
        param: jnp.ndarray = field(metadata={"prior": ...})

    (An eqx.Modules is, in total, a dataclass + pytree registration + misc. utility stuff like pretty-printing).

    If you have something more complicated (i.e. you want the attributes to be part of the pytree leaves) then that's possible too; lmk if that'd be useful and I can write out more on that. (This exact point will actually also be appearing in the Equinox documentation later this week, as it happens.)

  • Distrax compatibility: unfortunately Distrax is a little buggy on this front, see this Distrax issue. Distrax isn't even really compatible with core JAX here, as its classes declare themselves to be PyTrees but then break the PyTree invariant.

    Two options here.

    1. Heavyweight approach: write your own Distrax-like library as a spin-out library. (At some point this is something I'd quite like to do anyway, as part of the broader Equinox-scientific-computing-ecosystem push -- if you went this way I'd be happy to work together on that.)

    2. Lightweight approach: wrap each Distrax class in to something that fixes this bug:

    import equinox as eqx
    from typing import Any
    
    DistraxClass = Any
    
    class FixedDistrax(eqx.Module):
        cls: DistraxClass
        args: tuple[Any, ...]
        kwargs: dict[str, Any]
    
        def __init__(self, cls, *args, **kwargs):
            self.cls = cls
            self.args = args
            self.kwargs = kwargs
    
        def log_prob(self, value):
            return self.cls(*self.args, **self.kwargs).log_prob(value)

    And then just use e.g. FixedDistrax(distrax.MultivariateNormal, mu, sigma) wherever you were using distrax.MultivariateNormal(mu, sigma) before.

    This works because we now have a FixedDistrax(...) object as our pytree, which subnodes cls, args, kwargs. Whenever we need to call a method we just instantiate a class and then call the method.

    (Even if you decide against Equinox, I'd advocate still doing your own version of this approach! Otherwise you'd also end up being incompatible with the broader JAX ecosystem via the Distrax dependency.)

@daniel-dodd
Copy link
Member

Thanks for such a detailed comment, @patrick-kidger!

We discussed moving to Equinox as base module, and all agree that it is the way to go!

The code is straight forward to migrate. I quite like that we can detach loss functions from the objects.

The only issue I have is dealing with parameter transformations. I have had a go playing around with the metadata - wondered if you would know if there's a nice way of building a PyTree of bijectors of the same structure as the model? Or is there a way of accessing the metadata during a tree map?

For the short-term, to get things rolling, we may go for a distrax wrapper,as you suggested. But keen to get a new framework up and running asap, that particulary facilitates infrastructure for doing fast variational inference that is seemlingly absent from the tfp ecosystem. If we can, minimally, get this running for multivariate Gaussians and Bernoulli distributions then that would complete our Equinox transition. 🙌

I am currently away this week, but would you be free to meet the week commencing 6th Feb?

Thanks, Dan :)

@patrick-kidger
Copy link
Author

patrick-kidger commented Feb 1, 2023

Excellent! That's all really great to hear.

In terms of parameter transformations, you've got several options. Here's one simple approach:

class Bijector(eqx.Module):
    param: Array
    transform: Callable

def is_bijector(x):
    return isinstance(x, Bijector)

def resolve_bijector(x):
    if is_bijector(x):
        return x.transform(x.param)
    else:
        return x

def resolve_all_bijectors(tree):
    return jtu.tree_map(resolve_bijector, tree, is_leaf=is_bijector)

model = ...  # some pytree with Bijectors on its nodes. It can have nodes of other things too.

@jax.jit  # or eqx.filter_jit
@jax.grad  # or eqx.filter_grad
def make_step(model, ...):  # pass the untransformed parameters across the `grad` API boundary, to get gradients wrt parameters in their original space
    model = resolve_all_bijectors(model)
    ...

If you control model's forward pass then you can resolve the bijectors there too:

class Model(eqx.Module):
    ...

    def __call__(self, ...):
        self = resolve_all_bijectors(self)
        ...

(One can find other approaches too, but this is one of the more common patterns.)

Meeting -- sure thing. I'm about to leave for two weeks (hiking Kilimanjaro, should be fun) but would be free from the 16th Feb. Send me an email (on my website) and we'll schedule something.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants