Skip to content

Commit

Permalink
QOL: remove need for use_cpu_mesh in data loading functions
Browse files Browse the repository at this point in the history
Fixes #748
  • Loading branch information
dlwh committed Nov 14, 2024
1 parent 63f2f3a commit 55db7fe
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
9 changes: 4 additions & 5 deletions src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from levanter.logging import silence_transformer_nag
from levanter.models.asr_model import AudioTextExample
from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache
from levanter.utils.jax_utils import key_iterator, local_cpu_mesh
from levanter.utils.jax_utils import key_iterator


silence_transformer_nag() # noqa
Expand Down Expand Up @@ -460,10 +460,9 @@ def __init__(

@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _convert_example(inputs: AudioTextDict) -> "AudioTextExample":
with local_cpu_mesh():
tokens = hax.named(inputs["input_ids"], self.TextPos)
audio_features = hax.named(inputs["input_features"], self.AudioPos)
return AudioTextExample.init(audio_features, tokens, ignore_id=self.ignore_id)
tokens = hax.named(inputs["input_ids"], self.TextPos)
audio_features = hax.named(inputs["input_features"], self.AudioPos)
return AudioTextExample.init(audio_features, tokens, ignore_id=self.ignore_id)

super().__init__(self.dataset, _convert_example)

Expand Down
74 changes: 36 additions & 38 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)
from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa
from levanter.store.cache import build_or_load_cache # noqa
from levanter.utils.jax_utils import key_iterator, local_cpu_mesh, use_cpu_device # noqa
from levanter.utils.jax_utils import key_iterator, use_cpu_device # noqa


T_co = TypeVar("T_co", covariant=True)
Expand Down Expand Up @@ -239,22 +239,21 @@ def __init__(

@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_lm_example(tokens, key):
with local_cpu_mesh():
tokens = hax.named(tokens, self.QPos)
example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)

if self.fcm_prob > 0:
# masks for attention
# We support forgetful causal masking (FCM) which is a technique that improves training speed by
# randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention
# mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432
assert self.key is not None
this_key, key = jax.random.split(key)
fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key)
attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask)
example = dataclasses.replace(example, attn_mask=attn_mask)

return example
tokens = hax.named(tokens, self.QPos)
example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)

if self.fcm_prob > 0:
# masks for attention
# We support forgetful causal masking (FCM) which is a technique that improves training speed by
# randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention
# mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432
assert self.key is not None
this_key, key = jax.random.split(key)
fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key)
attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask)
example = dataclasses.replace(example, attn_mask=attn_mask)

return example

super().__init__(self.dataset, _create_lm_example, key=key)

Expand Down Expand Up @@ -773,27 +772,26 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Po
2. Mask out the input and prompt if requested.
3. Create an LmExample with the input_ids as the input and the next token as the target.
"""
with local_cpu_mesh():
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad(
{k: np.expand_dims(v, 0) for k, v in ex.items()},
return_tensors="np",
padding="max_length",
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
# padding doesn't do truncation, so we have to do it ourselves.
# Truncate from the left since we want to predict the last tokens
input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos)
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
return lm_ex
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad(
{k: np.expand_dims(v, 0) for k, v in ex.items()},
return_tensors="np",
padding="max_length",
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
# padding doesn't do truncation, so we have to do it ourselves.
# Truncate from the left since we want to predict the last tokens
input_ids = hax.named(ex["input_ids"][-Pos.size :], Pos)
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
return lm_ex


def mk_supervised_datasets(
Expand Down

0 comments on commit 55db7fe

Please sign in to comment.