Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Models #75

Merged
merged 8 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/Cellm.Tests/IntegrationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Cellm.Tests;

[ExcelTestSettings(AddIn = @"..\..\..\..\Cellm\bin\Debug\net6.0-windows\Cellm-AddIn")]
[ExcelTestSettings(AddIn = @"..\..\..\..\Cellm\bin\Debug\net8.0-windows\Cellm-AddIn")]
public class ExcelTests : IDisposable
{
readonly Workbook _testWorkbook;
Expand Down Expand Up @@ -54,17 +54,17 @@ public void TestPromptWith()
Worksheet ws = (Worksheet)_testWorkbook.Sheets[1];
ws.Range["A1"].Value = "Respond with \"Hello World\"";
ws.Range["A2"].Formula = "=PROMPTWITH(\"Anthropic/claude-3-haiku-20240307\",A1)";
ExcelTestHelper.WaitForCellValue(ws.Range["A2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["A2"].Text);

ws.Range["B1"].Value = "Respond with \"Hello World\"";
ws.Range["B2"].Formula = "=PROMPTWITH(\"OpenAI/gpt-4o-mini\",B1)";
ExcelTestHelper.WaitForCellValue(ws.Range["B2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["B2"].Text);

ws.Range["C1"].Value = "Respond with \"Hello World\"";
ws.Range["C2"].Formula = "=PROMPTWITH(\"OpenAI/gemini-1.5-flash-latest\",C1)";
ExcelTestHelper.WaitForCellValue(ws.Range["C2"]);
Automation.Wait(5000);
Assert.Equal("Hello World", ws.Range["C2"].Text);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using Cellm.Services.Configuration;
using Cellm.Models.Providers;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;
Expand All @@ -10,7 +9,7 @@ namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);

public class PromptArgumentParser
public class ArgumentParser
{
private string? _provider;
private string? _model;
Expand All @@ -20,12 +19,12 @@ public class PromptArgumentParser

private readonly IConfiguration _configuration;

public PromptArgumentParser(IConfiguration configuration)
public ArgumentParser(IConfiguration configuration)
{
_configuration = configuration;
}

public PromptArgumentParser AddProvider(object providerAndModel)
public ArgumentParser AddProvider(object providerAndModel)
{
_provider = providerAndModel switch
{
Expand All @@ -37,7 +36,7 @@ public PromptArgumentParser AddProvider(object providerAndModel)
return this;
}

public PromptArgumentParser AddModel(object providerAndModel)
public ArgumentParser AddModel(object providerAndModel)
{
_model = providerAndModel switch
{
Expand All @@ -49,21 +48,21 @@ public PromptArgumentParser AddModel(object providerAndModel)
return this;
}

public PromptArgumentParser AddInstructionsOrContext(object instructionsOrContext)
public ArgumentParser AddInstructionsOrContext(object instructionsOrContext)
{
_instructionsOrContext = instructionsOrContext;

return this;
}

public PromptArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
public ArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
{
_instructionsOrTemperature = instructionsOrTemperature;

return this;
}

public PromptArgumentParser AddTemperature(object temperature)
public ArgumentParser AddTemperature(object temperature)
{
_temperature = temperature;

Expand All @@ -73,19 +72,19 @@ public PromptArgumentParser AddTemperature(object temperature)
public Arguments Parse()
{
var provider = _provider ?? _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<string>(nameof(CellmConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultProvider));
.GetSection(nameof(ProviderConfiguration))
.GetValue<string>(nameof(ProviderConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultProvider));

var model = _model ?? _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<string>(nameof(CellmConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultModel));
.GetSection($"{provider}Configuration")
.GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(IProviderConfiguration.DefaultModel));

var defaultTemperature = _configuration
.GetSection(nameof(CellmConfiguration))
.GetValue<double?>(nameof(CellmConfiguration.DefaultTemperature))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultTemperature));
.GetSection(nameof(ProviderConfiguration))
.GetValue<double?>(nameof(ProviderConfiguration.DefaultTemperature))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultTemperature));

return (_instructionsOrContext, _instructionsOrTemperature, _temperature) switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Cellm.AddIn;

public class CellmAddIn : IExcelAddIn
public class ExcelAddin : IExcelAddIn
{
public void AutoOpen()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Models;
using Cellm.Prompts;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Services;
using Cellm.Services.Configuration;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;

namespace Cellm.AddIn;

public static class CellmFunctions
public static class ExcelFunctions
{
/// <summary>
/// Sends a prompt to the default model configured in CellmConfiguration.
Expand All @@ -35,8 +35,8 @@ public static object Prompt(
{
var configuration = ServiceLocator.Get<IConfiguration>();

var provider = configuration.GetSection(nameof(CellmConfiguration)).GetValue<string>(nameof(CellmConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(CellmConfiguration.DefaultProvider));
var provider = configuration.GetSection(nameof(ProviderConfiguration)).GetValue<string>(nameof(ProviderConfiguration.DefaultProvider))
?? throw new ArgumentException(nameof(ProviderConfiguration.DefaultProvider));

var model = configuration.GetSection($"{provider}Configuration").GetValue<string>(nameof(IProviderConfiguration.DefaultModel))
?? throw new ArgumentException(nameof(IProviderConfiguration.DefaultModel));
Expand Down Expand Up @@ -73,7 +73,7 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.Get<PromptArgumentParser>()
var arguments = ServiceLocator.Get<ArgumentParser>()
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand Down Expand Up @@ -116,7 +116,7 @@ public static object PromptWith(
/// <returns>A task that represents the asynchronous operation. The task result contains the model's response as a string.</returns>
/// <exception cref="CellmException">Thrown when an unexpected error occurs during the operation.</exception>

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
internal static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress, CancellationToken.None);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace Cellm.Prompts;
namespace Cellm.AddIn;

internal static class SystemMessages
{
Expand Down
1 change: 1 addition & 0 deletions src/Cellm/Cellm.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" Version="9.0.1-preview.1.24570.5" />
<PackageReference Include="Microsoft.Extensions.Caching.Hybrid" Version="9.0.0-preview.9.24556.5" />
<PackageReference Include="Microsoft.Extensions.Configuration" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Abstractions" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.FileSystemGlobbing" Version="9.0.0" />
Expand Down
5 changes: 0 additions & 5 deletions src/Cellm/Models/Anthropic/AnthropicResponse.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
using System.Text.Json;
using Cellm.Services.Configuration;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.Caching.Hybrid;
using Microsoft.Extensions.Options;

namespace Cellm.Models.ModelRequestBehavior;
namespace Cellm.Models.Behaviors;

internal class CachingBehavior<TRequest, TResponse>(HybridCache cache, IOptions<CellmConfiguration> cellmConfiguration) : IPipelineBehavior<TRequest, TResponse>
internal class CacheBehavior<TRequest, TResponse>(HybridCache cache, IOptions<ProviderConfiguration> providerConfiguration) : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly HybridCacheEntryOptions _cacheEntryOptions = new()
{
Expiration = TimeSpan.FromSeconds(cellmConfiguration.Value.CacheTimeoutInSeconds)
Expiration = TimeSpan.FromSeconds(providerConfiguration.Value.CacheTimeoutInSeconds)
};

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (cellmConfiguration.Value.EnableCache)
if (providerConfiguration.Value.EnableCache)
{
return await cache.GetOrCreateAsync(
JsonSerializer.Serialize(request.Prompt),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using MediatR;

namespace Cellm.Models.ModelRequestBehavior;
namespace Cellm.Models.Behaviors;

internal class SentryBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : notnull
Expand Down
24 changes: 24 additions & 0 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using Cellm.Models.Behaviors;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Tools;

internal class ToolBehavior<TRequest, TResponse>(IOptions<ProviderConfiguration> providerConfiguration, IEnumerable<AIFunction> functions)
: IPipelineBehavior<TRequest, TResponse> where TRequest : IModelRequest<TResponse>
{
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;
private readonly List<AITool> _tools = new(functions);

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (_providerConfiguration.EnableTools)
{
request.Prompt.Options.Tools = _tools;
}

return await next();
}
}
57 changes: 22 additions & 35 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
@@ -1,73 +1,60 @@
using System.Text.Json;
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic;
using Cellm.Models.Llamafile;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Models.OpenAiCompatible;
using Cellm.Prompts;
using Cellm.Services.Configuration;
using Cellm.Models.Exceptions;
using Cellm.Models.Prompts;
using Cellm.Models.Providers;
using Cellm.Models.Providers.Anthropic;
using Cellm.Models.Providers.Llamafile;
using Cellm.Models.Providers.Ollama;
using Cellm.Models.Providers.OpenAi;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Options;
using Polly.Timeout;

namespace Cellm.Models;

internal class Client(ISender sender, IOptions<CellmConfiguration> cellmConfiguration)
public class Client(ISender sender, IOptions<ProviderConfiguration> providerConfiguration)
{
private readonly CellmConfiguration _cellmConfiguration = cellmConfiguration.Value;
private readonly ProviderConfiguration _providerConfiguration = providerConfiguration.Value;

public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress, CancellationToken cancellationToken)
{
try
{
provider ??= _cellmConfiguration.DefaultProvider;
provider ??= _providerConfiguration.DefaultProvider;

if (!Enum.TryParse<Providers>(provider, true, out var parsedProvider))
if (!Enum.TryParse<Provider>(provider, true, out var parsedProvider))
{
throw new ArgumentException($"Unsupported provider: {provider}");
}

IModelResponse response = parsedProvider switch
{
Providers.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Providers.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Providers.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Providers.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Providers.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
Provider.Anthropic => await sender.Send(new AnthropicRequest(prompt, provider, baseAddress), cancellationToken),
Provider.Llamafile => await sender.Send(new LlamafileRequest(prompt), cancellationToken),
Provider.Ollama => await sender.Send(new OllamaRequest(prompt), cancellationToken),
Provider.OpenAi => await sender.Send(new OpenAiRequest(prompt), cancellationToken),
Provider.OpenAiCompatible => await sender.Send(new OpenAiCompatibleRequest(prompt, baseAddress), cancellationToken),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
};

return response.Prompt;
}
catch (HttpRequestException ex)
{
throw new CellmException($"HTTP request failed: {ex.Message}", ex);
}
catch (JsonException ex)
{
throw new CellmException($"JSON processing failed: {ex.Message}", ex);
}
catch (NotSupportedException ex)
{
throw new CellmException($"Method not supported: {ex.Message}", ex);
}
catch (FileReaderException ex)
{
throw new CellmException($"File could not be read: {ex.Message}", ex);
throw new CellmModelException($"HTTP request failed: {ex.Message}", ex);
}
catch (NullReferenceException ex)
{
throw new CellmException($"Null reference error: {ex.Message}", ex);
throw new CellmModelException($"Null reference error: {ex.Message}", ex);
}
catch (TimeoutRejectedException ex)
{
throw new CellmException($"Request timed out: {ex.Message}", ex);
throw new CellmModelException($"Request timed out: {ex.Message}", ex);
}
catch (Exception ex) when (ex is not CellmException)
catch (Exception ex) when (ex is not CellmModelException)
{
// Handle any other unexpected exceptions
throw new CellmException($"An unexpected error occurred: {ex.Message}", ex);
throw new CellmModelException($"An unexpected error occurred: {ex.Message}", ex);
}
}
}
10 changes: 10 additions & 0 deletions src/Cellm/Models/Exceptions/CellmModelException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Cellm.Models.Exceptions;

public class CellmModelException : Exception
{
public CellmModelException(string message = "#CELLM_ERROR?")
: base(message) { }

public CellmModelException(string message, Exception inner)
: base(message, inner) { }
}
5 changes: 0 additions & 5 deletions src/Cellm/Models/Llamafile/LlamafileResponse.cs

This file was deleted.

Loading