From 039f3193529d8d1fe4cc0a755b28d24b46201cad Mon Sep 17 00:00:00 2001 From: David Hall Date: Sat, 30 Nov 2024 23:57:13 -0800 Subject: [PATCH 01/14] add segment ids --- src/levanter/models/attention.py | 21 ++++++++++++++++++-- tests/test_attention.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 0ae9a79f7..b8f370848 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -543,6 +543,16 @@ def _unflatten_bshd(attn_output, q_class, v_class): return attn_output +def _make_segment_mask(segment_ids, QPos, KPos, q_slice, k_slice) -> NamedArray: + """ + Make a segment mask for the attention. This is a mask that prevents attention between different segments. + """ + kv_segment_ids = segment_ids.rename({QPos: KPos})[KPos, k_slice] + q_segment_ids = segment_ids[QPos, q_slice] + + return q_segment_ids.broadcast_axis(KPos) == kv_segment_ids + + class AttentionMask(eqx.Module): """ @@ -563,7 +573,8 @@ class AttentionMask(eqx.Module): is_causal: bool = eqx.static_field() explicit_mask: Optional[NamedArray] = None - # TODO: add sequence packing + segment_ids: Optional[NamedArray] = None + # CF https://github.com/jax-ml/jax/blob/47858c4ac2fd4757a3b6fc5bb2981b71a71f00c2/jax/experimental/pallas/ops/tpu/flash_attention.py#L34 # TODO: add prefixlm # cf https://github.com/google-research/t5x/blob/51a99bff8696c373cc03918707ada1e98cbca407/t5x/examples/decoder_only/layers.py#L978 @@ -589,7 +600,13 @@ def materialize( else: explicit = None - return combine_masks_and(causal, explicit) + mask = combine_masks_and(causal, explicit) + + if self.segment_ids is not None: + segment_mask = _make_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice) + mask = combine_masks_and(mask, segment_mask) + + return mask @staticmethod def causal() -> "AttentionMask": diff --git a/tests/test_attention.py b/tests/test_attention.py index 7defcb4a0..ecbbd5735 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,16 +1,19 @@ import jax import jax.numpy as jnp import jax.random as jrandom +import numpy as np import pytest from chex import assert_trees_all_close import haliax as hax +from haliax import Axis from levanter.models.attention import ( AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention, _tpu_splash_attention, + dot_product_attention, ) from test_utils import skip_if_module_missing @@ -214,3 +217,34 @@ def test_tpu_splash_attention(): hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) assert hax_out.axes == flash_out.axes assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) + + +def test_segment_ids_are_respected(): + # test that we can't attend to something outside of the range + D = 2 + L = 10 + Pos = Axis("Pos", L) + Head = Axis("Head", D) + + keys = np.zeros((L, D), dtype=np.float32) + keys[0, 0] = 100.0 # really want to attend to this + values = np.zeros((L, D), dtype=np.float32) + values[0, 1] = 300.0 # check if we did attend + KPos = Pos.alias("KPos") + + query = np.ones((L, D), dtype=np.float32) + + query = hax.named(query, (Pos, Head)) + keys = hax.named(keys, (KPos, Head)) + values = hax.named(values, (KPos, Head)) + + segment_ids = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.int32) + segment_ids = hax.named(segment_ids, (Pos,)) + mask = AttentionMask(is_causal=True, segment_ids=segment_ids) + + result = dot_product_attention(Pos, KPos, Head, query, keys, values, mask=mask) + + # the first 3 positions should all have a value of 300.0 + assert_trees_all_close(result.array[0:3, 1], 300.0) + # the rest should be 0 + assert_trees_all_close(result.array[3:, 1], 0.0) From 913107886523dbd5ae11840c7d99b643ecc08bda Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 1 Dec 2024 10:28:54 -0800 Subject: [PATCH 02/14] better test, combine segment masks --- src/levanter/models/attention.py | 43 ++++++++++++++++++++++++++------ tests/test_attention.py | 12 ++++++--- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index b8f370848..6a207dfe9 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -7,6 +7,7 @@ import equinox as eqx import jax import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.splash_attention import SegmentIds from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec from jaxtyping import PRNGKeyArray @@ -543,14 +544,15 @@ def _unflatten_bshd(attn_output, q_class, v_class): return attn_output -def _make_segment_mask(segment_ids, QPos, KPos, q_slice, k_slice) -> NamedArray: +def _materialize_segment_mask(segment_ids, QPos, KPos, q_slice, k_slice) -> NamedArray: """ - Make a segment mask for the attention. This is a mask that prevents attention between different segments. + Make a segment mask for attention. This is a mask that prevents attention between different segments. """ kv_segment_ids = segment_ids.rename({QPos: KPos})[KPos, k_slice] q_segment_ids = segment_ids[QPos, q_slice] + sub_KPos = kv_segment_ids.resolve_axis(KPos.name) - return q_segment_ids.broadcast_axis(KPos) == kv_segment_ids + return q_segment_ids.broadcast_axis(sub_KPos) == kv_segment_ids class AttentionMask(eqx.Module): @@ -603,7 +605,7 @@ def materialize( mask = combine_masks_and(causal, explicit) if self.segment_ids is not None: - segment_mask = _make_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice) + segment_mask = _materialize_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice) mask = combine_masks_and(mask, segment_mask) return mask @@ -619,12 +621,32 @@ def explicit(mask: NamedArray) -> "AttentionMask": def __and__(self, other) -> "AttentionMask": is_causal = self.is_causal and other.is_causal explicit_mask = combine_masks_and(self.explicit_mask, other.explicit_mask) - return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask) + segment_ids = self._check_for_same_segment_ids(other) + + return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) def __or__(self, other) -> "AttentionMask": is_causal = self.is_causal or other.is_causal explicit_mask = combine_masks_or(self.explicit_mask, other.explicit_mask) - return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask) + segment_ids = self._check_for_same_segment_ids(other) + return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) + + def _check_for_same_segment_ids(self, other): + if self.segment_ids is not None and other.segment_ids is not None: + # only one segment mask is allowed + # b/c we might do this in jit, we use eqx.error_if + # in theory we can do this one by just assigning unique ids to each unique pair... + # (but i don't really anticipate needing this) + segment_ids = eqx.error_if( + self.segment_ids, + not haliax.all(self.segment_ids == other.segment_ids), + "Only one segment mask is allowed", + ) + elif self.segment_ids is not None: + segment_ids = self.segment_ids + else: + segment_ids = other.segment_ids + return segment_ids @overload @@ -867,6 +889,8 @@ def wrap_flash_attention(q, k, v): block_kv_dq=min(block_size, Sq), ) + segment_ids = None + if mask is None: base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) elif isinstance(mask, AttentionMask): @@ -878,6 +902,11 @@ def wrap_flash_attention(q, k, v): # This is going to be a pain to support if mask.explicit_mask is not None: raise NotImplementedError("Explicit masks are not yet supported for splash attention") + + if mask.segment_ids is not None: + # for now only support self attention + segment_ids = mask.segment_ids.array + segment_ids = SegmentIds(segment_ids, segment_ids) elif isinstance(mask, NamedArray): raise NotImplementedError("NamedArray masks are not yet supported for splash attention") else: @@ -893,7 +922,7 @@ def wrap_flash_attention(q, k, v): q = q.astype(attention_dtype) k = k.astype(attention_dtype) v = v.astype(attention_dtype) - return jax.vmap(splash_kernel)(q, k, v, segment_ids=None) + return jax.vmap(splash_kernel)(q, k, v, segment_ids=segment_ids) attn_output = wrap_flash_attention(q_, k_, v_) diff --git a/tests/test_attention.py b/tests/test_attention.py index ecbbd5735..af07eceed 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -9,6 +9,7 @@ from haliax import Axis from levanter.models.attention import ( + AttentionBackend, AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention, @@ -219,10 +220,11 @@ def test_tpu_splash_attention(): assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) -def test_segment_ids_are_respected(): +@pytest.mark.parametrize("impl", ["default", "jax_flash", "vanilla"]) +def test_segment_ids_are_respected(impl): # test that we can't attend to something outside of the range D = 2 - L = 10 + L = 256 Pos = Axis("Pos", L) Head = Axis("Head", D) @@ -238,11 +240,13 @@ def test_segment_ids_are_respected(): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) - segment_ids = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.int32) + segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) segment_ids = hax.named(segment_ids, (Pos,)) mask = AttentionMask(is_causal=True, segment_ids=segment_ids) - result = dot_product_attention(Pos, KPos, Head, query, keys, values, mask=mask) + result = dot_product_attention( + Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 + ) # the first 3 positions should all have a value of 300.0 assert_trees_all_close(result.array[0:3, 1], 300.0) From 039ecd7d0552787d9587b40a3b0534eaf2aba5d2 Mon Sep 17 00:00:00 2001 From: David Hall Date: Sun, 1 Dec 2024 19:40:27 -0800 Subject: [PATCH 03/14] add with_segment_ids --- src/levanter/models/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 6a207dfe9..f2f333965 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -618,6 +618,9 @@ def causal() -> "AttentionMask": def explicit(mask: NamedArray) -> "AttentionMask": return AttentionMask(is_causal=False, explicit_mask=mask) + def with_segment_ids(self, segment_ids: NamedArray) -> "AttentionMask": + return AttentionMask(is_causal=self.is_causal, explicit_mask=self.explicit_mask, segment_ids=segment_ids) + def __and__(self, other) -> "AttentionMask": is_causal = self.is_causal and other.is_causal explicit_mask = combine_masks_and(self.explicit_mask, other.explicit_mask) @@ -700,7 +703,6 @@ def materialize_mask( # TODO: padding mask # TODO: FCM mask? -# TODO: sequence packing mask def _try_tpu_splash_attention( From 8bf34bd9ea916651048aa175c72a723ed71a747b Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 08:49:53 -0800 Subject: [PATCH 04/14] actually use segment_ids --- examples/alpaca/alpaca.py | 2 +- examples/gsm8k-lora/gsm8k_lora.py | 2 +- src/levanter/data/text.py | 13 ++++++++----- src/levanter/main/doremi_lm.py | 16 ++++++++++++++-- src/levanter/main/eval_lm.py | 4 +++- src/levanter/main/lora_lm.py | 10 +++++++--- src/levanter/main/train_lm.py | 8 +++++++- src/levanter/main/viz_logprobs.py | 2 +- src/levanter/models/attention.py | 18 ++++++++++-------- src/levanter/models/lm_model.py | 15 ++++++++++++++- tests/test_hf_gpt2_serialize.py | 2 +- tests/test_text.py | 5 +++-- 12 files changed, 70 insertions(+), 27 deletions(-) diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index e8f805cde..13039f859 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -171,7 +171,7 @@ def _prepare_example(ex: dict) -> LmExample: loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id) return lm_ex return dataset.map(_prepare_example) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 97a9c06ab..1c0ff3446 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -119,7 +119,7 @@ def _prepare_example(ex: dict) -> LmExample: loss_mask = loss_mask & (targets != tokenizer.pad_token_id) else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id) return lm_ex return dataset.map(_prepare_example) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 1532a7d06..13c7ea44b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -222,15 +222,18 @@ def __init__( dataset: AsyncDataset[np.ndarray], QPos: Axis, KPos: Axis, + *, fcm_prob: float = 0.0, key: Optional[PRNGKey] = None, ignore_index: Optional[int] = None, + eos_id: Optional[int] = None, ): self.dataset = dataset self.QPos = QPos self.KPos = KPos self.fcm_prob = fcm_prob self.ignore_id = ignore_index + self.eos_id = eos_id self.key = key if self.fcm_prob > 0.0 and self.key is None: @@ -241,7 +244,7 @@ def __init__( @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id, eos_id=eos_id) if self.fcm_prob > 0: # masks for attention @@ -802,22 +805,22 @@ def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerB out = [] for ids, len in zip(truncated, lens): - causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id) + causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id) out.append(causal) return out -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id): +@functools.partial(jax.jit, static_argnums=(0, 3, 4)) +def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id, eos_id): # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= sources_len - 1 # don't predict the padding targets = hax.roll(input_ids, -1, Pos) loss_mask = loss_mask & (targets != pad_token_id) loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - return LmExample.causal(input_ids, loss_mask=loss_mask) + return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id) def mk_supervised_datasets( diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 742c3229c..d84e362f9 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -110,11 +110,23 @@ def init_proxy_model(): valid_datasets = config.data.validation_sets(ref_model.Pos.size) causal_train_datasets = { - k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + k: CausalLmDataset( + v, + config.model.Pos, + config.model.KeyPos, + ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, + ) for k, v in train_datasets.items() } valid_datasets = { - k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + k: CausalLmDataset( + v, + config.model.Pos, + config.model.KeyPos, + ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, + ) for k, v in valid_datasets.items() } diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 116a08f18..a4b3a9516 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -49,7 +49,9 @@ def main(config: EvalLmConfig): KeyPos = config.model.KeyPos if config.eval_on_train: - raw_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos) + raw_dataset = CausalLmDataset( + config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos, eos_id=tokenizer.eos_token_id + ) else: validation_set = config.data.validation_set(Pos.size) if validation_set is None: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 9eee109fe..966b4fc1e 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -85,7 +85,9 @@ def main(config: LoraLmConfig): if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") - train_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=data_key), Pos, KeyPos) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, eos_id=tokenizer.eos_token_id + ) train_loader = trainer.data_loader(train_dataset, Batch) # load the underlying hf model @@ -121,7 +123,7 @@ def loraize_hf_model(model): logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}") for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) trainer.add_eval_hook(eval_dataset, name=name) # boilerplate hooks and such @@ -129,7 +131,9 @@ def loraize_hf_model(model): logger.warning("No evaluation datasets provided.") for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) + eval_dataset = CausalLmDataset( + eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id + ) trainer.add_eval_hook(eval_dataset, name=name) trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 99165c017..423ecb938 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -131,6 +131,7 @@ def main(config: TrainLmConfig): Pos, KeyPos, ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, ) # add epoch logging if epochs specified @@ -199,7 +200,12 @@ def main(config: TrainLmConfig): logger.warning("No evaluation datasets provided.") else: causal_datasets = [ - (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) + ( + CausalLmDataset( + ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id + ), + tags, + ) for ds, tags in tagged_eval_datasets ] cb = levanter.eval.cb_tagged_lm_evaluate( diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index b00ba61d5..e0aa1596d 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -46,7 +46,7 @@ def main(config: VizGpt2Config): eval_loader = DataLoader( EvalBatch, - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore + CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos, eos_id=tokenizer.eos_token_id), # type: ignore 32, config.trainer.device_mesh, config.trainer.compute_axis_mapping, diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index f2f333965..9202b3b99 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -855,10 +855,13 @@ def flatten(axes): return PartitionSpec(b_out, h_out, s_out, d_out) + segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None + # BHSD physical_axes_q = _physical_axis_for_binning(q_class) physical_axes_k = _physical_axis_for_binning(k_class) physical_axes_v = _physical_axis_for_binning(v_class) + physical_axes_segments = pspec_for_axis(segment_ids.axes) if segment_ids is not None else None # MaxText uses a block size of 512 block_size = block_size or 512 @@ -871,11 +874,12 @@ def flatten(axes): physical_axes_q, physical_axes_k, physical_axes_v, + physical_axes_segments, ), out_specs=physical_axes_q, check_rep=False, ) - def wrap_flash_attention(q, k, v): + def wrap_flash_attention(q, k, v, segment_ids): # NB: inside the function, q, k, and v are partitioned, so in general the lengths of dims are not the same Sq = q.shape[2] Sk = k.shape[2] @@ -891,7 +895,10 @@ def wrap_flash_attention(q, k, v): block_kv_dq=min(block_size, Sq), ) - segment_ids = None + if mask.segment_ids is not None: + # for now only support self attention + segment_ids = segment_ids.array + segment_ids = SegmentIds(segment_ids, segment_ids) if mask is None: base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) @@ -900,15 +907,10 @@ def wrap_flash_attention(q, k, v): base_mask = splash_attention_mask.CausalMask(shape=(Sq, Sk)) else: base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) - # This is going to be a pain to support if mask.explicit_mask is not None: raise NotImplementedError("Explicit masks are not yet supported for splash attention") - if mask.segment_ids is not None: - # for now only support self attention - segment_ids = mask.segment_ids.array - segment_ids = SegmentIds(segment_ids, segment_ids) elif isinstance(mask, NamedArray): raise NotImplementedError("NamedArray masks are not yet supported for splash attention") else: @@ -926,7 +928,7 @@ def wrap_flash_attention(q, k, v): v = v.astype(attention_dtype) return jax.vmap(splash_kernel)(q, k, v, segment_ids=segment_ids) - attn_output = wrap_flash_attention(q_, k_, v_) + attn_output = wrap_flash_attention(q_, k_, v_, segment_ids) attn_output = haliax.named(attn_output, ("B", "H", "S", "D")) # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7e51ecadd..a77baef0d 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -25,7 +25,11 @@ class LmExample(eqx.Module): @staticmethod def causal( - tokens: hax.NamedArray, *, loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None + tokens: hax.NamedArray, + *, + loss_mask: Optional[hax.NamedArray] = None, + ignore_id: Optional[int] = None, + eos_id: Optional[int] = None, ) -> "LmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -45,6 +49,15 @@ def causal( loss_mask = loss_mask * ignore_mask attn_mask = AttentionMask.causal() + + if eos_id is not None: + # the next token after an eos token is in a new segment + eos_mask = hax.roll(tokens, 1, Pos) == eos_id + # first token is always in segment 0 + eos_mask = eos_mask.at[Pos, 0].set(False).astype(jnp.int32) + segment_ids = hax.cumsum(eos_mask, axis=Pos) + attn_mask = attn_mask.with_segment_ids(segment_ids) + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 87279d25b..9cc46ca0d 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -134,7 +134,7 @@ def torch_loss(model, input_ids) -> torch.Tensor: torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input.array)).to(torch.int64).unsqueeze(0)) def compute_loss(model: LmHeadModel, input_ids): - example = LmExample.causal(input_ids) + example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id) return compute_next_token_loss(model, example, key=None).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) diff --git a/tests/test_text.py b/tests/test_text.py index f293a9429..63d0afedb 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -30,9 +30,10 @@ def test_lm_example_handles_ignore_id(): tokens = hax.arange(Pos, dtype=jnp.int32) ignore_id = 6 + eos_id = 10 - ex_ignore = LmExample.causal(tokens, ignore_id=ignore_id) - ex_no_ignore = LmExample.causal(tokens) + ex_ignore = LmExample.causal(tokens, ignore_id=ignore_id, eos_id=eos_id) + ex_no_ignore = LmExample.causal(tokens, eos_id=eos_id) assert ex_ignore.loss_mask[Pos, ignore_id - 1] == 0 logits = hax.ones((Pos, Embed)) From 92247acd6e7eaaf545da2371aea249cd7610dce4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 08:57:39 -0800 Subject: [PATCH 05/14] comments --- src/levanter/models/attention.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 9202b3b99..cdc344c54 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -564,13 +564,21 @@ class AttentionMask(eqx.Module): Represents an attention mask in a structured way to make it easier to optimize attention for particular use cases (causal, prefix, etc.). It is anticipated that this will be extended with new types of masks as needed. + The abstraction is based on two concepts: + + 1) Materialization: An AttentionMask can be materialized for a particular slice of the query and key position axes. + Most naively, you can just get the whole mask as a NamedArray. However, in some cases, you might want to + only get a particular chunk (e.g. for flash attention). + 2) Combination: AttentionMasks are represented as an implicit conjunction of multiple masks, each with different + kinds of structure. You can combine masks with `&` and `|`. Due to the way jit works, we don't use inheritance + or similar to represent different kinds of masks. Instead, we use a single class with different fields. + In general, it should be safe to batch Attention Masks, but it is important that *all members of a batch have the - same sequence of combined masks*. Otherwise, the batching will not work and you'll get weird errors + same set of combined masks*. Otherwise, the batching will not work and you'll get weird errors - The interface exposed by this class is designed to work well with the attention functions in this module as - well as something like flash attention. + (Perhaps it's ok to use inheritance here? I'm not sure. Splash attention landed on inheritance, so maybe + that's a good sign.) - A mask can be materialized, in which case it returns the mask as a NamedArray. """ is_causal: bool = eqx.static_field() @@ -622,14 +630,14 @@ def with_segment_ids(self, segment_ids: NamedArray) -> "AttentionMask": return AttentionMask(is_causal=self.is_causal, explicit_mask=self.explicit_mask, segment_ids=segment_ids) def __and__(self, other) -> "AttentionMask": - is_causal = self.is_causal and other.is_causal + is_causal = self.is_causal or other.is_causal explicit_mask = combine_masks_and(self.explicit_mask, other.explicit_mask) segment_ids = self._check_for_same_segment_ids(other) return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) def __or__(self, other) -> "AttentionMask": - is_causal = self.is_causal or other.is_causal + is_causal = self.is_causal and other.is_causal explicit_mask = combine_masks_or(self.explicit_mask, other.explicit_mask) segment_ids = self._check_for_same_segment_ids(other) return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) From 381ce3ac0f0d9bcd7682721ae8e732120ef67165 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 09:02:42 -0800 Subject: [PATCH 06/14] cleanup --- src/levanter/models/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index cdc344c54..39f3cb3df 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -440,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes of Q, K, and V into the right bins to match that format. - NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD - the size of the axes is a bit flexible, - with the following conditions: + NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD. + + The size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) - S must be the same for K and V. Q's S can be different - H: Q's H must be a multiple of K's H (for GQA or MQA) From de04991cf6f9575c60bcbc4b0b51434cdd6cf94a Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 10:48:44 -0800 Subject: [PATCH 07/14] sigh --- tests/test_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index af07eceed..37df8ffaf 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -223,7 +223,8 @@ def test_tpu_splash_attention(): @pytest.mark.parametrize("impl", ["default", "jax_flash", "vanilla"]) def test_segment_ids_are_respected(impl): # test that we can't attend to something outside of the range - D = 2 + # splash needs 128 + D = 128 if impl == "default" else 2 L = 256 Pos = Axis("Pos", L) Head = Axis("Head", D) From bf14a2bc9019b7d3655abc0079a750dfa002f0bc Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 11:03:38 -0800 Subject: [PATCH 08/14] tol --- tests/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 37df8ffaf..3e830d1c8 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -250,6 +250,6 @@ def test_segment_ids_are_respected(impl): ) # the first 3 positions should all have a value of 300.0 - assert_trees_all_close(result.array[0:3, 1], 300.0) + assert_trees_all_close(result.array[0:3, 1], 300.0, atol=1e-3, rtol=1e-3) # the rest should be 0 - assert_trees_all_close(result.array[3:, 1], 0.0) + assert_trees_all_close(result.array[3:, 1], 0.0, atol=1e-3, rtol=1e-3) From df0f0ebf9d14956835b291271b86bdc911a73c99 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 11:32:11 -0800 Subject: [PATCH 09/14] sigh --- .github/workflows/tpu_unit_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tpu_unit_tests.yaml b/.github/workflows/tpu_unit_tests.yaml index abc0be76e..44d69db36 100644 --- a/.github/workflows/tpu_unit_tests.yaml +++ b/.github/workflows/tpu_unit_tests.yaml @@ -39,7 +39,7 @@ jobs: - name: Run most tests run: | export TPU_NAME=ci-run-${{ github.run_id }} - gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'" + gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "JAX_TRACEBACK_FILTERING=off PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'" # Something's wrong with these # # - name: Run forked tests From c29e07198412cb24737f62552b61265694365a66 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 11:54:12 -0800 Subject: [PATCH 10/14] kk --- tests/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 3e830d1c8..70aefe2f9 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -245,7 +245,7 @@ def test_segment_ids_are_respected(impl): segment_ids = hax.named(segment_ids, (Pos,)) mask = AttentionMask(is_causal=True, segment_ids=segment_ids) - result = dot_product_attention( + result = hax.named_jit(dot_product_attention)( Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 ) From 1bbb411fcecbc20cdbf8e956af37bd439e7002dc Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 12:44:20 -0800 Subject: [PATCH 11/14] better vmap --- src/levanter/models/attention.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 39f3cb3df..7044200ba 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -863,13 +863,28 @@ def flatten(axes): return PartitionSpec(b_out, h_out, s_out, d_out) - segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None - # BHSD physical_axes_q = _physical_axis_for_binning(q_class) physical_axes_k = _physical_axis_for_binning(k_class) physical_axes_v = _physical_axis_for_binning(v_class) + + # segment_ids + segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None physical_axes_segments = pspec_for_axis(segment_ids.axes) if segment_ids is not None else None + # do we have a batch axis in segment_ids? (needed for vmap below) + if segment_ids is not None: + index_of_seq_dim = segment_ids.axes.index(QPos) + other_indices = [i for i in range(len(segment_ids.axes)) if i != index_of_seq_dim] + if len(other_indices) > 1: + raise NotImplementedError( + f"Only one batch axis is supported in segment_ids right now (got {segment_ids.axes})" + ) + elif len(other_indices) == 1: + segment_batch_axis = other_indices[0] + else: + segment_batch_axis = None + else: + segment_batch_axis = None # MaxText uses a block size of 512 block_size = block_size or 512 @@ -934,7 +949,9 @@ def wrap_flash_attention(q, k, v, segment_ids): q = q.astype(attention_dtype) k = k.astype(attention_dtype) v = v.astype(attention_dtype) - return jax.vmap(splash_kernel)(q, k, v, segment_ids=segment_ids) + return jax.vmap( + lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis) + )(q, k, v, segment_ids) attn_output = wrap_flash_attention(q_, k_, v_, segment_ids) From 4fff4c339a6eee7870079c60dfb9f030b8156894 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 13:18:00 -0800 Subject: [PATCH 12/14] this? --- tests/test_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 70aefe2f9..15dfbcebb 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -242,6 +242,7 @@ def test_segment_ids_are_respected(impl): values = hax.named(values, (KPos, Head)) segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) + segment_ids = jax.device_put(segment_ids) segment_ids = hax.named(segment_ids, (Pos,)) mask = AttentionMask(is_causal=True, segment_ids=segment_ids) From 5052f363fde363129da2f1660e64ecbffd2014f0 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 15:55:22 -0800 Subject: [PATCH 13/14] this? --- tests/test_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index 15dfbcebb..4dde69579 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -241,6 +241,8 @@ def test_segment_ids_are_respected(impl): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) + query, keys, values = jax.device_put([query, keys, values]) + segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) segment_ids = jax.device_put(segment_ids) segment_ids = hax.named(segment_ids, (Pos,)) From 1a7729a19c23c3b17fb813ddecedca6f45c33cd6 Mon Sep 17 00:00:00 2001 From: David Hall Date: Mon, 2 Dec 2024 22:03:51 -0800 Subject: [PATCH 14/14] enough device puts and we're good --- tests/test_attention.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 4dde69579..140e3cf7c 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -4,6 +4,7 @@ import numpy as np import pytest from chex import assert_trees_all_close +from jax.sharding import Mesh import haliax as hax from haliax import Axis @@ -241,16 +242,21 @@ def test_segment_ids_are_respected(impl): keys = hax.named(keys, (KPos, Head)) values = hax.named(values, (KPos, Head)) - query, keys, values = jax.device_put([query, keys, values]) + query, keys, values = jax.device_put( + [query, keys, values], jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1) + ) segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) - segment_ids = jax.device_put(segment_ids) + segment_ids = jax.device_put(segment_ids, jax.sharding.PositionalSharding(jax.devices())) segment_ids = hax.named(segment_ids, (Pos,)) mask = AttentionMask(is_causal=True, segment_ids=segment_ids) - result = hax.named_jit(dot_product_attention)( - Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 - ) + devices = jax.devices() + + with Mesh(devices, ("dp",)): + result = hax.named_jit(dot_product_attention)( + Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 + ) # the first 3 positions should all have a value of 300.0 assert_trees_all_close(result.array[0:3, 1], 300.0, atol=1e-3, rtol=1e-3)