Skip to content

Commit

Permalink
pre commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 28, 2024
1 parent dde75ac commit 5343096
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 30 deletions.
9 changes: 5 additions & 4 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tqdm_loggable.auto import tqdm

import levanter.tracker
from levanter.data import DataLoader, AsyncDataset
from levanter.data import AsyncDataset, DataLoader
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
Expand All @@ -27,8 +27,10 @@
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


logger = pylogging.getLogger(__name__)


def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None):
total_tokens = None

Expand All @@ -43,18 +45,17 @@ def log_epoch(step_info: StepInfo):

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step

# If we're doing multiple epochs, adjust the denominator
total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens
current_epoch = processed_tokens / total_tokens_for_epochs

levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch


def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int):

def log_length():
# If ds.async_len() is the only option, run it in an event loop inside the thread
import asyncio
Expand Down
22 changes: 13 additions & 9 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,17 +261,21 @@ def _async_checkpoint_remover(self):
self._do_rm_checkpoint(checkpoint)
self._checkpoint_being_removed = None


# In callbacks.py - Add a new callback that handles epoch checkpointing
class EpochCheckpointer:
"""
A separate checkpointing system that saves based on epochs.
Works alongside the regular step-based checkpointer without modifying core state.
"""
def __init__(self,
checkpointer: Checkpointer,
every_n_epochs: int = 1,
total_dataset_size: Optional[int] = None,
batch_size: int = 1):

def __init__(
self,
checkpointer: Checkpointer,
every_n_epochs: int = 1,
total_dataset_size: Optional[int] = None,
batch_size: int = 1,
):
self.checkpointer = checkpointer
self.every_n_epochs = every_n_epochs
self.total_dataset_size = total_dataset_size
Expand All @@ -281,20 +285,20 @@ def __init__(self,
def __call__(self, step_info):
if self.total_dataset_size is None:
return # Can't calculate epochs without dataset size

# Calculate current epoch from steps without modifying StepInfo
current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size

# Only save if we've moved to a new epoch and it matches our interval
if (current_epoch > self._last_saved_epoch and
current_epoch % self.every_n_epochs == 0):
if current_epoch > self._last_saved_epoch and current_epoch % self.every_n_epochs == 0:
# Use existing checkpointer's save_checkpoint method
self.checkpointer.save_checkpoint(
step_info,
f"epoch-{current_epoch}",
)
self._last_saved_epoch = current_epoch


def save_checkpoint(
tree: M,
step: int,
Expand Down
28 changes: 17 additions & 11 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class EpochDataset(AsyncDataset[T_co]):
:param dataset: The dataset to wrap.
:param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
"""

def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
self.dataset = dataset
self.max_epochs = max_epochs
Expand Down Expand Up @@ -111,7 +112,9 @@ async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:

# If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
if self.max_epochs is not None and epoch >= self.max_epochs:
raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")
raise StopIteration(
f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}"
)

# Wrap the indices within the bounds of the dataset length
wrapped_indices = [idx % ds_len for idx in indices]
Expand Down Expand Up @@ -139,7 +142,8 @@ async def wait_until_len_at_least(self, length: int) -> int:
return self.max_epochs * base_length

return base_length



class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from an underlying TreeCache.
Expand Down Expand Up @@ -639,19 +643,20 @@ def tagged_eval_sets(
@dataclass
class LMSupervisedDatasetConfig:
"""Config for supervised fine-tuning datasets"""

cache_dir: str = "cache/"

# HF dataset config
hf_dataset_name: Optional[str] = None # e.g. "tatsu-lab/alpaca" or "OpenAssistant/oasst1"
hf_dataset_split: str = "train" # which split to use

# Local files config
validation_urls: List[str] = field(default_factory=list) # paths to jsonl/json files

# Field names in the data
input_field: str = "prompt" # name of the input field
output_field: str = "response" # name of output field

# Optional metadata
tags: Optional[List[str]] = None
name: Optional[str] = None
Expand Down Expand Up @@ -705,7 +710,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) ->

def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase):
import levanter.data

# Choose data source based on config
if config.hf_dataset_name is not None:
# Using HF dataset
Expand All @@ -725,18 +730,19 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
# Use the same preprocessing as before
dataset = dataset.map_batches(
lambda ex: preprocess_supervised_example(ex, tokenizer, input_field, output_field),
batch_size=128,
batch_size=128,
num_cpus=num_cpus_used_by_tokenizer(tokenizer),
output_exemplar=output_exemplar
output_exemplar=output_exemplar,
)

dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
"""This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls"""
Expand Down
12 changes: 6 additions & 6 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: int = 0
epoch: int = 0


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -127,20 +127,22 @@ def main(config: TrainLmConfig):
ignore_index=config.data.ignore_token_id,
)


# add epoch logging if epochs specified
if config.epoch > 0:
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch), every=1
callbacks.log_epoch_progress(
total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch
),
every=1,
)

# Add epoch checkpoint callback
epoch_checkpointer = EpochCheckpointer(
checkpointer=trainer.config.checkpointer.create(trainer.run_id),
every_n_epochs=1, # Or configure as needed
total_dataset_size=total_tokens_future.result(),
batch_size=trainer.config.train_batch_size
batch_size=trainer.config.train_batch_size,
)
trainer.add_hook(epoch_checkpointer, every=1)

Expand Down Expand Up @@ -260,8 +262,6 @@ def compute_log_probs(model, example):
## OK, actually run training!
last_info = trainer.train(state, train_loader)



# If running EpochDataset save latest checkpoint by default
if trainer.config.checkpointer is not None and config.epoch > 0:
trainer.run_hooks(last_info, force=True)
Expand Down

0 comments on commit 5343096

Please sign in to comment.