From 477ef3ac869adb5fc5a98b2199853b13ee6e779a Mon Sep 17 00:00:00 2001 From: Rodion Mostovoi Date: Fri, 3 Nov 2023 23:52:08 +0800 Subject: [PATCH] Update OpenAI library with modified client creation and improved structured response This update makes several modifications to the OpenAI .NET library. First, the creation of OpenAI client instances was simplified by removing unnecessary calls to IHttpClientFactory, while credential setup was shifted to HttpClient extension for reusability. The ServiceCollectionExtensions.cs and ChatGPTFactory.cs were updated to reflect these changes. In addition, support for multiple serialized examples was added to the GetStructuredResponse method extension for OpenAiClient, along with custom JSON serializer options. Now, the expected structured response can include several examples, providing more accurate API responses. Lastly, the version of the library in Directory.Build.props was incremented to 2.9.1, reflecting these improvements. --- OpenAI_DotNet.sln | 5 ++ src/Directory.Build.props | 2 +- .../ChatGPTFactory.cs | 23 +------ .../Extensions/ServiceCollectionExtensions.cs | 59 +++++++---------- .../Models/OpenAICredentials.cs | 7 ++ .../Extensions/ServiceCollectionExtensions.cs | 10 ++- ...iClientExtensions.GetStructuredResponse.cs | 66 ++++++++++++++++--- .../ChatGPTTranslatorService.cs | 2 +- .../OpenAiClient_GetStructuredResponse.cs | 7 +- .../ChatGptServicesIntegrationTests.cs | 2 +- 10 files changed, 100 insertions(+), 83 deletions(-) diff --git a/OpenAI_DotNet.sln b/OpenAI_DotNet.sln index 7a7456e..27b812e 100644 --- a/OpenAI_DotNet.sln +++ b/OpenAI_DotNet.sln @@ -33,6 +33,11 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAI.ChatGpt.Modules.Stru EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{068E9E67-C2FC-4F8C-B27C-CB3A8FA44BD8}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "configs", "configs", "{77B5B4CD-2299-4FEE-B6C3-1090A8A8F2C2}" + ProjectSection(SolutionItems) = preProject + src\Directory.Build.props = src\Directory.Build.props + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 35874f0..dd9f6fd 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -1,6 +1,6 @@ - 2.9.0 + 2.9.1 enable enable 12 diff --git a/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs b/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs index 637c044..88539d0 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs @@ -1,16 +1,10 @@ using Microsoft.Extensions.Options; -using OpenAI.ChatGpt.AspNetCore.Models; namespace OpenAI.ChatGpt.AspNetCore; /// /// Factory for creating instances from DI. /// -/// -/// builder.Services.AddHttpClient<ChatGPTFactory<("OpenAiClient") -/// .AddPolicyHandler(GetRetryPolicy()) -/// .AddPolicyHandler(GetCircuitBreakerPolicy()); -/// [Fody.ConfigureAwait(false)] // ReSharper disable once InconsistentNaming public class ChatGPTFactory : IDisposable @@ -23,18 +17,15 @@ public class ChatGPTFactory : IDisposable private volatile bool _ensureStorageCreatedCalled; public ChatGPTFactory( - IHttpClientFactory httpClientFactory, - IOptions credentials, + IOpenAiClient client, IOptions config, IChatHistoryStorage chatHistoryStorage, ITimeProvider clock) { - if (httpClientFactory == null) throw new ArgumentNullException(nameof(httpClientFactory)); - if (credentials?.Value == null) throw new ArgumentNullException(nameof(credentials)); _config = config?.Value ?? throw new ArgumentNullException(nameof(config)); _chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage)); _clock = clock ?? throw new ArgumentNullException(nameof(clock)); - _client = CreateOpenAiClient(httpClientFactory, credentials); + _client = client ?? throw new ArgumentNullException(nameof(client)); _isHttpClientInjected = true; } @@ -65,16 +56,6 @@ public ChatGPTFactory( _clock = clock ?? new TimeProviderUtc(); } - private OpenAiClient CreateOpenAiClient( - IHttpClientFactory httpClientFactory, - IOptions credentials) - { - var httpClient = httpClientFactory.CreateClient(OpenAiClient.HttpClientName); - httpClient.DefaultRequestHeaders.Authorization = credentials.Value.GetAuthHeader(); - httpClient.BaseAddress = new Uri(credentials.Value.ApiHost); - return new OpenAiClient(httpClient); - } - public static ChatGPTFactory CreateInMemory( string apiKey, ChatGPTConfig? config = null, diff --git a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs index b9b002e..c73c09a 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs @@ -1,19 +1,18 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; -using OpenAI.ChatGpt.AspNetCore.Models; namespace OpenAI.ChatGpt.AspNetCore.Extensions; public static class ServiceCollectionExtensions { public const string CredentialsConfigSectionPathDefault = "OpenAICredentials"; + // ReSharper disable once InconsistentNaming public const string ChatGPTConfigSectionPathDefault = "ChatGPTConfig"; - - public static IServiceCollection AddChatGptInMemoryIntegration( + + public static IHttpClientBuilder AddChatGptInMemoryIntegration( this IServiceCollection services, bool injectInMemoryChatService = true, - bool injectOpenAiClient = true, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault) { @@ -23,23 +22,23 @@ public static IServiceCollection AddChatGptInMemoryIntegration( throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(completionsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", nameof(completionsConfigSectionPath)); } - services.AddChatGptIntegrationCore( - credentialsConfigSectionPath, - completionsConfigSectionPath, - injectOpenAiClient: injectOpenAiClient - ); + services.AddSingleton(); - if(injectInMemoryChatService) + if (injectInMemoryChatService) { services.AddScoped(CreateChatService); } - - return services; + + return services.AddChatGptIntegrationCore( + credentialsConfigSectionPath, + completionsConfigSectionPath + ); } private static ChatService CreateChatService(IServiceProvider provider) @@ -47,12 +46,13 @@ private static ChatService CreateChatService(IServiceProvider provider) ArgumentNullException.ThrowIfNull(provider); var userId = Guid.Empty.ToString(); var storage = provider.GetRequiredService(); - if(storage is not InMemoryChatHistoryStorage) + if (storage is not InMemoryChatHistoryStorage) { throw new InvalidOperationException( $"Chat injection is supported only with {nameof(InMemoryChatHistoryStorage)} " + $"and is not supported for {storage.GetType().FullName}"); } + /* * .GetAwaiter().GetResult() are safe here because we use sync in memory storage */ @@ -64,13 +64,12 @@ private static ChatService CreateChatService(IServiceProvider provider) return chat; } - public static IServiceCollection AddChatGptIntegrationCore( - this IServiceCollection services, + public static IHttpClientBuilder AddChatGptIntegrationCore( + this IServiceCollection services, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, - ServiceLifetime serviceLifetime = ServiceLifetime.Scoped, - bool injectOpenAiClient = true - ) + ServiceLifetime serviceLifetime = ServiceLifetime.Scoped + ) { ArgumentNullException.ThrowIfNull(services); if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) @@ -78,6 +77,7 @@ public static IServiceCollection AddChatGptIntegrationCore( throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(completionsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", @@ -94,33 +94,18 @@ public static IServiceCollection AddChatGptIntegrationCore( .ValidateDataAnnotations() .ValidateOnStart(); - if (services.All(it => it.ServiceType != typeof(IHttpClientFactory))) - { - services.AddHttpClient(OpenAiClient.HttpClientName); - } - services.AddSingleton(); services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), serviceLifetime)); - if (injectOpenAiClient) - { - AddOpenAiClient(services); - } - - return services; + return AddOpenAiClient(services); } - private static void AddOpenAiClient(IServiceCollection services) + private static IHttpClientBuilder AddOpenAiClient(IServiceCollection services) { - services.AddSingleton(provider => + return services.AddHttpClient((provider, httpClient) => { var credentials = provider.GetRequiredService>().Value; - var factory = provider.GetRequiredService(); - var httpClient = factory.CreateClient(OpenAiClient.HttpClientName); - httpClient.DefaultRequestHeaders.Authorization = credentials.GetAuthHeader(); - httpClient.BaseAddress = new Uri(credentials.ApiHost); - var client = new OpenAiClient(httpClient); - return client; + credentials.SetupHttpClient(httpClient); }); } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs b/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs index e02d36f..ca0bb04 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs @@ -26,4 +26,11 @@ public AuthenticationHeaderValue GetAuthHeader() { return new AuthenticationHeaderValue("Bearer", ApiKey); } + + public void SetupHttpClient(HttpClient httpClient) + { + ArgumentNullException.ThrowIfNull(httpClient); + httpClient.DefaultRequestHeaders.Authorization = GetAuthHeader(); + httpClient.BaseAddress = new Uri(ApiHost); + } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs index 6cb201d..8378aa3 100644 --- a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs @@ -9,13 +9,12 @@ public static class ServiceCollectionExtensions /// /// Adds the implementation using Entity Framework Core. /// - public static IServiceCollection AddChatGptEntityFrameworkIntegration( + public static IHttpClientBuilder AddChatGptEntityFrameworkIntegration( this IServiceCollection services, Action optionsAction, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, - ServiceLifetime serviceLifetime = ServiceLifetime.Scoped, - bool injectOpenAiClient = true) + ServiceLifetime serviceLifetime = ServiceLifetime.Scoped) { ArgumentNullException.ThrowIfNull(services); ArgumentNullException.ThrowIfNull(optionsAction); @@ -30,8 +29,6 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration( nameof(completionsConfigSectionPath)); } - services.AddChatGptIntegrationCore( - credentialsConfigSectionPath, completionsConfigSectionPath, serviceLifetime, injectOpenAiClient); services.AddDbContext(optionsAction, serviceLifetime); switch (serviceLifetime) { @@ -48,6 +45,7 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration( throw new ArgumentOutOfRangeException(nameof(serviceLifetime), serviceLifetime, null); } - return services; + return services.AddChatGptIntegrationCore( + credentialsConfigSectionPath, completionsConfigSectionPath, serviceLifetime); } } \ No newline at end of file diff --git a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs index e182c20..abaf2d3 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs @@ -4,6 +4,7 @@ using Json.Schema.Generation; using OpenAI.ChatGpt.Models.ChatCompletion; using OpenAI.ChatGpt.Models.ChatCompletion.Messaging; +using static OpenAI.ChatGpt.Models.ChatCompletion.Messaging.ChatCompletionMessage; namespace OpenAI.ChatGpt.Modules.StructuredResponse; @@ -21,6 +22,12 @@ public static class OpenAiClientExtensions WriteIndented = false }; + private static readonly JsonSerializerOptions JsonDefaultSerializerOptions = new() + { + Converters = { new JsonStringEnumConverter() }, + WriteIndented = false + }; + /// /// Asynchronously sends a chat completion request to the OpenAI API and deserializes the response to a specific object type. /// @@ -34,6 +41,8 @@ public static class OpenAiClientExtensions /// Optional. A function that can modify the chat completion request before it is sent to the API. /// Optional. A function that can access the raw API response. /// Optional. Custom JSON deserializer options for the deserialization. If not specified, default options with case insensitive property names are used. + /// Optional. Custom JSON serializer options for the serialization. + /// Optional. Example of the models those will be serialized using /// Optional. A cancellation token that can be used to cancel the operation. /// /// A task that represents the asynchronous operation. The task result contains the deserialized object from the API response. @@ -54,6 +63,8 @@ public static Task GetStructuredResponse( Action? requestModifier = null, Action? rawResponseGetter = null, JsonSerializerOptions? jsonDeserializerOptions = null, + JsonSerializerOptions? jsonSerializerOptions = null, + IEnumerable? examples = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(client); @@ -70,6 +81,8 @@ public static Task GetStructuredResponse( requestModifier: requestModifier, rawResponseGetter: rawResponseGetter, jsonDeserializerOptions: jsonDeserializerOptions, + jsonSerializerOptions: jsonSerializerOptions, + examples: examples, cancellationToken: cancellationToken ); } @@ -85,20 +98,22 @@ internal static async Task GetStructuredResponse( Action? requestModifier = null, Action? rawResponseGetter = null, JsonSerializerOptions? jsonDeserializerOptions = null, + JsonSerializerOptions? jsonSerializerOptions = null, + IEnumerable? examples = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(client); ArgumentNullException.ThrowIfNull(dialog); - var editMsg = dialog.GetMessages().FirstOrDefault(it => it is SystemMessage) + var editMsg = dialog.GetMessages() + .FirstOrDefault(it => it is SystemMessage) ?? dialog.GetMessages()[0]; var originalContent = editMsg.Content; try { - editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat); + editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat, examples, jsonSerializerOptions); - (model, maxTokens) = - ChatCompletionMessage.FindOptimalModelAndMaxToken(dialog.GetMessages(), model, maxTokens); + (model, maxTokens) = FindOptimalModelAndMaxToken(dialog.GetMessages(), model, maxTokens); var response = await client.GetChatCompletions( dialog, @@ -124,10 +139,19 @@ private static TObject DeserializeOrThrow(JsonSerializerOptions? jsonDe { ArgumentNullException.ThrowIfNull(response); response = response.Trim(); - if (response.StartsWith("```") && response.EndsWith("```")) + if (response.StartsWith("```json") && response.EndsWith("```")) + { + response = response[7..^3]; + } + else if (response.StartsWith("```") && response.EndsWith("```")) { response = response[3..^3]; } + if(!response.StartsWith('{') || !response.EndsWith('}')) + { + var (openBracketIndex, closeBracketIndex) = FindFirstAndLastBracket(response); + response = response[openBracketIndex..(closeBracketIndex + 1)]; + } jsonDeserializerOptions ??= new JsonSerializerOptions { @@ -151,19 +175,41 @@ private static TObject DeserializeOrThrow(JsonSerializerOptions? jsonDe } return deserialized; + + static (int openBracketIndex, int closeBracketIndex) FindFirstAndLastBracket(string response) + { + ArgumentNullException.ThrowIfNull(response); + var openBracketIndex = response.IndexOf('{'); + var closeBracketIndex = response.LastIndexOf('}'); + if (openBracketIndex < 0 || closeBracketIndex < 0) + { + throw new InvalidJsonException( + $"Failed to deserialize response to {typeof(TObject)}. Response: {response}.", response); + } + return (openBracketIndex, closeBracketIndex); + } } - private static string GetAdditionalJsonResponsePrompt(string responseFormat) + private static string GetAdditionalJsonResponsePrompt( + string responseFormat, IEnumerable? examples, JsonSerializerOptions? jsonSerializerOptions) { - return$"\n\nWrite your response in compact JSON format with escaped strings. " + - $"Here is the response structure (JSON Schema): {responseFormat}"; + var res = $"\n\nYour response MUST be STRICTLY compact JSON with escaped strings. " + + $"Here is the response structure (JSON Schema): \n```json{responseFormat}```"; + + if (examples is not null) + { + jsonSerializerOptions ??= JsonDefaultSerializerOptions; + var examplesString = string.Join("\n", examples.Select(it => JsonSerializer.Serialize(it, jsonSerializerOptions))); + res += $"\n\nHere are some examples:\n```json\n{examplesString}\n```"; + } + + return res; } internal static string CreateResponseFormatJson() { var schemaBuilder = new JsonSchemaBuilder(); - var schema = schemaBuilder.FromType(SchemaGeneratorConfiguration - ).Build();; + var schema = schemaBuilder.FromType(SchemaGeneratorConfiguration).Build(); var schemaString = JsonSerializer.Serialize(schema, JsonSchemaSerializerOptions); return schemaString; } diff --git a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs index bd465c3..04f719f 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs @@ -148,7 +148,7 @@ public virtual async Task TranslateObject( requestModifier, rawResponseGetter, jsonDeserializerOptions, - cancellationToken + cancellationToken: cancellationToken ); return response; } diff --git a/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs b/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs index 3afc17a..86cb4bc 100644 --- a/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs +++ b/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs @@ -4,12 +4,7 @@ namespace OpenAI.ChatGpt.IntegrationTests.OpenAiClientTests; public class OpenAiClientGetStructuredResponseTests { - private readonly OpenAiClient _client; - - public OpenAiClientGetStructuredResponseTests() - { - _client = new OpenAiClient(Helpers.GetOpenAiKey()); - } + private readonly OpenAiClient _client = new(Helpers.GetOpenAiKey()); [Fact] public async void Get_simple_structured_response_from_ChatGPT() diff --git a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs index 3713630..9faba4d 100644 --- a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs +++ b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs @@ -26,9 +26,9 @@ public void AddChatGptCoreIntegration_added_expected_services() provider.GetRequiredService>(); provider.GetRequiredService>(); - provider.GetRequiredService(); provider.GetRequiredService(); provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact]