Skip to content

Commit

Permalink
Restore json support in regular GPT4 and GPT3.5 models
Browse files Browse the repository at this point in the history
  • Loading branch information
rodion-m committed Nov 10, 2023
1 parent ae076a1 commit 743dd83
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project>
<PropertyGroup>
<Version>3.0.0</Version>
<Version>3.1.0</Version>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<LangVersion>12</LangVersion>
Expand Down
35 changes: 28 additions & 7 deletions src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionModels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public static class ChatCompletionModels
/// This model has a maximum token limit of 4,096.
/// The model was trained with data up to September 2021.
/// </summary>
[Obsolete("Legacy. Snapshot of gpt-3.5-turbo from June 13th 2023. Will be deprecated on June 13, 2024.")]
public const string Gpt3_5_Turbo_0613 = "gpt-3.5-turbo-0613";

/// <summary>
Expand All @@ -93,6 +94,7 @@ public static class ChatCompletionModels
/// This model has a maximum token limit of 16,384.
/// The model was trained with data up to September 2021.
/// </summary>
[Obsolete("Legacy. Snapshot of gpt-3.5-16k-turbo from June 13th 2023. Will be deprecated on June 13, 2024.")]
public const string Gpt3_5_Turbo_16k_0613 = "gpt-3.5-turbo-16k-0613";

/// <summary>
Expand All @@ -101,26 +103,29 @@ public static class ChatCompletionModels
/// Unlike gpt-4, this model will not receive updates,
/// and will only be supported for a three month period ending on June 14th 2023.
/// </summary>
[Obsolete("DISCONTINUATION DATE 09/13/2023")]
[Obsolete("Legacy. Snapshot of gpt-4 from March 14th 2023 with function calling support. This model version will be deprecated on June 13th 2024. Use Gpt4 instead.")]
public const string Gpt4_0314 = "gpt-4-0314";

/// <summary>
/// Snapshot of gpt-4-32 from March 14th 2023.
/// Unlike gpt-4-32k, this model will not receive updates,
/// and will only be supported for a three month period ending on June 14th 2023.
/// </summary>
[Obsolete("DISCONTINUATION DATE 09/13/2023. This model is available only by request. " +
"Link for joining waitlist: https://openai.com/waitlist/gpt-4-api")]
[Obsolete("Legacy. Snapshot of gpt-4-32k from March 14th 2023 with function calling support. This model version will be deprecated on June 13th 2024. Use Gpt432k instead.")]
public const string Gpt4_32k_0314 = "gpt-4-32k-0314";

/// <summary>
/// Snapshot of gpt-3.5-turbo from March 1st 2023.
/// Unlike gpt-3.5-turbo, this model will not receive updates,
/// and will only be supported for a three month period ending on June 1st 2023.
/// </summary>
[Obsolete("DISCONTINUATION DATE 09/13/2023")]
[Obsolete("Snapshot of gpt-3.5-turbo from March 1st 2023. Will be deprecated on June 13th 2024. Use Gpt3_5_Turbo instead.")]
public const string Gpt3_5_Turbo_0301 = "gpt-3.5-turbo-0301";

private static readonly string[] ModelsSupportedJson = {
Gpt4Turbo, Gpt3_5_Turbo_1106
};

/// <summary>
/// The maximum number of tokens that can be processed by the model.
/// </summary>
Expand All @@ -132,10 +137,10 @@ public static class ChatCompletionModels
{ Gpt4_32k, 32_768 },
{ Gpt4_32k_0613, 32_768 },
{ Gpt3_5_Turbo, 4096 },
{ Gpt3_5_Turbo_1106, 16385 },
{ Gpt3_5_Turbo_16k, 16_384 },
{ Gpt3_5_Turbo_1106, 4096 },
{ Gpt3_5_Turbo_16k, 16_385 },
{ Gpt3_5_Turbo_0613, 4096 },
{ Gpt3_5_Turbo_16k_0613, 16_384 },
{ Gpt3_5_Turbo_16k_0613, 16_385 },
{ Gpt4_0314, 8192 },
{ Gpt4_32k_0314, 32_768 },
{ Gpt3_5_Turbo_0301, 4096 },
Expand Down Expand Up @@ -222,4 +227,20 @@ public static void EnsureMaxTokensIsSupportedByAnyModel(int maxTokens)
nameof(maxTokens), $"Max tokens must be less than or equal to {limit} but was {maxTokens}");
}
}

/// <summary>
/// Checks if the model name is supported for JSON mode
/// </summary>
/// <param name="model">GPT model name</param>
/// <returns>True if the model is supported for JSON mode</returns>
public static bool IsJsonModeSupported(string model)
{
ArgumentNullException.ThrowIfNull(model);
return Array.IndexOf(ModelsSupportedJson, model) != -1;
}

internal static IReadOnlyList<string> GetModelsThatSupportJsonMode()
{
return ModelsSupportedJson;
}
}
15 changes: 13 additions & 2 deletions src/OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public int MaxTokens
/// An object specifying the format that the model must output.
/// </summary>
[JsonPropertyName("response_format")]
public ChatCompletionResponseFormat ResponseFormat { get; set; } = new();
public ChatCompletionResponseFormat ResponseFormat { get; set; } = new(false);

/// <summary>
/// This feature is in Beta.
Expand All @@ -158,6 +158,11 @@ public int MaxTokens

public class ChatCompletionResponseFormat
{
public ChatCompletionResponseFormat(bool jsonMode)
{
Type = jsonMode ? ResponseTypes.JsonObject : ResponseTypes.Text;
}

/// <summary>
/// Setting to `json_object` enables JSON mode. This guarantees that the message the model generates is valid JSON.
/// Note that your system prompt must still instruct the model to produce JSON, and to help ensure you don't forget,
Expand All @@ -167,6 +172,12 @@ public class ChatCompletionResponseFormat
/// Must be one of `text` or `json_object`.
/// </summary>
[JsonPropertyName("type")]
public string Type { get; set; } = "text";
public string Type { get; set; }
}

internal static class ResponseTypes
{
public const string Text = "text";
public const string JsonObject = "json_object";
}
}
30 changes: 23 additions & 7 deletions src/OpenAI.ChatGpt/OpenAiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

namespace OpenAI.ChatGpt;

