Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use Microsoft.Extensions.AI #66

Merged
merged 8 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Cellm/AddIn/CellmConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ public class CellmConfiguration
public int HttpTimeoutInSeconds { get; init; }

public int CacheTimeoutInSeconds { get; init; }

public bool EnableTools { get; init; }
}

2 changes: 1 addition & 1 deletion src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@ private static async Task<string> CallModelAsync(Prompt prompt, string? provider
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
return response.Messages.Last().Content;
return response.Messages.Last().Text ?? throw new NullReferenceException("No text response");
}
}
5 changes: 5 additions & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
<PackageReference Include="JsonPatch.Net" Version="3.1.1" />
<PackageReference Include="JsonSchema.Net.Generation" Version="4.5.1" />
<PackageReference Include="MediatR" Version="12.4.1" />
<PackageReference Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Ollama" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Hybrid" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="8.0.1" />
Expand Down
9 changes: 5 additions & 4 deletions src/Cellm/Models/Anthropic/AnthropicRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic.Models;
using Cellm.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Anthropic;
Expand Down Expand Up @@ -52,11 +53,11 @@ public string Serialize(AnthropicRequest request)
{
var requestBody = new AnthropicRequestBody
{
System = request.Prompt.SystemMessage,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Content, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Model ?? _anthropicConfiguration.DefaultModel,
System = request.Prompt.Messages.Where(x => x.Role == ChatRole.System).First().Text,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Text, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Options.ModelId ?? _anthropicConfiguration.DefaultModel,
MaxTokens = _cellmConfiguration.MaxOutputTokens,
Temperature = request.Prompt.Temperature
Temperature = request.Prompt.Options.Temperature ?? _cellmConfiguration.DefaultTemperature,
};

return _serde.Serialize(requestBody, new JsonSerializerOptions
Expand Down
3 changes: 0 additions & 3 deletions src/Cellm/Models/IModelRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,4 @@ namespace Cellm.Models;
internal interface IModelRequestHandler<TRequest, TResponse> : IRequestHandler<TRequest, TResponse>
where TRequest : IRequest<TResponse>
{
public string Serialize(TRequest request);

public TResponse Deserialize(TRequest request, string response);
}
4 changes: 2 additions & 2 deletions src/Cellm/Models/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ public LlamafileRequestHandler(IOptions<CellmConfiguration> cellmConfiguration,

public async Task<LlamafileResponse> Handle(LlamafileRequest request, CancellationToken cancellationToken)
{
// Download model and start Llamafile on first call
var llamafile = await _llamafiles[request.Prompt.Model ?? _llamafileConfiguration.DefaultModel];
// Start server on first call
var llamafile = await _llamafiles[request.Prompt.Options.ModelId ?? _llamafileConfiguration.DefaultModel];

var openAiResponse = await _sender.Send(new OpenAiRequest(request.Prompt, nameof(Llamafile), llamafile.BaseAddress), cancellationToken);

Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class CachingBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Models/ModelRequestBehavior/SentryBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MediatR;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class SentryBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
Expand Down
45 changes: 18 additions & 27 deletions src/Cellm/Models/ModelRequestBehavior/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
using Cellm.Prompts;
using Cellm.AddIn;
using Cellm.Tools;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;
namespace Cellm.Models.ModelRequestBehavior;

internal class ToolBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly ISender _sender;
private readonly ToolRunner _toolRunner;
private readonly CellmConfiguration _cellmConfiguration;
private readonly Functions _functions;
private readonly List<AITool> _tools;

public ToolBehavior(ISender sender, ToolRunner toolRunner)
public ToolBehavior(IOptions<CellmConfiguration> cellmConfiguration, Functions functions)
{
_sender = sender;
_toolRunner = toolRunner;
_cellmConfiguration = cellmConfiguration.Value;
_functions = functions;
_tools = [
AIFunctionFactory.Create(_functions.GlobRequest),
AIFunctionFactory.Create(_functions.FileReaderRequest)
];
}

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
var response = await next();

var toolCalls = response.Prompt.Messages.LastOrDefault()?.ToolCalls;

if (toolCalls is not null)
if (_cellmConfiguration.EnableTools)
{
// Model called tools, run tools and call model again
var message = await RunTools(toolCalls);
request.Prompt.Messages.Add(message);
response = await _sender.Send(request, cancellationToken);
request.Prompt.Options.Tools = _tools;
}

return response;
}

private async Task<Message> RunTools(List<ToolCall> toolCalls)
{
var toolResults = await Task.WhenAll(toolCalls.Select(x => _toolRunner.Run(x)));
var toolCallsWithResults = toolCalls
.Zip(toolResults, (toolCall, toolResult) => toolCall with { Result = toolResult })
.ToList();
var response = await next();

return new Message(string.Empty, Roles.Tool, toolCallsWithResults);
return response;
}
}
50 changes: 0 additions & 50 deletions src/Cellm/Models/OpenAi/Extensions.cs

This file was deleted.

53 changes: 0 additions & 53 deletions src/Cellm/Models/OpenAi/Models.cs

This file was deleted.

3 changes: 0 additions & 3 deletions src/Cellm/Models/OpenAi/OpenAiConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@ internal class OpenAiConfiguration : IProviderConfiguration

public string ApiKey { get; init; }

public bool EnableTools { get; init; }

public OpenAiConfiguration()
{
BaseAddress = default!;
DefaultModel = default!;
ApiKey = default!;
EnableTools = default;
}
}
102 changes: 25 additions & 77 deletions src/Cellm/Models/OpenAi/OpenAiRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,100 +1,48 @@
using System.Text;
using System.Text.Encodings.Web;
using System.Text.Json;
using System.Text.Json.Serialization;
using Cellm.AddIn;
using Cellm.AddIn.Exceptions;
using Cellm.Models.OpenAi.Models;
using System.ClientModel;
using System.ClientModel.Primitives;
using Cellm.Prompts;
using Cellm.Tools;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
using OpenAI;

namespace Cellm.Models.OpenAi;

internal class OpenAiRequestHandler : IModelRequestHandler<OpenAiRequest, OpenAiResponse>
internal class OpenAiRequestHandler(IOptions<OpenAiConfiguration> openAiConfiguration, HttpClient httpClient) : IModelRequestHandler<OpenAiRequest, OpenAiResponse>
{
private readonly OpenAiConfiguration _openAiConfiguration;
private readonly CellmConfiguration _cellmConfiguration;
private readonly HttpClient _httpClient;
private readonly ToolRunner _toolRunner;
private readonly Serde _serde;

public OpenAiRequestHandler(
IOptions<OpenAiConfiguration> openAiConfiguration,
IOptions<CellmConfiguration> cellmConfiguration,
HttpClient httpClient,
ToolRunner toolRunner,
Serde serde)
{
_openAiConfiguration = openAiConfiguration.Value;
_cellmConfiguration = cellmConfiguration.Value;
_httpClient = httpClient;
_toolRunner = toolRunner;
_serde = serde;
}
private readonly OpenAiConfiguration _openAiConfiguration = openAiConfiguration.Value;

public async Task<OpenAiResponse> Handle(OpenAiRequest request, CancellationToken cancellationToken)
{
var modelId = request.Prompt.Options.ModelId ?? _openAiConfiguration.DefaultModel;

const string path = "/v1/chat/completions";
var address = request.BaseAddress is null ? new Uri(path, UriKind.Relative) : new Uri(request.BaseAddress, path);

var json = Serialize(request);
var jsonAsStringContent = new StringContent(json, Encoding.UTF8, "application/json");

var response = await _httpClient.PostAsync(address, jsonAsStringContent, cancellationToken);
var responseBodyAsString = await response.Content.ReadAsStringAsync(cancellationToken);
// Must instantiate manually because address can be set/changed only at instantiation
var chatClient = GetChatClient(address, modelId);
var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken);

if (!response.IsSuccessStatusCode)
{
throw new HttpRequestException($"{nameof(OpenAiRequest)} failed: {responseBodyAsString}", null, response.StatusCode);
}

return Deserialize(request, responseBodyAsString);
}

public string Serialize(OpenAiRequest request)
{
var openAiPrompt = new PromptBuilder(request.Prompt)
.AddSystemMessage()
var prompt = new PromptBuilder(request.Prompt)
.AddMessage(chatCompletion.Message)
.Build();

var chatCompletionRequest = new OpenAiChatCompletionRequest(
openAiPrompt.Model,
openAiPrompt.ToOpenAiMessages(),
_cellmConfiguration.MaxOutputTokens,
openAiPrompt.Temperature,
_toolRunner.ToOpenAiTools(),
"auto");

return _serde.Serialize(chatCompletionRequest, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});
return new OpenAiResponse(prompt);
}

public OpenAiResponse Deserialize(OpenAiRequest request, string responseBodyAsString)
private IChatClient GetChatClient(Uri address, string modelId)
{
var responseBody = _serde.Deserialize<OpenAiChatCompletionResponse>(responseBodyAsString, new JsonSerializerOptions
var openAiClientCredentials = new ApiKeyCredential(_openAiConfiguration.ApiKey);
var openAiClientOptions = new OpenAIClientOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});

var choice = responseBody?.Choices?.FirstOrDefault() ?? throw new CellmException("Empty response from OpenAI API");
var toolCalls = choice.Message.ToolCalls?
.Select(x => new ToolCall(x.Id, x.Function.Name, x.Function.Arguments, null))
.ToList();
Transport = new HttpClientPipelineTransport(httpClient),
Endpoint = address
};

var content = choice.Message.Content;
var message = new Message(content, Roles.Assistant, toolCalls);
var openAiClient = new OpenAIClient(openAiClientCredentials, openAiClientOptions);

var prompt = new PromptBuilder(request.Prompt)
.AddMessage(message)
.Build();

return new OpenAiResponse(prompt);
return new ChatClientBuilder()
.UseLogging()
.UseFunctionInvocation()
.Use(openAiClient.AsChatClient(modelId));
}
}
Loading