Skip to content

Commit

Permalink
AI: Support OpenAI o1-preview model. v5.15.22
Browse files Browse the repository at this point in the history
  • Loading branch information
winlinvip committed Sep 13, 2024
1 parent 6ec702b commit 6317c0d
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 65 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ __pycache__
/.*.txt
/*.txt
.tmp
__debug_bin*
1 change: 1 addition & 0 deletions DEVELOPER.md
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@ The following are the update records for the Oryx server.
* Dubbing: Fix bug when changing ASR segment size. v5.15.20
* Dubbing: Refine the window of text. [v5.15.20](https://github.com/ossrs/oryx/releases/tag/v5.15.20)
* Dubbing: Support space key to play/pause. v5.15.21
* AI: Support OpenAI o1-preview model. v5.15.22
* v5.14:
* Merge features and bugfix from releases. v5.14.1
* Dubbing: Support VoD dubbing for multiple languages. [v5.14.2](https://github.com/ossrs/oryx/releases/tag/v5.14.2)
Expand Down
210 changes: 145 additions & 65 deletions platform/ai-talk.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
ohttp "github.com/ossrs/go-oryx-lib/http"
"github.com/ossrs/go-oryx-lib/logger"
"github.com/sashabaranov/go-openai"

// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
)
Expand Down Expand Up @@ -132,8 +133,18 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest,

system := stage.prompt
system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.replyLimit)
messages := []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: system},
messages := []openai.ChatCompletionMessage{}

// If not support system message, use User message.
model := stage.chatModel
if gptModelSupportSystem(model) {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem, Content: system,
})
} else {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser, Content: system,
})
}

messages = append(messages, stage.histories...)
Expand All @@ -142,30 +153,55 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest,
Content: user.previousAsrText,
})

model := stage.chatModel
maxTokens := 1024
temperature := float32(0.9)
logger.Tf(ctx, "AIChat is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v",
v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.chatWindow, len(stage.histories), system)

client := openai.NewClientWithConfig(v.conf)
gptChatStream, err := client.CreateChatCompletionStream(
ctx, openai.ChatCompletionRequest{
Model: model,
Messages: messages,
Stream: true,
Temperature: temperature,
MaxTokens: maxTokens,
},
)
if err != nil {
return errors.Wrapf(err, "create chat")
gptReq := openai.ChatCompletionRequest{
Model: model,
Messages: messages,
// Some model may not support stream.
Stream: gptModelSupportStream(model),
// Some model may not support MaxTokens.
MaxTokens: gptModelSupportMaxTokens(model, maxTokens),
// Some model may not support temporature.
Temperature: gptModelSupportTemperature(model, temperature),
}

// For OpenAI chat completion, without stream.
if !gptModelSupportStream(model) {
client := openai.NewClientWithConfig(v.conf)
gptChat, err := client.CreateChatCompletion(ctx, gptReq)
if err != nil {
return errors.Wrapf(err, "create chat")
}

// For sync request, complete the task when finished.
defer taskCancel()

if err := v.handleSentence(ctx,
stage, sreq, gptChat.Choices[0].Message.Content, true, nil,
func(sentence string) {
stage.previousAssitant += sentence + " "
},
); err != nil {
return errors.Wrapf(err, "handle chat")
}

return nil
}

// Wait for AI got the first sentence response.
// For OpenAI chat stream. Wait for AI got the first sentence response.
aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx)
defer aiFirstResponseCancel()

client := openai.NewClientWithConfig(v.conf)
gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq)
if err != nil {
return errors.Wrapf(err, "create chat")
}

go func() {
defer gptChatStream.Close()
if err := v.handle(ctx,
Expand Down Expand Up @@ -210,8 +246,18 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR

system := stage.postPrompt
system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.postReplyLimit)
messages := []openai.ChatCompletionMessage{
{Role: openai.ChatMessageRoleSystem, Content: system},
messages := []openai.ChatCompletionMessage{}

// If not support system message, use User message.
model := stage.chatModel
if gptModelSupportSystem(model) {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem, Content: system,
})
} else {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser, Content: system,
})
}

messages = append(messages, stage.postHistories...)
Expand All @@ -220,27 +266,49 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR
Content: stage.previousAssitant,
})

model := stage.postChatModel
maxTokens := 1024
temperature := float32(0.9)
logger.Tf(ctx, "AIPostProcess is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v",
v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.postChatWindow, len(stage.postHistories), system)

gptReq := openai.ChatCompletionRequest{
Model: model,
Messages: messages,
// Some model may not support stream.
Stream: gptModelSupportStream(model),
// Some model may not support MaxTokens.
MaxTokens: gptModelSupportMaxTokens(model, maxTokens),
// Some model may not support temporature.
Temperature: gptModelSupportTemperature(model, temperature),
}

// For OpenAI chat completion, without stream.
if !gptModelSupportStream(model) {
client := openai.NewClientWithConfig(v.conf)
gptChat, err := client.CreateChatCompletion(ctx, gptReq)
if err != nil {
return errors.Wrapf(err, "create post-process")
}

if err := v.handleSentence(ctx,
stage, sreq, gptChat.Choices[0].Message.Content, true, nil,
func(sentence string) {
stage.postPreviousAssitant += sentence + " "
},
); err != nil {
return errors.Wrapf(err, "handle post-process")
}

return nil
}

// For OpenAI chat stream. Wait for AI got the first sentence response.
client := openai.NewClientWithConfig(v.conf)
gptChatStream, err := client.CreateChatCompletionStream(
ctx, openai.ChatCompletionRequest{
Model: model,
Messages: messages,
Stream: true,
Temperature: temperature,
MaxTokens: maxTokens,
},
)
gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq)
if err != nil {
return errors.Wrapf(err, "create post-process")
}

// Wait for AI got the first sentence response.
aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx)
defer aiFirstResponseCancel()

Expand Down Expand Up @@ -268,6 +336,48 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR
return nil
}

func (v *openaiChatService) handleSentence(
ctx context.Context, stage *Stage, sreq *StageRequest,
sentence string, firstSentense bool,
aiFirstResponseCancel context.CancelFunc, onSentence func(string),
) error {
// Use the sentence for prompt and logging.
if onSentence != nil && sentence != "" {
onSentence(sentence)
}

filteredSentence := sentence
if strings.TrimSpace(sentence) == "" {
return nil
}

if firstSentense {
if stage.prefix != "" {
filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence)
}
if v.onFirstResponse != nil {
v.onFirstResponse(ctx, filteredSentence)
}
}

segment := NewAnswerSegment(func(segment *AnswerSegment) {
segment.request = sreq
segment.text = filteredSentence
segment.first = firstSentense
})
stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment)

// We have commit the segment to TTS worker, so we can return the response to client and allow
// it to query audio segments immediately.
if firstSentense && aiFirstResponseCancel != nil {
aiFirstResponseCancel()
}

logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v",
sreq.rid, segment.asid, firstSentense, filteredSentence)
return nil
}

func (v *openaiChatService) handle(
ctx context.Context, stage *Stage, user *StageUser, sreq *StageRequest,
gptChatStream *openai.ChatCompletionStream, aiFirstResponseCancel context.CancelFunc,
Expand Down Expand Up @@ -364,37 +474,8 @@ func (v *openaiChatService) handle(
return newSentence
}

commitAISentence := func(sentence string, firstSentense bool) {
filteredSentence := sentence
if strings.TrimSpace(sentence) == "" {
return
}

if firstSentense {
if stage.prefix != "" {
filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence)
}
if v.onFirstResponse != nil {
v.onFirstResponse(ctx, filteredSentence)
}
}

segment := NewAnswerSegment(func(segment *AnswerSegment) {
segment.request = sreq
segment.text = filteredSentence
segment.first = firstSentense
})
stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment)

// We have commit the segment to TTS worker, so we can return the response to client and allow
// it to query audio segments immediately.
if firstSentense {
aiFirstResponseCancel()
}

logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v",
sreq.rid, segment.asid, firstSentense, filteredSentence)
return
commitAISentence := func(sentence string, firstSentense bool) error {
return v.handleSentence(ctx, stage, sreq, sentence, firstSentense, aiFirstResponseCancel, onSentence)
}

var sentence, lastWords string
Expand All @@ -413,12 +494,11 @@ func (v *openaiChatService) handle(
continue
}

// Use the sentence for prompt and logging.
if onSentence != nil && sentence != "" {
onSentence(sentence)
}
// Commit the sentense to TTS worker and callbacks.
commitAISentence(sentence, firstSentense)
if err = commitAISentence(sentence, firstSentense); err != nil {
return errors.Wrapf(err, "commit")
}

// Reset the sentence, because we have committed it.
sentence, firstSentense = "", false
}
Expand Down
34 changes: 34 additions & 0 deletions platform/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) 2022-2024 Winlin
//
// SPDX-License-Identifier: MIT
package main

import "strings"

func gptModelSupportSystem(model string) bool {
if strings.HasPrefix(model, "o1-") {
return false
}
return true
}

func gptModelSupportStream(model string) bool {
if strings.HasPrefix(model, "o1-") {
return false
}
return true
}

func gptModelSupportMaxTokens(model string, maxTokens int) int {
if strings.HasPrefix(model, "o1-") {
return 0
}
return maxTokens
}

func gptModelSupportTemperature(model string, temperature float32) float32 {
if strings.HasPrefix(model, "o1-") {
return 0.0
}
return temperature
}
1 change: 1 addition & 0 deletions platform/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

"github.com/ossrs/go-oryx-lib/errors"
"github.com/ossrs/go-oryx-lib/logger"

// Use v8 because we use Go 1.16+, while v9 requires Go 1.18+
"github.com/go-redis/redis/v8"
"github.com/golang-jwt/jwt/v4"
Expand Down

0 comments on commit 6317c0d

Please sign in to comment.