Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update target framework to net8.0 and use Microsoft.Extensions.AI #61

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Cellm.Tests/Cellm.Tests.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0-windows</TargetFramework>
<TargetFramework>net8.0-windows</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<RollForward>LatestMajor</RollForward>
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm.Tests/packages.lock.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"version": 1,
"dependencies": {
"net6.0-windows7.0": {
"net8.0-windows7.0": {
"ExcelDna.Testing": {
"type": "Direct",
"requested": "[1.8.0, )",
Expand Down
2 changes: 2 additions & 0 deletions src/Cellm/AddIn/CellmConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ public class CellmConfiguration
public int HttpTimeoutInSeconds { get; init; }

public int CacheTimeoutInSeconds { get; init; }

public bool EnableTools { get; init; }
}

27 changes: 7 additions & 20 deletions src/Cellm/AddIn/Functions.cs → src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Cellm.AddIn;

public static class Functions
public static class CellmFunctions
{
/// <summary>
/// Sends a prompt to the default model configured in CellmConfiguration.
Expand Down Expand Up @@ -73,7 +73,7 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.Get<PromptWithArgumentParser>()
var arguments = ServiceLocator.Get<PromptArgumentParser>()
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand All @@ -88,8 +88,8 @@ public static object PromptWith(

var prompt = new PromptBuilder()
.SetModel(arguments.Model)
.SetSystemMessage(SystemMessages.SystemMessage)
.SetTemperature(arguments.Temperature)
.AddSystemMessage(SystemMessages.SystemMessage)
.AddUserMessage(userMessage)
.Build();

Expand All @@ -102,6 +102,7 @@ public static object PromptWith(
catch (CellmException ex)
{
SentrySdk.CaptureException(ex);
Debug.WriteLine(ex);
return ex.Message;
}
}
Expand All @@ -117,22 +118,8 @@ public static object PromptWith(

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
try
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
var content = response.Messages.Last().Content;
return content;
}
catch (CellmException ex)
{
Debug.WriteLine(ex);
throw;
}
catch (Exception ex)
{
Debug.WriteLine(ex);
throw new CellmException("An unexpected error occurred", ex);
}
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
return response.Messages.Last().Text ?? throw new NullReferenceException("No text response");
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;
Expand All @@ -8,7 +9,7 @@ namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);

public class PromptWithArgumentParser
public class PromptArgumentParser
{
private string? _provider;
private string? _model;
Expand All @@ -18,12 +19,12 @@ public class PromptWithArgumentParser

private readonly IConfiguration _configuration;

public PromptWithArgumentParser(IConfiguration configuration)
public PromptArgumentParser(IConfiguration configuration)
{
_configuration = configuration;
}

public PromptWithArgumentParser AddProvider(object providerAndModel)
public PromptArgumentParser AddProvider(object providerAndModel)
{
_provider = providerAndModel switch
{
Expand All @@ -35,7 +36,7 @@ public PromptWithArgumentParser AddProvider(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddModel(object providerAndModel)
public PromptArgumentParser AddModel(object providerAndModel)
{
_model = providerAndModel switch
{
Expand All @@ -47,21 +48,21 @@ public PromptWithArgumentParser AddModel(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddInstructionsOrContext(object instructionsOrContext)
public PromptArgumentParser AddInstructionsOrContext(object instructionsOrContext)
{
_instructionsOrContext = instructionsOrContext;

return this;
}

public PromptWithArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
public PromptArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
{
_instructionsOrTemperature = instructionsOrTemperature;

return this;
}

public PromptWithArgumentParser AddTemperature(object temperature)
public PromptArgumentParser AddTemperature(object temperature)
{
_temperature = temperature;

Expand Down Expand Up @@ -92,25 +93,25 @@ public Arguments Parse()
// "=PROMPT("Extract keywords", 0.7)
(string instructions, double temperature, ExcelMissing) => new Arguments(provider, model, string.Empty, RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2)
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, 0.7)
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords")
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords", 0.7)
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2, C1:D2)
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, C1:D2, 0.7)
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
// Anything else
_ => throw new ArgumentException($"Invalid arguments ({_instructionsOrContext?.GetType().Name}, {_instructionsOrTemperature?.GetType().Name}, {_temperature?.GetType().Name})")
};
}

private static string GetProvider(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand All @@ -122,7 +123,7 @@ private static string GetProvider(string providerAndModel)

private static string GetModel(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand Down Expand Up @@ -203,7 +204,7 @@ private static string GetRowName(int rowNumber)
return (rowNumber + 1).ToString();
}

private static string RenderContext(string context)
private static string RenderCells(string context)
{
return new StringBuilder()
.AppendLine("<context>")
Expand Down
7 changes: 6 additions & 1 deletion src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0-windows</TargetFramework>
<TargetFramework>net8.0-windows</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<RollForward>LatestMajor</RollForward>
Expand All @@ -23,6 +23,11 @@
<PackageReference Include="JsonPatch.Net" Version="3.1.1" />
<PackageReference Include="JsonSchema.Net.Generation" Version="4.5.1" />
<PackageReference Include="MediatR" Version="12.4.1" />
<PackageReference Include="Microsoft.Extensions.AI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.Ollama" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Hybrid" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="8.0.1" />
Expand Down
9 changes: 5 additions & 4 deletions src/Cellm/Models/Anthropic/AnthropicRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic.Models;
using Cellm.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Anthropic;
Expand Down Expand Up @@ -52,11 +53,11 @@ public string Serialize(AnthropicRequest request)
{
var requestBody = new AnthropicRequestBody
{
System = request.Prompt.SystemMessage,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Content, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Model ?? _anthropicConfiguration.DefaultModel,
System = request.Prompt.Messages.Where(x => x.Role == ChatRole.System).First().Text,
Messages = request.Prompt.Messages.Select(x => new AnthropicMessage { Content = x.Text, Role = x.Role.ToString().ToLower() }).ToList(),
Model = request.Prompt.Options.ModelId ?? _anthropicConfiguration.DefaultModel,
MaxTokens = _cellmConfiguration.MaxOutputTokens,
Temperature = request.Prompt.Temperature
Temperature = request.Prompt.Options.Temperature ?? _cellmConfiguration.DefaultTemperature,
};

return _serde.Serialize(requestBody, new JsonSerializerOptions
Expand Down
3 changes: 0 additions & 3 deletions src/Cellm/Models/IModelRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,4 @@ namespace Cellm.Models;
internal interface IModelRequestHandler<TRequest, TResponse> : IRequestHandler<TRequest, TResponse>
where TRequest : IRequest<TResponse>
{
public string Serialize(TRequest request);

public TResponse Deserialize(TRequest request, string response);
}
42 changes: 17 additions & 25 deletions src/Cellm/Models/ModelRequestBehavior/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,46 +1,38 @@
using Cellm.Prompts;
using Cellm.AddIn;
using Cellm.Tools;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;

internal class ToolBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly ISender _sender;
private readonly ToolRunner _toolRunner;
private readonly CellmConfiguration _cellmConfiguration;
private readonly Functions _functions;
private readonly List<AITool> _tools;

public ToolBehavior(ISender sender, ToolRunner toolRunner)
public ToolBehavior(IOptions<CellmConfiguration> cellmConfiguration, Functions functions)
{
_sender = sender;
_toolRunner = toolRunner;
_cellmConfiguration = cellmConfiguration.Value;
_functions = functions;
_tools = [
AIFunctionFactory.Create(_functions.GlobRequest),
AIFunctionFactory.Create(_functions.FileReaderRequest)
];
}

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
var response = await next();

var toolCalls = response.Prompt.Messages.LastOrDefault()?.ToolCalls;

if (toolCalls is not null)
if (_cellmConfiguration.EnableTools)
{
// Model called tools, run tools and call model again
var message = await RunTools(toolCalls);
request.Prompt.Messages.Add(message);
response = await _sender.Send(request, cancellationToken);
}

return response;
request.Prompt.Options.Tools = _tools;
}

private async Task<Message> RunTools(List<ToolCall> toolCalls)
{
var toolResults = await Task.WhenAll(toolCalls.Select(x => _toolRunner.Run(x)));
var toolCallsWithResults = toolCalls
.Zip(toolResults, (toolCall, toolResult) => toolCall with { Result = toolResult })
.ToList();
var response = await next();

return new Message(string.Empty, Roles.Tool, toolCallsWithResults);
return response;
}
}
50 changes: 0 additions & 50 deletions src/Cellm/Models/OpenAi/Extensions.cs

This file was deleted.

Loading
Loading