Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sharding_tolerance configurable #1058

Merged
merged 1 commit into from
Dec 10, 2024
Merged

Conversation

Doris26
Copy link
Collaborator

@Doris26 Doris26 commented Nov 21, 2024

Using pure (512 DCN) FSDP triggers MaxText error of "Number of unsharded parameters exceeds tolerance 2% of total parameters."

Make tolerance a configurable param to avoid future errors across certain machines setups.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link

google-cla bot commented Nov 21, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Doris26 Doris26 changed the title Tolerance configurable Make tolerance configurable Nov 21, 2024
@Doris26 Doris26 changed the title Make tolerance configurable Make sharding_tolerance configurable Dec 3, 2024
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this idea! Some small refactoring:

Ideally we keep train.py as small as possible - and have this tolerance configurable in the typical config way base.yml

MaxText/train.py Outdated
@@ -71,6 +71,7 @@
Transformer = models.Transformer
EPS = 1e-8
_DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE = 2 * 1024**3
_DEFAULT_TOLERANCE = 0.02
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this to a field in base.yml named e.g. "sharded_tolerance" defaulting to 0.02

MaxText/train.py Outdated
@@ -82,6 +83,8 @@ def validate_train_config(config):
if not config.base_output_directory.startswith("gs://"):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."
if "tolerance" in config.__dict__ and (config.tolerance > 1.0 or config.tolerance < 0.0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this check to pyconfig.py? Ideally we keep train.py as small as possible

MaxText/train.py Outdated
@@ -550,14 +553,15 @@ def setup_train_loop(config):
record_goodput(recorder, config, recorder.record_tpu_init_end_time if recorder else None)
record_goodput(recorder, config, recorder.record_training_preparation_start_time if recorder else None)
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
tolerance = config.tolerance if "tolerance" in config.__dict__ else _DEFAULT_TOLERANCE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be necessary when tolerance is part of config

MaxText/train.py Outdated

state, _, state_mesh_shardings, data_iterator = max_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
)

if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02)
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use config.sharded_tolerance instead of tolerance here

@Doris26 Doris26 requested a review from gobbleturk December 5, 2024 03:43
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing feedback!

@Doris26 Doris26 force-pushed the tolerance_configurable branch 3 times, most recently from 737a915 to 614fb0b Compare December 9, 2024 22:10
@Doris26 Doris26 force-pushed the tolerance_configurable branch from 614fb0b to 2e0ac9d Compare December 9, 2024 22:58
@copybara-service copybara-service bot merged commit 86d85e4 into main Dec 10, 2024
14 checks passed
@copybara-service copybara-service bot deleted the tolerance_configurable branch December 10, 2024 00:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants