Skip to content

Commit

Permalink
Merge pull request #4385 from emmanuel-ferdman:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702816307
  • Loading branch information
Flax Authors committed Dec 4, 2024
2 parents 7eac051 + 6a4a957 commit 30b438c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions docs/guides/training_techniques/use_checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@
"\n",
"Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training.\n",
"\n",
"You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`.\n",
"You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`.\n",
"\n",
"Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps.\n",
"\n",
Expand Down Expand Up @@ -1051,7 +1051,7 @@
"\n",
"In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket.\n",
"\n",
"Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context."
"Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/guides/training_techniques/use_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ flax.serialization.from_state_dict(custom_target, raw_state_dict)

Checkpointing is I/O heavy, and if you have a large amount of data to save, it may be worthwhile to put it into a background thread, while continuing with your training.

You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`.
You can do this by creating an [`orbax.checkpoint.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) in place of the `orbax.checkpoint.PyTreeCheckpointer`.

Note: You should use the same `async_checkpointer` to handle all your async saves across your training steps, so that it can make sure that a previous async save is done before the next one begins. This enables bookkeeping, such as `keep` (the number of checkpoints) and `overwrite` to be consistent across steps.

Expand Down Expand Up @@ -389,7 +389,7 @@ JAX provides a few ways to scale up your code on multiple hosts at the same time

In the [Single Program Multi Data (SPMD)](https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD) paradigm with JAX `jit`, a large multi-process array can have its data sharded across different devices. (Note that JAX `pjit` and `jit` have been merged into a single unified interface. To learn about compiling and executing JAX functions in multi-host or multi-core environments, refer to [this guide](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) and the [jax.Array migration guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html).) When a multi-process array is serialized, each host dumps its data shards to a single shared storage, such as a Google Cloud bucket.

Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context.
Orbax supports saving and loading pytrees with multi-process arrays in the same fashion as single-process pytrees. However, it's recommended to use the asynchronized [`orbax.AsyncCheckpointer`](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/async_checkpointer.py) to save large multi-process arrays on another thread, so that you can perform computation alongside the saves. With pure Orbax, saving checkpoints in a multi-process context uses the same API as in a single-process context.

```python
from jax.sharding import PartitionSpec, NamedSharding
Expand Down

0 comments on commit 30b438c

Please sign in to comment.