-
Notifications
You must be signed in to change notification settings - Fork 662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add cached_partial #4469
base: main
Are you sure you want to change the base?
[nnx] add cached_partial #4469
Conversation
ad1a9a6
to
78be9e9
Compare
78be9e9
to
2ef1999
Compare
2ef1999
to
89aee13
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
186ca31
to
e421e3c
Compare
035857e
to
8c13c94
Compare
8c13c94
to
dd5755a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this massive change! Glad we got a solution for this.
@@ -22,6 +22,7 @@ | |||
|
|||
|
|||
class Config: | |||
flax_use_flaxlib: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you still need this predefined attribute if you already have the bool_flag(...)
line below?
If yes, does it make sense to add other flags here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python doesn't treat this as an attribute because Config
is not a dataclass, this is a pure type hint.
flax/nnx/module.py
Outdated
node_dict = graph.get_node_impl(self).node_dict(self) | ||
node_impl = graph.get_node_impl(self) | ||
if node_impl is None: | ||
raise RuntimeError(f'Unsupported type: {type(self)}, this is a bug.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have better error message here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, changing this to an assert
flax/nnx/object.py
Outdated
node_dict = graph.get_node_impl(node).node_dict(node) | ||
node_impl = graph.get_node_impl(node) | ||
if node_impl is None: | ||
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And also better message here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to an assert
flax/nnx/tracers.py
Outdated
attributes={'jax_trace': self._jax_trace}, | ||
path=path, | ||
subtree_renderer=subtree_renderer, | ||
import treescope # type: ignore[import-not-found,import-untyped] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Style nit: Can imports be moved to top level? We do top-level import of treescope
in a bunch of files so adding this one won't make a difference.
@@ -250,7 +250,9 @@ def test_nnx_to_linen(self): | |||
assert y.shape == (1, 64) | |||
np.testing.assert_allclose(y, x @ variables['params']['kernel']) | |||
assert 'nnx' in variables | |||
assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef) | |||
assert isinstance( | |||
variables['nnx']['graphdef'], nnx.graph.NodeDef | nnx.graph.NodeRef |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a type annotation for nnx.graph.NodeDef | nnx.graph.NodeRef
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GraphDef
is now defined as GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]]
but isinstance
doesn't like it because its Generic.
95c54e5
to
7aea7d9
Compare
7aea7d9
to
48c59d3
Compare
What does this PR do?
nnx.jit
.flaxlib
(unreleased)FLAX_USE_FLAXLIB
flag.nnx.cached_partial
API to cache the traversal or NNX objects, this API yields performance on par withsplit
/merge
.