/// <summary> Thread-safe OpenAI client. </summary>
/// <summary>Thread-safe OpenAI client.</summary>
/// <remarks>https://github.com/openai/openai-openapi/blob/master/openapi.yaml</remarks>
[Fody.ConfigureAwait(false)]
public class OpenAiClient : IOpenAiClient, IDisposable
{
internal const string HttpClientName = "OpenAiClient";
private const string DefaultHost = "https://api.openai.com/v1/";
private const string ChatCompletionsEndpoint = "chat/completions";

Expand Down Expand Up @@ -142,6 +142,7 @@ public async Task<string> GetChatCompletions(
{
if (dialog == null) throw new ArgumentNullException(nameof(dialog));
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(
dialog.GetMessages(),
Expand Down Expand Up @@ -174,6 +175,7 @@ public async Task<string> GetChatCompletions(
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(
messages,
Expand Down Expand Up @@ -205,6 +207,7 @@ public async Task<ChatCompletionResponse> GetChatCompletionsRaw(
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(
messages,
Expand All @@ -225,7 +228,7 @@ internal async Task<ChatCompletionResponse> GetChatCompletionsRaw(
ChatCompletionRequest request,
CancellationToken cancellationToken = default)
{
if (request == null) throw new ArgumentNullException(nameof(request));
ArgumentNullException.ThrowIfNull(request);
ThrowIfDisposed();
var response = await _httpClient.PostAsJsonAsync(
ChatCompletionsEndpoint,
Expand Down Expand Up @@ -258,6 +261,7 @@ public IAsyncEnumerable<string> StreamChatCompletions(
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(
messages,
Expand Down Expand Up @@ -292,10 +296,7 @@ private static ChatCompletionRequest CreateChatCompletionRequest(
Stream = stream,
User = user,
Temperature = temperature,
ResponseFormat = new ChatCompletionRequest.ChatCompletionResponseFormat()
{
Type = jsonMode ? "json_object" : "text"
},
ResponseFormat = new ChatCompletionRequest.ChatCompletionResponseFormat(jsonMode),
Seed = seed,
};
requestModifier?.Invoke(request);
Expand All @@ -316,6 +317,7 @@ public IAsyncEnumerable<string> StreamChatCompletions(
{
if (messages == null) throw new ArgumentNullException(nameof(messages));
if (model == null) throw new ArgumentNullException(nameof(model));
EnsureJsonModeIsSupported(model, jsonMode);
ThrowIfDisposed();
var request = CreateChatCompletionRequest(messages.GetMessages(),
maxTokens,
Expand All @@ -335,7 +337,9 @@ public async IAsyncEnumerable<string> StreamChatCompletions(
ChatCompletionRequest request,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(request);
if (request == null) throw new ArgumentNullException(nameof(request));
EnsureJsonModeIsSupported(request.Model, request.ResponseFormat.Type == ChatCompletionRequest.ResponseTypes.JsonObject);
ThrowIfDisposed();
request.Stream = true;
await foreach (var response in StreamChatCompletionsRaw(request, cancellationToken))
Expand All @@ -351,6 +355,7 @@ public IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
ChatCompletionRequest request, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(request);
EnsureJsonModeIsSupported(request.Model, request.ResponseFormat.Type == ChatCompletionRequest.ResponseTypes.JsonObject);
ThrowIfDisposed();
request.Stream = true;
return _httpClient.StreamUsingServerSentEvents<ChatCompletionRequest, ChatCompletionResponse>
Expand All @@ -361,4 +366,15 @@ public IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
cancellationToken
);
}

private static void EnsureJsonModeIsSupported(string model, bool jsonMode)
{
if(jsonMode && !ChatCompletionModels.IsJsonModeSupported(model))
{
throw new NotSupportedException(
$"Model {model} does not support JSON mode. " +
$"Supported models are: {string.Join(", ", ChatCompletionModels.GetModelsThatSupportJsonMode())}"
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ public static Task<TObject> GetStructuredResponse<TObject>(
ArgumentNullException.ThrowIfNull(dialog);
var responseFormat = CreateResponseFormatJson<TObject>();

return client.GetStructuredResponse<TObject>(
return GetStructuredResponse<TObject>(
client,
dialog: dialog,
responseFormat: responseFormat,
maxTokens: maxTokens,
Expand Down Expand Up @@ -121,7 +122,7 @@ internal static async Task<TObject> GetStructuredResponse<TObject>(
model,
temperature,
user,
true,
ChatCompletionModels.IsJsonModeSupported(model),
null,
requestModifier,
rawResponseGetter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,35 @@ public class OpenAiClientGetStructuredResponseTests
{
private readonly OpenAiClient _client = new(Helpers.GetOpenAiKey());

[Fact]
public async void Get_simple_structured_response_from_ChatGPT()
[Theory]
[InlineData(ChatCompletionModels.Gpt3_5_Turbo)]
[InlineData(ChatCompletionModels.Gpt4Turbo)]
[InlineData(ChatCompletionModels.Gpt4)]
[InlineData(ChatCompletionModels.Gpt3_5_Turbo_1106)]
public async void Get_simple_structured_response_from_ChatGPT(string model)
{
var message =
Dialog.StartAsSystem("What did user input?")
.ThenUser("My name is John, my age is 30, my email is [email protected]");
var response = await _client.GetStructuredResponse<UserInfo>(message, model: ChatCompletionModels.Gpt4Turbo);
var examples = new []{ new UserInfo() {Age = 0, Email = "[email protected]", Name = "Rodion"} };
var response = await _client.GetStructuredResponse<UserInfo>(message, model: model, examples: examples);
response.Should().NotBeNull();
response.Name.Should().Be("John");
response.Age.Should().Be(30);
response.Email.Should().Be("[email protected]");
}

[Fact]
public async void Get_structured_response_with_ARRAY_from_ChatGPT()
[Theory]
[InlineData(ChatCompletionModels.Gpt4Turbo)]
[InlineData(ChatCompletionModels.Gpt4)]
[InlineData(ChatCompletionModels.Gpt3_5_Turbo_1106)]
public async void Get_structured_response_with_ARRAY_from_ChatGPT(string model)
{
var message = Dialog
.StartAsSystem("What did user input?")
.ThenUser("My name is John, my age is 30, my email is [email protected]. " +
"I want to buy 2 apple and 3 orange.");
var response = await _client.GetStructuredResponse<Order>(message, model: ChatCompletionModels.Gpt4Turbo);
var response = await _client.GetStructuredResponse<Order>(message, model: model);
response.Should().NotBeNull();
response.UserInfo.Should().NotBeNull();
response.UserInfo!.Name.Should().Be("John");
Expand All @@ -40,24 +48,30 @@ public async void Get_structured_response_with_ARRAY_from_ChatGPT()
response.Items[1].Quantity.Should().Be(3);
}

[Fact]
public async void Get_structured_response_with_ENUM_from_ChatGPT()
[Theory]
[InlineData(ChatCompletionModels.Gpt4Turbo)]
[InlineData(ChatCompletionModels.Gpt4)]
[InlineData(ChatCompletionModels.Gpt3_5_Turbo_1106)]
public async void Get_structured_response_with_ENUM_from_ChatGPT(string model)
{
var message = Dialog
.StartAsSystem("What did user input?")
.ThenUser("Мой любимый цвет - красный");
var response = await _client.GetStructuredResponse<Thing>(message, model: ChatCompletionModels.Gpt4Turbo);
var response = await _client.GetStructuredResponse<Thing>(message, model: model);
response.Should().NotBeNull();
response.Color.Should().Be(Thing.Colors.Red);
}

[Fact]
public async void Get_structured_response_with_extra_data_from_ChatGPT()
[Theory]
[InlineData(ChatCompletionModels.Gpt4Turbo)]
[InlineData(ChatCompletionModels.Gpt4)]
[InlineData(ChatCompletionModels.Gpt3_5_Turbo_1106)]
public async void Get_structured_response_with_extra_data_from_ChatGPT(string model)
{
var message = Dialog
.StartAsSystem("Return requested data.")
.ThenUser("I need info about Almaty city");
var response = await _client.GetStructuredResponse<City>(message, model: ChatCompletionModels.Gpt4Turbo);
var response = await _client.GetStructuredResponse<City>(message, model: model);
response.Should().NotBeNull();
response.Name.Should().Be("Almaty");
response.Country.Should().Be("Kazakhstan");
Expand Down

0 comments on commit 743dd83

Please sign in to comment.