From 234a945669354ec85d3b985c7c3cde4d39ef2e0b Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 15 Oct 2024 20:32:03 -0700 Subject: [PATCH 01/66] wip epochs --- config/llama_7b_with_olmo_config.yaml | 6 +++ examples/alpaca/alpaca.py | 4 +- infra/helpers/setup-tpu-vm.sh | 2 +- pyproject.toml | 2 +- src/levanter/data/text.py | 67 ++++++++++++++++++++++++++- src/levanter/main/train_lm.py | 12 +++-- src/levanter/trainer.py | 15 +++++- 7 files changed, 98 insertions(+), 10 deletions(-) diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index 9864a000f..595fabe68 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -15,6 +15,10 @@ trainer: project: "marin" tags: ["dolma", "olmo", "llama"] + checkpointer: + keep: + - every: 250 + mp: p=f32,c=bfloat16 train_batch_size: 2048 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 @@ -27,3 +31,5 @@ optimizer: weight_decay: 0.1 min_lr_ratio: 0.1 warmup: 0.01 + + data_shuffle: true \ No newline at end of file diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index a2201de76..97ef8d7ef 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample: # mask out padding and anything before the start of the target Pos = input_ids.resolve_axis("position") if config.mask_inputs: - loss_mask = hax.arange(Pos) >= ex["source_lens"] + loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1? # don't predict the padding targets = hax.roll(input_ids, -1, Pos) loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + # to not predict EOS token since we don't have target! + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index f80e586bb..d1a24b263 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -8,7 +8,7 @@ if [ "$DEBUG" == "1" ]; then fi REPO="https://github.com/stanford-crfm/levanter.git" -BRANCH=main +BRANCH=prefetch_actor_tokenizer if [ "$GIT_BRANCH" != "" ]; then BRANCH="$GIT_BRANCH" diff --git a/pyproject.toml b/pyproject.toml index b0c3df90a..fba7f0b74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dependencies = [ "pydantic<3", "rich~=13.0", "filelock~=3.13", - # "ai2-olmo", + "ai2-olmo", "async-lru~=2.0", "tqdm-loggable>=0.2", "deepdiff" diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index a1e20384f..fde6b9f5b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -63,6 +63,57 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index +class TokenSeqEpochDataset(AsyncDataset[np.ndarray]): + def __init__(self, doc_cache: TreeCache[dict], seq_len: int): + self.doc_cache = doc_cache + self.seq_len = seq_len + self._store: Optional[TreeStore] = None + self._cached_len: Optional[int] = None + + async def async_len(self) -> int: + await self.doc_cache.finished() + token_arrays = await self._await_token_cache() + return token_arrays.data_size // self.seq_len + + async def _await_token_cache(self) -> JaggedArrayStore: + if self._store is None: + self._store = await self.doc_cache.store_async() + return self._store.tree["input_ids"] + + async def final_length_is_known(self) -> bool: + return await self.doc_cache.final_length_is_known() + + def is_finite(self) -> bool: + return False # Now infinite due to epoch wrapping + + async def current_len(self) -> Optional[int]: + store = await self._await_token_cache() + return store.data_size // self.seq_len + + async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: + token_arrays = await self._await_token_cache() + dataset_len = await self.async_len() + + wrapped_indices = [idx % dataset_len for idx in indices] + offsets = np.array(wrapped_indices) * self.seq_len + + with ts.Batch(): + out = [] + for offset in offsets: + out.append(token_arrays.data[offset : offset + self.seq_len].read()) + + out = await asyncio.gather(*out) + return out + + async def wait_until_len_at_least(self, length: int) -> int: + # length is brutally slow to compute, so we cache it + if self._cached_len is not None: + return self._cached_len + + # TODO: would be better to listen for cache updates + length = await super().wait_until_len_at_least(length) + self._cached_len = length + return length class TokenSeqDataset(AsyncDataset[np.ndarray]): """ @@ -642,7 +693,13 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): def train_set( self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None ) -> AsyncDataset[np.ndarray]: - ds = self.token_seq_dataset("train", seq_len, monitors) + + if self.epochs is not None: + ds = self.token_epoch_dataset("train", seq_len, monitors) + else: + ds = self.token_seq_dataset("train", seq_len, monitors) + + # add epoch flag here. if ds is None: raise ValueError("No training set!") @@ -693,6 +750,14 @@ def token_seq_dataset( if cache is None: return None return TokenSeqDataset(cache, seq_len) + + def token_epoch_dataset( + self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True + ) -> Optional[TokenSeqDataset]: + cache = self.build_or_load_cache(split, monitors=monitors) + if cache is None: + return None + return TokenSeqEpochDataset(cache, seq_len) def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..717941038 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -117,6 +117,7 @@ def main(config: TrainLmConfig): # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) + # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) train_dataset = CausalLmDataset( config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) @@ -229,13 +230,14 @@ def compute_log_probs(model, example): return logprobs.rearrange((EvalBatch, Pos)).array train_loader = trainer.data_loader(train_dataset, Batch) - if seek_dataloader: - train_loader = train_loader.iter_from_step(state.step) - else: - train_loader = iter(train_loader) + # if seek_dataloader: + # train_loader = train_loader.iter_from_step(int(state.step)) + # else: + # train_loader = iter(train_loader) ## OK, actually run training! - trainer.train(state, train_loader) + train_loader = train_loader.iter_from_step(int(state.step)) + trainer.train(state, train_loader, epochs=config.epochs) # checkpointer.on_step(last_step, force=True) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 8e98eaedb..1ebc22122 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -376,13 +376,26 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi while int(state.step) < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) - + # while int(state.step) < target_steps and (epochs is None or current_epoch < epochs): + # current_epoch += 1 + # print(f"Starting epoch {current_epoch}") + # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) info = self.train_step(state, example) state = info.state if run_hooks: with capture_time() as hook_time: self.run_hooks(info) + # while True: + # try: + # with capture_time() as loading_time: + # example = next(iter_data) + # except StopIteration: + # # End of DataLoader iterator, proceed to next epoch + # train_loader = train_loader.iter_from_step(int(state.step)) + # print(f"End of epoch {current_epoch}") + # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) + # current_epoch += 1 levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step) From f0b1eaa1e2e7033ea178dcd088965d0b189d785e Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 15 Oct 2024 20:34:30 -0700 Subject: [PATCH 02/66] fix --- infra/helpers/setup-tpu-vm.sh | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index d1a24b263..f80e586bb 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -8,7 +8,7 @@ if [ "$DEBUG" == "1" ]; then fi REPO="https://github.com/stanford-crfm/levanter.git" -BRANCH=prefetch_actor_tokenizer +BRANCH=main if [ "$GIT_BRANCH" != "" ]; then BRANCH="$GIT_BRANCH" diff --git a/pyproject.toml b/pyproject.toml index fba7f0b74..a316ffb56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dependencies = [ "pydantic<3", "rich~=13.0", "filelock~=3.13", - "ai2-olmo", "async-lru~=2.0", "tqdm-loggable>=0.2", "deepdiff" From 020a1b283adaa32906b40b0c6dd9d17a3549fe88 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 15 Oct 2024 20:48:47 -0700 Subject: [PATCH 03/66] add epoch flag, sanity check tulu one epoch --- src/levanter/data/text.py | 4 ++-- src/levanter/main/train_lm.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index fde6b9f5b..88c278de2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -691,10 +691,10 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): cache_dir: Optional[str] = "cache/" def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False ) -> AsyncDataset[np.ndarray]: - if self.epochs is not None: + if epochs: ds = self.token_epoch_dataset("train", seq_len, monitors) else: ds = self.token_seq_dataset("train", seq_len, monitors) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 717941038..3866ac0f8 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -119,7 +119,7 @@ def main(config: TrainLmConfig): tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) train_dataset = CausalLmDataset( - config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key, epochs=False), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to @@ -230,14 +230,16 @@ def compute_log_probs(model, example): return logprobs.rearrange((EvalBatch, Pos)).array train_loader = trainer.data_loader(train_dataset, Batch) - # if seek_dataloader: - # train_loader = train_loader.iter_from_step(int(state.step)) - # else: - # train_loader = iter(train_loader) + if seek_dataloader: + train_loader = train_loader.iter_from_step(state.step) + else: + train_loader = iter(train_loader) + + ## OK, actually run training! + trainer.train(state, train_loader) ## OK, actually run training! - train_loader = train_loader.iter_from_step(int(state.step)) - trainer.train(state, train_loader, epochs=config.epochs) + # checkpointer.on_step(last_step, force=True) From 50500b9aed32a024d604d01507723036e8757c72 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 16 Oct 2024 16:03:47 -0700 Subject: [PATCH 04/66] epochs work --- src/levanter/callbacks.py | 44 +++++++++++++++++++++++++++++++++++ src/levanter/main/train_lm.py | 9 ++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index e03add43d..9a5b290dc 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -25,11 +25,55 @@ from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs +from levanter.data.text import TokenSeqEpochDataset +from concurrent.futures import ThreadPoolExecutor + logger = pylogging.getLogger(__name__) +def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size): + total_tokens = None + + def log_epoch(step_info: StepInfo): + nonlocal total_tokens + if total_tokens is None: + if not total_tokens_future.done(): + return # We don't have the total tokens yet, so we can't calculate epoch + total_tokens = total_tokens_future.result() + + # Get the total processed tokens from the metrics logged by log_performance_stats + processed_tokens = tokens_per_example * batch_size * step_info.step + if processed_tokens is None: + return # No token count available yet + + current_epoch = processed_tokens / total_tokens + levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step) + + return log_epoch + +def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): + def log_length(): + # If ds.async_len() is the only option, run it in an event loop inside the thread + import asyncio + + async def compute_length(): + length = await ds.async_len() + return length + + # Run the async function synchronously in this thread + length = asyncio.run(compute_length()) + total_tokens = length * seq_length + levanter.tracker.log_summary({"dataset/total_tokens": total_tokens}) + return total_tokens + + # Create a ThreadPoolExecutor with a single worker thread + executor = ThreadPoolExecutor(max_workers=1) + # Submit the log_length function to be executed in a separate thread + future = executor.submit(log_length) + return future + def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 total_load_time = 0.0 diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 3866ac0f8..17de8f52d 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -54,6 +54,7 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture + epoch: bool = False # if true, will keep epoching over the dataset and track epochs def main(config: TrainLmConfig): @@ -118,10 +119,16 @@ def main(config: TrainLmConfig): # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) + train_dataset = CausalLmDataset( - config.data.train_set(Pos.size, key=data_key, epochs=False), Pos, KeyPos, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) + if config.epoch: + # add epoch logging + total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) + trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1) + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. From 49afb5d824a7957d1155f9b6e103887a51f4b638 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 16 Oct 2024 16:07:13 -0700 Subject: [PATCH 05/66] minor fix --- src/levanter/main/train_lm.py | 2 -- src/levanter/trainer.py | 14 -------------- 2 files changed, 16 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 17de8f52d..9c511d31a 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -244,8 +244,6 @@ def compute_log_probs(model, example): ## OK, actually run training! trainer.train(state, train_loader) - - ## OK, actually run training! # checkpointer.on_step(last_step, force=True) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1ebc22122..3973c025a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -376,26 +376,12 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi while int(state.step) < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) - # while int(state.step) < target_steps and (epochs is None or current_epoch < epochs): - # current_epoch += 1 - # print(f"Starting epoch {current_epoch}") - # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) info = self.train_step(state, example) state = info.state if run_hooks: with capture_time() as hook_time: self.run_hooks(info) - # while True: - # try: - # with capture_time() as loading_time: - # example = next(iter_data) - # except StopIteration: - # # End of DataLoader iterator, proceed to next epoch - # train_loader = train_loader.iter_from_step(int(state.step)) - # print(f"End of epoch {current_epoch}") - # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) - # current_epoch += 1 levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step) From c2ed3ee23f2075ce86ef2724faadad8aa4f911fb Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 16 Oct 2024 16:37:40 -0700 Subject: [PATCH 06/66] fix ci --- config/llama_7b_tulu.yaml | 39 +++++++++++++++++++++++++++ config/llama_7b_with_olmo_config.yaml | 4 +-- src/levanter/data/text.py | 10 +++---- src/levanter/main/train_lm.py | 4 +-- 4 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 config/llama_7b_tulu.yaml diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml new file mode 100644 index 000000000..1c059a509 --- /dev/null +++ b/config/llama_7b_tulu.yaml @@ -0,0 +1,39 @@ +data: + train_urls: + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz" + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz" + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz" + cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/" + tokenizer: "allenai/OLMo-1B" +model: # 7B class model + type: llama + seq_len: 4096 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 1024 + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 5000 + +epoch: False \ No newline at end of file diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index 595fabe68..e41f7dbc2 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -31,5 +31,5 @@ optimizer: weight_decay: 0.1 min_lr_ratio: 0.1 warmup: 0.01 - - data_shuffle: true \ No newline at end of file + + data_shuffle: true diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 88c278de2..9605ff74c 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -93,10 +93,10 @@ async def current_len(self) -> Optional[int]: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: token_arrays = await self._await_token_cache() dataset_len = await self.async_len() - + wrapped_indices = [idx % dataset_len for idx in indices] offsets = np.array(wrapped_indices) * self.seq_len - + with ts.Batch(): out = [] for offset in offsets: @@ -691,9 +691,9 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): cache_dir: Optional[str] = "cache/" def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False ) -> AsyncDataset[np.ndarray]: - + if epochs: ds = self.token_epoch_dataset("train", seq_len, monitors) else: @@ -750,7 +750,7 @@ def token_seq_dataset( if cache is None: return None return TokenSeqDataset(cache, seq_len) - + def token_epoch_dataset( self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Optional[TokenSeqDataset]: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9c511d31a..9134591f2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -119,7 +119,7 @@ def main(config: TrainLmConfig): # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) - + train_dataset = CausalLmDataset( config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) @@ -244,7 +244,7 @@ def compute_log_probs(model, example): ## OK, actually run training! trainer.train(state, train_loader) - + # checkpointer.on_step(last_step, force=True) From 667a5a3602cf48d1bb40742677b2ee48c5bcc8de Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 16 Oct 2024 16:59:55 -0700 Subject: [PATCH 07/66] fix ci --- config/llama_7b_tulu.yaml | 4 ++-- config/llama_7b_with_olmo_config.yaml | 2 +- examples/alpaca/alpaca.py | 2 +- src/levanter/callbacks.py | 7 ++++--- src/levanter/data/text.py | 9 ++++++++- src/levanter/main/train_lm.py | 9 +++++++-- 6 files changed, 23 insertions(+), 10 deletions(-) diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index 1c059a509..2cd9bf5a2 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -27,7 +27,7 @@ trainer: train_batch_size: 256 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 - tensor_parallel_axes: ["mlp", "heads"] + tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" optimizer: @@ -36,4 +36,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 5000 -epoch: False \ No newline at end of file +epoch: False diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index e41f7dbc2..0b5bc4067 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -32,4 +32,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 0.01 - data_shuffle: true +data_shuffle: true diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index 97ef8d7ef..e8f805cde 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -162,7 +162,7 @@ def _prepare_example(ex: dict) -> LmExample: # mask out padding and anything before the start of the target Pos = input_ids.resolve_axis("position") if config.mask_inputs: - loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1? + loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1? # don't predict the padding targets = hax.roll(input_ids, -1, Pos) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 9a5b290dc..a96f904ad 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -8,6 +8,7 @@ import threading import time import warnings +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from typing import Callable, Optional @@ -18,6 +19,7 @@ import levanter.tracker from levanter.data import DataLoader +from levanter.data.text import TokenSeqEpochDataset from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -25,9 +27,6 @@ from levanter.utils import flop_utils from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs -from levanter.data.text import TokenSeqEpochDataset -from concurrent.futures import ThreadPoolExecutor - logger = pylogging.getLogger(__name__) @@ -53,6 +52,7 @@ def log_epoch(step_info: StepInfo): return log_epoch + def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread @@ -74,6 +74,7 @@ async def compute_length(): future = executor.submit(log_length) return future + def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None): total_loss = 0.0 total_load_time = 0.0 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 9605ff74c..9f9a24a1b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -63,6 +63,7 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index + class TokenSeqEpochDataset(AsyncDataset[np.ndarray]): def __init__(self, doc_cache: TreeCache[dict], seq_len: int): self.doc_cache = doc_cache @@ -115,6 +116,7 @@ async def wait_until_len_at_least(self, length: int) -> int: self._cached_len = length return length + class TokenSeqDataset(AsyncDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. @@ -691,7 +693,12 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): cache_dir: Optional[str] = "cache/" def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray] = None, + epochs: bool = False, ) -> AsyncDataset[np.ndarray]: if epochs: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9134591f2..96323dc03 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -121,13 +121,18 @@ def main(config: TrainLmConfig): # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) train_dataset = CausalLmDataset( - config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), + Pos, + KeyPos, + ignore_index=config.data.ignore_token_id, ) if config.epoch: # add epoch logging total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) - trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1) + trainer.add_hook( + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of From 37e77fb4f2815c8a8c183608c5cc193470f23858 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 17 Oct 2024 11:57:25 -0700 Subject: [PATCH 08/66] fix config file --- config/llama_7b_with_olmo_config.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index 0b5bc4067..9864a000f 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -15,10 +15,6 @@ trainer: project: "marin" tags: ["dolma", "olmo", "llama"] - checkpointer: - keep: - - every: 250 - mp: p=f32,c=bfloat16 train_batch_size: 2048 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 @@ -31,5 +27,3 @@ optimizer: weight_decay: 0.1 min_lr_ratio: 0.1 warmup: 0.01 - -data_shuffle: true From 7c195ba4cda506990a5aa54e10510e492df0cba4 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 17 Oct 2024 17:06:05 -0700 Subject: [PATCH 09/66] add suggested fix from david --- src/levanter/callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index a96f904ad..01e3d5528 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -39,6 +39,8 @@ def log_epoch(step_info: StepInfo): nonlocal total_tokens if total_tokens is None: if not total_tokens_future.done(): + if step_info.step % 1000 == 0: + logger.info("Dataset not finished. Can't compute epochs.") return # We don't have the total tokens yet, so we can't calculate epoch total_tokens = total_tokens_future.result() From 54a6007bea3f2fb562c11ede519001101efdfb59 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 22 Oct 2024 18:38:13 -0700 Subject: [PATCH 10/66] restore toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 2e217bc9a..19fb077bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "pydantic<3", "rich~=13.0", "filelock~=3.13", + # "ai2-olmo", "async-lru~=2.0", "tqdm-loggable>=0.2", "deepdiff" From e2646d6532530b46a386adecefc24c754f18a36c Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 22 Oct 2024 21:49:00 -0400 Subject: [PATCH 11/66] Update src/levanter/callbacks.py Co-authored-by: David Hall --- src/levanter/callbacks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 01e3d5528..6d3a8a154 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -56,6 +56,9 @@ def log_epoch(step_info: StepInfo): def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): + if not ds.is_finite(): + raise ValueError("Epochs don't make sense with an infinite dataset.") + def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread import asyncio From fd18cae541639517292293fb949c7ca7ac53aa86 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 22 Oct 2024 19:45:18 -0700 Subject: [PATCH 12/66] refactor --- src/levanter/data/text.py | 111 ++++++++++++++++++++-------------- src/levanter/main/train_lm.py | 14 ++--- 2 files changed, 71 insertions(+), 54 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index b8c714dd1..44931414c 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -64,59 +64,84 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index -class TokenSeqEpochDataset(AsyncDataset[np.ndarray]): - def __init__(self, doc_cache: TreeCache[dict], seq_len: int): - self.doc_cache = doc_cache - self.seq_len = seq_len - self._store: Optional[TreeStore] = None - self._cached_len: Optional[int] = None +class EpochDataset(AsyncDataset[T_co]): + """ + A dataset that wraps another dataset, providing infinite epochs by recycling indices. + If `max_epochs` is specified, it limits the number of cycles before raising StopIteration. - async def async_len(self) -> int: - await self.doc_cache.finished() - token_arrays = await self._await_token_cache() - return token_arrays.data_size // self.seq_len + :param dataset: The dataset to wrap. + :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely. + """ + def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): + if dataset.is_finite(): + raise ValueError("Cannot apply epoching to a finite dataset.") + self.dataset = dataset + self.max_epochs = max_epochs - async def _await_token_cache(self) -> JaggedArrayStore: - if self._store is None: - self._store = await self.doc_cache.store_async() - return self._store.tree["input_ids"] + async def async_len(self) -> int: + if self.max_epochs is None: + raise ValueError("Cannot determine length of an infinite dataset without max_epochs.") + # Return the total number of samples: max_epochs * length of the dataset + return self.max_epochs * await self.dataset.async_len() async def final_length_is_known(self) -> bool: - return await self.doc_cache.final_length_is_known() + return await self.dataset.final_length_is_known() def is_finite(self) -> bool: - return False # Now infinite due to epoch wrapping + # EpochDataset can be finite if max_epochs is set. + return self.max_epochs is not None async def current_len(self) -> Optional[int]: - store = await self._await_token_cache() - return store.data_size // self.seq_len + # If max_epochs is None, the dataset is effectively infinite. + if self.max_epochs is None: + return None + + # If the final length of the dataset is not known, return the current length of the underlying dataset. + if not await self.dataset.final_length_is_known(): + return await self.dataset.current_len() + + # If the final length is known, return the max_epochs * async_len of the dataset. + return self.max_epochs * await self.dataset.async_len() async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: - token_arrays = await self._await_token_cache() - dataset_len = await self.async_len() + # Use self.wait_until_len_at_least to ensure we have enough data for the batch. + max_index = max(indices) + ds_len = await self.wait_until_len_at_least(max_index + 1) - wrapped_indices = [idx % dataset_len for idx in indices] - offsets = np.array(wrapped_indices) * self.seq_len + # Determine the epoch based on the largest index + epoch = max_index // ds_len - with ts.Batch(): - out = [] - for offset in offsets: - out.append(token_arrays.data[offset : offset + self.seq_len].read()) + # If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs + if self.max_epochs is not None and epoch >= self.max_epochs: + raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}") - out = await asyncio.gather(*out) - return out + # Wrap the indices within the bounds of the dataset length + wrapped_indices = [idx % ds_len for idx in indices] - async def wait_until_len_at_least(self, length: int) -> int: - # length is brutally slow to compute, so we cache it - if self._cached_len is not None: - return self._cached_len + # Delegate to the underlying dataset's get_batch + return await self.dataset.get_batch(wrapped_indices) - # TODO: would be better to listen for cache updates - length = await super().wait_until_len_at_least(length) - self._cached_len = length - return length + async def wait_until_len_at_least(self, length: int) -> int: + """ + Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length. + If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length. + """ + # Wait until the underlying dataset's length is at least `length` + if not self.is_finite(): + return length + + if await self.dataset.final_length_is_known(): + base_length = await self.dataset.async_len() + else: + base_length = await self.dataset.wait_until_len_at_least(length) + if base_length < length: + # hit epoch boundary + assert self.max_epochs is not None + return self.max_epochs * base_length + return base_length + class TokenSeqDataset(AsyncDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. @@ -709,10 +734,10 @@ def train_set( epochs: bool = False, ) -> AsyncDataset[np.ndarray]: + ds = self.token_seq_dataset("train", seq_len, monitors) if epochs: - ds = self.token_epoch_dataset("train", seq_len, monitors) - else: - ds = self.token_seq_dataset("train", seq_len, monitors) + logger.info("Wrapping dataset in epoch dataset") + ds = EpochDataset(ds) # add epoch flag here. if ds is None: @@ -766,14 +791,6 @@ def token_seq_dataset( return None return TokenSeqDataset(cache, seq_len) - def token_epoch_dataset( - self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - cache = self.build_or_load_cache(split, monitors=monitors) - if cache is None: - return None - return TokenSeqEpochDataset(cache, seq_len) - def build_or_load_cache( self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None ) -> Optional[TreeCache[BatchEncoding]]: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 96323dc03..6f76482f2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -54,7 +54,7 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture - epoch: bool = False # if true, will keep epoching over the dataset and track epochs + epoch: bool | int = False def main(config: TrainLmConfig): @@ -127,12 +127,12 @@ def main(config: TrainLmConfig): ignore_index=config.data.ignore_token_id, ) - if config.epoch: - # add epoch logging - total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) - trainer.add_hook( - callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 - ) + + # add epoch logging + total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) + trainer.add_hook( + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of From 1706803e87a818e1125994b3f6c84e2c9a4f03ee Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 22 Oct 2024 19:58:17 -0700 Subject: [PATCH 13/66] add suggested fix from david --- src/levanter/callbacks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 6d3a8a154..49da98456 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -18,8 +18,7 @@ from tqdm_loggable.auto import tqdm import levanter.tracker -from levanter.data import DataLoader -from levanter.data.text import TokenSeqEpochDataset +from levanter.data import DataLoader, AsyncDataset from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -55,7 +54,7 @@ def log_epoch(step_info: StepInfo): return log_epoch -def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int): +def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): if not ds.is_finite(): raise ValueError("Epochs don't make sense with an infinite dataset.") From f0ca1637401a8c9a18e9cbf105686b0049f4a54b Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 23 Oct 2024 12:13:23 -0700 Subject: [PATCH 14/66] update for v4 so we don't crash --- config/llama_7b_tulu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index 2cd9bf5a2..48af18b2a 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -14,7 +14,7 @@ model: # 7B class model num_heads: 32 num_kv_heads: 32 use_flash_attention: True - flash_attention_block_size: 1024 + flash_attention_block_size: 512 use_bias: false use_layer_norm_weight: false trainer: From c971ebfeb61a87169d2cc62951bb514958fc97af Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 23 Oct 2024 14:21:04 -0700 Subject: [PATCH 15/66] remove changes that break epochs --- config/llama_7b_tulu.yaml | 2 +- src/levanter/callbacks.py | 2 -- src/levanter/data/text.py | 4 +--- src/levanter/main/train_lm.py | 2 +- 4 files changed, 3 insertions(+), 7 deletions(-) diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index 48af18b2a..cf333f850 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -36,4 +36,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 5000 -epoch: False +epoch: 0 diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 49da98456..2eae0185e 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -55,8 +55,6 @@ def log_epoch(step_info: StepInfo): def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): - if not ds.is_finite(): - raise ValueError("Epochs don't make sense with an infinite dataset.") def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 44931414c..3b9789cda 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -73,8 +73,6 @@ class EpochDataset(AsyncDataset[T_co]): :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely. """ def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): - if dataset.is_finite(): - raise ValueError("Cannot apply epoching to a finite dataset.") self.dataset = dataset self.max_epochs = max_epochs @@ -737,7 +735,7 @@ def train_set( ds = self.token_seq_dataset("train", seq_len, monitors) if epochs: logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds) + ds = EpochDataset(ds, max_epochs=epochs) # add epoch flag here. if ds is None: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 6f76482f2..54a39700e 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -54,7 +54,7 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture - epoch: bool | int = False + epoch: int = 0 def main(config: TrainLmConfig): From 4733f3b978a170d48feb2e5698dcc40e4d4d2cfd Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 23 Oct 2024 14:41:03 -0700 Subject: [PATCH 16/66] final fixes --- src/levanter/data/text.py | 2 +- src/levanter/main/train_lm.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 3b9789cda..ff1ce154f 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -729,7 +729,7 @@ def train_set( monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, - epochs: bool = False, + epochs: int = 0, ) -> AsyncDataset[np.ndarray]: ds = self.token_seq_dataset("train", seq_len, monitors) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 54a39700e..87e6cdc13 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -128,11 +128,12 @@ def main(config: TrainLmConfig): ) - # add epoch logging - total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) - trainer.add_hook( - callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 - ) + # add epoch logging if epochs specified + if config.epoch > 0: + total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) + trainer.add_hook( + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of From e82eec22b0521f2ec6805f1583f295ce02a7b6f2 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 23 Oct 2024 17:45:53 -0700 Subject: [PATCH 17/66] final fixes --- src/levanter/data/text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index ff1ce154f..8c9b5af05 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -104,7 +104,7 @@ async def current_len(self) -> Optional[int]: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: # Use self.wait_until_len_at_least to ensure we have enough data for the batch. max_index = max(indices) - ds_len = await self.wait_until_len_at_least(max_index + 1) + ds_len = await self.dataset.wait_until_len_at_least(max_index + 1) # Determine the epoch based on the largest index epoch = max_index // ds_len From 08fd427d7e37a4c5bb709519a6767a30ba8df100 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 23 Oct 2024 21:35:18 -0700 Subject: [PATCH 18/66] substatial changes to save on epochs w callback --- src/levanter/callbacks.py | 13 ++++++------- src/levanter/checkpoint.py | 34 ++++++++++++++++++++++++++++++++++ src/levanter/main/train_lm.py | 13 +++++++++++-- 3 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 2eae0185e..22be96cd8 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -27,11 +27,9 @@ from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs - logger = pylogging.getLogger(__name__) - -def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size): +def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None): total_tokens = None def log_epoch(step_info: StepInfo): @@ -45,10 +43,11 @@ def log_epoch(step_info: StepInfo): # Get the total processed tokens from the metrics logged by log_performance_stats processed_tokens = tokens_per_example * batch_size * step_info.step - if processed_tokens is None: - return # No token count available yet - - current_epoch = processed_tokens / total_tokens + + # If we're doing multiple epochs, adjust the denominator + total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens + current_epoch = processed_tokens / total_tokens_for_epochs + levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step) return log_epoch diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 5bfb6be30..00ad37491 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -27,6 +27,7 @@ from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore from levanter.types import FilterSpec +# from levanter.trainer import StepInfo logger = logging.getLogger(__name__) @@ -261,6 +262,39 @@ def _async_checkpoint_remover(self): self._do_rm_checkpoint(checkpoint) self._checkpoint_being_removed = None +# In callbacks.py - Add a new callback that handles epoch checkpointing +class EpochCheckpointer: + """ + A separate checkpointing system that saves based on epochs. + Works alongside the regular step-based checkpointer without modifying core state. + """ + def __init__(self, + checkpointer: Checkpointer, + every_n_epochs: int = 1, + total_dataset_size: Optional[int] = None, + batch_size: int = 1): + self.checkpointer = checkpointer + self.every_n_epochs = every_n_epochs + self.total_dataset_size = total_dataset_size + self.batch_size = batch_size + self._last_saved_epoch = -1 + + def __call__(self, step_info): + if self.total_dataset_size is None: + return # Can't calculate epochs without dataset size + + # Calculate current epoch from steps without modifying StepInfo + current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size + + # Only save if we've moved to a new epoch and it matches our interval + if (current_epoch > self._last_saved_epoch and + current_epoch % self.every_n_epochs == 0): + # Use existing checkpointer's save_checkpoint method + self.checkpointer.save_checkpoint( + step_info, + f"epoch-{current_epoch}" + ) + self._last_saved_epoch = current_epoch def save_checkpoint( tree: M, diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 87e6cdc13..e70c14c82 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -14,7 +14,7 @@ import levanter from levanter import callbacks -from levanter.checkpoint import load_checkpoint +from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig from levanter.models.gpt2 import Gpt2Config @@ -132,9 +132,18 @@ def main(config: TrainLmConfig): if config.epoch > 0: total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) trainer.add_hook( - callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1 + callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch), every=1 ) + # Add epoch checkpoint callback + epoch_checkpointer = EpochCheckpointer( + checkpointer=trainer.config.checkpointer.create(trainer.run_id), + every_n_epochs=1, # Or configure as needed + total_dataset_size=total_tokens_future.result(), + batch_size=trainer.config.train_batch_size + ) + trainer.add_hook(epoch_checkpointer, every=1) + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. From 18a535299275477039bd24e08df7f8d495304ef7 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 24 Oct 2024 10:50:25 -0700 Subject: [PATCH 19/66] epoch tracking still broken --- src/levanter/callbacks.py | 2 +- src/levanter/checkpoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 22be96cd8..37ea38272 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -60,7 +60,7 @@ def log_length(): import asyncio async def compute_length(): - length = await ds.async_len() + length = await ds.dataset.async_len() return length # Run the async function synchronously in this thread diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 00ad37491..d6f062ffc 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -292,7 +292,7 @@ def __call__(self, step_info): # Use existing checkpointer's save_checkpoint method self.checkpointer.save_checkpoint( step_info, - f"epoch-{current_epoch}" + f"epoch-{current_epoch}", ) self._last_saved_epoch = current_epoch From c38b076506e83e054e36b67ee2c2a20d74d5a822 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 24 Oct 2024 20:34:54 -0700 Subject: [PATCH 20/66] WIP --- src/levanter/data/text.py | 52 +++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 8c9b5af05..e59f74423 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -638,18 +638,23 @@ def tagged_eval_sets( @dataclass class LMSupervisedDatasetConfig: - """This class represents a dataset source with URLs or hf name/id.""" - + """Config for supervised fine-tuning datasets""" cache_dir: str = "cache/" - + + # HF dataset config + hf_dataset_name: Optional[str] = None # e.g. "tatsu-lab/alpaca" or "OpenAssistant/oasst1" + hf_dataset_split: str = "train" # which split to use + + # Local files config + validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files + + # Field names in the data + input_field: str = "prompt" # name of the input field + output_field: str = "response" # name of output field + + # Optional metadata tags: Optional[List[str]] = None - """tags for the dataset. Typically the name of the dataset in the config will be added as a tag as well""" - name: Optional[str] = None # name for hf dataset - - input_field: str = "prompt" # name of the input field in the jsonl file - output_field: str = "response" # name of the output field in the jsonl file - - validation_urls: List[str] = () # type:ignore + name: Optional[str] = None def preprocess_supervised_example( @@ -700,23 +705,38 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data - - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] - dataset = levanter.data.datasource_from_jsonl(validation_urls) + + # Choose data source based on config + if config.hf_dataset_name is not None: + # Using HF dataset + dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) + else: + # Using local files + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + if not validation_urls: + raise ValueError("Must specify either hf_dataset_name or validation_urls") + dataset = levanter.data.datasource_from_jsonl(validation_urls) input_field = config.input_field output_field = config.output_field output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} - dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore + # Use the same preprocessing as before + dataset = dataset.map_batches( + lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar + ) + + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) - @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" From 7331774379f54d7aec427fb6149a308a074f1e35 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sun, 27 Oct 2024 23:28:09 -0700 Subject: [PATCH 21/66] update epochs to save latest checkpoints --- src/levanter/main/train_lm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index e70c14c82..7a0527dc5 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -258,9 +258,15 @@ def compute_log_probs(model, example): train_loader = iter(train_loader) ## OK, actually run training! - trainer.train(state, train_loader) + last_info = trainer.train(state, train_loader) - # checkpointer.on_step(last_step, force=True) + + + # If running EpochDataset save latest checkpoint by default + if trainer.config.checkpointer is not None and config.epoch > 0: + trainer.run_hooks(last_info, force=True) + checkpointer = trainer.config.checkpointer.create(trainer.run_id) + checkpointer.wait_until_finished() if __name__ == "__main__": From aa47d4ec3f84f79148b46c4b421c40d86190271d Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 02:29:14 -0400 Subject: [PATCH 22/66] Update src/levanter/checkpoint.py Co-authored-by: David Hall --- src/levanter/checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index d6f062ffc..651b84c68 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -27,7 +27,6 @@ from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore from levanter.types import FilterSpec -# from levanter.trainer import StepInfo logger = logging.getLogger(__name__) From 0148cd092c1c5234787d5c71e0154069cfd53fee Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 13:20:54 -0700 Subject: [PATCH 23/66] update tulu config to match olmo sft --- config/llama_7b_tulu.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml index cf333f850..2da335a17 100644 --- a/config/llama_7b_tulu.yaml +++ b/config/llama_7b_tulu.yaml @@ -7,7 +7,7 @@ data: tokenizer: "allenai/OLMo-1B" model: # 7B class model type: llama - seq_len: 4096 + seq_len: 2048 hidden_dim: 4096 intermediate_dim: 11008 num_layers: 32 @@ -36,4 +36,4 @@ optimizer: min_lr_ratio: 0.1 warmup: 5000 -epoch: 0 +epoch: 3 From 5343096d875d33bd26b9f399dfafcaebc88a4a0b Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 13:25:54 -0700 Subject: [PATCH 24/66] pre commit --- src/levanter/callbacks.py | 9 +++++---- src/levanter/checkpoint.py | 22 +++++++++++++--------- src/levanter/data/text.py | 28 +++++++++++++++++----------- src/levanter/main/train_lm.py | 12 ++++++------ 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/levanter/callbacks.py b/src/levanter/callbacks.py index 37ea38272..897109ffc 100644 --- a/src/levanter/callbacks.py +++ b/src/levanter/callbacks.py @@ -18,7 +18,7 @@ from tqdm_loggable.auto import tqdm import levanter.tracker -from levanter.data import DataLoader, AsyncDataset +from levanter.data import AsyncDataset, DataLoader from levanter.logging import save_xla_dumps_to_wandb from levanter.tracker.helpers import log_optimizer_hyperparams from levanter.tracker.wandb import WandbConfig @@ -27,8 +27,10 @@ from levanter.utils.jax_utils import barrier_sync, jnp_to_python from levanter.visualization import compute_and_visualize_log_probs as viz_probs + logger = pylogging.getLogger(__name__) + def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None): total_tokens = None @@ -43,18 +45,17 @@ def log_epoch(step_info: StepInfo): # Get the total processed tokens from the metrics logged by log_performance_stats processed_tokens = tokens_per_example * batch_size * step_info.step - + # If we're doing multiple epochs, adjust the denominator total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens current_epoch = processed_tokens / total_tokens_for_epochs - + levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step) return log_epoch def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int): - def log_length(): # If ds.async_len() is the only option, run it in an event loop inside the thread import asyncio diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 651b84c68..38b039f20 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -261,17 +261,21 @@ def _async_checkpoint_remover(self): self._do_rm_checkpoint(checkpoint) self._checkpoint_being_removed = None + # In callbacks.py - Add a new callback that handles epoch checkpointing class EpochCheckpointer: """ A separate checkpointing system that saves based on epochs. Works alongside the regular step-based checkpointer without modifying core state. """ - def __init__(self, - checkpointer: Checkpointer, - every_n_epochs: int = 1, - total_dataset_size: Optional[int] = None, - batch_size: int = 1): + + def __init__( + self, + checkpointer: Checkpointer, + every_n_epochs: int = 1, + total_dataset_size: Optional[int] = None, + batch_size: int = 1, + ): self.checkpointer = checkpointer self.every_n_epochs = every_n_epochs self.total_dataset_size = total_dataset_size @@ -281,13 +285,12 @@ def __init__(self, def __call__(self, step_info): if self.total_dataset_size is None: return # Can't calculate epochs without dataset size - + # Calculate current epoch from steps without modifying StepInfo current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size - + # Only save if we've moved to a new epoch and it matches our interval - if (current_epoch > self._last_saved_epoch and - current_epoch % self.every_n_epochs == 0): + if current_epoch > self._last_saved_epoch and current_epoch % self.every_n_epochs == 0: # Use existing checkpointer's save_checkpoint method self.checkpointer.save_checkpoint( step_info, @@ -295,6 +298,7 @@ def __call__(self, step_info): ) self._last_saved_epoch = current_epoch + def save_checkpoint( tree: M, step: int, diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index e59f74423..f52799eea 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -72,6 +72,7 @@ class EpochDataset(AsyncDataset[T_co]): :param dataset: The dataset to wrap. :param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely. """ + def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): self.dataset = dataset self.max_epochs = max_epochs @@ -111,7 +112,9 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: # If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs if self.max_epochs is not None and epoch >= self.max_epochs: - raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}") + raise StopIteration( + f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}" + ) # Wrap the indices within the bounds of the dataset length wrapped_indices = [idx % ds_len for idx in indices] @@ -139,7 +142,8 @@ async def wait_until_len_at_least(self, length: int) -> int: return self.max_epochs * base_length return base_length - + + class TokenSeqDataset(AsyncDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from an underlying TreeCache. @@ -639,19 +643,20 @@ def tagged_eval_sets( @dataclass class LMSupervisedDatasetConfig: """Config for supervised fine-tuning datasets""" + cache_dir: str = "cache/" - + # HF dataset config hf_dataset_name: Optional[str] = None # e.g. "tatsu-lab/alpaca" or "OpenAssistant/oasst1" hf_dataset_split: str = "train" # which split to use - + # Local files config validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files - + # Field names in the data input_field: str = "prompt" # name of the input field output_field: str = "response" # name of output field - + # Optional metadata tags: Optional[List[str]] = None name: Optional[str] = None @@ -705,7 +710,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data - + # Choose data source based on config if config.hf_dataset_name is not None: # Using HF dataset @@ -725,18 +730,19 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain # Use the same preprocessing as before dataset = dataset.map_batches( lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), - batch_size=128, + batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), - output_exemplar=output_exemplar + output_exemplar=output_exemplar, ) - + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) - + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 7a0527dc5..ee7e353c5 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -54,7 +54,7 @@ class TrainLmConfig: data_seed: Optional[int] = None # if provided, will override the data seed from the trainer initialize_from_checkpoint_path: Optional[str] = None # if provided, will initialize from this checkpoint, used for llama style data mixture - epoch: int = 0 + epoch: int = 0 def main(config: TrainLmConfig): @@ -127,12 +127,14 @@ def main(config: TrainLmConfig): ignore_index=config.data.ignore_token_id, ) - # add epoch logging if epochs specified if config.epoch > 0: total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len) trainer.add_hook( - callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch), every=1 + callbacks.log_epoch_progress( + total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch + ), + every=1, ) # Add epoch checkpoint callback @@ -140,7 +142,7 @@ def main(config: TrainLmConfig): checkpointer=trainer.config.checkpointer.create(trainer.run_id), every_n_epochs=1, # Or configure as needed total_dataset_size=total_tokens_future.result(), - batch_size=trainer.config.train_batch_size + batch_size=trainer.config.train_batch_size, ) trainer.add_hook(epoch_checkpointer, every=1) @@ -260,8 +262,6 @@ def compute_log_probs(model, example): ## OK, actually run training! last_info = trainer.train(state, train_loader) - - # If running EpochDataset save latest checkpoint by default if trainer.config.checkpointer is not None and config.epoch > 0: trainer.run_hooks(last_info, force=True) From fd3982830c688520a7ab416c5eb77e3e39445462 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 18:51:02 -0700 Subject: [PATCH 25/66] fix sft bug caused by exemplar tldr exemplar is a schema for data. I was storing the sequences and lengths with np arrays as schema objects and one of them had dim 0 which is a scalar and invalid for arrays --- src/levanter/data/text.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index f52799eea..b9690e8b7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -725,7 +725,8 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} # Use the same preprocessing as before dataset = dataset.map_batches( From 313a3f4c3ae2ba7bf60e329409348b3b557cb316 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 18:53:20 -0700 Subject: [PATCH 26/66] add actual sft file --- examples/sft/alpaca-llama.yaml | 32 +++++++ examples/sft/dolly-llama.yaml | 32 +++++++ examples/sft/oasst-llama.yaml | 38 ++++++++ examples/sft/sft.py | 154 +++++++++++++++++++++++++++++++++ 4 files changed, 256 insertions(+) create mode 100644 examples/sft/alpaca-llama.yaml create mode 100644 examples/sft/dolly-llama.yaml create mode 100644 examples/sft/oasst-llama.yaml create mode 100644 examples/sft/sft.py diff --git a/examples/sft/alpaca-llama.yaml b/examples/sft/alpaca-llama.yaml new file mode 100644 index 000000000..ac2de709d --- /dev/null +++ b/examples/sft/alpaca-llama.yaml @@ -0,0 +1,32 @@ +model_name_or_path: meta-llama/Llama-2-7b-hf + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + wandb: + project: "levanter-sft" + tags: ["llama2", "alpaca"] + num_train_steps: 1218 + train_batch_size: 64 + # If using model parallelism + tensor_parallel_axes: ["mlp", "heads"] + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + +supervised_data: + hf_dataset_name: "tatsu-lab/alpaca" + hf_dataset_split: "train" + input_field: "instruction" # change from prompt + output_field: "output" # this is correct + cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-new" + +max_tune_length: 2048 +trust_remote_code: false +model_cache_dir: null + +hf_save_path: "sft_hf_ckpts" +hf_upload: false +hf_save_steps: 1000 \ No newline at end of file diff --git a/examples/sft/dolly-llama.yaml b/examples/sft/dolly-llama.yaml new file mode 100644 index 000000000..9dd68f984 --- /dev/null +++ b/examples/sft/dolly-llama.yaml @@ -0,0 +1,32 @@ +model_name_or_path: meta-llama/Llama-2-7b-hf + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + wandb: + project: "levanter-sft" + tags: ["llama2", "oasst"] + num_train_steps: 1218 + train_batch_size: 128 + # If using model parallelism + tensor_parallel_axes: ["mlp", "heads"] + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + +supervised_data: + hf_dataset_name: "databricks/databricks-dolly-15k" + hf_dataset_split: "train" + input_field: "instruction" # change from prompt + output_field: "response" # this is correct + cache_dir: "cache/dolly" + +max_tune_length: 2048 +trust_remote_code: false +model_cache_dir: null + +hf_save_path: "sft_hf_ckpts" +hf_upload: false +hf_save_steps: 1000 \ No newline at end of file diff --git a/examples/sft/oasst-llama.yaml b/examples/sft/oasst-llama.yaml new file mode 100644 index 000000000..46f89f0ea --- /dev/null +++ b/examples/sft/oasst-llama.yaml @@ -0,0 +1,38 @@ +model_name_or_path: meta-llama/Llama-2-7b-hf + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + wandb: + project: "levanter-sft" + tags: ["llama2", "oasst"] + num_train_steps: 1218 + train_batch_size: 128 + + # If using model parallelism + tensor_parallel_axes: ["mlp", "heads"] + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + +# Supervised data configuration +supervised_data: + # For HF dataset + id: "databricks/databricks-dolly-15k" + input_field: "instruction" # adjust based on dataset + output_field: "response" # adjust based on dataset + cache_dir: "cache/dolly" + +# Model configuration +max_tune_length: 2048 +trust_remote_code: false +model_cache_dir: null + +# Checkpoint saving configuration +hf_save_path: "sft_hf_ckpts" +hf_upload: false +hf_save_steps: 1000 + +# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml \ No newline at end of file diff --git a/examples/sft/sft.py b/examples/sft/sft.py new file mode 100644 index 000000000..8fa81624a --- /dev/null +++ b/examples/sft/sft.py @@ -0,0 +1,154 @@ +import json +import logging +import os +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import fsspec +import jax +import jax.random as jrandom +import transformers + +import haliax as hax + +import levanter +from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback +from levanter.data import PermutationDataset +from levanter.models.lm_model import LmHeadModel, compute_next_token_loss +from levanter.optim import OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils import fsspec_utils +from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.py_utils import non_caching_cycle +from levanter.data.text import mk_supervised_dataset, LMSupervisedDatasetConfig, EpochDataset + + +logger = logging.getLogger(__name__) + +# Define default special tokens +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" +DEFAULT_UNK_TOKEN = "" + + +@dataclass +class TrainArgs: + optimizer: OptimizerConfig + trainer: TrainerConfig + + max_tune_length: int = 2048 # maximum length of the input to the model during tuning + + # Supervision config + supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig() + input_field: str = "instruction" # field name for input in dataset + output_field: str = "output" # field name for output in dataset + data_cache_dir: str = "cache/" # Path to cache the tokenized data + + model_name_or_path: str = "meta-llama/Llama-2-7b-hf" + trust_remote_code: bool = False # Trust remote code when loading from HuggingFace checkpoints. + model_cache_dir: Optional[str] = None # Path to cache the model. must be local. + + hf_save_path: Optional[str] = "sft_hf_ckpts" # Path to save the HuggingFace checkpoint + hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any) + hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint + + epochs: int = 0 # Number of epochs to train for + + +def train(config: TrainArgs): + levanter.initialize(config) + + converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=config.trust_remote_code) + model_config = converter.default_config + + if config.max_tune_length > model_config.Pos.size: + logger.warning( + f"max_tune_length ({config.max_tune_length}) is greater than the model's maximum length" + f" ({model_config.Pos.size}). " + ) + + training_key, data_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + config.model_name_or_path, + cache_dir=config.model_cache_dir, + model_max_length=config.max_tune_length, + padding_side="right", + ) + num_new_tokens = add_special_tokens(tokenizer) + logger.info(f"Added {num_new_tokens} new tokens") + + # modify converter to use our tokenizer + converter = converter.replaced(tokenizer=tokenizer) + + # Configure supervised dataset + supervised_config = config.supervised_data + + # Create supervised dataset using generic machinery + logger.info("Creating supervised dataset") + train_dataset = mk_supervised_dataset(supervised_config, tokenizer) + logger.info("Supervised dataset created") + train_dataset = PermutationDataset(train_dataset, data_key) + + # Then wrap for epochs + if config.epochs > 0: + logger.info(f"Wrapping dataset for {config.epochs} epochs") + train_dataset = EpochDataset(train_dataset, max_epochs=config.epochs) + + logger.info("Creating optimizer") + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: + parameter_axis_mapping = trainer.parameter_axis_mapping + + logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") + model: LmHeadModel = converter.load_pretrained( + model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype + ) + + model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) + + loader = trainer.data_loader(train_dataset, trainer.TrainBatch) + loader = non_caching_cycle(loader) + + state = trainer.initial_state(training_key, model=model) + + if int(state.step) != 0: + logger.info(f"Resuming training from step {state.step}") + for i in range(state.step): + next(loader) + + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + + trainer.add_hook( + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload), + every=config.hf_save_steps, + ) + + trainer.train(state, loader) + + +def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): + special_tokens_dict = dict() + if use_unk_instead_of_adding: + if tokenizer.unk_token is None: + raise ValueError("use_unk_instead_of_add is True but tokenizer doesn't have an unk token") + + unk = tokenizer.unk_token if use_unk_instead_of_adding else None + + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if not use_unk_instead_of_adding else unk + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN if not use_unk_instead_of_adding else unk + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN if not use_unk_instead_of_adding else unk + if tokenizer.unk_token is None: + special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN + + return tokenizer.add_special_tokens(special_tokens_dict) + + +if __name__ == "__main__": + levanter.config.main(train)() \ No newline at end of file From b3718c1b7225cbc526e622146d1fc1f20ae79951 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 28 Oct 2024 18:57:59 -0700 Subject: [PATCH 27/66] precommit --- examples/sft/alpaca-llama.yaml | 4 ++-- examples/sft/dolly-llama.yaml | 4 ++-- examples/sft/oasst-llama.yaml | 8 ++++---- examples/sft/sft.py | 15 +++++---------- src/levanter/data/text.py | 1 - 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/sft/alpaca-llama.yaml b/examples/sft/alpaca-llama.yaml index ac2de709d..5817bf0e3 100644 --- a/examples/sft/alpaca-llama.yaml +++ b/examples/sft/alpaca-llama.yaml @@ -6,7 +6,7 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "alpaca"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 64 # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -29,4 +29,4 @@ model_cache_dir: null hf_save_path: "sft_hf_ckpts" hf_upload: false -hf_save_steps: 1000 \ No newline at end of file +hf_save_steps: 1000 diff --git a/examples/sft/dolly-llama.yaml b/examples/sft/dolly-llama.yaml index 9dd68f984..f386c32b7 100644 --- a/examples/sft/dolly-llama.yaml +++ b/examples/sft/dolly-llama.yaml @@ -6,7 +6,7 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "oasst"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 128 # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -29,4 +29,4 @@ model_cache_dir: null hf_save_path: "sft_hf_ckpts" hf_upload: false -hf_save_steps: 1000 \ No newline at end of file +hf_save_steps: 1000 diff --git a/examples/sft/oasst-llama.yaml b/examples/sft/oasst-llama.yaml index 46f89f0ea..48cd6ae2b 100644 --- a/examples/sft/oasst-llama.yaml +++ b/examples/sft/oasst-llama.yaml @@ -6,9 +6,9 @@ trainer: wandb: project: "levanter-sft" tags: ["llama2", "oasst"] - num_train_steps: 1218 + num_train_steps: 1218 train_batch_size: 128 - + # If using model parallelism tensor_parallel_axes: ["mlp", "heads"] @@ -25,7 +25,7 @@ supervised_data: output_field: "response" # adjust based on dataset cache_dir: "cache/dolly" -# Model configuration +# Model configuration max_tune_length: 2048 trust_remote_code: false model_cache_dir: null @@ -35,4 +35,4 @@ hf_save_path: "sft_hf_ckpts" hf_upload: false hf_save_steps: 1000 -# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml \ No newline at end of file +# python examples/sft/sft.py --config_path examples/sft/oasst-llama2.yaml diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 8fa81624a..90bd1ab85 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -1,11 +1,8 @@ -import json import logging import os from dataclasses import dataclass -from typing import Dict, Optional, Union +from typing import Optional, Union -import fsspec -import jax import jax.random as jrandom import transformers @@ -14,13 +11,11 @@ import levanter from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback from levanter.data import PermutationDataset +from levanter.data.text import EpochDataset, LMSupervisedDatasetConfig, mk_supervised_dataset from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.optim import OptimizerConfig from levanter.trainer import Trainer, TrainerConfig -from levanter.utils import fsspec_utils -from levanter.utils.hf_utils import num_cpus_used_by_tokenizer from levanter.utils.py_utils import non_caching_cycle -from levanter.data.text import mk_supervised_dataset, LMSupervisedDatasetConfig, EpochDataset logger = logging.getLogger(__name__) @@ -38,7 +33,7 @@ class TrainArgs: trainer: TrainerConfig max_tune_length: int = 2048 # maximum length of the input to the model during tuning - + # Supervision config supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig() input_field: str = "instruction" # field name for input in dataset @@ -81,7 +76,7 @@ def train(config: TrainArgs): # modify converter to use our tokenizer converter = converter.replaced(tokenizer=tokenizer) - + # Configure supervised dataset supervised_config = config.supervised_data @@ -151,4 +146,4 @@ def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): if __name__ == "__main__": - levanter.config.main(train)() \ No newline at end of file + levanter.config.main(train)() diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index b9690e8b7..0181889d9 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -725,7 +725,6 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} # Use the same preprocessing as before From 5f36eb8e75a045aa145f2e966d3572c8b3807510 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 29 Oct 2024 16:03:58 -0700 Subject: [PATCH 28/66] sft working w levanter chkpt --- examples/sft/alpaca-llama-sft.yaml | 52 ++++++++++++++ examples/sft/sft.py | 106 ++++++++++++++++------------- 2 files changed, 110 insertions(+), 48 deletions(-) create mode 100644 examples/sft/alpaca-llama-sft.yaml diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml new file mode 100644 index 000000000..72c9aad78 --- /dev/null +++ b/examples/sft/alpaca-llama-sft.yaml @@ -0,0 +1,52 @@ +# Model configuration +model: + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: true + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + tracker: + type: wandb + project: "levanter-sft" + tags: ["llama", "sft"] + num_train_steps: 1218 + train_batch_size: 64 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + steps_per_eval: 1000 + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + min_lr_ratio: 0.1 + warmup: 100 + +# Supervised data configuration +supervised_data: + cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" + input_field: "instruction" + output_field: "output" + hf_dataset_name: "tatsu-lab/alpaca" # Changed from id + hf_dataset_split: "train" + name: "alpaca" # Optional metadata + tags: ["instruction-tuning"] # Optional metadata + validation_urls: [] # Empty list for no validation files + +# Additional settings +tokenizer: "allenai/OLMo-1B" +max_tune_length: 2048 +epoch: 3 + +initialize_from_hf: false \ No newline at end of file diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 90bd1ab85..74d4f6dc9 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -1,20 +1,21 @@ import logging import os from dataclasses import dataclass -from typing import Optional, Union import jax.random as jrandom import transformers -import haliax as hax +from haliax import Axis +from haliax.partitioning import round_axis_for_partitioning import levanter +from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import EpochDataset, LMSupervisedDatasetConfig, mk_supervised_dataset -from levanter.models.lm_model import LmHeadModel, compute_next_token_loss -from levanter.optim import OptimizerConfig -from levanter.trainer import Trainer, TrainerConfig +from levanter.data.text import EpochDataset, mk_supervised_dataset +from levanter.main.train_lm import TrainLmConfig +from levanter.models.lm_model import compute_next_token_loss +from levanter.trainer import Trainer from levanter.utils.py_utils import non_caching_cycle @@ -28,55 +29,46 @@ @dataclass -class TrainArgs: - optimizer: OptimizerConfig - trainer: TrainerConfig - +class SFTConfig(TrainLmConfig): + # inherit most of the config from TrainLmConfig max_tune_length: int = 2048 # maximum length of the input to the model during tuning - - # Supervision config - supervised_data: LMSupervisedDatasetConfig = LMSupervisedDatasetConfig() - input_field: str = "instruction" # field name for input in dataset - output_field: str = "output" # field name for output in dataset - data_cache_dir: str = "cache/" # Path to cache the tokenized data - model_name_or_path: str = "meta-llama/Llama-2-7b-hf" - trust_remote_code: bool = False # Trust remote code when loading from HuggingFace checkpoints. - model_cache_dir: Optional[str] = None # Path to cache the model. must be local. + tokenizer: str = "gpt2" # Tokenizer to use - hf_save_path: Optional[str] = "sft_hf_ckpts" # Path to save the HuggingFace checkpoint - hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any) - hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint - epochs: int = 0 # Number of epochs to train for +def train(config: SFTConfig): + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot use both --initialize_from_hf and --initialize_from") -def train(config: TrainArgs): - levanter.initialize(config) + converter = HFCheckpointConverter.from_hf( + config.model_name_or_path, trust_remote_code=config.trust_remote_code + ) + else: + converter = None - converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=config.trust_remote_code) - model_config = converter.default_config + levanter.initialize(config) - if config.max_tune_length > model_config.Pos.size: - logger.warning( - f"max_tune_length ({config.max_tune_length}) is greater than the model's maximum length" - f" ({model_config.Pos.size}). " - ) + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, _, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) - training_key, data_key = jrandom.split(jrandom.PRNGKey(config.trainer.seed), 2) + if config.data_seed is not None: + logger.info(f"Overriding data seed with {config.data_seed}") + data_key = jrandom.PRNGKey(config.data_seed) tokenizer = transformers.AutoTokenizer.from_pretrained( - config.model_name_or_path, - cache_dir=config.model_cache_dir, + config.tokenizer, model_max_length=config.max_tune_length, padding_side="right", + trust_remote_code=True, ) + logger.info(f"Loaded tokenizer {tokenizer}") num_new_tokens = add_special_tokens(tokenizer) logger.info(f"Added {num_new_tokens} new tokens") - # modify converter to use our tokenizer - converter = converter.replaced(tokenizer=tokenizer) - # Configure supervised dataset supervised_config = config.supervised_data @@ -87,28 +79,46 @@ def train(config: TrainArgs): train_dataset = PermutationDataset(train_dataset, data_key) # Then wrap for epochs - if config.epochs > 0: - logger.info(f"Wrapping dataset for {config.epochs} epochs") - train_dataset = EpochDataset(train_dataset, max_epochs=config.epochs) + if config.epoch > 0: + logger.info(f"Wrapping dataset for {config.epoch} epochs") + train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) logger.info("Creating optimizer") optimizer = config.optimizer.build(config.trainer.num_train_steps) + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: parameter_axis_mapping = trainer.parameter_axis_mapping - logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") - model: LmHeadModel = converter.load_pretrained( - model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype - ) + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Pos = config.model.Pos + + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + + flops_per_token = config.model.flops_per_token(vocab_size) + flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None + trainer.add_hook( + callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 + ) loader = trainer.data_loader(train_dataset, trainer.TrainBatch) loader = non_caching_cycle(loader) - state = trainer.initial_state(training_key, model=model) - if int(state.step) != 0: logger.info(f"Resuming training from step {state.step}") for i in range(state.step): From f5533d678dd021b344aaf00d2767f2e080ccf060 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 29 Oct 2024 16:31:32 -0700 Subject: [PATCH 29/66] add back option for hf models on sft --- examples/sft/alpaca-llama-sft.yaml | 8 ++--- examples/sft/sft.py | 56 +++++++++++++++++++----------- 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml index 72c9aad78..8f1c408b4 100644 --- a/examples/sft/alpaca-llama-sft.yaml +++ b/examples/sft/alpaca-llama-sft.yaml @@ -12,7 +12,7 @@ model: use_bias: false use_layer_norm_weight: false -# Training configuration +# Training configuration trainer: mp: p=f32,c=bfloat16 tracker: @@ -21,7 +21,7 @@ trainer: tags: ["llama", "sft"] num_train_steps: 1218 train_batch_size: 64 - tensor_parallel_axes: ["mlp", "heads"] + tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" steps_per_eval: 1000 @@ -33,7 +33,7 @@ optimizer: min_lr_ratio: 0.1 warmup: 100 -# Supervised data configuration +# Supervised data configuration supervised_data: cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" input_field: "instruction" @@ -49,4 +49,4 @@ tokenizer: "allenai/OLMo-1B" max_tune_length: 2048 epoch: 3 -initialize_from_hf: false \ No newline at end of file +initialize_from_hf: false diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 74d4f6dc9..53638db12 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -5,16 +5,17 @@ import jax.random as jrandom import transformers +import haliax as hax from haliax import Axis from haliax.partitioning import round_axis_for_partitioning import levanter from levanter import callbacks -from levanter.compat.hf_checkpoints import HFCheckpointConverter, save_hf_checkpoint_callback +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset from levanter.data.text import EpochDataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig -from levanter.models.lm_model import compute_next_token_loss +from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer from levanter.utils.py_utils import non_caching_cycle @@ -33,23 +34,41 @@ class SFTConfig(TrainLmConfig): # inherit most of the config from TrainLmConfig max_tune_length: int = 2048 # maximum length of the input to the model during tuning model_name_or_path: str = "meta-llama/Llama-2-7b-hf" - tokenizer: str = "gpt2" # Tokenizer to use + tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use def train(config: SFTConfig): + tokenizer = transformers.AutoTokenizer.from_pretrained( + config.tokenizer, + model_max_length=config.max_tune_length, + padding_side="right", + trust_remote_code=True, + ) + logger.info(f"Loaded tokenizer {tokenizer}") if config.initialize_from_hf: if config.trainer.initialize_from is not None: raise ValueError("Cannot use both --initialize_from_hf and --initialize_from") - converter = HFCheckpointConverter.from_hf( - config.model_name_or_path, trust_remote_code=config.trust_remote_code - ) + assert isinstance(config.model, HFCompatConfig) + + converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True) + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + model_config = converter.default_config + else: converter = None levanter.initialize(config) + num_new_tokens = add_special_tokens(tokenizer) + logger.info(f"Added {num_new_tokens} new tokens") # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed @@ -59,16 +78,6 @@ def train(config: SFTConfig): logger.info(f"Overriding data seed with {config.data_seed}") data_key = jrandom.PRNGKey(config.data_seed) - tokenizer = transformers.AutoTokenizer.from_pretrained( - config.tokenizer, - model_max_length=config.max_tune_length, - padding_side="right", - trust_remote_code=True, - ) - logger.info(f"Loaded tokenizer {tokenizer}") - num_new_tokens = add_special_tokens(tokenizer) - logger.info(f"Added {num_new_tokens} new tokens") - # Configure supervised dataset supervised_config = config.supervised_data @@ -105,10 +114,17 @@ def train(config: SFTConfig): # tokens: gpt-2 has 50257, for example. So we round up. vocab_size = len(tokenizer) Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) - if vocab_size != Vocab.size: - logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + if config.initialize_from_hf: + logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") + model: LmHeadModel = converter.load_pretrained( + model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype + ) + model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) + state = trainer.initial_state(training_key, model=model) + else: + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) flops_per_token = config.model.flops_per_token(vocab_size) flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None From 91fc5df95edcc4b178cc1dc77d3ddd434be14fb3 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 29 Oct 2024 22:05:02 -0700 Subject: [PATCH 30/66] WIP for david --- examples/sft/alpaca-llama-fix.yaml | 55 ++++++++++++++++++++++++++++++ examples/sft/alpaca-llama-sft.yaml | 2 +- examples/sft/sft.py | 3 +- 3 files changed, 58 insertions(+), 2 deletions(-) create mode 100644 examples/sft/alpaca-llama-fix.yaml diff --git a/examples/sft/alpaca-llama-fix.yaml b/examples/sft/alpaca-llama-fix.yaml new file mode 100644 index 000000000..1590b7184 --- /dev/null +++ b/examples/sft/alpaca-llama-fix.yaml @@ -0,0 +1,55 @@ +# Model configuration +model: + activation_function: silu + gradient_checkpointing: true + hidden_dim: 4096 + initializer_range: 0.02 + intermediate_dim: 11008 + layer_norm_epsilon: 1.0e-05 + num_heads: 32 + num_kv_heads: 32 + num_layers: 32 + reference_checkpoint: meta-llama/Llama-2-7b-hf + seq_len: 4096 + type: llama + use_bias: false + use_layer_norm_weight: false + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + tracker: + type: wandb + project: "levanter-sft" + tags: ["llama", "sft"] + num_train_steps: 1218 + train_batch_size: 64 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + steps_per_eval: 1000 + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + min_lr_ratio: 0.1 + warmup: 100 + +# Supervised data configuration +supervised_data: + cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" + input_field: "instruction" + output_field: "output" + hf_dataset_name: "tatsu-lab/alpaca" # Changed from id + hf_dataset_split: "train" + name: "alpaca" # Optional metadata + tags: ["instruction-tuning"] # Optional metadata + validation_urls: [] # Empty list for no validation files + +# Additional settings +tokenizer: "allenai/OLMo-1B" +max_tune_length: 2048 +epoch: 3 + +initialize_from_hf: false \ No newline at end of file diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml index 8f1c408b4..58422c7ab 100644 --- a/examples/sft/alpaca-llama-sft.yaml +++ b/examples/sft/alpaca-llama-sft.yaml @@ -19,7 +19,7 @@ trainer: type: wandb project: "levanter-sft" tags: ["llama", "sft"] - num_train_steps: 1218 + num_train_steps: 750000 train_batch_size: 64 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 53638db12..594e1b41f 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -61,7 +61,8 @@ def train(config: SFTConfig): converter = converter.replaced(tokenizer=tokenizer) model_config = converter.default_config - + elif config.trainer.initialize_from is None: + raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: converter = None From ba682cac098e38331b9bcc2d1da5e96ea5e697b5 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 31 Oct 2024 09:58:06 -0700 Subject: [PATCH 31/66] debug epochs --- examples/sft/alpaca-llama-sft.yaml | 2 +- examples/sft/sft.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/sft/alpaca-llama-sft.yaml b/examples/sft/alpaca-llama-sft.yaml index 58422c7ab..f3667c489 100644 --- a/examples/sft/alpaca-llama-sft.yaml +++ b/examples/sft/alpaca-llama-sft.yaml @@ -47,6 +47,6 @@ supervised_data: # Additional settings tokenizer: "allenai/OLMo-1B" max_tune_length: 2048 -epoch: 3 +epoch: 0 initialize_from_hf: false diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 594e1b41f..9813184b9 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -89,9 +89,9 @@ def train(config: SFTConfig): train_dataset = PermutationDataset(train_dataset, data_key) # Then wrap for epochs - if config.epoch > 0: - logger.info(f"Wrapping dataset for {config.epoch} epochs") - train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) + # if config.epoch > 0: + # logger.info(f"Wrapping dataset for {config.epoch} epochs") + # train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) logger.info("Creating optimizer") optimizer = config.optimizer.build(config.trainer.num_train_steps) From 8e60ba9d66cb3f4c1e8ceeff84d2dae97ff9de48 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 09:53:22 -0700 Subject: [PATCH 32/66] almost there --- src/levanter/store/cache.py | 846 +++++++++++++++++-------------- src/levanter/store/tree_store.py | 3 + tests/test_new_cache.py | 116 +---- 3 files changed, 462 insertions(+), 503 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 45265c994..62752fc55 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -3,6 +3,7 @@ import copy import dataclasses import logging as pylogging +import operator import os import pprint import random @@ -12,24 +13,26 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core import humanfriendly import jax +import numpy as np import pyarrow as pa import ray +import tensorstore as ts from dataclasses_json import dataclass_json from fsspec import AbstractFileSystem from jaxtyping import PyTree from ray.actor import ActorHandle +from tqdm_loggable.auto import tqdm +from levanter.data import batched from levanter.data.dataset import AsyncDataset -from levanter.store._prefetch_actor import QueueEmpty, RayPrefetchQueue -from levanter.utils.py_utils import Stopwatch -from ..data._preprocessor import BatchProcessor, BatchProcessorPool, BatchResult, dict_from_record_batch +from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource from ..utils.ray_utils import ( @@ -40,8 +43,7 @@ log_failures_to, ser_exc_info, ) -from ..utils.thread_utils import ExceptionTrackingThread -from .jagged_array import PreparedBatch +from .jagged_array import JaggedArrayStore, PreparedBatch from .tree_store import TreeStore @@ -69,20 +71,13 @@ class CacheOptions: """ num_shard_groups: Optional[int] = 128 - """Number of groups to divide the shards into. This is used to parallelize the cache building process without - overloading Ray. If None, all shards will be in their own group.""" - shard_order_randomization_key: Optional[int] = 0 - """A key used to randomize the order of the shards before building and grouping.""" - batch_size: int = 128 - """The batch size to use when processing the data. This is used to control the memory usage of the cache building - process. Lower values will use less memory but take somewhat longer to build the cache.""" # the below options don't actually impact the cache's result, but do impact construction target_size_per_flush: int | str = "512MB" """The number of bytes to buffer before flushing to disk. This is used to control the memory usage of the cache building process. Lower values will use less memory but could take somewhat longer to build the cache.""" - prefetch_per_group: int = 4 - """The number of batches to prefetch per group. This is used to keep the processors busy and to reduce the time""" + + batch_size: int = 128 @property def target_bytes_per_flush(self): @@ -99,14 +94,14 @@ def no_fanciness(batch_size: Optional[int] = None): """ if batch_size is None: batch_size = 128 - return CacheOptions(num_shard_groups=None, shard_order_randomization_key=None, batch_size=batch_size) + return CacheOptions(num_shard_groups=None, batch_size=batch_size) @staticmethod def one_group(): """ For testing, disables all the fancy features of the cache. This makes it easier to predict the behavior """ - return CacheOptions(num_shard_groups=1, shard_order_randomization_key=None, batch_size=128) + return CacheOptions(num_shard_groups=1, batch_size=128) def build_or_load_cache( @@ -116,7 +111,6 @@ def build_or_load_cache( await_finished: bool = True, monitors: Optional[Sequence["MetricsMonitor"]] = None, options: CacheOptions = CacheOptions.default(), - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": """ @@ -144,8 +138,6 @@ def build_or_load_cache( options: Configuration for the cache. This is used to configure a few parts of the cache creation process - force_flush: for testing, forces the cache to flush after every batch. This is useful for testing. - Returns: (TreeCache) A TreeCache object that can be used to read the cache. @@ -156,7 +148,6 @@ def build_or_load_cache( shard_source=input_shards, processor=processor, options=options, - force_flush=force_flush, split=split, ) @@ -320,12 +311,11 @@ def build_or_load( shard_source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: Optional["CacheOptions"] = None, - force_flush: bool = False, split: str = "test", ) -> "TreeCache[U]": if options is None: options = CacheOptions.default() - metadata = CacheMetadata(options=options, preprocessor_metadata=processor.metadata) + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return TreeCache.load(cache_dir, processor.output_exemplar, metadata) except FileNotFoundError: @@ -334,7 +324,6 @@ def build_or_load( shard_source=shard_source, processor=processor, options=options, - force_flush=force_flush, split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -489,13 +478,11 @@ class CacheLedger: is_finished: bool = False finished_shards: List[str] = dataclasses.field(default_factory=list) field_counts: Dict[str, int] = dataclasses.field(default_factory=dict) - metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata(CacheOptions(), {})) + metadata: "CacheMetadata" = dataclasses.field(default_factory=lambda: CacheMetadata({})) @staticmethod - def load_or_initialize( - cache_dir: str, source: ShardedDataSource, processor: BatchProcessor, config: "CacheOptions" - ): - metadata = CacheMetadata(options=config, preprocessor_metadata=processor.metadata) + def load_or_initialize(cache_dir: str, source: ShardedDataSource, processor: BatchProcessor): + metadata = CacheMetadata(preprocessor_metadata=processor.metadata) try: return CacheLedger.load(cache_dir, metadata) except FileNotFoundError: @@ -531,7 +518,6 @@ def _serialize_and_commit(self, cache_dir): @dataclass_json @dataclass(frozen=True) class CacheMetadata: - options: CacheOptions = CacheOptions.default() preprocessor_metadata: Optional[dict[str, Any]] = None def compare_to(self, other: "CacheMetadata") -> deepdiff.DeepDiff: @@ -711,11 +697,10 @@ def _serialize_json_and_commit(path, obj): fs.copy(path, f"{path}.bak") for i in range(10): - with fsspec.open(f"{path}.tmp", "w") as file: - file.write(obj.to_json()) try: - fs.rename(f"{path}.tmp", path) + with fsspec.open(path, "w") as file: + file.write(obj.to_json()) break except FileNotFoundError: # this happens for some reason sometimes. It makes no sense. @@ -740,7 +725,6 @@ def __init__( source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, - force_flush: bool, ): pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) self.logger = pylogging.getLogger(f"{__name__}.{name}") @@ -751,7 +735,7 @@ def __init__( self._options = options self._updated_ledger_condition = asyncio.Condition() # used to subscribe to metrics updates - self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor, options) + self._ledger = CacheLedger.load_or_initialize(cache_dir, source, processor) if self._ledger.is_finished: self._finished_promise.set_result(None) @@ -770,7 +754,7 @@ def __init__( # (we get twice from we need to concatenate prepared batches into the accumulator) # TODO: measure. memory=2 * self._options.target_bytes_per_flush, - ).remote(current_actor_handle(), cache_dir, self._ledger, source, processor, force_flush) + ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -827,6 +811,9 @@ def _writer_exception(self, shard_name, exc_info: ExceptionInfo): pass self._do_notify() + def _child_failed(self, child: ray.actor.ActorHandle | str | None, exception: ExceptionInfo): + self._writer_exception(str(child), exception) + def _notify_updated_ledger(self, ledger: CacheLedger): """ Called by the cache writer when it has updated the ledger. @@ -855,7 +842,7 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) -def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default(), force_flush=False): +def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): name = f"lev_cache_manager::{split}::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" @@ -867,27 +854,13 @@ def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheO source=shard_source, processor=processor, options=options, - force_flush=force_flush, ) ##### # Core implementation starts below. ##### -# The main idea is to have a bunch of reader tasks that read batches, dispatch tokenization tasks, producing -# a stream of tokenized batches. We then interleave these tokenized batches and write them to the cache. -# The reader tasks are given a group of shards, which are implicitly concatenated together. - - -@dataclass -class _Batch: - """ - A batch of data that has either been read or tokenized. - """ - - shard_name: str - row_indices: List[int] - payload: ray.ObjectRef +# The main idea is to tokenize each shard group in parallel, and then write the results to the cache in order. @dataclass @@ -898,14 +871,7 @@ class _ShardFinished: shard_name: str total_rows: int - - -_Message = _Batch | _ShardFinished -""" -A message that can be sent from a reader task to the writer task. -""" - -_TIME_BETWEEN_WRITES = 20.0 # seconds + path_to_shard: str @ray.remote(num_cpus=1) @@ -914,17 +880,15 @@ def _core_writer_task( cache_dir, initial_ledger: CacheLedger, source: ShardedDataSource, + options: CacheOptions, processor, - force_flush: bool, ): """ This is the main task that processes the data and writes it to the cache. - It chains together: - * 1 generator per shard group - * interleaving of the generators - * processing of the batches - * writing of the batches to the cache + It receives "finished shards" messages from the reader tasks, and copies the data from temporary files + to the cache directory. + """ pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) logger.info("Starting writer task") @@ -933,400 +897,489 @@ def _core_writer_task( # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" - with log_failures_to(parent): - - def on_write(ledger): - ray.get(parent._notify_updated_ledger.remote(ledger)) + # We want to make sure it's there + initial_ledger._serialize_and_commit(cache_dir) - sharded_cache_writer = ShardedCacheWriter( - cache_dir, initial_ledger, processor.output_exemplar, on_write=on_write - ) + # we want to do the following: + # 1. write the 0th shard group to the output cache directly, updating metrics as we go + # 2. in the background, start processing other shard groups to temporary caches + # 3. once (1) is done, we start copying the temporary caches to the output cache (in order) + # for now we're going to punt on (1) + with log_failures_to(parent): + temporary_cache_path = os.path.join(cache_dir, "___temp") - options = initial_ledger.metadata.options - num_groups = min(options.num_shard_groups or 1000000, len(source.shard_names)) + paths: dict[str, str] = {} + ledgers: dict[str, CacheLedger | None] = {} + already_finished_paths: list[str] = [] + refs: dict[str, ray.ObjectRef] = {} - processor_pool = _mk_processor_pool(processor, 0, num_groups * 4) + shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - interleave: RayPrefetchQueue = RayPrefetchQueue( - lambda: _make_interleave(name, source, initial_ledger, processor_pool), - 64, - producer_options={"num_cpus": 1, "name": f"{name}::interleave"}, + logger.info( + f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - total_time = Stopwatch() - loading_time = Stopwatch() - append_time = Stopwatch() - flush_time = Stopwatch() - flush_amortized_time = Stopwatch() - - current_prepared_batch: Optional[PyTree[PreparedBatch]] = None - current_shard_rows: dict[str, int] = {} - time_of_last_write = time.time() - batches_total = 0.0 - flush_thread = None - finished_shards_last_flush: list = [] - - while True: - with total_time: # 0.0051 - try: - cur_time = time.time() - time_since_last_write = cur_time - time_of_last_write - remaining_time = _TIME_BETWEEN_WRITES - time_since_last_write - - if current_prepared_batch is not None: - with flush_amortized_time: # 6e-4 - current_byte_size = sum( - b.byte_size for b in jax.tree_util.tree_flatten(current_prepared_batch)[0] - ) - should_flush = ( - force_flush - or remaining_time <= 0 - or (current_byte_size >= options.target_bytes_per_flush) - ) - if should_flush: - with flush_time: # 0.613s - if flush_thread is not None: - flush_thread.join() - - flush_thread = ExceptionTrackingThread( - target=_write_batches, - args=( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ), - ) - flush_thread.start() - - current_prepared_batch = None - current_shard_rows = {} - finished_shards_last_flush = [] - - time_of_last_write = time.time() - continue - else: - remaining_time = _TIME_BETWEEN_WRITES - - with loading_time: - try: - message = interleave.get_next(timeout=max(remaining_time, 0.1)) - except QueueEmpty: - logger.info("Writer running ahead of reader.") - continue - - with append_time: - match message: - case _Batch(shard, row_indices, payload): - batches_total += 1 - this_prepared_batch = ray.get(payload) - if current_prepared_batch is None: - # TODO: actually check row indices - current_shard_rows = {shard: len(row_indices)} - current_prepared_batch = this_prepared_batch - else: - current_shard_rows[shard] = current_shard_rows.get(shard, 0) + len(row_indices) - current_prepared_batch = _concat_prepared_batches( - current_prepared_batch, this_prepared_batch - ) - del this_prepared_batch - - if force_flush: - _write_batches( - sharded_cache_writer, - current_shard_rows, - current_prepared_batch, - finished_shards_last_flush, - ) - finished_shards_last_flush = [] - current_prepared_batch = None - current_shard_rows = {} - - case _ShardFinished(shard, total_rows): - finished_shards_last_flush.append((shard, total_rows)) - case _: - raise AssertionError(f"Unexpected message type {type(message)}") - - # if batches_total % 1000 == 0: - # print( - # f"Processed {batches_total} batches: {loading_time.average()}s load," - # f" {append_time.average()}s append, {flush_time.average()}s flush blocked, " - # f"{flush_amortized_time.average()}s amortized flush, " - # f"{total_time.average()}s total" - # ) - except StopIteration: - logger.info("Finished all shards") - break - except Exception as e: - logger.exception("Error while processing batch") - raise e - - # force a flush - if current_prepared_batch is not None or finished_shards_last_flush: - if flush_thread is not None: - flush_thread.join() - _write_batches( - sharded_cache_writer, current_shard_rows, current_prepared_batch, finished_shards_last_flush + unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" + pbar = tqdm(total=len(shard_groups), desc="Tokenizing", unit=unit) + + processor_ref = ray.put(processor) + source_ref = ray.put(source) + + for group_name, shards in shard_groups.items(): + path = os.path.join(temporary_cache_path, group_name) + paths[group_name] = path + + ledger = _try_load(path) + ledgers[group_name] = ledger + + if ledger is not None: + already_finished_paths.append(path) + pbar.update(1) + continue + + ref = ( + ray.remote(_tokenize_one_shard_group) + .options( # type: ignore + num_cpus=processor.num_cpus, + num_gpus=processor.num_gpus, + resources=processor.resources, + memory=3 * 1024 * 1024 * 1024, # made this up + name=f"tokenize::{temporary_cache_path}::{group_name}", + retry_exceptions=True, + max_retries=10, + ) + .remote(os.path.join(temporary_cache_path, group_name), source_ref, shards, processor_ref, options) ) - sharded_cache_writer.finish() + refs[group_name] = ref + + # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) + # This logic is a bit hairy thanks to resumes. + # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these + # separately. We also need to update the ledger as we go. + # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. + # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. + # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. + + # * When we load the permanent cache, we have already written some number of groups to it. + # (We check this invariant with an assert) + # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. + # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" + # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate + # the data offsets for where the group goes in the permanent cache. This is just a running sum of the + # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree + # of integers, one for each array. + # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache + # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. + # * We also need to update the ledger with the total number of rows + permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) + # initialize the data offset tree + data_offset_tree = jax.tree_map(lambda x: 0, permanent_cache.tree) + total_rows_from_caches = 0 + + copy_refs: dict[str, ray.ObjectRef] = {} + last_ref: ray.ObjectRef | None = None + + for group in shard_groups: + # first make sure it's either done this run or already done + if refs.get(group) is not None: + this_ledger = ray.get(refs[group]) + ledgers[group] = ledger + else: + this_ledger = ledgers[group] + + assert this_ledger is not None + # see if we already copied this group, meaning all the shards are in the permanent cache + shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) + if shards_copied == len(shard_groups[group]): + assert initial_ledger.total_num_rows >= total_rows_from_caches + elif shards_copied > 0: + # In theory we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError("Some shards were copied but not all. This should never happen.") + else: + # we need to copy this group + ref_to_send = None if last_ref is None else RefBox(last_ref) + last_ref = _copy_cache.remote( + cache_dir, + paths[group], + processor_ref, + data_offset_tree, + ref_to_send, + total_rows_from_caches, + parent, + ) + copy_refs[group] = last_ref - out = sharded_cache_writer.get_ledger() - return out + # update the data offset tree + this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) + data_offset_tree = jax.tree.map( + operator.add, data_offset_tree, jax.tree_map(lambda x: x.data_size, this_cache.tree) + ) + total_rows_from_caches += this_ledger.total_num_rows + if last_ref is not None: + ledger = ray.get(last_ref) + else: + ledger = initial_ledger -def _concat_prepared_batches( - current_prepared_batch: PyTree[PreparedBatch], this_prepared_batch: PyTree[PreparedBatch] -): - return jax.tree.map(lambda *bs: PreparedBatch.concat(bs), current_prepared_batch, this_prepared_batch) + ledger.is_finished = True + parent._notify_updated_ledger.remote(ledger) -def _write_batches(writer: ShardedCacheWriter, shard_totals, batch: Optional[PyTree[PreparedBatch]], finished_shards): - # concatenate the payloads - if batch is not None: - writer.write_prepared_batch(shard_totals, batch) +def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: + if num_groups is None or num_groups >= len(source.shard_names): + return {shard_name: [shard_name] for shard_name in source.shard_names} - for shard, total_rows in finished_shards: - writer.finish_shard(shard, total_rows) + shard_names = source.shard_names + num_shards_per_group = len(shard_names) // num_groups + # if we have a remainder, we'll just add it to the last group + out_groups = { + f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) + for i in range(num_groups) + } + if len(shard_names) % num_shards_per_group != 0: + out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :]) + return out_groups # type: ignore -def _fetch_batches(batches) -> tuple[dict[str, int], list[PreparedBatch]]: - shards_for_batches, payloads_for_batches = zip(*batches) - payloads_for_batches = ray.get(list(payloads_for_batches)) - shard_row_totals: dict[str, int] = {} - for shard, payload in zip(shards_for_batches, payloads_for_batches): - shard_row_totals[shard] = shard_row_totals.get(shard, 0) + jax.tree.leaves(payload)[0].num_rows +def _merge_ledgers(dest, source): + dest.total_num_rows += source.total_num_rows + for shard, rows in source.shard_rows.items(): + current_value = dest.shard_rows.get(shard, 0) + assert current_value == 0, f"Shard {shard} already has {current_value} rows" + dest.shard_rows[shard] = rows - return shard_row_totals, payloads_for_batches + dest.finished_shards.extend(source.finished_shards) + dest.field_counts.update(source.field_counts) + return dest -def _interleave_shards(readers: Sequence[RayPrefetchQueue], first_index: int) -> Iterator[T]: # _Message +@ray.remote(num_cpus=4, memory=4 * 1024 * 1024 * 1024) +def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: RefBox, rows_so_far, parent): """ - Interleaves the results of multiple iterators. To support resume, - we need to be able to start from not the "first" iterator. + Copies the data from one cache to another, appending it to the end of the destination cache. + Once the copy is done and the last_ref is set, the data is "unlocked" in the destination cache by updating the + offsets[0] of the destination cache to the total number of rows in the cache. Args: - readers: A list of iterators - first_index: The index of the first iterator to start from. We use this to support resuming. + dest_path: The path to the destination cache. + source_path: The path to the source cache. + processor: The processor used to create the cache. + data_offset_tree: The data offset tree for the destination cache. + last_ref: The ref to wait on before updating the ledger. + rows_so_far: The total number of rows in the destination cache before this copy. + + Returns: + + """ + with log_failures_to(parent): + asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far)) + print("done copying", flush=True) + if last_ref is not None: + ray.wait([last_ref.ref], fetch_local=False) + print("done waiting", flush=True) + permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + dest_ledger = CacheLedger.load(dest_path) + source_ledger = CacheLedger.load(source_path) + + new_num_rows = source_ledger.total_num_rows + rows_so_far + + futures = jax.tree.leaves(jax.tree.map(lambda x: x.offsets[0].write(new_num_rows), permanent_cache.tree)) + for future in futures: + future.result() + + print("wrote rows", flush=True) + + _merge_ledgers(dest_ledger, source_ledger) + dest_ledger._serialize_and_commit(dest_path) + assert not dest_ledger.is_finished + parent._notify_updated_ledger.remote(dest_ledger) + print("done", flush=True) + return dest_ledger + + +async def _extend_cache_with_other_cache( + dest_path: str, source_path: str, processor: BatchProcessor, data_offset_tree: PyTree[int], row_offset +) -> int: """ + Copies the data from one cache to another, appending it to the end of the destination cache. - finished: set[int] = set() - total = 0 - while len(finished) < len(readers): - for i in range(first_index, len(readers)): - reader = readers[i] - if i not in finished: - try: - message = reader.get_next() - total += 1 - yield message - except StopIteration: - finished.add(i) - except Exception as e: - logger.exception(f"Error while processing group {i}") - raise e + Returns: + The number of rows in the source cache. + """ + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) + + source_num_rows = await source.async_len() + + async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + """Copies **just the data array** from one shard to the permanent cache at a given offset.""" + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data_size = source_array.data_size + data = source_array.data[0:data_size] + print(f"starting to write data. {data.read().result()=}", flush=True) + print(f"{row_offset=}", flush=True) + futures: list[ts.Future] = [] + + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + async with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) + futures.append(write_future) + + if source_array.shapes is not None: + source_shapes = source_array.shapes[0:source_num_rows] + async with ts.Transaction() as txn: + dest = dest_array.shapes + out_end = row_offset + source_num_rows + shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) + futures.append(shape_future) + print("done writing shapes", flush=True) + + source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] + source_offsets = _virtual_offset(source_offsets, data_offset) + + async with ts.Transaction() as txn: + dest = dest_array.offsets + out_end = row_offset + 1 + source_num_rows + offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) + + print("hi", flush=True) + print(f"done writing offsets {source_offsets.domain}", flush=True) + print(f"done writing offsets {dest[row_offset+1:out_end].read().result()}", flush=True) + + futures.append(offset_future) + + out = await asyncio.gather(*futures) + print("done writing", flush=True) + return out + + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) - first_index = 0 + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") - logger.info(f"Finished all shards, got {total} batches") + return source_num_rows -def _assign_shards_to_groups(shards: Sequence[_ShardStatus], num_groups: int) -> list["_ShardGroup"]: +def _virtual_offset(base: ts.TensorStore, offset_amount): """ - Assigns shards to groups in a round-robin fashion. + This function creates a new tensorstore that is a virtual offset of another tensorstore. + That is, it's y[i] = x[i] + offset_amount. """ - groups: list[list] = [[] for _ in range(num_groups)] - for i, shard in enumerate(shards): - groups[i % num_groups].append(shard) - return [_ShardGroup(group) for group in groups] + async def do_read(domain: ts.IndexDomain, array: np.ndarray, read_params: ts.VirtualChunkedReadParameters): + array[...] = (await base[domain].read()) + offset_amount -def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: - prng = random.Random(seed) - shuffled = list(shards) - prng.shuffle(shuffled) - return shuffled + return ts.virtual_chunked(do_read, dtype=base.dtype, domain=base.domain, shape=base.shape) -class _ShardGroup: - """ - Given a group of shards and a list of statuses, implicitly concatenates the shards and reads from them. +async def _copy_data_from_one_shard_to_permanent_memory( + dest_path: str, + source_path: str, + processor: BatchProcessor, + data_offset_tree: PyTree[int], +): + """Copies from one tree store to the permanent cache at a given offset (for each leaf)""" + logger.info(f"Copying data from {source_path} to {dest_path}.") + dest = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) + source = TreeStore.open(processor.output_exemplar, source_path, mode="r", cache_metadata=True) - This class mostly exists for resuming: we want to be able to start from the last shard we were working on. - """ + def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore, data_offset: int): + # TODO: it'd be good if we just didn't expose the full data array (but only the used part) + data = source_array.data[0 : source_array.data_size] + # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) + with ts.Transaction() as txn: + dest = dest_array.data + out_end = data_offset + source_array.data_size + write_future = dest.with_transaction(txn)[data_offset:out_end].write(data) - def __init__(self, group: list[_ShardStatus]): - self.shards = group - self.total_rows_committed, _all_finished = self._impute_total_rows_committed_and_check_invariants() - - def _impute_total_rows_committed_and_check_invariants(self): - # we also want to ensure that we haven't started any shards until we've finished the previous ones - total_committed = 0 - last_shard_name = None - last_was_finished = True - all_finished = True - - for status in self.shards: - shard_name = status.shard_name - if not last_was_finished and status.num_rows_committed > 0: - raise ValueError( - f"Shard {shard_name} has rows committed but previous shard in group {last_shard_name} " - "is not finished. Something about the cache configuration has changed: either the " - "number/order of shards, the shard shuffle random seed, or the number of groups." - ) - total_committed += status.num_rows_committed - if not status.is_finished: - all_finished = False - last_was_finished = status.is_finished - last_shard_name = shard_name + return write_future - return total_committed, all_finished + futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) + await asyncio.gather(*jax.tree.leaves(futures)) + logger.info(f"Finished copying data from {source_path} to {dest_path}.") + return -def _make_interleave(name: str, source: ShardedDataSource, initial_ledger: CacheLedger, processor_pool: ActorHandle): - """ - Given a list of ShardStatus objects and sources, creates an interleaving generator - that reads from shards and tokenizes them in parallel. - We use ShardStatus objects to track the progress of each shard. If we're preempted, we can resume - from the last shard we were working on. This function starts each shard at the last committed row - and starts interleaving from the next shard (i.e. the one with the fewest rows that isn't finished). - """ - logger.setLevel(DEFAULT_LOG_LEVEL) - statuses = _get_shard_statuses(initial_ledger, source) +def _tokenize_one_shard_group( + temporary_cache_path: str, + source: ShardedDataSource, + shards: list[str], + processor: BatchProcessor, + options: CacheOptions, +) -> CacheLedger: + # ray breaks if this is top level + import humanfriendly - options = initial_ledger.metadata.options + logger = pylogging.getLogger("tokenize") + pylogging.basicConfig(level=pylogging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") - unfinished_shards = _check_current_shard_progress(statuses) + # restrict shards to the ones we're supposed to process + # this is a bit hacky but when there are a lot of shards (e.g. SlimPajama 122K), + # we encounter significant overhead just parsing the shard names from the json + source = _RestrictedShardedDataSource(source, shards) - if not unfinished_shards: - logger.info("All shards finished. Nothing to do.") - return + ledger = CacheLedger.load_or_initialize(temporary_cache_path, source, processor) - group_names, groups = _randomize_and_group_shards(name, options, statuses) + if ledger.is_finished: + logger.info("Shard group already processed.") + return ledger - logger.warning(f"Starting cache build with {len(statuses)} shards, in {len(groups)} groups") + writer = ShardGroupCacheWriter(temporary_cache_path, ledger, shards, processor.output_exemplar) - def _make_generator_fn(group: _ShardGroup): - def generator(): - pylogging.basicConfig(level=DEFAULT_LOG_LEVEL, format=LOG_FORMAT) - for message in _shard_reader_generator(source, group, options.batch_size): - match message: - case _Batch(): - # processed = ray.put(process_task(ray.get(message.payload))) - # processed = process_task.remote(processor_ref, message.payload) - processed = processor_pool.process_batch.remote(RefBox(message.payload)) - yield dataclasses.replace(message, payload=processed) - case _ShardFinished(): - yield message - case _: - raise AssertionError(f"Unexpected message type {type(message)}") + total_rows = ledger.total_num_rows + found_shard_with_rows = False - return generator + for shard_name in shards: + if shard_name in ledger.finished_shards: + logger.info(f"Shard {shard_name} already processed.") + continue - generator_fns = [_make_generator_fn(group) for group in groups] + logger.debug(f"Processing {shard_name}.") - readers = [ - RayPrefetchQueue( - fn, - options.prefetch_per_group, - producer_options=dict(num_cpus=0.1, name=name, scheduling_strategy="SPREAD"), - ) - for name, fn in zip(group_names, generator_fns) - ] + rows_this_shard = ledger.shard_rows.get(shard_name, 0) - # then figure out the first shard to start from. This is the first unfinished shard with the minimum number of rows - first_group_to_start = min( - range(len(groups)), - key=lambda i: groups[i].total_rows_committed, - ) + if found_shard_with_rows and rows_this_shard != 0: + raise ValueError("Found more than one shard with rows to process.") - yield from _interleave_shards(readers, first_group_to_start) + if rows_this_shard != 0: + found_shard_with_rows = True + shard_iterator = source.open_shard_at_row(shard_name, rows_this_shard) -def _mk_processor_pool(processor, min_size, max_size): - import hashlib + prepared_batch: PyTree[PreparedBatch] | None = None + this_batch_size = 0 - metadata_hash = hashlib.md5(str(processor.metadata).encode()).hexdigest() - processor_pool_name = f"processor_pool::{metadata_hash}" - processor_pool = BatchProcessorPool.options( # type: ignore - name=processor_pool_name, get_if_exists=True, lifetime="detached" - ).remote( # type: ignore - processor, min_size, max_size - ) + for batch in batched(shard_iterator, options.batch_size): + tokenized = processor(batch) + tokenized = _canonicalize_batch(tokenized) # type: ignore + this_prepared = writer._tree_store.batch_preparer(tokenized) + + this_batch_size += len(batch) + rows_this_shard += len(batch) - ray.get(processor_pool.ensure_max_at_least.remote(max_size)) + if prepared_batch is None: + prepared_batch = this_prepared + else: + prepared_batch = jax.tree.map( + lambda *trees: PreparedBatch.concat(trees), prepared_batch, this_prepared + ) - return processor_pool + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + if batch_byte_size > options.target_bytes_per_flush: + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + nice_bytes = humanfriendly.format_size(batch_byte_size) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + this_batch_size = 0 + prepared_batch = None + + if prepared_batch is not None: + batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) + nice_bytes = humanfriendly.format_size(batch_byte_size) + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + logger.debug( + f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" + ) + this_batch_size = 0 + prepared_batch = None -def _check_current_shard_progress(statuses): - unfinished_shards: list[_ShardStatus] = [] - shards_with_progress: dict[str, int] = {} - for status in statuses: - if not status.is_finished: - unfinished_shards.append(status) - if status.num_rows_committed > 0: - shards_with_progress[status.shard_name] = status.num_rows_committed - if unfinished_shards and shards_with_progress: - formatted = ", ".join(f"{k}: {v}" for k, v in shards_with_progress.items()) - logger.info(f"Resuming from shards with progress: {formatted}") - return unfinished_shards + total_rows += rows_this_shard + writer.finish_shard(shard_name, rows_this_shard) -def _randomize_and_group_shards(name, options, statuses): - if options.shard_order_randomization_key is not None: - seed = options.shard_order_randomization_key - logger.info(f"Randomizing shard order with seed {seed}") - statuses = _randomize_shards(statuses, seed) + writer.finish() - num_groups = min( - options.num_shard_groups if options.num_shard_groups is not None else len(statuses), len(statuses) - ) - if num_groups == 1: - group_names = [f"generator::{name}::all_shards"] - elif len(statuses) == num_groups: - group_names = [f"generator::{name}::{status.shard_name}" for status in statuses] - else: - group_names = [f"generator::{name}::group_{i}" for i in range(num_groups)] + logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") - groups = _assign_shards_to_groups(statuses, num_groups) - return group_names, groups + return writer.ledger -def _shard_reader_generator( - shard_source: ShardedDataSource[T], group: _ShardGroup, batch_size: int -) -> Iterator[_Message]: +class ShardGroupCacheWriter: """ - Given a group of shards, implicitly concatenates the shards and reads from them. + Similar to SerialCacheWriter, but tracks shard metadata for one shard. """ - for status in group.shards: - if status.is_finished: - logger.info(f"Skipping finished shard {status.shard_name}") - continue - start_row = status.num_rows_committed - logger.info(f"Opening shard {status.shard_name} at row {start_row}") - shard_iter = shard_source.open_shard_at_row(status.shard_name, start_row) - batch = [] - batch_idxes = [] - row_idx = start_row - for row in shard_iter: - batch.append(row) - batch_idxes.append(row_idx) - row_idx += 1 + def __init__(self, cache_dir: str, initial_ledger: CacheLedger, shards: list[str], exemplar: T): + self.cache_dir = cache_dir - if len(batch) == batch_size: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) - batch = [] - batch_idxes = [] + self._ledger = copy.deepcopy(initial_ledger) + self.shards = shards - if len(batch) > 0: - yield _Batch(status.shard_name, batch_idxes, ray.put(batch)) + self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore + self._tree_store.trim_to_size(self._ledger.total_num_rows) - logger.info(f"Finished generating shard {status.shard_name} with {row_idx} rows") - yield _ShardFinished(status.shard_name, row_idx) + @property + def ledger(self): + return self._ledger + + # we have both versions b/c we need this one for actors + def get_ledger(self): + return self._ledger + + @property + def is_finished(self): + return self._ledger.is_finished + + def finish_shard(self, shard_name: str, num_rows: int): + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + + current_rows = self._ledger.shard_rows.get(shard_name, 0) + if current_rows != num_rows: + raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") + + self._ledger.finished_shards.append(shard_name) + self._ledger._serialize_and_commit(self.cache_dir) + + def write_prepared_batch(self, shard_name: str, row_count: int, batch: PyTree[PreparedBatch]): + if self.is_finished: + raise RuntimeError("Cannot write to a finished cache") + self._tree_store.extend_with_batch(batch) + + if shard_name not in self.shards: + raise ValueError(f"Shard {shard_name} not in tracked shards") + self._ledger.shard_rows[shard_name] += row_count + self._ledger.total_num_rows += row_count + + self._ledger._serialize_and_commit(self.cache_dir) + + def finish(self): + if len(self._ledger.finished_shards) != len(self.shards): + raise ValueError("Not all shards are finished") + + self._ledger.is_finished = True + self._ledger._serialize_and_commit(self.cache_dir) + # ensure all tracked shards are finished + + return self._tree_store + + +class _RestrictedShardedDataSource(ShardedDataSource): + def __init__(self, source: ShardedDataSource, shards: list[str]): + self._source = source + self._shards = shards + + @property + def shard_names(self): + return self._shards + + def open_shard_at_row(self, shard_name, row): + return self._source.open_shard_at_row(shard_name, row) + + +def _randomize_shards(shards: Sequence[T], seed: int) -> list[T]: + prng = random.Random(seed) + shuffled = list(shards) + prng.shuffle(shuffled) + return shuffled def _canonicalize_batch(batch: Union[dict, List[dict]]) -> List[dict]: @@ -1360,8 +1413,13 @@ def _ledger_to_metrics(ledger: CacheLedger) -> InProgressCacheMetrics: ) -def _get_shard_statuses(ledger: CacheLedger, source: ShardedDataSource): - return [ - _ShardStatus(name, ledger.shard_rows.get(name, 0), name in ledger.finished_shards) - for name in source.shard_names - ] +def _try_load(path): + try: + ledger = CacheLedger.load(path) + if ledger.is_finished: + return ledger + else: + logger.debug(f"Cache exists but is not finished at {path}.") + return None + except FileNotFoundError: + return None diff --git a/src/levanter/store/tree_store.py b/src/levanter/store/tree_store.py index 03355a8d2..83d6c88b0 100644 --- a/src/levanter/store/tree_store.py +++ b/src/levanter/store/tree_store.py @@ -172,6 +172,9 @@ def get_batch_sync(self, indices) -> List[T]: return out + async def async_len(self) -> int: + return await jax.tree.leaves(self.tree)[0].num_rows_async() + def _construct_builder_tree(exemplar, path, mode, cache_metadata): def open_builder(tree_path, item): diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index c1eb73670..82bf045c7 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -1,6 +1,4 @@ import asyncio -import copy -import os import tempfile from typing import Any, Dict, Iterator, Sequence @@ -10,17 +8,7 @@ from levanter.data import BatchProcessor, ShardedDataSource, batched from levanter.data.sharded_datasource import TextUrlDataSource -from levanter.store.cache import ( - LEDGER_FILE_NAME, - CacheLedger, - CacheOptions, - SerialCacheWriter, - ShardedCacheWriter, - TreeStore, - _get_builder_actor, - _serialize_json_and_commit, - build_or_load_cache, -) +from levanter.store.cache import CacheOptions, SerialCacheWriter, TreeStore, _get_builder_actor, build_or_load_cache from levanter.utils.py_utils import logical_cpu_core_count @@ -146,7 +134,7 @@ def test_full_end_to_end_cache(): options=CacheOptions.no_fanciness(8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=2), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=2)) all_data = ray_ds[:] @@ -162,15 +150,14 @@ def test_full_end_to_end_cache_with_groups(): SimpleShardSource(num_shards=5), TestProcessor(), await_finished=True, - options=CacheOptions(num_shard_groups=2, batch_size=8, shard_order_randomization_key=None), + options=CacheOptions(num_shard_groups=2, batch_size=8), ) - expected = process_interleave(TestProcessor(), SimpleShardSource(num_shards=5), 8) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=5)) all_data = ray_ds[:] - # check_datasets_equal(all_data, expected) - assert len(all_data) == len(list(expected)) + check_datasets_equal(all_data, expected) @pytest.mark.ray @@ -295,7 +282,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now block until the cache is done cache.await_finished(timeout=30) - expected = process_interleave(processor, SlowShardSource(), 16) + expected = simple_process(processor, SlowShardSource()) check_datasets_equal(list(cache[:]), expected) @@ -334,9 +321,8 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: SlowShardSource(), TestProcessor(), await_finished=False, - force_flush=True, options=CacheOptions.no_fanciness(5), - ) # we need force_flush to ensure the cache is written to disk + ) # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] @@ -364,7 +350,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: cache.await_finished(timeout=10) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray def test_shard_cache_crashes_if_processor_throws(): class ThrowingProcessor(SimpleProcessor): @@ -398,7 +383,6 @@ def test_shard_cache_fails_with_multiple_shards_with_the_same_name(): build_or_load_cache(tmpdir, dataset, TestProcessor(), await_finished=True) -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") @pytest.mark.ray @pytest.mark.asyncio async def test_shard_cache_fails_gracefully_with_unknown_file_type_async(): @@ -451,89 +435,3 @@ def test_shard_cache_fails_gracefully_with_unknown_file_type(): cache.await_finished(timeout=10) del cache - - -def test_sharded_cache_writer(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(8)) - - exemplar = {"data": np.array([0], dtype=np.int64)} - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), ledger.metadata.options.batch_size): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, source._rows_per_shard) - - store = writer.finish() - - data_path = store.path - - del store - - builder = TreeStore.open(exemplar, data_path, mode="r") - - assert len(builder) == 40 - - for i, x in enumerate(builder): - np.testing.assert_array_equal(x["data"], np.asarray([i % 10 + i // 10 * 10] * 10)) - - # check totals for the ledger - ledger = writer.ledger - assert ledger.total_num_rows == 40 - assert ledger.is_finished - - for shard_name in source.shard_names: - assert ledger.shard_rows[shard_name] == 10 - - -def test_sharded_cache_writer_trims_on_resume(): - with tempfile.TemporaryDirectory() as tmpdir: - source = SimpleShardSource(num_shards=4) - processor = SimpleProcessor() - - exemplar = {"data": np.array([0], dtype=np.int64)} - - ledger = CacheLedger.load_or_initialize(tmpdir, source, processor, CacheOptions.no_fanciness(batch_size=8)) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - for shard_name in source.shard_names: - for ex in batched(source.open_shard(shard_name), 8): - writer.write_batch(shard_name, processor(ex)) - - for shard_name in source.shard_names: - writer.finish_shard(shard_name, 10) - - writer.finish() - - # now deliberately truncate the ledger a bit - ledger = copy.deepcopy(writer.ledger) - assert ledger.total_num_rows == 40 - assert ledger.is_finished - ledger.total_num_rows = 24 - ledger.shard_rows["shard_0"] = 8 - ledger.shard_rows["shard_1"] = 8 - ledger.shard_rows["shard_2"] = 8 - ledger.shard_rows["shard_3"] = 0 - ledger.is_finished = False - - _serialize_json_and_commit(os.path.join(tmpdir, LEDGER_FILE_NAME), ledger) - - writer = ShardedCacheWriter(tmpdir, ledger, exemplar) - - # ensure it got truncated - assert writer.ledger.total_num_rows == 24 - assert writer.ledger.is_finished is False - assert writer.ledger.shard_rows["shard_0"] == 8 - assert writer.ledger.shard_rows["shard_1"] == 8 - assert writer.ledger.shard_rows["shard_2"] == 8 - assert writer.ledger.shard_rows["shard_3"] == 0 - - new_store = writer._tree_store - new_data = new_store[:] - - assert len(new_data) == 24 From f742ba75d4216097d2931d2cdc44a169539220ff Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 09:54:53 -0700 Subject: [PATCH 33/66] crash the test for now --- tests/test_new_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 82bf045c7..c61c66105 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -326,7 +326,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] - first_10 = list(await cache.get_batch(range(0, 10))) + first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=10.0)) for i, x in enumerate(first_10): np.testing.assert_array_equal(x["test"], np.array([i] * 10)) @@ -339,7 +339,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # now ensure we can get the next 10 elements, which will be # [{"test": np.array([i] * 10)} for i in range(10, 20)] - batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10) + batch = await asyncio.wait_for(cache.get_batch(range(10, 20)), timeout=10.0) for i, x in enumerate(batch): np.testing.assert_array_equal(x["test"], np.array([i + 10] * 10)) From e9c03a0e9f1516d7e707e86cbe367e497d7e8f23 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 11:56:30 -0700 Subject: [PATCH 34/66] cleanup temp filesqq --- src/levanter/data/text.py | 6 +- src/levanter/store/cache.py | 119 +++++------------------------ src/levanter/utils/fsspec_utils.py | 28 ++++++- 3 files changed, 49 insertions(+), 104 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..de6980430 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,7 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore -from levanter.utils.fsspec_utils import fsspec_expand_glob +from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -508,7 +508,7 @@ def urls_for_split(self, split): else: raise ValueError(f"Unknown split {split}") - urls = [globbed for url in urls for globbed in fsspec_expand_glob(url)] + urls = [globbed for url in urls for globbed in expand_glob(url)] return urls @@ -625,7 +625,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): import levanter.data - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in expand_glob(url_pat)] dataset = levanter.data.datasource_from_jsonl(validation_urls) input_field = config.input_field diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 62752fc55..4e8244882 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -13,7 +13,7 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core @@ -27,6 +27,7 @@ from fsspec import AbstractFileSystem from jaxtyping import PyTree from ray.actor import ActorHandle +from ray.runtime_env import RuntimeEnv from tqdm_loggable.auto import tqdm from levanter.data import batched @@ -35,6 +36,7 @@ from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource +from ..utils.fsspec_utils import async_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -538,13 +540,6 @@ def empty(): return CacheMetadata() -@dataclass -class _ShardStatus: - shard_name: str - num_rows_committed: int - is_finished: bool - - class SerialCacheWriter(AbstractContextManager): """ Writes TreeCache-compatible caches to disk. This is a serial version of TreeCacheWriter that doesn't use Ray. @@ -602,91 +597,6 @@ def write_batch(self, batch: BatchResult): self._tree_store.extend(cbatch) -class ShardedCacheWriter: - """ - Similar to SerialCacheWriter, but tracks shard metadata. - - Similar to _OrderedCacheWriter, it also supports resuming, and it - groups together batches before writing (at some interval) in order to improve performance. - """ - - def __init__( - self, - cache_dir: str, - initial_ledger: CacheLedger, - exemplar: T, - on_write: Optional[Callable[[CacheLedger], None]] = None, - ): - self.cache_dir = cache_dir - self._on_write = on_write - - self._ledger = copy.deepcopy(initial_ledger) - - self._tree_store = TreeStore.open(exemplar, self.cache_dir, mode="a") # type: ignore - self._tree_store.trim_to_size(self._ledger.total_num_rows) - - @property - def ledger(self): - return self._ledger - - # we have both versions b/c we need this one for actors - def get_ledger(self): - return self._ledger - - @property - def is_finished(self): - return self._ledger.is_finished - - def finish_shard(self, shard_name: str, num_rows: int): - current_rows = self._ledger.shard_rows.get(shard_name, 0) - if current_rows != num_rows: - raise ValueError(f"Expected {num_rows} rows in finished shard {shard_name}, but found {current_rows}") - - self._ledger.finished_shards.append(shard_name) - self._ledger._serialize_and_commit(self.cache_dir) - - def write_prepared_batch(self, shard_counts: Mapping[str, int], batch: PyTree[PreparedBatch]): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - self._tree_store.extend_with_batch(batch) - - for shard, num_rows in shard_counts.items(): - self._ledger.shard_rows[shard] = self._ledger.shard_rows.get(shard, 0) + num_rows - - total_rows = self._ledger.total_num_rows + sum(shard_counts.values()) - self._ledger.total_num_rows = total_rows - self._ledger._serialize_and_commit(self.cache_dir) - - if self._on_write: - self._on_write(self._ledger) - - def write_batch(self, shard_name: str, batch: BatchResult): - if self.is_finished: - raise RuntimeError("Cannot write to a finished cache") - - if isinstance(batch, pa.RecordBatch): - raise NotImplementedError("Only non-RecordBatch batches are supported for now") - - batch = _canonicalize_batch(batch) # type: ignore - prepared = self._tree_store.batch_preparer(batch) - - return self.write_prepared_batch({shard_name: len(batch)}, prepared) - - def finish(self): - # if successful, write the ledger - logger.info("Finished writing cache") - # check that all shards are finished - if set(self._ledger.shard_rows.keys()) != set(self._ledger.finished_shards): - raise ValueError("Not all shards are finished") - - self._ledger.is_finished = True - self._ledger._serialize_and_commit(self.cache_dir) - if self._on_write: - self._on_write(self._ledger) - - return self._tree_store - - def _serialize_json_and_commit(path, obj): # just to be paranoid, we write to a temp file and then rename it # TODO: probably we could do better here @@ -709,7 +619,9 @@ def _serialize_json_and_commit(path, obj): pass -@ray.remote(num_cpus=0.1) # keep this small b/c it doesn't do a lot +@ray.remote( + num_cpus=0.1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"}) +) # keep this small b/c it doesn't do a lot class _TreeStoreCacheBuilder(SnitchRecipient): """ Actor that coordinates the building of a cache. It spins up a bunch of workers to read from each shard @@ -1013,7 +925,7 @@ def _core_writer_task( # update the data offset tree this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( - operator.add, data_offset_tree, jax.tree_map(lambda x: x.data_size, this_cache.tree) + operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) ) total_rows_from_caches += this_ledger.total_num_rows @@ -1025,6 +937,16 @@ def _core_writer_task( ledger.is_finished = True parent._notify_updated_ledger.remote(ledger) + # clean up the temporary caches + async def cleanup(): + futures = [] + for path in already_finished_paths: + futures.append(async_remove(path, recursive=True)) + + await asyncio.gather(*futures) + + asyncio.run(cleanup()) + def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: if num_groups is None or num_groups >= len(source.shard_names): @@ -1075,10 +997,8 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R """ with log_failures_to(parent): asyncio.run(_extend_cache_with_other_cache(dest_path, source_path, processor, data_offset_tree, rows_so_far)) - print("done copying", flush=True) if last_ref is not None: ray.wait([last_ref.ref], fetch_local=False) - print("done waiting", flush=True) permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) dest_ledger = CacheLedger.load(dest_path) source_ledger = CacheLedger.load(source_path) @@ -1089,13 +1009,12 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R for future in futures: future.result() - print("wrote rows", flush=True) - _merge_ledgers(dest_ledger, source_ledger) dest_ledger._serialize_and_commit(dest_path) assert not dest_ledger.is_finished + parent._notify_updated_ledger.remote(dest_ledger) - print("done", flush=True) + return dest_ledger diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index 64870443d..cc03c174b 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,5 +1,10 @@ +import asyncio + import braceexpand import fsspec +from fsspec.asyn import AsyncFileSystem + +from levanter.utils.thread_utils import _executor, blocking_wait def exists(url, **kwargs) -> bool: @@ -14,7 +19,7 @@ def mkdirs(path): fs.makedirs(path, exist_ok=True) -def fsspec_expand_glob(url): +def expand_glob(url): expanded_urls = braceexpand.braceexpand(url) for expanded_url in expanded_urls: if "*" in expanded_url: @@ -28,3 +33,24 @@ def fsspec_expand_glob(url): yield from [f"{protocol}://{path}" for path in globbed] else: yield expanded_url + + +def remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + # TODO: better to use a STS deletion policy or job for this one. + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + if isinstance(fs, AsyncFileSystem): + blocking_wait(fs._rm(path, recursive=recursive)) + else: + fs.rm(path, recursive=recursive) + + +async def async_remove(url, *, recursive=False, **kwargs): + """Remove a file from a remote filesystem.""" + fs, path = fsspec.core.url_to_fs(url, **kwargs) + + if isinstance(fs, AsyncFileSystem): + return await fs._rm(path, recursive=recursive) + else: + return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive)) From 20cae0c5d5ca5f7fb3d91e4a8c4a82b0539c3d99 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 11:59:26 -0700 Subject: [PATCH 35/66] more cleanup --- src/levanter/store/cache.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 4e8244882..aee2bc3d0 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -1038,8 +1038,6 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra # TODO: it'd be good if we just didn't expose the full data array (but only the used part) data_size = source_array.data_size data = source_array.data[0:data_size] - print(f"starting to write data. {data.read().result()=}", flush=True) - print(f"{row_offset=}", flush=True) futures: list[ts.Future] = [] # write_future = dest_array.data[data_offset : data_offset + source_array.data_size].write(data) @@ -1056,7 +1054,6 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra out_end = row_offset + source_num_rows shape_future = dest.with_transaction(txn)[row_offset:out_end].write(source_shapes) futures.append(shape_future) - print("done writing shapes", flush=True) source_offsets = source_array.offsets[1 : source_num_rows + 1][ts.d[:].translate_to[0]] source_offsets = _virtual_offset(source_offsets, data_offset) @@ -1066,14 +1063,9 @@ async def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArra out_end = row_offset + 1 + source_num_rows offset_future = dest.with_transaction(txn)[row_offset + 1 : out_end].write(source_offsets) - print("hi", flush=True) - print(f"done writing offsets {source_offsets.domain}", flush=True) - print(f"done writing offsets {dest[row_offset+1:out_end].read().result()}", flush=True) - futures.append(offset_future) out = await asyncio.gather(*futures) - print("done writing", flush=True) return out futures = jax.tree.map(_copy_one_array, dest.tree, source.tree, data_offset_tree) From cb85654799076b3db1b603cbe0c89294613d7c72 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 1 Nov 2024 16:47:58 -0700 Subject: [PATCH 36/66] ok, we are incremental! --- src/levanter/store/cache.py | 137 +++++++++++++++++++++++++++++------- tests/test_new_cache.py | 4 +- 2 files changed, 114 insertions(+), 27 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index aee2bc3d0..f01e3b881 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -13,11 +13,10 @@ from concurrent.futures import Future as threading_Future from contextlib import AbstractContextManager from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union import deepdiff import fsspec.core -import humanfriendly import jax import numpy as np import pyarrow as pa @@ -37,6 +36,7 @@ from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource from ..utils.fsspec_utils import async_remove +from ..utils.fsspec_utils import exists as fsspec_exists from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -83,6 +83,10 @@ class CacheOptions: @property def target_bytes_per_flush(self): + if isinstance(self.target_size_per_flush, int): + return self.target_size_per_flush + import humanfriendly + return humanfriendly.parse_size(self.target_size_per_flush) @staticmethod @@ -667,6 +671,9 @@ def __init__( # TODO: measure. memory=2 * self._options.target_bytes_per_flush, ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) + + self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard") + except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here # this also propagates to the finished promise, so we can handle it there @@ -753,6 +760,19 @@ async def _do_notify_async(): asyncio.create_task(_do_notify_async()) + def _report_progress(self, report: "_ProgressReport"): + import humanfriendly + + self._tokenize_pbar.update(report.total_shards_completed) + mb_str = humanfriendly.format_size(report.total_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": report.total_rows, + "shards": report.total_shards_completed, + "mb": mb_str, + } + ) + def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): name = f"lev_cache_manager::{split}::{cache_dir}" @@ -816,14 +836,22 @@ def _core_writer_task( # 1. write the 0th shard group to the output cache directly, updating metrics as we go # 2. in the background, start processing other shard groups to temporary caches # 3. once (1) is done, we start copying the temporary caches to the output cache (in order) - # for now we're going to punt on (1) + + # We notify the parent actor of progress and updates to the ledger. + # We special-case the 0'th ledger because we commit it to the output cache directly. + def report_fn(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + + def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): + parent._report_progress.remote(report) + parent._notify_updated_ledger.remote(ledger) + with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") paths: dict[str, str] = {} ledgers: dict[str, CacheLedger | None] = {} - already_finished_paths: list[str] = [] - refs: dict[str, ray.ObjectRef] = {} + write_refs: dict[str, ray.ObjectRef] = {} shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) @@ -832,23 +860,32 @@ def _core_writer_task( ) unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" - pbar = tqdm(total=len(shard_groups), desc="Tokenizing", unit=unit) processor_ref = ray.put(processor) source_ref = ray.put(source) + # We treat the first group specially: we tokenize it directly to the output cache (since it comes first) + # This enables us to expose data quickly + first_group = next(iter(shard_groups), None) + for group_name, shards in shard_groups.items(): - path = os.path.join(temporary_cache_path, group_name) - paths[group_name] = path + if group_name == first_group: + group_out_path = cache_dir + else: + group_out_path = os.path.join(temporary_cache_path, group_name) - ledger = _try_load(path) + paths[group_name] = group_out_path + + ledger = _try_load(group_out_path) ledgers[group_name] = ledger if ledger is not None: - already_finished_paths.append(path) - pbar.update(1) + if group_name == first_group: + parent._notify_updated_ledger.remote(ledger) continue + report_fn = report_fn_first_group if group_name == first_group else report_fn + ref = ( ray.remote(_tokenize_one_shard_group) .options( # type: ignore @@ -860,10 +897,10 @@ def _core_writer_task( retry_exceptions=True, max_retries=10, ) - .remote(os.path.join(temporary_cache_path, group_name), source_ref, shards, processor_ref, options) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn, parent) ) - refs[group_name] = ref + write_refs[group_name] = ref # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) # This logic is a bit hairy thanks to resumes. @@ -891,11 +928,13 @@ def _core_writer_task( copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None + copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1) for group in shard_groups: # first make sure it's either done this run or already done - if refs.get(group) is not None: - this_ledger = ray.get(refs[group]) + if write_refs.get(group) is not None: + this_ledger = ray.get(write_refs[group]) + ledgers[group] = ledger else: this_ledger = ledgers[group] @@ -903,8 +942,10 @@ def _core_writer_task( assert this_ledger is not None # see if we already copied this group, meaning all the shards are in the permanent cache shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) - if shards_copied == len(shard_groups[group]): + if shards_copied == len(shard_groups[group]) or group == first_group: assert initial_ledger.total_num_rows >= total_rows_from_caches + copying_pbar.update(1) + elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now raise RuntimeError("Some shards were copied but not all. This should never happen.") @@ -922,30 +963,49 @@ def _core_writer_task( ) copy_refs[group] = last_ref - # update the data offset tree + # update the offset information: data offsets and total rows this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) ) total_rows_from_caches += this_ledger.total_num_rows + # this little bit is totally unnecessary but nice logging + for group in shard_groups: + if group == first_group: + continue + + if copy_refs.get(group) is not None: + ray.wait([copy_refs[group]], fetch_local=False) + copying_pbar.update(1) + + # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: ledger = ray.get(last_ref) else: ledger = initial_ledger ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) parent._notify_updated_ledger.remote(ledger) # clean up the temporary caches - async def cleanup(): - futures = [] - for path in already_finished_paths: + _clean_up_temp_caches(paths, first_group) + + +def _clean_up_temp_caches(paths, first_group): + async def cleanup(): + futures = [] + for group, path in paths.items(): + if group == first_group: + continue + + if fsspec_exists(path): futures.append(async_remove(path, recursive=True)) - await asyncio.gather(*futures) + await asyncio.gather(*futures) - asyncio.run(cleanup()) + asyncio.run(cleanup()) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: @@ -1117,12 +1177,22 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore return +@dataclass +class _ProgressReport: + total_rows: int + total_bytes: float + total_shards_completed: int + # TODO: other counts + + def _tokenize_one_shard_group( temporary_cache_path: str, source: ShardedDataSource, shards: list[str], processor: BatchProcessor, options: CacheOptions, + report_fn: Callable[[_ProgressReport, CacheLedger], None], + force_unfinalized: bool, ) -> CacheLedger: # ray breaks if this is top level import humanfriendly @@ -1146,9 +1216,13 @@ def _tokenize_one_shard_group( total_rows = ledger.total_num_rows found_shard_with_rows = False + report = _ProgressReport(total_rows, 0, 0) + for shard_name in shards: if shard_name in ledger.finished_shards: logger.info(f"Shard {shard_name} already processed.") + report.total_shards_completed += 1 + report_fn(report, ledger) continue logger.debug(f"Processing {shard_name}.") @@ -1185,16 +1259,27 @@ def _tokenize_one_shard_group( if batch_byte_size > options.target_bytes_per_flush: writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) + report.total_rows += this_batch_size + report.total_bytes += batch_byte_size + + report_fn(report, writer.ledger) + nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" ) + # print(f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})", flush=True) this_batch_size = 0 prepared_batch = None if prepared_batch is not None: batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) + + report.total_rows += this_batch_size + report.total_bytes += batch_byte_size + report_fn(report, writer.ledger) + writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) logger.debug( f"Processed {rows_this_shard} rows. Wrote {this_batch_size} rows to {shard_name}. ({nice_bytes})" @@ -1202,11 +1287,13 @@ def _tokenize_one_shard_group( this_batch_size = 0 prepared_batch = None - total_rows += rows_this_shard - + report.total_shards_completed += 1 writer.finish_shard(shard_name, rows_this_shard) - writer.finish() + report_fn(report, writer.ledger) + + if not force_unfinalized: + writer.finish() logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index c61c66105..0edd1fb15 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -321,12 +321,12 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: SlowShardSource(), TestProcessor(), await_finished=False, - options=CacheOptions.no_fanciness(5), + options=CacheOptions(target_size_per_flush=1, batch_size=1), ) # read the first 10 elements # ensure the first 10 elements are [{"test": np.array([i] * 10)} for i in range(10)] - first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=10.0)) + first_10 = list(await asyncio.wait_for(cache.get_batch(range(0, 10)), timeout=30.0)) for i, x in enumerate(first_10): np.testing.assert_array_equal(x["test"], np.array([i] * 10)) From 47441c0fd12b91eada4e4133a515f007a074de16 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 2 Nov 2024 22:59:13 -0700 Subject: [PATCH 37/66] a bit worried the bookkeeping isn't quite right on resume, but we're almost there. --- config/gpt2_small_fast_pile.yaml | 2 +- src/levanter/store/cache.py | 122 ++++++++++++++++++----------- src/levanter/utils/fsspec_utils.py | 11 +-- 3 files changed, 79 insertions(+), 56 deletions(-) diff --git a/config/gpt2_small_fast_pile.yaml b/config/gpt2_small_fast_pile.yaml index 3a21732a7..291213d75 100644 --- a/config/gpt2_small_fast_pile.yaml +++ b/config/gpt2_small_fast_pile.yaml @@ -1,4 +1,4 @@ -data: !include data/pile_source_old.yaml +data: !include data/pile_mixture.yaml model: type: gpt2 hidden_dim: 768 diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index f01e3b881..cd59aef4c 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -35,8 +35,8 @@ from ..data._preprocessor import BatchProcessor, BatchResult, dict_from_record_batch from ..data.metrics_monitor import InProgressCacheMetrics, LoggerMetricsMonitor, MetricsMonitor from ..data.sharded_datasource import ShardedDataSource -from ..utils.fsspec_utils import async_remove from ..utils.fsspec_utils import exists as fsspec_exists +from ..utils.fsspec_utils import remove as fsspec_remove from ..utils.ray_utils import ( ExceptionInfo, RefBox, @@ -672,7 +672,13 @@ def __init__( memory=2 * self._options.target_bytes_per_flush, ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) - self._tokenize_pbar = tqdm(total=len(source.shard_names), desc="Tokenizing", unit="shard") + self._tokenize_pbar = tqdm( + total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" + ) + self._copy_pbar = tqdm(total=len(source.shard_names), desc=f"{path_for_name}: copying", unit="shard") + self._report_totals = _ProgressReport(0, 0, 0) + self._copy_report_totals = _ProgressReport(0, 0, 0) + self._last_update = time.time() except Exception: # Ray behaves poorly if the constructor of an actor fails, so we catch and log here @@ -763,15 +769,39 @@ async def _do_notify_async(): def _report_progress(self, report: "_ProgressReport"): import humanfriendly - self._tokenize_pbar.update(report.total_shards_completed) - mb_str = humanfriendly.format_size(report.total_bytes) - self._tokenize_pbar.set_postfix( - { - "rows": report.total_rows, - "shards": report.total_shards_completed, - "mb": mb_str, - } - ) + self._tokenize_pbar.update(report.new_shards) + self._report_totals.new_shards += report.new_shards + self._report_totals.new_rows += report.new_rows + self._report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + + mb_str = humanfriendly.format_size(self._report_totals.new_bytes) + self._tokenize_pbar.set_postfix( + { + "rows": self._report_totals.new_rows, + "shards": self._report_totals.new_shards, + "size": mb_str, + } + ) + + def _report_copy_progress(self, report: "_ProgressReport"): + # TODO: log bytes copied + self._copy_pbar.update(report.new_shards) + self._copy_report_totals.new_shards += report.new_shards + self._copy_report_totals.new_rows += report.new_rows + self._copy_report_totals.new_bytes += report.new_bytes + + if time.time() - self._last_update > 10.0: + self._last_update = time.time() + self._copy_pbar.set_postfix( + { + "shards": report.new_shards, + "rows": report.new_rows, + # "size": humanfriendly.format_size(report.new_bytes), + } + ) def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): @@ -855,12 +885,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) - logger.info( + logger.debug( f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) - unit = "shard" if len(shard_groups) == len(source.shard_names) else "shard group" - processor_ref = ray.put(processor) source_ref = ray.put(source) @@ -928,7 +956,6 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): copy_refs: dict[str, ray.ObjectRef] = {} last_ref: ray.ObjectRef | None = None - copying_pbar = tqdm(total=len(shard_groups), desc="Copying", unit=unit, leave=False, position=1) for group in shard_groups: # first make sure it's either done this run or already done @@ -944,7 +971,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) if shards_copied == len(shard_groups[group]) or group == first_group: assert initial_ledger.total_num_rows >= total_rows_from_caches - copying_pbar.update(1) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows) + ) elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now @@ -976,8 +1005,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): continue if copy_refs.get(group) is not None: - ray.wait([copy_refs[group]], fetch_local=False) - copying_pbar.update(1) + ledger = ray.get(copy_refs[group]) + ledgers[group] = ledger + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) + ) # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: @@ -989,23 +1021,23 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger._serialize_and_commit(cache_dir) parent._notify_updated_ledger.remote(ledger) - # clean up the temporary caches _clean_up_temp_caches(paths, first_group) def _clean_up_temp_caches(paths, first_group): - async def cleanup(): - futures = [] - for group, path in paths.items(): - if group == first_group: - continue - - if fsspec_exists(path): - futures.append(async_remove(path, recursive=True)) - - await asyncio.gather(*futures) + for group, path in paths.items(): + if group == first_group: + continue - asyncio.run(cleanup()) + if fsspec_exists(path): + for i in range(10): + # this is crashy for some reason + try: + fsspec_remove(path, recursive=True) + break + except Exception: + logger.exception(f"Failed to remove {path} on attempt {i}") + time.sleep(1) def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) -> dict[str, Sequence[str]]: @@ -1074,6 +1106,9 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R assert not dest_ledger.is_finished parent._notify_updated_ledger.remote(dest_ledger) + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) + ) return dest_ledger @@ -1179,9 +1214,9 @@ def _copy_one_array(dest_array: JaggedArrayStore, source_array: JaggedArrayStore @dataclass class _ProgressReport: - total_rows: int - total_bytes: float - total_shards_completed: int + new_rows: int = 0 + new_bytes: float = 0 + new_shards: int = 0 # TODO: other counts @@ -1216,13 +1251,13 @@ def _tokenize_one_shard_group( total_rows = ledger.total_num_rows found_shard_with_rows = False - report = _ProgressReport(total_rows, 0, 0) + if total_rows > 0: + report_fn(_ProgressReport(new_rows=total_rows), ledger) for shard_name in shards: if shard_name in ledger.finished_shards: logger.info(f"Shard {shard_name} already processed.") - report.total_shards_completed += 1 - report_fn(report, ledger) + report_fn(_ProgressReport(new_shards=1), ledger) continue logger.debug(f"Processing {shard_name}.") @@ -1247,6 +1282,7 @@ def _tokenize_one_shard_group( this_batch_size += len(batch) rows_this_shard += len(batch) + total_rows += len(batch) if prepared_batch is None: prepared_batch = this_prepared @@ -1259,10 +1295,7 @@ def _tokenize_one_shard_group( if batch_byte_size > options.target_bytes_per_flush: writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) nice_bytes = humanfriendly.format_size(batch_byte_size) logger.debug( @@ -1276,9 +1309,7 @@ def _tokenize_one_shard_group( batch_byte_size = sum(prepared_batch.byte_size for prepared_batch in jax.tree.leaves(prepared_batch)) nice_bytes = humanfriendly.format_size(batch_byte_size) - report.total_rows += this_batch_size - report.total_bytes += batch_byte_size - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_rows=this_batch_size, new_bytes=batch_byte_size), writer.ledger) writer.write_prepared_batch(shard_name, this_batch_size, prepared_batch) logger.debug( @@ -1287,15 +1318,14 @@ def _tokenize_one_shard_group( this_batch_size = 0 prepared_batch = None - report.total_shards_completed += 1 writer.finish_shard(shard_name, rows_this_shard) - report_fn(report, writer.ledger) + report_fn(_ProgressReport(new_shards=1), writer.ledger) if not force_unfinalized: writer.finish() - logger.info(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") + logger.debug(f"Finished processing {len(shards)} shards. Wrote {total_rows} rows.") return writer.ledger diff --git a/src/levanter/utils/fsspec_utils.py b/src/levanter/utils/fsspec_utils.py index cc03c174b..c8d3931fe 100644 --- a/src/levanter/utils/fsspec_utils.py +++ b/src/levanter/utils/fsspec_utils.py @@ -1,11 +1,7 @@ -import asyncio - import braceexpand import fsspec from fsspec.asyn import AsyncFileSystem -from levanter.utils.thread_utils import _executor, blocking_wait - def exists(url, **kwargs) -> bool: """Check if a file exists on a remote filesystem.""" @@ -40,10 +36,7 @@ def remove(url, *, recursive=False, **kwargs): # TODO: better to use a STS deletion policy or job for this one. fs, path = fsspec.core.url_to_fs(url, **kwargs) - if isinstance(fs, AsyncFileSystem): - blocking_wait(fs._rm(path, recursive=recursive)) - else: - fs.rm(path, recursive=recursive) + fs.rm(path, recursive=recursive) async def async_remove(url, *, recursive=False, **kwargs): @@ -53,4 +46,4 @@ async def async_remove(url, *, recursive=False, **kwargs): if isinstance(fs, AsyncFileSystem): return await fs._rm(path, recursive=recursive) else: - return await asyncio.wrap_future(_executor.submit(fs.rm, path, recursive=recursive)) + fs.rm(path, recursive=recursive) From dbdc2e49cd9cc373baec123f4e1ee0cb58531346 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 07:49:41 -0800 Subject: [PATCH 38/66] fix resume bookkeeping logic --- src/levanter/store/cache.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index cd59aef4c..51679921d 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -992,6 +992,11 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ) copy_refs[group] = last_ref + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + continue + # update the offset information: data offsets and total rows this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) data_offset_tree = jax.tree.map( @@ -1045,14 +1050,15 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) return {shard_name: [shard_name] for shard_name in source.shard_names} shard_names = source.shard_names - num_shards_per_group = len(shard_names) // num_groups + num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups # if we have a remainder, we'll just add it to the last group out_groups = { f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) for i in range(num_groups) } - if len(shard_names) % num_shards_per_group != 0: - out_groups[f"group_{num_groups - 1}"].extend(shard_names[num_groups * num_shards_per_group :]) + + # make sure we got all the shards + assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) return out_groups # type: ignore From fb44f035df1abc18211a1979c183558f550a8cf0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 10:02:42 -0800 Subject: [PATCH 39/66] wip --- src/levanter/store/cache.py | 229 ++++++++++++++++++++---------------- 1 file changed, 130 insertions(+), 99 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 51679921d..b1b69e2c3 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -670,7 +670,7 @@ def __init__( # (we get twice from we need to concatenate prepared batches into the accumulator) # TODO: measure. memory=2 * self._options.target_bytes_per_flush, - ).remote(current_actor_handle(), cache_dir, self._ledger, source, options, processor) + ).remote(current_actor_handle(), cache_dir, source, options, processor) self._tokenize_pbar = tqdm( total=len(source.shard_names), desc=f"{path_for_name}: tokenizing", unit="shard" @@ -787,7 +787,6 @@ def _report_progress(self, report: "_ProgressReport"): ) def _report_copy_progress(self, report: "_ProgressReport"): - # TODO: log bytes copied self._copy_pbar.update(report.new_shards) self._copy_report_totals.new_shards += report.new_shards self._copy_report_totals.new_rows += report.new_rows @@ -840,7 +839,6 @@ class _ShardFinished: def _core_writer_task( parent, cache_dir, - initial_ledger: CacheLedger, source: ShardedDataSource, options: CacheOptions, processor, @@ -859,9 +857,6 @@ def _core_writer_task( # append a small random number to the name to avoid collisions name += f"::{random.randint(0, 1000)}" - # We want to make sure it's there - initial_ledger._serialize_and_commit(cache_dir) - # we want to do the following: # 1. write the 0th shard group to the output cache directly, updating metrics as we go # 2. in the background, start processing other shard groups to temporary caches @@ -879,8 +874,8 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") - paths: dict[str, str] = {} - ledgers: dict[str, CacheLedger | None] = {} + group_cache_paths: dict[str, str] = {} + group_ledgers: dict[str, CacheLedger | None] = {} write_refs: dict[str, ray.ObjectRef] = {} shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) @@ -902,10 +897,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): else: group_out_path = os.path.join(temporary_cache_path, group_name) - paths[group_name] = group_out_path + group_cache_paths[group_name] = group_out_path ledger = _try_load(group_out_path) - ledgers[group_name] = ledger + group_ledgers[group_name] = ledger if ledger is not None: if group_name == first_group: @@ -931,109 +926,145 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): write_refs[group_name] = ref # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) - # This logic is a bit hairy thanks to resumes. - # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these - # separately. We also need to update the ledger as we go. - # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. - # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. - # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. - - # * When we load the permanent cache, we have already written some number of groups to it. - # (We check this invariant with an assert) - # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. - # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" - # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate - # the data offsets for where the group goes in the permanent cache. This is just a running sum of the - # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree - # of integers, one for each array. - # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache - # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. - # * We also need to update the ledger with the total number of rows - permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) - # initialize the data offset tree - data_offset_tree = jax.tree_map(lambda x: 0, permanent_cache.tree) - total_rows_from_caches = 0 - - copy_refs: dict[str, ray.ObjectRef] = {} - last_ref: ray.ObjectRef | None = None - - for group in shard_groups: - # first make sure it's either done this run or already done - if write_refs.get(group) is not None: - this_ledger = ray.get(write_refs[group]) - - ledgers[group] = ledger - else: - this_ledger = ledgers[group] - - assert this_ledger is not None - # see if we already copied this group, meaning all the shards are in the permanent cache - shards_copied = sum(1 if shard in initial_ledger.finished_shards else 0 for shard in shard_groups[group]) - if shards_copied == len(shard_groups[group]) or group == first_group: - assert initial_ledger.total_num_rows >= total_rows_from_caches - parent._report_copy_progress.remote( - _ProgressReport(new_shards=shards_copied, new_rows=initial_ledger.total_num_rows) - ) - elif shards_copied > 0: - # In theory we can handle this, but it's a bit tricky, so we're going to punt for now - raise RuntimeError("Some shards were copied but not all. This should never happen.") - else: - # we need to copy this group - ref_to_send = None if last_ref is None else RefBox(last_ref) - last_ref = _copy_cache.remote( - cache_dir, - paths[group], - processor_ref, - data_offset_tree, - ref_to_send, - total_rows_from_caches, - parent, - ) - copy_refs[group] = last_ref + ledger = _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, + group_cache_paths, processor, processor_ref) - if group == first_group: - # this is the first group, so it's already in the cache and we don't need to - # increment the data offset tree etc. - continue + ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) + parent._notify_updated_ledger.remote(ledger) - # update the offset information: data offsets and total rows - this_cache = TreeStore.open(processor.output_exemplar, paths[group], mode="r", cache_metadata=True) - data_offset_tree = jax.tree.map( - operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) - ) - total_rows_from_caches += this_ledger.total_num_rows + temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} + _clean_up_temp_caches(temporary_cache_paths) - # this little bit is totally unnecessary but nice logging - for group in shard_groups: - if group == first_group: - continue - if copy_refs.get(group) is not None: - ledger = ray.get(copy_refs[group]) - ledgers[group] = ledger - parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) - ) +def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, group_cache_paths, processor, + processor_ref): + """ + Copy the temporary caches to the output cache, in order. (essentially concatenating them) - # refs form a linked list implicitly, so we can just wait on the last one - if last_ref is not None: - ledger = ray.get(last_ref) + Args: + parent: the parent actor handle (_TreeStoreCacheBuilder) + cache_dir: the output cache directory + shard_groups: a dict mapping group names to lists of shard names + first_group: the privileged group that is written directly to the output cache + write_refs: a dict mapping group names to ray.ObjectRefs of the cache building tasks + group_ledgers: a dict mapping group names to the ledgers for the groups. Mutated in place. + group_cache_paths: a dict mapping group names to the paths of the temporary caches + processor: the processor object + processor_ref: a ray.ObjectRef of the processor object + """ + # This logic is a bit hairy thanks to resumes. + # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these + # separately. We also need to update the ledger as we go. + # Second, note that JaggedArrayStores have two notions of length: the number of rows, and the data size. + # We store the number of rows in offsets[0], and the data size in offsets[offsets[0]], which is just the final offset. + # So we can keep a cache "locked" to a particular read size until we're ready by controlling the offsets. + + # * When we load the permanent cache, we have already written some number of groups to it. In + # particular, we have written the 0'th group to the permanent cache. + # * We enforce that we only commit a whole group to the ledger at a time. + # * We need to copy the remaining groups to the permanent cache, and update the ledger as we go. + # * To copy a group, we need to know the total number of rows in that group, as well as the "data offsets" + # for the data in the cache. We can get the total number of rows from the ledger, and we also calculate + # the data offsets for where the group goes in the permanent cache. This is just a running sum of the + # data sizes of the previous groups. Because we have multiple JaggedArrayStores, this can be a pytree + # of integers, one for each array. + # * Once we have finished the i'th cache and all caches < 1, we can "unlock" the data for the i'th cache + # by updating the offset[0] of the permanent cache to the total number of rows through the i'th cache. + # * We also need to update the ledger with the total number of rows + + # reload the ledger for the first group, which will be the sink for the other groups + assert first_group in write_refs + + group_ledgers[first_group] = ray.get(write_refs[first_group]) + overall_ledger = group_ledgers[first_group] + + # initialize the data offset tree + permanent_cache = TreeStore.open(processor.output_exemplar, cache_dir, mode="a", cache_metadata=False) + data_offset_tree = jax.tree_map(lambda x: x.data_size, permanent_cache.tree) + total_rows_from_caches = overall_ledger.total_num_rows + copy_refs: dict[str, ray.ObjectRef] = {} + last_ref: ray.ObjectRef | None = None + + found_one_to_copy = False + + for group in shard_groups: + # first make sure it's either done this run or already done + if write_refs.get(group) is not None: + this_ledger = ray.get(write_refs[group]) + group_ledgers[group] = this_ledger else: - ledger = initial_ledger + this_ledger = group_ledgers[group] - ledger.is_finished = True - ledger._serialize_and_commit(cache_dir) - parent._notify_updated_ledger.remote(ledger) + if group == first_group: + # this is the first group, so it's already in the cache and we don't need to + # increment the data offset tree etc. + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(overall_ledger.finished_shards), new_rows=overall_ledger.total_num_rows) + ) + continue - _clean_up_temp_caches(paths, first_group) + assert this_ledger is not None + # see if we already copied this group, meaning all the shards are in the permanent cache + shards_copied = sum(1 if shard in overall_ledger.finished_shards else 0 for shard in shard_groups[group]) + + if found_one_to_copy and shards_copied > 0: + raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") + elif shards_copied == len(shard_groups[group]): + assert overall_ledger.total_num_rows >= total_rows_from_caches, f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" + continue # nothing to do + elif shards_copied > 0: + # In theory we can handle this, but it's a bit tricky, so we're going to punt for now + raise RuntimeError("Some shards were copied but not all. This should never happen.") + + found_one_to_copy = True + # we need to copy this group + + # we can't "commit" the group to the ledger (or the number of rows) + # until we've updated the ledger for all previous groups, so we block on the last ref + ref_to_send = None if last_ref is None else RefBox(last_ref) + + last_ref = _copy_cache.remote( + cache_dir, + group_cache_paths[group], + processor_ref, + data_offset_tree, + ref_to_send, + total_rows_from_caches, + parent, + ) + copy_refs[group] = last_ref + # update the offset information: data offsets and total rows + this_cache = TreeStore.open(processor.output_exemplar, group_cache_paths[group], mode="r", cache_metadata=True) + data_offset_tree = jax.tree.map( + operator.add, data_offset_tree, jax.tree.map(lambda x: x.data_size, this_cache.tree) + ) + total_rows_from_caches += this_ledger.total_num_rows -def _clean_up_temp_caches(paths, first_group): - for group, path in paths.items(): + # this little bit is totally unnecessary but nice logging + for group in shard_groups: if group == first_group: continue + if copy_refs.get(group) is not None: + ledger = ray.get(copy_refs[group]) + group_ledgers[group] = ledger + parent._report_copy_progress.remote( + _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) + ) + + # refs form a linked list implicitly, so we can just wait on the last one + if last_ref is not None: + ledger = ray.get(last_ref) + else: + ledger = overall_ledger + return ledger + + +def _clean_up_temp_caches(paths): + for path in paths: if fsspec_exists(path): for i in range(10): # this is crashy for some reason From 98e017093b054bf9ec4019f434f1460b1894d7a1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:15:35 -0800 Subject: [PATCH 40/66] reorg r --- src/levanter/store/cache.py | 7 ++----- tests/test_new_cache.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index b1b69e2c3..affe274eb 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -330,7 +330,6 @@ def build_or_load( shard_source=shard_source, processor=processor, options=options, - split=split, ) return TreeCache(cache_dir=cache_dir, exemplar=processor.output_exemplar, ledger=None, _broker=broker) @@ -637,7 +636,6 @@ def __init__( self, cache_dir: str, name: str, - split: str, # to workaround https://github.com/ray-project/ray/issues/44083 source: ShardedDataSource[T], processor: BatchProcessor[T, U], options: CacheOptions, @@ -803,14 +801,13 @@ def _report_copy_progress(self, report: "_ProgressReport"): ) -def _get_builder_actor(split, cache_dir, shard_source, processor, options=CacheOptions.default()): - name = f"lev_cache_manager::{split}::{cache_dir}" +def _get_builder_actor(cache_dir, shard_source, processor, options=CacheOptions.default()): + name = f"lev_cache_manager::{cache_dir}" path_for_name = os.path.join(*os.path.split(cache_dir)[-2:]) name_for_display = f"builder::{path_for_name}" return _TreeStoreCacheBuilder.options(name=name, get_if_exists=True).remote( # type: ignore name=name_for_display, - split=split, cache_dir=cache_dir, source=shard_source, processor=processor, diff --git a/tests/test_new_cache.py b/tests/test_new_cache.py index 0edd1fb15..086de48e1 100644 --- a/tests/test_new_cache.py +++ b/tests/test_new_cache.py @@ -128,13 +128,13 @@ def test_full_end_to_end_cache(): with td as tmpdir: ray_ds = build_or_load_cache( tmpdir, - SimpleShardSource(num_shards=2), + SimpleShardSource(num_shards=15), TestProcessor(), await_finished=True, - options=CacheOptions.no_fanciness(8), + options=CacheOptions(num_shard_groups=3, batch_size=8), ) - expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=2)) + expected = simple_process(TestProcessor(), SimpleShardSource(num_shards=15)) all_data = ray_ds[:] @@ -191,7 +191,6 @@ class _CustomException(Exception): @pytest.mark.ray -@pytest.mark.skip("This test segfaults in CI. I think a ray bug") def test_cache_recover_from_crash(): class CrashingShardSource(ShardedDataSource[list[int]]): def __init__(self, crash_point: int): @@ -205,7 +204,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: # parse the shard name to get the shard number shard_num = int(shard_name.split("_")[1]) for i in range(10): - if shard_num * 10 + i == self.crash_point: + if i == self.crash_point: raise _CustomException(f"Crashing at {shard_num} {i} {self.crash_point}") if i >= row: yield [shard_num * 10 + i] * 10 @@ -213,7 +212,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as tmpdir2: source = CrashingShardSource(4) with pytest.raises(_CustomException): - build_or_load_cache(tmpdir, source, TestProcessor()) + build_or_load_cache(tmpdir, source, TestProcessor(), CacheOptions(target_size_per_flush=1)) # kill the broker actor so that we can test recovery ray.kill( @@ -231,11 +230,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[list[int]]: ) # testing this doesn't throw - source = CrashingShardSource(1000) + source = CrashingShardSource(100000) reader1 = build_or_load_cache(tmpdir, source, TestProcessor(), await_finished=True) # compare to the original with no crash - reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(), TestProcessor(), await_finished=True) + reader2 = build_or_load_cache(tmpdir2, SimpleShardSource(num_shards=4), TestProcessor(), await_finished=True) check_datasets_equal(reader1, reader2) From bebf4038c503443fdace611b23eb9418fab3e4c0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:17:58 -0800 Subject: [PATCH 41/66] fix hf data loading for datasets>=3.1.0 --- src/levanter/data/sharded_datasource.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..90803df3e 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -197,7 +197,10 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: dataset = self._load_dataset() if isinstance(dataset, datasets.IterableDataset) and shard_name != "data": # ex_iterable has a key that gets discarded typically - shard = map(lambda t: t[1], dataset._ex_iterable.shard_data_sources(int(shard_name), dataset.n_shards)) + shard = map( + lambda t: t[1], + dataset._ex_iterable.shard_data_sources(index=int(shard_name), num_shards=dataset.n_shards), + ) else: shard = dataset From ad0c3573a558a7afa236425f260bf53d02ba09ad Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 3 Nov 2024 16:18:27 -0800 Subject: [PATCH 42/66] go ahead and bump datasets --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 19fb077bf..0831605cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", - "datasets>=2.18,<4.0", + "datasets>=3.1.0,<4.0", "gcsfs>=2024.2,<2024.10", "braceexpand>=0.1.7", "jmp>=0.0.3", From fa00824eb129d279139d7814f6403b38b46a605a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 4 Nov 2024 20:57:59 -0800 Subject: [PATCH 43/66] wip --- src/levanter/store/cache.py | 49 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index affe274eb..794af8e23 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -742,11 +742,18 @@ def _notify_updated_ledger(self, ledger: CacheLedger): Called by the cache writer when it has updated the ledger. """ was_finished = self._ledger.is_finished - self._ledger = ledger + # ensure the ledger is "monotonic" meaning that we only expect it to grow + if ledger.total_num_rows < self._ledger.total_num_rows: + raise RuntimeError(f"Ledger went backwards: {ledger.total_num_rows} < {self._ledger.total_num_rows}") + + for shard, rows in ledger.shard_rows.items(): + if rows < self._ledger.shard_rows.get(shard, 0): + raise RuntimeError(f"Shard {shard} went backwards: {rows} < {self._ledger.shard_rows.get(shard, 0)}") if was_finished: raise RuntimeError("Ledger was already finished") + self._ledger = ledger if self._ledger.is_finished: logger.info(f"Finalizing cache {self._cache_dir}...") # guard against invalid state errors @@ -767,7 +774,8 @@ async def _do_notify_async(): def _report_progress(self, report: "_ProgressReport"): import humanfriendly - self._tokenize_pbar.update(report.new_shards) + if report.new_shards > 0: + self._tokenize_pbar.update(report.new_shards) self._report_totals.new_shards += report.new_shards self._report_totals.new_rows += report.new_rows self._report_totals.new_bytes += report.new_bytes @@ -866,7 +874,7 @@ def report_fn(report: _ProgressReport, ledger: CacheLedger): def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): parent._report_progress.remote(report) - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) with log_failures_to(parent): temporary_cache_path = os.path.join(cache_dir, "___temp") @@ -877,6 +885,9 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) + for name, group in shard_groups.items(): + assert len(group) > 0 + logger.debug( f"Tokenizing {len(source.shard_names)} shards in {len(shard_groups)} groups to {temporary_cache_path}." ) @@ -901,10 +912,10 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): if ledger is not None: if group_name == first_group: - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) continue - report_fn = report_fn_first_group if group_name == first_group else report_fn + report_fn_to_use = report_fn_first_group if group_name == first_group else report_fn ref = ( ray.remote(_tokenize_one_shard_group) @@ -917,7 +928,7 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): retry_exceptions=True, max_retries=10, ) - .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn, parent) + .remote(group_out_path, source_ref, shards, processor_ref, options, report_fn_to_use, parent) ) write_refs[group_name] = ref @@ -929,7 +940,7 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): ledger.is_finished = True ledger._serialize_and_commit(cache_dir) - parent._notify_updated_ledger.remote(ledger) + ray.get(parent._notify_updated_ledger.remote(ledger)) temporary_cache_paths = set(group_cache_paths.values()) - {cache_dir} _clean_up_temp_caches(temporary_cache_paths) @@ -1078,12 +1089,16 @@ def _assign_shards_to_groups(source: ShardedDataSource, num_groups: int | None) return {shard_name: [shard_name] for shard_name in source.shard_names} shard_names = source.shard_names - num_shards_per_group = (len(shard_names) + num_groups - 1) // num_groups - # if we have a remainder, we'll just add it to the last group - out_groups = { - f"group_{i}": list(shard_names[i * num_shards_per_group : (i + 1) * num_shards_per_group]) - for i in range(num_groups) - } + num_shards_per_group = (len(shard_names)) // num_groups + num_groups_with_extra = len(shard_names) % num_groups + + # if we have a remainder, we want to distribute the extra shards evenly + out_groups: dict[str, list[str]] = {} + start = 0 + for i in range(num_groups): + num_shards = num_shards_per_group + (1 if i < num_groups_with_extra else 0) + out_groups[f"group_{i}"] = list(shard_names[start : start + num_shards]) + start += num_shards # make sure we got all the shards assert sum(len(shards) for shards in out_groups.values()) == len(shard_names) @@ -1099,7 +1114,9 @@ def _merge_ledgers(dest, source): dest.shard_rows[shard] = rows dest.finished_shards.extend(source.finished_shards) - dest.field_counts.update(source.field_counts) + for field, count in source.field_counts.items(): + dest.field_counts[field] = dest.field_counts.get(field, 0) + count + return dest @@ -1126,7 +1143,6 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R if last_ref is not None: ray.wait([last_ref.ref], fetch_local=False) permanent_cache = TreeStore.open(processor.output_exemplar, dest_path, mode="a", cache_metadata=False) - dest_ledger = CacheLedger.load(dest_path) source_ledger = CacheLedger.load(source_path) new_num_rows = source_ledger.total_num_rows + rows_so_far @@ -1135,11 +1151,12 @@ def _copy_cache(dest_path, source_path, processor, data_offset_tree, last_ref: R for future in futures: future.result() + dest_ledger = CacheLedger.load(dest_path) _merge_ledgers(dest_ledger, source_ledger) dest_ledger._serialize_and_commit(dest_path) assert not dest_ledger.is_finished - parent._notify_updated_ledger.remote(dest_ledger) + ray.get(parent._notify_updated_ledger.remote(dest_ledger)) parent._report_copy_progress.remote( _ProgressReport(new_shards=len(source_ledger.shard_rows), new_rows=source_ledger.total_num_rows) ) From 91383f31b1ac17c77823b1e8df1646641edc812e Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 4 Nov 2024 23:05:43 -0800 Subject: [PATCH 44/66] fix wandb key (#785) turns out ray doesn't merge things when you use .options, which... what. --- src/levanter/infra/ray_tpu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index 57f484770..b04648079 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -11,6 +11,7 @@ from typing import Callable, Optional, Sequence import draccus +import mergedeep import ray from ray._private.accelerators import TPUAcceleratorManager from ray.dashboard.modules.job.sdk import JobSubmissionClient @@ -198,10 +199,15 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> my-tpu num_tpus_per_host = TPUAcceleratorManager.get_current_node_num_accelerators() # -> 8 + # ray doesn't merge the runtime envs properly, so we have to do it ourselves + # we need to do a deep merge + runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + remote_fn = remote_fn.options( runtime_env=runtime_env, resources={tpu_name: 1, "TPU": num_tpus_per_host}, ) + logger.info(f"Running on TPU {tpu_name} with {num_hosts} hosts and {num_tpus_per_host} TPUs per host") return remote_fn, tpu_name From 45d3e70f3a61fadf693da75fd322072947caf3f1 Mon Sep 17 00:00:00 2001 From: William Held Date: Tue, 5 Nov 2024 12:21:15 -0500 Subject: [PATCH 45/66] Fix Llama 3 Tests (#782) HuggingFace seems to have changed a few things around in what info they expect to be stored in the config leading the Llama 3 roundtrip tests to hit errors. AFAICT, the Torch tests aren't running in CI so this just fixes the regression! ![image](https://github.com/user-attachments/assets/eec1f2a3-ceb9-443a-911e-ea6476fa91bf) --- src/levanter/models/rotary.py | 1 + tests/test_llama3.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/rotary.py b/src/levanter/models/rotary.py index 07657e5ff..55bbf3fcb 100644 --- a/src/levanter/models/rotary.py +++ b/src/levanter/models/rotary.py @@ -157,6 +157,7 @@ def to_hf_config(self) -> tuple[float, dict]: "low_freq_factor": self.low_freq_factor, "high_freq_factor": self.high_freq_factor, "original_max_position_embeddings": self.original_max_position_embeddings, + "rope_type": "llama3", } diff --git a/tests/test_llama3.py b/tests/test_llama3.py index 2fae326d1..653ba723c 100644 --- a/tests/test_llama3.py +++ b/tests/test_llama3.py @@ -26,9 +26,10 @@ def get_config(vocab_size=1000): "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, + "head_dim": 64, "initializer_range": 0.02, "intermediate_size": 14336, - "max_position_embeddings": 8192, + "max_position_embeddings": 131072, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, @@ -55,6 +56,7 @@ def get_config(vocab_size=1000): llama3_8b_config.hidden_size = 16 llama3_8b_config.intermediate_size = 64 llama3_8b_config.num_attention_heads = 4 + llama3_8b_config.head_dim = 4 llama3_8b_config.num_hidden_layers = 4 llama3_8b_config.num_key_value_heads = 2 llama3_8b_config.max_position_embeddings = 128 From b51a3802dc89a8eb9965e0a8bde4bf6bf2b4fec5 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 11:03:55 -0800 Subject: [PATCH 46/66] was updating too many times --- src/levanter/store/cache.py | 47 ++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 794af8e23..d49f4553b 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -933,10 +933,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): write_refs[group_name] = ref - # now we start copying the temporary caches to the output cache, in order. (essentially concatenating them) - - ledger = _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, - group_cache_paths, processor, processor_ref) + ledger = _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, + ) ledger.is_finished = True ledger._serialize_and_commit(cache_dir) @@ -946,8 +953,17 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): _clean_up_temp_caches(temporary_cache_paths) -def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, group_ledgers, group_cache_paths, processor, - processor_ref): +def _start_copies( + parent, + cache_dir, + shard_groups, + first_group, + write_refs, + group_ledgers, + group_cache_paths, + processor, + processor_ref, +): """ Copy the temporary caches to the output cache, in order. (essentially concatenating them) @@ -961,6 +977,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou group_cache_paths: a dict mapping group names to the paths of the temporary caches processor: the processor object processor_ref: a ray.ObjectRef of the processor object + + Returns: + The final ledger """ # This logic is a bit hairy thanks to resumes. # First, note that each TreeCache is a tree of JaggedArrayStores, and we need to copy each of these @@ -1020,7 +1039,9 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou if found_one_to_copy and shards_copied > 0: raise RuntimeError("A previous group was copied, but this group was not. This should never happen.") elif shards_copied == len(shard_groups[group]): - assert overall_ledger.total_num_rows >= total_rows_from_caches, f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" + assert ( + overall_ledger.total_num_rows >= total_rows_from_caches + ), f"{overall_ledger.total_num_rows} < {total_rows_from_caches}. {group}" continue # nothing to do elif shards_copied > 0: # In theory we can handle this, but it's a bit tricky, so we're going to punt for now @@ -1051,18 +1072,6 @@ def _start_copies(parent, cache_dir, shard_groups, first_group, write_refs, grou ) total_rows_from_caches += this_ledger.total_num_rows - # this little bit is totally unnecessary but nice logging - for group in shard_groups: - if group == first_group: - continue - - if copy_refs.get(group) is not None: - ledger = ray.get(copy_refs[group]) - group_ledgers[group] = ledger - parent._report_copy_progress.remote( - _ProgressReport(new_shards=len(ledger.finished_shards), new_rows=ledger.total_num_rows) - ) - # refs form a linked list implicitly, so we can just wait on the last one if last_ref is not None: ledger = ray.get(last_ref) From f53c99180915050988090a57b90e1d4a8b8a4d77 Mon Sep 17 00:00:00 2001 From: Kamyar Salahi Date: Tue, 5 Nov 2024 19:26:51 -0800 Subject: [PATCH 47/66] Internal eval fixes (#788) Allowing internal supervised eval to work without separate eval set --------- Co-authored-by: David Hall --- config/gpt2_small_fast_supervised.yaml | 1 + src/levanter/data/text.py | 2 +- src/levanter/main/train_lm.py | 8 ++++---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_supervised.yaml index d71e1267e..93675366d 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_supervised.yaml @@ -15,6 +15,7 @@ data: supervised_data: validation_urls: - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" + - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz" cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" input_field: "input" output_field: "output" diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 70c1fe4b3..f2bea44b2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -631,7 +631,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain input_field = config.input_field output_field = config.output_field - output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((), dtype=np.int32)} + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} dataset = dataset.map_batches(lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar) # type: ignore dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # type: ignore diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index fe5e5dd35..79095d601 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -160,13 +160,13 @@ def main(config: TrainLmConfig): levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + if len(tagged_eval_datasets) == 0: logger.warning("No evaluation datasets provided.") else: - max_eval_examples_per_ds = config.trainer.max_eval_batches - if max_eval_examples_per_ds is not None: - max_eval_examples_per_ds *= config.trainer.eval_batch_size - causal_datasets = [ (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) for ds, tags in tagged_eval_datasets From 20ff94c80479df25805a5c616c526077efc3620f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 20:54:40 -0800 Subject: [PATCH 48/66] fix empty shards --- src/levanter/store/cache.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index d49f4553b..5e8657fc0 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -883,6 +883,14 @@ def report_fn_first_group(report: _ProgressReport, ledger: CacheLedger): group_ledgers: dict[str, CacheLedger | None] = {} write_refs: dict[str, ray.ObjectRef] = {} + if len(source.shard_names) == 0: + logger.info("No shards to process. Writing empty ledger.") + ledger = CacheLedger.load_or_initialize(cache_dir, source, processor) + ledger.is_finished = True + ledger._serialize_and_commit(cache_dir) + ray.get(parent._notify_updated_ledger.remote(ledger)) + return + shard_groups = _assign_shards_to_groups(source, options.num_shard_groups) for name, group in shard_groups.items(): From 0e6a6b4ecf89e245ac80090a488c739512ce215b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 21:20:43 -0800 Subject: [PATCH 49/66] correct total byte calculation for bpb (#789) ... when there are no tags --- src/levanter/eval.py | 52 +++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 555dd1466..99e132dc2 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -63,7 +63,7 @@ def __init__( self.datasets = [] tag_index: dict[str, int] = {} for i, (dataset, tags) in enumerate(datasets): - if tags is None: + if not tags and len(datasets) > 1: warnings.warn("Dataset has no tags. Giving it an index") tags = [f"domain_{i}"] for tag in tags: @@ -204,14 +204,16 @@ def eval_callback(step: StepInfo): } logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}") - for tag, loss in result.tag_macro_losses.items(): - # don't log leaf tag macro losses because it doesn't mean anything different than micro loss - if tag in evaluator.dataset.tag_to_index: - continue - if not tag: - continue - log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss - logger.info(f"{tag} macro loss: {loss:.3f}") + has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro + if has_tags: + for tag, loss in result.tag_macro_losses.items(): + # don't log leaf tag macro losses because it doesn't mean anything different than micro loss + if tag in evaluator.dataset.tag_to_index: + continue + if not tag: + continue + log_dict[_join_prefix(prefix, tag) + "/macro_loss"] = loss + logger.info(f"{tag} macro loss: {loss:.3f}") for tag, loss in result.tag_micro_losses.items(): if not tag: @@ -225,11 +227,14 @@ def eval_callback(step: StepInfo): if tokenizer is not None: log_dict[_join_prefix(prefix, "bpb")] = result.micro_bpb - log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb + if has_tags: + log_dict[_join_prefix(prefix, "macro_bpb")] = result.macro_bpb for tag, bpb in result.tag_micro_bpb.items(): log_dict[_join_prefix(prefix, tag) + "/bpb"] = bpb - for tag, bpb in result.tag_macro_bpb.items(): - log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb + + if has_tags: + for tag, bpb in result.tag_macro_bpb.items(): + log_dict[_join_prefix(prefix, tag) + "/macro_bpb"] = bpb levanter.tracker.log_metrics(log_dict, step=step.step) @@ -304,26 +309,29 @@ def accum_for_batch(m: LmHeadModel, state: _EvalRunningMeans, batch: LmExample, this_loss_per_tag = hax.einsum("-> tag", mask, losses, tags) # [Tag] mean = state.token_avg_loss.add(this_loss / this_tokens, this_tokens) - # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag - safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) - mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) + state = dataclasses.replace(state, token_avg_loss=mean) - state = dataclasses.replace(state, token_avg_loss=mean, loss_per_tag=mean_per_tag) + if len(self.dataset.tag_to_index) > 0: + # careful: this_tokens_per_tag can be 0 if there are no tokens for that tag + safe_mean = hax.where(this_tokens_per_tag, this_loss_per_tag / this_tokens_per_tag, 0.0) + mean_per_tag = state.loss_per_tag.add(safe_mean, this_tokens_per_tag) + state = dataclasses.replace(state, loss_per_tag=mean_per_tag) if self.bytes_per_token is not None: next_tokens = hax.roll(batch.tokens, -1, m.Pos) # [Batch, Pos], rolled by 1 for next token task bytes_per_pos = self.bytes_per_token.take("vocab", next_tokens) # [Batch, Pos] - bytes_per_pos = bytes_per_pos * mask # [Batch, Pos] - bytes_per_tag = hax.einsum("-> tag", bytes_per_pos, tags) # [Tag] - total_bytes = hax.sum(bytes_per_tag) + bytes_per_tag = hax.einsum("-> tag", mask, bytes_per_pos, tags) # [Tag] + this_bytes = hax.einsum("->", bytes_per_pos, mask) # Scalar # log loss -> bits is log2(e) * loss bpb_per_tag = this_loss_per_tag / hax.maximum(bytes_per_tag, 1) * jnp.log2(jnp.e) - bpb = this_loss / hax.maximum(total_bytes, 1) * jnp.log2(jnp.e) + bpb = this_loss / hax.maximum(this_bytes, 1) * jnp.log2(jnp.e) bpb_mean = state.bpb.add(bpb, this_tokens) - bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) - state = dataclasses.replace(state, bpb=bpb_mean, bpb_per_tag=bpb_per_tag_mean) + state = dataclasses.replace(state, bpb=bpb_mean) + if len(self.dataset.tag_to_index) > 0: + bpb_per_tag_mean = state.bpb_per_tag.add(bpb_per_tag, this_tokens_per_tag) + state = dataclasses.replace(state, bpb_per_tag=bpb_per_tag_mean) return state From 1c43256463edac6f4692b640b7f01573d076ae8e Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 5 Nov 2024 23:34:29 -0800 Subject: [PATCH 50/66] prepare lora for state dict change (#791) --- tests/test_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lora.py b/tests/test_lora.py index f7d852531..b6933f935 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -74,8 +74,8 @@ def __call__(self, x): @staticmethod def init(*, key): k1, k2 = jax.random.split(key) - first = hnn.Linear.init(In, Mid, key=k1) - second = hnn.Linear.init(Mid, In, key=k2) + first = hnn.Linear.init(In, Mid, key=k1, out_first=True) + second = hnn.Linear.init(Mid, In, key=k2, out_first=True) return Module(first, second) Layers = hax.Axis("Layers", 3) @@ -91,7 +91,7 @@ def init(*, key): assert loraized.stacked.first.lora.lora_A.weight.axes == (Layers, hax.Axis("LORA_R", 8), In) assert loraized.stacked.first.lora.lora_B.weight.axes == (Layers, Mid, hax.Axis("LORA_R", 8)) - assert loraized.stacked.second.weight.axes == (Layers, Mid, In) + assert loraized.stacked.second.weight.axes == (Layers, In, Mid) input = hax.random.normal(k0, (In,)) assert not hax.all(hax.isclose(module.fold(input), loraized.fold(input))) From a7e42ecc5972c88abde51231239a981c2ebb4fce Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 6 Nov 2024 09:49:56 -0800 Subject: [PATCH 51/66] Add "blocked"/"flash" cross entropy (#790) to mitigate large tokenizers limiting blocksize (e.g. llama3) imposes a kind of not ideal refactor on LMModel, but it's not the worst. FYI @Helw150 --------- Co-authored-by: Ivan Zhou Co-authored-by: Abhinav Garg --- config/llama3_small_fast.yaml | 32 +++ config/llama_7b_with_dclm.yaml | 2 +- src/levanter/infra/ray_tpu.py | 3 +- src/levanter/models/backpack.py | 7 +- src/levanter/models/gemma.py | 8 +- src/levanter/models/gpt2.py | 8 +- src/levanter/models/llama.py | 25 ++ src/levanter/models/lm_model.py | 65 ++++- src/levanter/models/loss.py | 408 +++++++++++++++++++++++++++++++- src/levanter/models/mistral.py | 9 +- src/levanter/models/mpt.py | 7 +- src/levanter/store/cache.py | 2 +- src/levanter/trainer.py | 3 +- tests/test_hf_gpt2_serialize.py | 10 +- tests/test_loss.py | 325 +++++++++++++++++++++++++ tests/test_text.py | 10 +- 16 files changed, 880 insertions(+), 44 deletions(-) create mode 100644 config/llama3_small_fast.yaml create mode 100644 tests/test_loss.py diff --git a/config/llama3_small_fast.yaml b/config/llama3_small_fast.yaml new file mode 100644 index 000000000..df1df9f96 --- /dev/null +++ b/config/llama3_small_fast.yaml @@ -0,0 +1,32 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/" + tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" +model: + type: llama + hidden_dim: 768 + intermediate_dim: 2048 + num_heads: 12 + num_kv_heads: 12 + num_layers: 12 + seq_len: 1024 + gradient_checkpointing: true +trainer: + tracker: + - type: wandb + project: "levanter" + tags: [ "openwebtext", "llama", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 256 + num_train_steps: 20000 +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/config/llama_7b_with_dclm.yaml b/config/llama_7b_with_dclm.yaml index 980e64e41..11a182f09 100644 --- a/config/llama_7b_with_dclm.yaml +++ b/config/llama_7b_with_dclm.yaml @@ -17,7 +17,7 @@ trainer: mp: p=f32,c=bfloat16 train_batch_size: 2048 - num_train_steps: 70000 # 280B / 4M + num_train_steps: 480000 # 2T / 4M steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" diff --git a/src/levanter/infra/ray_tpu.py b/src/levanter/infra/ray_tpu.py index b04648079..1a9342c54 100644 --- a/src/levanter/infra/ray_tpu.py +++ b/src/levanter/infra/ray_tpu.py @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env): # ray doesn't merge the runtime envs properly, so we have to do it ourselves # we need to do a deep merge - runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE) + sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None] + runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE) remote_fn = remote_fn.options( runtime_env=runtime_env, diff --git a/src/levanter/models/backpack.py b/src/levanter/models/backpack.py index 2a955395f..4de8accc7 100644 --- a/src/levanter/models/backpack.py +++ b/src/levanter/models/backpack.py @@ -401,7 +401,7 @@ def init(Vocab: Axis, config: BackpackConfig, *, key): ) @named_call - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: k_embed, k_transformer, k_senses, k_sa = haliax.jax_utils.maybe_rng_split(key, 4) @@ -428,9 +428,10 @@ def __call__( scale = self.config.Senses.size hidden_states = hidden_states / scale - lm_logits = self.embeddings.unembed(hidden_states) + return hidden_states - return lm_logits + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None): new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/gemma.py b/src/levanter/models/gemma.py index af5cc44be..c38acf5ef 100644 --- a/src/levanter/models/gemma.py +++ b/src/levanter/models/gemma.py @@ -339,6 +339,9 @@ def vocab_size(self) -> int: def Vocab(self) -> Axis: return self.embeddings.Vocab + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings.weight + @classmethod def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) @@ -346,7 +349,7 @@ def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel": embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) return GemmaLMHeadModel(transformer, embeddings) - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, @@ -363,8 +366,7 @@ def __call__( """ x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=key) - lm_logits = self.embeddings.unembed(x) - return lm_logits + return x def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]": new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/gpt2.py b/src/levanter/models/gpt2.py index a921074e9..28e878193 100644 --- a/src/levanter/models/gpt2.py +++ b/src/levanter/models/gpt2.py @@ -391,15 +391,17 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel": return Gpt2LMHeadModel(transformer, embeddings) - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: k_embed, k_transformer = haliax.jax_utils.maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids, key=k_embed) x = self.transformer(x, attn_mask, key=k_transformer) - lm_logits = self.embeddings.unembed(x) - return lm_logits + return x + + def get_lm_head(self) -> hax.NamedArray: + return self.embeddings.token_embeddings.weight def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "Gpt2LMHeadModel": new_embeddings = self.embeddings.resize_embeddings(new_size, key=key) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 1e09ffbc5..85861da6a 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -557,6 +557,31 @@ def __call__( lm_logits = self.embeddings.unembed(x) return lm_logits + def activations( + self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None + ) -> NamedArray: + """ + Compute the activations for the next token in a sequence. + Args: + input_ids: token IDs with shape {Pos} + attn_mask: attention mask with shape {Pos, KeyPos} + key: PRNGKey for random number generation + + Returns: + NamedArray: activations with shape {Pos, Embed} + + """ + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=key) + + return x + + def get_lm_head(self) -> hax.NamedArray: + if self.lm_head is None: + return self.embeddings.token_embeddings.weight + else: + return self.lm_head.weight + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 468f6a4a4..911e74b09 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -64,6 +64,18 @@ def KeyPos(self) -> Axis: def Pos(self) -> Axis: pass + @property + @abc.abstractmethod + def Embed(self) -> Axis: + pass + + cross_entropy_block_size: Optional[int] = 64000 + """ + The block size for computing cross-entropy loss. This is the number of tokens that are processed together + in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large + value because it usually faster to compute the loss in larger blocks. + """ + def flops_per_token(self, vocab_size: int) -> Optional[float]: return None @@ -94,17 +106,58 @@ def Pos(self) -> Axis: def KeyPos(self) -> Axis: return self.config.KeyPos + @property + def Embed(self) -> Axis: + return self.config.Embed + @classmethod @abc.abstractmethod def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[LmConfigT]": pass - @abc.abstractmethod def __call__( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: + """ + Compute the logits for the next token in a sequence. + Args: + input_ids: token IDs with shape [..., Pos] + attn_mask: attention mask with shape [..., Pos, KeyPos] + key: PRNGKey for random number generation + + Returns: + NamedArray: logits with shape [..., Pos, Vocab] + + """ + x = self.activations(input_ids, attn_mask, key=key) + lm_logits = hax.dot(x, self.get_lm_head(), axis=self.Embed) + + return lm_logits + + @abc.abstractmethod + def activations( + self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None + ) -> NamedArray: + """ + Compute the activations for the next token in a sequence. + Args: + input_ids: token IDs with shape {Pos} + attn_mask: attention mask with shape {Pos, KeyPos} + key: PRNGKey for random number generation + + Returns: + NamedArray: activations with shape {Pos, Embed} + + """ pass + @abc.abstractmethod + def get_lm_head(self) -> hax.NamedArray: + """ + The language modeling head of the model. Should have shape {Embed, Vocab}. + """ + raise NotImplementedError("get_lm_head not implemented") + @abc.abstractmethod def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]": """ @@ -133,19 +186,21 @@ def compute_next_token_loss( across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not reduced, and the result is a named array with axes (*batch axes, sequence_length). """ - logits = model(example.tokens, example.attn_mask, key=key) - if loss_dtype is not None: - logits = logits.astype(loss_dtype) + activations = model.activations(example.tokens, example.attn_mask, key=key) loss = next_token_loss( model.Pos, + model.Embed, model.Vocab, - logits, + activations, + model.get_lm_head(), example.tokens, loss_mask=example.loss_mask, reduction=reduction, reduction_axis=reduction_axis, logsumexp_weight=logsumexp_weight, + dtype=loss_dtype, + block_size=model.config.cross_entropy_block_size, ) return loss diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 1ef7e81f9..154fc66ac 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -1,5 +1,8 @@ +import functools from typing import Optional +import equinox +import jax import jax.numpy as jnp import haliax as hax @@ -9,34 +12,77 @@ def next_token_loss( Pos: hax.AxisSelector, + Embed: hax.AxisSelector, Vocab: hax.AxisSelector, - pred_ids: NamedArray, + pred_embeddings: NamedArray, + pred_lm_head: NamedArray, true_ids: NamedArray, loss_mask: Optional[NamedArray] = None, reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, logsumexp_weight: Optional[float] = None, -): - Pos, Vocab = pred_ids.resolve_axis((Pos, Vocab)) - # need to roll the target tokens back by one so that each token is predicting the next token + block_size: Optional[int] = None, + dtype: Optional[jnp.dtype] = jnp.float32, +) -> NamedArray: + """ + Compute the next token loss with optional block-wise processing. + + Args: + Pos (hax.AxisSelector): Position axis selector. + Vocab (hax.AxisSelector): Vocabulary axis selector. + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + true_ids (NamedArray): True token IDs. + loss_mask (Optional[NamedArray]): Mask to apply to the loss. + reduction (Optional[hax.ReductionFunction]): Reduction function. + reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction. + logsumexp_weight (Optional[float]): Weight for logsumexp penalty. + block_size (Optional[int]): Size of each block for processing. + + Returns: + NamedArray: Computed loss. + """ + # Resolve axes + Pos = pred_embeddings.resolve_axis(Pos) + Vocab = pred_lm_head.resolve_axis(Vocab) + + # Shift target tokens to predict the next token target_y = hax.roll(true_ids, -1, Pos) - target_y = hax.nn.one_hot(target_y, Vocab, dtype=pred_ids.dtype) # type: ignore - # one everywhere except the last token + # Create a mask that excludes the last token not_last_loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) # type: ignore if loss_mask is not None: loss_mask = loss_mask * not_last_loss_mask else: loss_mask = not_last_loss_mask - return cross_entropy_and_logsumexp_penalty( - pred_ids, - Vocab, - target_y, + if block_size is None: + # Full softmax computation + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) + target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) + return cross_entropy_and_logsumexp_penalty( + logits, + Vocab, + target_y_full, + reduction=reduction, + reduction_axis=reduction_axis, + where=loss_mask, + logsumexp_weight=logsumexp_weight, + ) + + # Compute the loss with optional block-wise processing + return fused_cross_entropy_loss_and_logsumexp_penalty( + pred_embeddings, + pred_lm_head, + Contract=Embed, + Label=Vocab, + target_y=target_y, reduction=reduction, reduction_axis=reduction_axis, where=loss_mask, logsumexp_weight=logsumexp_weight, + block_size=block_size, + dtype=dtype, ) @@ -58,3 +104,345 @@ def cross_entropy_and_logsumexp_penalty( loss = loss + logsumexp_weight * (log_normalizers**2) return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + + +def fused_cross_entropy_loss_and_logsumexp_penalty( + pred_embeddings: NamedArray, + pred_lm_head: NamedArray, + Contract: hax.AxisSelector, + Label: hax.AxisSelector, + target_y: NamedArray, + *, + reduction: Optional[hax.ReductionFunction] = hax.mean, + reduction_axis: Optional[hax.AxisSelection] = None, + where: Optional[NamedArray] = None, + logsumexp_weight: float | None = 0.0, + block_size: int, + dtype: Optional[jnp.dtype] = jnp.float32, +) -> NamedArray: + """ + Compute the cross-entropy loss and logsumexp penalty using embeddings and lm_head, + with optional block-wise processing. + + Args: + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + Contract (hax.AxisSelector): Axis to contract over. + Label (hax.AxisSelector): Label (Vocab) axis. + target_y (NamedArray): One-hot encoded target tokens. + reduction (Optional[hax.ReductionFunction]): Reduction function. + reduction_axis (Optional[hax.AxisSelection]): Axis to apply reduction. + where (Optional[NamedArray]): Mask to apply to the loss. + logsumexp_weight (float): Weight for logsumexp penalty. + block_size (int): Size of each block for processing. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Returns: + NamedArray: Computed loss. + """ + + # Block-wise softmax computation + loss, log_normalizers = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), Contract, Label, target_y, block_size, dtype=dtype + ) + + if logsumexp_weight is not None and (not isinstance(logsumexp_weight, (int, float)) or logsumexp_weight != 0.0): + loss = loss + logsumexp_weight * (log_normalizers**2) + + return hax.nn.loss.maybe_reduce_loss(loss, reduction, reduction_axis, where) + + +@equinox.filter_custom_vjp +def _blockwise_cross_entropy_loss( + # pred_embeddings: NamedArray, + # pred_lm_head: NamedArray, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[NamedArray, NamedArray]: + """ + Compute cross-entropy loss and log normalizers in a block-wise manner without materializing the full logits. + + Args: + pred_embeddings (NamedArray): Predicted embeddings. + pred_lm_head (NamedArray): Language model head weights. + Contract (hax.Axis): Axis to contract over. + Label (hax.AxisSelector): Label (Vocab) axis. + labels_y (NamedArray): label tensor. + block_size (int): Size of each block for processing. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Notes: + labels_y being anything other than the label tensor would remove any benefits + + TODO: but if XLA smart enough to optimize it out? + + Returns: + tuple[NamedArray, NamedArray]: tuple of loss and log_normalizers. + """ + + return _block_cross_entropy_forward(None, pred, Contract, Label, labels_y, block_size, dtype)[0] + + +def _block_cross_entropy_forward( + ignore, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[tuple[NamedArray, NamedArray], tuple[NamedArray]]: + """ + Forward pass for block-wise cross-entropy loss. + + This function computes the cross-entropy loss and log-sum-exp (`log_z`) in a block-wise manner + to maintain memory efficiency by processing subsets of the vocabulary at a time. + + Args: + ignore: Placeholder argument (unused). + pred (Tuple[NamedArray, NamedArray]): Tuple containing predicted embeddings and language model head weights. + Contract (hax.Axis): Axis to contract over (e.g., embedding axis). + Label (hax.Axis): Label axis (e.g., vocabulary axis). + labels_y (NamedArray): True target labels [Batch, Seq]. + block_size (int): Number of vocabulary tokens per block. + dtype (Optional[jnp.dtype]): Data type for the computations. + + Returns: + Tuple: + - Tuple[NamedArray, NamedArray]: Computed loss and logsumexp. + - Tuple[NamedArray]: Residuals needed for the backward pass. + """ + vocab_size = Label.size + + pred_embeddings, pred_lm_head = pred + + # + # if num_blocks == 1: + # # No need for block-wise processing + # logits = hax.dot(pred_embeddings, pred_lm_head, axis=Contract) + # labels_y = hax.nn.one_hot(labels_y, Label, dtype=pred_embeddings.dtype) + # return cross_entropy_loss_and_log_normalizers(logits, Label, labels_y) + # + # ensure block size divides vocab size + if vocab_size % block_size != 0: + has_stragglers = True + else: + has_stragglers = False + + num_blocks = vocab_size // block_size + + # Initialize accumulators: loss, logsumexp, max_logits + initial_O = hax.zeros(labels_y.axes) + initial_logsumexp = hax.full(labels_y.axes, -jnp.inf) + initial_max = hax.full(labels_y.axes, -jnp.inf) + # We don't need this b/c we're using one-hot targets + # initial_sumV = hax.full(labels_y.axes, 0.0) + + def process_block(block_idx, acc, current_block_size): + """ + Process a single block of the Vocab dimension. + + Args: + block_idx (int): Index of the current block. + acc (tuple[NamedArray, NamedArray, jnp.ndarray]): Accumulators for loss, logsumexp, and max logits. + current_block_size (int): Size of the current block (used for stragglers). + + Returns: + tuple[NamedArray, NamedArray, jnp.ndarray]: Updated accumulators + """ + loss, logsumexp_prev, max_logit_prev = acc + + start = block_idx * block_size + Block = Label.resize(current_block_size) + + # Materialize the logits for the current block + lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] + logits_b = hax.dot( + pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype + ) # [Batch, Seq, Block] + + # Update max and logsumexp + max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq] + # reweight the previous logsumexp by the new max, fold in the new logits' contribution + logsumexp = max_logit + hax.log( + hax.exp(logsumexp_prev - max_logit) + hax.sum(hax.exp(logits_b - max_logit), axis=Block) + ) # [Batch, Seq] + + # Materialize the target for the current block (one-hot) + target_y_b = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] + + # Update sumV. This is actually unnecessary if we're using one-hot targets + # sV = sV_prev + hax.sum(target_y_b, axis=Label.name) + + loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq] + + return loss, logsumexp, max_logit # , sV + + if num_blocks == 0: + o = initial_O + log_z = initial_logsumexp + max_logits = initial_max + elif num_blocks == 1: + o, log_z, max_logits = process_block(0, (initial_O, initial_logsumexp, initial_max), vocab_size) + else: + (o, log_z, max_logits) = jax.lax.fori_loop( + lower=0, + upper=num_blocks, + body_fun=functools.partial(process_block, current_block_size=block_size), + init_val=(initial_O, initial_logsumexp, initial_max), # , initial_sumV + ) + + if has_stragglers: + # Handle the stragglers + remainder_size = vocab_size - num_blocks * block_size + o, log_z, _ = process_block(num_blocks, (o, log_z, max_logits), remainder_size) + + # unnecessary if we're using one-hot targets + # logz_outer = hax.einsum("->...", log_z, sum_v) + o = log_z - o + + return (o, log_z), (log_z,) + + +def _block_cross_entropy_backward( + residuals: tuple[NamedArray,], + grad_in: tuple[NamedArray, NamedArray], + ignore, + pred: tuple[NamedArray, NamedArray], + Contract: hax.Axis, + Label: hax.Axis, + labels_y: NamedArray, + block_size: int, + dtype: Optional[jnp.dtype], +) -> tuple[NamedArray, NamedArray]: + """ + Compute the gradients of the block-wise cross-entropy loss. + + Args: + residuals (tuple[NamedArray, NamedArray]): Residuals from the forward pass. + grad_in (tuple[NamedArray, NamedArray]): Incoming gradients. + pred (tuple[NamedArray, NamedArray]): Predictions. + Contract (hax.Axis): Axis to contract over. + Label (hax.Axis): Label axis. + labels_y (NamedArray): Target labels. + block_size (int): Size of each block. + dtype (Optional[jnp.dtype]): Data type for the loss. + + Returns: + tuple[NamedArray, NamedArray]: Gradients. + """ + + (log_z,) = residuals + grad_loss, grad_log_z = grad_in + + vocab_size = Label.size + + pred_embeddings, pred_lm_head = pred + + if vocab_size % block_size != 0: + has_stragglers = True + else: + has_stragglers = False + + num_blocks = vocab_size // block_size + + grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype) + + def process_block(block_idx, acc, current_block_size): + """ + Process a single block of the Vocab dimension. + + Args: + block_idx (int): Index of the current block. + acc (tuple[NamedArray, NamedArray]): Accumulators for gradients. + current_block_size (int): Size of the current block (used for stragglers). + + Returns: + tuple[NamedArray, NamedArray]: Updated accumulators. + """ + grad_embeddings_prev, grad_lm_head_prev = acc + + start = block_idx * block_size + Block = Label.resize(current_block_size) + + # Materialize the logits for the current block + lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] + logits_b = hax.dot( + pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype + ) # [Batch, Seq, Block] + + # Materialize the target for the current block (one-hot) + target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] + + # materialize the softmax for the current block + p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block] + + delta_b = p_b - target_y_block + + # # dLoss/dL = g_loss * delta_b + g_log_z * probs_b + # # = g_loss * (probs_b - Y) + g_log_z * probs_b + # # = (g_loss + g_log_z) * probs_b - g_loss * Y + + # Compute gradients. We get None if the gradient is not provided. + if grad_loss.array is not None: + dLoss = grad_loss * delta_b # [Batch, Seq, Block] + else: + dLoss = 0.0 + + # Add the gradient of the logsumexp term (should be None if not provided) + if grad_log_z.array is not None: + dLoss += grad_log_z * p_b # [Batch, Seq, Block] + + # Compute gradients for the current block + # embeddings has shape [Batch, Seq, Embed], so we need to eliminate Block + g_embeddings_b = hax.dot( + dLoss, lm_head_b, axis=Block, preferred_element_type=grad_embeddings.dtype + ) # [Batch, Seq, Embed] + + # lm_head has shape [Block, Embed], so we need to eliminate Batch, Seq, etc. + eliminated_axes_W = hax.axis.without_axes(pred_embeddings.axes, lm_head_b.axes) + g_lm_head_b = hax.dot( + dLoss, pred_embeddings, axis=eliminated_axes_W, preferred_element_type=grad_lm_head_prev.dtype + ) # [Block, Embed] + + g_lm_head = grad_lm_head_prev.at[Label, hax.dslice(start, Block)].set(g_lm_head_b) + g_embeddings = grad_embeddings_prev + g_embeddings_b + + return g_embeddings, g_lm_head + + if num_blocks == 0: + pass + elif num_blocks == 1: + grad_embeddings, grad_lm_head = process_block(0, (grad_embeddings, grad_lm_head), vocab_size) + else: + grad_embeddings, grad_lm_head = jax.lax.fori_loop( + lower=0, + upper=num_blocks, + body_fun=functools.partial(process_block, current_block_size=block_size), + init_val=(grad_embeddings, grad_lm_head), + ) + + if has_stragglers: + # Handle the stragglers + remainder_size = vocab_size - num_blocks * block_size + grad_embeddings, grad_lm_head = process_block(num_blocks, (grad_embeddings, grad_lm_head), remainder_size) + + return grad_embeddings.astype(pred_embeddings.dtype), grad_lm_head.astype(pred_lm_head.dtype) + + +_blockwise_cross_entropy_loss.def_fwd(_block_cross_entropy_forward) +_blockwise_cross_entropy_loss.def_bwd(_block_cross_entropy_backward) + + +def _block_one_hot(LBlock, block_start, labels, dtype): + end = block_start + LBlock.size + target_is_in_this_block = hax.logical_and(labels >= block_start, labels < end) + target_y_block = hax.nn.one_hot(labels - block_start, LBlock, dtype=dtype) + # 0 out the logits that are not in this block + target_y_block *= target_is_in_this_block + return target_y_block diff --git a/src/levanter/models/mistral.py b/src/levanter/models/mistral.py index b48bfbe91..764e18aea 100644 --- a/src/levanter/models/mistral.py +++ b/src/levanter/models/mistral.py @@ -175,7 +175,11 @@ def init(cls, Vocab: Axis, config: MistralConfig, *, key) -> "MistralLMHeadModel lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) return MistralLMHeadModel(transformer, embeddings, lm_head) - def __call__( + def get_lm_head(self) -> hax.NamedArray: + assert self.lm_head.bias is None + return self.lm_head.weight + + def activations( self, input_ids: NamedArray, attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, @@ -193,8 +197,7 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - lm_logits = self.lm_head(x, key=k_head) - return lm_logits + return x def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[MistralConfig]": new_Vocab = self.Vocab.resize(new_size) diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 00044a4ed..0809d9d23 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -447,14 +447,15 @@ def init(cls, Vocab: Axis, config: MptConfig, *, key): return MptLmHeadModel(wte, transformer, config) @named_call - def __call__( + def activations( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray], *, key=None ) -> NamedArray: hidden_states = self.wte.embed(input_ids) hidden_states = self.transformer(hidden_states, attention_mask=attn_mask, key=key) - output_logits = self.wte.unembed(hidden_states) + return hidden_states - return output_logits + def get_lm_head(self) -> hax.NamedArray: + return self.wte.weight def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "MptLmHeadModel": if new_size == self.vocab_size: diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 5e8657fc0..558bbfceb 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -840,7 +840,7 @@ class _ShardFinished: path_to_shard: str -@ray.remote(num_cpus=1) +@ray.remote(num_cpus=1, runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "cpu"})) def _core_writer_task( parent, cache_dir, diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 8e98eaedb..92d7af4ac 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -499,7 +499,8 @@ def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwa grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) mbs = self.config.microbatch_size grad_fn = microbatched(grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping) - return grad_fn(model, *batch, **batch_kwargs) + with hax.axis_mapping(self.compute_axis_mapping): + return grad_fn(model, *batch, **batch_kwargs) def _initialize_global_tracker(config, run_id): diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 7a5475738..a0002b1c1 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -19,7 +19,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, RepoRef from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config, Gpt2LMHeadModel -from levanter.models.loss import next_token_loss +from levanter.models.lm_model import LmExample, LmHeadModel, compute_next_token_loss from levanter.optim import AdamConfig from levanter.utils.tree_utils import inference_mode from test_utils import arrays_only, skip_if_no_torch @@ -132,12 +132,10 @@ def torch_loss(model, input_ids) -> torch.Tensor: return model(input_ids, labels=input_ids)[0] torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input.array)).to(torch.int64).unsqueeze(0)) - causal_mask = AttentionMask.causal() - def compute_loss(model, input_ids): - pred_y = model(input_ids, key=None, attn_mask=causal_mask) - - return next_token_loss(model.Pos, model.Vocab, pred_y, input_ids).scalar() + def compute_loss(model: LmHeadModel, input_ids): + example = LmExample.causal(input_ids) + return compute_next_token_loss(model, example, key=None).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) jax_grad: Gpt2LMHeadModel diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100644 index 000000000..30d140ede --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,325 @@ +# test_cross_entropy.py +import math + +import equinox +import jax.numpy as jnp +import jax.random +import pytest + +import haliax as hax +from haliax import NamedArray + +# Import the functions from your module +# Replace 'your_module' with the actual module name where your functions are defined +from levanter.models.loss import _blockwise_cross_entropy_loss, cross_entropy_loss_and_log_normalizers +from levanter.utils.jax_utils import key_iterator + + +Batch = hax.Axis("batch", size=2) +Seq = hax.Axis("seq", size=3) +Embed = hax.Axis("embed", size=8) +Vocab = hax.Axis("vocab", size=16) + + +@pytest.fixture +def test_data(): + """ + Create synthetic test data for cross-entropy loss computation. + """ + + key = key_iterator(jax.random.PRNGKey(0)) + + # Initialize pred_embeddings with ones + pred_embeddings = hax.random.normal(next(key), (Batch, Seq, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) + + # Initialize pred_lm_head with ones + pred_lm_head = hax.random.normal(next(key), (Vocab, Embed), dtype=jnp.float32) / math.sqrt(Embed.size) + + # Define true_ids such that the target is always the first token in vocab + true_ids = hax.random.randint(next(key), (Batch, Seq), 0, Vocab.size) + + return pred_embeddings, pred_lm_head, true_ids + + +def test_basic_equivalence(test_data): + """ + Test that block-wise loss equals full loss when block_size perfectly divides vocab_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss_full, norm_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full) + + loss_block, norm_this = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=8, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." + + +def test_single_block(test_data): + """ + Test behavior when vocab_size equals block_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + loss_full, sumexp_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids) + + # Compute block-wise loss with block_size=4 (vocab_size=4) + with jax.disable_jit(): + loss_block, sumexp_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=Vocab.size, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(sumexp_full, sumexp_block, atol=1e-3, rtol=1e-3) + ), "Single block-wise sumexp does not match full sumexp." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Single block-wise loss does not match full loss." + + +def _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids): + logits_full = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y_full = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss_full, sumexp_full = cross_entropy_loss_and_log_normalizers(logits_full, Vocab, target_y_full) + return loss_full, sumexp_full + + +def test_multiple_blocks(test_data): + """ + Test block-wise loss with multiple blocks. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute full loss + loss_full, logz_full = _compute_full(Vocab, pred_embeddings, pred_lm_head, true_ids) + + # Compute block-wise loss with block_size=1 (vocab_size=4) + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=1, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise logz does not match full logz." + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Multiple block-wise loss does not match full loss." + + +def test_block_size_not_dividing_vocab(test_data): + pred_embeddings, pred_lm_head, true_ids = test_data + + # Set block_size that does not divide vocab_size + block_size = 3 # vocab_size=4 + + # should be fine now + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Block-wise logz does not match full logz." + + +def test_vocab_size_less_than_block_size(test_data): + """ + Test behavior when vocab_size is less than block_size. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Set block_size greater than vocab_size + block_size = 5 # vocab_size=4 + + # should be fine now + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Assert that the losses are close + assert hax.all(hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3)), "loss does not match full loss." + assert hax.all(hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3)), "logz does not match full logz." + + +def test_large_vocab(): + """ + Test block-wise loss with a larger vocabulary. + """ + Batch = hax.Axis("batch", size=4) + Seq = hax.Axis("seq", size=5) + Embed = hax.Axis("embed", size=6) + Vocab = hax.Axis("vocab", size=12) + + pred_embeddings = NamedArray( + jnp.ones((Batch.size, Seq.size, Embed.size)), + axes=(Batch, Seq, Embed), + ) + pred_lm_head = NamedArray( + jnp.ones((Embed.size, Vocab.size)), + axes=(Embed, Vocab), + ) + true_ids = NamedArray( + jnp.zeros((Batch.size, Seq.size), dtype=jnp.int32), + axes=(Batch, Seq), + ) + + # Compute full loss + loss_full, logz_full = cross_entropy_loss_and_log_normalizers( + pred_y=hax.dot(pred_embeddings, pred_lm_head, axis="embed"), + Label=Vocab, + target_y=hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype), + ) + + # Compute block-wise loss with block_size=3 (vocab_size=12 is divisible by 3) + loss_block, logz_block = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=3, + dtype=pred_embeddings.dtype, + ) + + # Assert that the losses are close + assert hax.all( + hax.isclose(loss_full, loss_block, atol=1e-3, rtol=1e-3) + ), "Large vocab block-wise loss does not match full loss." + assert hax.all( + hax.isclose(logz_full, logz_block, atol=1e-3, rtol=1e-3) + ), "Large vocab block-wise logz does not match full logz." + + +@pytest.mark.parametrize("block_size", [1, 2, 3, 4, 5]) +def test_gradient_block_cross_entropy(block_size, test_data): + """ + Test the gradient of block-wise cross-entropy loss. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute block-wise loss + def custom_fn(pred): + pred_embeddings, pred_lm_head = pred + a, b = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=block_size, + dtype=pred_embeddings.dtype, + ) + + return (a.mean() + b.mean()).scalar() + + g_embed, g_head, = equinox.filter_grad( + custom_fn + )((pred_embeddings, pred_lm_head)) + + # compute directly + + def direct_fn(pred): + pred_embeddings, pred_lm_head = pred + logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed") + target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss, logz = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y) + return (loss.mean() + logz.mean()).scalar() + + g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) + + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match." + + +def test_grad_loss_without_logz(test_data): + """ + Test the gradient of block-wise cross-entropy loss without logz. + """ + pred_embeddings, pred_lm_head, true_ids = test_data + + # Compute block-wise loss + def custom_fn(pred): + pred_embeddings, pred_lm_head = pred + a, b = _blockwise_cross_entropy_loss( + (pred_embeddings, pred_lm_head), + Contract=Embed, + Label=Vocab, + labels_y=true_ids, + block_size=2, + dtype=pred_embeddings.dtype, + ) + + return a.mean().scalar() + + g_embed, g_head, = equinox.filter_grad( + custom_fn + )((pred_embeddings, pred_lm_head)) + + # compute directly + + def direct_fn(pred): + pred_embeddings, pred_lm_head = pred + logits = hax.dot(pred_embeddings, pred_lm_head, axis="embed", preferred_element_type=pred_embeddings.dtype) + target_y = hax.nn.one_hot(true_ids, Vocab, dtype=pred_embeddings.dtype) + loss, _ = cross_entropy_loss_and_log_normalizers(logits, Vocab, target_y) + return loss.mean().scalar() + + g_embed_direct, g_head_direct = equinox.filter_grad(direct_fn)((pred_embeddings, pred_lm_head)) + + assert hax.all( + hax.isclose(g_embed, g_embed_direct, atol=1e-3, rtol=1e-3) + ), "Gradient of embeddings does not match." + assert hax.all(hax.isclose(g_head, g_head_direct, atol=1e-3, rtol=1e-3)), "Gradient of lm_head does not match." diff --git a/tests/test_text.py b/tests/test_text.py index a2645c1f9..e4e51acbc 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -26,6 +26,7 @@ def test_dont_blow_up_without_validation_set(): def test_lm_example_handles_ignore_id(): Pos = hax.Axis("Pos", 10) Vocab = hax.Axis("vocab", Pos.size + 1) + Embed = hax.Axis("embed", 10) tokens = hax.arange(Pos, dtype=jnp.int32) ignore_id = 6 @@ -34,11 +35,12 @@ def test_lm_example_handles_ignore_id(): ex_no_ignore = LmExample.causal(tokens) assert ex_ignore.loss_mask[Pos, ignore_id - 1] == 0 - distr = -100 * hax.nn.one_hot(ignore_id, Vocab) - distr = distr.broadcast_axis(Pos) + logits = hax.ones((Pos, Embed)) + lm_head = hax.zeros((Embed, Vocab)) + lm_head = lm_head.at[Vocab, ignore_id].set(-100) - ignored_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_ignore.loss_mask) - no_ignore_loss = next_token_loss(Pos, Vocab, distr, tokens, loss_mask=ex_no_ignore.loss_mask) + ignored_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_ignore.loss_mask) + no_ignore_loss = next_token_loss(Pos, Embed, Vocab, logits, lm_head, tokens, loss_mask=ex_no_ignore.loss_mask) assert no_ignore_loss.item() >= ignored_loss.item() + 100 / Pos.size From 812accb4f5b29f6f0bda84bcf87a4d4f7c538091 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 6 Nov 2024 10:08:51 -0800 Subject: [PATCH 52/66] load data from marin sources --- examples/sft/sft.py | 45 ++++++++---- examples/sft/tulu-llama-sft.yaml | 51 ++++++++++++++ src/levanter/data/sharded_datasource.py | 91 +++++++++++++++++++++++++ src/levanter/data/text.py | 78 +++++++++++++++++++++ 4 files changed, 253 insertions(+), 12 deletions(-) create mode 100644 examples/sft/tulu-llama-sft.yaml diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 9813184b9..2ced8591c 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -1,6 +1,8 @@ import logging import os from dataclasses import dataclass +from enum import Enum +from typing import List, Optional import jax.random as jrandom import transformers @@ -13,11 +15,10 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import EpochDataset, mk_supervised_dataset +from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer -from levanter.utils.py_utils import non_caching_cycle logger = logging.getLogger(__name__) @@ -29,12 +30,26 @@ DEFAULT_UNK_TOKEN = "" +class DatasetType(str, Enum): + """Type of dataset to use""" + + HUGGINGFACE = "huggingface" # Use HF dataset + CHAT_JSONL = "chat_jsonl" # Use JSONL files with chat format + + @dataclass class SFTConfig(TrainLmConfig): # inherit most of the config from TrainLmConfig - max_tune_length: int = 2048 # maximum length of the input to the model during tuning + max_tune_length: int = 2048 model_name_or_path: str = "meta-llama/Llama-2-7b-hf" - tokenizer: str = "meta-llama/Llama-2-7b-hf" # Tokenizer to use + tokenizer: str = "meta-llama/Llama-2-7b-hf" + + # Add dataset type and chat-specific fields + dataset_type: DatasetType = DatasetType.HUGGINGFACE + chat_train_urls: Optional[List[str]] = None + messages_field: str = "messages" + input_role: str = "user" + output_role: str = "assistant" def train(config: SFTConfig): @@ -79,19 +94,26 @@ def train(config: SFTConfig): logger.info(f"Overriding data seed with {config.data_seed}") data_key = jrandom.PRNGKey(config.data_seed) - # Configure supervised dataset - supervised_config = config.supervised_data - # Create supervised dataset using generic machinery logger.info("Creating supervised dataset") - train_dataset = mk_supervised_dataset(supervised_config, tokenizer) + if config.dataset_type == DatasetType.CHAT_JSONL: + chat_config = ChatSFTDatasetConfig( + cache_dir=config.supervised_data.cache_dir, + train_urls=config.chat_train_urls, # No validation in this config + messages_field=config.messages_field, + input_role=config.input_role, + output_role=config.output_role, + ) + train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) + else: + train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) # Then wrap for epochs - # if config.epoch > 0: - # logger.info(f"Wrapping dataset for {config.epoch} epochs") - # train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) + if config.epoch > 0: + logger.info(f"Wrapping dataset for {config.epoch} epochs") + train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch) logger.info("Creating optimizer") optimizer = config.optimizer.build(config.trainer.num_train_steps) @@ -134,7 +156,6 @@ def train(config: SFTConfig): ) loader = trainer.data_loader(train_dataset, trainer.TrainBatch) - loader = non_caching_cycle(loader) if int(state.step) != 0: logger.info(f"Resuming training from step {state.step}") diff --git a/examples/sft/tulu-llama-sft.yaml b/examples/sft/tulu-llama-sft.yaml new file mode 100644 index 000000000..6086e624d --- /dev/null +++ b/examples/sft/tulu-llama-sft.yaml @@ -0,0 +1,51 @@ +# Model configuration +model: + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: true + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false + +# Training configuration +trainer: + mp: p=f32,c=bfloat16 + tracker: + type: wandb + project: "levanter-sft" + tags: ["llama", "sft"] + num_train_steps: 750000 + train_batch_size: 64 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + steps_per_eval: 1000 + +# Optimizer settings +optimizer: + learning_rate: 2e-5 + weight_decay: 0.0 + min_lr_ratio: 0.1 + warmup: 100 + +# Supervised data configuration +dataset_type: chat_jsonl +chat_train_urls: + - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" +supervised_data: + cache_dir: "gs://levanter-checkpoints/marin/sft_cache/chat-data" +messages_field: "messages" +input_role: "user" +output_role: "assistant" + +# Additional settings +tokenizer: "EleutherAI/gpt-neox-20b" +max_tune_length: 2048 +epoch: 0 + +initialize_from_hf: false diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 186a0d9dd..333ddf768 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -1,6 +1,7 @@ import io import json import os +import re import warnings from typing import ( TYPE_CHECKING, @@ -16,11 +17,13 @@ Tuple, TypeVar, ) +from urllib.parse import urlparse import datasets import fsspec import numpy as np import pyarrow.parquet as pq +from google.cloud import storage from levanter.utils import fsspec_utils @@ -144,6 +147,68 @@ def map_batches( ) +def gcs_glob(pattern: str) -> list[str]: + """Glob files in Google Cloud Storage. + + Args: + pattern: GCS path pattern (gs://bucket/path/*) + + Returns: + List of matching GCS URLs + """ + if not pattern.startswith("gs://"): + # Handle local files + import glob + + return glob.glob(pattern) + + # Parse bucket and prefix from gs:// URL + parsed = urlparse(pattern) + bucket_name = parsed.netloc + prefix = parsed.path.lstrip("/") + + # Convert glob pattern to regex + prefix_no_glob = prefix.split("*")[0] + pattern_as_regex = re.compile(re.escape(prefix).replace("\\*", ".*")) + + # Initialize GCS client + client = storage.Client() + bucket = client.bucket(bucket_name) + + # List matching blobs + matching_urls = [] + for blob in bucket.list_blobs(prefix=prefix_no_glob): + if pattern_as_regex.match(blob.name): + matching_urls.append(f"gs://{bucket_name}/{blob.name}") + + return matching_urls + + +def datasource_from_chat_jsonl( + urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" +) -> "ShardedDataSource[dict]": + """Creates a ShardedDataSource from JSONL files containing chat messages. + + Args: + urls: Sequence of URLs or glob patterns pointing to JSONL files + messages_field: Field name containing the messages in each JSON object + input_role: Role identifier for input messages + output_role: Role identifier for output messages + + Returns: + ShardedDataSource configured for chat data + """ + # Expand any glob patterns in the URLs + expanded_urls = [] + for url in urls: + if any(c in url for c in "*?[]"): + expanded_urls.extend(gcs_glob(url)) + else: + expanded_urls.append(url) + + return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) + + def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: """ Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset. @@ -463,6 +528,32 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: return iter(data[row:]) +class ChatJsonlDataSource(JsonlDataSource): + """DataSource that reads JSONL files containing OpenAI chat format messages.""" + + def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): + super().__init__(urls) + self.messages_field = messages_field + self.input_role = input_role + self.output_role = output_role + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + with fsspec.open(url, "r", compression="infer") as f: + for line in f: + if i >= row: + data = json.loads(line) + messages = data[self.messages_field] + + # Extract input/output from messages + input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) + output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) + + yield {"input": input_msg, "output": output_msg} + i += 1 + + class ParquetDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0181889d9..b42bcf5f6 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -743,6 +743,84 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) +@dataclass +class ChatSFTDatasetConfig(LMSupervisedDatasetConfig): + """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" + + # Chat format specific fields + messages_field: str = "messages" + input_role: str = "user" + output_role: str = "assistant" + train_urls: List[str] = field(default_factory=list) # Add this line + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + import levanter.data + + """Gets ShardedDataSource for either training or validation data.""" + urls = self.validation_urls if split == "validation" else self.train_urls + + if not urls: + return None + + # Use the datasource_from_chat_jsonl function from sharded_datasource + return levanter.data.sharded_datasource.datasource_from_chat_jsonl( + urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role + ) + + +def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: + """ + Preprocess chat examples to match the format of preprocess_supervised_example. + Returns a dict with input_ids and sources_len like the supervised case. + """ + # Get sources (inputs) and targets (outputs) from the batch + sources = [example["input"] for example in batch] + targets = [example["output"] for example in batch] + + # Tokenize sources alone first to get the source lengths + sources_tokenized = tokenizer(sources, padding=False, truncation=True) + + # Combine source and target for full examples + full_examples = [f"{s}{t}" for s, t in zip(sources, targets)] + examples_tokenized = tokenizer(full_examples, padding=False, truncation=True) + + # Get source lengths to mask loss appropriately + source_lens = [len(s) for s in sources_tokenized["input_ids"]] + + return { + "input_ids": [np.array(example, dtype=np.int32) for example in examples_tokenized["input_ids"]], + "sources_len": np.array(source_lens, dtype=np.int32), + } + + +def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]: + """Creates a dataset from JSONL files containing chat format data for SFT.""" + source = config.get_shard_source("train") + if source is None: + raise ValueError("No training data source found") + + # Set up example structure matching supervised case + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + + # Process the dataset + dataset = source.map_batches( + lambda ex: preprocess_chat_example(ex, tokenizer), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + + # Cache the processed data + dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + + # Ensure padding token is set (needed by _prepare_supervised_example) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Reuse the supervised prepare function directly + return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + + @dataclass class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" From caf0a38bd730ded21a8f8654cdf5153a34d82c21 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 6 Nov 2024 16:27:01 -0800 Subject: [PATCH 53/66] merge main --- examples/sft/alpaca-llama-fix.yaml | 55 ------------------------------ 1 file changed, 55 deletions(-) delete mode 100644 examples/sft/alpaca-llama-fix.yaml diff --git a/examples/sft/alpaca-llama-fix.yaml b/examples/sft/alpaca-llama-fix.yaml deleted file mode 100644 index 1590b7184..000000000 --- a/examples/sft/alpaca-llama-fix.yaml +++ /dev/null @@ -1,55 +0,0 @@ -# Model configuration -model: - activation_function: silu - gradient_checkpointing: true - hidden_dim: 4096 - initializer_range: 0.02 - intermediate_dim: 11008 - layer_norm_epsilon: 1.0e-05 - num_heads: 32 - num_kv_heads: 32 - num_layers: 32 - reference_checkpoint: meta-llama/Llama-2-7b-hf - seq_len: 4096 - type: llama - use_bias: false - use_layer_norm_weight: false - -# Training configuration -trainer: - mp: p=f32,c=bfloat16 - tracker: - type: wandb - project: "levanter-sft" - tags: ["llama", "sft"] - num_train_steps: 1218 - train_batch_size: 64 - tensor_parallel_axes: ["mlp", "heads"] - fsdp_axis: "embed" - batch_axis: "batch" - steps_per_eval: 1000 - -# Optimizer settings -optimizer: - learning_rate: 2e-5 - weight_decay: 0.0 - min_lr_ratio: 0.1 - warmup: 100 - -# Supervised data configuration -supervised_data: - cache_dir: "gs://levanter-checkpoints/marin/sft_cache/alpaca-olmo" - input_field: "instruction" - output_field: "output" - hf_dataset_name: "tatsu-lab/alpaca" # Changed from id - hf_dataset_split: "train" - name: "alpaca" # Optional metadata - tags: ["instruction-tuning"] # Optional metadata - validation_urls: [] # Empty list for no validation files - -# Additional settings -tokenizer: "allenai/OLMo-1B" -max_tune_length: 2048 -epoch: 3 - -initialize_from_hf: false \ No newline at end of file From be80580039dd1e5111136caf15fd313b775279e0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 7 Nov 2024 10:01:08 -0800 Subject: [PATCH 54/66] fix epochs in type signature, fix type checker (#792) --- examples/sft/sft.py | 15 +++++++++++---- src/levanter/data/text.py | 39 +++++++++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 2ced8591c..629b556c2 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -97,6 +97,8 @@ def train(config: SFTConfig): # Create supervised dataset using generic machinery logger.info("Creating supervised dataset") if config.dataset_type == DatasetType.CHAT_JSONL: + assert config.chat_train_urls is not None + assert config.supervised_data is not None chat_config = ChatSFTDatasetConfig( cache_dir=config.supervised_data.cache_dir, train_urls=config.chat_train_urls, # No validation in this config @@ -106,6 +108,7 @@ def train(config: SFTConfig): ) train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) else: + assert config.supervised_data is not None train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) @@ -122,7 +125,7 @@ def train(config: SFTConfig): # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: + with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore parameter_axis_mapping = trainer.parameter_axis_mapping # We have two axis_mappings: one for storing the model and optimizer states, and one for compute @@ -141,7 +144,7 @@ def train(config: SFTConfig): logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") model: LmHeadModel = converter.load_pretrained( model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype - ) + ) # type: ignore model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) state = trainer.initial_state(training_key, model=model) else: @@ -163,10 +166,14 @@ def train(config: SFTConfig): next(loader) if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + # bit gross to reach this far into the config, but it's fine + if config.trainer.checkpointer.append_run_id_to_base_path: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + else: + full_save_path = config.hf_save_path trainer.add_hook( - save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload), + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), every=config.hf_save_steps, ) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index e0bf93466..0654d1dfa 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -35,6 +35,7 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore +from levanter.utils import fsspec_utils from levanter.utils.fsspec_utils import expand_glob from levanter.utils.hf_utils import num_cpus_used_by_tokenizer @@ -616,7 +617,12 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase: @abc.abstractmethod def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: pass @@ -717,7 +723,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) else: # Using local files - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)] + validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)] if not validation_urls: raise ValueError("Must specify either hf_dataset_name or validation_urls") dataset = levanter.data.datasource_from_jsonl(validation_urls) @@ -735,12 +741,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain output_exemplar=output_exemplar, ) - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -811,14 +817,14 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken ) # Cache the processed data - dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) # Ensure padding token is set (needed by _prepare_supervised_example) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) @dataclass @@ -833,18 +839,19 @@ def train_set( monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, - epochs: int = 0, + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: - ds = self.token_seq_dataset("train", seq_len, monitors) - if epochs: - logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) + ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) # add epoch flag here. if ds is None: raise ValueError("No training set!") + if epochs: + logger.info("Wrapping dataset in epoch dataset") + ds = EpochDataset(ds, max_epochs=epochs) + if self.shuffle is True: ds = ds.shuffle(key) elif isinstance(self.shuffle, int) and self.shuffle > 0: @@ -989,11 +996,19 @@ def __post_init__(self): ) def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] + self, + seq_len: int, + monitors: Union[bool, List[MetricsMonitor]] = True, + *, + key: Optional[PRNGKeyArray], + epochs: Optional[int] = None, ) -> AsyncDataset[np.ndarray]: doc_caches = self.build_caches("train", monitors=monitors) token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} + if epochs: + raise ValueError("Epochs are not supported for mixture datasets") + if key is None: key = jax.random.PRNGKey(0) From 74b4108e7f4f45eb9e2c262f21232458554255f2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Thu, 7 Nov 2024 21:42:57 -0800 Subject: [PATCH 55/66] fix internal_eval lengths (#794) previously we were padding to max tokenizer length, which is real bad with llama3 --- examples/sft/sft.py | 5 +++-- src/levanter/data/audio.py | 1 + src/levanter/data/dataset.py | 13 ++++++++++++- src/levanter/data/mixture.py | 1 + src/levanter/data/permutation.py | 2 ++ src/levanter/data/text.py | 24 ++++++++++++++++-------- src/levanter/eval.py | 1 + src/levanter/main/train_lm.py | 2 +- src/levanter/store/cache.py | 1 + tests/test_doremi.py | 1 + tests/test_new_loader.py | 2 ++ tests/test_supervised.py | 3 ++- 12 files changed, 43 insertions(+), 13 deletions(-) diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 629b556c2..152781b0b 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -80,6 +80,7 @@ def train(config: SFTConfig): raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: converter = None + model_config = config.model levanter.initialize(config) @@ -106,10 +107,10 @@ def train(config: SFTConfig): input_role=config.input_role, output_role=config.output_role, ) - train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) + train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None - train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) + train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 12695a20b..fbb118cfe 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -270,6 +270,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]): """ def __init__(self, cache: TreeCache[AudioTextDict]): + super().__init__() self.cache = cache async def async_len(self) -> int: diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index def0c158a..4d71241d4 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -48,6 +48,9 @@ class AsyncDataset(DatasetBase[T_co]): * `current_len`: Returns the current length of the dataset. This may be None if no current length is known. """ + def __init__(self): + self._min_known_len = 0 + @abc.abstractmethod async def async_len(self) -> int: raise NotImplementedError @@ -95,7 +98,12 @@ async def wait_until_len_at_least(self, length: int) -> int: The default implementation is a naive busy-wait loop. You should override this method for more efficient implementations. """ - return await naive_busy_wait_until_len_at_least(self, length) + if length <= self._min_known_len: + return self._min_known_len + + res_len = await naive_busy_wait_until_len_at_least(self, length) + self._min_known_len = max(self._min_known_len, res_len) + return res_len def as_sync_dataset(self): return SyncifiedDataset(self) @@ -206,6 +214,7 @@ def __getitem__(self, index: int) -> T_co: class AsyncifiedDataset(AsyncDataset[T_co]): def __init__(self, dataset: SyncDataset[T_co]): + super().__init__() self.dataset = dataset async def async_len(self) -> int: @@ -239,6 +248,7 @@ class ListAsyncDataset(AsyncDataset[T]): """ def __init__(self, data: list[T], is_complete: bool = False): + super().__init__() self.data = data self.is_complete = is_complete if not is_complete: @@ -315,6 +325,7 @@ def __init__( *extra_args, **extra_kwargs, ): + super().__init__() self.dataset = dataset self.fn = fn self._extra_args = extra_args diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index eb1bdfaaf..63c623e4b 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -53,6 +53,7 @@ def __init__( key: PRNGKeyArray | int, stop_strategy: str = StopStrategy.RESTART_STRATEGY, ): + super().__init__() self.weights = MixtureDataset._normalize_weights(weights) self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0} self.dataset_index = Index(self.datasets.keys()) diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index 6599d4974..66a1887fd 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -14,6 +14,7 @@ class PermutationDataset(AsyncDataset[T_co]): # TODO: add epoch reshuffling def __init__(self, dataset: AsyncDataset[T_co], key: jax.random.PRNGKey): + super().__init__() self.dataset = dataset self.key = key self._permutation: Optional[Permutation] = None @@ -72,6 +73,7 @@ class EraShufflingDataset(AsyncDataset[T_co]): """ def __init__(self, dataset: AsyncDataset[T_co], era_length: int, *, key: jax.random.PRNGKey): + super().__init__() self.dataset = dataset self.era_length = era_length self.key = key diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0654d1dfa..7e92d200b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -75,6 +75,7 @@ class EpochDataset(AsyncDataset[T_co]): """ def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): + super().__init__() self.dataset = dataset self.max_epochs = max_epochs @@ -154,6 +155,7 @@ class TokenSeqDataset(AsyncDataset[np.ndarray]): """ def __init__(self, doc_cache: TreeCache[dict], seq_len: int): + super().__init__() self.doc_cache = doc_cache self.seq_len = seq_len self._store: Optional[TreeStore] = None @@ -687,7 +689,7 @@ def preprocess_supervised_example( } -def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample: +def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample: """ Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. @@ -699,11 +701,15 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> """ with local_cpu_mesh(): # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = tokenizer.pad( + {k: np.expand_dims(v, 0) for k, v in ex.items()}, + return_tensors="np", + padding="max_length", + max_length=Pos.size, + ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], "position") + input_ids = hax.named(ex["input_ids"], Pos) # mask out padding and anything before the start of the target - Pos = input_ids.resolve_axis("position") loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 # don't predict the padding @@ -714,7 +720,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> return lm_ex -def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): +def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis): import levanter.data # Choose data source based on config @@ -746,7 +752,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) @dataclass @@ -799,7 +805,9 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: } -def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]: +def mk_chat_sft_dataset( + config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis +) -> AsyncDataset[LmExample]: """Creates a dataset from JSONL files containing chat format data for SFT.""" source = config.get_shard_source("train") if source is None: @@ -824,7 +832,7 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) @dataclass diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 99e132dc2..16342be4d 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -60,6 +60,7 @@ def tags(self): def __init__( self, datasets: Sequence[tuple[AsyncDataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None ): + super().__init__() self.datasets = [] tag_index: dict[str, int] = {} for i, (dataset, tags) in enumerate(datasets): diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b1b5d4aaa..b411bd59e 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -209,7 +209,7 @@ def main(config: TrainLmConfig): if config.supervised_data is not None: logger.info("Using supervised data") - supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")] + supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")] # TODO Add tags cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 558bbfceb..e7a5306d4 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -191,6 +191,7 @@ def __init__( ledger: Optional["CacheLedger"], _broker, # handle of _TreeStoreCacheBuilder ): + super().__init__() self.cache_dir = cache_dir self.ledger = ledger self._was_already_finished = ledger is not None and ledger.is_finished diff --git a/tests/test_doremi.py b/tests/test_doremi.py index d2cf8b590..bbab04f52 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -38,6 +38,7 @@ def platform_of_array(x): class LogitDataset(AsyncDataset[Example]): def __init__(self, W, noise, x_mask, x_bias, *, key): + super().__init__() self.W = W self.noise = noise self.x_mask = x_mask diff --git a/tests/test_new_loader.py b/tests/test_new_loader.py index e6f9a3dd7..94b5238b2 100644 --- a/tests/test_new_loader.py +++ b/tests/test_new_loader.py @@ -64,6 +64,7 @@ def test_local_batched_data_loading_model_axis_1(): class StructuredDataset(AsyncDataset): def __init__(self, seq_len): + super().__init__() self.seq_len = seq_len self.begin = 0 self.end = 256 @@ -138,6 +139,7 @@ def test_structured_batches_model_axis_2(): class StructuredDatasetWithNames(AsyncDataset): def __init__(self, Height: Axis, Width: Axis, begin, end, stride): + super().__init__() self.Height = Height self.Width = Width self.begin = begin diff --git a/tests/test_supervised.py b/tests/test_supervised.py index b8bec4f45..54c99a102 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -2,6 +2,7 @@ from transformers import AutoTokenizer import haliax +from haliax import Axis from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example @@ -76,7 +77,7 @@ def test_supervised_eval(): "sources_len": np.array(45, dtype=np.int32), } - lm_ex = _prepare_supervised_example(ex, tokenizer) + lm_ex = _prepare_supervised_example(ex, tokenizer, Axis("position", 128)) assert lm_ex.loss_mask["position", 44] assert haliax.sum(lm_ex.loss_mask) == 1 From e5deb47895a901eb47f7bf7422088d2ad4c1b871 Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Fri, 8 Nov 2024 18:02:11 +0000 Subject: [PATCH 56/66] Fix transformer-engine attention import (#795) Renamed upstream --- src/levanter/models/attention.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 633feee68..0ae9a79f7 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -310,8 +310,8 @@ def _te_flash_attention( precision: PrecisionLike = None, block_size: Optional[int] = None, ): - from transformer_engine.jax.fused_attn import fused_attn # noqa: F401 - from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType # noqa: F401 + from transformer_engine.jax.attention import fused_attn # noqa: F401 + from transformer_engine.jax.attention import AttnBiasType, AttnMaskType, QKVLayout # noqa: F401 attention_dtype = attention_dtype or query.dtype query = query.astype(attention_dtype) @@ -358,14 +358,13 @@ def _te_flash_attention( raise NotImplementedError("Using bias with flash attention on GPU is not currently implemented.") attn_output = fused_attn( - q=q_, - k=k_, - v=v_, + qkv=(q_, k_, v_), bias=fused_attn_bias, mask=fused_attn_mask, seed=prng, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=QKVLayout.BSHD_BSHD_BSHD, scaling_factor=scaling_factor, dropout_probability=dropout, is_training=is_training, @@ -402,7 +401,7 @@ def _te_flash_attention( def _te_materialize_mask(KPos, QPos, batch_size, mask): - from transformer_engine.jax.fused_attn import AttnMaskType + from transformer_engine.jax.attention import AttnMaskType if isinstance(mask, NamedArray): raise NotImplementedError( From 0503001ff0a37f9752109ce4ab3df7cc7ddc5e23 Mon Sep 17 00:00:00 2001 From: David Hall Date: Fri, 8 Nov 2024 12:02:55 -0800 Subject: [PATCH 57/66] pretty sure we just don't need scipy (#797) --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0831605cb..b10358d07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ dependencies = [ "transformers>=4.41.2", "optax>=0.1.9", "wandb>=0.17.8", - "scipy<=1.12.0", "draccus>=0.8.0", "pyarrow>=11.0.0", "zstandard>=0.20.0", From b0d53a02d06e82bc7cab65e40461e6ea66ee27bd Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 11:58:39 -0800 Subject: [PATCH 58/66] tracker.finish to deal with subprocess stuff (#801) we use subprocess in marin to invoke levanter, but subprocesses don't wait on other subprocesses somehow, and so wandb doesn't get a chance to finish. This solves this --- src/levanter/main/train_lm.py | 3 +++ src/levanter/tracker/tensorboard.py | 3 +++ src/levanter/tracker/tracker.py | 22 ++++++++++++++++++++++ src/levanter/tracker/wandb.py | 4 ++++ 4 files changed, 32 insertions(+) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b411bd59e..f2ad3e7ce 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -268,6 +268,9 @@ def compute_log_probs(model, example): checkpointer = trainer.config.checkpointer.create(trainer.run_id) checkpointer.wait_until_finished() + # This isn't necessary except when Levanter is run in a subprocess (as happens w/ ray) + trainer.tracker.finish() + if __name__ == "__main__": levanter.config.main(main)() diff --git a/src/levanter/tracker/tensorboard.py b/src/levanter/tracker/tensorboard.py index 360c32171..e819d6459 100644 --- a/src/levanter/tracker/tensorboard.py +++ b/src/levanter/tracker/tensorboard.py @@ -43,6 +43,9 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio pylogger.exception(f"Error logging artifact {artifact_path} to {log_path}") return + def finish(self): + self.writer.close() + @TrackerConfig.register_subclass("tensorboard") @dataclass diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 8b6816f17..99fd217e5 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -46,6 +46,14 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + @abc.abstractmethod + def finish(self): + """ + Finish the tracker. This is called when the tracker is no longer needed. This can, e.g., + force a commit of all metrics. + """ + pass + def __enter__(self): import levanter.tracker.tracker_fns as tracker_fns @@ -81,6 +89,17 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio for tracker in self.loggers: tracker.log_artifact(artifact_path, name=name, type=type) + def finish(self): + excs = [] + for tracker in self.loggers: + try: + tracker.finish() + except Exception as e: + excs.append(e) + + if excs: + raise RuntimeError("Errors occurred when finishing trackers") from excs[0] + class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" @@ -109,6 +128,9 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass + def finish(self): + pass + @TrackerConfig.register_subclass("noop") @dataclasses.dataclass diff --git a/src/levanter/tracker/wandb.py b/src/levanter/tracker/wandb.py index 18f0251ec..981bebf83 100644 --- a/src/levanter/tracker/wandb.py +++ b/src/levanter/tracker/wandb.py @@ -72,6 +72,10 @@ def log_summary(self, metrics: dict[str, Any]): def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): self.run.log_artifact(artifact_path, name=name, type=type) + def finish(self): + logger.info("Finishing wandb run...") + self.run.finish() + def is_wandb_available(): try: From cfa4fd0b18bd72b6dea5fbe4a7951f3d9c5939c1 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 12:39:40 -0800 Subject: [PATCH 59/66] bulk delete using STS (#799) --- scripts/clean_old_checkpoints.py | 16 +++- scripts/gcs_bulk_delete.py | 151 +++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 5 deletions(-) create mode 100644 scripts/gcs_bulk_delete.py diff --git a/scripts/clean_old_checkpoints.py b/scripts/clean_old_checkpoints.py index f9074d036..fbd037b7b 100644 --- a/scripts/clean_old_checkpoints.py +++ b/scripts/clean_old_checkpoints.py @@ -19,7 +19,7 @@ def is_dir_of_checkpoints(path): return any("step-" in child for child in children) -def list_deletable_directories(base_dir): +def list_deletable_directories(base_dir, age): fs = fsspec.filesystem("gcs") run_ids = fs.ls(base_dir) @@ -58,8 +58,8 @@ def list_deletable_directories(base_dir): details = fs.ls(f"{path}/{file}", detail=True) if details: mtime = details[0]["mtime"] - age = (datetime.now(timezone.utc) - mtime).days - if age < AGE: + this_age = (datetime.now(timezone.utc) - mtime).days + if this_age < age: new = True break @@ -74,9 +74,15 @@ def list_deletable_directories(base_dir): # Usage example: if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="List directories that can be deleted.") + parser.add_argument("base_dir", help="The base directory to clean up.", type=str, nargs="+") + parser.add_argument("--age", help="The age in days of the checkpoints to delete.", type=int, default=30) + args = parser.parse_args() if len(sys.argv) < 2: print("Usage: python clean_old_checkpoints.py ") sys.exit(1) - for base_dir in sys.argv[1:]: - for path in list_deletable_directories(base_dir): + for base_dir in args.base_dir: + for path in list_deletable_directories(base_dir, args.age): print(f"gs://{path}") diff --git a/scripts/gcs_bulk_delete.py b/scripts/gcs_bulk_delete.py new file mode 100644 index 000000000..564e3cd60 --- /dev/null +++ b/scripts/gcs_bulk_delete.py @@ -0,0 +1,151 @@ +import re +import sys +import time +from datetime import datetime + +import google.auth +from google.api_core import operations_v1 +from google.cloud import storage_transfer_v1 +from google.type.date_pb2 import Date +from google.type.timeofday_pb2 import TimeOfDay + + +EMPTY_BUCKET = "levanter-empty" + + +def schedule_gcs_deletion_job(project_id, gcs_bucket_name, path_to_delete): + """ + Schedules an STS job to delete all files in a GCS path and waits for completion. + + This function uses a "trick" in STS to delete all files in a GCS path by transferring files from an empty bucket to + the target path with the `delete_objects_unique_in_sink` option enabled. This will delete all objects in the target + path that do not exist in the source path (which is empty). + + """ + + client = storage_transfer_v1.StorageTransferServiceClient() + + # Define the transfer job + transfer_job = storage_transfer_v1.types.TransferJob( + project_id=project_id, + transfer_spec=storage_transfer_v1.types.TransferSpec( + gcs_data_sink=storage_transfer_v1.types.GcsData(bucket_name=gcs_bucket_name, path=path_to_delete), + gcs_data_source=storage_transfer_v1.types.GcsData(bucket_name=EMPTY_BUCKET), + transfer_options=storage_transfer_v1.types.TransferOptions(delete_objects_unique_in_sink=True), + ), + schedule=storage_transfer_v1.types.Schedule( + schedule_start_date=Date( + year=datetime.utcnow().year, month=datetime.utcnow().month, day=datetime.utcnow().day + ), + start_time_of_day=TimeOfDay( + hours=datetime.utcnow().hour, minutes=datetime.utcnow().minute + 2 # Start in 2 minutes + ), + ), + status=storage_transfer_v1.types.TransferJob.Status.ENABLED, + description=f"Delete all files in {gcs_bucket_name}/{path_to_delete}", + ) + + # Create the transfer job + response = client.create_transfer_job(request={"transfer_job": transfer_job}) + print(f"Created transfer job: {response.name}") + + # Wait for job completion + wait_for_transfer_job(response.name, timeout=3600, poll_interval=2, project_id=project_id) + + +def wait_for_transfer_job(job_name: str, timeout: int, poll_interval: int, project_id: str): + """ + Waits for a Transfer Job to complete by polling the job status every 10 seconds. Raises a `TimeoutError` if the + job does not complete within the specified `timeout` (default: 30 minutes). + + Parameters: + job_name (str): The name of the Transfer Job to wait for. + timeout (int): The maximum number of seconds to wait for the job to complete. + poll_interval (int): The number of seconds to wait between polling the job status. + + Raises: + TimeoutError: If the Transfer Job does not complete within the specified `timeout`. + """ + print(f"[*] Waiting for Transfer Job :: {job_name}") + + transfer_client = storage_transfer_v1.StorageTransferServiceClient() + channel = transfer_client.transport.grpc_channel + operations_client = operations_v1.OperationsClient(channel) + start_time = time.time() + + from tqdm import tqdm + + pbar = tqdm(desc=f"Waiting for Transfer Job :: {job_name}", unit="B", unit_scale=True) + while time.time() - start_time < timeout: + if (time.time() - start_time) % poll_interval == 0: + # Prepare the filter string to get the operations for the job + filter_string = f'{{"project_id": "{project_id}", "job_names": ["{job_name}"]}}' + + # List transfer operations for the job + transfer_operations = operations_client.list_operations("transferOperations", filter_string) + + # Get the latest operation + latest_operation = None + for operation in transfer_operations: + if operation.metadata is not None: + latest_operation = operation + + if latest_operation: + # Extract relevant counters + # Unpack the Any type to get TransferOperation + metadata = storage_transfer_v1.types.TransferOperation() + # Access the descriptor from the _pb2 module + if latest_operation.metadata.Is(metadata._pb.DESCRIPTOR): + latest_operation.metadata.Unpack(metadata._pb) + + objects_deleted = metadata.counters.objects_deleted_from_sink + objects_found = metadata.counters.objects_found_only_from_sink + bytes_found_only_from_sink = metadata.counters.bytes_found_only_from_sink + bytes_deleted_from_sink = metadata.counters.bytes_deleted_from_sink + + # Update the progress bar + pbar.total = bytes_found_only_from_sink + pbar.n = bytes_deleted_from_sink + pbar.set_postfix( + objects_deleted=objects_deleted, + objects_found=objects_found, + ) + pbar.update(0) + + if latest_operation.done: + print(f"[*] Transfer Job Completed :: {job_name}") + pbar.close() + return + + raise TimeoutError(f"Transfer Job did not complete within {timeout} seconds; check status for {job_name}") + + +def parse_gcs_url(gcs_url): + """Parse the GCS URL and return the bucket name and prefix path.""" + match = re.match(r"gs://([^/]+)/(.+)", gcs_url) + if match: + bucket_name, path_prefix = match.groups() + return bucket_name, path_prefix + else: + raise ValueError(f"Invalid GCS URL format: {gcs_url}") + + +if __name__ == "__main__": + # Check for correct usage + if len(sys.argv) != 2: + print("Usage: python gcs_bulk_delete.py gs://bucket_name/path/to/delete") + sys.exit(1) + + # Parse the GCS URL + gcs_url = sys.argv[1] + try: + bucket_name, path_prefix = parse_gcs_url(gcs_url) + except ValueError as e: + print(e) + sys.exit(1) + + # Get the project ID + credentials, project_id = google.auth.default() + + # Schedule the deletion job + schedule_gcs_deletion_job(project_id, bucket_name, path_prefix) From 2195263016ef63eaacdaec5aeb9782f6c0206c54 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 13:54:38 -0800 Subject: [PATCH 60/66] tweaks: truncate after pad for supervised (#800) --- src/levanter/data/text.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 7e92d200b..4cc000e59 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -708,7 +708,9 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po max_length=Pos.size, ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], Pos) + # padding doesn't do truncation, so we have to do it ourselves. + # Truncate from the left since we want to predict the last tokens + input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 From 433490e3858b87b5c3332d2bff7feaf48bc6124f Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 12 Nov 2024 13:54:50 -0800 Subject: [PATCH 61/66] Misc fixes from sweep (disable blocked CE by default) (#798) blocked CE is worse except it enables larger batch sizes --- src/levanter/doremi.py | 10 ++++++---- src/levanter/main/doremi_lm.py | 4 ++-- src/levanter/models/lm_model.py | 4 +++- src/levanter/models/loss.py | 22 +++++++++++++--------- src/levanter/models/mpt.py | 2 +- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/levanter/doremi.py b/src/levanter/doremi.py index 9d048b24f..6d9165cfc 100644 --- a/src/levanter/doremi.py +++ b/src/levanter/doremi.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Optional, Tuple, TypeVar +from typing import Mapping, Optional, Tuple, TypeVar import equinox as eqx import jax.numpy as jnp @@ -56,7 +56,7 @@ def estimate_mixture_weights( loss_fn: ComputeLossFunction[M, T], initial_proxy: M, ref: M, - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], sampling_weights: Optional[dict[str, float]] = None, *, validation_sets: Optional[dict[str, AsyncDataset[T]]] = None, @@ -184,7 +184,9 @@ def doremi_step(state: DoremiState, ref, batch, domains): # we're not actually going to use the trainer for very much but it holds hooks and sets up contexts with trainer: - tagged_mixture = domain_tagged_mixture(data_sources, sampling_weights, domain_to_index, key=data_key) + tagged_mixture: MixtureDataset = domain_tagged_mixture( + data_sources, sampling_weights, domain_to_index, key=data_key + ) state = load_checkpoint_or_initialize( DoremiState.init, trainer.checkpoint_path, @@ -263,7 +265,7 @@ def _prepare_ref_model(ref, trainer): def domain_tagged_mixture( - data_sources: dict[str, AsyncDataset[T]], + data_sources: Mapping[str, AsyncDataset[T]], weights: dict[str, float], domain_to_index: dict[str, int], *, diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 12b3e6ae0..742c3229c 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -109,7 +109,7 @@ def init_proxy_model(): train_datasets = config.data.training_sets(ref_model.Pos.size) valid_datasets = config.data.validation_sets(ref_model.Pos.size) - train_datasets = { + causal_train_datasets = { k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) for k, v in train_datasets.items() } @@ -122,7 +122,7 @@ def init_proxy_model(): loss_function, proxy_model, ref=ref_model, - data_sources=train_datasets, + data_sources=causal_train_datasets, trainer_config=config.trainer, optimizer=optimizer, domain_weight_step_size=config.doremi.domain_weight_step_size, diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 911e74b09..1a82aa7be 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -1,4 +1,5 @@ import abc +from dataclasses import dataclass from typing import Generic, Optional, Type, TypeVar import draccus @@ -48,6 +49,7 @@ def causal( # TODO: for some reason, mypy doesn't like the discover_packages_path argument? +@dataclass(frozen=True) class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -69,7 +71,7 @@ def Pos(self) -> Axis: def Embed(self) -> Axis: pass - cross_entropy_block_size: Optional[int] = 64000 + cross_entropy_block_size: Optional[int] = None """ The block size for computing cross-entropy loss. This is the number of tokens that are processed together in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large diff --git a/src/levanter/models/loss.py b/src/levanter/models/loss.py index 154fc66ac..d705eda4d 100644 --- a/src/levanter/models/loss.py +++ b/src/levanter/models/loss.py @@ -58,7 +58,9 @@ def next_token_loss( if block_size is None: # Full softmax computation - logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed, preferred_element_type=dtype) + logits = hax.dot(pred_embeddings, pred_lm_head, axis=Embed) + if dtype is not None: + logits = logits.astype(dtype) target_y_full = hax.nn.one_hot(target_y, Vocab, dtype=pred_embeddings.dtype) return cross_entropy_and_logsumexp_penalty( logits, @@ -261,9 +263,10 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] + + if dtype is not None: + logits_b = logits_b.astype(dtype) # Update max and logsumexp max_logit = hax.maximum(max_logit_prev, hax.max(logits_b, axis=Block)) # [Batch, Seq] @@ -278,7 +281,7 @@ def process_block(block_idx, acc, current_block_size): # Update sumV. This is actually unnecessary if we're using one-hot targets # sV = sV_prev + hax.sum(target_y_b, axis=Label.name) - loss += hax.dot(logits_b, target_y_b, axis=Block, preferred_element_type=dtype) # [Batch, Seq] + loss += hax.dot(logits_b, target_y_b, axis=Block) # [Batch, Seq] return loss, logsumexp, max_logit # , sV @@ -351,7 +354,7 @@ def _block_cross_entropy_backward( num_blocks = vocab_size // block_size grad_embeddings = hax.zeros(pred_embeddings.axes, dtype=pred_embeddings.dtype) - grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_embeddings.dtype) + grad_lm_head = hax.zeros(pred_lm_head.axes, dtype=pred_lm_head.dtype) def process_block(block_idx, acc, current_block_size): """ @@ -372,14 +375,15 @@ def process_block(block_idx, acc, current_block_size): # Materialize the logits for the current block lm_head_b = pred_lm_head[Label, hax.dslice(start, Block)] # [Contract, Block] - logits_b = hax.dot( - pred_embeddings, lm_head_b, axis=Contract, preferred_element_type=dtype - ) # [Batch, Seq, Block] + logits_b = hax.dot(pred_embeddings, lm_head_b, axis=Contract) # [Batch, Seq, Block] # Materialize the target for the current block (one-hot) target_y_block = _block_one_hot(Block, start, labels_y, logits_b.dtype) # [Batch, Seq, Block] # materialize the softmax for the current block + if dtype is not None: + logits_b = logits_b.astype(dtype) + p_b = hax.exp(logits_b - log_z) # [Batch, Seq, Block] delta_b = p_b - target_y_block diff --git a/src/levanter/models/mpt.py b/src/levanter/models/mpt.py index 0809d9d23..97b61f1dc 100644 --- a/src/levanter/models/mpt.py +++ b/src/levanter/models/mpt.py @@ -107,7 +107,7 @@ def from_hf(config: HfMptAttentionConfig): @LmConfig.register_subclass("mpt") -@dataclass +@dataclass(frozen=True) class MptConfig(HFCompatConfig): d_model: int = 768 n_heads: int = 12 From ba08164946df86bd6c80391630b8791b38fda194 Mon Sep 17 00:00:00 2001 From: Jennifer Zhou Date: Tue, 12 Nov 2024 19:54:17 -0800 Subject: [PATCH 62/66] Nit: typing (#802) Fixes a silly type checker complaint. (plus black wanted to reformat some things?) --- src/levanter/trainer.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index c7c1a5285..1d057484a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -65,7 +65,7 @@ X = TypeVar("X") # Input S = TypeVar("S", bound=TrainerState) -DEFAULT_JAX_CONFIG = { +DEFAULT_JAX_CONFIG: Dict[str, JsonAtom] = { "jax_threefry_partitionable": True, "jax_softmax_custom_jvp": True, } @@ -331,7 +331,12 @@ def init_state_and_model(model_init, training_key): model = model_init() # only force trainable params to param precision. Other params are cast to compute precision state = TrainerState.init( - self.optimizer, model, key=training_key, is_trainable=is_trainable, mp=self.mp, fp8=self.fp8 + self.optimizer, + model, + key=training_key, + is_trainable=is_trainable, + mp=self.mp, + fp8=self.fp8, ) return state @@ -444,7 +449,10 @@ def eval_loss(model, *batch, **batch_kwargs): self.add_hook( callbacks.compute_validation_loss( - eval_loss, eval_loader, max_batches=self.config.max_eval_batches, name=name + eval_loss, + eval_loader, + max_batches=self.config.max_eval_batches, + name=name, ), every=self.config.steps_per_eval, ) @@ -497,7 +505,13 @@ def obj_fun(trainable_model): def _compute_gradients_microbatched(self, loss_fn, model: M, *batch, **batch_kwargs) -> tuple[Scalar, M]: grad_fn = eqx.filter_value_and_grad(loss_fn, has_aux=False) mbs = self.config.microbatch_size - grad_fn = microbatched(grad_fn, self.TrainBatch, mbs, self.parameter_axis_mapping, self.compute_axis_mapping) + grad_fn = microbatched( + grad_fn, + self.TrainBatch, + mbs, + self.parameter_axis_mapping, + self.compute_axis_mapping, + ) with hax.axis_mapping(self.compute_axis_mapping): return grad_fn(model, *batch, **batch_kwargs) @@ -569,7 +583,7 @@ class TrainerConfig: """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.""" initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from - jax_config: Dict[str, JsonAtom] = field( + jax_config: Mapping[str, JsonAtom] = field( default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG) ) # config to pass to jax.config.update @@ -597,7 +611,10 @@ def microbatch_size(self): def __post_init__(self): if self.wandb is not None: - warnings.warn("wandb is deprecated. use tracker with type wandb instead", DeprecationWarning) + warnings.warn( + "wandb is deprecated. use tracker with type wandb instead", + DeprecationWarning, + ) self.tracker = self.wandb def initialize(self): From 63f2f3aa02dcf211a9067e69f7ca6b1a675ac3ee Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 13 Nov 2024 15:33:14 -0800 Subject: [PATCH 63/66] Support multiple supervisde evals, some cleanup around that (#803) --- ...ervised.yaml => gpt2_small_fast_eval.yaml} | 17 +- examples/sft/sft.py | 6 +- pyproject.toml | 2 +- src/levanter/compat/hf_checkpoints.py | 7 +- src/levanter/data/_preprocessor.py | 2 +- src/levanter/data/sharded_datasource.py | 112 ++++---- src/levanter/data/text.py | 254 +++++++++++++++--- src/levanter/eval.py | 3 +- src/levanter/main/train_lm.py | 20 +- src/levanter/utils/hf_utils.py | 5 + tests/test_supervised.py | 4 +- 11 files changed, 309 insertions(+), 123 deletions(-) rename config/{gpt2_small_fast_supervised.yaml => gpt2_small_fast_eval.yaml} (65%) diff --git a/config/gpt2_small_fast_supervised.yaml b/config/gpt2_small_fast_eval.yaml similarity index 65% rename from config/gpt2_small_fast_supervised.yaml rename to config/gpt2_small_fast_eval.yaml index 93675366d..14638db1b 100644 --- a/config/gpt2_small_fast_supervised.yaml +++ b/config/gpt2_small_fast_eval.yaml @@ -13,12 +13,17 @@ data: tokenizer: gpt2 cache_dir: "gs://levanter-data/tokenized/data_mix" supervised_data: - validation_urls: - - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-dev-evaluation.jsonl.gz" - - "gs://marin-us-central2/benchmarks/mmlu/mmlu-*-validation-evaluation.jsonl.gz" - cache_dir: "gs://marin-us-central2/benchmarks/tokenized-gpt2/mmlu/" - input_field: "input" - output_field: "output" + mmlu: + validation_urls: + - "gs://marin-us-central2/evaluation/mmlu-eval-subject-2eb39e/cais/*-validation-evaluation.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized-gpt2/mmlu/" + tags: [ "e"] + arc_easy: + validation_urls: + - "gs://marin-us-central2/evaluation/arc-easy-b39e70/allenai/ai2_arc-ARC-Easy-validation-evaluation.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized-gpt2/arc_easy/" + tags: [ "arc", "e"] + model: type: gpt2 hidden_dim: 768 diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 152781b0b..173c79212 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -15,7 +15,7 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset +from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset from levanter.main.train_lm import TrainLmConfig from levanter.models.lm_model import LmHeadModel, compute_next_token_loss from levanter.trainer import Trainer @@ -100,7 +100,7 @@ def train(config: SFTConfig): if config.dataset_type == DatasetType.CHAT_JSONL: assert config.chat_train_urls is not None assert config.supervised_data is not None - chat_config = ChatSFTDatasetConfig( + chat_config = ChatUrlDataSourceConfig( cache_dir=config.supervised_data.cache_dir, train_urls=config.chat_train_urls, # No validation in this config messages_field=config.messages_field, @@ -110,7 +110,7 @@ def train(config: SFTConfig): train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None - train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos) + train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) diff --git a/pyproject.toml b/pyproject.toml index b10358d07..abca1405d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "transformers>=4.41.2", "optax>=0.1.9", "wandb>=0.17.8", - "draccus>=0.8.0", + "draccus>=0.9.3", "pyarrow>=11.0.0", "zstandard>=0.20.0", "datasets>=3.1.0,<4.0", diff --git a/src/levanter/compat/hf_checkpoints.py b/src/levanter/compat/hf_checkpoints.py index ce267041c..7a116acae 100644 --- a/src/levanter/compat/hf_checkpoints.py +++ b/src/levanter/compat/hf_checkpoints.py @@ -39,6 +39,7 @@ from levanter.trainer import StepInfo from levanter.utils import jax_utils from levanter.utils.cloud_utils import temp_dir_before_upload +from levanter.utils.hf_utils import HfTokenizer from levanter.utils.jax_utils import best_effort_sharding, local_cpu_mesh, use_cpu_device from levanter.utils.py_utils import dataclass_with_default_init, logical_cpu_memory_size @@ -872,7 +873,7 @@ def cb(step: StepInfo): def arbitrary_load_from_hf( model_name_or_path, from_pretrained_lambda, revision=None, local_cache_dir=None, trust_remote_code=True -) -> Union[PreTrainedTokenizerBase | ProcessorMixin]: +) -> Union[HfTokenizer | ProcessorMixin]: is_url_like = urlparse(model_name_or_path).scheme != "" if is_url_like: if revision is not None: @@ -889,9 +890,7 @@ def arbitrary_load_from_hf( return from_pretrained_lambda(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code) -def load_tokenizer( - model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True -) -> PreTrainedTokenizerBase: +def load_tokenizer(model_name_or_path, revision=None, local_cache_dir=None, trust_remote_code=True) -> HfTokenizer: """Like AutoTokenizer.from_pretrained, but works with gs:// paths or anything on fsspec""" return arbitrary_load_from_hf( model_name_or_path, diff --git a/src/levanter/data/_preprocessor.py b/src/levanter/data/_preprocessor.py index 77b91617f..dd6578667 100644 --- a/src/levanter/data/_preprocessor.py +++ b/src/levanter/data/_preprocessor.py @@ -197,7 +197,7 @@ def __call__(self, batch): match transform: case _MapTransform(fn=fn): - batch = map(fn, batch) + batch = [fn(x) for x in batch] case _BatchMapTransform(fn=fn): batch = fn(batch) is_soa_form = isinstance(batch, dict) or isinstance(batch, pa.RecordBatch) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index cca3156b8..9dca9b618 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -184,31 +184,6 @@ def gcs_glob(pattern: str) -> list[str]: return matching_urls -def datasource_from_chat_jsonl( - urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" -) -> "ShardedDataSource[dict]": - """Creates a ShardedDataSource from JSONL files containing chat messages. - - Args: - urls: Sequence of URLs or glob patterns pointing to JSONL files - messages_field: Field name containing the messages in each JSON object - input_role: Role identifier for input messages - output_role: Role identifier for output messages - - Returns: - ShardedDataSource configured for chat data - """ - # Expand any glob patterns in the URLs - expanded_urls = [] - for url in urls: - if any(c in url for c in "*?[]"): - expanded_urls.extend(gcs_glob(url)) - else: - expanded_urls.append(url) - - return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) - - def datasource_from_hf(id: str, *, split, **kwargs) -> ShardedDataSource[dict]: """ Create a ShardedDataset from a HuggingFace dataset. Arguments are passed to load_dataset. @@ -288,14 +263,49 @@ class TextUrlDataSource(ShardedDataSource[str]): def __init__(self, urls, text_key="text"): self.urls = urls - self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) self.text_key = text_key + self.base_ds = UrlDataSource(urls, columns=[text_key]) @property def shard_names(self) -> Sequence[str]: - return list(self._shard_name_to_url_mapping.keys()) + return self.base_ds.shard_names def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: + url = self.base_ds._shard_name_to_url_mapping[shard_name] + i = 0 + compression = "infer" + if url.endswith(".zstd"): # hacky way to detect zstd + compression = "zstd" + + format = _sniff_format_for_dataset(url) + + # special case for txt files + if format == ".txt": + with fsspec.open(url, "r", compression=compression) as f: + for line in f: + if i >= row: + yield line + i += 1 + else: + for doc in self.base_ds.open_shard_at_row(shard_name, row): + yield doc[self.text_key] + + +class UrlDataSource(ShardedDataSource[dict]): + """ + Dataset for various dict-like formats. + """ + + def __init__(self, urls, columns=None): + self.urls = urls + self._shard_name_to_url_mapping = _mk_shard_name_mapping(urls) + self.columns = columns + + @property + def shard_names(self) -> Sequence[str]: + return list(self._shard_name_to_url_mapping.keys()) + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: url = self._shard_name_to_url_mapping[shard_name] i = 0 compression = "infer" @@ -310,19 +320,18 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: # which is not nothing, but not ideal. for line in f: if i >= row: - yield json.loads(line)[self.text_key] - i += 1 - case ".txt": - with fsspec.open(url, "r", compression=compression) as f: - for line in f: - if i >= row: - yield line + obj = json.loads(line) + if self.columns: + yield {col: obj[col] for col in self.columns} i += 1 case ".json": with fsspec.open(url, "r", compression=compression) as f: data = json.load(f) for doc in data[row:]: - yield doc[self.text_key] + if self.columns: + yield {col: doc[col] for col in self.columns} + else: + yield doc case ".parquet": with fsspec.open(url, "rb", compression=compression) as f: parquet_file = pq.ParquetFile(f) @@ -347,11 +356,11 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[str]: # Read from the starting row group onwards for rg_idx in range(row_group_index, parquet_file.num_row_groups): - table = parquet_file.read_row_group(rg_idx, columns=[self.text_key]) + table = parquet_file.read_row_group(rg_idx, columns=self.columns) if rg_idx == row_group_index: table = table.slice(start_row_in_group) for record in table.to_pylist(): - yield record[self.text_key] + yield record case _: raise ValueError(f"Unknown format {format}") @@ -531,32 +540,6 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: return iter(data[row:]) -class ChatJsonlDataSource(JsonlDataSource): - """DataSource that reads JSONL files containing OpenAI chat format messages.""" - - def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): - super().__init__(urls) - self.messages_field = messages_field - self.input_role = input_role - self.output_role = output_role - - def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: - url = self._shard_name_to_url_mapping[shard_name] - i = 0 - with fsspec.open(url, "r", compression="infer") as f: - for line in f: - if i >= row: - data = json.loads(line) - messages = data[self.messages_field] - - # Extract input/output from messages - input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) - output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) - - yield {"input": input_msg, "output": output_msg} - i += 1 - - class ParquetDataSource(ShardedDataSource[dict]): def __init__(self, urls): self.urls = urls @@ -650,7 +633,8 @@ def shard_names(self) -> Sequence[str]: return self.source.shard_names def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[T]: - return map(self.fn, self.source.open_shard_at_row(shard_name, row)) + for doc in self.source.open_shard_at_row(shard_name, row): + yield self.fn(doc) class _BatchMappedShardedDataSource(ShardedDataSource[T], _TransformedDataset): diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 4cc000e59..f7764a8b2 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -3,21 +3,23 @@ import copy import dataclasses import functools +import json import logging import os from dataclasses import dataclass from functools import cached_property from itertools import chain -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Protocol, Sequence, Tuple, TypeAlias, TypeVar, Union import datasets import equinox as eqx +import fsspec import jax import numpy as np import regex import tensorstore as ts from draccus import field -from jax._src.random import PRNGKey +from jax.random import PRNGKey from jaxtyping import PRNGKeyArray from tokenizers import normalizers @@ -35,9 +37,8 @@ from levanter.store.cache import CacheOptions, TreeCache from levanter.store.jagged_array import JaggedArrayStore from levanter.store.tree_store import TreeStore -from levanter.utils import fsspec_utils from levanter.utils.fsspec_utils import expand_glob -from levanter.utils.hf_utils import num_cpus_used_by_tokenizer +from levanter.utils.hf_utils import HfTokenizer, num_cpus_used_by_tokenizer silence_transformer_nag() # noqa @@ -46,7 +47,14 @@ from levanter.compat.hf_checkpoints import load_tokenizer # noqa from levanter.data._preprocessor import BatchProcessor, U, dict_from_record_batch # noqa from levanter.data.metrics_monitor import LoggerMetricsMonitor, LoggingMetricsMonitor, MetricsMonitor # noqa -from levanter.data.sharded_datasource import ShardedDataSource, TextUrlDataSource, WrappedHFDataSource # noqa +from levanter.data.sharded_datasource import ( # noqa + JsonlDataSource, + ShardedDataSource, + TextUrlDataSource, + UrlDataSource, + WrappedHFDataSource, + gcs_glob, +) from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.store.cache import build_or_load_cache # noqa from levanter.utils.jax_utils import key_iterator, local_cpu_mesh, use_cpu_device # noqa @@ -328,9 +336,18 @@ def __call__(self, batch: Sequence[str]) -> list[dict]: needs_merge = [] if self.padding is not False: - encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False, padding=self.padding, max_length=self.max_length, truncation=True) # type: ignore + encoding = self.tokenizer( + batch, + return_attention_mask=self.return_attention_mask, + verbose=False, + padding=self.padding, + max_length=self.max_length, + truncation=True, + ) # type: ignore else: - encoding = self.tokenizer(batch, return_attention_mask=self.return_attention_mask, verbose=False) # type: ignore + encoding = self.tokenizer( + batch, return_attention_mask=self.return_attention_mask, verbose=False + ) # type: ignore if needs_merge: new_encoding = self._merge_split_encodings(batch, encoding, needs_merge) @@ -611,7 +628,7 @@ class LMTaskConfig(abc.ABC): If you want to shuffle in eras, set this to the era length""" @cached_property - def the_tokenizer(self) -> PreTrainedTokenizerBase: + def the_tokenizer(self) -> HfTokenizer: if self.tokenizer == "passthrough": return PassthroughTokenizer(self.vocab_size) else: @@ -648,6 +665,10 @@ def tagged_eval_sets( return [(eval_sets[name], tags[name]) for name in eval_sets] +CANONICAL_INPUT_FIELD = "prompt" +CANONICAL_OUTPUT_FIELD = "response" + + @dataclass class LMSupervisedDatasetConfig: """Config for supervised fine-tuning datasets""" @@ -662,15 +683,68 @@ class LMSupervisedDatasetConfig: validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files # Field names in the data - input_field: str = "prompt" # name of the input field - output_field: str = "response" # name of output field + input_field: str = CANONICAL_INPUT_FIELD # name of the input field + output_field: str = CANONICAL_OUTPUT_FIELD # name of output field # Optional metadata tags: Optional[List[str]] = None - name: Optional[str] = None -def preprocess_supervised_example( +class SupervisedSourceConfigBase(Protocol): + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + raise NotImplementedError + + input_field: str + output_field: str + tags: Optional[List[str]] + cache_dir: str + + +@dataclass(frozen=True) +class SupervisedHfSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + id: str + name: str | None = None + + streaming: bool = True + + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + return WrappedHFDataSource(self.id, split=split, name=self.name, streaming=self.streaming).map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) + + +@dataclass(frozen=True) +class SupervisedUrlSourceConfig(SupervisedSourceConfigBase): + cache_dir: str + train_urls: list[str] = dataclasses.field(default_factory=list) + validation_urls: list[str] = dataclasses.field(default_factory=list) + + input_field: str = CANONICAL_INPUT_FIELD + output_field: str = CANONICAL_OUTPUT_FIELD + tags: Optional[List[str]] = None + + def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: + urls = self.train_urls if split == "train" else self.validation_urls + if not urls: + return None + + urls = [globbed for url in urls for globbed in expand_glob(url)] + + source = UrlDataSource(urls, columns=[self.input_field, self.output_field]) + return source.map( + lambda x: {CANONICAL_INPUT_FIELD: x[self.input_field], CANONICAL_OUTPUT_FIELD: x[self.output_field]} + ) + + +SupervisedSourceConfig: TypeAlias = Union[SupervisedHfSourceConfig, SupervisedUrlSourceConfig] + + +def _preprocess_supervised_example( batch, tokenizer: PreTrainedTokenizerBase, input_field: str, output_field: str ) -> dict: sources = [example[input_field] for example in batch] @@ -722,28 +796,69 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po return lm_ex -def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis): - import levanter.data +def mk_supervised_datasets( + sources: Mapping[str, SupervisedSourceConfigBase] | SupervisedSourceConfigBase, + split: str, + tokenizer: PreTrainedTokenizerBase, + Pos: hax.Axis, +) -> dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]]: + """ + Create supervised datasets from a set of sources. + + Returns: + A dictionary of dataset names to tuples of the dataset and the tags associated with the dataset. + """ + out: dict[str, tuple[AsyncDataset[LmExample], Sequence[str]]] = {} + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if isinstance(sources, Mapping): + for name, config in sources.items(): + source = config.get_shard_source(split) + if source is None: + continue + + ds = _cache_supervised_set( + source, config.cache_dir, tokenizer, Pos, config.input_field, config.output_field + ) + + if config.tags is None: + tags = [name] + else: + tags = config.tags + [name] - # Choose data source based on config - if config.hf_dataset_name is not None: - # Using HF dataset - dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split) + out[name] = (ds, tags) else: - # Using local files - validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)] - if not validation_urls: - raise ValueError("Must specify either hf_dataset_name or validation_urls") - dataset = levanter.data.datasource_from_jsonl(validation_urls) + source = sources.get_shard_source(split) # type: ignore + if source is not None: + ds = _cache_supervised_set( + source, sources.cache_dir, tokenizer, Pos, sources.input_field, sources.output_field + ) + tags = sources.tags or [] + if isinstance(sources, SupervisedHfSourceConfig): + name = sources.id + if sources.name is not None: + name = f"{name}/{sources.name}" + + tags = [name] + tags + else: + name = "default" + out[name] = (ds, tags) + + return out + + +def mk_supervised_dataset( + config: SupervisedSourceConfigBase, split: str, tokenizer: HfTokenizer, Pos: hax.Axis +) -> AsyncDataset[LmExample]: - input_field = config.input_field - output_field = config.output_field + source = config.get_shard_source(split) output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} - # Use the same preprocessing as before - dataset = dataset.map_batches( - lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field), + dataset = source.map_batches( # type: ignore + lambda ex: _preprocess_supervised_example(ex, tokenizer, config.input_field, config.output_field), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), output_exemplar=output_exemplar, @@ -757,19 +872,36 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) -@dataclass -class ChatSFTDatasetConfig(LMSupervisedDatasetConfig): +def _cache_supervised_set(source, cache_dir, tokenizer, Pos, input_field, output_field): + """ + Cache a supervised dataset into input_ids and sources_len. + """ + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + dataset = source.map_batches( + lambda ex: _preprocess_supervised_example(ex, tokenizer, input_field, output_field), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(cache_dir, await_finished=True) + ds = cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) + return ds + + +@dataclass(frozen=True) +class ChatUrlDataSourceConfig: """Config for loading JSONL files in OpenAI chat format for supervised fine-tuning.""" + cache_dir: str + train_urls: List[str] = field(default_factory=list) + validation_urls: List[str] = field(default_factory=list) + # Chat format specific fields messages_field: str = "messages" input_role: str = "user" output_role: str = "assistant" - train_urls: List[str] = field(default_factory=list) # Add this line def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: - import levanter.data - """Gets ShardedDataSource for either training or validation data.""" urls = self.validation_urls if split == "validation" else self.train_urls @@ -777,7 +909,7 @@ def get_shard_source(self, split: str) -> Optional[ShardedDataSource[dict]]: return None # Use the datasource_from_chat_jsonl function from sharded_datasource - return levanter.data.sharded_datasource.datasource_from_chat_jsonl( + return datasource_from_chat_jsonl( urls, messages_field=self.messages_field, input_role=self.input_role, output_role=self.output_role ) @@ -808,7 +940,7 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: def mk_chat_sft_dataset( - config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis + config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis ) -> AsyncDataset[LmExample]: """Creates a dataset from JSONL files containing chat format data for SFT.""" source = config.get_shard_source("train") @@ -1117,3 +1249,55 @@ def build_caches( @property def sources(self) -> Mapping[str, LMDatasetSourceConfig]: return self.configs + + +def datasource_from_chat_jsonl( + urls: Sequence[str], messages_field: str = "messages", input_role: str = "user", output_role: str = "assistant" +) -> "ShardedDataSource[dict]": + """Creates a ShardedDataSource from JSONL files containing chat messages. + + Args: + urls: Sequence of URLs or glob patterns pointing to JSONL files + messages_field: Field name containing the messages in each JSON object + input_role: Role identifier for input messages + output_role: Role identifier for output messages + + Returns: + ShardedDataSource configured for chat data + """ + # Expand any glob patterns in the URLs + expanded_urls = [] + for url in urls: + if any(c in url for c in "*?[]"): + expanded_urls.extend(gcs_glob(url)) + else: + expanded_urls.append(url) + + return ChatJsonlDataSource(expanded_urls, messages_field, input_role, output_role) + + +# TODO: switch to actual multi-turn +class ChatJsonlDataSource(JsonlDataSource): + """DataSource that reads JSONL files containing OpenAI chat format messages.""" + + def __init__(self, urls: Sequence[str], messages_field: str, input_role: str, output_role: str): + super().__init__(urls) + self.messages_field = messages_field + self.input_role = input_role + self.output_role = output_role + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + with fsspec.open(url, "r", compression="infer") as f: + for line in f: + if i >= row: + data = json.loads(line) + messages = data[self.messages_field] + + # Extract input/output from messages + input_msg = next(m["content"] for m in messages if m["role"] == self.input_role) + output_msg = next(m["content"] for m in messages if m["role"] == self.output_role) + + yield {"input": input_msg, "output": output_msg} + i += 1 diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 16342be4d..9fe9ab0d7 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -199,7 +199,6 @@ def eval_callback(step: StepInfo): log_dict = { # log micro average as just "loss" _join_prefix(prefix, "loss"): result.micro_avg_loss, - _join_prefix(prefix, "macro_loss"): result.macro_avg_loss, _join_prefix(prefix, "loading_time"): result.total_eval_loading_time, _join_prefix(prefix, "total_time"): time_fn(), } @@ -207,6 +206,8 @@ def eval_callback(step: StepInfo): logger.info(f"{prefix} loss: {result.micro_avg_loss:.3f}") has_tags = len(evaluator.dataset.tag_to_index) > 1 # 1 tag means there's no difference between micro and macro if has_tags: + log_dict[_join_prefix(prefix, "macro_loss")] = result.macro_avg_loss + for tag, loss in result.tag_macro_losses.items(): # don't log leaf tag macro losses because it doesn't mean anything different than micro loss if tag in evaluator.dataset.tag_to_index: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index f2ad3e7ce..99165c017 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -16,7 +16,13 @@ from levanter import callbacks from levanter.checkpoint import EpochCheckpointer, load_checkpoint from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback -from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig +from levanter.data.text import ( + CausalLmDataset, + LMDatasetConfig, + LMMixtureDatasetConfig, + SupervisedSourceConfig, + mk_supervised_datasets, +) from levanter.models.gpt2 import Gpt2Config from levanter.models.lm_model import LmConfig, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -30,7 +36,7 @@ @dataclass class TrainLmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) - supervised_data: Optional[LMSupervisedDatasetConfig] = None + supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None trainer: TrainerConfig = field(default_factory=TrainerConfig) model: LmConfig = field(default_factory=Gpt2Config) optimizer: OptimizerConfig = field(default_factory=AdamConfig) @@ -208,12 +214,14 @@ def main(config: TrainLmConfig): trainer.add_hook(cb, every=config.trainer.steps_per_eval) if config.supervised_data is not None: - logger.info("Using supervised data") - supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")] - # TODO Add tags + logger.info("Using supervised data for evals") + supervised_eval = mk_supervised_datasets(config.supervised_data, "validation", tokenizer, Pos) + + evals = list(supervised_eval.values()) + cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, - supervised_eval, + evals, tokenizer, trainer.device_mesh, compute_axis_mapping, diff --git a/src/levanter/utils/hf_utils.py b/src/levanter/utils/hf_utils.py index 922de4830..41e4488d4 100644 --- a/src/levanter/utils/hf_utils.py +++ b/src/levanter/utils/hf_utils.py @@ -13,6 +13,11 @@ _HF_TOKENIZER_OFF_VALUES = {"off", "false", "f", "no", "n", "0"} HfTokenizer: TypeAlias = PreTrainedTokenizerFast | PreTrainedTokenizer +""" +Type alias for a Hugging Face tokenizer. This is a union of the two tokenizer types. +While there is PreTrainedTokenizerBase, it doesn't have all methods that are implemented in both +PreTrainedTokenizer and PreTrainedTokenizerFast. grumble grumble. +""" def num_cpus_used_by_tokenizer(tokenizer) -> int: diff --git a/tests/test_supervised.py b/tests/test_supervised.py index 54c99a102..23f9e240c 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -4,7 +4,7 @@ import haliax from haliax import Axis -from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example +from levanter.data.text import _prepare_supervised_example, _preprocess_supervised_example def test_supervised_eval(): @@ -19,7 +19,7 @@ def test_supervised_eval(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - output = preprocess_supervised_example(examples, tokenizer, "input", "output") + output = _preprocess_supervised_example(examples, tokenizer, "input", "output") assert len(output["input_ids"][0]) == output["sources_len"][0] + 1 ex = { From 55db7fe42432abcf28c31fd8ff33e200aee46903 Mon Sep 17 00:00:00 2001 From: David Hall Date: Wed, 13 Nov 2024 23:53:57 -0800 Subject: [PATCH 64/66] QOL: remove need for use_cpu_mesh in data loading functions Fixes #748 --- src/levanter/data/audio.py | 9 +++-- src/levanter/data/text.py | 74 +++++++++++++++++++------------------- 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index fbb118cfe..9bfc1e142 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -33,7 +33,7 @@ from levanter.logging import silence_transformer_nag from levanter.models.asr_model import AudioTextExample from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache -from levanter.utils.jax_utils import key_iterator, local_cpu_mesh +from levanter.utils.jax_utils import key_iterator silence_transformer_nag() # noqa @@ -460,10 +460,9 @@ def __init__( @functools.partial(eqx.filter_jit, out_shardings=sharding) def _convert_example(inputs: AudioTextDict) -> "AudioTextExample": - with local_cpu_mesh(): - tokens = hax.named(inputs["input_ids"], self.TextPos) - audio_features = hax.named(inputs["input_features"], self.AudioPos) - return AudioTextExample.init(audio_features, tokens, ignore_id=self.ignore_id) + tokens = hax.named(inputs["input_ids"], self.TextPos) + audio_features = hax.named(inputs["input_features"], self.AudioPos) + return AudioTextExample.init(audio_features, tokens, ignore_id=self.ignore_id) super().__init__(self.dataset, _convert_example) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index f7764a8b2..234b5af14 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -57,7 +57,7 @@ ) from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.store.cache import build_or_load_cache # noqa -from levanter.utils.jax_utils import key_iterator, local_cpu_mesh, use_cpu_device # noqa +from levanter.utils.jax_utils import key_iterator, use_cpu_device # noqa T_co = TypeVar("T_co", covariant=True) @@ -239,22 +239,21 @@ def __init__( @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): - with local_cpu_mesh(): - tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None - this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) - - return example + tokens = hax.named(tokens, self.QPos) + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) + + return example super().__init__(self.dataset, _create_lm_example, key=key) @@ -773,27 +772,26 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po 2. Mask out the input and prompt if requested. 3. Create an LmExample with the input_ids as the input and the next token as the target. """ - with local_cpu_mesh(): - # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = tokenizer.pad( - {k: np.expand_dims(v, 0) for k, v in ex.items()}, - return_tensors="np", - padding="max_length", - max_length=Pos.size, - ) - ex = {k: v[0] for k, v in ex.items()} - # padding doesn't do truncation, so we have to do it ourselves. - # Truncate from the left since we want to predict the last tokens - input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) - # mask out padding and anything before the start of the target - loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 - - # don't predict the padding - targets = hax.roll(input_ids, -1, Pos) - loss_mask = loss_mask & (targets != tokenizer.pad_token_id) - loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) - return lm_ex + # annoyingly, pad expects things to be batched so we have to prepend a batch axis + ex = tokenizer.pad( + {k: np.expand_dims(v, 0) for k, v in ex.items()}, + return_tensors="np", + padding="max_length", + max_length=Pos.size, + ) + ex = {k: v[0] for k, v in ex.items()} + # padding doesn't do truncation, so we have to do it ourselves. + # Truncate from the left since we want to predict the last tokens + input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos) + # mask out padding and anything before the start of the target + loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 + + # don't predict the padding + targets = hax.roll(input_ids, -1, Pos) + loss_mask = loss_mask & (targets != tokenizer.pad_token_id) + loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + return lm_ex def mk_supervised_datasets( From a885f2061ed412de8f973d9ba1396007fe2899e2 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Thu, 14 Nov 2024 00:00:05 -0800 Subject: [PATCH 65/66] support auto hsdp (#804) --- src/levanter/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1d057484a..eee27cdeb 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -783,6 +783,10 @@ def _validate_and_set_defaults(self): if self.per_device_eval_parallelism == -1: self.per_device_eval_parallelism = self.per_device_parallelism + if self.replica_dcn_axis_size == -1: + self.replica_dcn_axis_size = self.num_slices + logger.info(f"Setting replica_dcn_axis_size to {self.replica_dcn_axis_size}") + class AllConfig(Protocol): trainer: TrainerConfig From f8ab21abd070436d5c9e3a0fcafb561e4ae8e1a5 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 14 Nov 2024 09:39:09 -0800 Subject: [PATCH 66/66] Revise SFT File (#793) PR to revise SFT file to avoid breaking changes to marin and for a request from @dlwh --- config/llama_sft_hf_ckpt.yaml | 13 ++++ {examples/sft => src/levanter/main}/sft.py | 72 ++++++++++++++++------ 2 files changed, 67 insertions(+), 18 deletions(-) create mode 100644 config/llama_sft_hf_ckpt.yaml rename {examples/sft => src/levanter/main}/sft.py (77%) diff --git a/config/llama_sft_hf_ckpt.yaml b/config/llama_sft_hf_ckpt.yaml new file mode 100644 index 000000000..a5742486c --- /dev/null +++ b/config/llama_sft_hf_ckpt.yaml @@ -0,0 +1,13 @@ +# Model configuration +model: + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: true + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false diff --git a/examples/sft/sft.py b/src/levanter/main/sft.py similarity index 77% rename from examples/sft/sft.py rename to src/levanter/main/sft.py index 173c79212..b3ff0e74c 100644 --- a/examples/sft/sft.py +++ b/src/levanter/main/sft.py @@ -1,8 +1,8 @@ import logging import os -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union import jax.random as jrandom import transformers @@ -15,10 +15,17 @@ from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset -from levanter.data.text import ChatUrlDataSourceConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset -from levanter.main.train_lm import TrainLmConfig -from levanter.models.lm_model import LmHeadModel, compute_next_token_loss -from levanter.trainer import Trainer +from levanter.data.text import ( + ChatUrlDataSourceConfig, + EpochDataset, + SupervisedSourceConfig, + mk_chat_sft_dataset, + mk_supervised_dataset, +) +from levanter.models.llama import LlamaConfig +from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig logger = logging.getLogger(__name__) @@ -38,24 +45,40 @@ class DatasetType(str, Enum): @dataclass -class SFTConfig(TrainLmConfig): +class SFTConfig: # inherit most of the config from TrainLmConfig - max_tune_length: int = 2048 + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=LlamaConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + supervised_data: Optional[SupervisedSourceConfig | dict[str, SupervisedSourceConfig]] = None + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + hf_save_path: Optional[str] = None + hf_upload: Optional[str] = None + hf_save_steps: int = 0 + + max_seq_len: int = 2048 model_name_or_path: str = "meta-llama/Llama-2-7b-hf" tokenizer: str = "meta-llama/Llama-2-7b-hf" # Add dataset type and chat-specific fields - dataset_type: DatasetType = DatasetType.HUGGINGFACE + dataset_type: DatasetType = DatasetType.CHAT_JSONL chat_train_urls: Optional[List[str]] = None messages_field: str = "messages" input_role: str = "user" output_role: str = "assistant" + data_seed: Optional[int] = None # if provided, will override the data seed from the trainer + + # if provided, will initialize from this checkpoint, used for llama style data mixture + epoch: int = 0 + def train(config: SFTConfig): tokenizer = transformers.AutoTokenizer.from_pretrained( config.tokenizer, - model_max_length=config.max_tune_length, + model_max_length=config.max_seq_len, padding_side="right", trust_remote_code=True, ) @@ -79,7 +102,11 @@ def train(config: SFTConfig): elif config.trainer.initialize_from is None: raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: - converter = None + if config.hf_save_steps: + converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True) + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None model_config = config.model levanter.initialize(config) @@ -100,8 +127,16 @@ def train(config: SFTConfig): if config.dataset_type == DatasetType.CHAT_JSONL: assert config.chat_train_urls is not None assert config.supervised_data is not None + + # Get the cache_dir safely + cache_dir = ( + config.supervised_data.cache_dir + if not isinstance(config.supervised_data, dict) + else next(iter(config.supervised_data.values())).cache_dir + ) + chat_config = ChatUrlDataSourceConfig( - cache_dir=config.supervised_data.cache_dir, + cache_dir=cache_dir, train_urls=config.chat_train_urls, # No validation in this config messages_field=config.messages_field, input_role=config.input_role, @@ -110,7 +145,13 @@ def train(config: SFTConfig): train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None - train_dataset = mk_supervised_dataset(config.supervised_data, "train", tokenizer, model_config.Pos) + if isinstance(config.supervised_data, dict): + # TODO: figure out what actually makes sense here + # for marin we will just use the url code path + config_to_use = next(iter(config.supervised_data.values())) + else: + config_to_use = config.supervised_data + train_dataset = mk_supervised_dataset(config_to_use, "train", tokenizer, model_config.Pos) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) @@ -161,11 +202,6 @@ def train(config: SFTConfig): loader = trainer.data_loader(train_dataset, trainer.TrainBatch) - if int(state.step) != 0: - logger.info(f"Resuming training from step {state.step}") - for i in range(state.step): - next(loader) - if config.hf_save_path is not None: # bit gross to reach this far into the config, but it's fine if config.trainer.checkpointer.append_run_id_to_base_path: