Skip to content

Commit

Permalink
feat: Add Ollama provider
Browse files Browse the repository at this point in the history
  • Loading branch information
kaspermarstal committed Nov 21, 2024
1 parent 27609c0 commit 1b7e9e6
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 145 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ To get started, we recommend using Ollama with the Gemma 2 2B model:
1. Rename `appsettings.Ollama.json` to `appsettings.Local.json`,
2. Build and install Cellm.
3. Run the following command in the docker directory:
3. Run the following command in the `docker/` directory:
```cmd
docker compose -f docker-compose.Ollama.yml up --detach
docker compose -f docker-compose.Ollama.yml exec backend ollama pull gemma2:2b
Expand Down
2 changes: 2 additions & 0 deletions src/Cellm/Models/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Cellm.AddIn.Exceptions;
using Cellm.Models.Anthropic;
using Cellm.Models.Llamafile;
using Cellm.Models.Ollama;
using Cellm.Models.OpenAi;
using Cellm.Prompts;
using MediatR;
Expand Down Expand Up @@ -37,6 +38,7 @@ public async Task<Prompt> Send(Prompt prompt, string? provider, Uri? baseAddress
{
Providers.Anthropic => await _sender.Send(new AnthropicRequest(prompt, provider, baseAddress)),
Providers.Llamafile => await _sender.Send(new LlamafileRequest(prompt)),
Providers.Ollama => await _sender.Send(new OllamaRequest(prompt, provider, baseAddress)),
Providers.OpenAi => await _sender.Send(new OpenAiRequest(prompt, provider, baseAddress)),
_ => throw new InvalidOperationException($"Provider {parsedProvider} is defined but not implemented")
};
Expand Down
143 changes: 16 additions & 127 deletions src/Cellm/Models/Llamafile/LlamafileRequestHandler.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Diagnostics;
using System.Net.NetworkInformation;
using Cellm.AddIn;
using Cellm.AddIn.Exceptions;
using Cellm.Models.Local;
using Cellm.Models.OpenAi;
using MediatR;
using Microsoft.Extensions.Options;
Expand All @@ -14,41 +14,42 @@ private record Llamafile(string ModelPath, Uri BaseAddress, Process Process);

private readonly AsyncLazy<string> _llamafileExePath;
private readonly Dictionary<string, AsyncLazy<Llamafile>> _llamafiles;
private readonly LLamafileProcessManager _llamafileProcessManager;
private readonly ProcessManager _processManager;

private readonly CellmConfiguration _cellmConfiguration;
private readonly LlamafileConfiguration _llamafileConfiguration;
private readonly OpenAiConfiguration _openAiConfiguration;

private readonly ISender _sender;
private readonly HttpClient _httpClient;
private readonly LocalUtilities _localUtilities;

public LlamafileRequestHandler(IOptions<CellmConfiguration> cellmConfiguration,
IOptions<LlamafileConfiguration> llamafileConfiguration,
IOptions<OpenAiConfiguration> openAiConfiguration,
ISender sender,
HttpClient httpClient,
LLamafileProcessManager llamafileProcessManager)
LocalUtilities localUtilities,
ProcessManager processManager)
{
_cellmConfiguration = cellmConfiguration.Value;
_llamafileConfiguration = llamafileConfiguration.Value;
_openAiConfiguration = openAiConfiguration.Value;
_sender = sender;
_httpClient = httpClient;
_llamafileProcessManager = llamafileProcessManager;
_localUtilities = localUtilities;
_processManager = processManager;

_llamafileExePath = new AsyncLazy<string>(async () =>
{
return await DownloadFile(_llamafileConfiguration.LlamafileUrl, $"{nameof(Llamafile)}.exe");
var llamafileName = Path.GetFileName(_llamafileConfiguration.LlamafileUrl.Segments.Last());
return await _localUtilities.DownloadFile(_llamafileConfiguration.LlamafileUrl, $"{llamafileName}.exe");
});

_llamafiles = _llamafileConfiguration.Models.ToDictionary(x => x.Key, x => new AsyncLazy<Llamafile>(async () =>
{
// Download model
var modelPath = await DownloadFile(x.Value, CreateFilePath(CreateModelFileName(x.Key)));
var modelPath = await _localUtilities.DownloadFile(x.Value, _localUtilities.CreateCellmFilePath(CreateModelFileName(x.Key)));

// Run Llamafile
var baseAddress = CreateBaseAddress();
// Start server
var baseAddress = new UriBuilder("http", "localhost", _localUtilities.FindPort()).Uri;
var process = await StartProcess(modelPath, baseAddress);

return new Llamafile(modelPath, baseAddress, process);
Expand Down Expand Up @@ -101,130 +102,18 @@ private async Task<Process> StartProcess(string modelPath, Uri baseAddress)
process.BeginErrorReadLine();
}

await WaitForLlamafile(baseAddress, process);
var address = new Uri(baseAddress, "health");
await _localUtilities.WaitForServer(address, process);

// Kill the process when Excel exits or dies
_llamafileProcessManager.AssignProcessToExcel(process);
// Kill Llamafile when Excel exits or dies
_processManager.AssignProcessToExcel(process);

return process;
}

private async Task<string> DownloadFile(Uri uri, string filePath)
{
if (File.Exists(filePath))
{
return filePath;
}

var filePathPart = $"{filePath}.part";

if (File.Exists(filePathPart))
{
File.Delete(filePathPart);
}

var response = await _httpClient.GetAsync(uri, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();

using (var fileStream = File.Create(filePathPart))
using (var httpStream = await response.Content.ReadAsStreamAsync())
{

await httpStream.CopyToAsync(fileStream);
}

File.Move(filePathPart, filePath);

return filePath;
}

private async Task WaitForLlamafile(Uri baseAddress, Process process)
{
var startTime = DateTime.UtcNow;

// Wait max 30 seconds to load model
while ((DateTime.UtcNow - startTime).TotalSeconds < 30)
{
if (process.HasExited)
{
throw new CellmException($"Failed to run Llamafile, process exited. Exit code: {process.ExitCode}");
}

try
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var response = await _httpClient.GetAsync(new Uri(baseAddress, "health"), cancellationTokenSource.Token);
if (response.StatusCode == System.Net.HttpStatusCode.OK)
{
// Server is ready
return;
}
}
catch (HttpRequestException)
{
}
catch (TaskCanceledException)
{
}

// Wait before next attempt
await Task.Delay(500);
}

process.Kill();

throw new CellmException("Failed to run Llamafile, timeout waiting for Llamafile server to start");
}

string CreateFilePath(string fileName)
{
var filePath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), nameof(Cellm), fileName);
Directory.CreateDirectory(Path.GetDirectoryName(filePath) ?? throw new CellmException("Failed to create Llamafile folder"));
return filePath;
}

private static string CreateModelFileName(string modelName)
{
return $"Llamafile-model-{modelName}";
}

private Uri CreateBaseAddress()
{
var uriBuilder = new UriBuilder(_llamafileConfiguration.BaseAddress)
{
Port = GetFirstUnusedPort()
};

return uriBuilder.Uri;
}

private static int GetFirstUnusedPort(ushort min = 49152, ushort max = 65535)
{
if (max < min)
{
throw new ArgumentException("Max port must be larger than min port.");
}

var ipProperties = IPGlobalProperties.GetIPGlobalProperties();

var activePorts = ipProperties.GetActiveTcpConnections()
.Where(connection => connection.State != TcpState.Closed)
.Select(connection => connection.LocalEndPoint)
.Concat(ipProperties.GetActiveTcpListeners())
.Concat(ipProperties.GetActiveUdpListeners())
.Select(endpoint => endpoint.Port)
.ToArray();

var firstInactivePort = Enumerable.Range(min, max)
.Where(port => !activePorts.Contains(port))
.FirstOrDefault();

if (firstInactivePort == default)
{
throw new CellmException($"All local TCP ports between {min} and {max} are currently in use.");
}

return firstInactivePort;
}
}

150 changes: 150 additions & 0 deletions src/Cellm/Models/Local/LocalUtilities.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
using System.Diagnostics;
using System.IO.Compression;
using System.Net.NetworkInformation;
using Cellm.AddIn.Exceptions;
using Microsoft.Office.Interop.Excel;

namespace Cellm.Models.Local;

internal class LocalUtilities(HttpClient httpClient)
{
public async Task<string> DownloadFile(Uri uri, string filePath)
{
if (File.Exists(filePath))
{
return filePath;
}

var filePathPart = $"{filePath}.part";

if (File.Exists(filePathPart))
{
File.Delete(filePathPart);
}

var response = await httpClient.GetAsync(uri, HttpCompletionOption.ResponseHeadersRead);
response.EnsureSuccessStatusCode();

using (var fileStream = File.Create(filePathPart))
using (var httpStream = await response.Content.ReadAsStreamAsync())
{

await httpStream.CopyToAsync(fileStream);
}

File.Move(filePathPart, filePath);

return filePath;
}

public async Task WaitForServer(Uri endpoint, Process process)
{
var startTime = DateTime.UtcNow;

// Wait max 30 seconds to load model
while ((DateTime.UtcNow - startTime).TotalSeconds < 30)
{
if (process.HasExited)
{
throw new CellmException($"Failed to run Llamafile, process exited. Exit code: {process.ExitCode}");
}

try
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(1));
var response = await httpClient.GetAsync(endpoint, cancellationTokenSource.Token);
if (response.StatusCode == System.Net.HttpStatusCode.OK)
{
// Server is ready
return;
}
}
catch (HttpRequestException)
{
}
catch (TaskCanceledException)
{
}

// Wait before next attempt
await Task.Delay(500);
}

process.Kill();

throw new CellmException("Failed to run Llamafile, timeout waiting for Llamafile server to start");
}

public string CreateCellmDirectory(params string[] subFolders)
{
var folderPath = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), nameof(Cellm));

if (subFolders.Length > 0)
{
folderPath = Path.Combine(subFolders.Prepend(folderPath).ToArray());
}

Directory.CreateDirectory(folderPath);
return folderPath;
}

public string CreateCellmFilePath(string fileName)
{
return Path.Combine(CreateCellmDirectory(), fileName);
}

public int FindPort(ushort min = 49152, ushort max = 65535)
{
if (max < min)
{
throw new ArgumentException("Max port must be larger than min port.");
}

var ipProperties = IPGlobalProperties.GetIPGlobalProperties();

var activePorts = ipProperties.GetActiveTcpConnections()
.Where(connection => connection.State != TcpState.Closed)
.Select(connection => connection.LocalEndPoint)
.Concat(ipProperties.GetActiveTcpListeners())
.Concat(ipProperties.GetActiveUdpListeners())
.Select(endpoint => endpoint.Port)
.ToArray();

var firstInactivePort = Enumerable.Range(min, max)
.Where(port => !activePorts.Contains(port))
.FirstOrDefault();

if (firstInactivePort == default)
{
throw new CellmException($"All local TCP ports between {min} and {max} are currently in use.");
}

return firstInactivePort;
}

public string ExtractFile(string zipFilePath, string targetDirectory)
{
using (ZipArchive archive = ZipFile.OpenRead(zipFilePath))
{
foreach (ZipArchiveEntry entry in archive.Entries)
{
string destinationPath = Path.Combine(targetDirectory, entry.FullName);

if (!File.Exists(destinationPath))
{
ZipFile.ExtractToDirectory(zipFilePath, targetDirectory);
return targetDirectory;
}

var fileInfo = new FileInfo(destinationPath);
if (fileInfo.Length != entry.Length)
{
ZipFile.ExtractToDirectory(zipFilePath, targetDirectory);
return targetDirectory;
}
}
}

return targetDirectory;
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using System.Diagnostics;
using System.Runtime.InteropServices;

public class LLamafileProcessManager
public class ProcessManager
{
[DllImport("kernel32.dll", CharSet = CharSet.Unicode)]
static extern IntPtr CreateJobObject(IntPtr a, string lpName);
Expand Down Expand Up @@ -61,7 +61,7 @@ enum JobObjectInfoType

private IntPtr _jobObject;

public LLamafileProcessManager()
public ProcessManager()
{
_jobObject = CreateJobObject(IntPtr.Zero, string.Empty);

Expand Down
Loading

0 comments on commit 1b7e9e6

Please sign in to comment.