From 4871dd744d52717d35bc5a5075db1167fc5308a3 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Mon, 18 Nov 2024 23:04:20 +0000 Subject: [PATCH 1/8] Add checkpoint topology discovery for the replicator service --- MaxText/configs/base.yml | 6 ++++++ MaxText/max_utils.py | 43 ++++++++++++++++++++++++++++++++++++++++ MaxText/pyconfig.py | 25 ++++++++--------------- 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 19b7287c9..308fa83bb 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -201,6 +201,12 @@ local_checkpoint_directory: "" # It should be a positive number when and only when `enable_emergency_checkpoint` is True. local_checkpoint_period: 0 +# Whether or not to use emergency checkpoint with the replicator service. +use_replicator_service: False + +# The interval to backup local checkpoints to the persistent storage. +replicator_backup_interval_minutes: 0 + # Jax cache directory jax_cache_dir: "~/jax_cache" diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 0a9a26128..9ca5cf0b2 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -273,6 +273,34 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): " coordinator_address to initialize JAX distributed runtime..." ) jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id)) + if raw_keys["use_replicator_service"]: + REPLICATOR_FILE = "replicator.yaml" + TEMP_FILE = REPLICATOR_FILE + ".tmp" + replicator_file = epath.Path(raw_keys["local_checkpoint_directory"]) / REPLICATOR_FILE + temp_file = epath.Path(raw_keys["local_checkpoint_directory"]) / TEMP_FILE + num_slices = get_num_slices(raw_keys) + num_nodes = jax.process_count() + nodes_per_slice = num_nodes // num_slices + max_logging.log(f"num_slices: {num_slices}, num_nodes: {num_nodes}, nodes_per_slice: {nodes_per_slice}") + node_rank = jax.process_index() + peer_ranks = [] + for i in range(num_slices): + peer = node_rank % nodes_per_slice + i * nodes_per_slice + if peer != node_rank: + peer_ranks.append(peer) + run_name = raw_keys["run_name"] + if run_name == "": + run_name = os.environ.get("JOBSET_NAME") # using XPK default + + replicator_yaml = f"""job-name: {run_name} + node-rank: {node_rank} + nodes: {num_nodes} + workers-per-node: 1 + peer-ranks: {peer_ranks} + backup-interval-minutes: {raw_keys["replicator_backup_interval_minutes"]}""" + + temp_file.write_text('\n'.join([l.strip() for l in replicator_yaml.split('\n')])) + os.rename(temp_file, replicator_file) else: max_logging.log( "Initializing JAX distributed runtime without args when emergency checkpointing is" @@ -303,6 +331,21 @@ def _retrieve_jax_init_info(raw_keys): return "", "" +def get_num_slices(raw_keys): + """Calculate num_slices based on number of devices.""" + if raw_keys["hardware"] == "cpu": + max_logging.log(" Setting num_slices=1 for CPU hardware type") + return 1 + if int(raw_keys["compile_topology_num_slices"]) > 0: + return raw_keys["compile_topology_num_slices"] + else: + devices = jax.devices() + try: + return 1 + max([d.slice_index for d in devices]) + except: + return 1 + + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" return raw_keys["hardware"] == "cpu" diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5adc88199..d71fe1eff 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -105,8 +105,14 @@ def validate_keys(keys): assert ( keys["local_checkpoint_period"] > 0 ), "A positive local checkpoint period must be specified when using emergency checkpoint" + if keys["use_replicator_service"]: + assert ( + keys["replicator_backup_interval_minutes"] > 0 + ), "Replicator service is enabled, the backup interval minutes must be positive" else: - max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period") + max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," + "use_replicator_service and replicator_backup_interval_minutes") + if keys["num_experts"] > 1: validate_megablox_parallelism(keys) @@ -388,7 +394,7 @@ def user_init(raw_keys): raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1 ) - raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) if using_pipeline_parallelism(raw_keys): @@ -589,21 +595,6 @@ def get_num_target_devices(raw_keys): return len(jax.devices()) -def get_num_slices(raw_keys): - """Calculate num_slices based on number of devices.""" - if raw_keys["hardware"] == "cpu": - max_logging.log(" Setting num_slices=1 for CPU hardware type") - return 1 - if int(raw_keys["compile_topology_num_slices"]) > 0: - return raw_keys["compile_topology_num_slices"] - else: - devices = jax.devices() - try: - return 1 + max([d.slice_index for d in devices]) - except: - return 1 - - def get_quantization_local_shard_count(raw_keys): if raw_keys["quantization_local_shard_count"] == -1: return raw_keys["num_slices"] From 712bddae8b8ae14f8db716806b2fded1d8738dfa Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Mon, 18 Nov 2024 23:17:30 +0000 Subject: [PATCH 2/8] Fix linting --- MaxText/pyconfig.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index d71fe1eff..facfc1d6d 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -110,8 +110,8 @@ def validate_keys(keys): keys["replicator_backup_interval_minutes"] > 0 ), "Replicator service is enabled, the backup interval minutes must be positive" else: - max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," - "use_replicator_service and replicator_backup_interval_minutes") + max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," + " use_replicator_service and replicator_backup_interval_minutes") if keys["num_experts"] > 1: validate_megablox_parallelism(keys) From 41afdde24e0888809a8f826e8fe4202ed34fb81f Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 00:29:43 +0000 Subject: [PATCH 3/8] Use generator --- MaxText/max_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 9ca5cf0b2..7f673d482 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -341,7 +341,7 @@ def get_num_slices(raw_keys): else: devices = jax.devices() try: - return 1 + max([d.slice_index for d in devices]) + return 1 + max(d.slice_index for d in devices) except: return 1 From a5a107e6cc8abbb180ee8d23bd8a6c94b97e7e8f Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 00:43:47 +0000 Subject: [PATCH 4/8] More linting... --- MaxText/max_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 7f673d482..35bd5e605 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -291,14 +291,14 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): run_name = raw_keys["run_name"] if run_name == "": run_name = os.environ.get("JOBSET_NAME") # using XPK default - + replicator_yaml = f"""job-name: {run_name} node-rank: {node_rank} nodes: {num_nodes} workers-per-node: 1 peer-ranks: {peer_ranks} backup-interval-minutes: {raw_keys["replicator_backup_interval_minutes"]}""" - + temp_file.write_text('\n'.join([l.strip() for l in replicator_yaml.split('\n')])) os.rename(temp_file, replicator_file) else: @@ -342,7 +342,7 @@ def get_num_slices(raw_keys): devices = jax.devices() try: return 1 + max(d.slice_index for d in devices) - except: + except Exception: return 1 From f9a9fd299c3999c3f7ec5c66ba4fa68486d31469 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 01:04:42 +0000 Subject: [PATCH 5/8] Use ValueError for exception --- MaxText/max_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 35bd5e605..a5c19825a 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -342,7 +342,7 @@ def get_num_slices(raw_keys): devices = jax.devices() try: return 1 + max(d.slice_index for d in devices) - except Exception: + except ValueError: return 1 From 38309765a0565a3b53f75bbce95ab0988688c405 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 06:15:26 +0000 Subject: [PATCH 6/8] Disable linting for general exception --- MaxText/max_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index a5c19825a..71f4e47e6 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -342,7 +342,7 @@ def get_num_slices(raw_keys): devices = jax.devices() try: return 1 + max(d.slice_index for d in devices) - except ValueError: + except: # pylint: disable=broad-except return 1 From 61e169cb883d5f820c8ad74aedd112a09f8e6484 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 18:07:55 +0000 Subject: [PATCH 7/8] Use specific exceptions --- MaxText/max_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 71f4e47e6..04b30384b 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -342,7 +342,7 @@ def get_num_slices(raw_keys): devices = jax.devices() try: return 1 + max(d.slice_index for d in devices) - except: # pylint: disable=broad-except + except (ValueError, AttributeError): return 1 From 818203bae1527b58db5832cc0ab49515d06ae3a9 Mon Sep 17 00:00:00 2001 From: Xuefeng Gu Date: Tue, 19 Nov 2024 18:16:12 +0000 Subject: [PATCH 8/8] More linting --- MaxText/max_utils.py | 2 +- MaxText/pyconfig.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 04b30384b..338116d30 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -299,7 +299,7 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): peer-ranks: {peer_ranks} backup-interval-minutes: {raw_keys["replicator_backup_interval_minutes"]}""" - temp_file.write_text('\n'.join([l.strip() for l in replicator_yaml.split('\n')])) + temp_file.write_text("\n".join([l.strip() for l in replicator_yaml.split("\n")])) os.rename(temp_file, replicator_file) else: max_logging.log( diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index facfc1d6d..fa80c9226 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -107,11 +107,13 @@ def validate_keys(keys): ), "A positive local checkpoint period must be specified when using emergency checkpoint" if keys["use_replicator_service"]: assert ( - keys["replicator_backup_interval_minutes"] > 0 + keys["replicator_backup_interval_minutes"] > 0 ), "Replicator service is enabled, the backup interval minutes must be positive" else: - max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," - " use_replicator_service and replicator_backup_interval_minutes") + max_logging.log( + "Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," + " use_replicator_service and replicator_backup_interval_minutes" + ) if keys["num_experts"] > 1: validate_megablox_parallelism(keys)