From 23fd4a85bfad0fb49538cabedc53a9bf2c7f00f6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 21 Oct 2024 21:38:22 -0700 Subject: [PATCH 1/6] workaround connection refused error coming from GCE metadata server (#775) --- docker/tpu/Dockerfile.base | 3 ++- infra/helpers/setup-tpu-vm-tests.sh | 2 +- infra/helpers/setup-tpu-vm.sh | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/tpu/Dockerfile.base b/docker/tpu/Dockerfile.base index 3c0d1cc5e..d276c974d 100644 --- a/docker/tpu/Dockerfile.base +++ b/docker/tpu/Dockerfile.base @@ -5,7 +5,8 @@ RUN pip install virtualenv # venv binaries encode their directory, so we need to setup the venv in the final location RUN virtualenv -p python3.10 /opt/levanter/.venv ENV PATH /opt/levanter/.venv/bin:$PATH -RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install package dependencies to make incremental builds faster. WORKDIR /tmp/ diff --git a/infra/helpers/setup-tpu-vm-tests.sh b/infra/helpers/setup-tpu-vm-tests.sh index 71bead17e..33c1c4add 100755 --- a/infra/helpers/setup-tpu-vm-tests.sh +++ b/infra/helpers/setup-tpu-vm-tests.sh @@ -105,7 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -retry pip install -U "jax[tpu]==0.4.31" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index f80e586bb..3ca81d76b 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -105,8 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -#retry pip install -U "jax[tpu]==0.4.5" libtpu-nightly==0.1.dev20230216 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -retry pip install -U "jax[tpu]==0.4.31" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter From 655f48f9066bac645c297396418ea64f4ea00953 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 23 Oct 2024 12:40:10 -0700 Subject: [PATCH 2/6] Infra tweaks (#776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit found an issue with the ray stuff where somehow ray seemingly wasn’t raising RayTaskErrors but other exceptions. also fixes an issue with quoting --- src/levanter/infra/cli_helpers.py | 7 +++++-- src/levanter/infra/ray_tpu.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index eef8fa969..b92b6efb5 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -1,6 +1,7 @@ import argparse import base64 import os +import shlex import subprocess from typing import Optional @@ -64,7 +65,7 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "docker", "run", "-t" if foreground else "-d", - f"--name={name}", + f"--name={shlex.quote(name)}", "--privileged", "--shm-size=32gb", "--net=host", @@ -76,7 +77,9 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante ] for k, v in env.items(): - docker_command.extend(["-e", k + f"={str(v)}"]) + v = shlex.quote(str(v)) + k = shlex.quote(str(k)) + docker_command.extend(["-e", f"{k}={v}"]) docker_command.extend([image_id, *command]) return docker_command diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 3ae5d0105..2dc554808 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -76,6 +76,7 @@ def run_on_pod(remote_fn: RemoteFunction | Callable, tpu_type: str) -> ray.Objec @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) def do_run(remote_fn) -> _TpuRunResult: + logging.basicConfig(level=logging.INFO) num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> 4 remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, num_hosts) @@ -92,6 +93,13 @@ def do_run(remote_fn) -> _TpuRunResult: except Exception: logger.exception("Failed to kill job after primary failure") return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) return do_run.remote(remote_fn) @@ -144,12 +152,12 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re out = ray.get(run_on_pod(remote_fn, tpu_type)) except ray.exceptions.RayTaskError as e: problem = e - if "preempted" in str(e): + if "preempted" in str(e).lower(): num_preemptions += 1 logger.warning(f"Preempted {num_preemptions} times, {e}") else: num_failures += 1 - logger.warning(f"Failed {num_failures} times") + logger.warning(f"Failed {num_failures} times", exc_info=e) continue except Exception as e: problem = e From d3774370a98a1151bb97d094cff25a4b37142bbf Mon Sep 17 00:00:00 2001 From: William Held Date: Fri, 25 Oct 2024 13:00:10 -0400 Subject: [PATCH 3/6] Support Tied Weights in Llama Models (#777) The new smaller Llama 3.2 1B and 3.2 3B models have tied weights - so Levanter throws an error currently if we try to import these models. https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/blob/main/config.json ![Screenshot 2024-10-24 8 17 03 PM](https://github.com/user-attachments/assets/08e79ed7-cab5-43f0-9ca6-f90e2fe73249) This adds HF support for that argument and just switches to using embedding.unembed when Embeddings are tied! --- src/levanter/models/llama.py | 42 ++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e777b7636..1e09ffbc5 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -64,6 +64,7 @@ class LlamaConfig(HFCompatConfig): activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 + tie_word_embeddings: bool = False # Attention-related config upcast_attn: bool = False @@ -120,6 +121,7 @@ def from_hf_config(cls, hf_config: HfConfig): activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, + tie_word_embeddings=hf_config.tie_word_embeddings, rope=rope_config, ) @@ -148,6 +150,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, + tie_word_embeddings=self.tie_word_embeddings, # rope_scaling=self.rope_scaling, vocab_size=vocab_size, rope_theta=rope_theta, @@ -504,7 +507,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): transformer: LlamaTransformer embeddings: LlamaEmbedding - lm_head: hnn.Linear + lm_head: Optional[hnn.Linear] @property def config(self): @@ -523,7 +526,11 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) - lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + if config.tie_word_embeddings: + lm_head = None + else: + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return LlamaLMHeadModel(transformer, embeddings, lm_head) def __call__( @@ -544,17 +551,22 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - lm_logits = self.lm_head(x, key=k_head) + if self.lm_head: + lm_logits = self.lm_head(x, key=k_head) + else: + lm_logits = self.embeddings.unembed(x) return lm_logits def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) - new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) - new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) - - return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + if self.lm_head is not None: + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + else: + return dataclasses.replace(self, embeddings=new_embeddings) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} @@ -562,20 +574,22 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp d = state_dict.copy() - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + if self.lm_head is not None: + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) ) - ) return super().from_state_dict(d, prefix) def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: my_dict: StateDict = {} super().update_state_dict(my_dict, prefix=prefix) - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) - ) + if self.lm_head is not None: + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) state_dict.update(my_dict) return state_dict From e6dfd177fd6f31ba5710397357843ab1088d38c4 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Tue, 29 Oct 2024 21:50:55 -0700 Subject: [PATCH 4/6] add multislice support in ray (#771) --- infra/cluster/job-cluster.yaml | 45 ++++++- infra/launch_on_ray.py | 3 +- src/levanter/infra/cli_helpers.py | 5 + src/levanter/infra/ray_tpu.py | 201 +++++++++++++++++++++++++++++- 4 files changed, 244 insertions(+), 10 deletions(-) diff --git a/infra/cluster/job-cluster.yaml b/infra/cluster/job-cluster.yaml index cf8703d54..cff7d4884 100644 --- a/infra/cluster/job-cluster.yaml +++ b/infra/cluster/job-cluster.yaml @@ -14,8 +14,8 @@ cluster_name: levanter-cluster # Configure GCP provider: type: gcp - region: us-central2 - availability_zone: us-central2-b + region: us-west4 + availability_zone: us-west4-a project_id: hai-gcp-models # Maximum Workers (excluding Head Node) @@ -126,6 +126,45 @@ available_node_types: schedulingConfig: preemptible: true + tpu_slice_v5e_16: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-16 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_64: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-64 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + + tpu_slice_v5e_256: + min_workers: 0 + max_workers: 1024 + resources: { "CPU": 120, "TPU": 4 } + + node_config: + acceleratorType: v5litepod-256 + runtimeVersion: tpu-ubuntu2204-base + + # [IMPORTANT] Configure all TPU Workers to be Preemptible! + schedulingConfig: + preemptible: true + docker: image: "ghcr.io/stanford-crfm/levanter-cluster:latest" container_name: "ray_docker" @@ -140,7 +179,7 @@ docker: - -v "/var/run/docker.sock:/var/run/docker.sock" initialization_commands: - - yes | gcloud auth configure-docker us-central2-docker.pkg.dev + - yes | gcloud auth configure-docker us-west4-docker.pkg.dev - "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true" - which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f) # always run this because ray doesn't run with sudo diff --git a/infra/launch_on_ray.py b/infra/launch_on_ray.py index fa5e81f27..90f2c586a 100755 --- a/infra/launch_on_ray.py +++ b/infra/launch_on_ray.py @@ -27,7 +27,7 @@ def main(): cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"]) cli.add_arg(parser, config, ["--tpu_type"], required=True) # TODO: bring node_count to Ray - # cli.add_arg(parser, config, ["--node_count"], default=1, type=int) + cli.add_arg(parser, config, ["--node_count"], default=1, type=int) cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true") cli.add_arg(parser, config, ["--retries"], default=10, type=int) cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str) @@ -122,6 +122,7 @@ def main(): env=env, name="levanter", retries=retries, + node_count=args.node_count, ) address = args.address or os.getenv("RAY_ADDRESS") diff --git a/src/levanter/infra/cli_helpers.py b/src/levanter/infra/cli_helpers.py index b92b6efb5..58413ef2b 100644 --- a/src/levanter/infra/cli_helpers.py +++ b/src/levanter/infra/cli_helpers.py @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante "/tmp:/tmp", ] + # optionally add multislice env vars (if set by ray runtime env vars) + for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]: + v = shlex.quote(str(v)) + docker_command.extend(["-e", v]) + for k, v in env.items(): v = shlex.quote(str(v)) k = shlex.quote(str(k)) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 2dc554808..57f484770 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -3,6 +3,7 @@ import logging import multiprocessing import os +import socket import subprocess import tempfile import time @@ -104,7 +105,83 @@ def do_run(remote_fn) -> _TpuRunResult: return do_run.remote(remote_fn) -def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): +def run_on_pod_multislice(remote_fn: RemoteFunction | Callable, tpu_type: str, num_slices: int) -> ray.ObjectRef: + """ + Run a remote function on multiple TPU slices. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + + Returns: + A Ray ObjectRef that represents the result of the function + """ + + @ray.remote(resources={f"TPU-{tpu_type}-head": 1}) + class MultisliceActor: + def __init__(self): + self.pod_name = ray.util.accelerators.tpu.get_current_pod_name() + self.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() + self.ip = socket.gethostbyname(socket.gethostname()) + + def get_slice_info(self): + return self.pod_name, self.num_hosts, self.ip + + def do_run(self, remote_fn, coordinator_ip, slice_id, num_slices) -> _TpuRunResult: + port = 8081 + mxla_env = { + "MEGASCALE_COORDINATOR_ADDRESS": f"{coordinator_ip}:{port}", + "MEGASCALE_NUM_SLICES": str(num_slices), + "MEGASCALE_PORT": f"{port}", + "MEGASCALE_SLICE_ID": str(slice_id), + } + + remote_fn, tpu_name = _redecorate_remote_fn_for_tpu(remote_fn, self.num_hosts, env_vars=mxla_env) + + info = _TpuInfo(tpu_name, "ACTIVE", "TPU") + futures = [remote_fn.remote() for _ in range(self.num_hosts)] + try: + out = ray.get(futures) + logger.info("TPU job finished") + return TpuSuccess(info, out) + except RayError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return _handle_ray_error(info, e) + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + return TpuFailed(info, e) + + actors = [MultisliceActor.remote() for _ in range(num_slices)] # type: ignore + futures = [actor.get_slice_info.remote() for actor in actors] + try: + logger.info("Getting slice infos...") + # also act as a sync step + slice_infos = ray.get(futures) + logger.info(f"TPU slice infos {slice_infos}") + except RayError as e: + logger.exception(e) + for actor in actors: + try: + ray.cancel(actor) + except Exception: + logger.exception("Failed to kill actor after primary failure") + return futures + + coordinator_ip = slice_infos[0][2] + + return [actor.do_run.remote(remote_fn, coordinator_ip, i, num_slices) for i, actor in enumerate(actors)] + + +def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): """ Redecorate a remote function to run on a TPU pod. @@ -120,7 +197,11 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 - remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": num_tpus_per_host}) + + remote_fn = remote_fn.options( + runtime_env=runtime_env, + resources={tpu_name: 1, "TPU": num_tpus_per_host}, + ) logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name @@ -193,11 +274,107 @@ def run_on_pod_resumable(remote_fn, tpu_type, max_retries_preemption=1e6, max_re raise RuntimeError("Failed too many times") from problem +def run_on_pod_multislice_resumable( + remote_fn, tpu_type, num_slices, max_retries_preemption=1e6, max_retries_failure=10 +): + """ + Repeatedly run a function on a TPU pod until it succeeds or a maximum number of retries is reached. + + Args: + remote_fn: A remote function that takes no arguments + tpu_type: The type of TPU to run on, e.g. "v4-32" + num_slices: The number of slices to run + max_retries_preemption: The maximum number of times to retry if the job is preempted + max_retries_failure: The maximum number of times to retry if the job fails + + Returns: + The result of the function (not an ObjectRef) + + """ + num_failures = 0 + num_preemptions = 0 + attempt = 0 + problem: Exception | None = None + + while num_failures < max_retries_failure and num_preemptions < max_retries_preemption: + logger.info(f"Running on TPU {tpu_type}. Attempt {attempt}") + attempt += 1 + problem = None + futures = run_on_pod_multislice(remote_fn, tpu_type, num_slices) + try: + outs = ray.get(futures) + except ray.exceptions.RayTaskError as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + if "preempted" in str(e).lower(): + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times, {e}") + else: + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + except Exception as e: + for f in futures: + try: + ray.cancel(f) + except Exception: + logger.exception("Failed to kill job after primary failure") + problem = e + num_failures += 1 + if num_failures >= max_retries_failure: + logger.exception("Failed too many times", exc_info=e) + raise e + else: + logger.warning(f"Failed {num_failures} times", exc_info=e) + continue + + if all(isinstance(out, TpuSuccess) for out in outs): + results = [out.result for out in outs] + logger.info("Success") + return results + elif any(isinstance(out, TpuPreempted) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuPreempted): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + logger.warning(f"Preempted {num_preemptions} times. {problem}", exc_info=problem) + elif any(isinstance(out, TpuFailed) for out in outs): + num_preemptions += 1 + logger.warning(f"TPU node failure. Treating as preempted: {num_preemptions} times") + elif any(isinstance(out, TpuRunError) for out in outs): + out = None + for o in outs: + if isinstance(o, TpuRunError): + out = o + assert out is not None + problem = out.error + num_preemptions += 1 + problem = out.error + num_failures += 1 + logger.warning(f"Failed {num_failures} times", exc_info=problem) + else: + raise RuntimeError(f"Unexpected result: {out}") + + if num_preemptions >= max_retries_preemption: + raise RuntimeError("Preempted too many times") from problem + elif num_failures >= max_retries_failure: + raise RuntimeError("Failed too many times") from problem + + def _run_command(*args, **kwargs): return subprocess.check_call(args, **kwargs) -def run_docker_on_pod(image_id: str, command: Sequence[str], *, tpu_type: str, env: dict, name="levanter", retries=10): +def run_docker_on_pod( + image_id: str, command: Sequence[str], *, tpu_type: str, num_slices: int, env: dict, name="levanter", retries=10 +): env = _massage_env(env) docker_cmd = make_docker_run_command(image_id, command, env=env, foreground=True, name=name) @@ -210,9 +387,18 @@ def run_docker(): logger.exception("Failed to run docker command") raise e - run_on_pod_resumable( - ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 - ) + if num_slices == 1: + run_on_pod_resumable( + ray.remote(run_docker), tpu_type=tpu_type, max_retries_failure=retries, max_retries_preemption=10000 + ) + else: + run_on_pod_multislice_resumable( + ray.remote(run_docker), + tpu_type=tpu_type, + num_slices=num_slices, + max_retries_failure=retries, + max_retries_preemption=10000, + ) def _kill_old_container(name): @@ -351,6 +537,7 @@ class RunDockerOnPodConfig: env: dict = dataclasses.field(default_factory=dict) name: str = "levanter" retries: int = 10 + node_count: int = 1 def submit_tpu_job_on_ray(config: RunDockerOnPodConfig, ray_address: str, run_id: Optional[str] = None): @@ -419,6 +606,8 @@ def main(args: RunDockerOnPodConfig): tpu_type=args.tpu_type, env=args.env, name=args.name, + retries=args.retries, + num_slices=args.node_count, ) From f6de2311e4be5a2ccd6cfacf8c9590261f9c7230 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:40:55 -0800 Subject: [PATCH 5/6] Fix hf datasets for new version (#784) --- pyproject.toml | 2 +- src/levanter/data/sharded_datasource.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset From 22ed1405b44c476090101046282835ab7f4b2204 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Tue, 5 Nov 2024 14:01:22 -0500 Subject: [PATCH 6/6] Allowing internal supervised eval to work without separate eval set --- src/levanter/data/text.py | 3 +-- src/levanter/main/train_lm.py | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..7fb844d8f 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -575,7 +575,6 @@ class LMSupervisedDatasetConfig: validation_urls: List[str] = () # type:ignore - def preprocess_supervised_example( batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str ) -> dict: @@ -631,7 +630,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..79095d601 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -160,13 +160,13 @@ def main(config: TrainLmConfig): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size - causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets