From 613d5f6d73656eac59c61ae18a2733e689f9c84d Mon Sep 17 00:00:00 2001 From: Rodion Mostovoi Date: Sun, 17 Dec 2023 15:22:23 +0800 Subject: [PATCH] Update default max tokens and model selection in AI chat methods Removed default max tokens and hard-coded model selection in various parts of the AI chat related code. The update allows the AI service to determine the optimal model and the number of tokens based on the dialog input, improving the flexibility and adaptability of the AI responses. Default max tokens are now nullable, offering further flexibility. The version has also been bumped from 4.0.0-alpha to 4.1.0-alpha. --- src/Directory.Build.props | 2 +- .../AiClientFromConfiguration.cs | 21 +++++++++---- src/OpenAI.ChatGpt/ChatService.cs | 16 +++------- src/OpenAI.ChatGpt/IAiClient.cs | 24 +++++++++++---- .../ChatCompletion/ChatCompletionRequest.cs | 14 +-------- .../Messaging/ChatCompletionMessage.cs | 30 ------------------- src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs | 1 - src/OpenAI.ChatGpt/OpenAiClient.cs | 30 +++++++++++++------ .../GeneratedClientsFactory.cs | 3 +- ...iClientExtensions.GetStructuredResponse.cs | 10 ++----- .../ChatGPTTranslatorService.cs | 15 ++++------ .../ClientTests/AzureOpenAiClientTests.cs | 13 ++++++-- 12 files changed, 82 insertions(+), 97 deletions(-) diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 9728ba1..fa9f1ad 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -1,6 +1,6 @@ - 4.0.0-alpha + 4.1.0-alpha net6.0;net7.0;net8.0 enable enable diff --git a/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs b/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs index 25d58b3..9f64fd3 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/AiClientFromConfiguration.cs @@ -47,7 +47,7 @@ private static void ThrowUnkownProviderException(string provider) /// public Task GetChatCompletions( UserOrSystemMessage dialog, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, @@ -60,7 +60,7 @@ public Task GetChatCompletions( /// public Task GetChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, @@ -73,7 +73,7 @@ public Task GetChatCompletions( /// public Task GetChatCompletionsRaw( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, @@ -86,7 +86,7 @@ public Task GetChatCompletionsRaw( /// public IAsyncEnumerable StreamChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, @@ -99,7 +99,7 @@ public IAsyncEnumerable StreamChatCompletions( /// public IAsyncEnumerable StreamChatCompletions( UserOrSystemMessage messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, string model = ChatCompletionModels.Default, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, bool jsonMode = false, long? seed = null, Action? requestModifier = null, CancellationToken cancellationToken = default) @@ -108,6 +108,12 @@ public IAsyncEnumerable StreamChatCompletions( messages, maxTokens, model, temperature, user, jsonMode, seed, requestModifier, cancellationToken); } + /// + public int? GetDefaultMaxTokens(string model) + { + return _client.GetDefaultMaxTokens(model); + } + /// public IAsyncEnumerable StreamChatCompletions( ChatCompletionRequest request, @@ -124,5 +130,10 @@ public IAsyncEnumerable StreamChatCompletionsRaw( return _client.StreamChatCompletionsRaw(request, cancellationToken); } + public string GetOptimalModel(ChatCompletionMessage[] messages) + { + return _client.GetOptimalModel(messages); + } + internal IAiClient GetInnerClient() => _client; } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt/ChatService.cs b/src/OpenAI.ChatGpt/ChatService.cs index cd6c6d6..9f0d0ae 100644 --- a/src/OpenAI.ChatGpt/ChatService.cs +++ b/src/OpenAI.ChatGpt/ChatService.cs @@ -102,11 +102,10 @@ private async Task GetNextMessageResponse( IsWriting = true; try { - var (model, maxTokens) = FindOptimalModelAndMaxToken(messages); var response = await _client.GetChatCompletionsRaw( messages, - maxTokens: maxTokens, - model: model, + maxTokens: Topic.Config.MaxTokens, + model:Topic.Config.Model ?? _client.GetOptimalModel(message), user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null, requestModifier: Topic.Config.ModifyRequest, cancellationToken: cancellationToken @@ -125,12 +124,6 @@ await _chatHistoryStorage.SaveMessages( } } - private (string model, int maxTokens) FindOptimalModelAndMaxToken(ChatCompletionMessage[] messages) - { - return ChatCompletionMessage.FindOptimalModelAndMaxToken( - messages, Topic.Config.Model, Topic.Config.MaxTokens); - } - public IAsyncEnumerable StreamNextMessageResponse( string message, bool throwOnCancellation = true, @@ -159,11 +152,10 @@ private async IAsyncEnumerable StreamNextMessageResponse( var messages = history.Append(message).ToArray(); var sb = new StringBuilder(); IsWriting = true; - var (model, maxTokens) = FindOptimalModelAndMaxToken(messages); var stream = _client.StreamChatCompletions( messages, - maxTokens: maxTokens, - model: model, + maxTokens: Topic.Config.MaxTokens, + model:Topic.Config.Model ?? _client.GetOptimalModel(message), user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null, requestModifier: Topic.Config.ModifyRequest, cancellationToken: cancellationToken diff --git a/src/OpenAI.ChatGpt/IAiClient.cs b/src/OpenAI.ChatGpt/IAiClient.cs index 5b22fab..a1ceaf4 100644 --- a/src/OpenAI.ChatGpt/IAiClient.cs +++ b/src/OpenAI.ChatGpt/IAiClient.cs @@ -8,6 +8,17 @@ namespace OpenAI.ChatGpt; /// public interface IAiClient { + /// + /// Retrieves the default maximum number of tokens for a given model. + /// + /// + /// The model name for which to retrieve the maximum number of tokens. + /// + /// + /// The default maximum number of tokens as an integer or just null if it's reqired to delegate it to the AI service. + /// + int? GetDefaultMaxTokens(string model); + /// /// Get a chat completion response as a string /// @@ -41,7 +52,7 @@ public interface IAiClient /// The chat completion response as a string Task GetChatCompletions( UserOrSystemMessage dialog, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -84,7 +95,7 @@ Task GetChatCompletions( /// The chat completion response as a string Task GetChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -126,7 +137,7 @@ Task GetChatCompletions( /// The raw chat completion response Task GetChatCompletionsRaw( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -167,7 +178,7 @@ Task GetChatCompletionsRaw( /// Chunks of LLM's response, one by one. IAsyncEnumerable StreamChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -201,7 +212,7 @@ IAsyncEnumerable StreamChatCompletions( /// Chunks of LLM's response, one by one IAsyncEnumerable StreamChatCompletions( UserOrSystemMessage messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -227,4 +238,7 @@ IAsyncEnumerable StreamChatCompletions( /// A stream of raw chat completion responses IAsyncEnumerable StreamChatCompletionsRaw( ChatCompletionRequest request, CancellationToken cancellationToken = default); + + string GetOptimalModel(ChatCompletionMessage[] messages); + string GetOptimalModel(UserOrSystemMessage dialog) => GetOptimalModel(dialog.GetMessages().ToArray()); } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionRequest.cs b/src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionRequest.cs index d73bd67..5a94414 100644 --- a/src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionRequest.cs +++ b/src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionRequest.cs @@ -14,9 +14,6 @@ namespace OpenAI.ChatGpt.Models.ChatCompletion; /// public class ChatCompletionRequest { - public const int MaxTokensDefault = 64; - - private int _maxTokens = MaxTokensDefault; private string _model = ChatCompletionModels.Default; private float _temperature = ChatCompletionTemperatures.Default; private IEnumerable _messages; @@ -87,7 +84,6 @@ public float Temperature /// /// The maximum number of tokens allowed for the generated answer. - /// Defaults to . /// This value is validated and limited with method. /// It's possible to calculate approximately tokens count using method. /// @@ -98,15 +94,7 @@ public float Temperature /// Encoding algorithm can be found here: https://github.com/latitudegames/GPT-3-Encoder /// [JsonPropertyName("max_tokens")] - public int MaxTokens - { - get => _maxTokens; - set - { - ChatCompletionModels.EnsureMaxTokensIsSupported(Model, value); - _maxTokens = value; - } - } + public int? MaxTokens { get; set; } = null; /// /// Number between -2.0 and 2.0. diff --git a/src/OpenAI.ChatGpt/Models/ChatCompletion/Messaging/ChatCompletionMessage.cs b/src/OpenAI.ChatGpt/Models/ChatCompletion/Messaging/ChatCompletionMessage.cs index d2af0ca..63c6e1c 100644 --- a/src/OpenAI.ChatGpt/Models/ChatCompletion/Messaging/ChatCompletionMessage.cs +++ b/src/OpenAI.ChatGpt/Models/ChatCompletion/Messaging/ChatCompletionMessage.cs @@ -113,34 +113,4 @@ public override string ToString() ? $"{Role}: {Content}" : string.Join(Environment.NewLine, _messages.Select(m => $"{m.Role}: {m.Content}")); } - - public static (string model, int maxTokens) FindOptimalModelAndMaxToken( - IEnumerable messages, - string? model, - int? maxTokens, - string smallModel = ChatCompletionModels.Default, - string bigModel = ChatCompletionModels.Gpt3_5_Turbo_16k, - bool useMaxPossibleTokens = true) - { - var tokenCount = CalculateApproxTotalTokenCount(messages); - switch (model, maxTokens) - { - case (null, null): - { - model = tokenCount > 6000 ? bigModel : smallModel; - maxTokens = GetMaxPossibleTokens(model); - break; - } - case (null, _): - model = smallModel; - break; - case (_, null): - maxTokens = useMaxPossibleTokens ? GetMaxPossibleTokens(model) : ChatCompletionRequest.MaxTokensDefault; - break; - } - - return (model, maxTokens.Value); - - int GetMaxPossibleTokens(string s) => ChatCompletionModels.GetMaxTokensLimitForModel(s) - tokenCount - 500; - } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs b/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs index b9a7bdb..e43ccc9 100644 --- a/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs +++ b/src/OpenAI.ChatGpt/Models/ChatGPTConfig.cs @@ -56,7 +56,6 @@ public class ChatGPTConfig /// /// The maximum number of tokens allowed for the generated answer. - /// Defaults to . /// This value is validated and limited with method. /// It's possible to calculate approximately tokens count using method. /// Maps to: diff --git a/src/OpenAI.ChatGpt/OpenAiClient.cs b/src/OpenAI.ChatGpt/OpenAiClient.cs index e98513d..e592825 100644 --- a/src/OpenAI.ChatGpt/OpenAiClient.cs +++ b/src/OpenAI.ChatGpt/OpenAiClient.cs @@ -152,7 +152,7 @@ private static void ValidateHttpClient( /// public async Task GetChatCompletions( UserOrSystemMessage dialog, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -185,7 +185,7 @@ public async Task GetChatCompletions( /// public async Task GetChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -218,7 +218,7 @@ public async Task GetChatCompletions( /// public async Task GetChatCompletionsRaw( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -277,7 +277,7 @@ protected virtual string GetChatCompletionsEndpoint() /// public IAsyncEnumerable StreamChatCompletions( IEnumerable messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -304,9 +304,9 @@ public IAsyncEnumerable StreamChatCompletions( return StreamChatCompletions(request, cancellationToken); } - private static ChatCompletionRequest CreateChatCompletionRequest( + private ChatCompletionRequest CreateChatCompletionRequest( IEnumerable messages, - int maxTokens, + int? maxTokens, string model, float temperature, string? user, @@ -316,6 +316,7 @@ private static ChatCompletionRequest CreateChatCompletionRequest( Action? requestModifier) { ArgumentNullException.ThrowIfNull(messages); + maxTokens ??= GetDefaultMaxTokens(model); var request = new ChatCompletionRequest(messages) { Model = model, @@ -330,10 +331,15 @@ private static ChatCompletionRequest CreateChatCompletionRequest( return request; } + public int? GetDefaultMaxTokens(string model) + { + return null; + } + /// public IAsyncEnumerable StreamChatCompletions( UserOrSystemMessage messages, - int maxTokens = ChatCompletionRequest.MaxTokensDefault, + int? maxTokens = null, string model = ChatCompletionModels.Default, float temperature = ChatCompletionTemperatures.Default, string? user = null, @@ -346,7 +352,8 @@ public IAsyncEnumerable StreamChatCompletions( if (model == null) throw new ArgumentNullException(nameof(model)); EnsureJsonModeIsSupported(model, jsonMode); ThrowIfDisposed(); - var request = CreateChatCompletionRequest(messages.GetMessages(), + var request = CreateChatCompletionRequest( + messages.GetMessages(), maxTokens, model, temperature, @@ -393,7 +400,12 @@ public IAsyncEnumerable StreamChatCompletionsRaw( cancellationToken ); } - + + public string GetOptimalModel(ChatCompletionMessage[] messages) + { + return ChatCompletionModels.Gpt4Turbo; + } + private static void EnsureJsonModeIsSupported(string model, bool jsonMode) { if(jsonMode && !ChatCompletionModels.IsJsonModeSupported(model)) diff --git a/src/internal/OpenAI.GeneratedKiotaClient/GeneratedClientsFactory.cs b/src/internal/OpenAI.GeneratedKiotaClient/GeneratedClientsFactory.cs index 36dc1d7..9e27fca 100644 --- a/src/internal/OpenAI.GeneratedKiotaClient/GeneratedClientsFactory.cs +++ b/src/internal/OpenAI.GeneratedKiotaClient/GeneratedClientsFactory.cs @@ -11,7 +11,8 @@ public static GeneratedOpenAiClient CreateGeneratedOpenAiClient(HttpClient httpC ArgumentNullException.ThrowIfNull(httpClient); var authProvider = new AnonymousAuthenticationProvider(); var adapter = new HttpClientRequestAdapter(authProvider, httpClient: httpClient); - return new GeneratedOpenAiClient(adapter); + var openAiClient = new GeneratedOpenAiClient(adapter); + return openAiClient; } public static GeneratedAzureOpenAiClient CreateGeneratedAzureOpenAiClient(HttpClient httpClient) diff --git a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs index 2748624..8097b4b 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs @@ -114,17 +114,11 @@ internal static async Task GetStructuredResponse( { editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat, examples, jsonSerializerOptions); - (model, maxTokens) = FindOptimalModelAndMaxToken( - dialog.GetMessages(), - model, - maxTokens, - smallModel: ChatCompletionModels.Gpt4, - bigModel: ChatCompletionModels.Gpt4 - ); + model ??= client.GetOptimalModel(dialog); var response = await client.GetChatCompletions( dialog, - maxTokens.Value, + maxTokens, model, temperature, user, diff --git a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs index 6f67934..a28f8c1 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs @@ -78,10 +78,10 @@ public async Task TranslateText( var prompt = CreateTextTranslationPrompt(sourceLanguageOrDefault, targetLanguageOrDefault); var messages = Dialog.StartAsSystem(prompt).ThenUser(text).GetMessages().ToArray(); - (model, maxTokens) = ChatCompletionMessage.FindOptimalModelAndMaxToken(messages, model, maxTokens); + model ??= _client.GetOptimalModel(messages); var response = await _client.GetChatCompletions( messages, - maxTokens.Value, + maxTokens, model, temperature, user, @@ -140,16 +140,11 @@ public virtual async Task TranslateObject( var objectJson = JsonSerializer.Serialize(objectToTranslate, jsonSerializerOptions); var dialog = Dialog.StartAsSystem(prompt).ThenUser(objectJson); var messages = dialog.GetMessages().ToArray(); - (model, maxTokens) = ChatCompletionMessage.FindOptimalModelAndMaxToken( - messages, - model, - maxTokens, - smallModel: ChatCompletionModels.Gpt4, - bigModel: ChatCompletionModels.Gpt4 - ); + model ??= _client.GetOptimalModel(messages); + var response = await _client.GetStructuredResponse( dialog, - maxTokens.Value, + maxTokens, model, temperature, user, diff --git a/tests/OpenAI.ChatGpt.IntegrationTests/ClientTests/AzureOpenAiClientTests.cs b/tests/OpenAI.ChatGpt.IntegrationTests/ClientTests/AzureOpenAiClientTests.cs index 6d53b50..3d8e187 100644 --- a/tests/OpenAI.ChatGpt.IntegrationTests/ClientTests/AzureOpenAiClientTests.cs +++ b/tests/OpenAI.ChatGpt.IntegrationTests/ClientTests/AzureOpenAiClientTests.cs @@ -14,13 +14,22 @@ public AzureOpenAiClientTests(ITestOutputHelper outputHelper, AzureOpenAiClientF } [Fact] - public async void Get_response_from_gpt4_32k_model_for_one_message_works() + public async void Get_response_from_GPT4_32k_model_for_one_message_works() { string text = "Who are you? In two words."; #pragma warning disable CS0618 // Type or member is obsolete - string response = await _client.GetChatCompletions(new UserMessage(text), 64, model: ChatCompletionModels.Gpt4_32k); + string response = await _client.GetChatCompletions(new UserMessage(text), model: ChatCompletionModels.Gpt4_32k); #pragma warning restore CS0618 // Type or member is obsolete _outputHelper.WriteLine(response); response.Should().NotBeNullOrEmpty(); } + + [Fact] + public async void Get_long_response_from_gpt4_Turbo_model() + { + string text = "Describe who are you in a very detailed way. At least 300 words."; + string response = await _client.GetChatCompletions(new UserMessage(text), model: ChatCompletionModels.Gpt4Turbo); + _outputHelper.WriteLine(response); + response.Should().NotBeNullOrEmpty(); + } } \ No newline at end of file