Skip to content

Commit

Permalink
add support for adapter loading in mllama
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtejankar committed Nov 8, 2024
1 parent a2c1fc1 commit c279221
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 6 deletions.
18 changes: 13 additions & 5 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,19 @@ def load(
layer_type: str,
unused_weight_names: Set[str],
) -> Optional[AdapterWeights]:
# for vlm models we need to return list of layers
# so nlayers is a list of ints in this case but in others its just an int
nlayers = model.get_num_layers_for_type(layer_type)
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
if type(nlayers) == int:
lora_a_list = [None] * nlayers
lora_b_list = [None] * nlayers
layer_ids = list(range(nlayers))
else:
lora_a_list = [None] * len(nlayers)
lora_b_list = [None] * len(nlayers)
layer_ids = nlayers

for layer_id in range(nlayers):
for i, layer_id in enumerate(layer_ids):
key = (layer_id, layer_type)
weight_name, layer = model.target_to_layer[key]

Expand Down Expand Up @@ -184,8 +192,8 @@ def load(

# Merge scaling factor into lora_b due to associativity of matrix multiplication:
# (A * B) * C = A * (B * C)
lora_a_list[layer_id] = lora_a.transpose(0, 1)
lora_b_list[layer_id] = lora_b.transpose(0, 1) * scale
lora_a_list[i] = lora_a.transpose(0, 1)
lora_b_list[i] = lora_b.transpose(0, 1) * scale

# pad lora ranks to be compatible with sgmv
lora_a_list = [pad_rank(w, dim=1, world_size=model.world_size) for w in lora_a_list]
Expand Down
104 changes: 103 additions & 1 deletion server/lorax_server/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,28 @@
from lorax_server.models.vlm_causal_lm import VlmCausalLM, VlmCausalLMBatch
from lorax_server.pb import generate_pb2
from lorax_server.utils.attention.common import Seqlen
from lorax_server.utils.attention.utils import block_tables_to_ragged
from lorax_server.models.metadata_kernels import block_tables_to_ragged
from lorax_server.utils.state import PREFIX_CACHING
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.models.custom_modeling.mllama import FlashLlamaCrossLayer

from lorax_server.utils.lora import (
DOWN_PROJ,
GATE_PROJ,
K_PROJ,
LM_HEAD,
O_PROJ,
Q_PROJ,
UP_PROJ,
V_PROJ,
FC1,
FC2
)

tracer = trace.get_tracer(__name__)

TEXT_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, GATE_PROJ, UP_PROJ, DOWN_PROJ, LM_HEAD]
VISION_ADAPTER_LAYERS = [Q_PROJ, K_PROJ, V_PROJ, O_PROJ, FC1, FC2]

@dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch):
Expand Down Expand Up @@ -175,6 +191,92 @@ def from_pb(


class MllamaCausalLM(VlmCausalLM):

@property
def supports_adapter_loading(self) -> bool:
return True

@property
def adapter_layers(self) -> List[str]:
return [f'TEXT_{layer_type}' for layer_type in TEXT_ADAPTER_LAYERS] \
+ [f'VISION_GLOBAL_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS] \
+ [f'VISION_TRANSFORMER_{layer_type}' for layer_type in VISION_ADAPTER_LAYERS]

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]

def get_num_layers_for_type(self, layer_type: str) -> int:
if 'LM_HEAD' in layer_type:
return 1
if 'TEXT_' in layer_type:
return [
layer_id
for layer_id, layer in enumerate(self.model.text_model.model.layers)
if not isinstance(layer, FlashLlamaCrossLayer)
]
if 'VISION_GLOBAL_TRANSFORMER_' in layer_type:
return len(self.model.vision_model.global_transformer.layers)
if 'VISION_TRANSFORMER_' in layer_type:
return len(self.model.vision_model.transformer.layers)

def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "language_model.model.layers"
for i, layer in enumerate(self.model.text_model.model.layers):
if isinstance(layer, FlashLlamaCrossLayer):
continue
layer_weights[(i, f'TEXT_{Q_PROJ}')] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{K_PROJ}')] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{V_PROJ}')] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, f'TEXT_{O_PROJ}')] = (f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj)

layer_weights[(i, f'TEXT_{GATE_PROJ}')] = (f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{UP_PROJ}')] = (f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj)
layer_weights[(i, f'TEXT_{DOWN_PROJ}')] = (f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj)
layer_weights[(0, f'TEXT_{LM_HEAD}')] = ("base_model.model.language_model.lm_head", self.model.text_model.lm_head)

# base_model.model.vision_model.transformer.layers.17.self_attn.v_proj.lora_A.weight
# vision_model.transformer.layers.17.self_attn.v_proj
vision_layer_mappings = [
("vision_model.global_transformer.layers", self.model.vision_model.global_transformer.layers),
("vision_model.transformer.layers", self.model.vision_model.transformer.layers),
]
for prefix, layer_list in vision_layer_mappings:
layer_type_prefix = 'VISION_GLOBAL_TRANSFORMER' if 'global_transformer' in prefix else 'VISION_TRANSFORMER'
for i, layer in enumerate(layer_list):
layer_weights[(i, f'{layer_type_prefix}_{Q_PROJ}')] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{K_PROJ}')] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{V_PROJ}')] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.qkv_proj,
)
layer_weights[(i, f'{layer_type_prefix}_{O_PROJ}')] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj
)

layer_weights[(i, f'{layer_type_prefix}_{FC1}')] = (f"{prefix}.{i}.mlp.fc1", layer.mlp.fc1)
layer_weights[(i, f'{layer_type_prefix}_{FC2}')] = (f"{prefix}.{i}.mlp.fc2", layer.mlp.fc2)

return layer_weights

def forward(
self,
batch: VlmCausalLMBatch,
Expand Down
3 changes: 3 additions & 0 deletions server/lorax_server/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,7 @@
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"

FC1 = 'fc1'
FC2 = 'fc2'

LM_HEAD = "lm_head"

0 comments on commit c279221

Please sign in to comment.