Skip to content

Commit

Permalink
update docs for accel_env
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Dec 20, 2024
1 parent 01f1792 commit fa700e8
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 56 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/scratch

# Configuration for TPU launches/secrets
.config
.levanter.yaml
.levanter.yaml

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
13 changes: 9 additions & 4 deletions docs/Getting-Started-TPU-VM.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ on your development machine to build and run images on TPUs.
First create a configuration file for future launches in your Levanter directory:

```bash
cat > .config <<EOF
cat > .levanter.yaml <<EOF
env:
WANDB_API_KEY:
WANDB_ENTITY:
Expand All @@ -93,15 +93,18 @@ env:
TPU_MIN_LOG_LEVEL: 0
LIBTPU_INIT_ARGS: <extra args to libtpu> # Optional
# Optional: specific environment variables for TPUs based on the TPU type
accel_env:
v6e:
# If you're lucky enough to have a v6e, you can set the following, which is pretty important for performance
LIBTPU_INIT_ARGS: "--xla_tpu_scoped_vmem_limit_kib=98304"
docker_repository: levanter # default
zone: us-west4-a # if not set, will use your default zone
tpu_name: test-spin-up-32
tpu_type: "v5litepod-16"
vm_image: "tpu-ubuntu2204-base" # default
capacity_type: "preemptible"
autodelete: false
subnetwork: "default" # default
EOF
```

Expand Down Expand Up @@ -155,6 +158,8 @@ a new file:
If you're using `launch.py`, the config will be automatically uploaded as part of your Docker image, so you
can just reference the local config path in your command line:

```bash
python infra/launch.py -- python src/levanter/main/train_lm.py --config_path config/my_config.yaml --trainer.checkpointer.base_path gs://<somewhere>'
```

Afterward, you can use the config directly from the TPU VM instance, e.g.:
Expand Down
15 changes: 10 additions & 5 deletions infra/launch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/python

import argparse
import getpass
import subprocess
import sys
import time
from pathlib import Path

Expand All @@ -11,6 +11,7 @@
import levanter.infra.tpus
from levanter.infra.tpus import launch_job


# default: tpu-ubuntu2204-base
TPU_TYPE_TO_VM_IMAGE = {
"v5litepod": "v2-alpha-tpuv5-lite",
Expand Down Expand Up @@ -44,9 +45,7 @@ def main():
cli.add_arg(parser, config, ["--github_token"], type=str)
cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None)

parser.add_argument(
"-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items())
)
parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"))
parser.add_argument("command", nargs=argparse.REMAINDER)

args = parser.parse_args()
Expand All @@ -68,6 +67,9 @@ def main():
tpu_gen = tpu_type.split("-")[0]
version = args.version or TPU_TYPE_TO_VM_IMAGE.get(tpu_gen, "tpu-ubuntu2204-base")

if not args.version:
print(f"Using default version: {version}", file=sys.stderr)

node_count = args.node_count
zone = args.zone
run_id = args.run_id
Expand All @@ -83,7 +85,10 @@ def main():
raise ValueError("Zone must be specified or set in gcloud config.")

region = "-".join(zone.split("-")[:-1])
env = {k: v for k, v in args.env}

env = config.env_for_accel(tpu_type)
for key, value in args.env or []:
env[key] = value

if "WANDB_PROJECT" not in env:
env["WANDB_PROJECT"] = "levanter"
Expand Down
9 changes: 6 additions & 3 deletions infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ def main():
cli.add_arg(parser, config, ["--extra_context"], type=Path, required=False, default=None)
cli.add_arg(parser, config, ["--zone"], default=None, type=str, required=False)

parser.add_argument(
"-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"), default=list(config.get("env", {}).items())
)
parser.add_argument("-e", "--env", action="append", nargs=2, metavar=("KEY", "VALUE"))

parser.add_argument("command", nargs=argparse.REMAINDER)

args = parser.parse_args()
Expand All @@ -62,6 +61,10 @@ def main():
github_token = args.github_token
extra_context = args.extra_context

env = config.env_for_accel(tpu_type)
for key, value in args.env or []:
env[key] = value

if zone is None:
zone = cli.gcloud_config()["zone"]

Expand Down
99 changes: 61 additions & 38 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,62 @@
import argparse
import base64
import dataclasses
import os
import shlex
import subprocess
import warnings
from dataclasses import dataclass
from functools import cached_property
from typing import Optional

import draccus
import yaml
from google.cloud import storage


@dataclass(frozen=True)
class CliConfig:
project: str | None = None
zone: str | None = None
tpu: str | None = None
repository: str | None = None
image: str | None = None
tag: str | None = None
github_user: str | None = None
github_token: str | None = None
docker_file: str | None = None
extra_context: str | None = None
docker_target: str | None = None
docker_repository: str | None = None
subnetwork: str | None = None

env: dict[str, str] = dataclasses.field(default_factory=dict)

