diff --git a/.gitignore b/.gitignore index 0bc7f3cbe6..a1de4b1dda 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,8 @@ flaxlib_src/build flaxlib_src/builddir flaxlib_src/dist flaxlib_src/subprojects - +target/ +flaxlib.cpython-* # used by direnv .envrc diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py new file mode 100644 index 0000000000..73cff6d6d6 --- /dev/null +++ b/benchmarks/nnx_graph_overhead.py @@ -0,0 +1,118 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import numpy as np +import optax +from time import time + +from flax import nnx + +from absl import flags +from absl import app + +FLAGS = flags.FLAGS +flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') +flags.DEFINE_integer('width', 32, 'Hidden layer size') +flags.DEFINE_integer('depth', 5, 'Depth of the model') + + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.list = [ + nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + nnx.Param(jnp.zeros((dout,))), + ] + self.dict = { + 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + 'b': nnx.Param(jnp.zeros((dout,))), + } + + + +class MLP(nnx.Module): + def __init__(self, depth, *, rngs: nnx.Rngs): + self.intermediates = [ + Linear(10, 10, rngs=rngs) for _ in range(depth) + ] + + +def main(argv): + print(argv) + mode: str = FLAGS.mode + total_steps: int = FLAGS.total_steps + width: int = FLAGS.width + depth: int = FLAGS.depth + + print(f'{mode=}, {total_steps=}, {width=}') + + X = np.linspace(0, 1, 100)[:, None] + Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + model = MLP(depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + + #------------------------------------------------------------ + # NNX + #------------------------------------------------------------ + if mode in ['all', 'nnx']: + @nnx.jit + def step_nnx(model: MLP, optimizer: nnx.Optimizer): + pass + + t0 = time() + for _ in range(total_steps): + step_nnx(model, optimizer) + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print("### NNX ###") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + + + #------------------------------------------------------------ + # JAX + #------------------------------------------------------------ + + if mode in ['all', 'jax']: + @jax.jit + def step_jax(graphdef, state): + return graphdef, state + + graphdef, state = nnx.split((model, optimizer)) + t0 = time() + for _ in range(total_steps): + graphdef, state = step_jax(graphdef, state) + + total_time = time() - t0 + time_per_step = total_time / total_steps + time_per_layer = time_per_step / depth + print("### JAX ###") + print('total time:', total_time) + print(f'time per step: {time_per_step * 1e6:.2f} µs') + print(f'time per layer: {time_per_layer * 1e6:.2f} µs') + print() + + + +if __name__ == '__main__': + app.run(main) diff --git a/benchmarks/nnx_simple_training.py b/benchmarks/nnx_simple_training.py new file mode 100644 index 0000000000..0cb08066fe --- /dev/null +++ b/benchmarks/nnx_simple_training.py @@ -0,0 +1,168 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import jax +import jax.numpy as jnp +import numpy as np +import optax +from time import time + +from flax import nnx + +from absl import flags +from absl import app + +FLAGS = flags.FLAGS +flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') +flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') +flags.DEFINE_integer('batch_size', 32, 'Batch size') +flags.DEFINE_integer('width', 32, 'Hidden layer size') +flags.DEFINE_integer('depth', 5, 'Depth of the model') + + +def dataset(X, Y, batch_size): + while True: + idx = np.random.choice(len(X), size=batch_size) + yield X[idx], Y[idx] + + +class Linear(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) + self.b = nnx.Param(jnp.zeros((dout,))) + + def __call__(self, x): + return x @ self.w + self.b + + +class Count(nnx.Variable): + pass + + +class MLP(nnx.Module): + def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): + self.count = Count(jnp.array(0)) + self.linear_in = Linear(din, dhidden, rngs=rngs) + self.intermediates = [ + Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) + ] + self.linear_out = Linear(dhidden, dout, rngs=rngs) + + def __call__(self, x): + self.count.value += 1 + x = nnx.relu(self.linear_in(x)) + for layer in self.intermediates: + x = nnx.relu(layer(x)) + x = self.linear_out(x) + return x + + +def main(argv): + print(argv) + mode: str = FLAGS.mode + total_steps: int = FLAGS.total_steps + batch_size: int = FLAGS.batch_size + width: int = FLAGS.width + depth: int = FLAGS.depth + + print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') + + if mode not in ['nnx', 'jax']: + raise ValueError(f'Invalid mode: {mode}') + + X = np.linspace(0, 1, 100)[:, None] + Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) + + model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) + tx = optax.sgd(1e-3) + optimizer = nnx.Optimizer(model, tx) + t0 = time() + + if mode == 'nnx': + + @nnx.jit + def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads: nnx.State = nnx.grad(loss_fn)(model) + optimizer.update(grads) + + @nnx.jit + def test_step_nnx(model: MLP, batch): + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + return {'loss': loss} + + for step, batch in enumerate(dataset(X, Y, batch_size)): + train_step_nnx(model, optimizer, batch) + + if step % 1000 == 0: + logs = test_step_nnx(model, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + else: + + @jax.jit + def train_step_jax(graphdef, state, batch): + model, optimizer = nnx.merge(graphdef, state) + x, y = batch + + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y - y_pred) ** 2) + + grads = nnx.grad(loss_fn)(model) + optimizer.update(grads) + + return nnx.state((model, optimizer)) + + @jax.jit + def test_step_jax(graphdef, state, batch): + model, optimizer = nnx.merge(graphdef, state) + x, y = batch + y_pred = model(x) + loss = jnp.mean((y - y_pred) ** 2) + state = nnx.state((model, optimizer)) + return state, {'loss': loss} + + graphdef, state = nnx.split((model, optimizer)) + + for step, batch in enumerate(dataset(X, Y, batch_size)): + state = train_step_jax(graphdef, state, batch) + + if step % 1000 == 0: + state, logs = test_step_jax(graphdef, state, (X, Y)) + print(f"step: {step}, loss: {logs['loss']}") + + if step >= total_steps - 1: + break + + model, optimizer = nnx.merge(graphdef, state) + + total_time = time() - t0 + print('total time:', total_time) + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') + print('times called:', model.count.value) + + +if __name__ == '__main__': + app.run(main) diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 2a92b8b5ad..fec21add20 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -94,9 +94,9 @@ def __str__(self) -> str: return repr(self) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): - type: type + type: type[Node] flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] def node_dict(self, node: Node) -> dict[Key, Leaf]: @@ -104,7 +104,7 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]: return dict(nodes) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): set_key: tp.Callable[[Node, Key, Leaf], None] pop_key: tp.Callable[[Node, Key], Leaf] @@ -116,7 +116,7 @@ def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): self.set_key(node, key, value) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] @@ -126,7 +126,8 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): ] -_node_impl_for_type: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {} def register_graph_node_type( @@ -137,7 +138,10 @@ def register_graph_node_type( create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], ): - _node_impl_for_type[type] = GraphNodeImpl( + if type in GRAPH_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + GRAPH_REGISTRY[type] = GraphNodeImpl( type=type, flatten=flatten, set_key=set_key, @@ -146,19 +150,30 @@ def register_graph_node_type( clear=clear, ) +def register_pytree_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], + unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node], +): + if type in PYTREE_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + PYTREE_REGISTRY[type] = PytreeNodeImpl( + type=type, flatten=flatten, unflatten=unflatten + ) def is_node(x: tp.Any) -> bool: - if type(x) in _node_impl_for_type: + if type(x) in GRAPH_REGISTRY: return True return is_pytree_node(x) def is_graph_node(x: tp.Any) -> bool: - return type(x) in _node_impl_for_type + return type(x) in GRAPH_REGISTRY def is_node_type(x: type[tp.Any]) -> bool: - return x in _node_impl_for_type or x is PytreeType + return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: @@ -167,19 +182,23 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: node_type = type(x) - if node_type not in _node_impl_for_type: - if is_pytree_node(x): - return PYTREE_NODE_IMPL - else: - raise ValueError(f'Unknown node type: {x}') - - return _node_impl_for_type[node_type] + if node_type in GRAPH_REGISTRY: + return GRAPH_REGISTRY[node_type] + elif node_type in PYTREE_REGISTRY: + return PYTREE_REGISTRY[node_type] + elif is_pytree_node(x): + return PYTREE_NODE_IMPL # type: ignore + else: + raise ValueError(f'Unknown node type: {x}') def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: - if x is PytreeType: - return PYTREE_NODE_IMPL - return _node_impl_for_type[x] + if x is GenericPytree: + return PYTREE_NODE_IMPL # type: ignore + elif x in PYTREE_REGISTRY: + return PYTREE_REGISTRY[x] + else: + return GRAPH_REGISTRY[x] class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): @@ -1751,11 +1770,23 @@ class Static(tp.Generic[A]): # --------------------------------------------------------- # Pytree # --------------------------------------------------------- -class PytreeType: ... +class GenericPytree: ... def is_pytree_node(x: tp.Any) -> bool: - return not jax.tree_util.all_leaves((x,)) + t = type(x) + if t in PYTREE_REGISTRY: + return True + elif t in GRAPH_REGISTRY: + return False + # known non-pytree types + elif isinstance(x, Variable): + return False + # knon pytree types + elif isinstance(x, (VariableState, State)): + return True + else: + return not jax.tree_util.all_leaves((x,)) def _key_path_to_key(key: tp.Any) -> Key: @@ -1792,7 +1823,33 @@ def _unflatten_pytree( PYTREE_NODE_IMPL = PytreeNodeImpl( - type=PytreeType, + type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, ) + +# common pytrees +# list +register_pytree_node_type( + list, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore +) +# tuple +register_pytree_node_type( + tuple, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore +) +# dict +register_pytree_node_type( + dict, + flatten=lambda x: (sorted(x.items()), None), + unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore +) +# None +register_pytree_node_type( + type(None), + flatten=lambda x: ([], None), + unflatten=lambda _, __: None, # type: ignore +) \ No newline at end of file diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 8983acbe7f..fb0496e07a 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -303,7 +303,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.subgraphs['tree'].type is nnx.graph.PytreeType + assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state)