Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add segment ids/sequence masking #827

Merged
merged 14 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tpu_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
- name: Run most tests
run: |
export TPU_NAME=ci-run-${{ github.run_id }}
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'"
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "JAX_TRACEBACK_FILTERING=off PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'"
# Something's wrong with these
#
# - name: Run forked tests
Expand Down
2 changes: 1 addition & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _prepare_example(ex: dict) -> LmExample:
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id)
return lm_ex

return dataset.map(_prepare_example)
Expand Down
2 changes: 1 addition & 1 deletion examples/gsm8k-lora/gsm8k_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _prepare_example(ex: dict) -> LmExample:
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id)
return lm_ex

return dataset.map(_prepare_example)
Expand Down
13 changes: 8 additions & 5 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,18 @@ def __init__(
dataset: AsyncDataset[np.ndarray],
QPos: Axis,
KPos: Axis,
*,
fcm_prob: float = 0.0,
key: Optional[PRNGKey] = None,
ignore_index: Optional[int] = None,
eos_id: Optional[int] = None,
):
self.dataset = dataset
self.QPos = QPos
self.KPos = KPos
self.fcm_prob = fcm_prob
self.ignore_id = ignore_index
self.eos_id = eos_id
self.key = key

if self.fcm_prob > 0.0 and self.key is None:
Expand All @@ -241,7 +244,7 @@ def __init__(
@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_lm_example(tokens, key):
tokens = hax.named(tokens, self.QPos)
example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)
example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id, eos_id=eos_id)

if self.fcm_prob > 0:
# masks for attention
Expand Down Expand Up @@ -802,22 +805,22 @@ def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerB

out = []
for ids, len in zip(truncated, lens):
causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id)
causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id)

out.append(causal)

return out


@functools.partial(jax.jit, static_argnums=(0, 3))
def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id):
@functools.partial(jax.jit, static_argnums=(0, 3, 4))
def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id, eos_id):
# mask out padding and anything before the start of the target
loss_mask = hax.arange(Pos) >= sources_len - 1
# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != pad_token_id)
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
return LmExample.causal(input_ids, loss_mask=loss_mask)
return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id)


def mk_supervised_datasets(
Expand Down
16 changes: 14 additions & 2 deletions src/levanter/main/doremi_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,23 @@ def init_proxy_model():
valid_datasets = config.data.validation_sets(ref_model.Pos.size)

causal_train_datasets = {
k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id)
k: CausalLmDataset(
v,
config.model.Pos,
config.model.KeyPos,
ignore_index=config.data.ignore_token_id,
eos_id=tokenizer.eos_token_id,
)
for k, v in train_datasets.items()
}
valid_datasets = {
k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id)
k: CausalLmDataset(
v,
config.model.Pos,
config.model.KeyPos,
ignore_index=config.data.ignore_token_id,
eos_id=tokenizer.eos_token_id,
)
for k, v in valid_datasets.items()
}

Expand Down
4 changes: 3 additions & 1 deletion src/levanter/main/eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def main(config: EvalLmConfig):
KeyPos = config.model.KeyPos

if config.eval_on_train:
raw_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos)
raw_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos, eos_id=tokenizer.eos_token_id
)
else:
validation_set = config.data.validation_set(Pos.size)
if validation_set is None:
Expand Down
10 changes: 7 additions & 3 deletions src/levanter/main/lora_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def main(config: LoraLmConfig):
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

train_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=data_key), Pos, KeyPos)
train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, eos_id=tokenizer.eos_token_id
)
train_loader = trainer.data_loader(train_dataset, Batch)

# load the underlying hf model
Expand Down Expand Up @@ -121,15 +123,17 @@ def loraize_hf_model(model):
logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos)
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id)
trainer.add_eval_hook(eval_dataset, name=name)

# boilerplate hooks and such
if len(eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")

for name, eval_dataset in eval_datasets.items():
eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id)
eval_dataset = CausalLmDataset(
eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id
)
trainer.add_eval_hook(eval_dataset, name=name)

trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1)
Expand Down
8 changes: 7 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def main(config: TrainLmConfig):
Pos,
KeyPos,
ignore_index=config.data.ignore_token_id,
eos_id=tokenizer.eos_token_id,
)

# add epoch logging if epochs specified
Expand Down Expand Up @@ -199,7 +200,12 @@ def main(config: TrainLmConfig):
logger.warning("No evaluation datasets provided.")
else:
causal_datasets = [
(CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags)
(
CausalLmDataset(
ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id
),
tags,
)
for ds, tags in tagged_eval_datasets
]
cb = levanter.eval.cb_tagged_lm_evaluate(
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/viz_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def main(config: VizGpt2Config):

eval_loader = DataLoader(
EvalBatch,
CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore
CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos, eos_id=tokenizer.eos_token_id), # type: ignore
32,
config.trainer.device_mesh,
config.trainer.compute_axis_mapping,
Expand Down
111 changes: 93 additions & 18 deletions src/levanter/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.splash_attention import SegmentIds
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec
from jaxtyping import PRNGKeyArray
Expand Down Expand Up @@ -439,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key):
NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes
of Q, K, and V into the right bins to match that format.

NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD
the size of the axes is a bit flexible,
with the following conditions:
NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD.

The size of the axes is a bit flexible, with the following conditions:
- B must be the same for all (TODO: is this true?)
- S must be the same for K and V. Q's S can be different
- H: Q's H must be a multiple of K's H (for GQA or MQA)
Expand Down Expand Up @@ -543,6 +544,17 @@ def _unflatten_bshd(attn_output, q_class, v_class):
return attn_output


def _materialize_segment_mask(segment_ids, QPos, KPos, q_slice, k_slice) -> NamedArray:
"""
Make a segment mask for attention. This is a mask that prevents attention between different segments.
"""
kv_segment_ids = segment_ids.rename({QPos: KPos})[KPos, k_slice]
q_segment_ids = segment_ids[QPos, q_slice]
sub_KPos = kv_segment_ids.resolve_axis(KPos.name)

return q_segment_ids.broadcast_axis(sub_KPos) == kv_segment_ids


class AttentionMask(eqx.Module):
"""

Expand All @@ -552,18 +564,27 @@ class AttentionMask(eqx.Module):
Represents an attention mask in a structured way to make it easier to optimize attention for particular use cases
(causal, prefix, etc.). It is anticipated that this will be extended with new types of masks as needed.

The abstraction is based on two concepts:

1) Materialization: An AttentionMask can be materialized for a particular slice of the query and key position axes.
Most naively, you can just get the whole mask as a NamedArray. However, in some cases, you might want to
only get a particular chunk (e.g. for flash attention).
2) Combination: AttentionMasks are represented as an implicit conjunction of multiple masks, each with different
kinds of structure. You can combine masks with `&` and `|`. Due to the way jit works, we don't use inheritance
or similar to represent different kinds of masks. Instead, we use a single class with different fields.

In general, it should be safe to batch Attention Masks, but it is important that *all members of a batch have the
same sequence of combined masks*. Otherwise, the batching will not work and you'll get weird errors
same set of combined masks*. Otherwise, the batching will not work and you'll get weird errors

The interface exposed by this class is designed to work well with the attention functions in this module as
well as something like flash attention.
(Perhaps it's ok to use inheritance here? I'm not sure. Splash attention landed on inheritance, so maybe
that's a good sign.)

A mask can be materialized, in which case it returns the mask as a NamedArray.
"""

is_causal: bool = eqx.static_field()
explicit_mask: Optional[NamedArray] = None
# TODO: add sequence packing
segment_ids: Optional[NamedArray] = None
# CF https://github.com/jax-ml/jax/blob/47858c4ac2fd4757a3b6fc5bb2981b71a71f00c2/jax/experimental/pallas/ops/tpu/flash_attention.py#L34
# TODO: add prefixlm
# cf https://github.com/google-research/t5x/blob/51a99bff8696c373cc03918707ada1e98cbca407/t5x/examples/decoder_only/layers.py#L978

Expand All @@ -589,7 +610,13 @@ def materialize(
else:
explicit = None

return combine_masks_and(causal, explicit)
mask = combine_masks_and(causal, explicit)

if self.segment_ids is not None:
segment_mask = _materialize_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice)
mask = combine_masks_and(mask, segment_mask)

return mask

@staticmethod
def causal() -> "AttentionMask":
Expand All @@ -599,15 +626,38 @@ def causal() -> "AttentionMask":
def explicit(mask: NamedArray) -> "AttentionMask":
return AttentionMask(is_causal=False, explicit_mask=mask)

def with_segment_ids(self, segment_ids: NamedArray) -> "AttentionMask":
return AttentionMask(is_causal=self.is_causal, explicit_mask=self.explicit_mask, segment_ids=segment_ids)