accel_env: dict[str, dict[str, str]] = dataclasses.field(default_factory=dict)
"""
Environment variables specific to a type of accelerator. The keys are the accelerator type (e.g. v5litepod-256) or
generation (e.g. v5litepod), with priority given to the more specific key. The values are dictionaries of environment
variables to set. These take priority over the general `env` field.
"""

def env_for_accel(self, accel_type: str) -> dict[str, str]:

base_env = self.env.copy()

if "-" in accel_type:
base_env.update(self.accel_env.get(accel_type.split("-")[0], {}))

if accel_type in self.accel_env:
base_env.update(self.accel_env[accel_type])

return base_env

@cached_property
def as_dict(self):
dict = dataclasses.asdict(self)
# remove Nones
return {k: v for k, v in dict.items() if v is not None}


# Oddly enough, there's no API to simply fetch the current gcloud configuration...
def gcloud_config():
client = storage.Client()
Expand All @@ -31,11 +79,11 @@ def get_default_zone() -> Optional[str]:
return None


def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], required=False, default=None, **kw):
def add_arg(parser: argparse.ArgumentParser, config: CliConfig, flags: list[str], required=False, default=None, **kw):
"""Add an argument to the parser, using `config` or the environment to resolve default values."""
key = flags[0].lstrip("-").replace("-", "_")
if key in config:
default = config[key]
if key in config.as_dict:
default = config.as_dict[key]

if key.upper() in os.environ:
default = os.environ[key.upper()]
Expand All @@ -48,48 +96,23 @@ def add_arg(parser: argparse.ArgumentParser, config: dict, flags: list[str], req
parser.add_argument(*flags, **kw)


def load_config():
if os.path.exists(".config"):
return yaml.load(open(".config", "r"), Loader=yaml.SafeLoader)
def load_config() -> CliConfig:
if os.path.exists(".levanter.yaml"):
d = yaml.load(open(".levanter.yaml", "r"), Loader=yaml.SafeLoader)
elif os.path.exists(".config"):
warnings.warn("Using deprecated .config file. Please rename to .levanter.yaml")
d = yaml.load(open(".config", "r"), Loader=yaml.SafeLoader)
else:
return {}
d = {}

return draccus.decode(CliConfig, d)


def get_git_commit():
"""Get the current git commit hash."""
return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()


def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"):
docker_command = [
"docker",
"run",
"-t" if foreground else "-d",
f"--name={shlex.quote(name)}",
"--privileged",
"--shm-size=32gb",
"--net=host",
"--init",
"--mount",
"type=volume,source=levanter,target=/home/levanter",
"-v",
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
v = shlex.quote(str(v))
docker_command.extend(["-e", v])

for k, v in env.items():
v = shlex.quote(str(v))
k = shlex.quote(str(k))
docker_command.extend(["-e", f"{k}={v}"])

docker_command.extend([image_id, *command])
return docker_command


def default_run_id():
"""Generate a run ID for wandb and continuation.
Expand Down
31 changes: 31 additions & 0 deletions src/levanter/infra/docker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pty
import shlex
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -236,3 +237,33 @@ def split_image_and_tag(docker_base_image):
base_image = docker_base_image
base_tag = "latest"
return base_image, base_tag


def make_docker_run_command(image_id, command, *, foreground, env, name="levanter"):
docker_command = [
"docker",
"run",
"-t" if foreground else "-d",
f"--name={shlex.quote(name)}",
"--privileged",
"--shm-size=32gb",
"--net=host",
"--init",
"--mount",
"type=volume,source=levanter,target=/home/levanter",
"-v",
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
v = shlex.quote(str(v))
docker_command.extend(["-e", v])

for k, v in env.items():
v = shlex.quote(str(v))
k = shlex.quote(str(k))
docker_command.extend(["-e", f"{k}={v}"])

docker_command.extend([image_id, *command])
return docker_command
2 changes: 1 addition & 1 deletion src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ray.exceptions import NodeDiedError, RayError, RaySystemError, RayTaskError, WorkerCrashedError
from ray.remote_function import RemoteFunction

from levanter.infra.cli_helpers import make_docker_run_command
from levanter.infra.docker import make_docker_run_command
from levanter.utils.ray_utils import ser_exc_info


Expand Down
2 changes: 1 addition & 1 deletion src/levanter/infra/tpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import requests # type: ignore

from levanter.infra.cli_helpers import make_docker_run_command
from levanter.infra.docker import make_docker_run_command


logger = logging.getLogger(__name__)
Expand Down
3 changes: 0 additions & 3 deletions src/levanter/utils/flop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def _simplify_device_kind(kind: str) -> str:

# TPU looks like 'TPU v4'
if kind.startswith("tpu"):
print(f"TPU kind: {kind}")
return kind

if "h100" in kind and ("sxm" in kind or "hbm3" in kind):
Expand All @@ -200,8 +199,6 @@ def _simplify_device_kind(kind: str) -> str:
if "a6000" in kind:
return "a6000"



return kind


Expand Down

0 comments on commit fa700e8

Please sign in to comment.