Skip to content

Commit

Permalink
refactor: Clean up src/Cellm/AddIn (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaspermarstal authored Nov 21, 2024
1 parent 49ca9ec commit 40652b9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 38 deletions.
27 changes: 7 additions & 20 deletions src/Cellm/AddIn/Functions.cs → src/Cellm/AddIn/CellmFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Cellm.AddIn;

public static class Functions
public static class CellmFunctions
{
/// <summary>
/// Sends a prompt to the default model configured in CellmConfiguration.
Expand Down Expand Up @@ -73,7 +73,7 @@ public static object PromptWith(
{
try
{
var arguments = ServiceLocator.Get<PromptWithArgumentParser>()
var arguments = ServiceLocator.Get<PromptArgumentParser>()
.AddProvider(providerAndModel)
.AddModel(providerAndModel)
.AddInstructionsOrContext(instructionsOrContext)
Expand All @@ -88,8 +88,8 @@ public static object PromptWith(

var prompt = new PromptBuilder()
.SetModel(arguments.Model)
.SetSystemMessage(SystemMessages.SystemMessage)
.SetTemperature(arguments.Temperature)
.AddSystemMessage(SystemMessages.SystemMessage)
.AddUserMessage(userMessage)
.Build();

Expand All @@ -102,6 +102,7 @@ public static object PromptWith(
catch (CellmException ex)
{
SentrySdk.CaptureException(ex);
Debug.WriteLine(ex);
return ex.Message;
}
}
Expand All @@ -117,22 +118,8 @@ public static object PromptWith(

private static async Task<string> CallModelAsync(Prompt prompt, string? provider = null, Uri? baseAddress = null)
{
try
{
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
var content = response.Messages.Last().Content;
return content;
}
catch (CellmException ex)
{
Debug.WriteLine(ex);
throw;
}
catch (Exception ex)
{
Debug.WriteLine(ex);
throw new CellmException("An unexpected error occurred", ex);
}
var client = ServiceLocator.Get<Client>();
var response = await client.Send(prompt, provider, baseAddress);
return response.Messages.Last().Content;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Text;
using Cellm.AddIn.Exceptions;
using Cellm.Prompts;
using ExcelDna.Integration;
using Microsoft.Extensions.Configuration;
using Microsoft.Office.Interop.Excel;
Expand All @@ -8,7 +9,7 @@ namespace Cellm.AddIn;

public record Arguments(string Provider, string Model, string Context, string Instructions, double Temperature);

public class PromptWithArgumentParser
public class PromptArgumentParser
{
private string? _provider;
private string? _model;
Expand All @@ -18,12 +19,12 @@ public class PromptWithArgumentParser

private readonly IConfiguration _configuration;

public PromptWithArgumentParser(IConfiguration configuration)
public PromptArgumentParser(IConfiguration configuration)
{
_configuration = configuration;
}

public PromptWithArgumentParser AddProvider(object providerAndModel)
public PromptArgumentParser AddProvider(object providerAndModel)
{
_provider = providerAndModel switch
{
Expand All @@ -35,7 +36,7 @@ public PromptWithArgumentParser AddProvider(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddModel(object providerAndModel)
public PromptArgumentParser AddModel(object providerAndModel)
{
_model = providerAndModel switch
{
Expand All @@ -47,21 +48,21 @@ public PromptWithArgumentParser AddModel(object providerAndModel)
return this;
}

public PromptWithArgumentParser AddInstructionsOrContext(object instructionsOrContext)
public PromptArgumentParser AddInstructionsOrContext(object instructionsOrContext)
{
_instructionsOrContext = instructionsOrContext;

return this;
}

public PromptWithArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
public PromptArgumentParser AddInstructionsOrTemperature(object instructionsOrTemperature)
{
_instructionsOrTemperature = instructionsOrTemperature;

return this;
}

public PromptWithArgumentParser AddTemperature(object temperature)
public PromptArgumentParser AddTemperature(object temperature)
{
_temperature = temperature;

Expand Down Expand Up @@ -92,25 +93,25 @@ public Arguments Parse()
// "=PROMPT("Extract keywords", 0.7)
(string instructions, double temperature, ExcelMissing) => new Arguments(provider, model, string.Empty, RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2)
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelMissing, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, 0.7)
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, double temperature, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(SystemMessages.InlineInstructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords")
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
(ExcelReference context, string instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, "Extract keywords", 0.7)
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
(ExcelReference context, string instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(instructions), ParseTemperature(temperature)),
// "=PROMPT(A1:B2, C1:D2)
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
(ExcelReference context, ExcelReference instructions, ExcelMissing) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(defaultTemperature)),
// "=PROMPT(A1:B2, C1:D2, 0.7)
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderContext(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
(ExcelReference context, ExcelReference instructions, double temperature) => new Arguments(provider, model, RenderCells(ParseCells(context)), RenderInstructions(ParseCells(instructions)), ParseTemperature(temperature)),
// Anything else
_ => throw new ArgumentException($"Invalid arguments ({_instructionsOrContext?.GetType().Name}, {_instructionsOrTemperature?.GetType().Name}, {_temperature?.GetType().Name})")
};
}

private static string GetProvider(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand All @@ -122,7 +123,7 @@ private static string GetProvider(string providerAndModel)

private static string GetModel(string providerAndModel)
{
var index = providerAndModel.IndexOf("/");
var index = providerAndModel.IndexOf('/');

if (index < 0)
{
Expand Down Expand Up @@ -203,7 +204,7 @@ private static string GetRowName(int rowNumber)
return (rowNumber + 1).ToString();
}

private static string RenderContext(string context)
private static string RenderCells(string context)
{
return new StringBuilder()
.AppendLine("<context>")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
namespace Cellm.AddIn;
namespace Cellm.Prompts;

internal static class SystemMessages
{
Expand Down
2 changes: 1 addition & 1 deletion src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
.AddSingleton(configuration)
.AddMemoryCache()
.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(Assembly.GetExecutingAssembly()))
.AddTransient<PromptWithArgumentParser>()
.AddTransient<PromptArgumentParser>()
.AddSingleton<Client>()
.AddSingleton<Serde>();

Expand Down

0 comments on commit 40652b9

Please sign in to comment.