From d15022b4fb7d1cd41f338f610e2517fb89556063 Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Sun, 27 Oct 2024 07:45:32 -0700
Subject: [PATCH] fix issue https://github.com/SylphAI-Inc/AdalFlow/issues/237
---
adalflow/adalflow/core/generator.py | 35 +++++++++++++++++++++--------
adalflow/adalflow/utils/cache.py | 8 +++++--
adalflow/tests/test_generator.py | 27 ++++++++++++++++++++++
3 files changed, 59 insertions(+), 11 deletions(-)
diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py
index d3662063f..309f954fe 100644
--- a/adalflow/adalflow/core/generator.py
+++ b/adalflow/adalflow/core/generator.py
@@ -2,8 +2,9 @@
It is a pipeline that consists of three subcomponents."""
-import os
import json
+import re
+from pathlib import Path
from typing import Any, Dict, Optional, Union, Callable, Tuple, List
import logging
@@ -114,16 +115,14 @@ def __init__(
template = template or DEFAULT_LIGHTRAG_SYSTEM_PROMPT
- # Cache
- model_str = (
- f"{model_client.__class__.__name__}_{model_kwargs.get('model', 'default')}"
- )
- _cache_path = (
- get_adalflow_default_root_path() if cache_path is None else cache_path
+ # create the cache path and initialize the cache engine
+
+ self.set_cache_path(
+ cache_path, model_client, model_kwargs.get("model", "default")
)
- self.cache_path = os.path.join(_cache_path, f"cache_{model_str}.db")
CachedEngine.__init__(self, cache_path=self.cache_path)
+
Component.__init__(self)
GradComponent.__init__(self)
CallbackManager.__init__(self)
@@ -148,7 +147,6 @@ def __init__(
self.mock_output_data: str = "mock data"
self.data_map_func: Callable = None
self.set_data_map_func()
- self.model_str = model_str
self._use_cache = use_cache
self._kwargs = {
@@ -163,6 +161,25 @@ def __init__(
}
self._teacher: Optional["Generator"] = None
+ def set_cache_path(self, cache_path: str, model_client: object, model: str):
+ """Set the cache path for the generator."""
+
+ # Construct a valid model string using the client class name and model
+ self.model_str = f"{model_client.__class__.__name__}_{model}"
+
+ # Remove any characters that are not allowed in file names (cross-platform)
+ # On Windows, characters like `:<>?/\|*` are prohibited.
+ self.model_str = re.sub(r"[^a-zA-Z0-9_\-]", "_", self.model_str)
+
+ _cache_path = (
+ get_adalflow_default_root_path() if cache_path is None else cache_path
+ )
+
+ # Use pathlib to handle paths more safely across OS
+ self.cache_path = Path(_cache_path) / f"cache_{self.model_str}.db"
+
+ log.debug(f"Cache path set to: {self.cache_path}")
+
def get_cache_path(self) -> str:
r"""Get the cache path for the generator."""
return self.cache_path
diff --git a/adalflow/adalflow/utils/cache.py b/adalflow/adalflow/utils/cache.py
index 31fccdaa9..c330cdc89 100644
--- a/adalflow/adalflow/utils/cache.py
+++ b/adalflow/adalflow/utils/cache.py
@@ -1,5 +1,7 @@
import hashlib
import diskcache as dc
+from pathlib import Path
+from typing import Union
def hash_text(text: str):
@@ -15,9 +17,11 @@ def direct(text: str):
class CachedEngine:
- def __init__(self, cache_path: str):
+ def __init__(self, cache_path: Union[str, Path]):
super().__init__()
- self.cache_path = cache_path
+ self.cache_path = Path(cache_path)
+ self.cache_path.parent.mkdir(parents=True, exist_ok=True)
+
self.cache = dc.Cache(cache_path)
def _check_cache(self, prompt: str):
diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py
index 5ea8b76d5..a15c302a5 100644
--- a/adalflow/tests/test_generator.py
+++ b/adalflow/tests/test_generator.py
@@ -3,6 +3,7 @@
import unittest
import os
import shutil
+from pathlib import Path
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion
@@ -55,6 +56,32 @@ def test_generator_call(self):
print(f"output: {output}")
# self.assertEqual(output.data, "Generated text response")
+ def test_cache_path(self):
+ prompt_kwargs = {"input_str": "Hello, world!"}
+ model_kwargs = {"model": "phi3.5:latest"}
+
+ self.test_generator = Generator(
+ model_client=self.mock_api_client,
+ prompt_kwargs=prompt_kwargs,
+ model_kwargs=model_kwargs,
+ use_cache=True,
+ )
+
+ # Convert the path to a string to avoid the TypeError
+ cache_path = self.test_generator.get_cache_path()
+ cache_path_str = str(cache_path)
+
+ print(f"cache path: {cache_path}")
+
+ # Check if the sanitized model string is in the cache path
+ self.assertIn("phi3_5_latest", cache_path_str)
+
+ # Check if the cache path exists as a file (or directory, depending on your use case)
+
+ self.assertTrue(
+ Path(cache_path).exists(), f"Cache path {cache_path_str} does not exist"
+ )
+
def test_generator_prompt_logger_first_record(self):
# prompt_kwargs = {"input_str": "Hello, world!"}
# model_kwargs = {"model": "gpt-3.5-turbo"}