Skip to content

Commit

Permalink
Add NearSwap (#480)
Browse files Browse the repository at this point in the history
# NearSwap Algorithm

NearSwap retains most of the weights of the base model, but when a
weight is similar between the two, it is interpolated to the secondary
model value. A parameter t specifies the sameness threshold. When the
distance between two values is below t, the weight from the secondary
model is used.

This PR implements the NearSwap algorithm from
[here](https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001)

---------

Co-authored-by: Elliot Stein <[email protected]>
Co-authored-by: Charles O. Goddard <[email protected]>
  • Loading branch information
3 people authored Jan 25, 2025
1 parent 526c5a8 commit 84c83f8
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ A quick overview of the currently supported merge methods:
| ------------------------------------------------------------------------------------------------ | -------------------- | ----------- | --------------- |
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | ✅ | ❌ |
| SLERP | `slerp` | ❌ | ✅ |
| Nearswap | `nearswap` | ❌ | ✅ |
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | ✅ | ✅ |
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | ✅ | ✅ |
Expand Down Expand Up @@ -272,6 +273,14 @@ Parameters:

- `t` - interpolation factor. At `t=0` will return `base_model`, at `t=1` will return the other one.

### Nearswap

Interpolates base model with secondary model if similarity is below t. Accepts two models.

Parameters:

- `t` - similarity threshold

### [Task Arithmetic](https://arxiv.org/abs/2212.04089)

Computes "task vectors" for each model by subtracting a base model. Merges the task vectors linearly and adds back the base. Works great for models that were fine tuned from a common ancestor. Also a super useful mental framework for several of the more involved merge methods.
Expand Down
126 changes: 126 additions & 0 deletions mergekit/merge_methods/nearswap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (C) 2025 Charles Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from typing import Any, Dict, List, Optional, Union

import torch

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class NearSwapTask(Task[torch.Tensor]):
gather_tensors: MergeTensorInput
base_model: ModelReference
t: float
weight_info: WeightInfo

def uses_accelerator(self) -> bool:
return True

def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
if self.t <= 0:
raise RuntimeError(f"Threshold cannot be <= zero, got {self.t}")
if len(tensors) == 1:
return list(tensors.values())[0]
elif len(tensors) != 2:
raise RuntimeError(
f"Nearswap merge expects exactly two models, got {len(tensors)}"
)
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")

[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
prepped_tensors = [a[1], b[1]]

rectify_embed_sizes(self.weight_info, prepped_tensors)

return (
nearswap(
self.t,
prepped_tensors[0],
prepped_tensors[1],
)
.to(prepped_tensors[0].dtype)
.to(prepped_tensors[0].device)
)


class NearSwapMerge(MergeMethod):
def name(self) -> str:
return "nearswap"

def pretty_name(self) -> Optional[str]:
return "NearSwap"

def reference_url(self) -> Optional[str]:
return "https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001"

def parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="t", required=True)]

def make_task(
self,
*,
output_weight: WeightInfo,
tensors: MergeTensorInput,
parameters: ImmutableMap[str, Any],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
return NearSwapTask(
gather_tensors=tensors,
base_model=base_model,
weight_info=output_weight,
t=parameters["t"],
)


def nearswap(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""
NearSwap implementation using PyTorch.
Adapted from: https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001
Parameters:
t (float): The sameness threshold.
v0 (torch.Tensor): Weights from the base model.
v1 (torch.Tensor): Weights from the secondary model.
Returns:
torch.Tensor: Resulting interpolated weights.
"""
# Compute the absolute difference
lweight = torch.abs(v0 - v1)

# Compute the interpolation factor
lweight = t / lweight
lweight = torch.nan_to_num(lweight, nan=1.0, posinf=1.0, neginf=1.0)
lweight = torch.clamp(lweight, min=0.0, max=1.0)

# Linearly interpolate between v0 and v1
return lweight * v1 + (1 - lweight) * v0
2 changes: 2 additions & 0 deletions mergekit/merge_methods/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.model_stock import ModelStockMerge
from mergekit.merge_methods.nearswap import NearSwapMerge
from mergekit.merge_methods.nuslerp import NuSlerpMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.sce import SCEMerge
Expand All @@ -35,6 +36,7 @@
PassthroughMerge(),
ModelStockMerge(),
SCEMerge(),
NearSwapMerge(),
# generalized task arithmetic methods
GeneralizedTaskArithmeticMerge(
consensus_method=None,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_basic_merges.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def test_slerp_merge(self, model_a, model_b):
config.parameters = {"t": 0.35}
run_and_check_merge(config)

def test_nearswap_merge(self, model_a, model_b):
config = self.two_model_config(
model_a, model_b, merge_method="nearswap", base_model=model_a
)
config.parameters = {"t": 0.0001}
run_and_check_merge(config)

def test_nuslerp_merges(self, model_a, model_b, model_c):
for base_model in [None, model_c]:
for row_wise in [False, True]:
Expand Down

0 comments on commit 84c83f8

Please sign in to comment.