diff --git a/README.md b/README.md index cbf93b12..b041c9f2 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,39 @@ `mergekit` is a toolkit for merging pre-trained language models. `mergekit` uses an out-of-core approach to perform unreasonably elaborate merges in resource-constrained situations. Merges can be run entirely on CPU or accelerated with as little as 8 GB of VRAM. Many merging algorithms are supported, with more coming as they catch my attention. -Features: +## Contents + +- [Why Merge Models?](#why-merge-models) +- [Features](#features) +- [Installation](#installation) +- [Usage](#usage) +- [Merge Configuration](#merge-configuration) + - [Parameter Specification](#parameter-specification) + - [Tokenizer Configuration](#tokenizer-configuration) + - [Chat Template Configuration](#chat-template-configuration) + - [Examples](#examples) +- [Merge Methods](#merge-methods) +- [LoRA extraction](#lora-extraction) +- [Mixture of Experts merging](#mixture-of-experts-merging) +- [Evolutionary merge methods](#evolutionary-merge-methods) +- [Merge in the Cloud](#-merge-in-the-cloud-) +- [Citation](#citation) + +## Why Merge Models? + +Model merging is a powerful technique that allows combining the strengths of different models without the computational overhead of ensembling or the need for additional training. By operating directly in the weight space of models, merging can: + +- Combine multiple specialized models into a single versatile model +- Transfer capabilities between models without access to training data +- Find optimal trade-offs between different model behaviors +- Improve performance while maintaining inference costs +- Create new capabilities through creative model combinations + +Unlike traditional ensembling which requires running multiple models, merged models maintain the same inference cost as a single model while often achieving comparable or superior performance. + +## Features + +Key features of `mergekit` include: - Supports Llama, Mistral, GPT-NeoX, StableLM, and more - Many [merge methods](#merge-methods) @@ -52,7 +84,7 @@ When you have a merged model you're happy with, you may want to share it on the Once you're happy with your model card and merged model, you can upload it to the Hugging Face Hub using the [huggingface_hub](https://huggingface.co/docs/huggingface_hub/index) Python library. -``` +```sh # log in to huggingface with an access token (must have write permission) huggingface-cli login # upload your model @@ -72,7 +104,8 @@ Below are the primary elements of a configuration file: - `base_model`: Specifies the base model used in some merging methods. - `parameters`: Holds various parameters such as weights and densities, which can also be specified at different levels of the configuration. - `dtype`: Specifies the data type used for the merging operation. -- `tokenizer_source`: Determines how to construct a tokenizer for the merged model. +- `tokenizer` or `tokenizer_source`: Determines how to construct a tokenizer for the merged model. +- `chat_template`: Specifies a chat template for the merged model. ### Parameter Specification @@ -90,23 +123,112 @@ The parameters can be set at different levels, with decreasing precedence as fol 3. `models.*.parameters` or `input_model_parameters` - applying to any tensors coming from specific input models 4. `parameters` - catchall -### Tokenizer Source +### Tokenizer Configuration + +The tokenizer behavior can be configured in two ways: using the new `tokenizer` field (recommended) or the legacy `tokenizer_source` field (maintained for backward compatibility). These fields are mutually exclusive - you should use one or the other, not both. -The `tokenizer_source` field of a configuration file determines what tokenizer is used by the merged model. This also effects how embeddings and language model heads are merged. +#### Modern Configuration (tokenizer) + +The `tokenizer` field provides fine-grained control over vocabulary and embeddings: + +```yaml +tokenizer: + source: "union" # or "base" or a specific model path + tokens: # Optional: configure specific tokens + : + source: ... # Specify embedding source + force: false # Optional: force this embedding for all models + pad_to_multiple_of: null # Optional: pad vocabulary size +``` -This functionality is still experimental and may break. Please file an issue if you encounter any issues with it. +##### Tokenizer Source -Valid values: +The `source` field determines the vocabulary of the output model: -- `base`: use the tokenizer from the base model -- `union`: construct a tokenizer with all tokens from all models -- `model:`: use the tokenizer from a specific model +- `union`: Combine vocabularies from all input models (default) +- `base`: Use vocabulary from the base model +- `"path/to/model"`: Use vocabulary from a specific model -If set, mergekit will find a mapping between each model's vocabulary and the output tokenizer. This allows models with different vocabularies or added tokens to be meaningfully merged. +##### Token Embedding Handling + +When merging models with different vocabularies, mergekit uses smart defaults to handle token embeddings: + +- If a token exists in the base model, its embedding is used as the default +- If only one model has the token, that model's embedding is used +- Otherwise, an average of all available embeddings is used + +You can override these defaults for specific tokens: + +```yaml +tokenizer: + source: union + tokens: + # Use embedding from a specific model + <|im_start|>: + source: "path/to/chatml/model" + + # Force a specific embedding for all models + <|special|>: + source: "path/to/model" + force: true + + # Map a token to another model's token embedding + <|renamed_token|>: + source: + kind: "model_token" + model: "path/to/model" + token: "<|original_token|>" # or use token_id: 1234 +``` + +##### Practical Example + +Here's how you might preserve both Llama 3 Instruct and ChatML prompt formats when merging models: + +```yaml +tokenizer: + source: union + tokens: + # ChatML tokens + <|im_start|>: + source: "chatml_model" + <|im_end|>: + source: "chatml_model" + + # Llama 3 tokens - force original embeddings + <|start_header_id|>: + source: "llama3_model" + force: true + <|end_header_id|>: + source: "llama3_model" + force: true + <|eot_id|>: + source: "llama3_model" + force: true +``` + +#### Legacy Configuration (tokenizer_source) + +For backward compatibility, the `tokenizer_source` field is still supported: + +```yaml +tokenizer_source: "union" # or "base" or a model path +``` -`tokenizer_source` is compatible with all merge methods, but when used `lm_head`/`embed_tokens` will be merged linearly. For two-model merges, the `embed_slerp` parameter can be set to `true` to use SLERP instead. +This provides basic tokenizer selection but lacks the fine-grained control of the modern `tokenizer` field. -If the `tokenizer_source` field is not set, mergekit will fall back to its legacy default behavior. The tokenizer for the base model (or first model in the merge, if no base model is specified) will be copied to the output directory. The parameter matrices for `lm_head`/`embed_tokens` will be truncated to the smallest size present in the merge. In _most_ cases this corresponds to using the tokenizer for the base model. +### Chat Template Configuration + +The optional `chat_template` field allows overriding the chat template used for the merged model. + +```yaml +chat_template: "auto" # or a template name or Jinja2 template +``` + +Options include: + +- `"auto"`: Automatically select the most common template among input models +- Built-in templates: `"alpaca"`, `"chatml"`, `"llama3"`, `"mistral"`, `"exaone"` +- A Jinja2 template string for custom formatting ### Examples @@ -120,6 +242,7 @@ A quick overview of the currently supported merge methods: | ------------------------------------------------------------------------------------------------ | -------------------- | ----------- | --------------- | | Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | ✅ | ❌ | | SLERP | `slerp` | ❌ | ✅ | +| Nearswap | `nearswap` | ❌ | ✅ | | [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | ✅ | ✅ | | [TIES](https://arxiv.org/abs/2306.01708) | `ties` | ✅ | ✅ | | [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | ✅ | ✅ | @@ -128,8 +251,11 @@ A quick overview of the currently supported merge methods: | [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) | `breadcrumbs` | ✅ | ✅ | | [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) + [TIES](https://arxiv.org/abs/2306.01708) | `breadcrumbs_ties` | ✅ | ✅ | | [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ | -| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | -| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | +| NuSLERP | `nuslerp` | ❌ | ✅ | +| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ | +| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ | +| [SCE](https://arxiv.org/abs/2408.07990) | `sce` | ✅ | ✅ | + ### Linear The classic merge method - a simple weighted average. @@ -147,6 +273,14 @@ Parameters: - `t` - interpolation factor. At `t=0` will return `base_model`, at `t=1` will return the other one. +### Nearswap + +Interpolates base model with secondary model if similarity is below t. Accepts two models. + +Parameters: + +- `t` - similarity threshold + ### [Task Arithmetic](https://arxiv.org/abs/2212.04089) Computes "task vectors" for each model by subtracting a base model. Merges the task vectors linearly and adds back the base. Works great for models that were fine tuned from a common ancestor. Also a super useful mental framework for several of the more involved merge methods. @@ -173,7 +307,7 @@ Parameters: same as [TIES](#ties) for `dare_ties`, or [Linear](#linear) for `dar ### [Model Breadcrumbs](https://arxiv.org/abs/2312.06795) -An extension of task arithmetic that discards both small and and extremely large differences from the base model. As with DARE, the Model Breadcrumbs algorithm can be used with (`breadcrumbs_ties`) or without (`breadcrumbs`) the sign consensus algorithm of TIES. +An extension of task arithmetic that discards both small and extremely large differences from the base model. As with DARE, the Model Breadcrumbs algorithm can be used with (`breadcrumbs_ties`) or without (`breadcrumbs`) the sign consensus algorithm of TIES. Parameters: same as [Linear](#linear), plus: @@ -190,15 +324,36 @@ Parameters: - `filter_wise`: if true, weight calculation will be per-row rather than per-tensor. Not recommended. +### NuSLERP + +Spherically interpolate between parameters, but with more options and more sensical configuration! Does not require a base model, but can use one to do spherical interpolation of task vectors. Only works with either two models or two plus a base model. + +Parameters: + +- `weight`: relative weighting of a given tensor +- `nuslerp_flatten`: set to false to do row-wise/column-wise interpolation instead of treating tensors as vectors +- `nuslerp_row_wise`: SLERP row vectors instead of column vectors + +To replicate the behavior of the original `slerp` method, set `weight` to `1-t` and `t` for your first and second model respectively. + ### [DELLA](https://arxiv.org/abs/2406.11617) Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES Parameters: same as [Linear](#linear), plus: + - `density` - fraction of weights in differences from the base model to retain - `epsilon` - maximum change in drop probability based on magnitude. Drop probabilities assigned will range from `density - epsilon` to `density + epsilon`. (When selecting values for `density` and `epsilon`, ensure that the range of probabilities falls within 0 to 1) - `lambda` - scaling factor for the final merged delta parameters before merging with the base parameters. +### [SCE](https://arxiv.org/abs/2408.07990) + +SCE introduces adaptive matrix-level merging weights based on parameter variances. SCE first selects the top-k% elements from each parameter matrix that exhibit high variance across all delta parameters. Following this selection, SCE calculates matrix-level merging weights based on the sum of squares of elements in the delta parameters. Finally, it erases minority elements, a step similar to the sign election process in TIES. + +Parameters: + +- `select_topk` - fraction of elements with the highest variance in the delta parameters to retain. + ## LoRA extraction Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models. @@ -215,7 +370,7 @@ The `mergekit-moe` script supports merging multiple dense models into a mixture ## Evolutionary merge methods -See `docs/evolve.md` for details. +See [`docs/evolve.md`](docs/evolve.md) for details. ## ✨ Merge in the Cloud ✨ @@ -224,7 +379,7 @@ We host merging on Arcee's cloud GPUs - you can launch a cloud merge in the [Arc `export ARCEE_API_KEY=` `pip install -q arcee-py` -``` +```python import arcee arcee.merge_yaml("bio-merge","./examples/bio-merge.yml") ``` @@ -233,7 +388,7 @@ Check your merge status at the [Arcee App](https://app.arcee.ai) When complete, either deploy your merge: -``` +```python arcee.start_deployment("bio-merge", merging="bio-merge") ``` @@ -241,16 +396,32 @@ Or download your merge: `!arcee merging download bio-merge` - ## Citation -We now have a [paper](https://arxiv.org/abs/2403.13257) you can cite for the MergeKit library: +If you find `mergekit` useful in your research, please consider citing the [paper](https://aclanthology.org/2024.emnlp-industry.36/): ```bibtex -@article{goddard2024arcee, - title={Arcee's MergeKit: A Toolkit for Merging Large Language Models}, - author={Goddard, Charles and Siriwardhana, Shamane and Ehghaghi, Malikeh and Meyers, Luke and Karpukhin, Vlad and Benedict, Brian and McQuade, Mark and Solawetz, Jacob}, - journal={arXiv preprint arXiv:2403.13257}, - year={2024} +@inproceedings{goddard-etal-2024-arcees, + title = "Arcee{'}s {M}erge{K}it: A Toolkit for Merging Large Language Models", + author = "Goddard, Charles and + Siriwardhana, Shamane and + Ehghaghi, Malikeh and + Meyers, Luke and + Karpukhin, Vladimir and + Benedict, Brian and + McQuade, Mark and + Solawetz, Jacob", + editor = "Dernoncourt, Franck and + Preo{\c{t}}iuc-Pietro, Daniel and + Shimorina, Anastasia", + booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing: Industry Track", + month = nov, + year = "2024", + address = "Miami, Florida, US", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2024.emnlp-industry.36", + doi = "10.18653/v1/2024.emnlp-industry.36", + pages = "477--485", + abstract = "The rapid growth of open-source language models provides the opportunity to merge model checkpoints, combining their parameters to improve performance and versatility. Advances in transfer learning have led to numerous task-specific models, which model merging can integrate into powerful multitask models without additional training. MergeKit is an open-source library designed to support this process with an efficient and extensible framework suitable for any hardware. It has facilitated the merging of thousands of models, contributing to some of the world{'}s most powerful open-source model checkpoints. The library is accessible at: https://github.com/arcee-ai/mergekit.", } ``` diff --git a/docs/create_a_merge_method.md b/docs/create_a_merge_method.md new file mode 100644 index 00000000..f6cf539f --- /dev/null +++ b/docs/create_a_merge_method.md @@ -0,0 +1,236 @@ +# Extending MergeKit with Custom Merge Methods + +## Overview + +MergeKit offers two different paths for implementing custom merge methods: + +| | Decorator API | Class-based API | +| ---------------------- | --------------------- | --------------------------- | +| **Complexity** | Simple function-based | Full class implementation | +| **Abstraction Level** | Higher-level | Lower-level | +| **Parameter Handling** | Automatic validation | Manual configuration | +| **Execution Flow** | Single-step | Arbitrary computation graph | +| **Best For** | Simple tensor ops | Complex merge strategies | + +Either approach benefits from MergeKit's underlying task system for resource management and execution control. The question of which to use largely depends on the complexity of the merge operation and the level of control needed. + +### Core Task System Features + +MergeKit's computational graph infrastructure provides sophisticated resource management that all merge methods inherit: + +- **Smart Memory Management** + - Automatic return value lifecycle tracking + - Early value eviction when no longer needed + - Optimized shard loading based on task groups + +- **Device Management** + - Automatic tensor movement between compute and storage devices + - Support for both CPU and GPU execution + +- **Task Scheduling** + - Tasks grouped by tensor shard to minimize memory usage + - Loads deferred until last possible moment (via priority system) + - Execution ordered to optimize shard residency + +### Decorator API +Best for straightforward merge operations that can be expressed as a single tensor transformation. Features: +- Parameter validation and type checking +- Configuration schema generation +- Simplified base model handling +- Default GPU acceleration opt-in + +### Class-based API +Choose when you need: +- Multi-stage merge operations +- Custom computation graphs +- Direct access to weight metadata +- Complex parameter types +- Fine-grained control over execution + +## Decorator API Implementation + +### Basic Workflow +1. Define a type-annotated Python function with your merge logic +2. Add the `@merge_method` decorator with configuration +3. Register by importing in `mergekit/merge_methods/__init__.py` + +### Example: Weighted Average + +```python +from mergekit.merge_methods.easy_define import merge_method +from typing import List +import torch + +@merge_method( + name="weighted_average", + pretty_name="Weighted Average", # Optional: human-readable name + reference_url="https://example.com/docs", # Optional: documentation link +) +def average_merge( + tensors: List[torch.Tensor], # Required: input tensors + weight: List[float], # Vector parameter (per-model) + normalize: bool = True, # Scalar parameter with default +) -> torch.Tensor: + if normalize: + total = sum(weight) + weight = [w / total for w in weight] + + return sum(t * w for t, w in zip(tensors, weight)) +``` + +This enables configurations like: +```yaml +merge_method: weighted_average +models: + - model: model1 + parameters: + weight: 0.3 + - model: model2 + parameters: + weight: 0.7 +parameters: + normalize: true +``` + +### Parameter Types and Handling + +The decorator supports three parameter categories: + +1. **Scalar Parameters** + - Types: `bool`, `float`, or `int` + - Defined in top-level `parameters` section + - Without defaults they become required parameters + - Example: `normalize: bool = True` + +2. **Vector Parameters** + - Types: `List[float]` or `List[int]` only + - Configured per-model in their `parameters` section + - Default values must be single numbers, not lists, as they are broadcasted + - Example: `weights: List[float]` + +3. **Base Model Integration** + - Via `base_tensor` parameter annotation: + * `torch.Tensor`: Base model required + * `Optional[torch.Tensor]`: Base model optional + - Without `base_tensor`: Base model tensor goes first in `tensors` list if present + +## Class-based API + +For complex merges requiring granular control, implement `MergeMethod` and `Task` classes: + +### Example Implementation + +```python +from mergekit.merge_methods.base import MergeMethod, ConfigParameterDef +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from typing import Any, Dict, List + + +class CustomMergeTask(Task[torch.Tensor]): + gather_tensors: MergeTensorInput + parameters: ImmutableMap[str, Any] + weight_info: WeightInfo + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def priority(self) -> int: + return 1 # Optional: higher priority = earlier execution + + def group_label(self) -> str: + return self.weight_info.name # Optional: modify task grouping + + def uses_accelerator(self) -> bool: + return True # Enable GPU acceleration + + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: + # Implementation using weight info and parameters + result = ... + return result + + +class CustomMerge(MergeMethod): + def name(self) -> str: + return "custom_merge" + + def pretty_name(self) -> str: + return "Custom Merge" + + def reference_url(self) -> str: + return "https://example.com/custom" + + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef("threshold", float, required=False, default_value=0.5) + ] + + def tensor_parameters(self) -> List[ConfigParameterDef]: + return [ConfigParameterDef("weight", float, required=True)] + + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: MergeTensorInput, + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + **kwargs, + ) -> Task: + return CustomMergeTask( + gather_tensors=tensors, + parameters=parameters, + weight_info=output_weight, + ) +``` + +### Task Scheduling System + +The class-based API provides fine-grained control over execution: + +- **Priority Control**: Override `priority()` to influence execution order within groups +- **Task Grouping**: Use `group_label()` to batch similar operations +- **Resource Management**: + - Automatic tensor lifecycle tracking + - Memory optimization via early tensor eviction + - Smart device placement for computation vs storage +- **Computation Graph**: Build complex flows by connecting multiple tasks + +### Implementation Requirements + +1. Task Class: + - Must implement `execute()` with proper type annotations + - Must implement `arguments()` to declare dependencies + - Optionally override `priority()`, `group_label()`, `uses_accelerator()` + +2. Method Class: + - Must implement core methods: `name()`, `make_task()` + - Optional methods: `pretty_name()`, `reference_url()` + - Define parameters via `parameters()` and `tensor_parameters()` + +### Registration + +Add class-based methods to `STATIC_MERGE_METHODS` in `mergekit/merge_methods/registry.py`: + +```python +from mergekit.merge_methods.my_module import CustomMerge + +STATIC_MERGE_METHODS: List[MergeMethod] = [ + CustomMerge(), + # other methods... +] +``` + +## Reference Implementations + +1. **Linear Merge** (`mergekit.merge_methods.linear`): + - Basic weighted averaging + - Good example of class-based implementation + +2. **Multi-SLERP** (`mergekit.merge_methods.multislerp`): + - Hypersphere interpolation + - Complex decorator usage example + +3. **Task Arithmetic** (`mergekit.merge_methods.task_arithmetic`): + - Advanced graph-based implementation + - TIES/Magnitude pruning example diff --git a/docs/evolve.md b/docs/evolve.md index 2ac164a9..930fc279 100644 --- a/docs/evolve.md +++ b/docs/evolve.md @@ -121,7 +121,7 @@ Assigns an actor to each GPU in your cluster and guarantees merges and evaluatio #### `buffered` -Maintains a buffer of tasks scheduled to ensure that there is always a model mergign or ready to evaluate for each gpu. Allows for concurrent merging and evaluation of models on the same GPU if enough VRAM is available. Only suitable for a single-node setup or when `--storage-path` points to a fast shared filesystem. +Maintains a buffer of tasks scheduled to ensure that there is always a model merging or ready to evaluate for each GPU. Allows for concurrent merging and evaluation of models on the same GPU if enough VRAM is available. Only suitable for a single-node setup or when `--storage-path` points to a fast shared filesystem. #### `serial` diff --git a/mergekit/_data/architectures/bert-masked-lm.json b/mergekit/_data/architectures/bert-masked-lm.json index 3b0620fb..d6430e40 100644 --- a/mergekit/_data/architectures/bert-masked-lm.json +++ b/mergekit/_data/architectures/bert-masked-lm.json @@ -44,7 +44,8 @@ }, { "name": "cls.predictions.decoder.weight", - "aliases": [ + "optional": true, + "tied_names": [ "bert.embeddings.word_embeddings.weight" ], "is_embed": true diff --git a/mergekit/_data/architectures/distilbert-masked-lm.json b/mergekit/_data/architectures/distilbert-masked-lm.json index 6828cca2..1a079811 100644 --- a/mergekit/_data/architectures/distilbert-masked-lm.json +++ b/mergekit/_data/architectures/distilbert-masked-lm.json @@ -40,7 +40,8 @@ { "name": "vocab_projector.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "distilbert.embeddings.word_embeddings.weight" ] }, diff --git a/mergekit/_data/architectures/gemma2.json b/mergekit/_data/architectures/gemma2.json index 0c6372f0..52505245 100644 --- a/mergekit/_data/architectures/gemma2.json +++ b/mergekit/_data/architectures/gemma2.json @@ -54,7 +54,10 @@ { "name": "lm_head.weight", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/gpt2.json b/mergekit/_data/architectures/gpt2.json index 64a04e9d..fc7a3201 100644 --- a/mergekit/_data/architectures/gpt2.json +++ b/mergekit/_data/architectures/gpt2.json @@ -5,59 +5,116 @@ ], "pre_weights": [ { - "name": "wte.weight", - "is_embed": true + "name": "transformer.wte.weight", + "is_embed": true, + "aliases": [ + "wte.weight" + ] }, { - "name": "wpe.weight" + "name": "transformer.wpe.weight", + "aliases": [ + "wpe.weight" + ] } ], "post_weights": [ { - "name": "ln_f.weight" + "name": "transformer.ln_f.weight", + "aliases": [ + "ln_f.weight" + ] }, { - "name": "ln_f.bias" + "name": "transformer.ln_f.bias", + "aliases": [ + "ln_f.bias" + ] + }, + { + "name": "lm_head.weight", + "is_embed": true, + "optional": true, + "tied_names": [ + "transformer.wte.weight", + "wte.weight" + ] } ], "num_layers_config_key": "n_layer", "layer_templates": { "weights": [ { - "name": "h.${layer_index}.attn.c_attn.weight" + "name": "transformer.h.${layer_index}.attn.c_attn.weight", + "aliases": [ + "h.${layer_index}.attn.c_attn.weight" + ] }, { - "name": "h.${layer_index}.attn.c_attn.bias" + "name": "transformer.h.${layer_index}.attn.c_attn.bias", + "aliases": [ + "h.${layer_index}.attn.c_attn.bias" + ] }, { - "name": "h.${layer_index}.attn.c_proj.weight" + "name": "transformer.h.${layer_index}.attn.c_proj.weight", + "aliases": [ + "h.${layer_index}.attn.c_proj.weight" + ] }, { - "name": "h.${layer_index}.attn.c_proj.bias" + "name": "transformer.h.${layer_index}.attn.c_proj.bias", + "aliases": [ + "h.${layer_index}.attn.c_proj.bias" + ] }, { - "name": "h.${layer_index}.ln_1.weight" + "name": "transformer.h.${layer_index}.ln_1.weight", + "aliases": [ + "h.${layer_index}.ln_1.weight" + ] }, { - "name": "h.${layer_index}.ln_1.bias" + "name": "transformer.h.${layer_index}.ln_1.bias", + "aliases": [ + "h.${layer_index}.ln_1.bias" + ] }, { - "name": "h.${layer_index}.ln_2.weight" + "name": "transformer.h.${layer_index}.ln_2.weight", + "aliases": [ + "h.${layer_index}.ln_2.weight" + ] }, { - "name": "h.${layer_index}.ln_2.bias" + "name": "transformer.h.${layer_index}.ln_2.bias", + "aliases": [ + "h.${layer_index}.ln_2.bias" + ] }, { - "name": "h.${layer_index}.mlp.c_proj.weight" + "name": "transformer.h.${layer_index}.mlp.c_proj.weight", + "aliases": [ + "h.${layer_index}.mlp.c_proj.weight" + ] }, { - "name": "h.${layer_index}.mlp.c_proj.bias" + "name": "transformer.h.${layer_index}.mlp.c_proj.bias", + "aliases": [ + "h.${layer_index}.mlp.c_proj.bias" + ] }, { - "name": "h.${layer_index}.mlp.c_fc.weight" + "name": "transformer.h.${layer_index}.mlp.c_fc.weight", + "aliases": [ + "h.${layer_index}.mlp.c_fc.weight" + ] }, { - "name": "h.${layer_index}.mlp.c_fc.bias" + "name": "transformer.h.${layer_index}.mlp.c_fc.bias", + "aliases": [ + "h.${layer_index}.mlp.c_fc.bias" + ] } ] } diff --git a/mergekit/_data/architectures/gptbigcode.json b/mergekit/_data/architectures/gptbigcode.json index 4b086278..c12bac5c 100644 --- a/mergekit/_data/architectures/gptbigcode.json +++ b/mergekit/_data/architectures/gptbigcode.json @@ -21,7 +21,9 @@ }, { "name": "lm_head.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "transformer.wte.weight" ] } diff --git a/mergekit/_data/architectures/internlm2.json b/mergekit/_data/architectures/internlm2.json index 057bc649..888faa48 100644 --- a/mergekit/_data/architectures/internlm2.json +++ b/mergekit/_data/architectures/internlm2.json @@ -16,7 +16,8 @@ { "name": "output.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.tok_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/llama.json b/mergekit/_data/architectures/llama.json index 7106806b..00918a2c 100644 --- a/mergekit/_data/architectures/llama.json +++ b/mergekit/_data/architectures/llama.json @@ -74,7 +74,10 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "optional": true + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] } ] } diff --git a/mergekit/_data/architectures/mamba.json b/mergekit/_data/architectures/mamba.json index b3727dba..1c473532 100644 --- a/mergekit/_data/architectures/mamba.json +++ b/mergekit/_data/architectures/mamba.json @@ -16,7 +16,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["backbone.embeddings.weight"] + "optional": true, + "tied_names": [ + "backbone.embeddings.weight" + ] } ], "num_layers_config_key": "num_hidden_layers", diff --git a/mergekit/_data/architectures/phi3-small.json b/mergekit/_data/architectures/phi3-small.json index 7b3a1e80..f27dfac4 100644 --- a/mergekit/_data/architectures/phi3-small.json +++ b/mergekit/_data/architectures/phi3-small.json @@ -12,8 +12,9 @@ "post_weights": [ { "name": "lm_head.weight", - "is_embed":true, - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] }, diff --git a/mergekit/_data/architectures/qwen2.json b/mergekit/_data/architectures/qwen2.json index 638b3630..c7131523 100644 --- a/mergekit/_data/architectures/qwen2.json +++ b/mergekit/_data/architectures/qwen2.json @@ -16,7 +16,8 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.embed_tokens.weight" ] } diff --git a/mergekit/_data/architectures/roberta-masked-lm.json b/mergekit/_data/architectures/roberta-masked-lm.json index 492127a5..1aae76a1 100644 --- a/mergekit/_data/architectures/roberta-masked-lm.json +++ b/mergekit/_data/architectures/roberta-masked-lm.json @@ -8,7 +8,8 @@ "name": "roberta.embeddings.position_embeddings.weight" }, { - "name": "roberta.embeddings.word_embeddings.weight" + "name": "roberta.embeddings.word_embeddings.weight", + "is_embed": true }, { "name": "roberta.embeddings.token_type_embeddings.weight" @@ -43,7 +44,9 @@ }, { "name": "lm_head.decoder.weight", - "aliases": [ + "is_embed": true, + "optional": true, + "tied_names": [ "roberta.embeddings.word_embeddings.weight" ] } diff --git a/mergekit/_data/architectures/solar.json b/mergekit/_data/architectures/solar.json index 7bd6a751..78fd5998 100644 --- a/mergekit/_data/architectures/solar.json +++ b/mergekit/_data/architectures/solar.json @@ -73,7 +73,8 @@ "name": "lm_head.weight", "input_space": "running_residual", "is_embed": true, - "aliases": [ + "optional": true, + "tied_names": [ "model.lm_head.weight" ] } diff --git a/mergekit/_data/architectures/starcoder2.json b/mergekit/_data/architectures/starcoder2.json index 851fdd1a..c2266899 100644 --- a/mergekit/_data/architectures/starcoder2.json +++ b/mergekit/_data/architectures/starcoder2.json @@ -13,7 +13,10 @@ { "name": "lm_head.weight", "is_embed": true, - "aliases": ["model.embed_tokens.weight"] + "optional": true, + "tied_names": [ + "model.embed_tokens.weight" + ] }, { "name": "model.norm.bias" diff --git a/mergekit/architecture.py b/mergekit/architecture.py index 4c7b4625..40872160 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture.py @@ -41,6 +41,8 @@ class WeightInfo(BaseModel, frozen=True): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. + tied_names (Optional[List[str]]): + List of names for weights that are tied to this weight, if applicable. force_dtype (Optional[str]): Mandatory dtype for the weight, if applicable. """ @@ -50,7 +52,9 @@ class WeightInfo(BaseModel, frozen=True): input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False + tied: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None force_dtype: Optional[str] = None head_split: Literal[None, "input", "output"] = None is_kq: Optional[bool] = False diff --git a/mergekit/card.py b/mergekit/card.py index bf0a2d0a..85eeba1e 100644 --- a/mergekit/card.py +++ b/mergekit/card.py @@ -22,6 +22,7 @@ from huggingface_hub.utils import HFValidationError from yaml.nodes import SequenceNode as SequenceNode +from mergekit import merge_methods from mergekit.config import MergeConfiguration, ModelReference CARD_TEMPLATE = """--- @@ -110,16 +111,15 @@ def method_md(merge_method: str) -> str: Args: merge_method: A string indicating the merge method used. """ - methods = { - "linear": "[linear](https://arxiv.org/abs/2203.05482)", - "ties": "[TIES](https://arxiv.org/abs/2306.01708)", - "slerp": "SLERP", - "task_arithmetic": "[task arithmetic](https://arxiv.org/abs/2212.04089)", - "dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)", - "dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)", - "model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)", - } - return methods.get(merge_method, merge_method) + try: + method = merge_methods.get(merge_method) + except RuntimeError: + return merge_method + ref_url = method.reference_url() + name = method.pretty_name() or method.name() + if ref_url and ref_url.strip(): + return f"[{name}]({ref_url})" + return name def maybe_link_hf(path: str) -> str: diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index 0f9f42fc..0e107aeb 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -213,7 +213,7 @@ def _maybe_init_model(self, config: MergeConfiguration): tokenizer_donor = self.genome.definition.base_model if tokenizer_donor is None: logging.warning( - f"Base model not set, using tokenizer from first model in genome" + "Base model not set, using tokenizer from first model in genome" ) tokenizer_donor = self.genome.definition.models[0] tok = transformers.AutoTokenizer.from_pretrained( diff --git a/mergekit/io/lazy_tensor_loader.py b/mergekit/io/lazy_tensor_loader.py index e79c5714..d491f8dd 100644 --- a/mergekit/io/lazy_tensor_loader.py +++ b/mergekit/io/lazy_tensor_loader.py @@ -114,7 +114,11 @@ def __init__(self, index: ShardedTensorIndex, lazy_unpickle: bool = True): self.lazy_unpickle = lazy_unpickle def get_tensor( - self, key: str, device: str = "cpu", aliases: Optional[List[str]] = None + self, + key: str, + device: str = "cpu", + aliases: Optional[List[str]] = None, + raise_on_missing: bool = True, ) -> Optional[Tensor]: if aliases and key not in self.index.tensor_paths: for alias in aliases: @@ -124,7 +128,9 @@ def get_tensor( if self.current_shard is None or key not in self.current_shard.keys(): if key not in self.index.tensor_paths: - raise KeyError(key) + if raise_on_missing: + raise KeyError(key) + return None self.current_shard = None self.current_keys = None diff --git a/mergekit/io/tasks.py b/mergekit/io/tasks.py index 70dffc41..499ad4c0 100644 --- a/mergekit/io/tasks.py +++ b/mergekit/io/tasks.py @@ -67,12 +67,15 @@ class LoadTensor(Task[Optional[torch.Tensor]]): device: Optional[str] = None optional: bool = False aliases: Optional[Tuple[str, ...]] = None + tied_names: Optional[Tuple[str, ...]] = None def arguments(self) -> Dict[str, Task]: return {} def _resolve_name(self, loader: LazyTensorLoader) -> Optional[str]: - all_names = [self.tensor] + list(self.aliases or []) + all_names = ( + [self.tensor] + list(self.aliases or []) + list(self.tied_names or []) + ) for name in all_names: if name in loader.index.tensor_paths: return name @@ -120,6 +123,7 @@ def arguments(self) -> Dict[str, Task]: device=self.device, optional=wi.optional, aliases=wi.aliases, + tied_names=wi.tied_names, ) for (model, wi) in self.weight_info.items() } diff --git a/mergekit/io/tensor_writer.py b/mergekit/io/tensor_writer.py index 1483a3c3..9ea58222 100644 --- a/mergekit/io/tensor_writer.py +++ b/mergekit/io/tensor_writer.py @@ -121,7 +121,7 @@ def finalize(self): json.dump( { "metadata": { - "mergekit_version": "0.0.4.4", + "mergekit_version": "0.0.5.2", "total_size": self.total_size, }, "weight_map": self.weight_map, diff --git a/mergekit/merge.py b/mergekit/merge.py index 60189f44..2d659505 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -98,7 +98,10 @@ def run_merge( tokenizer = value.tokenizer if tokenizer: - _update_config_vocab(cfg_out, tokenizer) + pad_to_multiple_of = None + if merge_config.tokenizer and merge_config.tokenizer.pad_to_multiple_of: + pad_to_multiple_of = merge_config.tokenizer.pad_to_multiple_of + _update_config_vocab(cfg_out, tokenizer, pad_to_multiple_of=pad_to_multiple_of) logging.info("Saving config") cfg_out.save_pretrained(out_path) @@ -263,9 +266,13 @@ def _model_out_config( def _update_config_vocab( config: transformers.PretrainedConfig, tokenizer: transformers.PreTrainedTokenizerBase, + pad_to_multiple_of: Optional[int] = None, ): + vocab_size = len(tokenizer.get_vocab()) + if pad_to_multiple_of and vocab_size % pad_to_multiple_of: + vocab_size = vocab_size + pad_to_multiple_of - (vocab_size % pad_to_multiple_of) try: - config.vocab_size = len(tokenizer.get_vocab()) + config.vocab_size = vocab_size except Exception as e: logging.warning( "Unable to set vocabulary size in output config - you may need to manually correct it.", diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 007e163e..e9cf802a 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -13,86 +13,20 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. +import mergekit.merge_methods.multislerp from mergekit.merge_methods.base import MergeMethod from mergekit.merge_methods.generalized_task_arithmetic import ( - ConsensusMethod, GeneralizedTaskArithmeticMerge, - SparsificationMethod, ) from mergekit.merge_methods.linear import LinearMerge -from mergekit.merge_methods.model_stock import ModelStockMerge from mergekit.merge_methods.passthrough import PassthroughMerge +from mergekit.merge_methods.registry import REGISTERED_MERGE_METHODS from mergekit.merge_methods.slerp import SlerpMerge -from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge def get(method: str) -> MergeMethod: - if method == "linear": - return LinearMerge() - elif method == "slerp": - return SlerpMerge() - elif method == "passthrough": - return PassthroughMerge() - elif method == "task_arithmetic": - return GeneralizedTaskArithmeticMerge( - consensus_method=None, - sparsification_method=None, - default_normalize=False, - default_rescale=False, - ) - elif method == "ties": - return GeneralizedTaskArithmeticMerge( - consensus_method=ConsensusMethod.sum, - sparsification_method=SparsificationMethod.magnitude, - default_normalize=True, - default_rescale=False, - ) - elif method == "dare_ties": - return GeneralizedTaskArithmeticMerge( - consensus_method=ConsensusMethod.sum, - sparsification_method=SparsificationMethod.random, - default_normalize=False, - default_rescale=True, - ) - elif method == "dare_linear": - return GeneralizedTaskArithmeticMerge( - consensus_method=None, - sparsification_method=SparsificationMethod.random, - default_normalize=False, - default_rescale=True, - ) - elif method == "breadcrumbs": - return GeneralizedTaskArithmeticMerge( - consensus_method=None, - sparsification_method=SparsificationMethod.magnitude_outliers, - default_normalize=False, - default_rescale=False, - ) - elif method == "breadcrumbs_ties": - return GeneralizedTaskArithmeticMerge( - consensus_method=ConsensusMethod.sum, - sparsification_method=SparsificationMethod.magnitude_outliers, - default_normalize=False, - default_rescale=False, - ) - elif method == "model_stock": - return ModelStockMerge() - - elif method == "della": - return GeneralizedTaskArithmeticMerge( - consensus_method=ConsensusMethod.sum, - sparsification_method=SparsificationMethod.rank_magnitude_sampling, - default_normalize=True, - default_rescale=True, - ) - - elif method == "della_linear": - return GeneralizedTaskArithmeticMerge( - consensus_method=None, - sparsification_method=SparsificationMethod.rank_magnitude_sampling, - default_normalize=False, - default_rescale=True, - ) + if method in REGISTERED_MERGE_METHODS: + return REGISTERED_MERGE_METHODS[method] raise RuntimeError(f"Unimplemented merge method {method}") @@ -100,8 +34,9 @@ def get(method: str) -> MergeMethod: "MergeMethod", "get", "LinearMerge", + "SCEMerge", "SlerpMerge", "PassthroughMerge", "GeneralizedTaskArithmeticMerge", - "TokenizerPermutationMerge", + "REGISTERED_MERGE_METHODS", ] diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index 917ed089..aabe5999 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -41,6 +41,16 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: def parameters(self) -> List[ConfigParameterDef]: return [] + @abstractmethod + def name(self) -> str: + ... + + def pretty_name(self) -> Optional[str]: + return None + + def reference_url(self) -> Optional[str]: + return None + @abstractmethod def make_task( self, diff --git a/mergekit/merge_methods/easy_define.py b/mergekit/merge_methods/easy_define.py new file mode 100644 index 00000000..b6aa839e --- /dev/null +++ b/mergekit/merge_methods/easy_define.py @@ -0,0 +1,322 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import inspect +import typing +from typing import Any, ClassVar, Dict, List, Optional + +import pydantic +import torch +from pydantic import Field +from typing_extensions import Callable + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) +from mergekit.merge_methods.registry import REGISTERED_MERGE_METHODS + +STANDARD_KWARGS = {"output_weight", "base_model"} + + +def __merge_method( + func: Callable, + name: str, + reference_url: Optional[str] = None, + pretty_name: Optional[str] = None, +) -> Callable: + use_base_tensor_arg = False + require_base_tensor = False + used_kwargs = set() + parameters: List[ConfigParameterDef] = [] + tensor_parameters: List[ConfigParameterDef] = [] + sig = inspect.signature(func) + if "tensors" not in sig.parameters: + raise ValueError("Merge methods must have a 'tensors' parameter") + tensor_param = sig.parameters["tensors"] + if ( + (tensor_param.annotation is None) + or (not hasattr(tensor_param.annotation, "__origin__")) + or not ( + tensor_param.annotation.__origin__ == list + and tensor_param.annotation.__args__ == (torch.Tensor,) + ) + ): + raise ValueError("'tensors' must be annotated with List[torch.Tensor]") + if "base_tensor" in sig.parameters: + bt_param = sig.parameters["base_tensor"] + if bt_param.annotation == torch.Tensor: + require_base_tensor = True + elif ( + hasattr(bt_param.annotation, "__origin__") + and bt_param.annotation.__origin__ == typing.Union + and ( + bt_param.annotation.__args__ == (torch.Tensor, type(None)) + or bt_param.annotation.__args__ == (type(None), torch.Tensor) + ) + ): + require_base_tensor = False + else: + raise ValueError( + "'base_tensor' must be annotated either torch.Tensor or Optional[torch.Tensor]" + ) + use_base_tensor_arg = True + for arg, arg_info in sig.parameters.items(): + if arg in ("base_tensor", "tensors"): + continue + if arg in STANDARD_KWARGS: + used_kwargs.add(arg) + else: + if arg_info.annotation is None: + raise ValueError( + "All merge method arguments must have type annotations" + ) + elif arg_info.annotation in (int, float, bool): + default_value = arg_info.default + if default_value == inspect.Parameter.empty: + default_value = None + required = True + else: + required = False + parameters.append( + ConfigParameterDef( + name=arg, required=required, default_value=default_value + ) + ) + elif ( + hasattr(arg_info.annotation, "__origin__") + and arg_info.annotation.__origin__ == list + and arg_info.annotation.__args__[0] in (float, int) + ): + default_value = arg_info.default + if default_value == inspect.Parameter.empty: + default_value = None + required = True + else: + required = False + if (not required) and (not isinstance(default_value, (int, float))): + raise ValueError( + f"Unexpected default for presumed tensor parameter '{arg}' - should be single number, got {repr(default_value)}" + ) + tensor_parameters.append( + ConfigParameterDef( + name=arg, required=required, default_value=default_value + ) + ) + + tt_fields = {} + tt_fields["gather_tensors"] = (MergeTensorInput, Field(...)) + if ("base_model" in used_kwargs) or use_base_tensor_arg: + bm_ty = ModelReference if require_base_tensor else Optional[ModelReference] + field_kwargs = {"default": None} if not require_base_tensor else {} + tt_fields["base_model"] = (bm_ty, Field(**field_kwargs)) + if "output_weight" in used_kwargs: + tt_fields["output_weight"] = (WeightInfo, Field(...)) + if parameters: + tt_fields["parameters"] = (ImmutableMap[str, Any], Field(...)) + if tensor_parameters: + tt_fields["tensor_parameters"] = ( + ImmutableMap[ModelReference, ImmutableMap[str, Any]], + Field(...), + ) + + def _arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + tt_fields["arguments"] = _arguments + + def _group_label(self) -> Optional[str]: + return self.gather_tensors.group_label() + + tt_fields["group_label"] = _group_label + + def _uses_accelerator(self) -> bool: + return True + + tt_fields["uses_accelerator"] = _uses_accelerator + + def _execute(self, tensors: Dict[ModelReference, torch.Tensor], **_kwargs): + model_refs = set(tensors.keys()) + base_model = getattr(self, "base_model", None) + if base_model and base_model in model_refs: + model_refs.remove(base_model) + if not use_base_tensor_arg: + model_refs = [base_model] + list(model_refs) + else: + model_refs = list(model_refs) + base_tensor = tensors.get(base_model, None) + tensors = [tensors[key] for key in model_refs] + inner_kwargs = {} + for key in used_kwargs: + inner_kwargs[key] = getattr(self, key) + if use_base_tensor_arg: + inner_kwargs["base_tensor"] = base_tensor + if require_base_tensor and (inner_kwargs["base_tensor"] is None): + raise ValueError(f"Base model tensor required but not present") + for key in parameters: + inner_kwargs[key.name] = self.parameters[key.name] + for key in tensor_parameters: + inner_kwargs[key.name] = [ + self.tensor_parameters[ref][key.name] for ref in model_refs + ] + return func(tensors=tensors, **inner_kwargs) + + tt_fields["execute"] = _execute + + tt_name = f"{name.title().replace(' ','')}MergeTask" + tt_cls = pydantic.create_model(tt_name, __base__=Task[torch.Tensor], **tt_fields) + + mm_fields = {} + + def _make_task( + self, + *, + output_weight: WeightInfo, + tensors: MergeTensorInput, + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + base_model: Optional[ModelReference], + **_kwargs, + ) -> Task: + tt_kwargs = {"gather_tensors": tensors} + if "base_model" in tt_fields: + tt_kwargs["base_model"] = base_model + if "output_weight" in tt_fields: + tt_kwargs["output_weight"] = output_weight + if "parameters" in tt_fields: + tt_kwargs["parameters"] = parameters + if "tensor_parameters" in tt_fields: + tt_kwargs["tensor_parameters"] = tensor_parameters + return tt_cls(**tt_kwargs) + + mm_fields["make_task"] = _make_task + + def _name(self) -> str: + return name + + mm_fields["name"] = _name + + def _pretty_name(self) -> Optional[str]: + return pretty_name + + mm_fields["pretty_name"] = _pretty_name + + def _reference_url(self) -> Optional[str]: + return reference_url + + mm_fields["reference_url"] = _reference_url + + def _tensor_parameters(self) -> List[ConfigParameterDef]: + return tensor_parameters + + mm_fields["tensor_parameters"] = _tensor_parameters + + def _parameters(self) -> List[ConfigParameterDef]: + return parameters + + mm_fields["parameters"] = _parameters + + mm_name = f"{name.title().replace(' ','')}MergeMethod" + mm_cls = type(mm_name, (MergeMethod,), mm_fields) + REGISTERED_MERGE_METHODS[name] = mm_cls() + return func + + +def merge_method( + name: str, + reference_url: Optional[str] = None, + pretty_name: Optional[str] = None, +) -> Callable: + """Decorator for registering custom model merging algorithms. + + Enables creation of new merge algorithms that can be specified in merge configurations + and executed through mergekit's processing pipeline. Handles parameter validation, task + creation, and registration in the mergekit system. + + Args: + name: Unique identifier for the merge method (lowercase, snake_case recommended) + reference_url: Optional URL to paper/documentation explaining the method (used in generated READMEs) + pretty_name: Human-readable display name (used in generated READMEs) + + Returns: + A decorator that registers the function as a merge method implementation + + Notes: + The decorated function must meet these requirements: + - First parameter must be `tensors: List[torch.Tensor]` + - Must return a single `torch.Tensor` + - All parameters must have type annotations + + Key behavioral considerations: + + *Base Model Handling:* + - If the method includes a `base_tensor` parameter: + * `torch.Tensor` annotation: Requires `base_model` in config, receives its tensor + * `Optional[torch.Tensor]` annotation: `base_model` optional, `None` if not provided + * Non-base model tensors passed in `tensors` list + - Without `base_tensor` parameter: + * Base model tensor (if specified) will be first in `tensors` list + + *Parameter Types:* + - Standard parameters (auto-populated): + * `base_tensor`: Tensor from base model (type determines requirement) + * `output_weight`: WeightInfo with output configuration + * `base_model`: ModelReference if using base model logic + - Scalar parameters (global config): + * `float`, `int`, or `bool` types specified in top-level `parameters` + - Tensor parameters (per-model weights): + * Annotated as `List[float]` or `List[int]` + * Configured per-model in their `parameters` section + * Collected into lists ordered by input models + + Example: + ```python + @merge_method( + name="average", + pretty_name="Simple Average", + reference_url="https://example.com/mean-merge" + ) + def average_merge( + tensors: List[torch.Tensor], # Input tensors to merge + weights: List[float], # Per-model weights (tensor parameter) + normalize: bool = True # Scalar parameter + ) -> torch.Tensor: + if normalize: + weights = [w / sum(weights) for w in weights] + return sum(t * w for t, w in zip(tensors, weights)) + ``` + + This would enable merge configurations like: + ```yaml + merge_method: average + parameters: + normalize: true + tensor_parameters: + weights: [0.3, 0.7] + ``` + + Raises: + ValueError: If function signature doesn't meet requirements + TypeError: For invalid parameter annotations + """ + + def _wrap(func: Callable) -> Callable: + return __merge_method(func, name, reference_url, pretty_name) + + return _wrap diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 214726b7..70d25af3 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -19,7 +19,7 @@ import torch from pydantic import BaseModel -from typing_extensions import Literal +from typing_extensions import Literal, override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference @@ -29,7 +29,7 @@ MergeMethod, MergeTensorInput, ) -from mergekit.sparsify import SparsificationMethod, sparsify +from mergekit.sparsify import SparsificationMethod, get_tall_mask, sparsify class ConsensusMethod(str, Enum): @@ -42,6 +42,20 @@ class GeneralizedTaskArithmeticMerge(MergeMethod, BaseModel, frozen=True): sparsification_method: Optional[SparsificationMethod] default_normalize: bool default_rescale: bool + method_name: str + method_pretty_name: Optional[str] + method_reference_url: Optional[str] + + def name(self) -> str: + return self.method_name + + @override + def pretty_name(self) -> Optional[str]: + return self.method_pretty_name + + @override + def reference_url(self) -> Optional[str]: + return self.method_reference_url def parameters(self) -> List[ConfigParameterDef]: return [ @@ -79,6 +93,22 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: default_value=1.0, ) ) + if ( + self.sparsification_method == SparsificationMethod.consensus_ta + or self.sparsification_method == SparsificationMethod.consensus_ties + ): + res.append( + ConfigParameterDef( + name="k", + default_value=1, + ) + ) + res.append( + ConfigParameterDef( + name="lambda", + default_value=1.0, + ) + ) return res def make_task( @@ -133,7 +163,10 @@ def execute( return base # sparsify - if self.method.sparsification_method: + if ( + self.method.sparsification_method + and self.method.sparsification_method != SparsificationMethod.consensus_ta + ): for tv_info in tvs: kwargs = {} if "gamma" in tv_info: @@ -142,7 +175,7 @@ def execute( if "epsilon" in tv_info: kwargs["epsilon"] = tv_info["epsilon"] - tv_info["delta"] = sparsify( + tv_info["sparsified_delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, @@ -150,7 +183,9 @@ def execute( **kwargs, ) - deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) + deltas = torch.stack([tv["sparsified_delta"] for tv in tvs], dim=0) + else: + deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) @@ -185,6 +220,20 @@ def execute( lambda_factor = tvs[0]["lambda"] mixed_delta *= lambda_factor + if ( + self.method.sparsification_method == SparsificationMethod.consensus_ta + or self.method.sparsification_method == SparsificationMethod.consensus_ties + ): + for tv_info in tvs: + tv_info["tall_mask"] = get_tall_mask( + tv_info["delta"], + tv_info["lambda"], + mixed_delta, + ) + tall_masks = torch.stack([tv["tall_mask"] for tv in tvs], dim=0) + consensus_mask = tall_masks.sum(dim=0) >= tvs[0]["k"] + mixed_delta = mixed_delta * consensus_mask + return (base + mixed_delta).to(base.dtype) def group_label(self) -> Optional[str]: diff --git a/mergekit/merge_methods/linear.py b/mergekit/merge_methods/linear.py index 48224bb8..234de877 100644 --- a/mergekit/merge_methods/linear.py +++ b/mergekit/merge_methods/linear.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional import torch +from typing_extensions import override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference @@ -72,6 +73,17 @@ def group_label(self) -> Optional[str]: class LinearMerge(MergeMethod): + def name(self) -> str: + return "linear" + + @override + def pretty_name(self) -> Optional[str]: + return "Linear" + + @override + def reference_url(self) -> Optional[str]: + return "https://arxiv.org/abs/2203.05482" + def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="normalize", required=False, default_value=True), diff --git a/mergekit/merge_methods/model_stock.py b/mergekit/merge_methods/model_stock.py index 94b1e05b..8d8002b5 100644 --- a/mergekit/merge_methods/model_stock.py +++ b/mergekit/merge_methods/model_stock.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Optional import torch +from typing_extensions import override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference @@ -114,6 +115,17 @@ def group_label(self) -> Optional[str]: class ModelStockMerge(MergeMethod): + def name(self) -> str: + return "model_stock" + + @override + def pretty_name(self) -> Optional[str]: + return "Model Stock" + + @override + def reference_url(self): + return "https://arxiv.org/abs/2403.19522" + def parameters(self) -> List[ConfigParameterDef]: return [ ConfigParameterDef(name="filter_wise", required=False, default_value=False) diff --git a/mergekit/merge_methods/multislerp.py b/mergekit/merge_methods/multislerp.py new file mode 100644 index 00000000..ddbdcada --- /dev/null +++ b/mergekit/merge_methods/multislerp.py @@ -0,0 +1,97 @@ +from typing import List, Optional + +import torch + +from mergekit.merge_methods.easy_define import merge_method + + +@merge_method( + name="multislerp", + pretty_name="Multi-SLERP", + reference_url="https://goddard.blog/posts/multislerp-wow-what-a-cool-idea", +) +def multislerp( + tensors: List[torch.Tensor], + weight: List[float], + base_tensor: Optional[torch.Tensor] = None, + normalize_weights: bool = True, + eps: float = 1e-8, +): + """ + Implements barycentric interpolation on a hypersphere. + + The approach: + 1. Project points onto a tangent space at their weighted Euclidean mean. + 2. Perform the interpolation in the tangent space. + 3. Project the result back to the hypersphere. + + Limitations: + - The weighted sum of the input tensors must not be zero. + - The tensors must not be all parallel or antiparallel. + + Args: + tensors: List of tensors to interpolate + weight: List of weights for each tensor + base_tensor: Optional tensor defining the origin of the hypersphere + normalize_weights: If True, the weights will be normalized to sum to 1 + eps: Small constant for numerical stability + """ + if len(tensors) == 1: + # No interpolation needed + return tensors[0] + + tensors = torch.stack(tensors, dim=0) + if base_tensor is not None: + tensors -= base_tensor + + tensors_flat = tensors.view(tensors.shape[0], -1) + + weights = torch.tensor(weight, dtype=tensors.dtype, device=tensors.device) + if normalize_weights: + weights = weights / weights.sum() + + # Project to unit hypersphere + norms = torch.norm(tensors_flat, dim=-1, keepdim=True) + unit_tensors = tensors_flat / (norms + eps) + + mean = (unit_tensors * weights.view(-1, 1)).sum(0) + mean_norm = torch.norm(mean) + print(mean_norm) + if mean_norm < eps: + if tensors.shape[0] == 2: + # fallback to linear interpolation + res = (tensors[0] * weights[0] + tensors[1] * weights[1]).view( + tensors.shape[1:] + ) + if base_tensor is not None: + res = res + base_tensor + return res + raise ValueError( + "The weighted sum of the input tensors is zero. This occurs when " + "antipodal vectors or sets of vectors have weights that exactly " + "balance out (e.g., vectors a,-a with equal weights). Try using " + "different weights if you have antipodal vectors." + ) + mean = mean / mean_norm + + # Project to tangent space + dots = (unit_tensors * mean).sum(-1, keepdim=True) + tangent_vectors = unit_tensors - dots * mean + + # Interpolate + tangent_result = (tangent_vectors * weights.view(-1, 1)).sum(0) + + # Project back to sphere using exponential map + tangent_norm = torch.norm(tangent_result) + eps + result = mean * torch.cos(tangent_norm) + tangent_result * ( + torch.sin(tangent_norm) / tangent_norm + ) + + avg_norm = (norms.squeeze(-1) * weights).sum() + result = result * avg_norm + result = result.view(tensors.shape[1:]) + + if base_tensor is not None: + result = result + base_tensor + + return result diff --git a/mergekit/merge_methods/nearswap.py b/mergekit/merge_methods/nearswap.py new file mode 100644 index 00000000..1a636007 --- /dev/null +++ b/mergekit/merge_methods/nearswap.py @@ -0,0 +1,126 @@ +# Copyright (C) 2025 Charles Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional, Union + +import torch + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes + + +class NearSwapTask(Task[torch.Tensor]): + gather_tensors: MergeTensorInput + base_model: ModelReference + t: float + weight_info: WeightInfo + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor: + if self.t <= 0: + raise RuntimeError(f"Threshold cannot be <= zero, got {self.t}") + if len(tensors) == 1: + return list(tensors.values())[0] + elif len(tensors) != 2: + raise RuntimeError( + f"Nearswap merge expects exactly two models, got {len(tensors)}" + ) + elif self.base_model not in tensors: + raise RuntimeError("Base model not in input tensors") + + [a, b] = list(tensors.items()) + if a[0] != self.base_model: + [a, b] = [b, a] + prepped_tensors = [a[1], b[1]] + + rectify_embed_sizes(self.weight_info, prepped_tensors) + + return ( + nearswap( + self.t, + prepped_tensors[0], + prepped_tensors[1], + ) + .to(prepped_tensors[0].dtype) + .to(prepped_tensors[0].device) + ) + + +class NearSwapMerge(MergeMethod): + def name(self) -> str: + return "nearswap" + + def pretty_name(self) -> Optional[str]: + return "NearSwap" + + def reference_url(self) -> Optional[str]: + return "https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001" + + def parameters(self) -> List[ConfigParameterDef]: + return [ConfigParameterDef(name="t", required=True)] + + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: MergeTensorInput, + parameters: ImmutableMap[str, Any], + base_model: Optional[ModelReference], + **_kwargs, + ) -> Task: + return NearSwapTask( + gather_tensors=tensors, + base_model=base_model, + weight_info=output_weight, + t=parameters["t"], + ) + + +def nearswap(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor: + """ + NearSwap implementation using PyTorch. + + Adapted from: https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001 + + Parameters: + t (float): The sameness threshold. + v0 (torch.Tensor): Weights from the base model. + v1 (torch.Tensor): Weights from the secondary model. + + Returns: + torch.Tensor: Resulting interpolated weights. + """ + # Compute the absolute difference + lweight = torch.abs(v0 - v1) + + # Compute the interpolation factor + lweight = t / lweight + lweight = torch.nan_to_num(lweight, nan=1.0, posinf=1.0, neginf=1.0) + lweight = torch.clamp(lweight, min=0.0, max=1.0) + + # Linearly interpolate between v0 and v1 + return lweight * v1 + (1 - lweight) * v0 diff --git a/mergekit/merge_methods/nuslerp.py b/mergekit/merge_methods/nuslerp.py new file mode 100644 index 00000000..40773c29 --- /dev/null +++ b/mergekit/merge_methods/nuslerp.py @@ -0,0 +1,179 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Any, Dict, List, Optional + +import torch +from torch._tensor import Tensor +from typing_extensions import override + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) +from mergekit.merge_methods.rectify_embed import rectify_embed_sizes + + +class NuSlerpTask(Task[torch.Tensor]): + gather_tensors: MergeTensorInput + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]] + weight_info: WeightInfo + row_wise: bool + flatten: bool + base_model: Optional[ModelReference] + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.gather_tensors} + + def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor: + if len(tensors) == 1: + return list(tensors.values())[0] + + if self.base_model is not None: + if len(tensors) != 3: + raise RuntimeError( + "NuSlerp base model can not be one of the two models to merge" + ) + base_tensor = tensors.pop(self.base_model) + else: + base_tensor = None + + keys = list(tensors.keys()) + tensors = [tensors[key] for key in keys] + weights = [self.tensor_parameters[key]["weight"] for key in keys] + + if len(tensors) != 2: + print(keys) + print(self.base_model) + raise RuntimeError( + "NuSlerp merge expects exactly two models (plus optional base model)" + ) + + if abs(sum(weights)) < 1e-6: + # this is fairly arbitrary, but it's more sane than exploding + t = 0.5 + else: + t = weights[1] / sum(weights) + + if base_tensor is not None: + tensors.append(base_tensor) + rectify_embed_sizes(self.weight_info, tensors) + + if base_tensor is not None: + base_tensor = tensors.pop() + return base_tensor + nuslerp( + t, + tensors[0] - base_tensor, + tensors[1] - base_tensor, + dim=0 if self.row_wise else -1, + flatten=self.flatten, + ) + return nuslerp( + t, + tensors[0], + tensors[1], + dim=0 if self.row_wise else -1, + flatten=self.flatten, + ) + + +class NuSlerpMerge(MergeMethod): + def name(self) -> str: + return "nuslerp" + + @override + def pretty_name(self): + return "NuSLERP" + + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef( + name="nuslerp_row_wise", + required=False, + default_value=False, + ), + ConfigParameterDef( + name="nuslerp_flatten", + required=False, + default_value=True, + ), + ] + + def tensor_parameters(self) -> List[ConfigParameterDef]: + return [ConfigParameterDef(name="weight", required=True)] + + def make_task( + self, + *, + output_weight: WeightInfo, + tensors: MergeTensorInput, + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], + tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], + **_kwargs, + ) -> Task: + return NuSlerpTask( + gather_tensors=tensors, + tensor_parameters=tensor_parameters, + weight_info=output_weight, + row_wise=parameters["nuslerp_row_wise"], + flatten=parameters["nuslerp_flatten"], + base_model=base_model, + ) + + +def nuslerp( + t: float, + v0: torch.Tensor, + v1: torch.Tensor, + dim: int = -1, + eps: float = 1e-8, + flatten: bool = False, +): + out_shape = v0.shape + + def _normalize(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + return x / torch.norm(x, dim=-1, keepdim=True).clamp(min=eps) + + if flatten: + v0 = v0.view(-1) + v1 = v1.view(-1) + elif dim != -1: + v0 = v0.transpose(dim, -1) + v1 = v1.transpose(dim, -1) + + v0_u = _normalize(v0) + v1_u = _normalize(v1) + + cos_theta = torch.sum(v0_u * v1_u, dim=-1, keepdim=True) + theta = torch.acos(cos_theta.clamp(-1, 1)) + sin_theta = torch.sin(theta) + + colinear = (sin_theta.abs() < eps).squeeze() + + res = (torch.sin((1 - t) * theta) * v0 + torch.sin(t * theta) * v1) / sin_theta + # Use linear interpolation for (nearly) colinear vectors + res[colinear] = (1 - t) * v0[colinear] + t * v1[colinear] + + if dim != -1 and not flatten: + res = res.transpose(dim, -1) + return res.view(out_shape) diff --git a/mergekit/merge_methods/passthrough.py b/mergekit/merge_methods/passthrough.py index 62b0bf12..219eb9aa 100644 --- a/mergekit/merge_methods/passthrough.py +++ b/mergekit/merge_methods/passthrough.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List, Optional import torch +from typing_extensions import override from mergekit.common import ImmutableMap, ModelReference from mergekit.graph import Task @@ -49,6 +50,13 @@ def group_label(self) -> Optional[str]: class PassthroughMerge(MergeMethod): + def name(self) -> str: + return "passthrough" + + @override + def pretty_name(self) -> Optional[str]: + return "Passthrough" + def tensor_parameters(self) -> List[ConfigParameterDef]: return [ConfigParameterDef(name="scale", required=False, default_value=None)] diff --git a/mergekit/merge_methods/registry.py b/mergekit/merge_methods/registry.py new file mode 100644 index 00000000..e55cb252 --- /dev/null +++ b/mergekit/merge_methods/registry.py @@ -0,0 +1,135 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from typing import Dict, List + +from mergekit.merge_methods.base import MergeMethod +from mergekit.merge_methods.generalized_task_arithmetic import ( + ConsensusMethod, + GeneralizedTaskArithmeticMerge, +) +from mergekit.merge_methods.linear import LinearMerge +from mergekit.merge_methods.model_stock import ModelStockMerge +from mergekit.merge_methods.nearswap import NearSwapMerge +from mergekit.merge_methods.nuslerp import NuSlerpMerge +from mergekit.merge_methods.passthrough import PassthroughMerge +from mergekit.merge_methods.sce import SCEMerge +from mergekit.merge_methods.slerp import SlerpMerge +from mergekit.sparsify import SparsificationMethod + +STATIC_MERGE_METHODS: List[MergeMethod] = [ + LinearMerge(), + SlerpMerge(), + NuSlerpMerge(), + PassthroughMerge(), + ModelStockMerge(), + SCEMerge(), + NearSwapMerge(), + # generalized task arithmetic methods + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=None, + default_normalize=False, + default_rescale=False, + method_name="task_arithmetic", + method_pretty_name="Task Arithmetic", + method_reference_url="https://arxiv.org/abs/2212.04089", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.magnitude, + default_normalize=True, + default_rescale=False, + method_name="ties", + method_pretty_name="TIES", + method_reference_url="https://arxiv.org/abs/2306.01708", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.random, + default_normalize=False, + default_rescale=True, + method_name="dare_ties", + method_pretty_name="DARE TIES", + method_reference_url="https://arxiv.org/abs/2311.03099", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.random, + default_normalize=False, + default_rescale=True, + method_name="dare_linear", + method_pretty_name="Linear DARE", + method_reference_url="https://arxiv.org/abs/2311.03099", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.magnitude_outliers, + default_normalize=False, + default_rescale=False, + method_name="breadcrumbs", + method_pretty_name="Model Breadcrumbs", + method_reference_url="https://arxiv.org/abs/2312.06795", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.magnitude_outliers, + default_normalize=False, + default_rescale=False, + method_name="breadcrumbs_ties", + method_pretty_name="Model Breadcrumbs with TIES", + method_reference_url="https://arxiv.org/abs/2312.06795", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.rank_magnitude_sampling, + default_normalize=True, + default_rescale=True, + method_name="della", + method_pretty_name="DELLA", + method_reference_url="https://arxiv.org/abs/2406.11617", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.rank_magnitude_sampling, + default_normalize=False, + default_rescale=True, + method_name="della_linear", + method_pretty_name="Linear DELLA", + method_reference_url="https://arxiv.org/abs/2406.11617", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.consensus_ta, + default_normalize=False, + default_rescale=False, + method_name="consensus_ta", + method_pretty_name="Consensus Task Arithmetic", + method_reference_url="https://arxiv.org/abs/2405.07813", + ), + GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.consensus_ties, + default_normalize=True, + default_rescale=False, + method_name="consensus_ties", + method_pretty_name="Consensus TIES", + method_reference_url="https://arxiv.org/abs/2405.07813", + ), +] + +REGISTERED_MERGE_METHODS: Dict[str, MergeMethod] = { + method.name(): method for method in STATIC_MERGE_METHODS +} diff --git a/mergekit/merge_methods/sce.py b/mergekit/merge_methods/sce.py new file mode 100644 index 00000000..b716f44d --- /dev/null +++ b/mergekit/merge_methods/sce.py @@ -0,0 +1,214 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import torch +from pydantic import BaseModel +from typing_extensions import override + +from mergekit.architecture import WeightInfo +from mergekit.common import ImmutableMap, ModelReference +from mergekit.graph import Task +from mergekit.merge_methods.base import ( + ConfigParameterDef, + MergeMethod, + MergeTensorInput, +) + + +class SCEMerge(MergeMethod, BaseModel, frozen=True): + def name(self) -> str: + return "sce" + + @override + def pretty_name(self) -> str: + return "SCE" + + @override + def reference_url(self) -> str: + return "https://arxiv.org/abs/2408.07990" + + def parameters(self) -> List[ConfigParameterDef]: + return [ + ConfigParameterDef(name="int8_mask", required=False, default_value=False), + ConfigParameterDef(name="select_topk", required=False, default_value=1.0), + ] + + def make_task( + self, + output_weight: WeightInfo, + tensors: MergeTensorInput, + base_model: Optional[ModelReference], + parameters: ImmutableMap[str, Any], + **_kwargs, + ) -> Task: + return SCETask( + tensors=tensors, + base_model=base_model, + int8_mask=parameters["int8_mask"], + select_topk=parameters["select_topk"], + weight_info=output_weight, + ) + + +class SCETask(Task[torch.Tensor]): + tensors: MergeTensorInput + base_model: ModelReference + weight_info: WeightInfo + int8_mask: bool + select_topk: float + + def uses_accelerator(self) -> bool: + return True + + def arguments(self) -> Dict[str, Task]: + return {"tensors": self.tensors} + + def execute( + self, + tensors: Dict[ModelReference, torch.Tensor], + **_kwargs, + ) -> torch.Tensor: + # collect task vectors + tvs, base = get_task_vectors(self.weight_info, self.base_model, tensors) + if not tvs: + return base + + deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) + mask_dtype = torch.int8 if self.int8_mask else base.dtype + + # Select the top τ% elements with high variance + if self.select_topk < 1: + mask = get_sce_mask(deltas, self.select_topk, mask_dtype) + mask_expanded = mask.unsqueeze(0).expand_as(deltas) + deltas = deltas * mask_expanded + + # Calculate matrix level merging coefficient + weights = get_sce_weight(deltas) + weights = torch.tensor(weights, dtype=deltas.dtype, device=deltas.device) + while len(deltas.shape) > len(weights.shape): + weights.unsqueeze_(-1) + + # Erase elements with minority directions + erase_mask = get_erase_mask( + deltas, + mask_dtype=mask_dtype, + ) + erased_weights = weights * erase_mask + mixed_delta = (deltas * erased_weights).sum(dim=0) + + # Normalize + divisor = (erased_weights).sum(dim=0) + divisor[divisor == 0] = 1 + mixed_delta /= divisor + + return (base + mixed_delta).to(base.dtype) + + +def get_task_vectors( + weight_info: WeightInfo, + base_model: ModelReference, + tensors: ImmutableMap[ModelReference, torch.Tensor], +) -> Tuple[List[Dict[str, Any]], torch.Tensor]: + keys = list(tensors.keys()) + base = tensors[base_model] + + parameter_name = weight_info.name + + res = [] + for model in keys: + if model == base_model: + continue + + x = tensors[model].to(base.dtype) + if x.shape != base.shape: + if weight_info.is_embed: + x = x[: base.shape[0], : base.shape[1]] + logging.warning(f"Using submatrix of {model}:{parameter_name}") + else: + logging.warning( + f"skipping {model}:{parameter_name} due to size mismatch" + ) + continue + + delta = x - base + del x + del tensors[model] + + d = {} + d["model"] = model + d["delta"] = delta + res.append(d) + return res, base + + +def get_erase_mask( + delta: torch.Tensor, + mask_dtype: Optional[torch.dtype] = None, +): + """Returns a mask determining which delta vectors should be merged + into the final model. + """ + if mask_dtype is None: + mask_dtype = delta.dtype + + sign = delta.sign().to(mask_dtype) + + sign_weight = delta.sum(dim=0) + majority_sign = (sign_weight >= 0).to(mask_dtype) * 2 - 1 + del sign_weight + + return sign == majority_sign + + +def get_sce_mask( + deltas: torch.Tensor, + density: float, + mask_dtype: Optional[torch.dtype] = None, +): + if mask_dtype is None: + mask_dtype = deltas.dtype + # Calculate variance along the first dimension + variance = torch.var(deltas, dim=0, unbiased=False) + # Count non-zero positions in variance + non_zero_positions_count = torch.count_nonzero(variance) + # Calculate the number of top elements to select + k = int(abs(density) * non_zero_positions_count) + mask = torch.zeros_like(variance, dtype=mask_dtype) + if k == 0: + return mask + assert k > 0, "not gonna zero out the whole tensor buddy" + + # Get the indices of the top k elements with the highest absolute variance + topk_indices = torch.topk(variance.abs().view(-1), k=k, largest=True).indices + + mask.view(-1)[topk_indices] = 1 + return mask + + +def get_sce_weight(deltas): + # Calculate the squared sum of each delta and normalize by the number of elements + weights = [torch.sum(delta**2).item() / delta.numel() for delta in deltas] + + # Normalize the weights + sum_weights = sum(weights) + if sum_weights == 0: + weights = [1.0 / len(weights)] * len(weights) + else: + weights = [w / sum_weights for w in weights] + + return weights diff --git a/mergekit/merge_methods/slerp.py b/mergekit/merge_methods/slerp.py index d33dd5a9..06017545 100644 --- a/mergekit/merge_methods/slerp.py +++ b/mergekit/merge_methods/slerp.py @@ -17,6 +17,7 @@ import numpy as np import torch +from typing_extensions import override from mergekit.architecture import WeightInfo from mergekit.common import ImmutableMap, ModelReference @@ -71,6 +72,17 @@ def group_label(self) -> Optional[str]: class SlerpMerge(MergeMethod): + def name(self) -> str: + return "slerp" + + @override + def pretty_name(self) -> Optional[str]: + return "SLERP" + + @override + def reference_url(self): + return "https://en.wikipedia.org/wiki/Slerp" + def parameters(self) -> List[ConfigParameterDef]: return [ConfigParameterDef(name="t", required=True)] diff --git a/mergekit/merge_methods/tokenizer_permute.py b/mergekit/merge_methods/tokenizer_permute.py deleted file mode 100644 index 07c6f9c5..00000000 --- a/mergekit/merge_methods/tokenizer_permute.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (C) 2024 Charles O. Goddard -# -# This software is free software: you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This software is distributed in the hope that it will be useful, but -# WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see http://www.gnu.org/licenses/. - -from typing import Any, Dict, List, Optional - -import torch -from pydantic import BaseModel - -from mergekit.common import ImmutableMap, ModelReference -from mergekit.graph import Task -from mergekit.merge_methods.base import ( - ConfigParameterDef, - MergeMethod, - MergeTensorInput, -) -from mergekit.merge_methods.slerp import slerp -from mergekit.tokenizer import BuildTokenizer, TokenizerInfo - - -class TokenizerPermutationMergeTask(Task[torch.Tensor]): - tokenizer_task: BuildTokenizer - gather_tensors: MergeTensorInput - base_model: Optional[ModelReference] - use_slerp: bool - slerp_t: Optional[float] - tensor_parameters: ImmutableMap[ModelReference, Any] - - def uses_accelerator(self) -> bool: - return True - - def arguments(self) -> Dict[str, Task]: - return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors} - - def execute( - self, tokenizer_info: TokenizerInfo, tensors: Dict[ModelReference, torch.Tensor] - ) -> torch.Tensor: - if not tensors: - return None - if len(tensors) == 1: - return list(tensors.values())[0] - - if self.use_slerp and self.slerp_t is None: - raise RuntimeError("Must set t to use embed_slerp") - - models = [] - expanded = [] - masks = [] - weights = [] - for model in tensors: - models.append(model) - - x = tensors[model] - p = tokenizer_info.permutations[model] - - xp = torch.zeros((len(p), x.shape[-1]), dtype=x.dtype, device=x.device) - mask = torch.zeros((len(p),), dtype=torch.bool, device=x.device) - for out_idx in p: - in_idx = p[out_idx] - if in_idx < 0: - continue - - xp[out_idx, :] = x[in_idx, :] - mask[out_idx] = 1 - - expanded.append(xp) - masks.append(mask) - - is_base = model == self.base_model - if self.use_slerp: - weight = (1.0 - self.slerp_t) if is_base else self.slerp_t - else: - weight = self.tensor_parameters[model]["weight"] - - weights.append(weight) - - expanded = torch.stack(expanded, dim=0) - masks = torch.stack(masks, dim=0).unsqueeze(-1) - weights = ( - torch.tensor(weights, dtype=expanded.dtype, device=expanded.device) - .unsqueeze(-1) - .unsqueeze(-1) - ) - - total_weight = (masks * weights).sum(dim=0) - scale = 1 / total_weight - scale[total_weight.abs() < 1e-8] = 0 - - linear_merged = (expanded * weights * masks).sum(dim=0) * scale - - if self.use_slerp: - if expanded.shape[0] != 2: - raise RuntimeError("SLERP takes exactly two models") - - if models[0] == self.base_model: - v0 = expanded[0, ...] - v1 = expanded[1, ...] - else: - v0 = expanded[1, ...] - v1 = expanded[0, ...] - - res = slerp(self.slerp_t, v0, v1) - need_linear = (masks.sum(dim=0) != 2).squeeze(dim=-1) - res[need_linear, :] = linear_merged[need_linear, :].to( - device=res.device, dtype=res.dtype - ) - return res - - return linear_merged - - -class TokenizerPermutationMerge(MergeMethod, BaseModel): - tokenizer_task: BuildTokenizer - - def parameters(self) -> List[ConfigParameterDef]: - return [ - ConfigParameterDef(name="t", required=False), - ConfigParameterDef(name="embed_slerp", required=False, default_value=False), - ] - - def tensor_parameters(self) -> List[ConfigParameterDef]: - return [ - ConfigParameterDef(name="weight", required=False), - ] - - def make_task( - self, - *, - tensors: MergeTensorInput, - parameters: Dict[str, Any], - tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]], - base_model: Optional[ModelReference], - **_kwargs, - ) -> Task: - return TokenizerPermutationMergeTask( - base_model=base_model, - tokenizer_task=self.tokenizer_task, - gather_tensors=tensors, - use_slerp=parameters["embed_slerp"], - slerp_t=parameters["t"], - tensor_parameters=tensor_parameters, - ) diff --git a/mergekit/moe/common.py b/mergekit/moe/common.py index 4a0df69c..a5970b4a 100644 --- a/mergekit/moe/common.py +++ b/mergekit/moe/common.py @@ -13,12 +13,14 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. +import logging from typing import Dict, Optional, Tuple import torch import tqdm import transformers +from mergekit.architecture import WeightInfo from mergekit.common import ModelReference, dtype_from_name from mergekit.io import LazyTensorLoader, TensorWriter from mergekit.merge import MergeOptions @@ -73,3 +75,31 @@ def noise_and_scale( if is_residual and expert.residual_scale is not None: tensor = tensor * expert.residual_scale return tensor + + +def copy_tensor_out( + weight_info: WeightInfo, + loader: LazyTensorLoader, + writer: TensorWriter, + expert: Optional[Expert] = None, + is_residual: bool = False, + output_name: Optional[str] = None, + out_dtype: Optional[torch.dtype] = None, + clone: bool = False, +): + out_tensor_name = output_name or weight_info.name + try: + tensor = loader.get_tensor(weight_info.name, aliases=weight_info.aliases) + except KeyError: + tensor = None + if tensor is None and not weight_info.optional: + logging.error(f"Missing weight: {weight_info.name} / {out_tensor_name}") + raise KeyError(out_tensor_name) + + if expert: + tensor = noise_and_scale(tensor, expert, is_residual=is_residual) + writer.save_tensor( + out_tensor_name, + tensor.to(dtype=out_dtype), + clone=clone, + ) diff --git a/mergekit/moe/deepseek.py b/mergekit/moe/deepseek.py index 1f7226fb..4ce62865 100644 --- a/mergekit/moe/deepseek.py +++ b/mergekit/moe/deepseek.py @@ -24,7 +24,7 @@ from mergekit.architecture import get_architecture_info from mergekit.moe.arch import MoEOutputArchitecture -from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig from mergekit.options import MergeOptions @@ -148,39 +148,36 @@ def write_model( ".mlp.", f".mlp.experts.{expert_idx}." ) expert_loader = loaders.get(expert.source_model) - tensor = expert_loader.get_tensor( - weight_info.name, aliases=weight_info.aliases - ) - tensor = noise_and_scale( - tensor, expert, is_residual="down_proj" in tensor_name - ) - writer.save_tensor( - expert_name, - tensor.to(dtype=out_dtype), + copy_tensor_out( + weight_info, + expert_loader, + writer, + expert=expert, + is_residual="down_proj" in tensor_name, + output_name=expert_name, + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) if shared_def is not None: - shared_tensor = shared_loader.get_tensor( - weight_info.name, aliases=weight_info.aliases - ) - shared_tensor = noise_and_scale( - shared_tensor, - shared_def, + copy_tensor_out( + weight_info, + shared_loader, + writer, + expert=shared_def, is_residual="down_proj" in tensor_name, - ) - writer.save_tensor( - tensor_name.replace(".mlp.", ".mlp.shared_experts."), - shared_tensor.to(dtype=out_dtype), + output_name=tensor_name.replace( + ".mlp.", ".mlp.shared_experts." + ), + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) else: - tensor = base_loader.get_tensor( - tensor_name, aliases=weight_info.aliases - ) - writer.save_tensor( - tensor_name, - tensor.to(dtype=out_dtype), + copy_tensor_out( + weight_info, + base_loader, + writer, + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) diff --git a/mergekit/moe/mixtral.py b/mergekit/moe/mixtral.py index 538cb701..f3fe97df 100644 --- a/mergekit/moe/mixtral.py +++ b/mergekit/moe/mixtral.py @@ -22,7 +22,7 @@ from mergekit.architecture import MISTRAL_INFO, WeightInfo from mergekit.moe.arch import MoEOutputArchitecture -from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig from mergekit.options import MergeOptions @@ -145,24 +145,22 @@ def write_model( for expert_index, expert in enumerate(config.experts): expert_name = tensor_name.replace("{expert_idx}", str(expert_index)) expert_loader = loaders.get(expert.source_model) - tensor = expert_loader.get_tensor( - weight_info.name, aliases=weight_info.aliases - ) - tensor = noise_and_scale( - tensor, expert, is_residual="down_proj" in tensor_name - ) - writer.save_tensor( - expert_name, - tensor.to(dtype=out_dtype), + copy_tensor_out( + weight_info, + expert_loader, + writer, + expert=expert, + out_dtype=out_dtype, + output_name=expert_name, clone=merge_options.clone_tensors, + is_residual="down_proj" in tensor_name, ) else: - tensor = base_loader.get_tensor( - tensor_name, aliases=weight_info.aliases - ) - writer.save_tensor( - tensor_name, - tensor.to(dtype=out_dtype), + copy_tensor_out( + weight_info, + base_loader, + writer, + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) diff --git a/mergekit/moe/qwen.py b/mergekit/moe/qwen.py index ab94f7d5..65337a0a 100644 --- a/mergekit/moe/qwen.py +++ b/mergekit/moe/qwen.py @@ -26,7 +26,7 @@ from mergekit.architecture import QWEN2_INFO from mergekit.moe.arch import MoEOutputArchitecture -from mergekit.moe.common import initialize_io, noise_and_scale, select_dtype +from mergekit.moe.common import copy_tensor_out, initialize_io, select_dtype from mergekit.moe.config import MoEMergeConfig from mergekit.options import MergeOptions @@ -137,29 +137,25 @@ def write_model( ".mlp.", f".mlp.experts.{expert_idx}." ) expert_loader = loaders.get(expert.source_model) - tensor = expert_loader.get_tensor( - weight_info.name, aliases=weight_info.aliases - ) - tensor = noise_and_scale( - tensor, expert, is_residual="down_proj" in tensor_name - ) - writer.save_tensor( - expert_name, - tensor.to(dtype=out_dtype), + copy_tensor_out( + weight_info, + expert_loader, + writer, + expert=expert, + is_residual="down_proj" in tensor_name, + output_name=expert_name, + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) - shared_tensor = shared_loader.get_tensor( - weight_info.name, aliases=weight_info.aliases - ) - shared_tensor = noise_and_scale( - shared_tensor, - shared_def, + copy_tensor_out( + weight_info, + shared_loader, + writer, + expert=shared_def, is_residual="down_proj" in tensor_name, - ) - writer.save_tensor( - tensor_name.replace(".mlp.", ".mlp.shared_expert."), - shared_tensor.to(dtype=out_dtype), + output_name=tensor_name.replace(".mlp.", ".mlp.shared_expert."), + out_dtype=out_dtype, clone=merge_options.clone_tensors, ) else: @@ -180,6 +176,8 @@ def write_model( else out_cfg.num_attention_heads ) tensor = torch.zeros(num_heads * head_dim, dtype=out_dtype) + elif weight_info.optional: + continue else: raise diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..3e407be1 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -139,7 +139,10 @@ def plan_tensor( any_weight = False for model, w_in in zip(models, weights_in): index = LoaderCache().get(model).index - if w_in.name in index.tensor_paths: + if any( + name in index.tensor_paths + for name in [w_in.name] + (w_in.aliases or []) + ): any_weight = True break @@ -179,12 +182,15 @@ def plan_tensor( tensor_input_task = gather_tensors if self._tokenizer_task and weight.is_embed: token_cfg = {} + pad_to_multiple = None if cfg_reader.config.tokenizer: token_cfg = cfg_reader.config.tokenizer.tokens + pad_to_multiple = cfg_reader.config.tokenizer.pad_to_multiple_of tensor_input_task = PermutedEmbeddings( gather_tensors=gather_tensors, tokenizer_task=self._tokenizer_task, tokens=token_cfg, + pad_to_multiple_of=pad_to_multiple, base_model=base_model, ) diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index ff063232..be361302 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -8,12 +8,11 @@ import torch from peft.tuners.lora import QuantLinear from safetensors.torch import save_file -from torch.nn.functional import pad from tqdm import tqdm from transformers import AutoModelForCausalLM -from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import Conv1D +from mergekit.architecture import WeightInfo, get_architecture_info from mergekit.card import generate_card_lora from mergekit.common import ModelReference from mergekit.io import LazyTensorLoader @@ -83,7 +82,9 @@ def decompose_delta_weight( def get_model_details( - model_id: str, skip_undecomposable: bool + model_id: str, + skip_undecomposable: bool, + modules_to_save: Optional[List[str]] = None, ) -> List[Tuple[str, str, torch.Size]]: """ Retrieve architectural details of a given pre-trained model. @@ -101,10 +102,17 @@ def get_model_details( model_id, state_dict={}, device_map="meta" ) + vocab_size = pretrained_model.get_input_embeddings().weight.shape[0] + module_details = [] + modules_to_save = set(modules_to_save or []) for name, module in pretrained_model.named_modules(): - if module == pretrained_model.get_input_embeddings(): + if name in modules_to_save or ( + "." in name and name.split(".")[-1] in modules_to_save + ): + module_details.append(("to_save", name, module.weight.size())) + elif module == pretrained_model.get_input_embeddings(): # if isinstance(module, torch.nn.Embedding): module_details.append(("embedding", name, module.weight.size())) elif module == pretrained_model.get_output_embeddings(): @@ -136,7 +144,7 @@ def get_model_details( else: logging.info(f"Skipping undecomposable module '{name}'.") - return module_details + return module_details, vocab_size def validate_and_combine_details( @@ -144,7 +152,8 @@ def validate_and_combine_details( finetuned_model_id: str, skip_undecomposable: bool, extend_vocab: bool, -) -> List[Tuple[str, str]]: + modules_to_save: Optional[List[str]] = None, +) -> Tuple[List[Tuple[str, str]], int, int]: """ Validate and combine details from a base model and a fine-tuned model. @@ -154,14 +163,15 @@ def validate_and_combine_details( :return: A list of tuples with the type and name of the validated/combined model layers """ - base_model_details = get_model_details(base_model_id, skip_undecomposable) - finetuned_model_details = get_model_details(finetuned_model_id, skip_undecomposable) + base_model_details, base_vocab_size = get_model_details( + base_model_id, skip_undecomposable, modules_to_save=modules_to_save + ) + finetuned_model_details, finetuned_vocab_size = get_model_details( + finetuned_model_id, skip_undecomposable, modules_to_save=modules_to_save + ) module_details = [] - base_model_embedding_size = None - finetuned_model_embedding_size = None - for i, (base_layer, finetuned_layer) in enumerate( zip(base_model_details, finetuned_model_details) ): @@ -175,12 +185,6 @@ def validate_and_combine_details( base_name == finetuned_name ), f"Layer name mismatch: {base_name} != {finetuned_name}" - if base_type == "embedding": - base_model_embedding_size = base_size[0] - - if finetuned_type == "embedding": - finetuned_model_embedding_size = finetuned_size[0] - # Fine-tuned models with added vocab will have have their extra rows truncated unless `extend_vocab` is specified if base_type != "to_save" and finetuned_size[0] > base_size[0]: assert ( @@ -190,7 +194,7 @@ def validate_and_combine_details( if base_type == "embedding" or base_type == "output": if not extend_vocab: logging.warning( - f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition! To preserve all embeddings, invoke script with --extend-vocab" + f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition! To preserve all embeddings, invoke script with --extend-vocab or --save-module={base_name}." ) else: logging.warning( @@ -201,14 +205,72 @@ def validate_and_combine_details( f"Finetuned module '{base_name}' will have {finetuned_size[0] - base_size[0]} rows truncated for weight decomposition!" ) - else: + elif base_type != "to_save": assert ( base_size == finetuned_size ), f"Dimension mismatch in layer '{base_name}': {base_size} != {finetuned_size}" module_details.append((base_type, base_name)) - return module_details, base_model_embedding_size, finetuned_model_embedding_size + return module_details, base_vocab_size, finetuned_vocab_size + + +def build_wi_map(base_model_ref: ModelReference, trust_remote_code: bool = False): + weight_info_map = {} + base_cfg = base_model_ref.config(trust_remote_code=trust_remote_code) + try: + arch_info = get_architecture_info(base_cfg) + except RuntimeError as e: + logging.error( + f"Failed to load architecture info for model {base_model_ref}: {e}" + ) + return {} + for weight_info in arch_info.all_weights(base_cfg): + weight_info_map[weight_info.name] = weight_info + return weight_info_map + + +def load_weights( + wi_map: Dict[str, WeightInfo], + base_loader: LazyTensorLoader, + finetuned_loader: LazyTensorLoader, + module_name: str, +): + optional = False + aliases = None + tied_names = None + if weight_info := wi_map.get(module_name + ".weight"): + if weight_info.optional: + optional = True + if weight_info.aliases: + aliases = weight_info.aliases + if weight_info.tied_names: + tied_names = weight_info.tied_names + + base_weight = base_loader.get_tensor( + f"{module_name}.weight", aliases=aliases, raise_on_missing=False + ) + finetuned_weight = finetuned_loader.get_tensor( + f"{module_name}.weight", aliases=aliases, raise_on_missing=False + ) + if optional and (base_weight is None and finetuned_weight is None): + return None, None + if tied_names: + if base_weight is None: + base_weight = base_loader.get_tensor( + f"{module_name}.weight", aliases=tied_names, raise_on_missing=False + ) + if finetuned_weight is None: + finetuned_weight = finetuned_loader.get_tensor( + f"{module_name}.weight", aliases=tied_names, raise_on_missing=False + ) + if base_weight is None: + raise RuntimeError(f"Missing base weight for {module_name}") + if finetuned_weight is None: + if optional: + return None, None + raise RuntimeError(f"Missing finetuned weight for {module_name}") + return base_weight, finetuned_weight def extract_lora( @@ -219,6 +281,7 @@ def extract_lora( extend_vocab: bool, no_lazy_unpickle: bool, device: Optional[str], + trust_remote_code: bool = False, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: """ Process module details to decompose weights and generate LoRA weights and ranks. @@ -242,9 +305,15 @@ def extract_lora( lora_weights = {} ranks = {} + wi_map = build_wi_map(base_model_ref, trust_remote_code) + for module_type, module_name in tqdm(module_details): - base_weight = base_loader.get_tensor(f"{module_name}.weight") - finetuned_weight = finetuned_loader.get_tensor(f"{module_name}.weight") + base_weight, finetuned_weight = load_weights( + wi_map, base_loader, finetuned_loader, module_name + ) + if base_weight is None and finetuned_weight is None: + logging.info(f"[{module_type}] {module_name}: optional weight not found") + continue if module_type == "to_save": lora_weights[ @@ -352,6 +421,9 @@ def reconstruct_invocation(args: Dict[str, Any]) -> str: invocation += f" --device={args['device']}" if args.get("verbose"): invocation += " --verbose" + if args.get("modules_to_save"): + for module in args["modules_to_save"]: + invocation += f" --save-module={module}" return invocation @@ -520,6 +592,19 @@ def save_model_and_config( @click.option( "--verbose", "-v", type=bool, is_flag=True, default=False, help="Verbose logging" ) +@click.option( + "--save-module", + "modules_to_save", + type=str, + multiple=True, + default=[], + help="Save the specified module(s) at full rank", +) +@click.option( + "--trust-remote-code/--no-trust-remote-code", + default=False, + help="Trust remote code when loading model configurations", +) def main( finetuned_model: str, base_model: str, @@ -531,6 +616,8 @@ def main( model_name: str, device: str, verbose: bool, + modules_to_save: List[str], + trust_remote_code: bool, ) -> None: """ Decomposes delta weights between a base model and a finetuned model, saving a PEFT model to the specified output path. @@ -553,6 +640,7 @@ def main( "no_lazy_unpickle": no_lazy_unpickle, "skip_undecomposable": skip_undecomposable, "verbose": verbose, + "modules_to_save": modules_to_save or None, } logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) @@ -564,14 +652,17 @@ def main( ( module_details, - base_model_embedding_size, - finetuned_model_embedding_size, + base_vocab_size, + finetuned_vocab_size, ) = validate_and_combine_details( - ModelReference.parse(base_model).model.path, - ModelReference.parse(finetuned_model).model.path, + base_model_ref.model.path, + finetuned_model_ref.model.path, skip_undecomposable, extend_vocab, + modules_to_save=modules_to_save, ) + logging.info(f"Base model vocab size: {base_vocab_size}") + logging.info(f"Finetuned model vocab size: {finetuned_vocab_size}") lora_weights, ranks = extract_lora( module_details, @@ -581,13 +672,14 @@ def main( extend_vocab, no_lazy_unpickle, device, + trust_remote_code, ) save_model_and_config( lora_weights, ranks, - finetuned_model_embedding_size > base_model_embedding_size and extend_vocab, - finetuned_model_embedding_size if extend_vocab else base_model_embedding_size, + finetuned_vocab_size > base_vocab_size and extend_vocab, + finetuned_vocab_size if extend_vocab else base_vocab_size, module_details, invocation_args, ) diff --git a/mergekit/scripts/tokensurgeon.py b/mergekit/scripts/tokensurgeon.py index a6715643..31d38fdf 100644 --- a/mergekit/scripts/tokensurgeon.py +++ b/mergekit/scripts/tokensurgeon.py @@ -147,26 +147,42 @@ def main( ) if lm_head_info: - old_lm_head = cache.get(model).get_tensor( - lm_head_info.name, aliases=lm_head_info.aliases, device=device - ) - donor_lm_head = cache.get(donor).get_tensor( - donor_lm_head_info.name, aliases=donor_lm_head_info.aliases, device=device - ) + try: + old_lm_head = cache.get(model).get_tensor( + lm_head_info.name, aliases=lm_head_info.aliases, device=device + ) + except KeyError: + if lm_head_info.optional: + logging.info(f"LM head tensor {lm_head_info.name} not found, skipping") + else: + report_issue( + f"Could not load LM head tensor {lm_head_info.name}", + error=True, + ) + old_lm_head = None - LOG.info("Computing new lm_head embeddings") - new_lm_head = get_embeddings( - old_lm_head, - donor_lm_head, - old_vocab, - new_vocab, - common_tokens, - accept_prefix=True, - k=k, - barycentric=barycentric, - cosine_similarity=cosine_similarity, - name=lm_head_info.name, - ) + if old_lm_head is not None: + donor_lm_head = cache.get(donor).get_tensor( + donor_lm_head_info.name, + aliases=donor_lm_head_info.aliases, + device=device, + ) + + LOG.info("Computing new lm_head embeddings") + new_lm_head = get_embeddings( + old_lm_head, + donor_lm_head, + old_vocab, + new_vocab, + common_tokens, + accept_prefix=True, + k=k, + barycentric=barycentric, + cosine_similarity=cosine_similarity, + name=lm_head_info.name, + ) + else: + new_lm_head = None # Save out the new model LOG.info(f"Saving new model to {out_path}") @@ -184,13 +200,17 @@ def main( tensor = cache.get(model).get_tensor( weight_info.name, aliases=weight_info.aliases ) + if tensor is None: + if weight_info.optional: + continue + report_issue(f"Could not load weight tensor {weight_info.name}", error=True) writer.save_tensor(weight_info.name, tensor, clone=merge_options.clone_tensors) writer.finalize() tokenizer.save_pretrained(out_path) cfg_out = arch_info.config try: - cfg_out.vocab_size = tokenizer.vocab_size + cfg_out.vocab_size = new_embed.shape[0] except AttributeError: LOG.error( "Could not set vocab size in config.json - you may need to update it manually." diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index ee6477c3..f782247f 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,6 +23,8 @@ class SparsificationMethod(str, Enum): random = "random" magnitude_outliers = "magnitude_outliers" rank_magnitude_sampling = "rank_magnitude_sampling" + consensus_ta = "consensus_ta" + consensus_ties = "consensus_ties" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -177,7 +179,10 @@ def sparsify( rescale: bool = False, epsilon: float = 0.15, ) -> torch.Tensor: - if method == SparsificationMethod.magnitude: + if ( + method == SparsificationMethod.magnitude + or method == SparsificationMethod.consensus_ties + ): return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: return bernoulli(tensor, density=density, rescale=rescale) @@ -187,3 +192,12 @@ def sparsify( return rank_magnitude(tensor, density=density, rescale=rescale, epsilon=epsilon) else: raise NotImplementedError(method) + + +def get_tall_mask( + delta: torch.Tensor, # individual task vectors + lambda_factor: float, # hyper-parameter lambda for generating TALL masks + mixed_delta: torch.Tensor, # multi-task vector +): + mask = delta.abs() > lambda_factor * (mixed_delta - delta).abs() + return mask diff --git a/mergekit/tokenizer/config.py b/mergekit/tokenizer/config.py index 94208385..7bdaeca2 100644 --- a/mergekit/tokenizer/config.py +++ b/mergekit/tokenizer/config.py @@ -49,3 +49,4 @@ class TokenEmbeddingConfig(BaseModel, frozen=True): class TokenizerConfig(BaseModel, frozen=True): source: Union[ModelReference, Literal["union"], Literal["base"]] = "union" tokens: Optional[Dict[str, TokenEmbeddingConfig]] = None + pad_to_multiple_of: Optional[int] = None diff --git a/mergekit/tokenizer/embed.py b/mergekit/tokenizer/embed.py index 3cdb1840..a853d1af 100644 --- a/mergekit/tokenizer/embed.py +++ b/mergekit/tokenizer/embed.py @@ -33,6 +33,7 @@ class PermutedEmbeddings(Task[Dict[ModelReference, torch.Tensor]]): gather_tensors: GatherTensors tokenizer_task: BuildTokenizer tokens: Optional[ImmutableMap[str, TokenEmbeddingConfig]] + pad_to_multiple_of: Optional[int] base_model: Optional[ModelReference] def arguments(self) -> Dict[str, Task]: @@ -51,6 +52,10 @@ def execute( vocab = tokenizer.get_vocab() vocab_size = len(vocab) + if self.pad_to_multiple_of and vocab_size % self.pad_to_multiple_of: + vocab_size = ( + vocab_size // self.pad_to_multiple_of + 1 + ) * self.pad_to_multiple_of embed_size = tensors[models[0]].shape[1] assert all( t.shape[1] == embed_size for t in tensors.values() @@ -59,7 +64,7 @@ def execute( dtype = tensors[models[0]].dtype device = tensors[models[0]].device - token_configs = dict(**self.tokens) or {} + token_configs = dict(**(self.tokens or {})) tokens_to_average = self.assign_embedding_sources( permutations, models, vocab, token_configs ) @@ -105,6 +110,11 @@ def execute( logging.error( f"No embedding for token {repr(token)} in model {model}!" ) + + if vocab_size > len(vocab): + # as suggested by https://nlp.stanford.edu/~johnhew/vocab-expansion.html + avg_embed = torch.mean(new_embed[: len(vocab), :], dim=0) + new_embed[len(vocab) :, :] = avg_embed result[model] = new_embed return result diff --git a/pyproject.toml b/pyproject.toml index 9bb09a7d..e04fd464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "mergekit" description = "Tools for merging pre-trained large language models" readme = "README.md" license = { text = "LGPL-3.0-or-later" } -version = "0.0.4.4" +version = "0.0.5.2" authors = [{ name = "Charles Goddard", email = "chargoddard@gmail.com" }] dependencies = [ "torch>=2.0.0", diff --git a/tests/common.py b/tests/common.py index a542cd44..23f63b25 100644 --- a/tests/common.py +++ b/tests/common.py @@ -45,8 +45,12 @@ def run_and_check_merge( index = ShardedTensorIndex.from_disk(tmpdir) for weight_info in arch_info.all_weights(config): - if weight_info.name not in index.tensor_paths: - raise RuntimeError(f"Output missing tensor {tensor_name}") + if weight_info.optional: + continue + if weight_info.name not in index.tensor_paths and not any( + a in index.tensor_paths for a in weight_info.aliases + ): + raise RuntimeError(f"Output missing tensor {weight_info.name}") if validate: validate(tmpdir) diff --git a/tests/test_basic_merges.py b/tests/test_basic_merges.py index ae54de43..797be01d 100644 --- a/tests/test_basic_merges.py +++ b/tests/test_basic_merges.py @@ -99,6 +99,45 @@ def test_slerp_merge(self, model_a, model_b): config.parameters = {"t": 0.35} run_and_check_merge(config) + def test_nearswap_merge(self, model_a, model_b): + config = self.two_model_config( + model_a, model_b, merge_method="nearswap", base_model=model_a + ) + config.parameters = {"t": 0.0001} + run_and_check_merge(config) + + def test_nuslerp_merges(self, model_a, model_b, model_c): + for base_model in [None, model_c]: + for row_wise in [False, True]: + for flatten in [False, True]: + print( + f"Testing nuslerp with row_wise={row_wise}, flatten={flatten}, base_model={base_model}" + ) + run_and_check_merge( + self.two_model_config( + model_a, + model_b, + merge_method="nuslerp", + base_model=base_model, + params={ + "nuslerp_row_wise": row_wise, + "nuslerp_flatten": flatten, + }, + ) + ) + + # test weights that sum to zero + config = self.two_model_config( + model_a, + model_b, + merge_method="nuslerp", + base_model=model_c, + params={"nuslerp_row_wise": False, "nuslerp_flatten": False}, + ) + config.models[0].parameters["weight"] = -0.5 + config.models[1].parameters["weight"] = 0.5 + run_and_check_merge(config) + def test_task_arithmetic_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, model_b, merge_method="task_arithmetic", base_model=model_c @@ -121,6 +160,15 @@ def test_ties_merge(self, model_a, model_b, model_c): ) run_and_check_merge(config) + def test_multislerp_merge(self, model_a, model_b, model_c): + config = self.two_model_config( + model_a, + model_b, + merge_method="multislerp", + base_model=model_c, + ) + run_and_check_merge(config) + def test_dare_ties_merge(self, model_a, model_b, model_c): config = self.two_model_config( model_a, diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 17fafcc8..a799e8c4 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -7,7 +7,7 @@ import tokenizers import torch from common import make_picollama, run_and_check_merge -from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase +from transformers import LlamaConfig, LlamaTokenizerFast, PreTrainedTokenizerBase from mergekit.config import InputModelDefinition, MergeConfiguration from mergekit.io import LazyTensorLoader @@ -270,6 +270,36 @@ def _check_embed(model_path: str): run_and_check_merge(config, validate=_check_embed) + def test_pad_to_multiple_of(self, model_chatml: str): + config = self.make_config( + [model_chatml], + base_model=model_chatml, + merge_method="linear", + tokenizer_config=TokenizerConfig( + source="base", + pad_to_multiple_of=16, + ), + ) + real_vocab_size = 64 + 2 + padded_size = (real_vocab_size // 16 + 1) * 16 + + def _check_result(model_path: str): + cfg = LlamaConfig.from_pretrained(model_path) + assert ( + cfg.vocab_size == padded_size + ), f"Expected vocab size {padded_size}, got {cfg.vocab_size}" + check_tokenizer( + expected_size=real_vocab_size, + must_contain=["<|im_start|>", "<|im_end|>"], + )(model_path) + + emb_out = ModelEmbeddings(model_path) + assert ( + emb_out.embed_tokens.shape[0] == padded_size + ), "Embedding size mismatch" + + run_and_check_merge(config, validate=_check_result) + def make_config( self, models: List[str],