Skip to content

Commit

Permalink
feat: aiproxy rerank (#5235)
Browse files Browse the repository at this point in the history
* feat: aiproxy rerank support

* fix: aiproxy rerank default value

* fix: aiproxy lint

* fix: aiproxy get model name
  • Loading branch information
zijiren233 authored Nov 25, 2024
1 parent 097bdd3 commit 1782032
Show file tree
Hide file tree
Showing 18 changed files with 345 additions and 34 deletions.
2 changes: 1 addition & 1 deletion service/aiproxy/controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
return nil, err
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp)
err := controller.RelayErrorHandler(resp, meta.Mode)
return &err.Error, errors.New(err.Error.Message)
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
Expand Down
2 changes: 2 additions & 0 deletions service/aiproxy/controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
case relaymode.Rerank:
err = controller.RerankHelper(c)
default:
err = controller.RelayTextHelper(c)
}
Expand Down
4 changes: 4 additions & 0 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func Distribute(c *gin.Context) {
return
}
requestModel := c.GetString(ctxkey.RequestModel)
if requestModel == "" {
abortWithMessage(c, http.StatusBadRequest, "no model provided")
return
}
var channel *model.Channel
channelID, ok := c.Get(ctxkey.SpecificChannelID)
if ok {
Expand Down
2 changes: 0 additions & 2 deletions service/aiproxy/middleware/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ func getRequestModel(c *gin.Context) (string, error) {
switch {
case strings.HasPrefix(path, "/v1/moderations"):
return "text-moderation-stable", nil
case strings.HasSuffix(path, "embeddings"):
return c.Param("model"), nil
case strings.HasPrefix(path, "/v1/images/generations"):
return "dall-e-2", nil
case strings.HasPrefix(path, "/v1/audio/transcriptions"), strings.HasPrefix(path, "/v1/audio/translations"):
Expand Down
9 changes: 8 additions & 1 deletion service/aiproxy/relay/adaptor/cohere/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (

"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/relay/adaptor"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
)

type Adaptor struct{}
Expand Down Expand Up @@ -53,7 +55,12 @@ func (a *Adaptor) ConvertTTSRequest(*model.TextToSpeechRequest) (any, error) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
return
}
switch meta.Mode {
case relaymode.Rerank:
err, usage = openai.RerankHandler(c, resp, meta.PromptTokens, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
Expand Down
24 changes: 13 additions & 11 deletions service/aiproxy/relay/adaptor/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
usage.PromptTokens = meta.PromptTokens
usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens
}
} else {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.AudioTranscription:
err, usage = STTHandler(c, resp, meta, a.responseFormat)
case relaymode.AudioSpeech:
err, usage = TTSHandler(c, resp, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.AudioTranscription:
err, usage = STTHandler(c, resp, meta, a.responseFormat)
case relaymode.AudioSpeech:
err, usage = TTSHandler(c, resp, meta)
case relaymode.Rerank:
err, usage = RerankHandler(c, resp, meta.PromptTokens, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
Expand Down
44 changes: 38 additions & 6 deletions service/aiproxy/relay/adaptor/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
}

func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
defer resp.Body.Close()

responseBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
var textResponse SlimTextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
Expand All @@ -126,18 +127,49 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
}

resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
defer resp.Body.Close()

for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)

_, _ = io.Copy(c.Writer, resp.Body)
_, _ = c.Writer.Write(responseBody)
return nil, &textResponse.Usage
}

func RerankHandler(c *gin.Context, resp *http.Response, promptTokens int, _ *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) {
defer resp.Body.Close()

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
var rerankResponse SlimRerankResponse
err = json.Unmarshal(responseBody, &rerankResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}

c.Writer.WriteHeader(resp.StatusCode)

_, _ = c.Writer.Write(responseBody)

if rerankResponse.Meta.Tokens == nil {
return nil, &model.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
}
if rerankResponse.Meta.Tokens.InputTokens <= 0 {
rerankResponse.Meta.Tokens.InputTokens = promptTokens
}
return nil, &model.Usage{
PromptTokens: rerankResponse.Meta.Tokens.InputTokens,
CompletionTokens: rerankResponse.Meta.Tokens.OutputTokens,
TotalTokens: rerankResponse.Meta.Tokens.InputTokens + rerankResponse.Meta.Tokens.OutputTokens,
}
}

func TTSHandler(c *gin.Context, resp *http.Response, meta *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) {
defer resp.Body.Close()

Expand Down
4 changes: 4 additions & 0 deletions service/aiproxy/relay/adaptor/openai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type SlimTextResponse struct {
model.Usage `json:"usage"`
}

type SlimRerankResponse struct {
Meta model.RerankMeta `json:"meta"`
}

type TextResponseChoice struct {
FinishReason string `json:"finish_reason"`
model.Message `json:"message"`
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/controller/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}

if isErrorHappened(meta, resp) {
err := RelayErrorHandler(resp)
err := RelayErrorHandler(resp, meta.Mode)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(context.Background(),
&ConsumeWaitGroup,
Expand Down
50 changes: 48 additions & 2 deletions service/aiproxy/relay/controller/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package controller

import (
"fmt"
"io"
"net/http"
"strconv"
"strings"

json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common/config"
"github.com/labring/sealos/service/aiproxy/common/conv"
"github.com/labring/sealos/service/aiproxy/common/logger"
"github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
)

type GeneralErrorResponse struct {
Expand Down Expand Up @@ -52,7 +56,7 @@ func (e GeneralErrorResponse) ToMessage() string {
return ""
}

func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
func RelayErrorHandler(resp *http.Response, relayMode int) *model.ErrorWithStatusCode {
if resp == nil {
return &model.ErrorWithStatusCode{
StatusCode: 500,
Expand All @@ -63,7 +67,49 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
},
}
}
switch relayMode {
case relaymode.Rerank:
return RerankErrorHandler(resp)
default:
return RelayDefaultErrorHanlder(resp)
}
}

func RerankErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: err.Error(),
Type: "upstream_error",
Code: "bad_response",
},
}
}
trimmedRespBody := strings.Trim(conv.BytesToString(respBody), "\"")
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: trimmedRespBody,
Type: "upstream_error",
Code: "bad_response",
},
}
}

