diff --git a/.github/workflows/tpu_unit_tests.yaml b/.github/workflows/tpu_unit_tests.yaml index abc0be76e..44d69db36 100644 --- a/.github/workflows/tpu_unit_tests.yaml +++ b/.github/workflows/tpu_unit_tests.yaml @@ -39,7 +39,7 @@ jobs: - name: Run most tests run: | export TPU_NAME=ci-run-${{ github.run_id }} - gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'" + gcloud compute tpus tpu-vm ssh $TPU_NAME --zone ${TPU_ZONE} --command "JAX_TRACEBACK_FILTERING=off PYTHONPATH=$PYTHONPATH:levanter/tests CI=1 bash levanter/infra/run.sh pytest levanter/tests -m 'not entry'" # Something's wrong with these # # - name: Run forked tests diff --git a/examples/alpaca/alpaca.py b/examples/alpaca/alpaca.py index e8f805cde..13039f859 100644 --- a/examples/alpaca/alpaca.py +++ b/examples/alpaca/alpaca.py @@ -171,7 +171,7 @@ def _prepare_example(ex: dict) -> LmExample: loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id) return lm_ex return dataset.map(_prepare_example) diff --git a/examples/gsm8k-lora/gsm8k_lora.py b/examples/gsm8k-lora/gsm8k_lora.py index 97a9c06ab..1c0ff3446 100644 --- a/examples/gsm8k-lora/gsm8k_lora.py +++ b/examples/gsm8k-lora/gsm8k_lora.py @@ -119,7 +119,7 @@ def _prepare_example(ex: dict) -> LmExample: loss_mask = loss_mask & (targets != tokenizer.pad_token_id) else: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32) - lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask) + lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=tokenizer.eos_token_id) return lm_ex return dataset.map(_prepare_example) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 1532a7d06..13c7ea44b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -222,15 +222,18 @@ def __init__( dataset: AsyncDataset[np.ndarray], QPos: Axis, KPos: Axis, + *, fcm_prob: float = 0.0, key: Optional[PRNGKey] = None, ignore_index: Optional[int] = None, + eos_id: Optional[int] = None, ): self.dataset = dataset self.QPos = QPos self.KPos = KPos self.fcm_prob = fcm_prob self.ignore_id = ignore_index + self.eos_id = eos_id self.key = key if self.fcm_prob > 0.0 and self.key is None: @@ -241,7 +244,7 @@ def __init__( @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): tokens = hax.named(tokens, self.QPos) - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id, eos_id=eos_id) if self.fcm_prob > 0: # masks for attention @@ -802,22 +805,22 @@ def _prepare_supervised_examples(ex: list[dict], tokenizer: PreTrainedTokenizerB out = [] for ids, len in zip(truncated, lens): - causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id) + causal = _mk_sup_example_jit(Pos, hax.named(ids, Pos), len, tokenizer.pad_token_id, tokenizer.eos_token_id) out.append(causal) return out -@functools.partial(jax.jit, static_argnums=(0, 3)) -def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id): +@functools.partial(jax.jit, static_argnums=(0, 3, 4)) +def _mk_sup_example_jit(Pos, input_ids: hax.NamedArray, sources_len, pad_token_id, eos_id): # mask out padding and anything before the start of the target loss_mask = hax.arange(Pos) >= sources_len - 1 # don't predict the padding targets = hax.roll(input_ids, -1, Pos) loss_mask = loss_mask & (targets != pad_token_id) loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_)) - return LmExample.causal(input_ids, loss_mask=loss_mask) + return LmExample.causal(input_ids, loss_mask=loss_mask, eos_id=eos_id) def mk_supervised_datasets( diff --git a/src/levanter/main/doremi_lm.py b/src/levanter/main/doremi_lm.py index 742c3229c..d84e362f9 100644 --- a/src/levanter/main/doremi_lm.py +++ b/src/levanter/main/doremi_lm.py @@ -110,11 +110,23 @@ def init_proxy_model(): valid_datasets = config.data.validation_sets(ref_model.Pos.size) causal_train_datasets = { - k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + k: CausalLmDataset( + v, + config.model.Pos, + config.model.KeyPos, + ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, + ) for k, v in train_datasets.items() } valid_datasets = { - k: CausalLmDataset(v, config.model.Pos, config.model.KeyPos, ignore_index=config.data.ignore_token_id) + k: CausalLmDataset( + v, + config.model.Pos, + config.model.KeyPos, + ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, + ) for k, v in valid_datasets.items() } diff --git a/src/levanter/main/eval_lm.py b/src/levanter/main/eval_lm.py index 116a08f18..a4b3a9516 100644 --- a/src/levanter/main/eval_lm.py +++ b/src/levanter/main/eval_lm.py @@ -49,7 +49,9 @@ def main(config: EvalLmConfig): KeyPos = config.model.KeyPos if config.eval_on_train: - raw_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos) + raw_dataset = CausalLmDataset( + config.data.train_set(Pos.size, key=jax.random.PRNGKey(0)), Pos, KeyPos, eos_id=tokenizer.eos_token_id + ) else: validation_set = config.data.validation_set(Pos.size) if validation_set is None: diff --git a/src/levanter/main/lora_lm.py b/src/levanter/main/lora_lm.py index 9eee109fe..966b4fc1e 100644 --- a/src/levanter/main/lora_lm.py +++ b/src/levanter/main/lora_lm.py @@ -85,7 +85,9 @@ def main(config: LoraLmConfig): if len(eval_datasets) == 0: logger.warning("No evaluation datasets provided.") - train_dataset = CausalLmDataset(config.data.train_set(Pos.size, key=data_key), Pos, KeyPos) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, eos_id=tokenizer.eos_token_id + ) train_loader = trainer.data_loader(train_dataset, Batch) # load the underlying hf model @@ -121,7 +123,7 @@ def loraize_hf_model(model): logger.info(f"Fraction of parameters that are trainable: {just_lora_params * 1.0 / all_param_count:.3e}") for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos) + eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) trainer.add_eval_hook(eval_dataset, name=name) # boilerplate hooks and such @@ -129,7 +131,9 @@ def loraize_hf_model(model): logger.warning("No evaluation datasets provided.") for name, eval_dataset in eval_datasets.items(): - eval_dataset = CausalLmDataset(eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id) + eval_dataset = CausalLmDataset( + eval_dataset, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id + ) trainer.add_eval_hook(eval_dataset, name=name) trainer.add_hook(callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size), every=1) diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 99165c017..423ecb938 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -131,6 +131,7 @@ def main(config: TrainLmConfig): Pos, KeyPos, ignore_index=config.data.ignore_token_id, + eos_id=tokenizer.eos_token_id, ) # add epoch logging if epochs specified @@ -199,7 +200,12 @@ def main(config: TrainLmConfig): logger.warning("No evaluation datasets provided.") else: causal_datasets = [ - (CausalLmDataset(ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id), tags) + ( + CausalLmDataset( + ds, Pos, KeyPos, ignore_index=config.data.ignore_token_id, eos_id=tokenizer.eos_token_id + ), + tags, + ) for ds, tags in tagged_eval_datasets ] cb = levanter.eval.cb_tagged_lm_evaluate( diff --git a/src/levanter/main/viz_logprobs.py b/src/levanter/main/viz_logprobs.py index b00ba61d5..e0aa1596d 100644 --- a/src/levanter/main/viz_logprobs.py +++ b/src/levanter/main/viz_logprobs.py @@ -46,7 +46,7 @@ def main(config: VizGpt2Config): eval_loader = DataLoader( EvalBatch, - CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos), # type: ignore + CausalLmDataset(config.data.validation_set(Pos.size), Pos, KeyPos, eos_id=tokenizer.eos_token_id), # type: ignore 32, config.trainer.device_mesh, config.trainer.compute_axis_mapping, diff --git a/src/levanter/models/attention.py b/src/levanter/models/attention.py index 0ae9a79f7..7044200ba 100644 --- a/src/levanter/models/attention.py +++ b/src/levanter/models/attention.py @@ -7,6 +7,7 @@ import equinox as eqx import jax import jax.numpy as jnp +from jax.experimental.pallas.ops.tpu.splash_attention import SegmentIds from jax.experimental.shard_map import shard_map from jax.sharding import PartitionSpec from jaxtyping import PRNGKeyArray @@ -439,9 +440,9 @@ def _bin_and_group_axes_by_function(q, k, v, QPos, KPos, Key): NVTE and the Splash Attention kernel require the Q, K, and V to be in a specific format. This function groups the axes of Q, K, and V into the right bins to match that format. - NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD - the size of the axes is a bit flexible, - with the following conditions: + NVTE requires Q, K, and V to have shape BSHD (Batch, Sequence, Head, Embed), while Splash Attention requires BHSD. + + The size of the axes is a bit flexible, with the following conditions: - B must be the same for all (TODO: is this true?) - S must be the same for K and V. Q's S can be different - H: Q's H must be a multiple of K's H (for GQA or MQA) @@ -543,6 +544,17 @@ def _unflatten_bshd(attn_output, q_class, v_class): return attn_output +def _materialize_segment_mask(segment_ids, QPos, KPos, q_slice, k_slice) -> NamedArray: + """ + Make a segment mask for attention. This is a mask that prevents attention between different segments. + """ + kv_segment_ids = segment_ids.rename({QPos: KPos})[KPos, k_slice] + q_segment_ids = segment_ids[QPos, q_slice] + sub_KPos = kv_segment_ids.resolve_axis(KPos.name) + + return q_segment_ids.broadcast_axis(sub_KPos) == kv_segment_ids + + class AttentionMask(eqx.Module): """ @@ -552,18 +564,27 @@ class AttentionMask(eqx.Module): Represents an attention mask in a structured way to make it easier to optimize attention for particular use cases (causal, prefix, etc.). It is anticipated that this will be extended with new types of masks as needed. + The abstraction is based on two concepts: + + 1) Materialization: An AttentionMask can be materialized for a particular slice of the query and key position axes. + Most naively, you can just get the whole mask as a NamedArray. However, in some cases, you might want to + only get a particular chunk (e.g. for flash attention). + 2) Combination: AttentionMasks are represented as an implicit conjunction of multiple masks, each with different + kinds of structure. You can combine masks with `&` and `|`. Due to the way jit works, we don't use inheritance + or similar to represent different kinds of masks. Instead, we use a single class with different fields. + In general, it should be safe to batch Attention Masks, but it is important that *all members of a batch have the - same sequence of combined masks*. Otherwise, the batching will not work and you'll get weird errors + same set of combined masks*. Otherwise, the batching will not work and you'll get weird errors - The interface exposed by this class is designed to work well with the attention functions in this module as - well as something like flash attention. + (Perhaps it's ok to use inheritance here? I'm not sure. Splash attention landed on inheritance, so maybe + that's a good sign.) - A mask can be materialized, in which case it returns the mask as a NamedArray. """ is_causal: bool = eqx.static_field() explicit_mask: Optional[NamedArray] = None - # TODO: add sequence packing + segment_ids: Optional[NamedArray] = None + # CF https://github.com/jax-ml/jax/blob/47858c4ac2fd4757a3b6fc5bb2981b71a71f00c2/jax/experimental/pallas/ops/tpu/flash_attention.py#L34 # TODO: add prefixlm # cf https://github.com/google-research/t5x/blob/51a99bff8696c373cc03918707ada1e98cbca407/t5x/examples/decoder_only/layers.py#L978 @@ -589,7 +610,13 @@ def materialize( else: explicit = None - return combine_masks_and(causal, explicit) + mask = combine_masks_and(causal, explicit) + + if self.segment_ids is not None: + segment_mask = _materialize_segment_mask(self.segment_ids, QPos, KPos, q_slice, k_slice) + mask = combine_masks_and(mask, segment_mask) + + return mask @staticmethod def causal() -> "AttentionMask": @@ -599,15 +626,38 @@ def causal() -> "AttentionMask": def explicit(mask: NamedArray) -> "AttentionMask": return AttentionMask(is_causal=False, explicit_mask=mask) + def with_segment_ids(self, segment_ids: NamedArray) -> "AttentionMask": + return AttentionMask(is_causal=self.is_causal, explicit_mask=self.explicit_mask, segment_ids=segment_ids) + def __and__(self, other) -> "AttentionMask": - is_causal = self.is_causal and other.is_causal + is_causal = self.is_causal or other.is_causal explicit_mask = combine_masks_and(self.explicit_mask, other.explicit_mask) - return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask) + segment_ids = self._check_for_same_segment_ids(other) + + return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) def __or__(self, other) -> "AttentionMask": - is_causal = self.is_causal or other.is_causal + is_causal = self.is_causal and other.is_causal explicit_mask = combine_masks_or(self.explicit_mask, other.explicit_mask) - return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask) + segment_ids = self._check_for_same_segment_ids(other) + return AttentionMask(is_causal=is_causal, explicit_mask=explicit_mask, segment_ids=segment_ids) + + def _check_for_same_segment_ids(self, other): + if self.segment_ids is not None and other.segment_ids is not None: + # only one segment mask is allowed + # b/c we might do this in jit, we use eqx.error_if + # in theory we can do this one by just assigning unique ids to each unique pair... + # (but i don't really anticipate needing this) + segment_ids = eqx.error_if( + self.segment_ids, + not haliax.all(self.segment_ids == other.segment_ids), + "Only one segment mask is allowed", + ) + elif self.segment_ids is not None: + segment_ids = self.segment_ids + else: + segment_ids = other.segment_ids + return segment_ids @overload @@ -661,7 +711,6 @@ def materialize_mask( # TODO: padding mask # TODO: FCM mask? -# TODO: sequence packing mask def _try_tpu_splash_attention( @@ -819,6 +868,24 @@ def flatten(axes): physical_axes_k = _physical_axis_for_binning(k_class) physical_axes_v = _physical_axis_for_binning(v_class) + # segment_ids + segment_ids = mask.segment_ids if isinstance(mask, AttentionMask) else None + physical_axes_segments = pspec_for_axis(segment_ids.axes) if segment_ids is not None else None + # do we have a batch axis in segment_ids? (needed for vmap below) + if segment_ids is not None: + index_of_seq_dim = segment_ids.axes.index(QPos) + other_indices = [i for i in range(len(segment_ids.axes)) if i != index_of_seq_dim] + if len(other_indices) > 1: + raise NotImplementedError( + f"Only one batch axis is supported in segment_ids right now (got {segment_ids.axes})" + ) + elif len(other_indices) == 1: + segment_batch_axis = other_indices[0] + else: + segment_batch_axis = None + else: + segment_batch_axis = None + # MaxText uses a block size of 512 block_size = block_size or 512 @@ -830,11 +897,12 @@ def flatten(axes): physical_axes_q, physical_axes_k, physical_axes_v, + physical_axes_segments, ), out_specs=physical_axes_q, check_rep=False, ) - def wrap_flash_attention(q, k, v): + def wrap_flash_attention(q, k, v, segment_ids): # NB: inside the function, q, k, and v are partitioned, so in general the lengths of dims are not the same Sq = q.shape[2] Sk = k.shape[2] @@ -850,6 +918,11 @@ def wrap_flash_attention(q, k, v): block_kv_dq=min(block_size, Sq), ) + if mask.segment_ids is not None: + # for now only support self attention + segment_ids = segment_ids.array + segment_ids = SegmentIds(segment_ids, segment_ids) + if mask is None: base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) elif isinstance(mask, AttentionMask): @@ -857,10 +930,10 @@ def wrap_flash_attention(q, k, v): base_mask = splash_attention_mask.CausalMask(shape=(Sq, Sk)) else: base_mask = splash_attention_mask.FullMask(_shape=(Sq, Sk)) - # This is going to be a pain to support if mask.explicit_mask is not None: raise NotImplementedError("Explicit masks are not yet supported for splash attention") + elif isinstance(mask, NamedArray): raise NotImplementedError("NamedArray masks are not yet supported for splash attention") else: @@ -876,9 +949,11 @@ def wrap_flash_attention(q, k, v): q = q.astype(attention_dtype) k = k.astype(attention_dtype) v = v.astype(attention_dtype) - return jax.vmap(splash_kernel)(q, k, v, segment_ids=None) + return jax.vmap( + lambda q, k, v, si: splash_kernel(q, k, v, segment_ids=si), in_axes=(0, 0, 0, segment_batch_axis) + )(q, k, v, segment_ids) - attn_output = wrap_flash_attention(q_, k_, v_) + attn_output = wrap_flash_attention(q_, k_, v_, segment_ids) attn_output = haliax.named(attn_output, ("B", "H", "S", "D")) # the output shape is B, S_q, H_q, D_v. Right now we're requiring D_k == D_v diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 7e51ecadd..a77baef0d 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -25,7 +25,11 @@ class LmExample(eqx.Module): @staticmethod def causal( - tokens: hax.NamedArray, *, loss_mask: Optional[hax.NamedArray] = None, ignore_id: Optional[int] = None + tokens: hax.NamedArray, + *, + loss_mask: Optional[hax.NamedArray] = None, + ignore_id: Optional[int] = None, + eos_id: Optional[int] = None, ) -> "LmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -45,6 +49,15 @@ def causal( loss_mask = loss_mask * ignore_mask attn_mask = AttentionMask.causal() + + if eos_id is not None: + # the next token after an eos token is in a new segment + eos_mask = hax.roll(tokens, 1, Pos) == eos_id + # first token is always in segment 0 + eos_mask = eos_mask.at[Pos, 0].set(False).astype(jnp.int32) + segment_ids = hax.cumsum(eos_mask, axis=Pos) + attn_mask = attn_mask.with_segment_ids(segment_ids) + return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) diff --git a/tests/test_attention.py b/tests/test_attention.py index 7defcb4a0..140e3cf7c 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,16 +1,21 @@ import jax import jax.numpy as jnp import jax.random as jrandom +import numpy as np import pytest from chex import assert_trees_all_close +from jax.sharding import Mesh import haliax as hax +from haliax import Axis from levanter.models.attention import ( + AttentionBackend, AttentionMask, _bin_and_group_axes_by_function, _te_flash_attention, _tpu_splash_attention, + dot_product_attention, ) from test_utils import skip_if_module_missing @@ -214,3 +219,46 @@ def test_tpu_splash_attention(): hax_out = hax.nn.attention.dot_product_attention(KPos, Key, q, k, v, mask=mask.materialize(QPos, KPos)) assert hax_out.axes == flash_out.axes assert_trees_all_close(hax_out.array, flash_out.array, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("impl", ["default", "jax_flash", "vanilla"]) +def test_segment_ids_are_respected(impl): + # test that we can't attend to something outside of the range + # splash needs 128 + D = 128 if impl == "default" else 2 + L = 256 + Pos = Axis("Pos", L) + Head = Axis("Head", D) + + keys = np.zeros((L, D), dtype=np.float32) + keys[0, 0] = 100.0 # really want to attend to this + values = np.zeros((L, D), dtype=np.float32) + values[0, 1] = 300.0 # check if we did attend + KPos = Pos.alias("KPos") + + query = np.ones((L, D), dtype=np.float32) + + query = hax.named(query, (Pos, Head)) + keys = hax.named(keys, (KPos, Head)) + values = hax.named(values, (KPos, Head)) + + query, keys, values = jax.device_put( + [query, keys, values], jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1) + ) + + segment_ids = np.array([0, 0, 0] + [1] * (L - 3), dtype=np.int32) + segment_ids = jax.device_put(segment_ids, jax.sharding.PositionalSharding(jax.devices())) + segment_ids = hax.named(segment_ids, (Pos,)) + mask = AttentionMask(is_causal=True, segment_ids=segment_ids) + + devices = jax.devices() + + with Mesh(devices, ("dp",)): + result = hax.named_jit(dot_product_attention)( + Pos, KPos, Head, query, keys, values, attn_backend=AttentionBackend(impl), mask=mask, flash_block_size=128 + ) + + # the first 3 positions should all have a value of 300.0 + assert_trees_all_close(result.array[0:3, 1], 300.0, atol=1e-3, rtol=1e-3) + # the rest should be 0 + assert_trees_all_close(result.array[3:, 1], 0.0, atol=1e-3, rtol=1e-3) diff --git a/tests/test_hf_gpt2_serialize.py b/tests/test_hf_gpt2_serialize.py index 87279d25b..9cc46ca0d 100644 --- a/tests/test_hf_gpt2_serialize.py +++ b/tests/test_hf_gpt2_serialize.py @@ -134,7 +134,7 @@ def torch_loss(model, input_ids) -> torch.Tensor: torch_out = torch_loss(torch_model, torch.from_numpy(onp.array(input.array)).to(torch.int64).unsqueeze(0)) def compute_loss(model: LmHeadModel, input_ids): - example = LmExample.causal(input_ids) + example = LmExample.causal(input_ids, eos_id=converter.tokenizer.eos_token_id) return compute_next_token_loss(model, example, key=None).scalar() jax_compute_grad = equinox.filter_value_and_grad(compute_loss, has_aux=False) diff --git a/tests/test_text.py b/tests/test_text.py index f293a9429..63d0afedb 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -30,9 +30,10 @@ def test_lm_example_handles_ignore_id(): tokens = hax.arange(Pos, dtype=jnp.int32) ignore_id = 6 + eos_id = 10 - ex_ignore = LmExample.causal(tokens, ignore_id=ignore_id) - ex_no_ignore = LmExample.causal(tokens) + ex_ignore = LmExample.causal(tokens, ignore_id=ignore_id, eos_id=eos_id) + ex_no_ignore = LmExample.causal(tokens, eos_id=eos_id) assert ex_ignore.loss_mask[Pos, ignore_id - 1] == 0 logits = hax.ones((Pos, Embed))