-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Use Microsoft.Extensions.AI (#66)
- Loading branch information
1 parent
40652b9
commit 27609c0
Showing
36 changed files
with
228 additions
and
566 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,37 @@ | ||
using Cellm.Prompts; | ||
using Cellm.AddIn; | ||
using Cellm.Tools; | ||
using MediatR; | ||
using Microsoft.Extensions.AI; | ||
using Microsoft.Extensions.Options; | ||
|
||
namespace Cellm.Models.PipelineBehavior; | ||
namespace Cellm.Models.ModelRequestBehavior; | ||
|
||
internal class ToolBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse> | ||
where TRequest : IModelRequest<TResponse> | ||
where TResponse : IModelResponse | ||
{ | ||
private readonly ISender _sender; | ||
private readonly ToolRunner _toolRunner; | ||
private readonly CellmConfiguration _cellmConfiguration; | ||
private readonly Functions _functions; | ||
private readonly List<AITool> _tools; | ||
|
||
public ToolBehavior(ISender sender, ToolRunner toolRunner) | ||
public ToolBehavior(IOptions<CellmConfiguration> cellmConfiguration, Functions functions) | ||
{ | ||
_sender = sender; | ||
_toolRunner = toolRunner; | ||
_cellmConfiguration = cellmConfiguration.Value; | ||
_functions = functions; | ||
_tools = [ | ||
AIFunctionFactory.Create(_functions.GlobRequest), | ||
AIFunctionFactory.Create(_functions.FileReaderRequest) | ||
]; | ||
} | ||
|
||
public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken) | ||
{ | ||
var response = await next(); | ||
|
||
var toolCalls = response.Prompt.Messages.LastOrDefault()?.ToolCalls; | ||
|
||
if (toolCalls is not null) | ||
if (_cellmConfiguration.EnableTools) | ||
{ | ||
// Model called tools, run tools and call model again | ||
var message = await RunTools(toolCalls); | ||
request.Prompt.Messages.Add(message); | ||
response = await _sender.Send(request, cancellationToken); | ||
request.Prompt.Options.Tools = _tools; | ||
} | ||
|
||
return response; | ||
} | ||
|
||
private async Task<Message> RunTools(List<ToolCall> toolCalls) | ||
{ | ||
var toolResults = await Task.WhenAll(toolCalls.Select(x => _toolRunner.Run(x))); | ||
var toolCallsWithResults = toolCalls | ||
.Zip(toolResults, (toolCall, toolResult) => toolCall with { Result = toolResult }) | ||
.ToList(); | ||
var response = await next(); | ||
|
||
return new Message(string.Empty, Roles.Tool, toolCallsWithResults); | ||
return response; | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,48 @@ | ||
using System.Text; | ||
using System.Text.Encodings.Web; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization; | ||
using Cellm.AddIn; | ||
using Cellm.AddIn.Exceptions; | ||
using Cellm.Models.OpenAi.Models; | ||
using System.ClientModel; | ||
using System.ClientModel.Primitives; | ||
using Cellm.Prompts; | ||
using Cellm.Tools; | ||
using Microsoft.Extensions.AI; | ||
using Microsoft.Extensions.Options; | ||
using OpenAI; | ||
|
||
namespace Cellm.Models.OpenAi; | ||
|
||
internal class OpenAiRequestHandler : IModelRequestHandler<OpenAiRequest, OpenAiResponse> | ||
internal class OpenAiRequestHandler(IOptions<OpenAiConfiguration> openAiConfiguration, HttpClient httpClient) : IModelRequestHandler<OpenAiRequest, OpenAiResponse> | ||
{ | ||
private readonly OpenAiConfiguration _openAiConfiguration; | ||
private readonly CellmConfiguration _cellmConfiguration; | ||
private readonly HttpClient _httpClient; | ||
private readonly ToolRunner _toolRunner; | ||
private readonly Serde _serde; | ||
|
||
public OpenAiRequestHandler( | ||
IOptions<OpenAiConfiguration> openAiConfiguration, | ||
IOptions<CellmConfiguration> cellmConfiguration, | ||
HttpClient httpClient, | ||
ToolRunner toolRunner, | ||
Serde serde) | ||
{ | ||
_openAiConfiguration = openAiConfiguration.Value; | ||
_cellmConfiguration = cellmConfiguration.Value; | ||
_httpClient = httpClient; | ||
_toolRunner = toolRunner; | ||
_serde = serde; | ||
} | ||
private readonly OpenAiConfiguration _openAiConfiguration = openAiConfiguration.Value; | ||
|
||
public async Task<OpenAiResponse> Handle(OpenAiRequest request, CancellationToken cancellationToken) | ||
{ | ||
var modelId = request.Prompt.Options.ModelId ?? _openAiConfiguration.DefaultModel; | ||
|
||
const string path = "/v1/chat/completions"; | ||
var address = request.BaseAddress is null ? new Uri(path, UriKind.Relative) : new Uri(request.BaseAddress, path); | ||
|
||
var json = Serialize(request); | ||
var jsonAsStringContent = new StringContent(json, Encoding.UTF8, "application/json"); | ||
|
||
var response = await _httpClient.PostAsync(address, jsonAsStringContent, cancellationToken); | ||
var responseBodyAsString = await response.Content.ReadAsStringAsync(cancellationToken); | ||
// Must instantiate manually because address can be set/changed only at instantiation | ||
var chatClient = GetChatClient(address, modelId); | ||
var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken); | ||
|
||
if (!response.IsSuccessStatusCode) | ||
{ | ||
throw new HttpRequestException($"{nameof(OpenAiRequest)} failed: {responseBodyAsString}", null, response.StatusCode); | ||
} | ||
|
||
return Deserialize(request, responseBodyAsString); | ||
} | ||
|
||
public string Serialize(OpenAiRequest request) | ||
{ | ||
var openAiPrompt = new PromptBuilder(request.Prompt) | ||
.AddSystemMessage() | ||
var prompt = new PromptBuilder(request.Prompt) | ||
.AddMessage(chatCompletion.Message) | ||
.Build(); | ||
|
||
var chatCompletionRequest = new OpenAiChatCompletionRequest( | ||
openAiPrompt.Model, | ||
openAiPrompt.ToOpenAiMessages(), | ||
_cellmConfiguration.MaxOutputTokens, | ||
openAiPrompt.Temperature, | ||
_toolRunner.ToOpenAiTools(), | ||
"auto"); | ||
|
||
return _serde.Serialize(chatCompletionRequest, new JsonSerializerOptions | ||
{ | ||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, | ||
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, | ||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull | ||
}); | ||
return new OpenAiResponse(prompt); | ||
} | ||
|
||
public OpenAiResponse Deserialize(OpenAiRequest request, string responseBodyAsString) | ||
private IChatClient GetChatClient(Uri address, string modelId) | ||
{ | ||
var responseBody = _serde.Deserialize<OpenAiChatCompletionResponse>(responseBodyAsString, new JsonSerializerOptions | ||
var openAiClientCredentials = new ApiKeyCredential(_openAiConfiguration.ApiKey); | ||
var openAiClientOptions = new OpenAIClientOptions | ||
{ | ||
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, | ||
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, | ||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull | ||
}); | ||
|
||
var choice = responseBody?.Choices?.FirstOrDefault() ?? throw new CellmException("Empty response from OpenAI API"); | ||
var toolCalls = choice.Message.ToolCalls? | ||
.Select(x => new ToolCall(x.Id, x.Function.Name, x.Function.Arguments, null)) | ||
.ToList(); | ||
Transport = new HttpClientPipelineTransport(httpClient), | ||
Endpoint = address | ||
}; | ||
|
||
var content = choice.Message.Content; | ||
var message = new Message(content, Roles.Assistant, toolCalls); | ||
var openAiClient = new OpenAIClient(openAiClientCredentials, openAiClientOptions); | ||
|
||
var prompt = new PromptBuilder(request.Prompt) | ||
.AddMessage(message) | ||
.Build(); | ||
|
||
return new OpenAiResponse(prompt); | ||
return new ChatClientBuilder() | ||
.UseLogging() | ||
.UseFunctionInvocation() | ||
.Use(openAiClient.AsChatClient(modelId)); | ||
} | ||
} |
Oops, something went wrong.