server: add repeat penalty sigmoid #9076
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR for
server
allows to apply a sigmoid (to be precise - the logistic curve) function to therepeat_penalty
over therepeat_last_n
range.It may be useful to apply more penalty for the tokens that are closer to the end of the text, and less penalty to the tokens at the beginning of the penalty range. This will allow to set higher penalty values and they will be applied only to the recent tokens, and the older tokens will receive lower penalty and AI will have a chance to use them more freely for inference. This feature was inspired by KoboldAI's repetition penalty slope parameter, which in turn got it from NovelAI. However, the implementation in the current PR functions slightly differently (explained below), so I named it differently too to avoid confusion.
Math
The new parameter is added to the
server
API:repeat_penalty_sigmoid_growth
. It only affectsrepeat_penalty
, not other penalties. This param is calledB
in the Wikipedia, but let's call itgrowth
here.growth = 0
- the feature is disabled (default). The repetition penalty is constant across the entire penalty range.growth = 1
- the penalty will be changing linearly within therepeat_last_n
range from1
torepeat_penalty
.growth > 1
- the usual logistic curve is applied to the penalty, making it grow slower at the start, then raise rapidly in the middle, and then slowing down towards the end of the range. The formula isk = 1 / (1 + exp((-x + 0.5) * growth))
, wherex
is the normalized token position from the start of the penalty range, andk
is the coefficient to be applied to the penalty (explained below).0 < growth < 1
- a regular sigmoid function will make almost no difference within this range, but I wanted this range to be useful somehow. So I "invented" what I called in the source code "mirrored sigmoid". It means that for the range of(0;1)
the logistic function is mirrored relative tok=x
diagonal. The formula isk = 0.5 - log((1 - x) / x) / growth
.growth < 0
- basically, the same as above, but mirrored vertically (relative tok=0.5
line).All
x
andk
are normalized in the range of[0;1]
. In the current implementation the mirrored sigmoid is technically not smooth atx=0
andx=1
, but I don't think it matters in practice.The
k
is applied to the initial penalty so the resulting penalty changes from1
torepeat_penalty
. For example, ifk = 0.9
andrepeat_penalty = 1.5
then the resulting penalty is1.45
. Ifk = 0.9
andrepeat_penalty = 0.5
then the resulting penalty is0.55
.Graphs
Notes
If the "mirrored sigmoid" is too weird, I can remove it.
I put all the code in the
sigmoid
struct to better organize it. It will also allow to quickly add the same sigmoid to the other penalties (presence and frequency) if needed. Since it is only used in one function, I put the struct right into that function.In the
sigmoid
's constructor I initialize all the fields even if they are not used afterwards (whenenabled=false
), because otherwise the compiler will print lots of warnings about possibly uninitialized fields.The new code uses a long identifier name
penalty_repeat_sigmoid_growth
and it does align with some of the existing formatting.The position of the penalized token (
x
) is the position of the last occurrence of this token in the penalty range.I measured the sampling speed with and without this functionality and didn't observe any measurable impact.
Some tests are added to
tests/test-sampling.cpp
.