Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add cached_partial #4469

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[nnx] add cached_partial #4469

wants to merge 1 commit into from

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jan 1, 2025

What does this PR do?

  • Optimizes nnx.jit.
  • Fixes nanobind setup for flaxlib (unreleased)
  • Adds the FLAX_USE_FLAXLIB flag.
  • Adds nnx.cached_partial API to cache the traversal or NNX objects, this API yields performance on par with split / merge.

@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 9 times, most recently from ad1a9a6 to 78be9e9 Compare January 3, 2025 23:32
@cgarciae cgarciae changed the base branch from nnx-cache-flatten to main January 14, 2025 14:46
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch from 78be9e9 to 2ef1999 Compare January 14, 2025 14:53
@cgarciae cgarciae marked this pull request as ready for review January 14, 2025 14:53
@cgarciae cgarciae changed the title [nnx] add flaxlib [nnx] add cache_args Jan 14, 2025
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch from 2ef1999 to 89aee13 Compare January 14, 2025 21:10
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 15 times, most recently from 186ca31 to e421e3c Compare January 17, 2025 09:42
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 5 times, most recently from 035857e to 8c13c94 Compare January 24, 2025 00:23
@cgarciae cgarciae changed the title [nnx] add cache_args [nnx] add partial_cache Jan 24, 2025
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch from 8c13c94 to dd5755a Compare January 25, 2025 02:27
@cgarciae cgarciae changed the title [nnx] add partial_cache [nnx] add cached_partial Jan 25, 2025
Copy link
Collaborator

@IvyZX IvyZX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this massive change! Glad we got a solution for this.

@@ -22,6 +22,7 @@


class Config:
flax_use_flaxlib: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you still need this predefined attribute if you already have the bool_flag(...) line below?
If yes, does it make sense to add other flags here too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python doesn't treat this as an attribute because Config is not a dataclass, this is a pure type hint.

node_dict = graph.get_node_impl(self).node_dict(self)
node_impl = graph.get_node_impl(self)
if node_impl is None:
raise RuntimeError(f'Unsupported type: {type(self)}, this is a bug.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have better error message here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, changing this to an assert

node_dict = graph.get_node_impl(node).node_dict(node)
node_impl = graph.get_node_impl(node)
if node_impl is None:
raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also better message here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to an assert

attributes={'jax_trace': self._jax_trace},
path=path,
subtree_renderer=subtree_renderer,
import treescope # type: ignore[import-not-found,import-untyped]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: Can imports be moved to top level? We do top-level import of treescope in a bunch of files so adding this one won't make a difference.

@@ -250,7 +250,9 @@ def test_nnx_to_linen(self):
assert y.shape == (1, 64)
np.testing.assert_allclose(y, x @ variables['params']['kernel'])
assert 'nnx' in variables
assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef)
assert isinstance(
variables['nnx']['graphdef'], nnx.graph.NodeDef | nnx.graph.NodeRef
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a type annotation for nnx.graph.NodeDef | nnx.graph.NodeRef?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GraphDef is now defined as GraphDef = tp.Union[NodeDef[Node], NodeRef[Node]] but isinstance doesn't like it because its Generic.

@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch 5 times, most recently from 95c54e5 to 7aea7d9 Compare February 4, 2025 18:03
@cgarciae cgarciae force-pushed the nnx-optimize-fingerprint branch from 7aea7d9 to 48c59d3 Compare February 4, 2025 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants