Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix internal_eval lengths #794

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def train(config: SFTConfig):
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
converter = None
model_config = config.model

levanter.initialize(config)

Expand All @@ -106,10 +107,10 @@ def train(config: SFTConfig):
input_role=config.input_role,
output_role=config.output_role,
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down
1 change: 1 addition & 0 deletions src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]):
"""

def __init__(self, cache: TreeCache[AudioTextDict]):
super().__init__()
self.cache = cache

async def async_len(self) -> int:
Expand Down
13 changes: 12 additions & 1 deletion src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class AsyncDataset(DatasetBase[T_co]):
* `current_len`: Returns the current length of the dataset. This may be None if no current length is known.
"""

def __init__(self):
self._min_known_len = 0

@abc.abstractmethod
async def async_len(self) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -95,7 +98,12 @@ async def wait_until_len_at_least(self, length: int) -> int:
The default implementation is a naive busy-wait loop. You should override this method for more efficient
implementations.
"""
return await naive_busy_wait_until_len_at_least(self, length)
if length <= self._min_known_len:
return self._min_known_len

res_len = await naive_busy_wait_until_len_at_least(self, length)
self._min_known_len = max(self._min_known_len, res_len)
return res_len

def as_sync_dataset(self):
return SyncifiedDataset(self)
Expand Down Expand Up @@ -206,6 +214,7 @@ def __getitem__(self, index: int) -> T_co:

class AsyncifiedDataset(AsyncDataset[T_co]):
def __init__(self, dataset: SyncDataset[T_co]):
super().__init__()
self.dataset = dataset

async def async_len(self) -> int:
Expand Down Expand Up @@ -239,6 +248,7 @@ class ListAsyncDataset(AsyncDataset[T]):
"""

def __init__(self, data: list[T], is_complete: bool = False):
super().__init__()
self.data = data
self.is_complete = is_complete
if not is_complete:
Expand Down Expand Up @@ -315,6 +325,7 @@ def __init__(
*extra_args,
**extra_kwargs,
):
super().__init__()
self.dataset = dataset
self.fn = fn
self._extra_args = extra_args
Expand Down
1 change: 1 addition & 0 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
):
super().__init__()
self.weights = MixtureDataset._normalize_weights(weights)
self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0}
self.dataset_index = Index(self.datasets.keys())
Expand Down
2 changes: 2 additions & 0 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PermutationDataset(AsyncDataset[T_co]):
# TODO: add epoch reshuffling

def __init__(self, dataset: AsyncDataset[T_co], key: jax.random.PRNGKey):
super().__init__()
self.dataset = dataset
self.key = key
self._permutation: Optional[Permutation] = None
Expand Down Expand Up @@ -72,6 +73,7 @@ class EraShufflingDataset(AsyncDataset[T_co]):
"""

def __init__(self, dataset: AsyncDataset[T_co], era_length: int, *, key: jax.random.PRNGKey):
super().__init__()
self.dataset = dataset
self.era_length = era_length
self.key = key
Expand Down
24 changes: 16 additions & 8 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class EpochDataset(AsyncDataset[T_co]):
"""

def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
super().__init__()
self.dataset = dataset
self.max_epochs = max_epochs

Expand Down Expand Up @@ -154,6 +155,7 @@ class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""

def __init__(self, doc_cache: TreeCache[dict], seq_len: int):
super().__init__()
self.doc_cache = doc_cache
self.seq_len = seq_len
self._store: Optional[TreeStore] = None
Expand Down Expand Up @@ -687,7 +689,7 @@ def preprocess_supervised_example(
}


def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample:
def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample:
"""
Prepare an example for training. This function converts the (cached) batch encoding into an LmExample.

Expand All @@ -699,11 +701,15 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) ->
"""
with local_cpu_mesh():
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length")
ex = tokenizer.pad(
{k: np.expand_dims(v, 0) for k, v in ex.items()},
return_tensors="np",
padding="max_length",
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], "position")
input_ids = hax.named(ex["input_ids"], Pos)
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
Expand All @@ -714,7 +720,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) ->
return lm_ex


def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase):
def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis):
import levanter.data

# Choose data source based on config
Expand Down Expand Up @@ -746,7 +752,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))


@dataclass
Expand Down Expand Up @@ -799,7 +805,9 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict:
}


def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]:
def mk_chat_sft_dataset(
config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis
) -> AsyncDataset[LmExample]:
"""Creates a dataset from JSONL files containing chat format data for SFT."""
source = config.get_shard_source("train")
if source is None:
Expand All @@ -824,7 +832,7 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))


@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def tags(self):
def __init__(
self, datasets: Sequence[tuple[AsyncDataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None
):
super().__init__()
self.datasets = []
tag_index: dict[str, int] = {}
for i, (dataset, tags) in enumerate(datasets):
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def main(config: TrainLmConfig):

if config.supervised_data is not None:
logger.info("Using supervised data")
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")]
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")]
# TODO Add tags
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
Expand Down
1 change: 1 addition & 0 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
ledger: Optional["CacheLedger"],
_broker, # handle of _TreeStoreCacheBuilder
):
super().__init__()
self.cache_dir = cache_dir
self.ledger = ledger
self._was_already_finished = ledger is not None and ledger.is_finished
Expand Down
1 change: 1 addition & 0 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def platform_of_array(x):

class LogitDataset(AsyncDataset[Example]):
def __init__(self, W, noise, x_mask, x_bias, *, key):
super().__init__()
self.W = W
self.noise = noise
self.x_mask = x_mask
Expand Down
2 changes: 2 additions & 0 deletions tests/test_new_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_local_batched_data_loading_model_axis_1():

class StructuredDataset(AsyncDataset):
def __init__(self, seq_len):
super().__init__()
self.seq_len = seq_len
self.begin = 0
self.end = 256
Expand Down Expand Up @@ -138,6 +139,7 @@ def test_structured_batches_model_axis_2():

class StructuredDatasetWithNames(AsyncDataset):
def __init__(self, Height: Axis, Width: Axis, begin, end, stride):
super().__init__()
self.Height = Height
self.Width = Width
self.begin = begin
Expand Down
3 changes: 2 additions & 1 deletion tests/test_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from transformers import AutoTokenizer

import haliax
from haliax import Axis

from levanter.data.text import _prepare_supervised_example, preprocess_supervised_example

Expand Down Expand Up @@ -76,7 +77,7 @@ def test_supervised_eval():
"sources_len": np.array(45, dtype=np.int32),
}

lm_ex = _prepare_supervised_example(ex, tokenizer)
lm_ex = _prepare_supervised_example(ex, tokenizer, Axis("position", 128))

assert lm_ex.loss_mask["position", 44]
assert haliax.sum(lm_ex.loss_mask) == 1
Loading