Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 16, 2024
1 parent 49afb5d commit c2ed3ee
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 9 deletions.
39 changes: 39 additions & 0 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
train_urls:
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/"
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 1024
use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000

epoch: False
4 changes: 2 additions & 2 deletions config/llama_7b_with_olmo_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,5 @@ optimizer:
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 0.01
data_shuffle: true

data_shuffle: true
10 changes: 5 additions & 5 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ async def current_len(self) -> Optional[int]:
async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
token_arrays = await self._await_token_cache()
dataset_len = await self.async_len()

wrapped_indices = [idx % dataset_len for idx in indices]
offsets = np.array(wrapped_indices) * self.seq_len

with ts.Batch():
out = []
for offset in offsets:
Expand Down Expand Up @@ -691,9 +691,9 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False
) -> AsyncDataset[np.ndarray]:

if epochs:
ds = self.token_epoch_dataset("train", seq_len, monitors)
else:
Expand Down Expand Up @@ -750,7 +750,7 @@ def token_seq_dataset(
if cache is None:
return None
return TokenSeqDataset(cache, seq_len)

def token_epoch_dataset(
self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> Optional[TokenSeqDataset]:
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main(config: TrainLmConfig):
# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id
)
Expand Down Expand Up @@ -244,7 +244,7 @@ def compute_log_probs(model, example):

## OK, actually run training!
trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


Expand Down

0 comments on commit c2ed3ee

Please sign in to comment.