func RelayDefaultErrorHanlder(resp *http.Response) *model.ErrorWithStatusCode {
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: err.Error(),
},
}
}

ErrorWithStatusCode := &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Expand All @@ -75,7 +121,7 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
},
}
var errResponse GeneralErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errResponse)
err = json.Unmarshal(respBody, &errResponse)
if err != nil {
return ErrorWithStatusCode
}
Expand Down
23 changes: 16 additions & 7 deletions service/aiproxy/relay/controller/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,29 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int
return 0
}

func getPreConsumedAmount(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64) float64 {
preConsumedTokens := int64(promptTokens)
if textRequest.MaxTokens != 0 {
preConsumedTokens += int64(textRequest.MaxTokens)
type PreCheckGroupBalanceReq struct {
PromptTokens int
MaxTokens int
Price float64
}

func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 {
if req.Price == 0 || (req.PromptTokens == 0 && req.MaxTokens == 0) {
return 0
}
preConsumedTokens := int64(req.PromptTokens)
if req.MaxTokens != 0 {
preConsumedTokens += int64(req.MaxTokens)
}
return decimal.
NewFromInt(preConsumedTokens).
Mul(decimal.NewFromFloat(price)).
Mul(decimal.NewFromFloat(req.Price)).
Div(decimal.NewFromInt(billingprice.PriceUnit)).
InexactFloat64()
}

func preCheckGroupBalance(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) {
preConsumedAmount := getPreConsumedAmount(textRequest, promptTokens, price)
func preCheckGroupBalance(ctx context.Context, req *PreCheckGroupBalanceReq, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) {
preConsumedAmount := getPreConsumedAmount(req)

groupRemainBalance, postGroupConsumer, err := balance.Default.GetGroupRemainBalance(ctx, meta.Group)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/controller/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func RelayImageHelper(c *gin.Context, _ int) *relaymodel.ErrorWithStatusCode {
}

if isErrorHappened(meta, resp) {
err := RelayErrorHandler(resp)
err := RelayErrorHandler(resp, meta.Mode)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(context.Background(),
&ConsumeWaitGroup,
Expand Down
Loading

0 comments on commit 1782032

Please sign in to comment.