Skip to content

Commit

Permalink
feat: Remove support for embedded Ollama and Llamafile servers (#85)
Browse files Browse the repository at this point in the history
* refactor: Make argument provider an enum

* refactor: Use ExcelAsyncUtil to run task

* feat: Remove support for embedded Ollama and Llamafile servers
  • Loading branch information
kaspermarstal authored Jan 19, 2025
1 parent 3a85d6f commit 26da547
Show file tree
Hide file tree
Showing 15 changed files with 22 additions and 628 deletions.
2 changes: 1 addition & 1 deletion src/Cellm/AddIn/ExcelAddin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

namespace Cellm.AddIn;

public class ExcelAddin : IExcelAddIn
public class ExcelAddIn : IExcelAddIn
{
public void AutoOpen()
{
Expand Down
3 changes: 1 addition & 2 deletions src/Cellm/Models/Behaviors/ToolBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Cellm.Models.Behaviors;
using Cellm.Models.Providers;
using Cellm.Models.Providers;
using MediatR;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Options;
Expand Down
12 changes: 0 additions & 12 deletions src/Cellm/Models/Providers/Llamafile/LlamafileConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,13 @@

internal class LlamafileConfiguration : IProviderConfiguration
{
public Uri LlamafileUrl { get; init; }

public Uri BaseAddress { get; init; }

public Dictionary<string, Uri> Models { get; init; }

public string DefaultModel { get; init; }

public bool Gpu { get; init; }

public int GpuLayers { get; init; }

public LlamafileConfiguration()
{
LlamafileUrl = default!;
BaseAddress = default!;
Models = default!;
DefaultModel = default!;
Gpu = false;
GpuLayers = 999;
}
}
118 changes: 4 additions & 114 deletions src/Cellm/Models/Providers/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,129 +1,19 @@
using System.Diagnostics;
using Cellm.Models.Exceptions;
using Cellm.Models.Local.Utilities;
using Cellm.Models.Providers.OpenAiCompatible;
using Cellm.Models.Providers.OpenAiCompatible;
using MediatR;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Providers.Llamafile;

internal class LlamafileRequestHandler : IProviderRequestHandler<LlamafileRequest, LlamafileResponse>
internal class LlamafileRequestHandler(ISender sender, IOptions<LlamafileConfiguration> llamafileConfiguration) : IProviderRequestHandler<LlamafileRequest, LlamafileResponse>
{
private record Llamafile(string ModelPath, Uri BaseAddress, Process Process);

private readonly AsyncLazy<string> _llamafileExePath;
private readonly Dictionary<string, AsyncLazy<Llamafile>> _llamafiles;
private readonly ProcessManager _processManager;
private readonly FileManager _fileManager;
private readonly ServerManager _serverManager;

private readonly LlamafileConfiguration _llamafileConfiguration;

private readonly ISender _sender;
private readonly ILogger<LlamafileRequestHandler> _logger;

public LlamafileRequestHandler(
IOptions<LlamafileConfiguration> llamafileConfiguration,
ISender sender,
HttpClient httpClient,
FileManager fileManager,
ProcessManager processManager,
ServerManager serverManager,
ILogger<LlamafileRequestHandler> logger)
{
_llamafileConfiguration = llamafileConfiguration.Value;
_sender = sender;
_fileManager = fileManager;
_processManager = processManager;
_serverManager = serverManager;
_logger = logger;

_llamafileExePath = new AsyncLazy<string>(async () =>
{
var llamafileName = Path.GetFileName(_llamafileConfiguration.LlamafileUrl.Segments.Last());
return await _fileManager.DownloadFileIfNotExists(_llamafileConfiguration.LlamafileUrl, _fileManager.CreateCellmFilePath(CreateModelFileName($"{llamafileName}.exe"), "Llamafile"));
});

_llamafiles = _llamafileConfiguration.Models.ToDictionary(x => x.Key, x => new AsyncLazy<Llamafile>(async () =>
{
// Download Llamafile
var exePath = await _llamafileExePath;

// Download model
var modelPath = await _fileManager.DownloadFileIfNotExists(x.Value, _fileManager.CreateCellmFilePath(CreateModelFileName(x.Key), "Llamafile"));

// Start server
var baseAddress = new UriBuilder(
_llamafileConfiguration.BaseAddress.Scheme,
_llamafileConfiguration.BaseAddress.Host,
_serverManager.FindPort(),
_llamafileConfiguration.BaseAddress.AbsolutePath).Uri;

var process = await StartProcess(exePath, modelPath, baseAddress);

return new Llamafile(modelPath, baseAddress, process);
}));
}

public async Task<LlamafileResponse> Handle(LlamafileRequest request, CancellationToken cancellationToken)
{
// Start server on first call
var llamafile = await _llamafiles[request.Prompt.Options.ModelId ?? _llamafileConfiguration.DefaultModel];
request.Prompt.Options.ModelId ??= llamafileConfiguration.Value.DefaultModel;

var openAiResponse = await _sender.Send(new OpenAiCompatibleRequest(request.Prompt, llamafile.BaseAddress), cancellationToken);
var openAiResponse = await sender.Send(new OpenAiCompatibleRequest(request.Prompt, llamafileConfiguration.Value.BaseAddress), cancellationToken);

return new LlamafileResponse(openAiResponse.Prompt);
}

private async Task<Process> StartProcess(string exePath, string modelPath, Uri baseAddress)
{
var processStartInfo = new ProcessStartInfo(exePath);

processStartInfo.ArgumentList.Add("--server");
processStartInfo.ArgumentList.Add("--nobrowser");
processStartInfo.ArgumentList.Add("-m");
processStartInfo.ArgumentList.Add(modelPath);
processStartInfo.ArgumentList.Add("--host");
processStartInfo.ArgumentList.Add(baseAddress.Host);
processStartInfo.ArgumentList.Add("--port");
processStartInfo.ArgumentList.Add(baseAddress.Port.ToString());

if (_llamafileConfiguration.Gpu)
{
processStartInfo.Arguments += $"-ngl {_llamafileConfiguration.GpuLayers} ";
}

processStartInfo.UseShellExecute = false;
processStartInfo.CreateNoWindow = true;
processStartInfo.RedirectStandardError = true;
processStartInfo.RedirectStandardOutput = true;

var process = Process.Start(processStartInfo) ?? throw new CellmModelException("Failed to run Llamafile");

process.OutputDataReceived += (sender, e) =>
{
if (!string.IsNullOrEmpty(e.Data))
{
_logger.LogDebug(e.Data);
}
};

process.BeginOutputReadLine();
process.BeginErrorReadLine();

var uriBuilder = new UriBuilder(baseAddress.Scheme, baseAddress.Host, baseAddress.Port, "/health");
await _serverManager.WaitForServer(uriBuilder.Uri, process);

// Kill Llamafile when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}

private static string CreateModelFileName(string modelName)
{
return $"Llamafile-{modelName}";
}
}

163 changes: 4 additions & 159 deletions src/Cellm/Models/Providers/Ollama/OllamaRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,98 +1,15 @@
using System.Diagnostics;
using System.Net.Http.Json;
using System.Text;
using System.Text.Json;
using Cellm.Models.Exceptions;
using Cellm.Models.Local.Utilities;
using Cellm.Models.Prompts;
using Cellm.Models.Prompts;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Cellm.Models.Providers.Ollama;

internal class OllamaRequestHandler : IModelRequestHandler<OllamaRequest, OllamaResponse>
internal class OllamaRequestHandler(
[FromKeyedServices(Provider.Ollama)] IChatClient chatClient) : IModelRequestHandler<OllamaRequest, OllamaResponse>
{
private record OllamaServer(Uri BaseAddress, Process Process);

record Tags(List<Model> Models);
record Model(string Name);
record Progress(string Status);

private readonly IChatClient _chatClient;
private readonly OllamaConfiguration _ollamaConfiguration;
private readonly HttpClient _httpClient;
private readonly FileManager _fileManager;
private readonly ProcessManager _processManager;
private readonly ServerManager _serverManager;
private readonly ILogger<OllamaRequestHandler> _logger;

private readonly AsyncLazy<string> _ollamaExePath;
private readonly AsyncLazy<OllamaServer> _ollamaServer;

public OllamaRequestHandler(
[FromKeyedServices(Provider.Ollama)] IChatClient chatClient,
IHttpClientFactory httpClientFactory,
IOptions<OllamaConfiguration> ollamaConfiguration,
FileManager fileManager,
ProcessManager processManager,
ServerManager serverManager,
ILogger<OllamaRequestHandler> logger)
{
_chatClient = chatClient;
_httpClient = httpClientFactory.CreateClient(nameof(Provider.Ollama));
_ollamaConfiguration = ollamaConfiguration.Value;
_fileManager = fileManager;
_processManager = processManager;
_serverManager = serverManager;
_logger = logger;

_ollamaExePath = new AsyncLazy<string>(async () =>
{
var zipFileName = string.Join("-", _ollamaConfiguration.ZipUrl.Segments.Select(x => x.Replace("/", string.Empty)).TakeLast(2));
var zipFilePath = _fileManager.CreateCellmFilePath(zipFileName);

await _fileManager.DownloadFileIfNotExists(
_ollamaConfiguration.ZipUrl,
zipFilePath);

var ollamaPath = _fileManager.ExtractZipFileIfNotExtracted(
zipFilePath,
_fileManager.CreateCellmDirectory(nameof(Ollama), Path.GetFileNameWithoutExtension(zipFileName)));

return Path.Combine(ollamaPath, "ollama.exe");
});

_ollamaServer = new AsyncLazy<OllamaServer>(async () =>
{
var ollamaExePath = await _ollamaExePath;
var process = await StartProcess(ollamaExePath, _ollamaConfiguration.BaseAddress);

return new OllamaServer(_ollamaConfiguration.BaseAddress, process);
});
}

public async Task<OllamaResponse> Handle(OllamaRequest request, CancellationToken cancellationToken)
{
var serverIsRunning = await ServerIsRunning(_ollamaConfiguration.BaseAddress);
if (_ollamaConfiguration.EnableServer && !serverIsRunning)
{
_ = await _ollamaServer;
}

var modelIsDownloaded = await ModelIsDownloaded(
_ollamaConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _ollamaConfiguration.DefaultModel);

if (!modelIsDownloaded)
{
await DownloadModel(
_ollamaConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _ollamaConfiguration.DefaultModel);
}

var chatCompletion = await _chatClient.CompleteAsync(
var chatCompletion = await chatClient.CompleteAsync(
request.Prompt.Messages,
request.Prompt.Options,
cancellationToken);
Expand All @@ -103,76 +20,4 @@ await DownloadModel(

return new OllamaResponse(prompt);
}

private async Task<bool> ServerIsRunning(Uri baseAddress)
{
var response = await _httpClient.GetAsync(baseAddress);

return response.IsSuccessStatusCode;
}

private async Task<bool> ModelIsDownloaded(Uri baseAddress, string modelId)
{
var tags = await _httpClient.GetFromJsonAsync<Tags>("api/tags") ?? throw new CellmModelException();

return tags.Models.Select(x => x.Name).Contains(modelId);
}

private async Task DownloadModel(Uri baseAddress, string modelId)
{
try
{
var modelName = JsonSerializer.Serialize(new { name = modelId });
var modelStringContent = new StringContent(modelName, Encoding.UTF8, "application/json");
var response = await _httpClient.PostAsync("api/pull", modelStringContent);

response.EnsureSuccessStatusCode();

var progress = await response.Content.ReadFromJsonAsync<List<Progress>>();

if (progress is null || progress.Last().Status != "success")
{
throw new CellmModelException($"Ollama failed to download model {modelId}");
}
}
catch (HttpRequestException ex)
{
throw new CellmModelException($"Ollama failed to download model {modelId} or {modelId} does not exist", ex);
}
}

private async Task<Process> StartProcess(string ollamaExePath, Uri baseAddress)
{
var processStartInfo = new ProcessStartInfo(await _ollamaExePath);

processStartInfo.ArgumentList.Add("serve");
processStartInfo.EnvironmentVariables.Add("OLLAMA_HOST", baseAddress.ToString());

processStartInfo.UseShellExecute = false;
processStartInfo.CreateNoWindow = true;
processStartInfo.RedirectStandardError = true;
processStartInfo.RedirectStandardOutput = true;

var process = Process.Start(processStartInfo) ?? throw new CellmModelException("Failed to run Ollama");

process.OutputDataReceived += (sender, e) =>
{
if (!string.IsNullOrEmpty(e.Data))
{
_logger.LogDebug(e.Data);
Debug.WriteLine(e.Data);
}
};

process.BeginOutputReadLine();
process.BeginErrorReadLine();

var address = new Uri(baseAddress, "/v1/models");
await _serverManager.WaitForServer(address, process);

// Kill Ollama when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ namespace Cellm.Models.Providers.OpenAiCompatible;

internal record OpenAiCompatibleRequest(
Prompt Prompt,
Uri? BaseAddress = null,
Uri BaseAddress,
string? ApiKey = null) : IModelRequest<OpenAiCompatibleResponse>;
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ internal class OpenAiCompatibleRequestHandler(
IOptions<OpenAiCompatibleConfiguration> openAiCompatibleConfiguration)

Check warning on line 8 in src/Cellm/Models/Providers/OpenAiCompatible/OpenAiCompatibleRequestHandler.cs

View workflow job for this annotation

GitHub Actions / Build

Parameter 'openAiCompatibleConfiguration' is unread.

Check warning on line 8 in src/Cellm/Models/Providers/OpenAiCompatible/OpenAiCompatibleRequestHandler.cs

View workflow job for this annotation

GitHub Actions / Build

Parameter 'openAiCompatibleConfiguration' is unread.
: IModelRequestHandler<OpenAiCompatibleRequest, OpenAiCompatibleResponse>
{
private readonly OpenAiCompatibleConfiguration _openAiCompatibleConfiguration = openAiCompatibleConfiguration.Value;

public async Task<OpenAiCompatibleResponse> Handle(OpenAiCompatibleRequest request, CancellationToken cancellationToken)
{
var chatClient = openAiCompatibleChatClientFactory.Create(
request.BaseAddress ?? _openAiCompatibleConfiguration.BaseAddress,
request.Prompt.Options.ModelId ?? _openAiCompatibleConfiguration.DefaultModel,
request.ApiKey ?? _openAiCompatibleConfiguration.ApiKey);
request.BaseAddress,
request.Prompt.Options.ModelId ?? string.Empty,
request.ApiKey ?? "API_KEY");

var chatCompletion = await chatClient.CompleteAsync(request.Prompt.Messages, request.Prompt.Options, cancellationToken);

Expand Down
Loading

0 comments on commit 26da547

Please sign in to comment.