Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pathways workloads for v6e benchmarks #1040

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ def add_shared_arguments(custom_parser: argparse.ArgumentParser):
default='maxtext_base_image',
help='version of base docker image to be benchmarked command.',
)
custom_parser.add_argument(
'--use_pathways',
type=bool,
default=False,
help='whether to use pathways or not.',
)
custom_parser.add_argument(
'--xpk_path',
type=str,
default='~/xpk',
help='path to xpk dir.',
)

def main() -> None:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -153,9 +165,12 @@ def main() -> None:
model_name=benchmark_model,
software_config=v6e_env_configs,
hardware_config=v6e_256_configs,
use_pathways=options.use_pathways,
)

xpk_benchmark_runner(cluster_config, [model_runner])
xpk_benchmark_runner(
cluster_config, [model_runner], xpk_path=options.xpk_path
)


if __name__ == '__main__':
Expand Down
81 changes: 57 additions & 24 deletions benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class BenchmarkRunner:
model_name: str
hardware_config: HWConfig
software_config: SWconfig
use_pathways: bool


def chunks(lst: list, n: int):
Expand Down Expand Up @@ -265,13 +266,16 @@ def build_user_command(
cluster_config: XpkConfig,
base_output_directory: str,
buffer_size: int,
use_pathways: bool = False,
):
config_tuning_params = ''
for key, value in model.tuning_params.items():
config_tuning_params += f'{key}={value} '

install_libtpu_cmd = ''
if libtpu_type == LibTpuType.NIGHTLY:
if use_pathways:
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future - Should we perhaps give users a message that we won't be running Pathways with nightly Libtpu?