def __and__(self, other) -> "AttentionMask":
is_causal = self.is_causal and other.is_causal
is_causal = self.is_causal or other.is_causal
explicit_mask = combine_masks_and(self.explicit_mask, other.explicit_mask)
return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask)
segment_ids = self._check_for_same_segment_ids(other)

return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids)

def __or__(self, other) -> "AttentionMask":
is_causal = self.is_causal or other.is_causal
is_causal = self.is_causal and other.is_causal
explicit_mask = combine_masks_or(self.explicit_mask, other.explicit_mask)
return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask)
segment_ids = self._check_for_same_segment_ids(other)
return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids)

def _check_for_same_segment_ids(self, other):
if self.segment_ids is not None and other.segment_ids is not None:
# only one segment mask is allowed
# b/c we might do this in jit, we use eqx.error_if
# in theory we can do this one by just assigning unique ids to each unique pair...
# (but i don't really anticipate needing this)
segment_ids = eqx.error_if(
self.segment_ids,
not haliax.all(self.segment_ids == other.segment_ids),
"Only one segment mask is allowed",
)
elif self.segment_ids is not None:
segment_ids = self.segment_ids
else:
segment_ids = other.segment_ids
return segment_ids


@overload
Expand Down Expand Up @@ -661,7 +711,6 @@ def materialize_mask(

# TODO: padding mask
# TODO: FCM mask?
# TODO: sequence packing mask


def _try_tpu_splash_attention(
Expand Down Expand Up @@ -819,6 +868,24 @@ def flatten(axes):
physical_axes_k = _physical_axis_for_binning(k_class)
physical_axes_v = _physical_axis_for_binning(v_class)

# segment_ids
segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None
physical_axes_segments = pspec_for_axis(segment_ids.axes) if segment_ids is not None else None
# do we have a batch axis in segment_ids? (needed for vmap below)
if segment_ids is not None:
index_of_seq_dim = segment_ids.axes.index(QPos)
other_indices = [i for i in range(len(segment_ids.axes)) if i != index_of_seq_dim]
if len(other_indices) > 1:
raise NotImplementedError(
f"Only one batch axis is supported in segment_ids right now (got {segment_ids.axes})"
)
elif len(other_indices) == 1:
segment_batch_axis = other_indices[0]
else:
segment_batch_axis = None
else:
segment_batch_axis = None

# MaxText uses a block size of 512
block_size = block_size or 512

Expand All @@ -830,11 +897,12 @@ def flatten(axes):
physical_axes_q,
physical_axes_k,
physical_axes_v,
physical_axes_segments,
),
out_specs=physical_axes_q,
check_rep=False,
)
def wrap_flash_attention(q, k, v):
def wrap_flash_attention(q, k, v, segment_ids):
# NB: inside the function, q, k, and v are partitioned, so in general the lengths of dims are not the same
Sq = q.shape[2]
Sk = k.shape[2]
Expand All @@ -850,17 +918,22 @@ def wrap_flash_attention(q, k, v):
block_kv_dq=min(block_size, Sq),
)

if mask.segment_ids is not None:
# for now only support self attention
segment_ids = segment_ids.array
segment_ids = SegmentIds(segment_ids, segment_ids)

if mask is None:
base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))
elif isinstance(mask, AttentionMask):
if mask.is_causal:
base_mask = splash_attention_mask.CausalMask(shape=(Sq, Sk))
else:
base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk))

# This is going to be a pain to support
if mask.explicit_mask is not None:
raise NotImplementedError("Explicit masks are not yet supported for splash attention")

elif isinstance(mask, NamedArray):
raise NotImplementedError("NamedArray masks are not yet supported for splash attention")
else:
Expand All @@ -876,9 +949,11 @@ def wrap_flash_attention(q, k, v):
q = q.astype(attention_dtype)
k = k.astype(attention_dtype)
v = v.astype(attention_dtype)
return jax.vmap(splash_kernel)(q, k, v, segment_ids=None)
return jax.vmap(
lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis)
)(q, k, v, segment_ids)

attn_output = wrap_flash_attention(q_, k_, v_)
attn_output = wrap_flash_attention(q_, k_, v_, segment_ids)

attn_output = haliax.named(attn_output, ("B", "H", "S", "D"))
# the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v
Expand Down
Loading
Loading