-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make sharding_tolerance configurable #1058
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this idea! Some small refactoring:
Ideally we keep train.py as small as possible - and have this tolerance configurable in the typical config way base.yml
MaxText/train.py
Outdated
@@ -71,6 +71,7 @@ | |||
Transformer = models.Transformer | |||
EPS = 1e-8 | |||
_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3 | |||
_DEFAULT_TOLERANCE = 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this to a field in base.yml named e.g. "sharded_tolerance" defaulting to 0.02
MaxText/train.py
Outdated
@@ -82,6 +83,8 @@ def validate_train_config(config): | |||
if not config.base_output_directory.startswith("gs://"): | |||
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") | |||
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." | |||
if "tolerance" in config.__dict__ and (config.tolerance > 1.0 or config.tolerance < 0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move this check to pyconfig.py? Ideally we keep train.py as small as possible
MaxText/train.py
Outdated
@@ -550,14 +553,15 @@ def setup_train_loop(config): | |||
record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None) | |||
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None) | |||
data_iterator, eval_data_iterator = create_data_iterator(config, mesh) | |||
tolerance = config.tolerance if "tolerance" in config.__dict__ else _DEFAULT_TOLERANCE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this shouldn't be necessary when tolerance is part of config
MaxText/train.py
Outdated
|
||
state, _, state_mesh_shardings, data_iterator = max_utils.setup_training_state( | ||
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager | ||
) | ||
|
||
if not config.using_pipeline_parallelism: | ||
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage | ||
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02) | ||
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use config.sharded_tolerance instead of tolerance here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing feedback!
737a915
to
614fb0b
Compare
614fb0b
to
2e0ac9d
Compare
Using pure (512 DCN) FSDP triggers MaxText error of "Number of unsharded parameters exceeds tolerance 2% of total parameters."
Make tolerance a configurable param to avoid future errors across certain machines setups.
Checklist
Before submitting this PR, please make sure (put X in square brackets):