diff --git a/.github/tests/lm_tests.py b/.github/tests/lm_tests.py index 94fcd59..3a18bf8 100644 --- a/.github/tests/lm_tests.py +++ b/.github/tests/lm_tests.py @@ -5,6 +5,7 @@ from tokenizers import Tokenizer import lotus +from lotus.cache import CacheConfig, CacheFactory, CacheType from lotus.models import LM, SentenceTransformersRM from lotus.types import CascadeArgs @@ -398,7 +399,7 @@ def test_custom_tokenizer(): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=True) + lotus.settings.configure(lm=lm, enable_message_cache=True) # Check that "What is the capital of France?" becomes cached first_batch = [ @@ -427,7 +428,7 @@ def test_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_disable_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=False) + lotus.settings.configure(lm=lm, enable_message_cache=False) batch = [ [{"role": "user", "content": "Hello, world!"}], @@ -439,7 +440,7 @@ def test_disable_cache(setup_models, model): assert lm.stats.total_usage.cache_hits == 0 # Now enable cache. Note that the first batch is not cached. - lotus.settings.configure(enable_cache=True) + lotus.settings.configure(enable_message_cache=True) first_responses = lm(batch).outputs assert lm.stats.total_usage.cache_hits == 0 second_responses = lm(batch).outputs @@ -450,7 +451,7 @@ def test_disable_cache(setup_models, model): @pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) def test_reset_cache(setup_models, model): lm = setup_models[model] - lotus.settings.configure(lm=lm, enable_cache=True) + lotus.settings.configure(lm=lm, enable_message_cache=True) batch = [ [{"role": "user", "content": "Hello, world!"}], @@ -472,3 +473,120 @@ def test_reset_cache(setup_models, model): assert lm.stats.total_usage.cache_hits == 3 lm(batch) assert lm.stats.total_usage.cache_hits == 3 + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_operator_cache(setup_models, model): + cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000) + cache = CacheFactory.create_cache(cache_config) + + lm = LM(model="gpt-4o-mini", cache=cache) + lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True) + + data = { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ] + } + + expected_response = pd.DataFrame( + { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ], + "_map": [ + "Process Dynamics and Control", + "Advanced Optimization Techniques in Engineering", + "Reaction Kinetics and Mechanisms", + "Fluid Mechanics and Mass Transfer", + ], + } + ) + + df = pd.DataFrame(data) + user_instruction = "What is a similar course to {Course Name}. Please just output the course name." + + first_response = df.sem_map(user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 0 + + second_response = df.sem_map(user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 1 + + first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + + pd.testing.assert_frame_equal(first_response, second_response) + pd.testing.assert_frame_equal(first_response, expected_response) + pd.testing.assert_frame_equal(second_response, expected_response) + + lm.reset_cache() + lm.reset_stats() + assert lm.stats.total_usage.operator_cache_hits == 0 + + +@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini")) +def test_disable_operator_cache(setup_models, model): + cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000) + cache = CacheFactory.create_cache(cache_config) + + lm = LM(model="gpt-4o-mini", cache=cache) + lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False) + + data = { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ] + } + + expected_response = pd.DataFrame( + { + "Course Name": [ + "Dynamics and Control of Chemical Processes", + "Optimization Methods in Engineering", + "Chemical Kinetics and Catalysis", + "Transport Phenomena and Separations", + ], + "_map": [ + "Process Dynamics and Control", + "Advanced Optimization Techniques in Engineering", + "Reaction Kinetics and Mechanisms", + "Fluid Mechanics and Mass Transfer", + ], + } + ) + + df = pd.DataFrame(data) + user_instruction = "What is a similar course to {Course Name}. Please just output the course name." + + first_response = df.sem_map(user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 0 + + second_response = df.sem_map(user_instruction) + assert lm.stats.total_usage.operator_cache_hits == 0 + + pd.testing.assert_frame_equal(first_response, second_response) + + # Now enable operator cache. + lotus.settings.configure(enable_operator_cache=True) + first_responses = df.sem_map(user_instruction) + first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + assert lm.stats.total_usage.operator_cache_hits == 0 + second_responses = df.sem_map(user_instruction) + second_responses["_map"] = second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + assert lm.stats.total_usage.operator_cache_hits == 1 + + expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower() + + pd.testing.assert_frame_equal(first_responses, second_responses) + pd.testing.assert_frame_equal(first_responses, expected_response) + pd.testing.assert_frame_equal(second_responses, expected_response) diff --git a/docs/configurations.rst b/docs/configurations.rst index 79c5a46..0ee0172 100644 --- a/docs/configurations.rst +++ b/docs/configurations.rst @@ -21,12 +21,12 @@ Using the Settings module Configurable Parameters -------------------------- -1. enable_cache: +1. enable_message_cache: * Description: Enables or Disables cahcing mechanisms * Default: False .. code-block:: python - lotus.settings.configure(enable_cache=True) + lotus.settings.configure(enable_message_cache=True) 2. setting RM: * Description: Configures the retrieval model diff --git a/examples/model_examples/cache.py b/examples/model_examples/cache.py index 5314ca5..95bc282 100644 --- a/examples/model_examples/cache.py +++ b/examples/model_examples/cache.py @@ -11,7 +11,7 @@ lm = LM(model="gpt-4o-mini", cache=cache) -lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False +lotus.settings.configure(lm=lm, enable_message_cache=True) # default caching is False data = { "Course Name": [ "Probability and Random Processes", diff --git a/lotus/cache.py b/lotus/cache.py index 74cadd5..82c1c4c 100644 --- a/lotus/cache.py +++ b/lotus/cache.py @@ -1,3 +1,5 @@ +import hashlib +import json import os import pickle import sqlite3 @@ -8,6 +10,8 @@ from functools import wraps from typing import Any, Callable +import pandas as pd + import lotus @@ -16,13 +20,52 @@ def require_cache_enabled(func: Callable) -> Callable: @wraps(func) def wrapper(self, *args, **kwargs): - if not lotus.settings.enable_cache: + if not lotus.settings.enable_message_cache: return None return func(self, *args, **kwargs) return wrapper +def operator_cache(func: Callable) -> Callable: + """Decorator to add operator level caching.""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + model = lotus.settings.lm + use_operator_cache = lotus.settings.enable_operator_cache + + if use_operator_cache and model.cache: + + def serialize(value): + if isinstance(value, pd.DataFrame): + return value.to_json() + elif hasattr(value, "dict"): + return value.dict() + return value + + serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} + serialized_args = [serialize(arg) for arg in args] + cache_key = hashlib.sha256( + json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode() + ).hexdigest() + + cached_result = model.cache.get(cache_key) + if cached_result is not None: + lotus.logger.debug(f"Cache hit for {cache_key}") + model.stats.total_usage.operator_cache_hits += 1 + return cached_result + lotus.logger.debug(f"Cache miss for {cache_key}") + + result = func(self, *args, **kwargs) + model.cache.insert(cache_key, result) + return result + + return func(self, *args, **kwargs) + + return wrapper + + class CacheType(Enum): IN_MEMORY = "in_memory" SQLITE = "sqlite" diff --git a/lotus/sem_ops/sem_agg.py b/lotus/sem_ops/sem_agg.py index 7bb67ab..706f12f 100644 --- a/lotus/sem_ops/sem_agg.py +++ b/lotus/sem_ops/sem_agg.py @@ -3,6 +3,7 @@ import pandas as pd import lotus.models +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticAggOutput @@ -148,6 +149,7 @@ def process_group(args): group, user_instruction, all_cols, suffix, progress_bar_desc = args return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc) + @operator_cache def __call__( self, user_instruction: str, diff --git a/lotus/sem_ops/sem_cluster_by.py b/lotus/sem_ops/sem_cluster_by.py index 5811101..fc8a9c6 100644 --- a/lotus/sem_ops/sem_cluster_by.py +++ b/lotus/sem_ops/sem_cluster_by.py @@ -4,6 +4,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache @pd.api.extensions.register_dataframe_accessor("sem_cluster_by") @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, col_name: str, @@ -52,7 +54,7 @@ def __call__( self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index) # if return_scores: # self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index) - + # if return_centroids: # return self._obj, centroids # else: diff --git a/lotus/sem_ops/sem_extract.py b/lotus/sem_ops/sem_extract.py index 053dc5a..ed9619e 100644 --- a/lotus/sem_ops/sem_extract.py +++ b/lotus/sem_ops/sem_extract.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.models import LM from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput @@ -33,7 +34,6 @@ def sem_extract( Returns: SemanticExtractOutput: The outputs, raw outputs, and quotes. """ - # prepare model inputs inputs = [] for doc in docs: @@ -72,6 +72,7 @@ def _validate(obj: pd.DataFrame) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, input_cols: list[str], diff --git a/lotus/sem_ops/sem_filter.py b/lotus/sem_ops/sem_filter.py index c03a064..d6253b8 100644 --- a/lotus/sem_ops/sem_filter.py +++ b/lotus/sem_ops/sem_filter.py @@ -5,6 +5,7 @@ from numpy.typing import NDArray import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput from lotus.utils import show_safe_mode @@ -134,6 +135,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, user_instruction: str, diff --git a/lotus/sem_ops/sem_join.py b/lotus/sem_ops/sem_join.py index abb2765..0050f49 100644 --- a/lotus/sem_ops/sem_join.py +++ b/lotus/sem_ops/sem_join.py @@ -4,6 +4,7 @@ from tqdm import tqdm import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import CascadeArgs, SemanticJoinOutput from lotus.utils import show_safe_mode @@ -234,7 +235,6 @@ def sem_join_cascade( cot_reasoning=cot_reasoning, default=default, strategy=strategy, - show_progress_bar=False, ) pbar.update(num_large) pbar.close() @@ -545,6 +545,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, other: pd.DataFrame | pd.Series, diff --git a/lotus/sem_ops/sem_map.py b/lotus/sem_ops/sem_map.py index 9e8b991..9708bb1 100644 --- a/lotus/sem_ops/sem_map.py +++ b/lotus/sem_ops/sem_map.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput from lotus.utils import show_safe_mode @@ -80,6 +81,7 @@ def _validate(obj: pd.DataFrame) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, user_instruction: str, diff --git a/lotus/sem_ops/sem_search.py b/lotus/sem_ops/sem_search.py index 5846cc1..de2df35 100644 --- a/lotus/sem_ops/sem_search.py +++ b/lotus/sem_ops/sem_search.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.types import RerankerOutput, RMOutput @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, col_name: str, diff --git a/lotus/sem_ops/sem_sim_join.py b/lotus/sem_ops/sem_sim_join.py index b1fd986..47d3cbe 100644 --- a/lotus/sem_ops/sem_sim_join.py +++ b/lotus/sem_ops/sem_sim_join.py @@ -3,6 +3,7 @@ import pandas as pd import lotus +from lotus.cache import operator_cache from lotus.models import RM from lotus.types import RMOutput @@ -20,6 +21,7 @@ def _validate(obj: Any) -> None: if not isinstance(obj, pd.DataFrame): raise AttributeError("Must be a DataFrame") + @operator_cache def __call__( self, other: pd.DataFrame, diff --git a/lotus/sem_ops/sem_topk.py b/lotus/sem_ops/sem_topk.py index c92e81d..b5ecd5e 100644 --- a/lotus/sem_ops/sem_topk.py +++ b/lotus/sem_ops/sem_topk.py @@ -7,6 +7,7 @@ from tqdm import tqdm import lotus +from lotus.cache import operator_cache from lotus.templates import task_instructions from lotus.types import LMOutput, SemanticTopKOutput from lotus.utils import show_safe_mode @@ -386,6 +387,7 @@ def process_group(args): return_stats=return_stats, ) + @operator_cache def __call__( self, user_instruction: str, @@ -438,7 +440,7 @@ def __call__( with ThreadPoolExecutor(max_workers=lotus.settings.parallel_groupby_max_threads) as executor: results = list(executor.map(SemTopKDataframe.process_group, group_args)) - + if return_stats: new_df = pd.concat([res[0] for res in results]) stats = {name: res[1] for name, res in zip(grouped.groups.keys(), results)} diff --git a/lotus/settings.py b/lotus/settings.py index ce12363..99e5944 100644 --- a/lotus/settings.py +++ b/lotus/settings.py @@ -12,7 +12,8 @@ class Settings: reranker: lotus.models.Reranker | None = None # Cache settings - enable_cache: bool = False + enable_message_cache: bool = False + enable_operator_cache: bool = False # Serialization setting serialization_format: SerializationFormat = SerializationFormat.DEFAULT diff --git a/lotus/types.py b/lotus/types.py index 96b9079..c4cbb6d 100644 --- a/lotus/types.py +++ b/lotus/types.py @@ -32,6 +32,7 @@ class TotalUsage(BaseModel): total_tokens: int = 0 total_cost: float = 0.0 cache_hits: int = 0 + operator_cache_hits: int = 0 total_usage: TotalUsage = TotalUsage() diff --git a/tests/test_settings.py b/tests/test_settings.py index dc6f871..4f251bb 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -13,12 +13,12 @@ def test_initial_values(self, settings): assert settings.rm is None assert settings.helper_lm is None assert settings.reranker is None - assert settings.enable_cache is False + assert settings.enable_message_cache is False assert settings.serialization_format == SerializationFormat.DEFAULT def test_configure_method(self, settings): - settings.configure(enable_cache=True) - assert settings.enable_cache is True + settings.configure(enable_message_cache=True) + assert settings.enable_message_cache is True def test_invalid_setting(self, settings): with pytest.raises(ValueError, match="Invalid setting: invalid_setting"):