Skip to content

Commit

Permalink
Improvements to v2 transformer implementation
Browse files Browse the repository at this point in the history
- Rename Transformer to TransformerLM
- Infer attention mask from token positions
- Add support for sliding window attention
- Add support for custom RMS norm epsilon

PiperOrigin-RevId: 647423781
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jun 27, 2024
1 parent 19fd13f commit b553f1f
Show file tree
Hide file tree
Showing 14 changed files with 335 additions and 168 deletions.
4 changes: 3 additions & 1 deletion docs/api/penzai.experimental.v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ Language Modeling
.. autosummary::
pz.nn.Attention
pz.nn.KVCachingAttention
pz.nn.ApplyAttentionMask
pz.nn.ApplyExplicitAttentionMask
pz.nnApplyCausalAttentionMask
pz.nnApplyCausalSlidingWindowAttentionMask
pz.nn.EmbeddingTable
pz.nn.EmbeddingLookup
pz.nn.EmbeddingDecode
Expand Down
79 changes: 16 additions & 63 deletions penzai/experimental/v2/models/transformer/model_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Core components of a Transformer model.
"""Core components of a Transformer language model.
Specific instantiations of the Transformer model will use the following axis
Specific instantiations of the TransformerLM model will use the following axis
naming conventions:
* "seq" is the temporal axis of the token sequence, i.e. the axis along which
Expand Down Expand Up @@ -42,29 +42,21 @@
of the keys and values in an attention matrix.
* "neurons" is the axis for the neurons in the MLP blocks, which have an
activation function (GEGLU) applied elementwise and therefore have a
activation function applied elementwise and therefore have a
priveleged basis.
Additionally, they use the following side input names:
* "token_positions" is the name of the side input that provides the position of
each token for the purposes of positional embeddings.
* "attn_mask" is the name of the side input that provides the attention mask
for each attention layer.
* Where applicable, "cache_end_index" is the name of the side input that
identifies the current length of the key/value cache state. This determines
where the new keys and values are inserted into the cache. The top-level
`KVCachingTransformer` class will usually handle this for you.
each token for the purposes of positional embeddings and causal attention
masking. -1 indicates a padding token.
The KV caching logic is defined in the separate module `sampling_mode`.
"""

from __future__ import annotations

import dataclasses
from typing import Any

import jax
from penzai.experimental.v2 import pz
Expand Down Expand Up @@ -111,7 +103,7 @@ class TransformerBlock(pz.nn.Sequential):


