diff --git a/pyproject.toml b/pyproject.toml index 4736fad..c2b01af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,6 @@ include = [ "instructlab.dolomite.hf_models.modeling_utils.normalization.layernorm", "instructlab.dolomite.hf_models.modeling_utils.normalization.rmsnorm", "instructlab.dolomite.hf_models.modeling_utils.position_embedding", - "instructlab.dolomite.gradient_checkpointing", "instructlab.dolomite.utils", ] diff --git a/src/instructlab/dolomite/gradient_checkpointing/__init__.py b/src/instructlab/dolomite/gradient_checkpointing/__init__.py deleted file mode 100644 index 8f648dd..0000000 --- a/src/instructlab/dolomite/gradient_checkpointing/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch - -# Local -from ..enums import GradientCheckpointingMethod -from .block import block_checkpointing - -_GRADIENT_CHECKPOINTING_METHODS = { - GradientCheckpointingMethod.block: block_checkpointing -} - - -def apply_gradient_checkpointing( - model: torch.nn.Module, - gradient_checkpointing_method: GradientCheckpointingMethod, - **kwargs, -) -> None: - checkpointing_function = _GRADIENT_CHECKPOINTING_METHODS[ - gradient_checkpointing_method - ] - checkpointing_function(model, **kwargs) diff --git a/src/instructlab/dolomite/gradient_checkpointing/block.py b/src/instructlab/dolomite/gradient_checkpointing/block.py deleted file mode 100644 index fc834e3..0000000 --- a/src/instructlab/dolomite/gradient_checkpointing/block.py +++ /dev/null @@ -1,47 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from functools import partial - -# Third Party -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - CheckpointImpl, - apply_activation_checkpointing, - checkpoint_wrapper, -) -import torch - -# Local -from ..utils import get_module_class_from_name - - -def block_checkpointing( - model: torch.nn.Module, - block_name: str, - checkpoint_every: int = 1, - use_reentrant: bool = False, -) -> None: - block_class = get_module_class_from_name(model, block_name) - block_idx = 0 - - def _whether_to_checkpoint(submodule: torch.nn.Module) -> bool: - nonlocal block_idx - - if isinstance(submodule, block_class): - block_idx += 1 - if (block_idx - 1) % checkpoint_every == 0: - return True - return False - - checkpoint_wrapper_function = checkpoint_wrapper - if use_reentrant: - checkpoint_wrapper_function = partial( - checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT - ) - - apply_activation_checkpointing( - model, - checkpoint_wrapper_fn=checkpoint_wrapper_function, - check_fn=_whether_to_checkpoint, - ) diff --git a/src/instructlab/dolomite/hf_models/__init__.py b/src/instructlab/dolomite/hf_models/__init__.py index ddf2ee1..66b024c 100644 --- a/src/instructlab/dolomite/hf_models/__init__.py +++ b/src/instructlab/dolomite/hf_models/__init__.py @@ -2,8 +2,9 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local +from .config import GPTDolomiteConfig from .model_conversion import export_to_huggingface, import_from_huggingface -from .models import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel +from .models import GPTDolomiteForCausalLM, GPTDolomiteModel from .register_hf import register_model_classes register_model_classes() diff --git a/src/instructlab/dolomite/hf_models/config.py b/src/instructlab/dolomite/hf_models/config.py index abfb722..0258489 100644 --- a/src/instructlab/dolomite/hf_models/config.py +++ b/src/instructlab/dolomite/hf_models/config.py @@ -8,7 +8,8 @@ from .enums import AttentionHeadType, PositionEmbeddingType -class CommonConfig(PretrainedConfig): +class GPTDolomiteConfig(PretrainedConfig): + model_type = "gpt_dolomite" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = { "hidden_size": "n_embd", @@ -19,6 +20,8 @@ class CommonConfig(PretrainedConfig): # NOTE: initializer range is kept for backward compatiblity # but it is not used anymore + # : also rope_scaling is not used anymore but kept for + # same reason. def __init__( self, diff --git a/src/instructlab/dolomite/hf_models/defaults.py b/src/instructlab/dolomite/hf_models/defaults.py deleted file mode 100644 index 01bd5fd..0000000 --- a/src/instructlab/dolomite/hf_models/defaults.py +++ /dev/null @@ -1,4 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -DEFAULT_NORMALIZATION_IMPLEMENTATION = "torch" diff --git a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py index 8f21040..9bfd4da 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py @@ -8,8 +8,8 @@ from transformers import AutoConfig, AutoTokenizer, GenerationConfig, GPTBigCodeConfig # Local +from ..config import GPTDolomiteConfig from ..enums import AttentionHeadType, PositionEmbeddingType -from ..models import GPTDolomiteConfig def import_from_huggingface_bigcode( diff --git a/src/instructlab/dolomite/hf_models/model_conversion/llama.py b/src/instructlab/dolomite/hf_models/model_conversion/llama.py index 4715788..ae6de92 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/llama.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/llama.py @@ -6,12 +6,12 @@ # Local from ...utils import SafeTensorsWeightsManager, download_repo +from ..config import GPTDolomiteConfig from ..enums import AttentionHeadType from ..modeling_utils import ( interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) -from ..models import GPTDolomiteConfig from ..models.gpt_dolomite import ( interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py index ec0a4ee..1ad7d54 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py @@ -14,7 +14,5 @@ repeat_key_value, split_query_key_value_tensor_for_attention, ) -from .embedding import Embedding -from .linear import Linear from .normalization import RMSNorm, get_normalization_function -from .position_embedding import Alibi, RoPE, YaRNScaledRoPE, apply_rotary_pos_emb +from .position_embedding import Alibi, RoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py index e522908..fca22e7 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py @@ -9,7 +9,7 @@ import torch # Local -from ...config import CommonConfig +from ...config import GPTDolomiteConfig from ...enums import AttentionHeadType from .base import Attention from .flash import FlashAttention2 @@ -48,7 +48,7 @@ def get_attention_module( - config: CommonConfig, + config: GPTDolomiteConfig, causal: bool, attention_implementation: str, use_padding_free_transformer: bool, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py index c6fe9cd..4c17941 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py @@ -5,21 +5,21 @@ from typing import Tuple # Third Party +from torch.nn import Linear # replaces ParameterizedLinear from transformers import DynamicCache import torch import torch.nn.functional as F # Local -from ...config import CommonConfig +from ...config import GPTDolomiteConfig from ...enums import AttentionHeadType, PositionEmbeddingType -from ..linear import Linear from ..position_embedding import apply_rotary_pos_emb from .utils import repeat_key_value class Attention(torch.nn.Module): def __init__( - self, config: CommonConfig, causal: bool, layer_idx: int = None + self, config: GPTDolomiteConfig, causal: bool, layer_idx: int = None ) -> None: super().__init__() diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py deleted file mode 100644 index e33cbb1..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py +++ /dev/null @@ -1,8 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# pylint: disable=unused-import -# Third Party -from torch.nn import Embedding - -# NOTE: we have replaced ParameterizedEmbedding with torch.nn.Embedding diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py deleted file mode 100644 index 0d193f6..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py +++ /dev/null @@ -1,8 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# pylint: disable=unused-import -# Third Party -from torch.nn import Linear - -# NOTE: we have replaced ParameterizedLinear with torch.nn.Linear diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py index b7e5c41..eb68644 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py @@ -5,8 +5,7 @@ import torch # Local -from .layernorm import get_layernorm -from .rmsnorm import RMSNorm, get_rmsnorm +from .norms import RMSNorm, get_layernorm, get_rmsnorm _NORMALIZATION_FUNCTIONS = { "layernorm": get_layernorm, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py deleted file mode 100644 index d91707e..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch - -_LAYERNORM_MODULES = { - "torch": torch.nn.LayerNorm, -} - - -def get_layernorm( - normalized_shape: int, - eps: float, - normalization_implementation: str = "torch", -) -> torch.nn.LayerNorm: - if normalization_implementation in _LAYERNORM_MODULES: - return _LAYERNORM_MODULES[normalization_implementation]( - normalized_shape=normalized_shape, eps=eps - ) - - raise ValueError( - f"unexpected `normalization_implementation` {normalization_implementation}" - ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py new file mode 100644 index 0000000..f752a6a --- /dev/null +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/norms.py @@ -0,0 +1,81 @@ +# ---------------------------------------------------------------- +# Extracted from https://github.com/ibm-granite/dolomite-engine +# ---------------------------------------------------------------- + +# Standard +import numbers + +# Third Party +import torch + +# ---------------- LayerNorm --------------- + +_LAYERNORM_MODULES = { + "torch": torch.nn.LayerNorm, +} + + +def get_layernorm( + normalized_shape: int, + eps: float, + normalization_implementation: str = "torch", +) -> torch.nn.LayerNorm: + if normalization_implementation in _LAYERNORM_MODULES: + return _LAYERNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) + + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) + + +# --------------- RMS Norm --------------- +# ---------------------------------------------------------------- +# Extracted from https://github.com/ibm-granite/dolomite-engine +# ---------------------------------------------------------------- + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: + super().__init__() + + self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) + self.eps = eps + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = normalized_shape + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input_dtype = input.dtype + + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.eps) + + return self.weight * input.to(input_dtype) + + def extra_repr(self) -> str: + return f"{self.normalized_shape}, eps={self.eps}" + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) + + +_RMSNORM_MODULES = {"torch": RMSNorm} + + +def get_rmsnorm( + normalized_shape: int, + eps: float, + normalization_implementation: str = "torch", +) -> torch.nn.LayerNorm: + if normalization_implementation in _RMSNORM_MODULES: + return _RMSNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) + + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py deleted file mode 100644 index 0570548..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Third Party -import torch - -# Local -from .base import RMSNorm - -_RMSNORM_MODULES = {"torch": RMSNorm} - - -def get_rmsnorm( - normalized_shape: int, - eps: float, - normalization_implementation: str = "torch", -) -> torch.nn.LayerNorm: - if normalization_implementation in _RMSNORM_MODULES: - return _RMSNORM_MODULES[normalization_implementation]( - normalized_shape=normalized_shape, eps=eps - ) - - raise ValueError( - f"unexpected `normalization_implementation` {normalization_implementation}" - ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py deleted file mode 100644 index 4a8feb8..0000000 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py +++ /dev/null @@ -1,35 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -import numbers - -# Third Party -import torch - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape: int, eps: float = 1e-6) -> None: - super().__init__() - - self.weight = torch.nn.Parameter(torch.ones(normalized_shape)) - self.eps = eps - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = normalized_shape - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input_dtype = input.dtype - - input = input.to(torch.float32) - variance = input.pow(2).mean(-1, keepdim=True) - input = input * torch.rsqrt(variance + self.eps) - - return self.weight * input.to(input_dtype) - - def extra_repr(self) -> str: - return f"{self.normalized_shape}, eps={self.eps}" - - def reset_parameters(self) -> None: - torch.nn.init.ones_(self.weight) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py index a4d9689..a16ee80 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py @@ -3,4 +3,4 @@ # ---------------------------------------------------------------- # Local from .alibi import Alibi -from .rope import RoPE, YaRNScaledRoPE, apply_rotary_pos_emb +from .rope import RoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py index f73b5eb..7dd9232 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py @@ -5,7 +5,6 @@ # Standard from typing import Tuple -import math # Third Party import torch @@ -72,66 +71,6 @@ def _set_cos_sin_cache( ) -class YaRNScaledRoPE(RoPE): - def __init__( - self, - head_dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - scale: float = 1, - original_max_position_embeddings: int = 2048, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - ) -> None: - torch.nn.Module.__init__(self) - - self.head_dim = head_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.scale = scale - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - # Get n-d magnitude scaling corrected for interpolation - self.mscale = _yarn_get_mscale(self.scale) * self.attn_factor - - self.reset_parameters() - - def reset_parameters(self) -> None: - pos_freqs = self.base ** ( - torch.arange(0, self.head_dim, 2).float() / self.head_dim - ) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) - - low, high = _yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - self.head_dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = ( - (1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float()) - * self.extrapolation_factor - ) # Get n-d rotational scaling corrected for extrapolation - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) - + inv_freq_extrapolation * inv_freq_mask - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # pylint: disable=no-value-for-parameter - self._set_cos_sin_cache( - self.max_position_embeddings, dtype=torch.get_default_dtype() - ) - - def apply_rotary_pos_emb( x: torch.Tensor, cos_sin: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -143,44 +82,3 @@ def apply_rotary_pos_emb( def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) - - -# Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim( - num_rotations: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 -) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -# Find dim range bounds based on rotations -def _yarn_find_correction_range( - low_rot: int, - high_rot: int, - dim: int, - base: int = 10000, - max_position_embeddings: int = 2048, -) -> int: - low = math.floor( - _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def _yarn_linear_ramp_mask(min: float, max: float, dim: int) -> torch.Tensor: - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -def _yarn_get_mscale(scale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * math.log(scale) + 1.0 diff --git a/src/instructlab/dolomite/hf_models/models/__init__.py b/src/instructlab/dolomite/hf_models/models/__init__.py index 8eb2025..684111f 100644 --- a/src/instructlab/dolomite/hf_models/models/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/__init__.py @@ -2,4 +2,4 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .gpt_dolomite import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel +from .gpt_dolomite import GPTDolomiteForCausalLM, GPTDolomiteModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py index c4f5b19..d121b1e 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py @@ -3,6 +3,5 @@ # ---------------------------------------------------------------- # Local from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel -from .config import GPTDolomiteConfig from .main import GPTDolomiteForCausalLM from .mlp import interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py index abf1696..e9753bb 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py @@ -11,21 +11,14 @@ import torch # Local -from ...defaults import DEFAULT_NORMALIZATION_IMPLEMENTATION +from ...config import GPTDolomiteConfig from ...enums import AttentionHeadType, PositionEmbeddingType -from ...modeling_utils import ( - Alibi, - Embedding, - Linear, - RMSNorm, - RoPE, - YaRNScaledRoPE, - get_normalization_function, -) +from ...modeling_utils import Alibi, RMSNorm, RoPE, get_normalization_function from ...utils import check_list_type, flatten_and_convert_to_tensors -from .config import GPTDolomiteConfig from .layer import GPTDolomiteBlock +DEFAULT_NORMALIZATION_IMPLEMENTATION = "torch" + class GPTDolomitePreTrainedModel(PreTrainedModel): """ @@ -85,7 +78,15 @@ def __init__(self, config: GPTDolomiteConfig, *inputs, **kwargs): def _init_weights(self, module: torch.nn.Module) -> None: if isinstance( - module, (Embedding, Linear, torch.nn.LayerNorm, RMSNorm, Alibi, RoPE) + module, + ( + torch.nn.Embedding, + torch.nn.Linear, + torch.nn.LayerNorm, + RMSNorm, + Alibi, + RoPE, + ), ): module.reset_parameters() @@ -234,7 +235,7 @@ def __init__(self, config: GPTDolomiteConfig, **kwargs) -> None: self.head_dim = self.embed_dim // self.num_heads - self.wte = Embedding(config.vocab_size, self.embed_dim) + self.wte = torch.nn.Embedding(config.vocab_size, self.embed_dim) self.drop = ( torch.nn.Identity() @@ -268,10 +269,10 @@ def __init__(self, config: GPTDolomiteConfig, **kwargs) -> None: # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> Embedding: + def get_input_embeddings(self) -> torch.nn.Embedding: return self.wte - def set_input_embeddings(self, new_embeddings: Embedding) -> None: + def set_input_embeddings(self, new_embeddings: torch.nn.Embedding) -> None: self.wte = new_embeddings def forward( @@ -709,7 +710,7 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.wpe = Embedding(max_position_embeddings, self.embed_dim) + self.wpe = torch.nn.Embedding(max_position_embeddings, self.embed_dim) elif self.position_embedding_type == PositionEmbeddingType.alibi: assert ( not self._use_flash_attention_2 @@ -717,22 +718,11 @@ def _setup_positional_encoding(self) -> None: self.alibi = Alibi(self.num_heads) elif self.position_embedding_type == PositionEmbeddingType.rope: - if self.config.rope_scaling is None: - self.rope = RoPE( - self.head_dim, - max_position_embeddings=max_position_embeddings, - base=self.config.rope_theta, - ) - else: - self.rope = YaRNScaledRoPE( - self.head_dim, - max_position_embeddings=max_position_embeddings, - base=self.config.rope_theta, - scale=self.config.rope_scaling["factor"], - original_max_position_embeddings=self.config.rope_scaling[ - "original_max_position_embeddings" - ], - ) + self.rope = RoPE( + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, + ) else: raise NotImplementedError() diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py deleted file mode 100644 index 7928154..0000000 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py +++ /dev/null @@ -1,9 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Local -from ...config import CommonConfig - - -class GPTDolomiteConfig(CommonConfig): - model_type = "gpt_dolomite" diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py index 68cc207..34fc3b3 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py @@ -9,9 +9,9 @@ import torch # Local +from ...config import GPTDolomiteConfig from ...enums import AttentionHeadType from ...modeling_utils import get_attention_module, get_normalization_function -from .config import GPTDolomiteConfig from .mlp import MLP diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py index 193c822..8d84c28 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py @@ -11,9 +11,8 @@ import torch.nn.functional as F # Local -from ...modeling_utils import Embedding, Linear +from ...config import GPTDolomiteConfig from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel -from .config import GPTDolomiteConfig class GPTDolomiteForCausalLM(GPTDolomitePreTrainedModel): @@ -24,24 +23,24 @@ def __init__(self, config: GPTDolomiteConfig, **kwargs) -> None: self.transformer = GPTDolomiteModel(config, **kwargs) if not self._tied_word_embeddings: - self.lm_head = Linear(config.n_embd, config.vocab_size, bias=False) + self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False) self.m_width = config.m_width # Initialize weights and apply final processing self.post_init() - def get_input_embeddings(self) -> Embedding: + def get_input_embeddings(self) -> torch.nn.Embedding: return self.transformer.wte - def set_input_embeddings(self, value: Embedding) -> None: + def set_input_embeddings(self, value: torch.nn.Embedding) -> None: self.transformer.wte = value - def get_output_embeddings(self) -> Linear: + def get_output_embeddings(self) -> torch.nn.Linear: if not self._tied_word_embeddings: return self.lm_head - def set_output_embeddings(self, new_embeddings: Linear) -> None: + def set_output_embeddings(self, new_embeddings: torch.nn.Linear) -> None: if not self._tied_word_embeddings: self.lm_head = new_embeddings diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py index 7d6b5bc..7e5b214 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py @@ -8,8 +8,8 @@ import torch # Local -from ...modeling_utils import Linear, get_activation_function, is_glu -from .config import GPTDolomiteConfig +from ...config import GPTDolomiteConfig +from ...modeling_utils import get_activation_function, is_glu class MLP(torch.nn.Module): @@ -22,7 +22,7 @@ def __init__(self, config: GPTDolomiteConfig) -> None: add_bias = config.add_bias residual_dropout = config.resid_pdrop - self.c_fc = Linear( + self.c_fc = torch.nn.Linear( hidden_size, 2 * intermediate_size if is_glu(activation_function) else intermediate_size, bias=add_bias, @@ -30,7 +30,7 @@ def __init__(self, config: GPTDolomiteConfig) -> None: self.act = get_activation_function(activation_function) - self.c_proj = Linear(intermediate_size, hidden_size, bias=add_bias) + self.c_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=add_bias) self.dropout = ( torch.nn.Identity() diff --git a/src/instructlab/dolomite/hf_models/register_hf.py b/src/instructlab/dolomite/hf_models/register_hf.py index 667c6d3..d264fd8 100644 --- a/src/instructlab/dolomite/hf_models/register_hf.py +++ b/src/instructlab/dolomite/hf_models/register_hf.py @@ -5,7 +5,8 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM # Local -from .models import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel +from .config import GPTDolomiteConfig +from .models import GPTDolomiteForCausalLM, GPTDolomiteModel # (AutoConfig, AutoModel, AutoModelForCausalLM) _CUSTOM_MODEL_REGISTRY = [ diff --git a/src/instructlab/dolomite/utils/__init__.py b/src/instructlab/dolomite/utils/__init__.py index 36d6fc0..6e928c5 100644 --- a/src/instructlab/dolomite/utils/__init__.py +++ b/src/instructlab/dolomite/utils/__init__.py @@ -4,7 +4,6 @@ # Local from .hf_hub import download_repo from .safetensors import SafeTensorsWeightsManager -from .wrapper import get_module_class_from_name try: # Third Party diff --git a/src/instructlab/dolomite/utils/wrapper.py b/src/instructlab/dolomite/utils/wrapper.py deleted file mode 100644 index 8d8548d..0000000 --- a/src/instructlab/dolomite/utils/wrapper.py +++ /dev/null @@ -1,24 +0,0 @@ -# ---------------------------------------------------------------- -# Extracted from https://github.com/ibm-granite/dolomite-engine -# ---------------------------------------------------------------- -# Standard -from typing import List, Type - -# Third Party -import torch - - -def get_module_class_from_name( - model: torch.nn.Module, name: str -) -> List[Type[torch.nn.Module]]: - modules_children = list(model.children()) - - if model.__class__.__name__ == name: - return model.__class__ - elif len(modules_children) == 0: - return - else: - for child_module in modules_children: - module_class = get_module_class_from_name(child_module, name) - if module_class is not None: - return module_class