Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Nov 6, 2024
2 parents 812accb + a7e42ec commit 2d7170c
Show file tree
Hide file tree
Showing 32 changed files with 1,861 additions and 697 deletions.
2 changes: 1 addition & 1 deletion config/gpt2_small_fast_pile.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
data: !include data/pile_source_old.yaml
data: !include data/pile_mixture.yaml
model:
type: gpt2
hidden_dim: 768
Expand Down
1 change: 1 addition & 0 deletions config/gpt2_small_fast_supervised.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ data:
supervised_data:
validation_urls:
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz"
- "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz"
cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/"
input_field: "input"
output_field: "output"
Expand Down
32 changes: 32 additions & 0 deletions config/llama3_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/"
tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
model:
type: llama
hidden_dim: 768
intermediate_dim: 2048
num_heads: 12
num_kv_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
trainer:
tracker:
- type: wandb
project: "levanter"
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
2 changes: 1 addition & 1 deletion config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 70000 # 280B / 4M
num_train_steps: 480000 # 2T / 4M
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
Expand Down
45 changes: 42 additions & 3 deletions infra/cluster/job-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ cluster_name: levanter-cluster
# Configure GCP
provider:
type: gcp
region: us-central2
availability_zone: us-central2-b
region: us-west4
availability_zone: us-west4-a
project_id: hai-gcp-models

# Maximum Workers (excluding Head Node)
Expand Down Expand Up @@ -126,6 +126,45 @@ available_node_types:
schedulingConfig:
preemptible: true

tpu_slice_v5e_16:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-16
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

tpu_slice_v5e_64:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-64
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

tpu_slice_v5e_256:
min_workers: 0
max_workers: 1024
resources: { "CPU": 120, "TPU": 4 }

node_config:
acceleratorType: v5litepod-256
runtimeVersion: tpu-ubuntu2204-base

# [IMPORTANT] Configure all TPU Workers to be Preemptible!
schedulingConfig:
preemptible: true

docker:
image: "ghcr.io/stanford-crfm/levanter-cluster:latest"
container_name: "ray_docker"
Expand All @@ -140,7 +179,7 @@ docker:
- -v "/var/run/docker.sock:/var/run/docker.sock"

initialization_commands:
- yes | gcloud auth configure-docker us-central2-docker.pkg.dev
- yes | gcloud auth configure-docker us-west4-docker.pkg.dev
- "export TPU_WORKER_ID=$(curl -H 'Metadata-Flavor: Google' http://metadata.google.internal/computeMetadata/v1/instance/attributes/agent-worker-number) || true"
- which docker || (curl -fsSL https://get.docker.com -o get-docker.sh; sudo sh get-docker.sh; sudo usermod -aG docker $USER; sudo systemctl restart docker -f)
# always run this because ray doesn't run with sudo
Expand Down
3 changes: 2 additions & 1 deletion infra/launch_on_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
cli.add_arg(parser, config, ["--project"], default=cli.gcloud_config()["project"])
cli.add_arg(parser, config, ["--tpu_type"], required=True)
# TODO: bring node_count to Ray
# cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--node_count"], default=1, type=int)
cli.add_arg(parser, config, ["--foreground"], default=False, action="store_true")
cli.add_arg(parser, config, ["--retries"], default=10, type=int)
cli.add_arg(parser, config, ["--run_id"], default=cli.default_run_id(), type=str)
Expand Down Expand Up @@ -122,6 +122,7 @@ def main():
env=env,
name="levanter",
retries=retries,
node_count=args.node_count,
)

address = args.address or os.getenv("RAY_ADDRESS")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"draccus>=0.8.0",
"pyarrow>=11.0.0",
"zstandard>=0.20.0",
"datasets>=2.18,<4.0",
"datasets>=3.1.0,<4.0",
"gcsfs>=2024.2,<2024.10",
"braceexpand>=0.1.7",
"jmp>=0.0.3",
Expand Down
5 changes: 4 additions & 1 deletion src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
dataset = self._load_dataset()
if isinstance(dataset, datasets.IterableDataset) and shard_name != "data":
# ex_iterable has a key that gets discarded typically
shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards))
shard = map(
lambda t: t[1],
dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards),
)
else:
shard = dataset

