diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 629b556c2..152781b0b 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -80,6 +80,7 @@ def train(config: SFTConfig): raise ValueError("Must specify either --initialize_from_hf or --initialize_from") else: converter = None + model_config = config.model levanter.initialize(config) @@ -106,10 +107,10 @@ def train(config: SFTConfig): input_role=config.input_role, output_role=config.output_role, ) - train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) + train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None - train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) + train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 12695a20b..fbb118cfe 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -270,6 +270,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]): """ def __init__(self, cache: TreeCache[AudioTextDict]): + super().__init__() self.cache = cache async def async_len(self) -> int: diff --git a/src/levanter/data/dataset.py b/src/levanter/data/dataset.py index def0c158a..4d71241d4 100644 --- a/src/levanter/data/dataset.py +++ b/src/levanter/data/dataset.py @@ -48,6 +48,9 @@ class AsyncDataset(DatasetBase[T_co]): * `current_len`: Returns the current length of the dataset. This may be None if no current length is known. """ + def __init__(self): + self._min_known_len = 0 + @abc.abstractmethod async def async_len(self) -> int: raise NotImplementedError @@ -95,7 +98,12 @@ async def wait_until_len_at_least(self, length: int) -> int: The default implementation is a naive busy-wait loop. You should override this method for more efficient implementations. """ - return await naive_busy_wait_until_len_at_least(self, length) + if length <= self._min_known_len: + return self._min_known_len + + res_len = await naive_busy_wait_until_len_at_least(self, length) + self._min_known_len = max(self._min_known_len, res_len) + return res_len def as_sync_dataset(self): return SyncifiedDataset(self) @@ -206,6 +214,7 @@ def __getitem__(self, index: int) -> T_co: class AsyncifiedDataset(AsyncDataset[T_co]): def __init__(self, dataset: SyncDataset[T_co]): + super().__init__() self.dataset = dataset async def async_len(self) -> int: @@ -239,6 +248,7 @@ class ListAsyncDataset(AsyncDataset[T]): """ def __init__(self, data: list[T], is_complete: bool = False): + super().__init__() self.data = data self.is_complete = is_complete if not is_complete: @@ -315,6 +325,7 @@ def __init__( *extra_args, **extra_kwargs, ): + super().__init__() self.dataset = dataset self.fn = fn self._extra_args = extra_args diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index eb1bdfaaf..63c623e4b 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -53,6 +53,7 @@ def __init__( key: PRNGKeyArray | int, stop_strategy: str = StopStrategy.RESTART_STRATEGY, ): + super().__init__() self.weights = MixtureDataset._normalize_weights(weights) self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0} self.dataset_index = Index(self.datasets.keys()) diff --git a/src/levanter/data/permutation.py b/src/levanter/data/permutation.py index 6599d4974..66a1887fd 100644 --- a/src/levanter/data/permutation.py +++ b/src/levanter/data/permutation.py @@ -14,6 +14,7 @@ class PermutationDataset(AsyncDataset[T_co]): # TODO: add epoch reshuffling def __init__(self, dataset: AsyncDataset[T_co], key: jax.random.PRNGKey): + super().__init__() self.dataset = dataset self.key = key self._permutation: Optional[Permutation] = None @@ -72,6 +73,7 @@ class EraShufflingDataset(AsyncDataset[T_co]): """ def __init__(self, dataset: AsyncDataset[T_co], era_length: int, *, key: jax.random.PRNGKey): + super().__init__() self.dataset = dataset self.era_length = era_length self.key = key diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 0654d1dfa..7e92d200b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -75,6 +75,7 @@ class EpochDataset(AsyncDataset[T_co]): """ def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None): + super().__init__() self.dataset = dataset self.max_epochs = max_epochs @@ -154,6 +155,7 @@ class TokenSeqDataset(AsyncDataset[np.ndarray]): """ def __init__(self, doc_cache: TreeCache[dict], seq_len: int): + super().__init__() self.doc_cache = doc_cache self.seq_len = seq_len self._store: Optional[TreeStore] = None @@ -687,7 +689,7 @@ def preprocess_supervised_example( } -def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample: +def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample: """ Prepare an example for training. This function converts the (cached) batch encoding into an LmExample. @@ -699,11 +701,15 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> """ with local_cpu_mesh(): # annoyingly, pad expects things to be batched so we have to prepend a batch axis - ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length") + ex = tokenizer.pad( + {k: np.expand_dims(v, 0) for k, v in ex.items()}, + return_tensors="np", + padding="max_length", + max_length=Pos.size, + ) ex = {k: v[0] for k, v in ex.items()} - input_ids = hax.named(ex["input_ids"], "position") + input_ids = hax.named(ex["input_ids"], Pos) # mask out padding and anything before the start of the target - Pos = input_ids.resolve_axis("position") loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1 # don't predict the padding @@ -714,7 +720,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> return lm_ex -def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase): +def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis): import levanter.data # Choose data source based on config @@ -746,7 +752,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) @dataclass @@ -799,7 +805,9 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict: } -def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]: +def mk_chat_sft_dataset( + config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis +) -> AsyncDataset[LmExample]: """Creates a dataset from JSONL files containing chat format data for SFT.""" source = config.get_shard_source("train") if source is None: @@ -824,7 +832,7 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken tokenizer.pad_token = tokenizer.eos_token # Reuse the supervised prepare function directly - return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer)) + return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos)) @dataclass diff --git a/src/levanter/eval.py b/src/levanter/eval.py index 99e132dc2..16342be4d 100644 --- a/src/levanter/eval.py +++ b/src/levanter/eval.py @@ -60,6 +60,7 @@ def tags(self): def __init__( self, datasets: Sequence[tuple[AsyncDataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None ): + super().__init__() self.datasets = [] tag_index: dict[str, int] = {} for i, (dataset, tags) in enumerate(datasets): diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index b1b5d4aaa..b411bd59e 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -209,7 +209,7 @@ def main(config: TrainLmConfig): if config.supervised_data is not None: logger.info("Using supervised data") - supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")] + supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")] # TODO Add tags cb = levanter.eval.cb_tagged_lm_evaluate( EvalBatch, diff --git a/src/levanter/store/cache.py b/src/levanter/store/cache.py index 558bbfceb..e7a5306d4 100644 --- a/src/levanter/store/cache.py +++ b/src/levanter/store/cache.py @@ -191,6 +191,7 @@ def __init__( ledger: Optional["CacheLedger"], _broker, # handle of _TreeStoreCacheBuilder ): + super().__init__() self.cache_dir = cache_dir self.ledger = ledger self._was_already_finished = ledger is not None and ledger.is_finished diff --git a/tests/test_doremi.py b/tests/test_doremi.py index d2cf8b590..bbab04f52 100644 --- a/tests/test_doremi.py +++ b/tests/test_doremi.py @@ -38,6 +38,7 @@ def platform_of_array(x): class LogitDataset(AsyncDataset[Example]): def __init__(self, W, noise, x_mask, x_bias, *, key): + super().__init__() self.W = W self.noise = noise self.x_mask = x_mask diff --git a/tests/test_new_loader.py b/tests/test_new_loader.py index e6f9a3dd7..94b5238b2 100644 --- a/tests/test_new_loader.py +++ b/tests/test_new_loader.py @@ -64,6 +64,7 @@ def test_local_batched_data_loading_model_axis_1(): class StructuredDataset(AsyncDataset): def __init__(self, seq_len): + super().__init__() self.seq_len = seq_len self.begin = 0 self.end = 256 @@ -138,6 +139,7 @@ def test_structured_batches_model_axis_2(): class StructuredDatasetWithNames(AsyncDataset): def __init__(self, Height: Axis, Width: Axis, begin, end, stride): + super().__init__() self.Height = Height self.Width = Width self.begin = begin