Skip to content

Commit

Permalink
Make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 25, 2025
1 parent 1ac9058 commit d251080
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
10 changes: 8 additions & 2 deletions mergekit/io/lazy_tensor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def __init__(self, index: ShardedTensorIndex, lazy_unpickle: bool = True):
self.lazy_unpickle = lazy_unpickle

def get_tensor(
self, key: str, device: str = "cpu", aliases: Optional[List[str]] = None
self,
key: str,
device: str = "cpu",
aliases: Optional[List[str]] = None,
raise_on_missing: bool = True,
) -> Optional[Tensor]:
if aliases and key not in self.index.tensor_paths:
for alias in aliases:
Expand All @@ -124,7 +128,9 @@ def get_tensor(

if self.current_shard is None or key not in self.current_shard.keys():
if key not in self.index.tensor_paths:
raise KeyError(key)
if raise_on_missing:
raise KeyError(key)
return None

self.current_shard = None
self.current_keys = None
Expand Down
73 changes: 69 additions & 4 deletions mergekit/scripts/extract_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoModelForCausalLM
from transformers.pytorch_utils import Conv1D

from mergekit.architecture import WeightInfo, get_architecture_info
from mergekit.card import generate_card_lora
from mergekit.common import ModelReference
from mergekit.io import LazyTensorLoader
Expand Down Expand Up @@ -200,7 +201,7 @@ def validate_and_combine_details(
if base_type == "embedding" or base_type == "output":
if not extend_vocab:
logging.warning(
f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition! To preserve all embeddings, invoke script with --extend-vocab"
f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition! To preserve all embeddings, invoke script with --extend-vocab or --save-module={base_name}."
)
else:
logging.warning(
Expand All @@ -211,7 +212,7 @@ def validate_and_combine_details(
f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition!"
)

else:
elif base_type != "to_save":
assert (
base_size == finetuned_size
), f"Dimension mismatch in layer '{base_name}': {base_size} != {finetuned_size}"
Expand All @@ -221,6 +222,64 @@ def validate_and_combine_details(
return module_details, base_model_embedding_size, finetuned_model_embedding_size


def build_wi_map(base_model_ref: ModelReference, trust_remote_code: bool = False):
weight_info_map = {}
base_cfg = base_model_ref.config(trust_remote_code=trust_remote_code)
try:
arch_info = get_architecture_info(base_cfg)
except RuntimeError as e:
logging.error(
f"Failed to load architecture info for model {base_model_ref}: {e}"
)
return {}
for weight_info in arch_info.all_weights(base_cfg):
weight_info_map[weight_info.name] = weight_info
return weight_info_map


def load_weights(
wi_map: Dict[str, WeightInfo],
base_loader: LazyTensorLoader,
finetuned_loader: LazyTensorLoader,
module_name: str,
):
optional = False
aliases = None
tied_names = None
if weight_info := wi_map.get(module_name + ".weight"):
if weight_info.optional:
optional = True
if weight_info.aliases:
aliases = weight_info.aliases
if weight_info.tied_names:
tied_names = weight_info.tied_names

base_weight = base_loader.get_tensor(
f"{module_name}.weight", aliases=aliases, raise_on_missing=False
)
finetuned_weight = finetuned_loader.get_tensor(
f"{module_name}.weight", aliases=aliases, raise_on_missing=False
)
if optional and (base_weight is None and finetuned_weight is None):
return None, None
if tied_names:
if base_weight is None:
base_weight = base_loader.get_tensor(
f"{module_name}.weight", aliases=tied_names, raise_on_missing=False
)
if finetuned_weight is None:
finetuned_weight = finetuned_loader.get_tensor(
f"{module_name}.weight", aliases=tied_names, raise_on_missing=False
)
if base_weight is None:
raise RuntimeError(f"Missing base weight for {module_name}")
if finetuned_weight is None:
if optional:
return None, None
raise RuntimeError(f"Missing finetuned weight for {module_name}")
return base_weight, finetuned_weight


def extract_lora(
module_details: List[Tuple[str, str]],
base_model_ref: ModelReference,
Expand Down Expand Up @@ -252,9 +311,14 @@ def extract_lora(
lora_weights = {}
ranks = {}

wi_map = build_wi_map(base_model_ref)

for module_type, module_name in tqdm(module_details):
base_weight = base_loader.get_tensor(f"{module_name}.weight")
finetuned_weight = finetuned_loader.get_tensor(f"{module_name}.weight")
base_weight, finetuned_weight = load_weights(
wi_map, base_loader, finetuned_loader, module_name
)
if base_weight is None and finetuned_weight is None:
continue

if module_type == "to_save":
lora_weights[
Expand Down Expand Up @@ -594,6 +658,7 @@ def main(
ModelReference.parse(finetuned_model).model.path,
skip_undecomposable,
extend_vocab,
modules_to_save=modules_to_save,
)

lora_weights, ranks = extract_lora(
Expand Down

0 comments on commit d251080

Please sign in to comment.