Expand Down
4 changes: 2 additions & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from levanter.store.cache import CacheOptions, TreeCache
from levanter.store.jagged_array import JaggedArrayStore
from levanter.store.tree_store import TreeStore
from levanter.utils.fsspec_utils import fsspec_expand_glob
from levanter.utils.fsspec_utils import expand_glob
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer


Expand Down Expand Up @@ -588,7 +588,7 @@ def urls_for_split(self, split):
else:
raise ValueError(f"Unknown split {split}")

urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)]
urls = [globbed for url in urls for globbed in expand_glob(url)]
return urls


Expand Down
52 changes: 30 additions & 22 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.datasets = []
tag_index: dict[str, int] = {}
for i, (dataset, tags) in enumerate(datasets):
if tags is None:
if not tags and len(datasets) > 1:
warnings.warn("Dataset has no tags. Giving it an index")
tags = [f"domain_{i}"]
for tag in tags:
Expand Down Expand Up @@ -204,14 +204,16 @@ def eval_callback(step: StepInfo):
}

logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}")
for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
if not tag:
continue
log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")
has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro
if has_tags:
for tag, loss in result.tag_macro_losses.items():
# don't log leaf tag macro losses because it doesn't mean anything different than micro loss
if tag in evaluator.dataset.tag_to_index:
continue
if not tag:
continue
log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss
logger.info(f"{tag} macro loss: {loss:.3f}")

for tag, loss in result.tag_micro_losses.items():
if not tag:
Expand All @@ -225,11 +227,14 @@ def eval_callback(step: StepInfo):

if tokenizer is not None:
log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
if has_tags:
log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb
for tag, bpb in result.tag_micro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

if has_tags:
for tag, bpb in result.tag_macro_bpb.items():
log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb

levanter.tracker.log_metrics(log_dict, step=step.step)

Expand Down Expand Up @@ -304,26 +309,29 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample,
this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag]

mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens)
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
state = dataclasses.replace(state, token_avg_loss=mean)

state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag)
if len(self.dataset.tag_to_index) > 0:
# careful: this_tokens_per_tag can be 0 if there are no tokens for that tag
safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0)
mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag)
state = dataclasses.replace(state, loss_per_tag=mean_per_tag)

if self.bytes_per_token is not None:
next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task
bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos]
bytes_per_pos = bytes_per_pos * mask # [Batch, Pos]
bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag]
total_bytes = hax.sum(bytes_per_tag)
bytes_per_tag = hax.einsum("-> tag", mask, bytes_per_pos, tags) # [Tag]
this_bytes = hax.einsum("->", bytes_per_pos, mask) # Scalar

# log loss -> bits is log2(e) * loss
bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e)
bpb = this_loss / hax.maximum(this_bytes, 1) * jnp.log2(jnp.e)

bpb_mean = state.bpb.add(bpb, this_tokens)
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean)
state = dataclasses.replace(state, bpb=bpb_mean)
if len(self.dataset.tag_to_index) > 0:
bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag)
state = dataclasses.replace(state, bpb_per_tag=bpb_per_tag_mean)

return state

Expand Down
5 changes: 5 additions & 0 deletions src/levanter/infra/cli_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def make_docker_run_command(image_id, command, *, foreground, env, name="levante
"/tmp:/tmp",
]

# optionally add multislice env vars (if set by ray runtime env vars)
for v in ["MEGASCALE_COORDINATOR_ADDRESS", "MEGASCALE_NUM_SLICES", "MEGASCALE_PORT", "MEGASCALE_SLICE_ID"]:
v = shlex.quote(str(v))
docker_command.extend(["-e", v])

for k, v in env.items():
v = shlex.quote(str(v))
k = shlex.quote(str(k))
Expand Down
Loading

0 comments on commit 2d7170c

Please sign in to comment.