From 646ed3ecc6f672c710da41f52941da78c5a02e52 Mon Sep 17 00:00:00 2001 From: Ding Yang Date: Tue, 1 Oct 2024 12:31:10 +0000 Subject: [PATCH 1/2] add consensus_ta method from https://arxiv.org/abs/2405.07813 --- mergekit/merge_methods/__init__.py | 8 +++++ .../generalized_task_arithmetic.py | 31 +++++++++++++++++-- mergekit/sparsify.py | 9 ++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index 007e163e..dd42a6fd 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -93,6 +93,14 @@ def get(method: str) -> MergeMethod: default_normalize=False, default_rescale=True, ) + + elif method == "consensus_ta": + return GeneralizedTaskArithmeticMerge( + consensus_method=None, + sparsification_method=SparsificationMethod.consensus_ta, + default_normalize=False, + default_rescale=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 214726b7..783b634a 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -29,7 +29,7 @@ MergeMethod, MergeTensorInput, ) -from mergekit.sparsify import SparsificationMethod, sparsify +from mergekit.sparsify import SparsificationMethod, sparsify, get_tall_mask class ConsensusMethod(str, Enum): @@ -79,6 +79,19 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: default_value=1.0, ) ) + if self.sparsification_method == SparsificationMethod.consensus_ta: + res.append( + ConfigParameterDef( + name="k", + default_value=1, + ) + ) + res.append( + ConfigParameterDef( + name="lambda", + default_value=1.0, + ) + ) return res def make_task( @@ -133,7 +146,7 @@ def execute( return base # sparsify - if self.method.sparsification_method: + if self.method.sparsification_method and self.method.sparsification_method != SparsificationMethod.consensus_ta: for tv_info in tvs: kwargs = {} if "gamma" in tv_info: @@ -184,6 +197,20 @@ def execute( ): lambda_factor = tvs[0]["lambda"] mixed_delta *= lambda_factor + + if ( + self.method.sparsification_method + == SparsificationMethod.consensus_ta + ): + for tv_info in tvs: + tv_info["tall_mask"] = get_tall_mask( + tv_info["delta"], + tv_info["lambda"], + mixed_delta, + ) + tall_masks = torch.stack([tv["tall_mask"] for tv in tvs], dim=0) + consensus_mask = tall_masks.sum(dim=0) >= tvs[0]["k"] + mixed_delta = mixed_delta * consensus_mask return (base + mixed_delta).to(base.dtype) diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index ee6477c3..280a6d9c 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -23,6 +23,7 @@ class SparsificationMethod(str, Enum): random = "random" magnitude_outliers = "magnitude_outliers" rank_magnitude_sampling = "rank_magnitude_sampling" + consensus_ta = "consensus_ta" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -187,3 +188,11 @@ def sparsify( return rank_magnitude(tensor, density=density, rescale=rescale, epsilon=epsilon) else: raise NotImplementedError(method) + +def get_tall_mask( + delta: torch.Tensor, # individual task vectors + lambda_factor: float, # hyper-parameter lambda for generating TALL masks + mixed_delta: torch.Tensor, # multi-task vector +): + mask = delta.abs() > lambda_factor * (mixed_delta - delta).abs() + return mask \ No newline at end of file From eafaa6248e2647c21797bd57ceb1d694ec94b47d Mon Sep 17 00:00:00 2001 From: Ding Yang Date: Tue, 1 Oct 2024 13:22:01 +0000 Subject: [PATCH 2/2] add consensus_ties method --- mergekit/merge_methods/__init__.py | 8 ++++++++ .../merge_methods/generalized_task_arithmetic.py | 12 +++++++----- mergekit/sparsify.py | 3 ++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mergekit/merge_methods/__init__.py b/mergekit/merge_methods/__init__.py index dd42a6fd..6c7f1aad 100644 --- a/mergekit/merge_methods/__init__.py +++ b/mergekit/merge_methods/__init__.py @@ -101,6 +101,14 @@ def get(method: str) -> MergeMethod: default_normalize=False, default_rescale=False, ) + + elif method == "consensus_ties": + return GeneralizedTaskArithmeticMerge( + consensus_method=ConsensusMethod.sum, + sparsification_method=SparsificationMethod.consensus_ties, + default_normalize=True, + default_rescale=False, + ) raise RuntimeError(f"Unimplemented merge method {method}") diff --git a/mergekit/merge_methods/generalized_task_arithmetic.py b/mergekit/merge_methods/generalized_task_arithmetic.py index 783b634a..cdbc6152 100644 --- a/mergekit/merge_methods/generalized_task_arithmetic.py +++ b/mergekit/merge_methods/generalized_task_arithmetic.py @@ -79,7 +79,7 @@ def tensor_parameters(self) -> List[ConfigParameterDef]: default_value=1.0, ) ) - if self.sparsification_method == SparsificationMethod.consensus_ta: + if self.sparsification_method == SparsificationMethod.consensus_ta or self.sparsification_method == SparsificationMethod.consensus_ties: res.append( ConfigParameterDef( name="k", @@ -155,7 +155,7 @@ def execute( if "epsilon" in tv_info: kwargs["epsilon"] = tv_info["epsilon"] - tv_info["delta"] = sparsify( + tv_info["sparsified_delta"] = sparsify( tv_info["delta"], density=tv_info["density"], method=self.method.sparsification_method, @@ -163,7 +163,9 @@ def execute( **kwargs, ) - deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) + deltas = torch.stack([tv["sparsified_delta"] for tv in tvs], dim=0) + else: + deltas = torch.stack([tv["delta"] for tv in tvs], dim=0) weights = torch.tensor( [tv["weight"] for tv in tvs], dtype=deltas.dtype, device=deltas.device ) @@ -199,8 +201,8 @@ def execute( mixed_delta *= lambda_factor if ( - self.method.sparsification_method - == SparsificationMethod.consensus_ta + self.method.sparsification_method== SparsificationMethod.consensus_ta + or self.method.sparsification_method == SparsificationMethod.consensus_ties ): for tv_info in tvs: tv_info["tall_mask"] = get_tall_mask( diff --git a/mergekit/sparsify.py b/mergekit/sparsify.py index 280a6d9c..a7e20506 100644 --- a/mergekit/sparsify.py +++ b/mergekit/sparsify.py @@ -24,6 +24,7 @@ class SparsificationMethod(str, Enum): magnitude_outliers = "magnitude_outliers" rank_magnitude_sampling = "rank_magnitude_sampling" consensus_ta = "consensus_ta" + consensus_ties = "consensus_ties" def rescale_sum(tensor: torch.Tensor, mask: torch.Tensor): @@ -178,7 +179,7 @@ def sparsify( rescale: bool = False, epsilon: float = 0.15, ) -> torch.Tensor: - if method == SparsificationMethod.magnitude: + if method == SparsificationMethod.magnitude or method == SparsificationMethod.consensus_ties: return magnitude(tensor, density=density, rescale=rescale) elif method == SparsificationMethod.random: return bernoulli(tensor, density=density, rescale=rescale)