Skip to content

Commit

Permalink
fix: CacheBehavior caches value even when model request is mutated do…
Browse files Browse the repository at this point in the history
…wnstream
  • Loading branch information
kaspermarstal committed Nov 17, 2024
1 parent 2518502 commit 4c2fac0
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 59 deletions.
42 changes: 0 additions & 42 deletions src/Cellm/Models/Cache.cs

This file was deleted.

39 changes: 29 additions & 10 deletions src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs
Original file line number Diff line number Diff line change
@@ -1,32 +1,51 @@
using MediatR;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using Cellm.AddIn;
using MediatR;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Options;

namespace Cellm.Models.PipelineBehavior;

internal class CachingBehavior<TRequest, TResponse> : IPipelineBehavior<TRequest, TResponse>
where TRequest : IModelRequest<TResponse>
where TResponse : IModelResponse
{
private readonly Cache _cache;
private readonly IMemoryCache _memoryCache;
private readonly MemoryCacheEntryOptions _memoryCacheEntryOptions;

public CachingBehavior(Cache cache)
public CachingBehavior(IMemoryCache memoryCache, IOptions<CellmConfiguration> _cellmConfiguration)
{
_cache = cache;
_memoryCache = memoryCache;
_memoryCacheEntryOptions = new()
{
SlidingExpiration = TimeSpan.FromSeconds(_cellmConfiguration.Value.CacheTimeoutInSeconds)
};
}

public async Task<TResponse> Handle(TRequest request, RequestHandlerDelegate<TResponse> next, CancellationToken cancellationToken)
{
if (_cache.TryGetValue(request, out object? value) && value is TResponse response)
var key = GetKey(request);

if (_memoryCache.TryGetValue(key, out object? value) && value is TResponse response)
{
return response;
}

response = await next();

// Tool results depend on state external to prompt and should not be cached
if (!request.Prompt.Messages.Any(x => x.Role == Prompts.Roles.Tool))
{
_cache.Set(request, response);
}
_memoryCache.Set(key, response, _memoryCacheEntryOptions);

return response;
}

private static string GetKey<T>(T key)
{
var json = JsonSerializer.Serialize(key);
var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(json));
var hash = Convert.ToHexString(bytes);

return hash;
}
}
8 changes: 1 addition & 7 deletions src/Cellm/Services/ServiceLocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
using Cellm.Services.Configuration;
using Cellm.Tools;
using Cellm.Tools.FileReader;
using Cellm.Tools.Glob;
using ExcelDna.Integration;
using MediatR;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Sentry.Profiling;

namespace Cellm.Services;

Expand Down Expand Up @@ -89,16 +87,12 @@ private static IServiceCollection ConfigureServices(IServiceCollection services)
// Internals
services
.AddSingleton(configuration)
.AddMemoryCache()
.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(Assembly.GetExecutingAssembly()))
.AddTransient<PromptWithArgumentParser>()
.AddSingleton<Client>()
.AddSingleton<Serde>();

// Cache
services
.AddMemoryCache()
.AddSingleton<Cache>();

// Tools
services
.AddSingleton<ToolRunner>()
Expand Down

0 comments on commit 4c2fac0

Please sign in to comment.