Skip to content

Commit

Permalink
switch to adapter based linear layers in vision trunk
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed Nov 9, 2024
1 parent c279221 commit 50c84f6
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 47 deletions.
2 changes: 1 addition & 1 deletion server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def load(
# for vlm models we need to return list of layers
# so nlayers is a list of ints in this case but in others its just an int
nlayers = model.get_num_layers_for_type(layer_type)
if type(nlayers) == int:
if type(nlayers) is int:
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
layer_ids = list(range(nlayers))
Expand Down
158 changes: 129 additions & 29 deletions server/lorax_server/models/custom_modeling/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,27 @@
from transformers.activations import ACT2FN

from lorax_server.adapters.weights import AdapterBatchData
from lorax_server.layers import (
from lorax_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
FlashLlamaLayer,
)
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.layers import (
FastLinear,
TensorParallelAdapterRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelMultiAdapterLinear,
TensorParallelRowLinear,
)
from lorax_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
FlashLlamaLayer,
from lorax_server.utils.lora import (
FC1,
FC2,
K_PROJ,
O_PROJ,
Q_PROJ,
V_PROJ,
)
from lorax_server.utils.attention.common import Seqlen


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
Expand Down Expand Up @@ -200,27 +210,76 @@ def _prepare_cross_attention_mask(

# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
def __init__(self, *, prefix, config, weights, layer_id, model_type):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True)
self.fc2 = TensorParallelRowLinear.load(prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
fc1 = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.fc1"],
weights=weights,
dim=0,
bias=True,
)

out_size = fc1.linear.weight.shape[-1] * weights.process_group.size()
self.fc1 = TensorParallelMultiAdapterLinear.load(
fc1,
layer_id,
[f'{model_type}_{FC1}'],
sizes=[out_size],
process_group=weights.process_group
)
self.fc2 = TensorParallelAdapterRowLinear.load(
TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=True,
),
layer_id,
f'{model_type}_{FC2}',
process_group=weights.process_group,
)

def forward(self, hidden_states: torch.Tensor, adapter_data: AdapterBatchData) -> torch.Tensor:
hidden_states = self.fc1(hidden_states, adapter_data)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = self.fc2(hidden_states, adapter_data)
return hidden_states


def load_attention(config, prefix, weights, layer_id, model_type, head_dim, n_head, n_head_kv):
base_layer = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[f'{model_type}_{Q_PROJ}', f'{model_type}_{K_PROJ}', f'{model_type}_{V_PROJ}'],
sizes=[
head_dim * n_head,
head_dim * n_head_kv,
head_dim * n_head_kv,
],
process_group=weights.process_group,
)


class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
def __init__(self, *, prefix, config, weights, layer_id, model_type):
super().__init__()

self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_heads = getattr(config, "n_head_kv", None) or self.num_heads

self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
Expand All @@ -229,19 +288,35 @@ def __init__(self, *, prefix, config, weights):
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
self.qkv_proj = load_attention(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
prefix,
weights,
layer_id,
model_type,
self.head_size,
self.num_heads,
self.num_key_value_heads,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
),
layer_id,
f'{model_type}_{O_PROJ}',
process_group=weights.process_group,
)

def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: AdapterBatchData = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
qkv = self.qkv_proj(hidden_state, adapter_data)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
Expand All @@ -267,21 +342,33 @@ def forward(
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)

output = self.o_proj(attn_output)
output = self.o_proj(attn_output, adapter_data)
return output


class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
def __init__(self, *, prefix, config, weights, is_gated: bool, layer_id: int, model_type: str):
super().__init__()

self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size

self.self_attn = MllamaVisionSdpaAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
self.mlp = MllamaVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
model_type=model_type,
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
layer_id=layer_id,
model_type=model_type,
)

self.input_layernorm = nn.LayerNorm.load(prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05)
self.post_attention_layernorm = nn.LayerNorm.load(
Expand All @@ -297,47 +384,52 @@ def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: AdapterBatchData = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
hidden_state = self.self_attn(hidden_state, attention_mask, adapter_data)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state

# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
hidden_state = self.mlp(hidden_state, adapter_data)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state


class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int, model_type: str):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
is_gated=is_gated,
layer_id=layer_id,
model_type=model_type,
)
for i in range(num_layers)
for layer_id in range(num_layers)
]

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
adapter_data: AdapterBatchData = None,
):
encoder_states = [hidden_states]
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
adapter_data,
)

hidden_states = layer_outputs
Expand Down Expand Up @@ -465,13 +557,15 @@ def __init__(self, *, prefix, config, weights):
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
model_type='VISION_TRANSFORMER',
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
model_type='VISION_GLOBAL_TRANSFORMER',
)

def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
Expand All @@ -485,6 +579,7 @@ def forward(
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
adapter_data: AdapterBatchData,
) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape

Expand Down Expand Up @@ -538,6 +633,7 @@ def forward(
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
adapter_data=adapter_data,
)
intermediate_hidden_states = [
hidden_state
Expand All @@ -560,7 +656,11 @@ def forward(
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(hidden_state, attention_mask=attention_mask)
hidden_state, _ = self.global_transformer(
hidden_state,
attention_mask=attention_mask,
adapter_data=adapter_data,
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
Expand Down Expand Up @@ -854,12 +954,12 @@ def create_layer(layer_id, prefix, config, weights):
self.dtype = weights.dtype
self.device = weights.device

def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask, adapter_data):
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask)
vision_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask, adapter_data)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
Expand Down
21 changes: 4 additions & 17 deletions server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,14 @@
PreTrainedTokenizerBase,
)

from lorax_server.models.custom_modeling.mllama import FlashLlamaCrossLayer
from lorax_server.models.metadata_kernels import block_tables_to_ragged
from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
from lorax_server.pb import generate_pb2
from lorax_server.utils.attention.common import Seqlen
from lorax_server.models.metadata_kernels import block_tables_to_ragged
from lorax_server.utils.lora import DOWN_PROJ, FC1, FC2, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ
from lorax_server.utils.state import PREFIX_CACHING
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.models.custom_modeling.mllama import FlashLlamaCrossLayer

from lorax_server.utils.lora import (
DOWN_PROJ,
GATE_PROJ,
K_PROJ,
LM_HEAD,
O_PROJ,
Q_PROJ,
UP_PROJ,
V_PROJ,
FC1,
FC2
)

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -246,8 +234,6 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head)

# base_model.model.vision_model.transformer.layers.17.self_attn.v_proj.lora_A.weight
# vision_model.transformer.layers.17.self_attn.v_proj
vision_layer_mappings = [
("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers),
("vision_model.transformer.layers", self.model.vision_model.transformer.layers),
Expand Down Expand Up @@ -361,6 +347,7 @@ def forward(
pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
adapter_data=adapter_data,
)
batch.cross_attention_states = cross_attention_states

Expand Down

0 comments on commit 50c84f6

Please sign in to comment.