diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index e843fdc7..d6c3b230 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -155,7 +155,7 @@ def load( # for vlm models we need to return list of layers # so nlayers is a list of ints in this case but in others its just an int nlayers = model.get_num_layers_for_type(layer_type) - if type(nlayers) == int: + if type(nlayers) is int: lora_a_list = [None] * nlayers lora_b_list = [None] * nlayers layer_ids = list(range(nlayers)) diff --git a/server/lorax_server/models/custom_modeling/mllama.py b/server/lorax_server/models/custom_modeling/mllama.py index 7aa2e01e..c9ec9dbf 100644 --- a/server/lorax_server/models/custom_modeling/mllama.py +++ b/server/lorax_server/models/custom_modeling/mllama.py @@ -24,17 +24,27 @@ from transformers.activations import ACT2FN from lorax_server.adapters.weights import AdapterBatchData -from lorax_server.layers import ( +from lorax_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, + FlashLlamaLayer, +) +from lorax_server.utils.attention.common import Seqlen +from lorax_server.utils.layers import ( FastLinear, + TensorParallelAdapterRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, + TensorParallelMultiAdapterLinear, TensorParallelRowLinear, ) -from lorax_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, - FlashLlamaLayer, +from lorax_server.utils.lora import ( + FC1, + FC2, + K_PROJ, + O_PROJ, + Q_PROJ, + V_PROJ, ) -from lorax_server.utils.attention.common import Seqlen # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb @@ -200,27 +210,76 @@ def _prepare_cross_attention_mask( # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision class MllamaVisionMLP(nn.Module): - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_id, model_type): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = TensorParallelColumnLinear.load(prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True) - self.fc2 = TensorParallelRowLinear.load(prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) + fc1 = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.fc1"], + weights=weights, + dim=0, + bias=True, + ) + + out_size = fc1.linear.weight.shape[-1] * weights.process_group.size() + self.fc1 = TensorParallelMultiAdapterLinear.load( + fc1, + layer_id, + [f'{model_type}_{FC1}'], + sizes=[out_size], + process_group=weights.process_group + ) + self.fc2 = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.fc2", + weights=weights, + bias=True, + ), + layer_id, + f'{model_type}_{FC2}', + process_group=weights.process_group, + ) + + def forward(self, hidden_states: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor: + hidden_states = self.fc1(hidden_states, adapter_data) hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) + hidden_states = self.fc2(hidden_states, adapter_data) return hidden_states +def load_attention(config, prefix, weights, layer_id, model_type, head_dim, n_head, n_head_kv): + base_layer = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], + dim=0, + weights=weights, + bias=False, + ) + return TensorParallelMultiAdapterLinear.load( + base_layer, + layer_id, + [f'{model_type}_{Q_PROJ}', f'{model_type}_{K_PROJ}', f'{model_type}_{V_PROJ}'], + sizes=[ + head_dim * n_head, + head_dim * n_head_kv, + head_dim * n_head_kv, + ], + process_group=weights.process_group, + ) + + class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, *, prefix, config, weights): + def __init__(self, *, prefix, config, weights, layer_id, model_type): super().__init__() self.embed_dim = config.hidden_size self.head_dim = config.hidden_size // config.attention_heads self.num_heads = config.attention_heads // weights.process_group.size() + self.head_size = config.hidden_size // self.num_heads + self.num_key_value_heads = getattr(config, "n_head_kv", None) or self.num_heads self.qkv_proj = TensorParallelColumnLinear.load_multi( config, @@ -229,19 +288,35 @@ def __init__(self, *, prefix, config, weights): weights=weights, bias=False, ) - self.o_proj = TensorParallelRowLinear.load( + self.qkv_proj = load_attention( config, - prefix=f"{prefix}.o_proj", - weights=weights, - bias=False, + prefix, + weights, + layer_id, + model_type, + self.head_size, + self.num_heads, + self.num_key_value_heads, + ) + self.o_proj = TensorParallelAdapterRowLinear.load( + TensorParallelRowLinear.load( + config, + prefix=f"{prefix}.o_proj", + weights=weights, + bias=False, + ), + layer_id, + f'{model_type}_{O_PROJ}', + process_group=weights.process_group, ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + adapter_data: AdapterBatchData = None, ) -> torch.Tensor: - qkv = self.qkv_proj(hidden_state) + qkv = self.qkv_proj(hidden_state, adapter_data) query, key, value = qkv.split( [ self.head_dim * self.num_heads, @@ -267,12 +342,12 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - output = self.o_proj(attn_output) + output = self.o_proj(attn_output, adapter_data) return output class MllamaVisionEncoderLayer(nn.Module): - def __init__(self, *, prefix, config, weights, is_gated: bool): + def __init__(self, *, prefix, config, weights, is_gated: bool, layer_id: int, model_type: str): super().__init__() self.hidden_size = config.hidden_size @@ -280,8 +355,20 @@ def __init__(self, *, prefix, config, weights, is_gated: bool): self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights) - self.mlp = MllamaVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.self_attn = MllamaVisionSdpaAttention( + prefix=f"{prefix}.self_attn", + config=config, + weights=weights, + layer_id=layer_id, + model_type=model_type, + ) + self.mlp = MllamaVisionMLP( + prefix=f"{prefix}.mlp", + config=config, + weights=weights, + layer_id=layer_id, + model_type=model_type, + ) self.input_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05) self.post_attention_layernorm = nn.LayerNorm.load( @@ -297,47 +384,52 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + adapter_data: AdapterBatchData = None, ): # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) - hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + hidden_state = self.self_attn(hidden_state, attention_mask, adapter_data) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state # Feed forward residual = hidden_state hidden_state = self.post_attention_layernorm(hidden_state) - hidden_state = self.mlp(hidden_state) + hidden_state = self.mlp(hidden_state, adapter_data) gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() hidden_state = residual + gate_ffn * hidden_state return hidden_state class MllamaVisionEncoder(nn.Module): - def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int): + def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int, model_type: str): super().__init__() self.config = config self.layers = [ MllamaVisionEncoderLayer( - prefix=f"{prefix}.layers.{i}", + prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, is_gated=is_gated, + layer_id=layer_id, + model_type=model_type, ) - for i in range(num_layers) + for layer_id in range(num_layers) ] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + adapter_data: AdapterBatchData = None, ): encoder_states = [hidden_states] for encoder_layer in self.layers: layer_outputs = encoder_layer( hidden_states, attention_mask, + adapter_data, ) hidden_states = layer_outputs @@ -465,6 +557,7 @@ def __init__(self, *, prefix, config, weights): weights=weights, is_gated=False, num_layers=config.num_hidden_layers, + model_type='VISION_TRANSFORMER', ) self.global_transformer = MllamaVisionEncoder( prefix=f"{prefix}.global_transformer", @@ -472,6 +565,7 @@ def __init__(self, *, prefix, config, weights): weights=weights, is_gated=True, num_layers=config.num_global_layers, + model_type='VISION_GLOBAL_TRANSFORMER', ) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: @@ -485,6 +579,7 @@ def forward( pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor, + adapter_data: AdapterBatchData, ) -> torch.Tensor: batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape @@ -538,6 +633,7 @@ def forward( hidden_state, all_intermediate_hidden_states = self.transformer( hidden_state, attention_mask=attention_mask, + adapter_data=adapter_data, ) intermediate_hidden_states = [ hidden_state @@ -560,7 +656,11 @@ def forward( num_tiles * (num_patches + num_padding_patches), dim, ) - hidden_state, _ = self.global_transformer(hidden_state, attention_mask=attention_mask) + hidden_state, _ = self.global_transformer( + hidden_state, + attention_mask=attention_mask, + adapter_data=adapter_data, + ) hidden_state = hidden_state.reshape( batch_size * num_concurrent_media, num_tiles, @@ -854,12 +954,12 @@ def create_layer(layer_id, prefix, config, weights): self.dtype = weights.dtype self.device = weights.device - def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask): + def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask, adapter_data): if aspect_ratio_ids is None: raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") # logger.info(f"PIxel values {pixel_values.shape}") batch_size = pixel_values.shape[0] - vision_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) + vision_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask, adapter_data) cross_attention_states = self.multi_modal_projector(vision_states).reshape( -1, vision_states.shape[-2], self.hidden_size ) diff --git a/server/lorax_server/models/mllama.py b/server/lorax_server/models/mllama.py index 673f4eac..504d8129 100644 --- a/server/lorax_server/models/mllama.py +++ b/server/lorax_server/models/mllama.py @@ -10,26 +10,14 @@ PreTrainedTokenizerBase, ) +from lorax_server.models.custom_modeling.mllama import FlashLlamaCrossLayer +from lorax_server.models.metadata_kernels import block_tables_to_ragged from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch from lorax_server.pb import generate_pb2 from lorax_server.utils.attention.common import Seqlen -from lorax_server.models.metadata_kernels import block_tables_to_ragged +from lorax_server.utils.lora import DOWN_PROJ, FC1, FC2, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ from lorax_server.utils.state import PREFIX_CACHING from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.models.custom_modeling.mllama import FlashLlamaCrossLayer - -from lorax_server.utils.lora import ( - DOWN_PROJ, - GATE_PROJ, - K_PROJ, - LM_HEAD, - O_PROJ, - Q_PROJ, - UP_PROJ, - V_PROJ, - FC1, - FC2 -) tracer = trace.get_tracer(__name__) @@ -246,8 +234,6 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj) layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head) - # base_model.model.vision_model.transformer.layers.17.self_attn.v_proj.lora_A.weight - # vision_model.transformer.layers.17.self_attn.v_proj vision_layer_mappings = [ ("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers), ("vision_model.transformer.layers", self.model.vision_model.transformer.layers), @@ -361,6 +347,7 @@ def forward( pixel_values=batch.pixel_values, aspect_ratio_ids=batch.aspect_ratio_ids, aspect_ratio_mask=batch.aspect_ratio_mask, + adapter_data=adapter_data, ) batch.cross_attention_states = cross_attention_states