-
Notifications
You must be signed in to change notification settings - Fork 2
dev: Consider equinox.Module
(or similar) inplace of jaxutils.PyTree
#16
Comments
Thanks @patrick-kidger! I was not aware - much appreciate you opening this! The framework provided by Equinox looks neat! I have not yet digged into your source code, but if we were to adopt Equinox as a backend for the JaxGP ecosystem, I have some preliminary questions:
Thanks, Dan :) |
Glad you're interested!
|
Thanks for such a detailed comment, @patrick-kidger! We discussed moving to Equinox as base module, and all agree that it is the way to go! The code is straight forward to migrate. I quite like that we can detach loss functions from the objects. The only issue I have is dealing with parameter transformations. I have had a go playing around with the metadata - wondered if you would know if there's a nice way of building a PyTree of bijectors of the same structure as the model? Or is there a way of accessing the metadata during a tree map? For the short-term, to get things rolling, we may go for a distrax wrapper,as you suggested. But keen to get a new framework up and running asap, that particulary facilitates infrastructure for doing fast variational inference that is seemlingly absent from the tfp ecosystem. If we can, minimally, get this running for multivariate Gaussians and Bernoulli distributions then that would complete our Equinox transition. 🙌 I am currently away this week, but would you be free to meet the week commencing 6th Feb? Thanks, Dan :) |
Excellent! That's all really great to hear. In terms of parameter transformations, you've got several options. Here's one simple approach: class Bijector(eqx.Module):
param: Array
transform: Callable
def is_bijector(x):
return isinstance(x, Bijector)
def resolve_bijector(x):
if is_bijector(x):
return x.transform(x.param)
else:
return x
def resolve_all_bijectors(tree):
return jtu.tree_map(resolve_bijector, tree, is_leaf=is_bijector)
model = ... # some pytree with Bijectors on its nodes. It can have nodes of other things too.
@jax.jit # or eqx.filter_jit
@jax.grad # or eqx.filter_grad
def make_step(model, ...): # pass the untransformed parameters across the `grad` API boundary, to get gradients wrt parameters in their original space
model = resolve_all_bijectors(model)
... If you control class Model(eqx.Module):
...
def __call__(self, ...):
self = resolve_all_bijectors(self)
... (One can find other approaches too, but this is one of the more common patterns.) Meeting -- sure thing. I'm about to leave for two weeks (hiking Kilimanjaro, should be fun) but would be free from the 16th Feb. Send me an email (on my website) and we'll schedule something. |
I just stumbled across
jaxutils
, and spottedjaxutils.PyTree
. I can see that this is based off of Distrax'sJittable
base class.I wanted to give a heads-up that Distrax's approach has some performance issues, and some compatibility issues. So I'd really recommend against using it.
Equinox has an
equinox.Module
which accomplishes the same thing (registering a class as a pytree), and also automatically handles a lot of edge cases. (E.g. bound methods are pytrees too; multiple inheritance works smoothly; good performance; pretty-printing; etc.) I realise I am advertising my own libary here... but hopefully it's of interest!The text was updated successfully, but these errors were encountered: