From 6317c0d314741ae70f994438fb0800a0f0adad70 Mon Sep 17 00:00:00 2001 From: winlin Date: Fri, 13 Sep 2024 09:00:47 +0800 Subject: [PATCH] AI: Support OpenAI o1-preview model. v5.15.22 --- .gitignore | 1 + DEVELOPER.md | 1 + platform/ai-talk.go | 210 ++++++++++++++++++++++++++++++-------------- platform/openai.go | 34 +++++++ platform/utils.go | 1 + 5 files changed, 182 insertions(+), 65 deletions(-) create mode 100644 platform/openai.go diff --git a/.gitignore b/.gitignore index d0bbd55e..115237b8 100644 --- a/.gitignore +++ b/.gitignore @@ -140,3 +140,4 @@ __pycache__ /.*.txt /*.txt .tmp +__debug_bin* diff --git a/DEVELOPER.md b/DEVELOPER.md index 3870e133..ba3015fb 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -1285,6 +1285,7 @@ The following are the update records for the Oryx server. * Dubbing: Fix bug when changing ASR segment size. v5.15.20 * Dubbing: Refine the window of text. [v5.15.20](https://github.com/ossrs/oryx/releases/tag/v5.15.20) * Dubbing: Support space key to play/pause. v5.15.21 + * AI: Support OpenAI o1-preview model. v5.15.22 * v5.14: * Merge features and bugfix from releases. v5.14.1 * Dubbing: Support VoD dubbing for multiple languages. [v5.14.2](https://github.com/ossrs/oryx/releases/tag/v5.14.2) diff --git a/platform/ai-talk.go b/platform/ai-talk.go index d75a5a1d..deb62dcb 100644 --- a/platform/ai-talk.go +++ b/platform/ai-talk.go @@ -24,6 +24,7 @@ import ( ohttp "github.com/ossrs/go-oryx-lib/http" "github.com/ossrs/go-oryx-lib/logger" "github.com/sashabaranov/go-openai" + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ "github.com/go-redis/redis/v8" ) @@ -132,8 +133,18 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest, system := stage.prompt system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.replyLimit) - messages := []openai.ChatCompletionMessage{ - {Role: openai.ChatMessageRoleSystem, Content: system}, + messages := []openai.ChatCompletionMessage{} + + // If not support system message, use User message. + model := stage.chatModel + if gptModelSupportSystem(model) { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, Content: system, + }) + } else { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, Content: system, + }) } messages = append(messages, stage.histories...) @@ -142,30 +153,55 @@ func (v *openaiChatService) RequestChat(ctx context.Context, sreq *StageRequest, Content: user.previousAsrText, }) - model := stage.chatModel maxTokens := 1024 temperature := float32(0.9) logger.Tf(ctx, "AIChat is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v", v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.chatWindow, len(stage.histories), system) - client := openai.NewClientWithConfig(v.conf) - gptChatStream, err := client.CreateChatCompletionStream( - ctx, openai.ChatCompletionRequest{ - Model: model, - Messages: messages, - Stream: true, - Temperature: temperature, - MaxTokens: maxTokens, - }, - ) - if err != nil { - return errors.Wrapf(err, "create chat") + gptReq := openai.ChatCompletionRequest{ + Model: model, + Messages: messages, + // Some model may not support stream. + Stream: gptModelSupportStream(model), + // Some model may not support MaxTokens. + MaxTokens: gptModelSupportMaxTokens(model, maxTokens), + // Some model may not support temporature. + Temperature: gptModelSupportTemperature(model, temperature), + } + + // For OpenAI chat completion, without stream. + if !gptModelSupportStream(model) { + client := openai.NewClientWithConfig(v.conf) + gptChat, err := client.CreateChatCompletion(ctx, gptReq) + if err != nil { + return errors.Wrapf(err, "create chat") + } + + // For sync request, complete the task when finished. + defer taskCancel() + + if err := v.handleSentence(ctx, + stage, sreq, gptChat.Choices[0].Message.Content, true, nil, + func(sentence string) { + stage.previousAssitant += sentence + " " + }, + ); err != nil { + return errors.Wrapf(err, "handle chat") + } + + return nil } - // Wait for AI got the first sentence response. + // For OpenAI chat stream. Wait for AI got the first sentence response. aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx) defer aiFirstResponseCancel() + client := openai.NewClientWithConfig(v.conf) + gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq) + if err != nil { + return errors.Wrapf(err, "create chat") + } + go func() { defer gptChatStream.Close() if err := v.handle(ctx, @@ -210,8 +246,18 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR system := stage.postPrompt system += fmt.Sprintf(" Keep your reply neat, limiting the reply to %v words.", stage.postReplyLimit) - messages := []openai.ChatCompletionMessage{ - {Role: openai.ChatMessageRoleSystem, Content: system}, + messages := []openai.ChatCompletionMessage{} + + // If not support system message, use User message. + model := stage.chatModel + if gptModelSupportSystem(model) { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, Content: system, + }) + } else { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, Content: system, + }) } messages = append(messages, stage.postHistories...) @@ -220,27 +266,49 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR Content: stage.previousAssitant, }) - model := stage.postChatModel maxTokens := 1024 temperature := float32(0.9) logger.Tf(ctx, "AIPostProcess is baseURL=%v, org=%v, model=%v, maxTokens=%v, temperature=%v, window=%v, histories=%v, system is %v", v.conf.BaseURL, v.conf.OrgID, model, maxTokens, temperature, stage.postChatWindow, len(stage.postHistories), system) + gptReq := openai.ChatCompletionRequest{ + Model: model, + Messages: messages, + // Some model may not support stream. + Stream: gptModelSupportStream(model), + // Some model may not support MaxTokens. + MaxTokens: gptModelSupportMaxTokens(model, maxTokens), + // Some model may not support temporature. + Temperature: gptModelSupportTemperature(model, temperature), + } + + // For OpenAI chat completion, without stream. + if !gptModelSupportStream(model) { + client := openai.NewClientWithConfig(v.conf) + gptChat, err := client.CreateChatCompletion(ctx, gptReq) + if err != nil { + return errors.Wrapf(err, "create post-process") + } + + if err := v.handleSentence(ctx, + stage, sreq, gptChat.Choices[0].Message.Content, true, nil, + func(sentence string) { + stage.postPreviousAssitant += sentence + " " + }, + ); err != nil { + return errors.Wrapf(err, "handle post-process") + } + + return nil + } + + // For OpenAI chat stream. Wait for AI got the first sentence response. client := openai.NewClientWithConfig(v.conf) - gptChatStream, err := client.CreateChatCompletionStream( - ctx, openai.ChatCompletionRequest{ - Model: model, - Messages: messages, - Stream: true, - Temperature: temperature, - MaxTokens: maxTokens, - }, - ) + gptChatStream, err := client.CreateChatCompletionStream(ctx, gptReq) if err != nil { return errors.Wrapf(err, "create post-process") } - // Wait for AI got the first sentence response. aiFirstResponseCtx, aiFirstResponseCancel := context.WithCancel(ctx) defer aiFirstResponseCancel() @@ -268,6 +336,48 @@ func (v *openaiChatService) RequestPostProcess(ctx context.Context, sreq *StageR return nil } +func (v *openaiChatService) handleSentence( + ctx context.Context, stage *Stage, sreq *StageRequest, + sentence string, firstSentense bool, + aiFirstResponseCancel context.CancelFunc, onSentence func(string), +) error { + // Use the sentence for prompt and logging. + if onSentence != nil && sentence != "" { + onSentence(sentence) + } + + filteredSentence := sentence + if strings.TrimSpace(sentence) == "" { + return nil + } + + if firstSentense { + if stage.prefix != "" { + filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence) + } + if v.onFirstResponse != nil { + v.onFirstResponse(ctx, filteredSentence) + } + } + + segment := NewAnswerSegment(func(segment *AnswerSegment) { + segment.request = sreq + segment.text = filteredSentence + segment.first = firstSentense + }) + stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment) + + // We have commit the segment to TTS worker, so we can return the response to client and allow + // it to query audio segments immediately. + if firstSentense && aiFirstResponseCancel != nil { + aiFirstResponseCancel() + } + + logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v", + sreq.rid, segment.asid, firstSentense, filteredSentence) + return nil +} + func (v *openaiChatService) handle( ctx context.Context, stage *Stage, user *StageUser, sreq *StageRequest, gptChatStream *openai.ChatCompletionStream, aiFirstResponseCancel context.CancelFunc, @@ -364,37 +474,8 @@ func (v *openaiChatService) handle( return newSentence } - commitAISentence := func(sentence string, firstSentense bool) { - filteredSentence := sentence - if strings.TrimSpace(sentence) == "" { - return - } - - if firstSentense { - if stage.prefix != "" { - filteredSentence = fmt.Sprintf("%v %v", stage.prefix, filteredSentence) - } - if v.onFirstResponse != nil { - v.onFirstResponse(ctx, filteredSentence) - } - } - - segment := NewAnswerSegment(func(segment *AnswerSegment) { - segment.request = sreq - segment.text = filteredSentence - segment.first = firstSentense - }) - stage.ttsWorker.SubmitSegment(ctx, stage, sreq, segment) - - // We have commit the segment to TTS worker, so we can return the response to client and allow - // it to query audio segments immediately. - if firstSentense { - aiFirstResponseCancel() - } - - logger.Tf(ctx, "TTS: Commit segment rid=%v, asid=%v, first=%v, sentence is %v", - sreq.rid, segment.asid, firstSentense, filteredSentence) - return + commitAISentence := func(sentence string, firstSentense bool) error { + return v.handleSentence(ctx, stage, sreq, sentence, firstSentense, aiFirstResponseCancel, onSentence) } var sentence, lastWords string @@ -413,12 +494,11 @@ func (v *openaiChatService) handle( continue } - // Use the sentence for prompt and logging. - if onSentence != nil && sentence != "" { - onSentence(sentence) - } // Commit the sentense to TTS worker and callbacks. - commitAISentence(sentence, firstSentense) + if err = commitAISentence(sentence, firstSentense); err != nil { + return errors.Wrapf(err, "commit") + } + // Reset the sentence, because we have committed it. sentence, firstSentense = "", false } diff --git a/platform/openai.go b/platform/openai.go new file mode 100644 index 00000000..2345d7a2 --- /dev/null +++ b/platform/openai.go @@ -0,0 +1,34 @@ +// Copyright (c) 2022-2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "strings" + +func gptModelSupportSystem(model string) bool { + if strings.HasPrefix(model, "o1-") { + return false + } + return true +} + +func gptModelSupportStream(model string) bool { + if strings.HasPrefix(model, "o1-") { + return false + } + return true +} + +func gptModelSupportMaxTokens(model string, maxTokens int) int { + if strings.HasPrefix(model, "o1-") { + return 0 + } + return maxTokens +} + +func gptModelSupportTemperature(model string, temperature float32) float32 { + if strings.HasPrefix(model, "o1-") { + return 0.0 + } + return temperature +} diff --git a/platform/utils.go b/platform/utils.go index 5a116582..c8283a7f 100644 --- a/platform/utils.go +++ b/platform/utils.go @@ -28,6 +28,7 @@ import ( "github.com/ossrs/go-oryx-lib/errors" "github.com/ossrs/go-oryx-lib/logger" + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ "github.com/go-redis/redis/v8" "github.com/golang-jwt/jwt/v4"