elif libtpu_type == LibTpuType.NIGHTLY:
install_libtpu_cmd += (
f' pip install libtpu-nightly==0.1.dev{libtpu_date} -f'
' https://storage.googleapis.com/libtpu-releases/index.html &&'
Expand All @@ -288,7 +292,9 @@ def build_user_command(
# model.xla_flags += ' --grpc_enable_rpc_receive_coalescing=true'
# model.xla_flags += ' --grpc_experiments=tcp_rcv_lowat'

libtpu_flags = f"LIBTPU_INIT_ARGS='{model.xla_flags}'"
libtpu_flags = '' if use_pathways else f"LIBTPU_INIT_ARGS='{model.xla_flags}'"
jax_platforms = 'proxy' if use_pathways else 'tpu,cpu'
pathways_prefix = 'pw-' if use_pathways else ''

return (
# f'python3 -m pip install google-cloud-aiplatform==v1.61.0 &&'
Expand All @@ -301,7 +307,7 @@ def build_user_command(
f' echo {libtpu_flags} &&'
# f' echo {model.tuning_params["sa_block_q"]}-q-dq-{model.tuning_params["sa_block_q_dq"]}-q-dkv-{model.tuning_params["sa_block_q_dkv"]} &&'
# f' echo {model.tuning_params["ici_fsdp_parallelism"]} {model.tuning_params["ici_tensor_parallelism"]} &&'
f' export JAX_PLATFORMS=tpu,cpu &&'
f' export JAX_PLATFORMS={jax_platforms} &&'
# f' export JAX_DEBUG_NANS=True &&'
# f' export TPU_MEGACORE=megachip_tccontrol &&'
# f' echo TPU MEGACORE: $TPU_MEGACORE &&'
Expand All @@ -315,10 +321,14 @@ def build_user_command(
f' base_output_directory={base_output_directory}'
f' use_vertex_tensorboard=false'
' vertex_tensorboard_project="" vertex_tensorboard_region=""'
f' run_name="{model.model_name}-{num_slices}-{libtpu_date}"'
f' run_name="{pathways_prefix}{model.model_name}-{num_slices}-{libtpu_date}"'
)


def reformat_xla_flags_for_xpk(flags: str):
return '"' + flags.replace(' ', ' \\\n') + '"'


def generate_xpk_workload_cmd(
model: model_configs.MaxTextModel,
cluster_config: XpkConfig,
Expand All @@ -327,6 +337,8 @@ def generate_xpk_workload_cmd(
libtpu_version: str,
base_output_directory: str,
buffer_size: int,
xpk_path: str,
use_pathways: bool = False,
):
"""Generates a command to run a maxstar model on XPK."""
num_steps = 20
Expand All @@ -337,8 +349,9 @@ def generate_xpk_workload_cmd(
random.choice(string.ascii_lowercase + string.digits) for _ in range(N)
)

pw_prefix = 'pw-' if use_pathways else ''
name = (
f"{model.model_name.replace('_', '-')}-{cluster_config.num_slices}-{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
f"{pw_prefix}{model.model_name.replace('_', '-')}-{cluster_config.num_slices}{time.strftime('%m%d%H', time.localtime())}-{temp_post_fix}"
)
user_command = build_user_command(
model,
Expand All @@ -349,6 +362,7 @@ def generate_xpk_workload_cmd(
cluster_config,
base_output_directory,
buffer_size,
use_pathways,
)

additional_flags = ''
Expand All @@ -361,25 +375,42 @@ def generate_xpk_workload_cmd(
' https://raw.githubusercontent.com/GoogleCloudPlatform/ai-on-gke/9ff340f07f70be0130454f9e7238551587242b75/scripts/network-setup/v6e-network-optimization.yaml'
)

# pathways-related flags
pathways_specific_flags = ''
docker_image_flag = f'--base-docker-image="{BASE_DOCKER_IMAGE}"'
if use_pathways:
pathways_specific_flags = (
' --use-pathways'
f' --additional_pw_proxy_args={reformat_xla_flags_for_xpk(model.xla_flags)}'
Copy link
Collaborator

@RoshaniN RoshaniN Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need hyphens instead of underscores here in "--additional_pw_proxy_args" to be consistent with rest of the parsed args in XPK.

)
docker_image_flag = (
'--docker-image="us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/'
'maxtext_jax_stable:latest"'
)

print(f'User command: {user_command}')
xpk_command = (
# f'{perf_optimzation_dcn} &&'
f'python3 {xpk_path}/xpk.py workload create'
f' {pathways_specific_flags}'
f' --cluster={cluster_config.cluster_name}'
f' --project={cluster_config.project}'
f' --zone={cluster_config.zone}'
f' --device-type={cluster_config.device_type}'
f' --num-slices={cluster_config.num_slices}'
f' --command="{user_command}"'
f' {docker_image_flag}'
' --enable-debug-logs'
f' --workload={name}'
' --priority=medium'
# ' --use-vertex-tensorboard'
# f' --experiment-name={test_purpose_name}'
f' {additional_flags}'
)
print(f'XPK command: {xpk_command}')

return (
(
# f'{perf_optimzation_dcn} &&'
'python3 ~/xpk/xpk.py workload create'
f' --cluster={cluster_config.cluster_name}'
f' --project={cluster_config.project}'
f' --zone={cluster_config.zone}'
f' --device-type={cluster_config.device_type}'
f' --num-slices={cluster_config.num_slices}'
f' --command="{user_command}"'
f' --base-docker-image="{BASE_DOCKER_IMAGE}"'
' --enable-debug-logs'
f' --workload={name}'
' --priority=medium'
# ' --use-vertex-tensorboard'
# f' --experiment-name={test_purpose_name}'
f' {additional_flags}'
),
xpk_command,
name,
)

Expand All @@ -401,12 +432,12 @@ def run_xpk_workload(
Returns:
"""
command, _ = generate_xpk_workload_cmd(
model, cluster_config, num_slices, libtpu_type, libtpu_version, base_output_directory=cluster_config.base_output_directory, buffer_size=buffer_size
model, cluster_config, num_slices, libtpu_type, libtpu_version, base_output_directory=cluster_config.base_output_directory, buffer_size=buffer_size, xpk_path=xpk_path,
)
return run_command_with_updates(command, 'Run XPK workload', cluster_config)


def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner]):
def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRunner], xpk_path: str):
xpk_workload_names = []
xpk_workload_cmds = []
for benchmark in benchmarks:
Expand All @@ -418,6 +449,8 @@ def xpk_benchmark_runner(cluster_config: XpkConfig, benchmarks: list[BenchmarkRu
libtpu_version=benchmark.software_config.libtpu_version,
base_output_directory=cluster_config.base_output_directory,
buffer_size=4294967296,
xpk_path=xpk_path,
use_pathways=benchmark.use_pathways,
)
xpk_workload_names.append(name)
xpk_workload_cmds.append(command)
Expand Down
Loading