@pz.pytree_dataclass
class Transformer(pz.nn.Layer):
class TransformerLM(pz.nn.Layer):
"""Top-level transformer decoder wrapper.
This class is a simple wrapper that holds configuration data and runs safety
Expand All @@ -130,64 +122,25 @@ class Transformer(pz.nn.Layer):
def __call__(
self,
tokens: pz.nx.NamedArray,
token_positions: pz.nx.NamedArray,
attn_mask: pz.nx.NamedArray,
**extra_side_inputs,
*,
token_positions: pz.nx.NamedArray | None = None,
**side_inputs,
) -> pz.nx.NamedArray:
"""Scores log-probabilities for the given inputs.
For simple sequences, the ``token_positions`` and ``attn_mask`` arguments
can be computed via `simple_causal_side_inputs`.
Args:
tokens: Array of token IDs, as an integer named array with a "seq" axis
and possibly batch axes. Usually starts with the beginning-of-sequence
token.
token_positions: Sequence of token positions, as an integer named array
with a "seq" axis and possibly batch axes. Usually starts from 0 and
increments along the "seq" axis, but can be different to support e.g.
example packing.
attn_mask: Boolean attention mask with "seq" and "kv_seq" axes of the same
length, and possibly batch axes. Usually a causal mask, but can be
different to support e.g. example packing or dropping out inputs.
**extra_side_inputs: Other side inputs, which will be forwarded to the
body.
token_positions: Array of token positions, as an integer named array with
a "seq" axis and possibly batch axes. Usually starts with 0. Inferred to
start from 0 and increment along the "seq" axis if not provided.
**side_inputs: Side inputs, which will be forwarded to the body.
Returns:
The final matrix of logits from the embedding decoding layer, which
(in the normal configuration) will have axes "seq" and "vocabulary".
"""

return self.body(
tokens,
token_positions=token_positions,
attn_mask=attn_mask,
**extra_side_inputs,
)

def simple_causal_side_inputs(
self, tokens: pz.nx.NamedArray
) -> dict[str, Any]:
"""Builds a side-input dictionary for a batch of single segments.
This can be used to process inputs that do not need advanced position or
attention mask handling, and which just consist of ordinary sequences that
are not packed together or padded.
Args:
tokens: Sequence of tokens, as an integer named array with a "seq" axis
and possibly batch axes, which starts with the beginning-of-sequence
token. Each 1d vector along the "seq" axis should represent an unpadded
sequence.
Returns:
A dictionary with key "positions" mapping to a simple incrementing
position array and key "attn_mask" mapping to a causal mask, suitable for
passing as side (keyword) arguments to this model.
"""
seq = tokens.named_shape["seq"]
# Query tokens can attend to keys/values if the query position is larger.
return {
"token_positions": pz.nx.arange("seq", seq),
"attn_mask": pz.nx.arange("seq", seq) >= pz.nx.arange("kv_seq", seq),
}
if token_positions is None:
token_positions = pz.nx.arange("seq", tokens.named_shape["seq"])
return self.body(tokens, token_positions=token_positions, **side_inputs)
70 changes: 45 additions & 25 deletions penzai/experimental/v2/models/transformer/sampling_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sampling-mode adapters for Transformer models.
"""Sampling-mode adapters for TransformerLM models.
This file includes the kv-cache sampling mode of the base Transformer model.
This mode is intended to be hot-swapped for the main Transformer implementation:
you should generally start by loading a `model_parts.Transformer` and then
converting it to a `KVCachingTransformer` using
`KVCachingTransformer.from_uncached`.
This file includes the kv-cache sampling mode of the base TransformerLM model.
This mode is intended to be hot-swapped for the main TransformerLM
implementation: you should generally start by loading a
`model_parts.TransformerLM` and then converting it to a `KVCachingTransformerLM`
using `KVCachingTransformerLM.from_uncached`.
The layers defined here follow the same conventions documented in the module
docstring for `model_parts`.
docstring for `model_parts`. In addition:
* Where applicable, "kv_token_positions" is the name of the side input that
provides the position of each token for the purposes of positional embeddings.
* Where applicable, "cache_end_index" is the name of the side input that
identifies the current length of the key/value cache state.
"""

from __future__ import annotations
Expand All @@ -36,14 +42,17 @@


@pz.pytree_dataclass
class KVCachingTransformer(pz.nn.Layer):
class KVCachingTransformerLM(pz.nn.Layer):
"""Top-level transformer in (stateful) cached autoregressive sampling mode.
This class represents the sampling mode of the model, and manages the sampling
state. It is designed to be loaded from an existing `Transformer`. If you want
to load this from the pretrained checkpoint, first load a `Transformer`, then
call `KVCachingTransformer.from_uncached`.
This class handles and automatically increments token positions based on the
tokens it has generated so far.
Attributes:
body: The implementation of the transformer. Usually a nested set of state
and side-effect handlers wrapping the main sequence of transformer blocks,
Expand Down Expand Up @@ -81,8 +90,8 @@ def __call__(
and possibly batch axes. The batch axes must match the `batch_axes`
attribute. Padding tokens are ignored.
**extra_side_inputs: Extra side inputs, which will be forwarded on to the
body. The "token_positions", "attn_mask", and "cache_end_index" inputs
will be added automatically and do not need to be provided.
body. The "token_positions", "kv_token_positions", and "cache_end_index"
inputs will be added automatically and do not need to be provided.
Returns:
Matrix of logits from the embedding decoding layer, which (in the
Expand All @@ -108,27 +117,19 @@ def __call__(
kv_nonpad_so_far_inclusive = pz.nx.nmap(jnp.cumsum)(
kv_nonpad_mask.untag("seq"), dtype=jnp.int32
).tag("seq")
kv_nonpad_so_far_exclusive = (
kv_nonpad_so_far_inclusive - kv_nonpad_mask.astype(jnp.int32)
key_value_positions = pz.nx.nmap(jnp.where)(
kv_nonpad_mask, kv_nonpad_so_far_inclusive - 1, -1
)
query_positions = pz.nx.nmap(jax.lax.dynamic_slice)(
kv_nonpad_so_far_exclusive.untag("seq"),
key_value_positions.untag("seq"),
(self.cache_end_index.value,),
(tokens.named_shape["seq"],),
).tag("seq")
key_value_positions = kv_nonpad_so_far_exclusive.untag("seq").tag("kv_seq")
# Tokens can attend to any kv-token position that they are after, as long as
# it was NOT padding.
attention_mask = (
(query_positions >= key_value_positions)
& (tokens != self.pad_id)
& kv_nonpad_mask.untag("seq").tag("kv_seq")
)
# Run the model.
outs = self.body(
tokens,
token_positions=query_positions,
attn_mask=attention_mask,
kv_token_positions=key_value_positions,
cache_end_index=self.cache_end_index.value,
**extra_side_inputs,
)
Expand All @@ -143,12 +144,12 @@ def __call__(
@classmethod
def from_uncached(
cls,
uncached: model_parts.Transformer,
uncached: model_parts.TransformerLM,
cache_len: int,
batch_axes: dict[str, int],
pad_id: int = 0,
variable_name_prefix: str = "sampler",
) -> KVCachingTransformer:
) -> KVCachingTransformerLM:
"""Transforms a `Transformer` into cached sampling mode.
This constructor hot-swaps all `pz.nn.Attention` layers in the
Expand All @@ -167,6 +168,17 @@ def from_uncached(
Returns:
A KVCachingTransformer.
"""

def _fix_attn_mask(masker):
if masker.kv_positions_input_name != "token_positions":
raise ValueError(
"Could not automatically convert attention mask layer with"
f" non-standard positions input name: {masker}"
)
return dataclasses.replace(
masker, kv_positions_input_name="kv_token_positions"
)

cached_axes = {
**batch_axes,
**uncached.metadata.common_head_axes,
Expand All @@ -175,8 +187,16 @@ def from_uncached(
attn_sel = pz.select(uncached.body).at_instances_of(pz.nn.Attention)
fixed_attns = {}
for ix, (keypath, attn) in enumerate(attn_sel.selected_by_path.items()):
attn_with_new_kv_positions = (
pz.select(attn)
.at_instances_of(
pz.nn.ApplyCausalAttentionMask
| pz.nn.ApplyCausalSlidingWindowAttentionMask
)
.apply(_fix_attn_mask)
)
fixed_attns[keypath] = pz.nn.KVCachingAttention.from_uncached(
attn,
attn_with_new_kv_positions,
cache_len=cache_len,
cached_axes=cached_axes,
cache_dtype=uncached.metadata.activation_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


def temperature_sample_pyloop(
model: sampling_mode.KVCachingTransformer,
model: sampling_mode.KVCachingTransformerLM,
prompt: pz.nx.NamedArray,
rng: jax.Array,
temperature: float = 1.0,
Expand Down
4 changes: 2 additions & 2 deletions penzai/experimental/v2/models/transformer/variants/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def gemma_from_pretrained_checkpoint(
ckpt_params: dict[str, Any],
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
) -> model_parts.Transformer:
) -> model_parts.TransformerLM:
"""Builds a Gemma model from a pretrained checkpoint.
The parameters of the loaded ``Transformer`` will be close to those in
Expand Down Expand Up @@ -77,7 +77,7 @@ def gemma_from_pretrained_checkpoint(
else:
activation_dtype = attn_0_einsum_param.dtype

config = llamalike_common.LLamalikeTransformerConfig(
config = llamalike_common.LlamalikeTransformerConfig(
num_kv_heads=1 if single_kv_head else num_heads,
query_head_multiplier=num_heads if single_kv_head else 1,
embedding_dim=embed_dim,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def build_gpt_neox_attention(
),
{"seq": "tq", "heads": "h", "kv_seq": "tkv"},
),
pz.nn.ApplyAttentionMask(
mask_input_name="attn_mask",
pz.nn.ApplyCausalAttentionMask(
masked_out_value=masked_out_value,
),
pz.nn.Softmax("kv_seq"),
Expand Down Expand Up @@ -276,7 +275,7 @@ def build_gpt_neox_transformer(
config: GPTNeoXTransformerConfig,
init_base_rng: jax.Array | None = None,
name: str = "transformer",
) -> model_parts.Transformer:
) -> model_parts.TransformerLM:
"""Builds a Llama-like transformer model from a configuration.
Args:
Expand Down Expand Up @@ -335,7 +334,7 @@ def build_gpt_neox_transformer(
)
)

return model_parts.Transformer(
return model_parts.TransformerLM(
metadata=model_parts.TransformerMetadata(
common_head_axes={"heads": config.num_attention_heads},
query_only_head_axes={},
Expand All @@ -356,7 +355,7 @@ def gpt_neox_from_huggingface_model(
model: GPTNeoXForCausalLM,
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
) -> model_parts.Transformer:
) -> model_parts.TransformerLM:
"""Converts a GPT-NeoX model to a Penzai model.
This function converts GPT-NeoX models from their HuggingFace implementations
Expand Down
3 changes: 1 addition & 2 deletions penzai/experimental/v2/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def llama_from_huggingface_model(
model: LlamaForCausalLM,
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
) -> model_parts.Transformer:
) -> model_parts.TransformerLM:
"""Converts a HuggingFace Llama model to a Penzai model.
This function converts Llama models from their HuggingFace
Expand All @@ -56,7 +56,6 @@ def llama_from_huggingface_model(
hf_config = model.config
checked_config_args = dict(
hidden_act="silu",
rms_norm_eps=1e-6,
tie_word_embeddings=False,
rope_scaling=None,
attention_bias=False,
Expand Down
Loading

0 comments on commit b553f1f

Please sign in to comment.