From aa67dc3074cfbb9f08c9826a9db53d75c21e6957 Mon Sep 17 00:00:00 2001 From: Daishan Peng Date: Fri, 24 Jan 2025 10:42:36 -0700 Subject: [PATCH] Add proxy change to return o1 response as SSE events --- openai-model-provider/proxy/proxy.go | 93 ++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 5 deletions(-) diff --git a/openai-model-provider/proxy/proxy.go b/openai-model-provider/proxy/proxy.go index 11cb45f0..882f61af 100644 --- a/openai-model-provider/proxy/proxy.go +++ b/openai-model-provider/proxy/proxy.go @@ -13,6 +13,12 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" ) +var ( + openaiBaseHostName = "api.openai.com" + + chatCompletionsPath = "/v1/chat/completions" +) + type Config struct { APIKey string Port string @@ -33,7 +39,7 @@ func Run(cfg *Config) error { cfg.Port = "8000" } if cfg.UpstreamHost == "" { - cfg.UpstreamHost = "api.openai.com" + cfg.UpstreamHost = openaiBaseHostName cfg.UseTLS = true } @@ -68,7 +74,8 @@ func Run(cfg *Config) error { ModifyResponse: cfg.RewriteModelsFn, }) mux.Handle("/v1/", &httputil.ReverseProxy{ - Director: s.proxyDirector, + Director: s.proxyDirector, + ModifyResponse: s.modifyResponse, }) httpServer := &http.Server{ @@ -102,7 +109,7 @@ func (s *server) proxyDirector(req *http.Request) { req.URL.Path = s.cfg.PathPrefix + req.URL.Path } - if req.Body == nil || req.Method != http.MethodPost { + if req.Body == nil || s.cfg.UpstreamHost != openaiBaseHostName || req.URL.Path != chatCompletionsPath { return } @@ -113,8 +120,7 @@ func (s *server) proxyDirector(req *http.Request) { } var reqBody openai.ChatCompletionRequest - // ignore errors here, because the request can be something other than ChatCompletionRequest - if err := json.Unmarshal(bodyBytes, &reqBody); err == nil && reqBody.Model == "o1" { + if err := json.Unmarshal(bodyBytes, &reqBody); err == nil && isModelO1(reqBody.Model) { modifyRequestBodyForO1(req, &reqBody) } else { req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -133,11 +139,88 @@ func modifyRequestBodyForO1(req *http.Request, reqBody *openai.ChatCompletionReq if err == nil { req.Body = io.NopCloser(bytes.NewBuffer(modifiedBodyBytes)) req.ContentLength = int64(len(modifiedBodyBytes)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "") + req.Header.Set("Content-Type", "application/json") } else { fmt.Println("failed to marshal request body after modification and skipping, error: ", err.Error()) } } +func (s *server) modifyResponse(resp *http.Response) error { + if resp.StatusCode != http.StatusOK || resp.Body == nil { + return nil + } + + if resp.Request.URL.Path != chatCompletionsPath || resp.Request.URL.Host != openaiBaseHostName { + return nil + } + + if resp.Header.Get("Content-Type") == "application/json" { + var respBody openai.ChatCompletionResponse + rawBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + if err := json.Unmarshal(rawBody, &respBody); err == nil && isModelO1(respBody.Model) { + // Convert non-streaming response to a single SSE for o1 model + streamResponse := openai.ChatCompletionStreamResponse{ + ID: respBody.ID, + Object: respBody.Object, + Created: respBody.Created, + Model: respBody.Model, + Usage: respBody.Usage, + Choices: func() []openai.ChatCompletionStreamChoice { + var choices []openai.ChatCompletionStreamChoice + for _, choice := range respBody.Choices { + choices = append(choices, openai.ChatCompletionStreamChoice{ + Index: choice.Index, + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: choice.Message.Content, + Role: choice.Message.Role, + FunctionCall: choice.Message.FunctionCall, + ToolCalls: choice.Message.ToolCalls, + }, + FinishReason: choice.FinishReason, + }) + } + return choices + }(), + } + + sseData, err := json.Marshal(streamResponse) + if err != nil { + return fmt.Errorf("failed to marshal stream response: %w", err) + } + + sseFormattedData := fmt.Sprintf("data: %s\n\nevent: close\ndata: [DONE]\n\n", sseData) + + pr, pw := io.Pipe() + go func() { + defer pw.Close() + pw.Write([]byte(sseFormattedData)) + }() + + resp.Header.Set("Content-Type", "text/event-stream") + resp.Header.Set("Cache-Control", "no-cache") + resp.Header.Set("Connection", "keep-alive") + resp.Body = pr + return nil + } else { + resp.Body = io.NopCloser(bytes.NewBuffer(rawBody)) + } + } + + return nil +} + +func isModelO1(model string) bool { + if model == "o1" { + return true + } + return strings.HasPrefix(model, "o1-") && !strings.HasPrefix(model, "o1-mini") +} + func Validate(cfg *Config) error { if cfg.ValidateFn == nil { return nil