Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc fixes from sweep (disable blocked CE by default) #798

Merged
merged 8 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], Pos)
# padding doesn't do truncation, so we have to do it ourselves.
# Truncate from the left since we want to predict the last tokens
input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos)
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

Expand Down
10 changes: 6 additions & 4 deletions src/levanter/doremi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import logging
from typing import Optional, Tuple, TypeVar
from typing import Mapping, Optional, Tuple, TypeVar

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -56,7 +56,7 @@ def estimate_mixture_weights(
loss_fn: ComputeLossFunction[M, T],
initial_proxy: M,
ref: M,
data_sources: dict[str, AsyncDataset[T]],
data_sources: Mapping[str, AsyncDataset[T]],
sampling_weights: Optional[dict[str, float]] = None,
*,
validation_sets: Optional[dict[str, AsyncDataset[T]]] = None,
Expand Down Expand Up @@ -184,7 +184,9 @@ def doremi_step(state: DoremiState, ref, batch, domains):

# we're not actually going to use the trainer for very much but it holds hooks and sets up contexts
with trainer:
tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key)
tagged_mixture: MixtureDataset = domain_tagged_mixture(
data_sources, sampling_weights, domain_to_index, key=data_key
)
state = load_checkpoint_or_initialize(
DoremiState.init,
trainer.checkpoint_path,
Expand Down Expand Up @@ -263,7 +265,7 @@ def _prepare_ref_model(ref, trainer):


def domain_tagged_mixture(
data_sources: dict[str, AsyncDataset[T]],
data_sources: Mapping[str, AsyncDataset[T]],
weights: dict[str, float],
domain_to_index: dict[str, int],
*,
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/main/doremi_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def init_proxy_model():
train_datasets = config.data.training_sets(ref_model.Pos.size)
valid_datasets = config.data.validation_sets(ref_model.Pos.size)

train_datasets = {
causal_train_datasets = {
k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id)
for k, v in train_datasets.items()
}
Expand All @@ -122,7 +122,7 @@ def init_proxy_model():
loss_function,
proxy_model,
ref=ref_model,
data_sources=train_datasets,
data_sources=causal_train_datasets,
trainer_config=config.trainer,
optimizer=optimizer,
domain_weight_step_size=config.doremi.domain_weight_step_size,
Expand Down
3 changes: 3 additions & 0 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ def compute_log_probs(model, example):
checkpointer = trainer.config.checkpointer.create(trainer.run_id)
checkpointer.wait_until_finished()

# This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray)
trainer.tracker.finish()


if __name__ == "__main__":
levanter.config.main(main)()
4 changes: 3 additions & 1 deletion src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
from dataclasses import dataclass
from typing import Generic, Optional, Type, TypeVar

import draccus
Expand Down Expand Up @@ -48,6 +49,7 @@ def causal(


# TODO: for some reason, mypy doesn't like the discover_packages_path argument?
@dataclass(frozen=True)
class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore
@property
@abc.abstractmethod
Expand All @@ -69,7 +71,7 @@ def Pos(self) -> Axis:
def Embed(self) -> Axis:
pass

cross_entropy_block_size: Optional[int] = 64000
cross_entropy_block_size: Optional[int] = None
"""
The block size for computing cross-entropy loss. This is the number of tokens that are processed together
in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large
Expand Down
22 changes: 13 additions & 9 deletions src/levanter/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def next_token_loss(

if block_size is None:
# Full softmax computation
logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype)
logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed)
if dtype is not None:
logits = logits.astype(dtype)
target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype)
return cross_entropy_and_logsumexp_penalty(
logits,
Expand Down Expand Up @@ -261,9 +263,10 @@ def process_block(block_idx, acc, current_block_size):

# Materialize the logits for the current block
lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block]
logits_b = hax.dot(
pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
) # [Batch, Seq, Block]
logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block]

if dtype is not None:
logits_b = logits_b.astype(dtype)

# Update max and logsumexp
max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq]
Expand All @@ -278,7 +281,7 @@ def process_block(block_idx, acc, current_block_size):
# Update sumV. This is actually unnecessary if we're using one-hot targets
# sV = sV_prev + hax.sum(target_y_b, axis=Label.name)

loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq]
loss += hax.dot(logits_b, target_y_b, axis=Block) # [Batch, Seq]

return loss, logsumexp, max_logit # , sV

Expand Down Expand Up @@ -351,7 +354,7 @@ def _block_cross_entropy_backward(
num_blocks = vocab_size // block_size

grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype)
grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype)
grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype)

def process_block(block_idx, acc, current_block_size):
"""
Expand All @@ -372,14 +375,15 @@ def process_block(block_idx, acc, current_block_size):

# Materialize the logits for the current block
lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block]
logits_b = hax.dot(
pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype
) # [Batch, Seq, Block]
logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block]

# Materialize the target for the current block (one-hot)
target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block]

# materialize the softmax for the current block
if dtype is not None:
logits_b = logits_b.astype(dtype)

p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block]

delta_b = p_b - target_y_block
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def from_hf(config: HfMptAttentionConfig):


@LmConfig.register_subclass("mpt")
@dataclass
@dataclass(frozen=True)
class MptConfig(HFCompatConfig):
d_model: int = 768
n_heads: int = 12
Expand Down
3 changes: 3 additions & 0 deletions src/levanter/tracker/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}")
return

def finish(self):
self.writer.close()


@TrackerConfig.register_subclass("tensorboard")
@dataclass
Expand Down
22 changes: 22 additions & 0 deletions src/levanter/tracker/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

@abc.abstractmethod
def finish(self):
"""
Finish the tracker. This is called when the tracker is no longer needed. This can, e.g.,
force a commit of all metrics.
"""
pass

def __enter__(self):
import levanter.tracker.tracker_fns as tracker_fns

Expand Down Expand Up @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio
for tracker in self.loggers:
tracker.log_artifact(artifact_path, name=name, type=type)

def finish(self):
excs = []
for tracker in self.loggers:
try:
tracker.finish()
except Exception as e:
excs.append(e)

if excs:
raise RuntimeError("Errors occurred when finishing trackers") from excs[0]


class TrackerConfig(draccus.PluginRegistry, abc.ABC):
discover_packages_path = "levanter.tracker"
Expand Down Expand Up @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
pass

def finish(self):
pass


@TrackerConfig.register_subclass("noop")
@dataclasses.dataclass
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/tracker/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]):
def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None):
self.run.log_artifact(artifact_path, name=name, type=type)

def finish(self):
logger.info("Finishing wandb run...")
self.run.finish()


def is_wandb_available():
try:
Expand Down
Loading