Skip to content

Commit

Permalink
Refactor chat stream functions to return LLMAnswer struct instead o…
Browse files Browse the repository at this point in the history
…f multiple values

- Introduced `LLMAnswer` struct in `models.go` to consolidate return values.
- Updated `chatStream`, `CompletionStream`, and `chatStreamClaude` functions to return `LLMAnswer` for consistency and clarity.
- Simplified error handling and return logic across streaming functions.
  • Loading branch information
swuecho committed Feb 4, 2025
1 parent 9e90417 commit b659f0f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
60 changes: 33 additions & 27 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,41 +477,42 @@ func getPerWordStreamLimit() int {
return perWordStreamLimit
}

func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, streamOutput bool) (string, string, bool) {
func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, streamOutput bool) models.LLMAnswer {
// check per chat_model limit
shouldReturn := models.LLMAnswer{ShouldReturn: true}

openAIRateLimiter.Wait(context.Background())

exceedPerModeRateLimitOrError := h.CheckModelAccess(w, chatSession.Uuid, chatSession.Model, chatSession.UserID)
if exceedPerModeRateLimitOrError {
return "", "", true
return shouldReturn
}

chatModel, err := h.service.q.ChatModelByName(context.Background(), chatSession.Model)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return "", "", true
return shouldReturn
}

config, err := genOpenAIConfig(chatModel)
log.Printf("%+v", config)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "gen open ai config").Error(), err)
return "", "", true
return shouldReturn
}

client := openai.NewClientWithConfig(config)

chatFiles, err := h.chatfileService.q.ListChatFilesWithContentBySessionUUID(context.Background(), chatSession.Uuid)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "Error getting chat files").Error(), err)
return "", "", true
return shouldReturn
}

openai_req := NewChatCompletionRequest(chatSession, chat_compeletion_messages, chatFiles, streamOutput)
if len(openai_req.Messages) <= 1 {
RespondWithError(w, http.StatusInternalServerError, "error.system_message_notice", err)
return "", "", true
return shouldReturn
}
log.Printf("%+v", openai_req)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
Expand All @@ -522,19 +523,19 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
if err != nil {
log.Printf("fail to do request: %+v", err)
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_do_request", err)
return "", "", true
return shouldReturn
}
log.Printf("completion: %+v", completion)
data, _ := json.Marshal(completion)
fmt.Fprint(w, string(data))
return completion.Choices[0].Message.Content, completion.ID, false
return models.LLMAnswer{Answer: completion.Choices[0].Message.Content, AnswerId: completion.ID, ShouldReturn: false}
}
stream, err := client.CreateChatCompletionStream(ctx, openai_req)

if err != nil {
log.Printf("fail to do request: %+v", err)
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_do_request", err)
return "", "", true
return shouldReturn
}
defer stream.Close()

Expand All @@ -543,7 +544,7 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
return shouldReturn
}

var answer string
Expand Down Expand Up @@ -580,11 +581,11 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
flusher.Flush()
}
// no reason in the answer (so do not disrupt the context)
return textBuffer.String("\n"), answer_id, false
return models.LLMAnswer{Answer: textBuffer.String("\n"), AnswerId: answer_id, ShouldReturn: false, ReasonContent: reasonBuffer.String("\n")}
} else {
log.Printf("%v", err)
RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("Stream error: %v", err), nil)
return "", "", true
return shouldReturn
}
}
response := llm_openai.ChatCompletionStreamResponse{}
Expand Down Expand Up @@ -628,26 +629,27 @@ func (h *ChatHandler) chatStream(w http.ResponseWriter, chatSession sqlc_queries
}
}

func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, streamOutput bool) (string, string, bool) {
func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, streamOutput bool) models.LLMAnswer {
// check per chat_model limit

shouldReturn := models.LLMAnswer{ShouldReturn: true}
openAIRateLimiter.Wait(context.Background())

exceedPerModeRateLimitOrError := h.CheckModelAccess(w, chatSession.Uuid, chatSession.Model, chatSession.UserID)
if exceedPerModeRateLimitOrError {
return "", "", true
return shouldReturn
}

chatModel, err := h.service.q.ChatModelByName(context.Background(), chatSession.Model)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return "", "", true
return shouldReturn
}

config, err := genOpenAIConfig(chatModel)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "gen open ai config").Error(), err)
return "", "", true
return shouldReturn
}

client := openai.NewClientWithConfig(config)
Expand All @@ -673,7 +675,7 @@ func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_q
stream, err := client.CreateCompletionStream(ctx, req)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_do_request", err)
return "", "", true
return shouldReturn
}
defer stream.Close()

Expand All @@ -682,7 +684,7 @@ func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_q
flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
return shouldReturn
}

var answer string
Expand Down Expand Up @@ -715,7 +717,7 @@ func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_q
}
if err != nil {
RespondWithError(w, http.StatusInternalServerError, fmt.Sprintf("Stream error: %v", err), nil)
return "", "", true
return shouldReturn
}
textIdx := response.Choices[0].Index
delta := response.Choices[0].Text
Expand All @@ -742,7 +744,7 @@ func (h *ChatHandler) CompletionStream(w http.ResponseWriter, chatSession sqlc_q
}
}
}
return answer, answer_id, false
return models.LLMAnswer{AnswerId: answer_id, Answer: answer, ShouldReturn: false}
}

type ClaudeResponse struct {
Expand All @@ -755,7 +757,7 @@ type ClaudeResponse struct {
Exception interface{} `json:"exception"`
}

func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, stream bool) (string, string, bool) {
func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []models.Message, chatUuid string, regenerate bool, stream bool) models.LLMAnswer {
// Obtain the API token (buffer 1, send to channel will block if there is a token in the buffer)
claudeRateLimiteToken <- struct{}{}
// Release the API token
Expand All @@ -764,7 +766,7 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
chatModel, err := h.service.q.ChatModelByName(context.Background(), chatSession.Model)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "get chat model").Error(), err)
return "", "", true
return models.LLMAnswer{ShouldReturn: true}
}

// OPENAI_API_KEY
Expand All @@ -791,7 +793,7 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q

if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_make_request", err)
return "", "", true
return models.LLMAnswer{ShouldReturn: true}
}

// add headers to the request
Expand All @@ -815,7 +817,7 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
resp, err := client.Do(req)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_do_request", err)
return "", "", true
return models.LLMAnswer{ShouldReturn: true}
}

ioreader := bufio.NewReader(resp.Body)
Expand All @@ -829,7 +831,7 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
flusher, ok := w.(http.Flusher)
if !ok {
RespondWithError(w, http.StatusInternalServerError, "Streaming unsupported!", nil)
return "", "", true
return models.LLMAnswer{ShouldReturn: true}
}

var answer string
Expand All @@ -853,7 +855,7 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
fmt.Println("End of stream reached")
break // Exit loop if end of stream
}
return "", "", true
return models.LLMAnswer{ShouldReturn: true}
}
if !bytes.HasPrefix(line, headerData) {
continue
Expand Down Expand Up @@ -881,7 +883,11 @@ func (h *ChatHandler) chatStreamClaude(w http.ResponseWriter, chatSession sqlc_q
}
}

return answer, answer_id, false
return models.LLMAnswer{
Answer: answer,
AnswerId: answer_id,
ShouldReturn: false,
}
}

// claude-3-opus-20240229
Expand Down
7 changes: 7 additions & 0 deletions api/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ func (m *Message) SetTokenCount(tokenCount int32) *Message {
m.tokenCount = tokenCount
return m
}

type LLMAnswer struct {
AnswerId string `json:"id"`
Answer string `json:"answer"`
ReasonContent string `json:"reason_content"`
ShouldReturn bool `json:"should_return"`
}

0 comments on commit b659f0f

Please sign in to comment.