Skip to content

Commit

Permalink
feat: Use Microsoft.Extensions.AI (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaspermarstal authored Nov 21, 2024
1 parent 40652b9 commit 27609c0
Show file tree
Hide file tree
Showing 36 changed files with 228 additions and 566 deletions.
2 changes: 2 additions & 0 deletions src/Cellm/AddIn/CellmConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ public class CellmConfiguration
public int HttpTimeoutInSeconds { get; init; }

public int CacheTimeoutInSeconds { get; init; }

public bool EnableTools { get; init; }
}

2 changes: 1 addition & 1 deletion src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@ private static async Task<string> CallModelAsync(Prompt prompt, string? provider
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
return response.Messages.Last().Content;
return response.Messages.Last().Text ?? throw new NullReferenceException("No text response");
}
}
5 changes: 5 additions & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
<PackageReference Include="JsonPatch.Net" Version="3.1.1" />
<PackageReference Include="JsonSchema.Net.Generation" Version="4.5.1" />
<PackageReference Include="MediatR" Version="12.4.1" />
<PackageReference Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Ollama" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Hybrid" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="8.0.1" />
Expand Down
9 changes: 5 additions & 4 deletions src/Cellm/Models/Anthropic/AnthropicRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic.Models;
using Cellm.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Anthropic;
Expand Down Expand Up @@ -52,11 +53,11 @@ public string Serialize(AnthropicRequest request)
{
var requestBody = new AnthropicRequestBody
{
System = request.Prompt.SystemMessage,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Content, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Model ?? _anthropicConfiguration.DefaultModel,
System = request.Prompt.Messages.Where(x => x.Role == ChatRole.System).First().Text,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Text, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Options.ModelId ?? _anthropicConfiguration.DefaultModel,
MaxTokens = _cellmConfiguration.MaxOutputTokens,
Temperature = request.Prompt.Temperature
Temperature = request.Prompt.Options.Temperature ?? _cellmConfiguration.DefaultTemperature,
};

return _serde.Serialize(requestBody, new JsonSerializerOptions
Expand Down
3 changes: 0 additions & 3 deletions src/Cellm/Models/IModelRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,4 @@ namespace Cellm.Models;
internal interface IModelRequestHandler<TRequest, TResponse> : IRequestHandler<TRequest, TResponse>
where TRequest : IRequest<TResponse>
{
public string Serialize(TRequest request);

public TResponse Deserialize(TRequest request, string response);
}
4 changes: 2 additions & 2 deletions src/Cellm/Models/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ public LlamafileRequestHandler(IOptions<CellmConfiguration> cellmConfiguration,

public async Task<LlamafileResponse> Handle(LlamafileRequest request, CancellationToken cancellationToken)
{
// Download model and start Llamafile on first call
var llamafile = await _llamafiles[request.Prompt.Model ?? _llamafileConfiguration.DefaultModel];
// Start server on first call
var llamafile = await _llamafiles[request.Prompt.Options.ModelId ?? _llamafileConfiguration.DefaultModel];

var openAiResponse = await _sender.Send(new OpenAiRequest(request.Prompt, nameof(Llamafile), llamafile.BaseAddress), cancellationToken);

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class CachingBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/ModelRequestBehavior/SentryBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MediatR;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class SentryBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
Expand Down
45 changes: 18 additions & 27 deletions src/Cellm/Models/ModelRequestBehavior/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
using Cellm.Prompts;
using Cellm.AddIn;
using Cellm.Tools;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class ToolBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly ISender _sender;
private readonly ToolRunner _toolRunner;
private readonly CellmConfiguration _cellmConfiguration;
private readonly Functions _functions;
private readonly List<AITool> _tools;

public ToolBehavior(ISender sender, ToolRunner toolRunner)
public ToolBehavior(IOptions<CellmConfiguration> cellmConfiguration, Functions functions)
{
_sender = sender;
_toolRunner = toolRunner;
_cellmConfiguration = cellmConfiguration.Value;
_functions = functions;
_tools = [
AIFunctionFactory.Create(_functions.GlobRequest),
AIFunctionFactory.Create(_functions.FileReaderRequest)
];
}

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
var response = await next();

var toolCalls = response.Prompt.Messages.LastOrDefault()?.ToolCalls;

if (toolCalls is not null)
if (_cellmConfiguration.EnableTools)
{
// Model called tools, run tools and call model again
var message = await RunTools(toolCalls);
request.Prompt.Messages.Add(message);
response = await _sender.Send(request, cancellationToken);
request.Prompt.Options.Tools = _tools;
}

return response;
}

private async Task<Message> RunTools(List<ToolCall> toolCalls)
{
var toolResults = await Task.WhenAll(toolCalls.Select(x => _toolRunner.Run(x)));
var toolCallsWithResults = toolCalls
.Zip(toolResults, (toolCall, toolResult) => toolCall with { Result = toolResult })
.ToList();
var response = await next();

return new Message(string.Empty, Roles.Tool, toolCallsWithResults);
return response;
}
}
50 changes: 0 additions & 50 deletions src/Cellm/Models/OpenAi/Extensions.cs

This file was deleted.

53 changes: 0 additions & 53 deletions src/Cellm/Models/OpenAi/Models.cs

This file was deleted.

3 changes: 0 additions & 3 deletions src/Cellm/Models/OpenAi/OpenAiConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@ internal class OpenAiConfiguration : IProviderConfiguration

public string ApiKey { get; init; }

public bool EnableTools { get; init; }

public OpenAiConfiguration()
{
BaseAddress = default!;
DefaultModel = default!;
ApiKey = default!;
EnableTools = default;
}
}
102 changes: 25 additions & 77 deletions src/Cellm/Models/OpenAi/OpenAiRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,100 +1,48 @@
using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.Json.Serialization;
using Cellm.AddIn;
using Cellm.AddIn.Exceptions;
using Cellm.Models.OpenAi.Models;
using System.ClientModel;
using System.ClientModel.Primitives;
using Cellm.Prompts;
using Cellm.Tools;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using OpenAI;

namespace Cellm.Models.OpenAi;

internal class OpenAiRequestHandler : IModelRequestHandler<OpenAiRequest, OpenAiResponse>
internal class OpenAiRequestHandler(IOptions<OpenAiConfiguration> openAiConfiguration, HttpClient httpClient) : IModelRequestHandler<OpenAiRequest, OpenAiResponse>
{
private readonly OpenAiConfiguration _openAiConfiguration;
private readonly CellmConfiguration _cellmConfiguration;
private readonly HttpClient _httpClient;
private readonly ToolRunner _toolRunner;
private readonly Serde _serde;

public OpenAiRequestHandler(
IOptions<OpenAiConfiguration> openAiConfiguration,
IOptions<CellmConfiguration> cellmConfiguration,
HttpClient httpClient,
ToolRunner toolRunner,
Serde serde)
{
_openAiConfiguration = openAiConfiguration.Value;
_cellmConfiguration = cellmConfiguration.Value;
_httpClient = httpClient;
_toolRunner = toolRunner;
_serde = serde;
}
private readonly OpenAiConfiguration _openAiConfiguration = openAiConfiguration.Value;

public async Task<OpenAiResponse> Handle(OpenAiRequest request, CancellationToken cancellationToken)
{
var modelId = request.Prompt.Options.ModelId ?? _openAiConfiguration.DefaultModel;

const string path = "/v1/chat/completions";
var address = request.BaseAddress is null ? new Uri(path, UriKind.Relative) : new Uri(request.BaseAddress, path);

var json = Serialize(request);
var jsonAsStringContent = new StringContent(json, Encoding.UTF8, "application/json");

var response = await _httpClient.PostAsync(address, jsonAsStringContent, cancellationToken);
var responseBodyAsString = await response.Content.ReadAsStringAsync(cancellationToken);
// Must instantiate manually because address can be set/changed only at instantiation
var chatClient = GetChatClient(address, modelId);
var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken);

if (!response.IsSuccessStatusCode)
{
throw new HttpRequestException($"{nameof(OpenAiRequest)} failed: {responseBodyAsString}", null, response.StatusCode);
}

return Deserialize(request, responseBodyAsString);
}

public string Serialize(OpenAiRequest request)
{
var openAiPrompt = new PromptBuilder(request.Prompt)
.AddSystemMessage()
var prompt = new PromptBuilder(request.Prompt)
.AddMessage(chatCompletion.Message)
.Build();

var chatCompletionRequest = new OpenAiChatCompletionRequest(
openAiPrompt.Model,
openAiPrompt.ToOpenAiMessages(),
_cellmConfiguration.MaxOutputTokens,
openAiPrompt.Temperature,
_toolRunner.ToOpenAiTools(),
"auto");

return _serde.Serialize(chatCompletionRequest, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});
return new OpenAiResponse(prompt);
}

public OpenAiResponse Deserialize(OpenAiRequest request, string responseBodyAsString)
private IChatClient GetChatClient(Uri address, string modelId)
{
var responseBody = _serde.Deserialize<OpenAiChatCompletionResponse>(responseBodyAsString, new JsonSerializerOptions
var openAiClientCredentials = new ApiKeyCredential(_openAiConfiguration.ApiKey);
var openAiClientOptions = new OpenAIClientOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});

var choice = responseBody?.Choices?.FirstOrDefault() ?? throw new CellmException("Empty response from OpenAI API");
var toolCalls = choice.Message.ToolCalls?
.Select(x => new ToolCall(x.Id, x.Function.Name, x.Function.Arguments, null))
.ToList();
Transport = new HttpClientPipelineTransport(httpClient),
Endpoint = address
};

var content = choice.Message.Content;
var message = new Message(content, Roles.Assistant, toolCalls);
var openAiClient = new OpenAIClient(openAiClientCredentials, openAiClientOptions);

var prompt = new PromptBuilder(request.Prompt)
.AddMessage(message)
.Build();

return new OpenAiResponse(prompt);
return new ChatClientBuilder()
.UseLogging()
.UseFunctionInvocation()
.Use(openAiClient.AsChatClient(modelId));
}
}
Loading

0 comments on commit 27609c0

Please sign in to comment.