Skip to content

Commit

Permalink
Merge pull request #1058 from AI-Hypercomputer:tolerance_configurable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704447936
  • Loading branch information
maxtext authors committed Dec 10, 2024
2 parents ad97ccb + 2e0ac9d commit 86d85e4
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ logical_axis_rules: [
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']]

# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
return total_tflops, learnable_weight_tflops, causal_attention_tflops


def assert_params_sufficiently_sharded(params, mesh, tolerance=0.02):
def assert_params_sufficiently_sharded(params, mesh, tolerance):
"""Checks whether most params are sharded across sharding axis.
This function determines whether the majority of parameters are distributed
Expand Down
5 changes: 5 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def validate_data_input(keys):
if keys["eval_interval"] > 0:
assert keys["eval_split"], "Please specify eval_split or set eval_interval to <=0."

if keys["sharding_tolerance"] > 1.0 or keys["sharding_tolerance"] < 0.0:
max_logging.log(
"WARNING: 'sharding_tolerance: allowed percentage of non-sharded parameters' should be between 0.0 and 1.0"
)


def validate_model_name(s: str) -> bool:
"""Validate provided model name."""
Expand Down
3 changes: 2 additions & 1 deletion MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def validate_train_config(config):
if not config.base_output_directory.startswith("gs://"):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."

if config.quantization == "fp8":
# pylint: disable=line-too-long
assert (
Expand Down Expand Up @@ -557,7 +558,7 @@ def setup_train_loop(config):

if not config.using_pipeline_parallelism:
# The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, tolerance=0.02)
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance)
record_goodput(recorder, config, recorder.record_training_preparation_end_time if recorder else None)
return (
init_rng,
Expand Down

0 comments on commit 86d85e4

Please sign in to comment.