Skip to content

Commit

Permalink
Implement Llama, Mistral, and GPT-NeoX transformer variants.
Browse files Browse the repository at this point in the history
Adds support for Llama, Mistral, and GPT-NeoX transformer models.
These models can be converted from their corresponding pretrained
model classes from HuggingFace's `transformers` library (the
`XForCausalLM` classes).

Also adds tests to ensure that the converted versions are consistent
with the HuggingFace implementations.
  • Loading branch information
danieldjohnson committed Jun 25, 2024
1 parent 9814374 commit a547365
Show file tree
Hide file tree
Showing 11 changed files with 1,683 additions and 504 deletions.
514 changes: 68 additions & 446 deletions penzai/experimental/v2/models/transformer/variants/gemma.py

Large diffs are not rendered by default.

549 changes: 549 additions & 0 deletions penzai/experimental/v2/models/transformer/variants/gpt_neox.py

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions penzai/experimental/v2/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 The Penzai Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Llama architecture transformer variant."""

from __future__ import annotations

from typing import Any

import jax
import jax.numpy as jnp
from penzai.experimental.v2 import pz
from penzai.experimental.v2.models.transformer import model_parts
from penzai.experimental.v2.models.transformer.variants import llamalike_common


LlamaForCausalLM = Any


def llama_from_huggingface_model(
model: LlamaForCausalLM,
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
) -> model_parts.Transformer:
"""Converts a HuggingFace Llama model to a Penzai model.
This function converts Llama models from their HuggingFace
implementations to Penzai. (Other models with the same architecture may also
be supported if they use the same configuration, but this has not been
tested.)
Args:
model: The HuggingFace Llama model.
upcast_activations_to_float32: Whether to cast activations to float32 when
the model runs. This allows analyzing activations at higher precision
without consuming additional memory for parameters.
use_layer_stack: Whether to use a layer stack for the decoder blocks.
Returns:
A Transformer model containing the loaded parameters.
"""
if type(model).__name__ != "LlamaForCausalLM":
raise ValueError(
"llama_from_huggingface_model should be called with a"
f" LlamaForCausalLM instance, but got {type(model).__name__}."
)
# Checkpoint conversion assumes these configuration arguments are set:
hf_config = model.config
checked_config_args = dict(
hidden_act="silu",
rms_norm_eps=1e-6,
tie_word_embeddings=False,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
)
for k, v in checked_config_args.items():
actual_value = getattr(hf_config, k)
if actual_value != v:
raise ValueError(
f"Conversion of a LlamaForCausalLM requires config.{k}={repr(v)}, but"
f" got {actual_value}"
)

return llamalike_common.llamalike_from_huggingface_model(
model,
upcast_activations_to_float32=upcast_activations_to_float32,
use_layer_stack=use_layer_stack,
)
Loading

0 comments on commit a547365

Please sign in to comment.