Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the implement of sign election in TIES #474

Closed
kobayashikanna01 opened this issue Dec 20, 2024 · 2 comments
Closed

Question about the implement of sign election in TIES #474

kobayashikanna01 opened this issue Dec 20, 2024 · 2 comments

Comments

@kobayashikanna01
Copy link

kobayashikanna01 commented Dec 20, 2024

Hi,

Thank you for the great contributions! I recently try to reproduce the experiments of TIES, and I find a small part of merged parameters (less than 3%) deviate from my expectation.

Compared to Algorithm 1 in TIES (page 4 in arxiv), this toolkit add a weight variable for each model. Intuitively, I believe weight can only affect the step of disjoint merging, which replaces the equation $$\tau_m^p=\frac{1}{\vert A_p\vert}\sum_{t\in A_p}\hat{\tau}^p_t$$ by $\tau_m^p=\frac{1}{\sum_{t\in A_p} w_t}\sum_{t\in A_p}w_t\cdot\hat{\tau}^p_t$ (with $w_1, \cdots, w_n$ representing the weight for model 1 to n).

However, I find weight does not only affect averaging of retained parameters, but also changes the behavior of sign election. Carefully reviewing the codes, I find this toolkit implements TIES based on Task Arithmetic, which introduces a rescaling variable $\lambda$ (i.e., weight here) for all task vectors.

Following Task Arithmetic, this toolkit multiplies the task vectors by weight once after sparsify. It means the magnitude of each parameter is rescaled during sign election. For example, I merge a French model and a Math model with following config:

models:
  - model: llama-2-french
    parameters:
      density: 1.0
      weight: 0.2
  - model: llama-2-mathsft
    parameters:
      density: 0.4
      weight: 0.8
merge_method: ties
base_model: llama-2-7b-hf
parameters:
  normalize: true
  int8_mask: true
dtype: float16

Then, if this toolkit chooses the sign of French model, the original magnitude of a delta parameter in llama-2-french should be 4 times larger than the magnitude of corresponding one in llama-2-mathsft.

I am not sure whether it is a correct overriding of TIES. I guess there might be several people hoping the weight variables are only active when averaging but do not affect the sign election.

Thanks!

@cg123
Copy link
Collaborator

cg123 commented Dec 28, 2024

Thanks for the question!

You're right that per-model weight is an extension from the method defined in the paper. This was a common request when I first implemented it and I kinda just came up with what I thought made sense.

In the case where weight is equal between all models, the approach in mergekit should give identical results to the algorithm as defined. I suspect the small percentage deviation you're seeing can be attributed to floating point precision. If you try the merge with dtype: float32 (with or without out_dtype: float16), do you see results in line with what you expect?

As for the inequal contribution to sign election from models with different weight values, this was a deliberate choice based on what I think people probably want - consider a model that disagrees extensively with the consensus, but is weighted very low (think a few percent.) Should it really have as much sway as a model with weight 1.0? My intuition says no.

This is totally a judgement call though and either interpretation could be valid. If that's specifically the behavior you want then it shouldn't be too hard to introduce a variant that does sign election before rescaling.

@kobayashikanna01
Copy link
Author

Thank you for the detailed replies!

I will check whether floating point precision may results in the deviation between two implements.

I agree that it is easy to adjust the order of sign election and rescaling. I have also tried these two methods, and the results show that rescaling before sign election has better performance on my datasets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants