From c2ed3ee23f2075ce86ef2724faadad8aa4f911fb Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 16 Oct 2024 16:37:40 -0700 Subject: [PATCH] fix ci --- config/llama_7b_tulu.yaml | 39 +++++++++++++++++++++++++++ config/llama_7b_with_olmo_config.yaml | 4 +-- src/levanter/data/text.py | 10 +++---- src/levanter/main/train_lm.py | 4 +-- 4 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 config/llama_7b_tulu.yaml diff --git a/config/llama_7b_tulu.yaml b/config/llama_7b_tulu.yaml new file mode 100644 index 000000000..1c059a509 --- /dev/null +++ b/config/llama_7b_tulu.yaml @@ -0,0 +1,39 @@ +data: + train_urls: + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz" + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz" + - "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz" + cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/" + tokenizer: "allenai/OLMo-1B" +model: # 7B class model + type: llama + seq_len: 4096 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 1024 + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 5000 + +epoch: False \ No newline at end of file diff --git a/config/llama_7b_with_olmo_config.yaml b/config/llama_7b_with_olmo_config.yaml index 595fabe68..e41f7dbc2 100644 --- a/config/llama_7b_with_olmo_config.yaml +++ b/config/llama_7b_with_olmo_config.yaml @@ -31,5 +31,5 @@ optimizer: weight_decay: 0.1 min_lr_ratio: 0.1 warmup: 0.01 - - data_shuffle: true \ No newline at end of file + + data_shuffle: true diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 88c278de2..9605ff74c 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -93,10 +93,10 @@ async def current_len(self) -> Optional[int]: async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]: token_arrays = await self._await_token_cache() dataset_len = await self.async_len() - + wrapped_indices = [idx % dataset_len for idx in indices] offsets = np.array(wrapped_indices) * self.seq_len - + with ts.Batch(): out = [] for offset in offsets: @@ -691,9 +691,9 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): cache_dir: Optional[str] = "cache/" def train_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False + self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False ) -> AsyncDataset[np.ndarray]: - + if epochs: ds = self.token_epoch_dataset("train", seq_len, monitors) else: @@ -750,7 +750,7 @@ def token_seq_dataset( if cache is None: return None return TokenSeqDataset(cache, seq_len) - + def token_epoch_dataset( self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True ) -> Optional[TokenSeqDataset]: diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 9c511d31a..9134591f2 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -119,7 +119,7 @@ def main(config: TrainLmConfig): # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) - + train_dataset = CausalLmDataset( config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id ) @@ -244,7 +244,7 @@ def compute_log_probs(model, example): ## OK, actually run training! trainer.train(state, train_loader) - + # checkpointer.on_step(last_step, force=True)