diff --git a/.gitignore b/.gitignore index 9615f94ab..018ad5497 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ /scratch # Configuration for TPU launches/secrets -.config +.levanter.yaml +.levanter.yaml # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/docs/Getting-Started-TPU-VM.md b/docs/Getting-Started-TPU-VM.md index 20fdaa765..53aaed218 100644 --- a/docs/Getting-Started-TPU-VM.md +++ b/docs/Getting-Started-TPU-VM.md @@ -83,7 +83,7 @@ on your development machine to build and run images on TPUs. First create a configuration file for future launches in your Levanter directory: ```bash -cat > .config < .levanter.yaml < # Optional +# Optional: specific environment variables for TPUs based on the TPU type +accel_env: + v6e: + # If you're lucky enough to have a v6e, you can set the following, which is pretty important for performance + LIBTPU_INIT_ARGS: "--xla_tpu_scoped_vmem_limit_kib=98304" + docker_repository: levanter # default zone: us-west4-a # if not set, will use your default zone tpu_name: test-spin-up-32 tpu_type: "v5litepod-16" -vm_image: "tpu-ubuntu2204-base" # default capacity_type: "preemptible" -autodelete: false subnetwork: "default" # default - EOF ``` @@ -155,6 +158,8 @@ a new file: If you're using `launch.py`, the config will be automatically uploaded as part of your Docker image, so you can just reference the local config path in your command line: +```bash +python infra/launch.py -- python src/levanter/main/train_lm.py --config_path config/my_config.yaml --trainer.checkpointer.base_path gs://' ``` Afterward, you can use the config directly from the TPU VM instance, e.g.: diff --git a/infra/launch.py b/infra/launch.py index 15591096d..612b77c8a 100755 --- a/infra/launch.py +++ b/infra/launch.py @@ -1,8 +1,8 @@ #!/usr/bin/python - import argparse import getpass import subprocess +import sys import time from pathlib import Path @@ -11,6 +11,7 @@ import levanter.infra.tpus from levanter.infra.tpus import launch_job + # default: tpu-ubuntu2204-base TPU_TYPE_TO_VM_IMAGE = { "v5litepod": "v2-alpha-tpuv5-lite", @@ -44,9 +45,7 @@ def main(): cli.add_arg(parser, config, ["--github_token"], type=str) cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None) - parser.add_argument( - "-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items()) - ) + parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE")) parser.add_argument("command", nargs=argparse.REMAINDER) args = parser.parse_args() @@ -68,6 +67,9 @@ def main(): tpu_gen = tpu_type.split("-")[0] version = args.version or TPU_TYPE_TO_VM_IMAGE.get(tpu_gen, "tpu-ubuntu2204-base") + if not args.version: + print(f"Using default version: {version}", file=sys.stderr) + node_count = args.node_count zone = args.zone run_id = args.run_id @@ -83,7 +85,10 @@ def main(): raise ValueError("Zone must be specified or set in gcloud config.") region = "-".join(zone.split("-")[:-1]) - env = {k: v for k, v in args.env} + + env = config.env_for_accel(tpu_type) + for key, value in args.env or []: + env[key] = value if "WANDB_PROJECT" not in env: env["WANDB_PROJECT"] = "levanter" diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index 90f2c586a..2e7551f8b 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -37,9 +37,8 @@ def main(): cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None) cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False) - parser.add_argument( - "-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items()) - ) + parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE")) + parser.add_argument("command", nargs=argparse.REMAINDER) args = parser.parse_args() @@ -62,6 +61,10 @@ def main(): github_token = args.github_token extra_context = args.extra_context + env = config.env_for_accel(tpu_type) + for key, value in args.env or []: + env[key] = value + if zone is None: zone = cli.gcloud_config()["zone"] diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index 58413ef2b..6c4229224 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -1,14 +1,62 @@ import argparse import base64 +import dataclasses import os -import shlex import subprocess +import warnings +from dataclasses import dataclass +from functools import cached_property from typing import Optional +import draccus import yaml from google.cloud import storage +@dataclass(frozen=True) +class CliConfig: + project: str | None = None + zone: str | None = None + tpu: str | None = None + repository: str | None = None + image: str | None = None + tag: str | None = None + github_user: str | None = None + github_token: str | None = None + docker_file: str | None = None + extra_context: str | None = None + docker_target: str | None = None + docker_repository: str | None = None + subnetwork: str | None = None + + env: dict[str, str] = dataclasses.field(default_factory=dict) + + accel_env: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict) + """ + Environment variables specific to a type of accelerator. The keys are the accelerator type (e.g. v5litepod-256) or + generation (e.g. v5litepod), with priority given to the more specific key. The values are dictionaries of environment + variables to set. These take priority over the general `env` field. + """ + + def env_for_accel(self, accel_type: str) -> dict[str, str]: + + base_env = self.env.copy() + + if "-" in accel_type: + base_env.update(self.accel_env.get(accel_type.split("-")[0], {})) + + if accel_type in self.accel_env: + base_env.update(self.accel_env[accel_type]) + + return base_env + + @cached_property + def as_dict(self): + dict = dataclasses.asdict(self) + # remove Nones + return {k: v for k, v in dict.items() if v is not None} + + # Oddly enough, there's no API to simply fetch the current gcloud configuration... def gcloud_config(): client = storage.Client() @@ -31,11 +79,11 @@ def get_default_zone() -> Optional[str]: return None -def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], required=False, default=None, **kw): +def add_arg(parser: argparse.ArgumentParser, config: CliConfig, flags: list[str], required=False, default=None, **kw): """Add an argument to the parser, using `config` or the environment to resolve default values.""" key = flags[0].lstrip("-").replace("-", "_") - if key in config: - default = config[key] + if key in config.as_dict: + default = config.as_dict[key] if key.upper() in os.environ: default = os.environ[key.upper()] @@ -48,11 +96,16 @@ def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], req parser.add_argument(*flags, **kw) -def load_config(): - if os.path.exists(".config"): - return yaml.load(open(".config", "r"), Loader=yaml.SafeLoader) +def load_config() -> CliConfig: + if os.path.exists(".levanter.yaml"): + d = yaml.load(open(".levanter.yaml", "r"), Loader=yaml.SafeLoader) + elif os.path.exists(".config"): + warnings.warn("Using deprecated .config file. Please rename to .levanter.yaml") + d = yaml.load(open(".config", "r"), Loader=yaml.SafeLoader) else: - return {} + d = {} + + return draccus.decode(CliConfig, d) def get_git_commit(): @@ -60,36 +113,6 @@ def get_git_commit(): return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip() -def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"): - docker_command = [ - "docker", - "run", - "-t" if foreground else "-d", - f"--name={shlex.quote(name)}", - "--privileged", - "--shm-size=32gb", - "--net=host", - "--init", - "--mount", - "type=volume,source=levanter,target=/home/levanter", - "-v", - "/tmp:/tmp", - ] - - # optionally add multislice env vars (if set by ray runtime env vars) - for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: - v = shlex.quote(str(v)) - docker_command.extend(["-e", v]) - - for k, v in env.items(): - v = shlex.quote(str(v)) - k = shlex.quote(str(k)) - docker_command.extend(["-e", f"{k}={v}"]) - - docker_command.extend([image_id, *command]) - return docker_command - - def default_run_id(): """Generate a run ID for wandb and continuation. diff --git a/src/levanter/infra/docker.py b/src/levanter/infra/docker.py index d48b558a5..aabce6b1a 100644 --- a/src/levanter/infra/docker.py +++ b/src/levanter/infra/docker.py @@ -1,6 +1,7 @@ import json import os import pty +import shlex import shutil import subprocess import sys @@ -236,3 +237,33 @@ def split_image_and_tag(docker_base_image): base_image = docker_base_image base_tag = "latest" return base_image, base_tag + + +def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"): + docker_command = [ + "docker", + "run", + "-t" if foreground else "-d", + f"--name={shlex.quote(name)}", + "--privileged", + "--shm-size=32gb", + "--net=host", + "--init", + "--mount", + "type=volume,source=levanter,target=/home/levanter", + "-v", + "/tmp:/tmp", + ] + + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + v = shlex.quote(str(v)) + docker_command.extend(["-e", v]) + + for k, v in env.items(): + v = shlex.quote(str(v)) + k = shlex.quote(str(k)) + docker_command.extend(["-e", f"{k}={v}"]) + + docker_command.extend([image_id, *command]) + return docker_command diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 1a9342c54..86ce4223a 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -18,7 +18,7 @@ from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError from ray.remote_function import RemoteFunction -from levanter.infra.cli_helpers import make_docker_run_command +from levanter.infra.docker import make_docker_run_command from levanter.utils.ray_utils import ser_exc_info diff --git a/src/levanter/infra/tpus.py b/src/levanter/infra/tpus.py index bbb1cc5f5..fa0f1a23c 100644 --- a/src/levanter/infra/tpus.py +++ b/src/levanter/infra/tpus.py @@ -10,7 +10,7 @@ import requests # type: ignore -from levanter.infra.cli_helpers import make_docker_run_command +from levanter.infra.docker import make_docker_run_command logger = logging.getLogger(__name__) diff --git a/src/levanter/utils/flop_utils.py b/src/levanter/utils/flop_utils.py index f57e90c7f..eef91f110 100644 --- a/src/levanter/utils/flop_utils.py +++ b/src/levanter/utils/flop_utils.py @@ -180,7 +180,6 @@ def _simplify_device_kind(kind: str) -> str: # TPU looks like 'TPU v4' if kind.startswith("tpu"): - print(f"TPU kind: {kind}") return kind if "h100" in kind and ("sxm" in kind or "hbm3" in kind): @@ -200,8 +199,6 @@ def _simplify_device_kind(kind: str) -> str: if "a6000" in kind: return "a6000" - - return kind