Skip to content

Commit

Permalink
Add proxy change to return o1 response as SSE events
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongMonkey committed Jan 24, 2025
1 parent a38138d commit aa67dc3
Showing 1 changed file with 88 additions and 5 deletions.
93 changes: 88 additions & 5 deletions openai-model-provider/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ import (
openai "github.com/gptscript-ai/chat-completion-client"
)

var (
openaiBaseHostName = "api.openai.com"

chatCompletionsPath = "/v1/chat/completions"
)

type Config struct {
APIKey string
Port string
Expand All @@ -33,7 +39,7 @@ func Run(cfg *Config) error {
cfg.Port = "8000"
}
if cfg.UpstreamHost == "" {
cfg.UpstreamHost = "api.openai.com"
cfg.UpstreamHost = openaiBaseHostName
cfg.UseTLS = true
}

Expand Down Expand Up @@ -68,7 +74,8 @@ func Run(cfg *Config) error {
ModifyResponse: cfg.RewriteModelsFn,
})
mux.Handle("/v1/", &httputil.ReverseProxy{
Director: s.proxyDirector,
Director: s.proxyDirector,
ModifyResponse: s.modifyResponse,
})

httpServer := &http.Server{
Expand Down Expand Up @@ -102,7 +109,7 @@ func (s *server) proxyDirector(req *http.Request) {
req.URL.Path = s.cfg.PathPrefix + req.URL.Path
}

if req.Body == nil || req.Method != http.MethodPost {
if req.Body == nil || s.cfg.UpstreamHost != openaiBaseHostName || req.URL.Path != chatCompletionsPath {
return
}

Expand All @@ -113,8 +120,7 @@ func (s *server) proxyDirector(req *http.Request) {
}

var reqBody openai.ChatCompletionRequest
// ignore errors here, because the request can be something other than ChatCompletionRequest
if err := json.Unmarshal(bodyBytes, &reqBody); err == nil && reqBody.Model == "o1" {
if err := json.Unmarshal(bodyBytes, &reqBody); err == nil && isModelO1(reqBody.Model) {
modifyRequestBodyForO1(req, &reqBody)
} else {
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
Expand All @@ -133,11 +139,88 @@ func modifyRequestBodyForO1(req *http.Request, reqBody *openai.ChatCompletionReq
if err == nil {
req.Body = io.NopCloser(bytes.NewBuffer(modifiedBodyBytes))
req.ContentLength = int64(len(modifiedBodyBytes))
req.Header.Set("Accept", "application/json")
req.Header.Set("Accept-Encoding", "")
req.Header.Set("Content-Type", "application/json")
} else {
fmt.Println("failed to marshal request body after modification and skipping, error: ", err.Error())
}
}

func (s *server) modifyResponse(resp *http.Response) error {
if resp.StatusCode != http.StatusOK || resp.Body == nil {
return nil
}

if resp.Request.URL.Path != chatCompletionsPath || resp.Request.URL.Host != openaiBaseHostName {
return nil
}

if resp.Header.Get("Content-Type") == "application/json" {
var respBody openai.ChatCompletionResponse
rawBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if err := json.Unmarshal(rawBody, &respBody); err == nil && isModelO1(respBody.Model) {
// Convert non-streaming response to a single SSE for o1 model
streamResponse := openai.ChatCompletionStreamResponse{
ID: respBody.ID,
Object: respBody.Object,
Created: respBody.Created,
Model: respBody.Model,
Usage: respBody.Usage,
Choices: func() []openai.ChatCompletionStreamChoice {
var choices []openai.ChatCompletionStreamChoice
for _, choice := range respBody.Choices {
choices = append(choices, openai.ChatCompletionStreamChoice{
Index: choice.Index,
Delta: openai.ChatCompletionStreamChoiceDelta{
Content: choice.Message.Content,
Role: choice.Message.Role,
FunctionCall: choice.Message.FunctionCall,
ToolCalls: choice.Message.ToolCalls,
},
FinishReason: choice.FinishReason,
})
}
return choices
}(),
}

sseData, err := json.Marshal(streamResponse)
if err != nil {
return fmt.Errorf("failed to marshal stream response: %w", err)
}

sseFormattedData := fmt.Sprintf("data: %s\n\nevent: close\ndata: [DONE]\n\n", sseData)

pr, pw := io.Pipe()
go func() {
defer pw.Close()
pw.Write([]byte(sseFormattedData))
}()

resp.Header.Set("Content-Type", "text/event-stream")
resp.Header.Set("Cache-Control", "no-cache")
resp.Header.Set("Connection", "keep-alive")
resp.Body = pr
return nil
} else {
resp.Body = io.NopCloser(bytes.NewBuffer(rawBody))
}
}

return nil
}

func isModelO1(model string) bool {
if model == "o1" {
return true
}
return strings.HasPrefix(model, "o1-") && !strings.HasPrefix(model, "o1-mini")
}

func Validate(cfg *Config) error {
if cfg.ValidateFn == nil {
return nil
Expand Down

0 comments on commit aa67dc3

Please sign in to comment.