diff --git a/OpenAI_DotNet.sln b/OpenAI_DotNet.sln index 7a7456e..27b812e 100644 --- a/OpenAI_DotNet.sln +++ b/OpenAI_DotNet.sln @@ -33,6 +33,11 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAI.ChatGpt.Modules.Stru EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "modules", "modules", "{068E9E67-C2FC-4F8C-B27C-CB3A8FA44BD8}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "configs", "configs", "{77B5B4CD-2299-4FEE-B6C3-1090A8A8F2C2}" + ProjectSection(SolutionItems) = preProject + src\Directory.Build.props = src\Directory.Build.props + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 35874f0..dd9f6fd 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -1,6 +1,6 @@ - 2.9.0 + 2.9.1 enable enable 12 diff --git a/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs b/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs index 637c044..88539d0 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs @@ -1,16 +1,10 @@ using Microsoft.Extensions.Options; -using OpenAI.ChatGpt.AspNetCore.Models; namespace OpenAI.ChatGpt.AspNetCore; /// /// Factory for creating instances from DI. /// -/// -/// builder.Services.AddHttpClient<ChatGPTFactory<("OpenAiClient") -/// .AddPolicyHandler(GetRetryPolicy()) -/// .AddPolicyHandler(GetCircuitBreakerPolicy()); -/// [Fody.ConfigureAwait(false)] // ReSharper disable once InconsistentNaming public class ChatGPTFactory : IDisposable @@ -23,18 +17,15 @@ public class ChatGPTFactory : IDisposable private volatile bool _ensureStorageCreatedCalled; public ChatGPTFactory( - IHttpClientFactory httpClientFactory, - IOptions credentials, + IOpenAiClient client, IOptions config, IChatHistoryStorage chatHistoryStorage, ITimeProvider clock) { - if (httpClientFactory == null) throw new ArgumentNullException(nameof(httpClientFactory)); - if (credentials?.Value == null) throw new ArgumentNullException(nameof(credentials)); _config = config?.Value ?? throw new ArgumentNullException(nameof(config)); _chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage)); _clock = clock ?? throw new ArgumentNullException(nameof(clock)); - _client = CreateOpenAiClient(httpClientFactory, credentials); + _client = client ?? throw new ArgumentNullException(nameof(client)); _isHttpClientInjected = true; } @@ -65,16 +56,6 @@ public ChatGPTFactory( _clock = clock ?? new TimeProviderUtc(); } - private OpenAiClient CreateOpenAiClient( - IHttpClientFactory httpClientFactory, - IOptions credentials) - { - var httpClient = httpClientFactory.CreateClient(OpenAiClient.HttpClientName); - httpClient.DefaultRequestHeaders.Authorization = credentials.Value.GetAuthHeader(); - httpClient.BaseAddress = new Uri(credentials.Value.ApiHost); - return new OpenAiClient(httpClient); - } - public static ChatGPTFactory CreateInMemory( string apiKey, ChatGPTConfig? config = null, diff --git a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs index b9b002e..c73c09a 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/Extensions/ServiceCollectionExtensions.cs @@ -1,19 +1,18 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; -using OpenAI.ChatGpt.AspNetCore.Models; namespace OpenAI.ChatGpt.AspNetCore.Extensions; public static class ServiceCollectionExtensions { public const string CredentialsConfigSectionPathDefault = "OpenAICredentials"; + // ReSharper disable once InconsistentNaming public const string ChatGPTConfigSectionPathDefault = "ChatGPTConfig"; - - public static IServiceCollection AddChatGptInMemoryIntegration( + + public static IHttpClientBuilder AddChatGptInMemoryIntegration( this IServiceCollection services, bool injectInMemoryChatService = true, - bool injectOpenAiClient = true, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault) { @@ -23,23 +22,23 @@ public static IServiceCollection AddChatGptInMemoryIntegration( throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(completionsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", nameof(completionsConfigSectionPath)); } - services.AddChatGptIntegrationCore( - credentialsConfigSectionPath, - completionsConfigSectionPath, - injectOpenAiClient: injectOpenAiClient - ); + services.AddSingleton(); - if(injectInMemoryChatService) + if (injectInMemoryChatService) { services.AddScoped(CreateChatService); } - - return services; + + return services.AddChatGptIntegrationCore( + credentialsConfigSectionPath, + completionsConfigSectionPath + ); } private static ChatService CreateChatService(IServiceProvider provider) @@ -47,12 +46,13 @@ private static ChatService CreateChatService(IServiceProvider provider) ArgumentNullException.ThrowIfNull(provider); var userId = Guid.Empty.ToString(); var storage = provider.GetRequiredService(); - if(storage is not InMemoryChatHistoryStorage) + if (storage is not InMemoryChatHistoryStorage) { throw new InvalidOperationException( $"Chat injection is supported only with {nameof(InMemoryChatHistoryStorage)} " + $"and is not supported for {storage.GetType().FullName}"); } + /* * .GetAwaiter().GetResult() are safe here because we use sync in memory storage */ @@ -64,13 +64,12 @@ private static ChatService CreateChatService(IServiceProvider provider) return chat; } - public static IServiceCollection AddChatGptIntegrationCore( - this IServiceCollection services, + public static IHttpClientBuilder AddChatGptIntegrationCore( + this IServiceCollection services, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, - ServiceLifetime serviceLifetime = ServiceLifetime.Scoped, - bool injectOpenAiClient = true - ) + ServiceLifetime serviceLifetime = ServiceLifetime.Scoped + ) { ArgumentNullException.ThrowIfNull(services); if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath)) @@ -78,6 +77,7 @@ public static IServiceCollection AddChatGptIntegrationCore( throw new ArgumentException("Value cannot be null or whitespace.", nameof(credentialsConfigSectionPath)); } + if (string.IsNullOrWhiteSpace(completionsConfigSectionPath)) { throw new ArgumentException("Value cannot be null or whitespace.", @@ -94,33 +94,18 @@ public static IServiceCollection AddChatGptIntegrationCore( .ValidateDataAnnotations() .ValidateOnStart(); - if (services.All(it => it.ServiceType != typeof(IHttpClientFactory))) - { - services.AddHttpClient(OpenAiClient.HttpClientName); - } - services.AddSingleton(); services.Add(new ServiceDescriptor(typeof(ChatGPTFactory), typeof(ChatGPTFactory), serviceLifetime)); - if (injectOpenAiClient) - { - AddOpenAiClient(services); - } - - return services; + return AddOpenAiClient(services); } - private static void AddOpenAiClient(IServiceCollection services) + private static IHttpClientBuilder AddOpenAiClient(IServiceCollection services) { - services.AddSingleton(provider => + return services.AddHttpClient((provider, httpClient) => { var credentials = provider.GetRequiredService>().Value; - var factory = provider.GetRequiredService(); - var httpClient = factory.CreateClient(OpenAiClient.HttpClientName); - httpClient.DefaultRequestHeaders.Authorization = credentials.GetAuthHeader(); - httpClient.BaseAddress = new Uri(credentials.ApiHost); - var client = new OpenAiClient(httpClient); - return client; + credentials.SetupHttpClient(httpClient); }); } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs b/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs index e02d36f..ca0bb04 100644 --- a/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs +++ b/src/OpenAI.ChatGpt.AspNetCore/Models/OpenAICredentials.cs @@ -26,4 +26,11 @@ public AuthenticationHeaderValue GetAuthHeader() { return new AuthenticationHeaderValue("Bearer", ApiKey); } + + public void SetupHttpClient(HttpClient httpClient) + { + ArgumentNullException.ThrowIfNull(httpClient); + httpClient.DefaultRequestHeaders.Authorization = GetAuthHeader(); + httpClient.BaseAddress = new Uri(ApiHost); + } } \ No newline at end of file diff --git a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs index 6cb201d..8378aa3 100644 --- a/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs +++ b/src/OpenAI.ChatGpt.EntityFrameworkCore/Extensions/ServiceCollectionExtensions.cs @@ -9,13 +9,12 @@ public static class ServiceCollectionExtensions /// /// Adds the implementation using Entity Framework Core. /// - public static IServiceCollection AddChatGptEntityFrameworkIntegration( + public static IHttpClientBuilder AddChatGptEntityFrameworkIntegration( this IServiceCollection services, Action optionsAction, string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault, string completionsConfigSectionPath = ChatGPTConfigSectionPathDefault, - ServiceLifetime serviceLifetime = ServiceLifetime.Scoped, - bool injectOpenAiClient = true) + ServiceLifetime serviceLifetime = ServiceLifetime.Scoped) { ArgumentNullException.ThrowIfNull(services); ArgumentNullException.ThrowIfNull(optionsAction); @@ -30,8 +29,6 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration( nameof(completionsConfigSectionPath)); } - services.AddChatGptIntegrationCore( - credentialsConfigSectionPath, completionsConfigSectionPath, serviceLifetime, injectOpenAiClient); services.AddDbContext(optionsAction, serviceLifetime); switch (serviceLifetime) { @@ -48,6 +45,7 @@ public static IServiceCollection AddChatGptEntityFrameworkIntegration( throw new ArgumentOutOfRangeException(nameof(serviceLifetime), serviceLifetime, null); } - return services; + return services.AddChatGptIntegrationCore( + credentialsConfigSectionPath, completionsConfigSectionPath, serviceLifetime); } } \ No newline at end of file diff --git a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs index e182c20..abaf2d3 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.StructuredResponse/OpenAiClientExtensions.GetStructuredResponse.cs @@ -4,6 +4,7 @@ using Json.Schema.Generation; using OpenAI.ChatGpt.Models.ChatCompletion; using OpenAI.ChatGpt.Models.ChatCompletion.Messaging; +using static OpenAI.ChatGpt.Models.ChatCompletion.Messaging.ChatCompletionMessage; namespace OpenAI.ChatGpt.Modules.StructuredResponse; @@ -21,6 +22,12 @@ public static class OpenAiClientExtensions WriteIndented = false }; + private static readonly JsonSerializerOptions JsonDefaultSerializerOptions = new() + { + Converters = { new JsonStringEnumConverter() }, + WriteIndented = false + }; + /// /// Asynchronously sends a chat completion request to the OpenAI API and deserializes the response to a specific object type. /// @@ -34,6 +41,8 @@ public static class OpenAiClientExtensions /// Optional. A function that can modify the chat completion request before it is sent to the API. /// Optional. A function that can access the raw API response. /// Optional. Custom JSON deserializer options for the deserialization. If not specified, default options with case insensitive property names are used. + /// Optional. Custom JSON serializer options for the serialization. + /// Optional. Example of the models those will be serialized using /// Optional. A cancellation token that can be used to cancel the operation. /// /// A task that represents the asynchronous operation. The task result contains the deserialized object from the API response. @@ -54,6 +63,8 @@ public static Task GetStructuredResponse( Action? requestModifier = null, Action? rawResponseGetter = null, JsonSerializerOptions? jsonDeserializerOptions = null, + JsonSerializerOptions? jsonSerializerOptions = null, + IEnumerable? examples = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(client); @@ -70,6 +81,8 @@ public static Task GetStructuredResponse( requestModifier: requestModifier, rawResponseGetter: rawResponseGetter, jsonDeserializerOptions: jsonDeserializerOptions, + jsonSerializerOptions: jsonSerializerOptions, + examples: examples, cancellationToken: cancellationToken ); } @@ -85,20 +98,22 @@ internal static async Task GetStructuredResponse( Action? requestModifier = null, Action? rawResponseGetter = null, JsonSerializerOptions? jsonDeserializerOptions = null, + JsonSerializerOptions? jsonSerializerOptions = null, + IEnumerable? examples = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(client); ArgumentNullException.ThrowIfNull(dialog); - var editMsg = dialog.GetMessages().FirstOrDefault(it => it is SystemMessage) + var editMsg = dialog.GetMessages() + .FirstOrDefault(it => it is SystemMessage) ?? dialog.GetMessages()[0]; var originalContent = editMsg.Content; try { - editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat); + editMsg.Content += GetAdditionalJsonResponsePrompt(responseFormat, examples, jsonSerializerOptions); - (model, maxTokens) = - ChatCompletionMessage.FindOptimalModelAndMaxToken(dialog.GetMessages(), model, maxTokens); + (model, maxTokens) = FindOptimalModelAndMaxToken(dialog.GetMessages(), model, maxTokens); var response = await client.GetChatCompletions( dialog, @@ -124,10 +139,19 @@ private static TObject DeserializeOrThrow(JsonSerializerOptions? jsonDe { ArgumentNullException.ThrowIfNull(response); response = response.Trim(); - if (response.StartsWith("```") && response.EndsWith("```")) + if (response.StartsWith("```json") && response.EndsWith("```")) + { + response = response[7..^3]; + } + else if (response.StartsWith("```") && response.EndsWith("```")) { response = response[3..^3]; } + if(!response.StartsWith('{') || !response.EndsWith('}')) + { + var (openBracketIndex, closeBracketIndex) = FindFirstAndLastBracket(response); + response = response[openBracketIndex..(closeBracketIndex + 1)]; + } jsonDeserializerOptions ??= new JsonSerializerOptions { @@ -151,19 +175,41 @@ private static TObject DeserializeOrThrow(JsonSerializerOptions? jsonDe } return deserialized; + + static (int openBracketIndex, int closeBracketIndex) FindFirstAndLastBracket(string response) + { + ArgumentNullException.ThrowIfNull(response); + var openBracketIndex = response.IndexOf('{'); + var closeBracketIndex = response.LastIndexOf('}'); + if (openBracketIndex < 0 || closeBracketIndex < 0) + { + throw new InvalidJsonException( + $"Failed to deserialize response to {typeof(TObject)}. Response: {response}.", response); + } + return (openBracketIndex, closeBracketIndex); + } } - private static string GetAdditionalJsonResponsePrompt(string responseFormat) + private static string GetAdditionalJsonResponsePrompt( + string responseFormat, IEnumerable? examples, JsonSerializerOptions? jsonSerializerOptions) { - return$"\n\nWrite your response in compact JSON format with escaped strings. " + - $"Here is the response structure (JSON Schema): {responseFormat}"; + var res = $"\n\nYour response MUST be STRICTLY compact JSON with escaped strings. " + + $"Here is the response structure (JSON Schema): \n```json{responseFormat}```"; + + if (examples is not null) + { + jsonSerializerOptions ??= JsonDefaultSerializerOptions; + var examplesString = string.Join("\n", examples.Select(it => JsonSerializer.Serialize(it, jsonSerializerOptions))); + res += $"\n\nHere are some examples:\n```json\n{examplesString}\n```"; + } + + return res; } internal static string CreateResponseFormatJson() { var schemaBuilder = new JsonSchemaBuilder(); - var schema = schemaBuilder.FromType(SchemaGeneratorConfiguration - ).Build();; + var schema = schemaBuilder.FromType(SchemaGeneratorConfiguration).Build(); var schemaString = JsonSerializer.Serialize(schema, JsonSchemaSerializerOptions); return schemaString; } diff --git a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs index bd465c3..04f719f 100644 --- a/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs +++ b/src/modules/OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs @@ -148,7 +148,7 @@ public virtual async Task TranslateObject( requestModifier, rawResponseGetter, jsonDeserializerOptions, - cancellationToken + cancellationToken: cancellationToken ); return response; } diff --git a/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs b/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs index 3afc17a..86cb4bc 100644 --- a/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs +++ b/tests/OpenAI.ChatGpt.IntegrationTests/OpenAiClientTests/OpenAiClient_GetStructuredResponse.cs @@ -4,12 +4,7 @@ namespace OpenAI.ChatGpt.IntegrationTests.OpenAiClientTests; public class OpenAiClientGetStructuredResponseTests { - private readonly OpenAiClient _client; - - public OpenAiClientGetStructuredResponseTests() - { - _client = new OpenAiClient(Helpers.GetOpenAiKey()); - } + private readonly OpenAiClient _client = new(Helpers.GetOpenAiKey()); [Fact] public async void Get_simple_structured_response_from_ChatGPT() diff --git a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs index 3713630..9faba4d 100644 --- a/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs +++ b/tests/OpenAI.ChatGpt.UnitTests/DependencyInjectionTests/ChatGptServicesIntegrationTests.cs @@ -26,9 +26,9 @@ public void AddChatGptCoreIntegration_added_expected_services() provider.GetRequiredService>(); provider.GetRequiredService>(); - provider.GetRequiredService(); provider.GetRequiredService(); provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact]