From 82437efef6d87a9c584641dd2eb5b9aa3d1ee021 Mon Sep 17 00:00:00 2001 From: Kasper Marstal Date: Sun, 17 Nov 2024 22:20:25 +0100 Subject: [PATCH] fix: CacheBehavior caches value even when model request is mutated downstream (#56) --- src/Cellm/Models/Cache.cs | 42 ------------------- .../ModelRequestBehavior/CachingBehavior.cs | 39 ++++++++++++----- src/Cellm/Services/ServiceLocator.cs | 8 +--- 3 files changed, 30 insertions(+), 59 deletions(-) delete mode 100644 src/Cellm/Models/Cache.cs diff --git a/src/Cellm/Models/Cache.cs b/src/Cellm/Models/Cache.cs deleted file mode 100644 index 8df299b..0000000 --- a/src/Cellm/Models/Cache.cs +++ /dev/null @@ -1,42 +0,0 @@ -using System.Security.Cryptography; -using System.Text; -using System.Text.Json; -using Cellm.AddIn; -using Microsoft.Extensions.Caching.Memory; -using Microsoft.Extensions.Options; - -namespace Cellm.Models; - -internal class Cache -{ - private readonly IMemoryCache _memoryCache; - private readonly MemoryCacheEntryOptions _memoryCacheEntryOptions; - - public Cache(IMemoryCache memoryCache, IOptions _cellmConfiguration) - { - _memoryCache = memoryCache; - _memoryCacheEntryOptions = new() - { - SlidingExpiration = TimeSpan.FromSeconds(_cellmConfiguration.Value.CacheTimeoutInSeconds) - }; - } - - public void Set(T key, U value) - { - _memoryCache.Set(GetHash(key), value, _memoryCacheEntryOptions); - } - - public bool TryGetValue(T key, out object? value) - { - return _memoryCache.TryGetValue(GetHash(key), out value); - } - - private static string GetHash(T key) - { - var json = JsonSerializer.Serialize(key); - var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(json)); - var hash = Convert.ToHexString(bytes); - - return hash; - } -} diff --git a/src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs b/src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs index 3a319a5..0f35565 100644 --- a/src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs +++ b/src/Cellm/Models/ModelRequestBehavior/CachingBehavior.cs @@ -1,32 +1,51 @@ -using MediatR; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using Cellm.AddIn; +using MediatR; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Options; namespace Cellm.Models.PipelineBehavior; internal class CachingBehavior : IPipelineBehavior where TRequest : IModelRequest + where TResponse : IModelResponse { - private readonly Cache _cache; + private readonly IMemoryCache _memoryCache; + private readonly MemoryCacheEntryOptions _memoryCacheEntryOptions; - public CachingBehavior(Cache cache) + public CachingBehavior(IMemoryCache memoryCache, IOptions _cellmConfiguration) { - _cache = cache; + _memoryCache = memoryCache; + _memoryCacheEntryOptions = new() + { + SlidingExpiration = TimeSpan.FromSeconds(_cellmConfiguration.Value.CacheTimeoutInSeconds) + }; } public async Task Handle(TRequest request, RequestHandlerDelegate next, CancellationToken cancellationToken) { - if (_cache.TryGetValue(request, out object? value) && value is TResponse response) + var key = GetKey(request); + + if (_memoryCache.TryGetValue(key, out object? value) && value is TResponse response) { return response; } response = await next(); - // Tool results depend on state external to prompt and should not be cached - if (!request.Prompt.Messages.Any(x => x.Role == Prompts.Roles.Tool)) - { - _cache.Set(request, response); - } + _memoryCache.Set(key, response, _memoryCacheEntryOptions); return response; } + + private static string GetKey(T key) + { + var json = JsonSerializer.Serialize(key); + var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(json)); + var hash = Convert.ToHexString(bytes); + + return hash; + } } diff --git a/src/Cellm/Services/ServiceLocator.cs b/src/Cellm/Services/ServiceLocator.cs index d7a4d64..f4a8c2c 100644 --- a/src/Cellm/Services/ServiceLocator.cs +++ b/src/Cellm/Services/ServiceLocator.cs @@ -10,14 +10,12 @@ using Cellm.Services.Configuration; using Cellm.Tools; using Cellm.Tools.FileReader; -using Cellm.Tools.Glob; using ExcelDna.Integration; using MediatR; using Microsoft.Extensions.Caching.Memory; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Sentry.Profiling; namespace Cellm.Services; @@ -89,16 +87,12 @@ private static IServiceCollection ConfigureServices(IServiceCollection services) // Internals services .AddSingleton(configuration) + .AddMemoryCache() .AddMediatR(cfg => cfg.RegisterServicesFromAssembly(Assembly.GetExecutingAssembly())) .AddTransient() .AddSingleton() .AddSingleton(); - // Cache - services - .AddMemoryCache() - .AddSingleton(); - // Tools services .AddSingleton()