Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: aiproxy rerank #5235

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion service/aiproxy/controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
return nil, err
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp)
err := controller.RelayErrorHandler(resp, meta.Mode)
return &err.Error, errors.New(err.Error.Message)
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
Expand Down
2 changes: 2 additions & 0 deletions service/aiproxy/controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
fallthrough
case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
case relaymode.Rerank:
err = controller.RerankHelper(c)
default:
err = controller.RelayTextHelper(c)
}
Expand Down
4 changes: 4 additions & 0 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ func Distribute(c *gin.Context) {
return
}
requestModel := c.GetString(ctxkey.RequestModel)
if requestModel == "" {
abortWithMessage(c, http.StatusBadRequest, "no model provided")
return
}
var channel *model.Channel
channelID, ok := c.Get(ctxkey.SpecificChannelID)
if ok {
Expand Down
2 changes: 0 additions & 2 deletions service/aiproxy/middleware/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ func getRequestModel(c *gin.Context) (string, error) {
switch {
case strings.HasPrefix(path, "/v1/moderations"):
return "text-moderation-stable", nil
case strings.HasSuffix(path, "embeddings"):
return c.Param("model"), nil
case strings.HasPrefix(path, "/v1/images/generations"):
return "dall-e-2", nil
case strings.HasPrefix(path, "/v1/audio/transcriptions"), strings.HasPrefix(path, "/v1/audio/translations"):
Expand Down
9 changes: 8 additions & 1 deletion service/aiproxy/relay/adaptor/cohere/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (

"github.com/gin-gonic/gin"
"github.com/labring/sealos/service/aiproxy/relay/adaptor"
"github.com/labring/sealos/service/aiproxy/relay/adaptor/openai"
"github.com/labring/sealos/service/aiproxy/relay/meta"
"github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
)

type Adaptor struct{}
Expand Down Expand Up @@ -53,7 +55,12 @@ func (a *Adaptor) ConvertTTSRequest(*model.TextToSpeechRequest) (any, error) {
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
return
}
switch meta.Mode {
case relaymode.Rerank:
err, usage = openai.RerankHandler(c, resp, meta.PromptTokens, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
Expand Down
24 changes: 13 additions & 11 deletions service/aiproxy/relay/adaptor/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
usage.PromptTokens = meta.PromptTokens
usage.CompletionTokens = usage.TotalTokens - meta.PromptTokens
}
} else {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.AudioTranscription:
err, usage = STTHandler(c, resp, meta, a.responseFormat)
case relaymode.AudioSpeech:
err, usage = TTSHandler(c, resp, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
case relaymode.AudioTranscription:
err, usage = STTHandler(c, resp, meta, a.responseFormat)
case relaymode.AudioSpeech:
err, usage = TTSHandler(c, resp, meta)
case relaymode.Rerank:
err, usage = RerankHandler(c, resp, meta.PromptTokens, meta)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
Expand Down
44 changes: 38 additions & 6 deletions service/aiproxy/relay/adaptor/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
}

func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse
defer resp.Body.Close()

responseBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
var textResponse SlimTextResponse
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
Expand All @@ -126,18 +127,49 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
}

resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
defer resp.Body.Close()

for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)

_, _ = io.Copy(c.Writer, resp.Body)
_, _ = c.Writer.Write(responseBody)
return nil, &textResponse.Usage
}

func RerankHandler(c *gin.Context, resp *http.Response, promptTokens int, _ *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) {
defer resp.Body.Close()

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
var rerankResponse SlimRerankResponse
err = json.Unmarshal(responseBody, &rerankResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}

c.Writer.WriteHeader(resp.StatusCode)

_, _ = c.Writer.Write(responseBody)

if rerankResponse.Meta.Tokens == nil {
return nil, &model.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
}
if rerankResponse.Meta.Tokens.InputTokens <= 0 {
rerankResponse.Meta.Tokens.InputTokens = promptTokens
}
return nil, &model.Usage{
PromptTokens: rerankResponse.Meta.Tokens.InputTokens,
CompletionTokens: rerankResponse.Meta.Tokens.OutputTokens,
TotalTokens: rerankResponse.Meta.Tokens.InputTokens + rerankResponse.Meta.Tokens.OutputTokens,
}
}

func TTSHandler(c *gin.Context, resp *http.Response, meta *meta.Meta) (*model.ErrorWithStatusCode, *model.Usage) {
defer resp.Body.Close()

Expand Down
4 changes: 4 additions & 0 deletions service/aiproxy/relay/adaptor/openai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ type SlimTextResponse struct {
model.Usage `json:"usage"`
}

type SlimRerankResponse struct {
Meta model.RerankMeta `json:"meta"`
}

type TextResponseChoice struct {
FinishReason string `json:"finish_reason"`
model.Message `json:"message"`
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/controller/audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}

if isErrorHappened(meta, resp) {
err := RelayErrorHandler(resp)
err := RelayErrorHandler(resp, meta.Mode)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(context.Background(),
&ConsumeWaitGroup,
Expand Down
50 changes: 48 additions & 2 deletions service/aiproxy/relay/controller/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package controller

import (
"fmt"
"io"
"net/http"
"strconv"
"strings"

json "github.com/json-iterator/go"
"github.com/labring/sealos/service/aiproxy/common/config"
"github.com/labring/sealos/service/aiproxy/common/conv"
"github.com/labring/sealos/service/aiproxy/common/logger"
"github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/labring/sealos/service/aiproxy/relay/relaymode"
)

type GeneralErrorResponse struct {
Expand Down Expand Up @@ -52,7 +56,7 @@ func (e GeneralErrorResponse) ToMessage() string {
return ""
}

func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
func RelayErrorHandler(resp *http.Response, relayMode int) *model.ErrorWithStatusCode {
if resp == nil {
return &model.ErrorWithStatusCode{
StatusCode: 500,
Expand All @@ -63,7 +67,49 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
},
}
}
switch relayMode {
case relaymode.Rerank:
return RerankErrorHandler(resp)
default:
return RelayDefaultErrorHanlder(resp)
}
}

func RerankErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: err.Error(),
Type: "upstream_error",
Code: "bad_response",
},
}
}
trimmedRespBody := strings.Trim(conv.BytesToString(respBody), "\"")
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: trimmedRespBody,
Type: "upstream_error",
Code: "bad_response",
},
}
}

func RelayDefaultErrorHanlder(resp *http.Response) *model.ErrorWithStatusCode {
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: err.Error(),
},
}
}

ErrorWithStatusCode := &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Expand All @@ -75,7 +121,7 @@ func RelayErrorHandler(resp *http.Response) *model.ErrorWithStatusCode {
},
}
var errResponse GeneralErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errResponse)
err = json.Unmarshal(respBody, &errResponse)
if err != nil {
return ErrorWithStatusCode
}
Expand Down
23 changes: 16 additions & 7 deletions service/aiproxy/relay/controller/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,29 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int
return 0
}

func getPreConsumedAmount(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64) float64 {
preConsumedTokens := int64(promptTokens)
if textRequest.MaxTokens != 0 {
preConsumedTokens += int64(textRequest.MaxTokens)
type PreCheckGroupBalanceReq struct {
PromptTokens int
MaxTokens int
Price float64
}

func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 {
if req.Price == 0 || (req.PromptTokens == 0 && req.MaxTokens == 0) {
return 0
}
preConsumedTokens := int64(req.PromptTokens)
if req.MaxTokens != 0 {
preConsumedTokens += int64(req.MaxTokens)
}
return decimal.
NewFromInt(preConsumedTokens).
Mul(decimal.NewFromFloat(price)).
Mul(decimal.NewFromFloat(req.Price)).
Div(decimal.NewFromInt(billingprice.PriceUnit)).
InexactFloat64()
}

func preCheckGroupBalance(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, price float64, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) {
preConsumedAmount := getPreConsumedAmount(textRequest, promptTokens, price)
func preCheckGroupBalance(ctx context.Context, req *PreCheckGroupBalanceReq, meta *meta.Meta) (bool, balance.PostGroupConsumer, error) {
preConsumedAmount := getPreConsumedAmount(req)

groupRemainBalance, postGroupConsumer, err := balance.Default.GetGroupRemainBalance(ctx, meta.Group)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion service/aiproxy/relay/controller/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func RelayImageHelper(c *gin.Context, _ int) *relaymodel.ErrorWithStatusCode {
}

if isErrorHappened(meta, resp) {
err := RelayErrorHandler(resp)
err := RelayErrorHandler(resp, meta.Mode)
ConsumeWaitGroup.Add(1)
go postConsumeAmount(context.Background(),
&ConsumeWaitGroup,
Expand Down
Loading
Loading