Skip to content

Commit

Permalink
operator level cache (#65)
Browse files Browse the repository at this point in the history
operator level caching
  • Loading branch information
StanChan03 authored Dec 27, 2024
1 parent 1ec4d8a commit 9761855
Show file tree
Hide file tree
Showing 16 changed files with 195 additions and 16 deletions.
126 changes: 122 additions & 4 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tokenizers import Tokenizer

import lotus
from lotus.cache import CacheConfig, CacheFactory, CacheType
from lotus.models import LM, SentenceTransformersRM
from lotus.types import CascadeArgs

Expand Down Expand Up @@ -398,7 +399,7 @@ def test_custom_tokenizer():
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)
lotus.settings.configure(lm=lm, enable_message_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
Expand Down Expand Up @@ -427,7 +428,7 @@ def test_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=False)
lotus.settings.configure(lm=lm, enable_message_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand All @@ -439,7 +440,7 @@ def test_disable_cache(setup_models, model):
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_cache=True)
lotus.settings.configure(enable_message_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
Expand All @@ -450,7 +451,7 @@ def test_disable_cache(setup_models, model):
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)
lotus.settings.configure(lm=lm, enable_message_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
Expand All @@ -472,3 +473,120 @@ def test_reset_cache(setup_models, model):
assert lm.stats.total_usage.cache_hits == 3
lm(batch)
assert lm.stats.total_usage.cache_hits == 3


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_operator_cache(setup_models, model):
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=True)

data = {
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
]
}

expected_response = pd.DataFrame(
{
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
"Fluid Mechanics and Mass Transfer",
],
}
)

df = pd.DataFrame(data)
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 1

first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()

pd.testing.assert_frame_equal(first_response, second_response)
pd.testing.assert_frame_equal(first_response, expected_response)
pd.testing.assert_frame_equal(second_response, expected_response)

lm.reset_cache()
lm.reset_stats()
assert lm.stats.total_usage.operator_cache_hits == 0


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_operator_cache(setup_models, model):
cache_config = CacheConfig(cache_type=CacheType.SQLITE, max_size=1000)
cache = CacheFactory.create_cache(cache_config)

lm = LM(model="gpt-4o-mini", cache=cache)
lotus.settings.configure(lm=lm, enable_message_cache=True, enable_operator_cache=False)

data = {
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
]
}

expected_response = pd.DataFrame(
{
"Course Name": [
"Dynamics and Control of Chemical Processes",
"Optimization Methods in Engineering",
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
"Fluid Mechanics and Mass Transfer",
],
}
)

df = pd.DataFrame(data)
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 0

pd.testing.assert_frame_equal(first_response, second_response)

# Now enable operator cache.
lotus.settings.configure(enable_operator_cache=True)
first_responses = df.sem_map(user_instruction)
first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0
second_responses = df.sem_map(user_instruction)
second_responses["_map"] = second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 1

expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()

pd.testing.assert_frame_equal(first_responses, second_responses)
pd.testing.assert_frame_equal(first_responses, expected_response)
pd.testing.assert_frame_equal(second_responses, expected_response)
4 changes: 2 additions & 2 deletions docs/configurations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Using the Settings module
Configurable Parameters
--------------------------

1. enable_cache:
1. enable_message_cache:
* Description: Enables or Disables cahcing mechanisms
* Default: False
.. code-block:: python
lotus.settings.configure(enable_cache=True)
lotus.settings.configure(enable_message_cache=True)
2. setting RM:
* Description: Configures the retrieval model
Expand Down
2 changes: 1 addition & 1 deletion examples/model_examples/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

lm = LM(model="gpt-4o-mini", cache=cache)

lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False
lotus.settings.configure(lm=lm, enable_message_cache=True) # default caching is False
data = {
"Course Name": [
"Probability and Random Processes",
Expand Down
45 changes: 44 additions & 1 deletion lotus/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import pickle
import sqlite3
Expand All @@ -8,6 +10,8 @@
from functools import wraps
from typing import Any, Callable

import pandas as pd

import lotus


Expand All @@ -16,13 +20,52 @@ def require_cache_enabled(func: Callable) -> Callable:

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_cache:
if not lotus.settings.enable_message_cache:
return None
return func(self, *args, **kwargs)

return wrapper


def operator_cache(func: Callable) -> Callable:
"""Decorator to add operator level caching."""

@wraps(func)
def wrapper(self, *args, **kwargs):
model = lotus.settings.lm
use_operator_cache = lotus.settings.enable_operator_cache

if use_operator_cache and model.cache:

def serialize(value):
if isinstance(value, pd.DataFrame):
return value.to_json()
elif hasattr(value, "dict"):
return value.dict()
return value

serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
serialized_args = [serialize(arg) for arg in args]
cache_key = hashlib.sha256(
json.dumps({"args": serialized_args, "kwargs": serialized_kwargs}, sort_keys=True).encode()
).hexdigest()

cached_result = model.cache.get(cache_key)
if cached_result is not None:
lotus.logger.debug(f"Cache hit for {cache_key}")
model.stats.total_usage.operator_cache_hits += 1
return cached_result
lotus.logger.debug(f"Cache miss for {cache_key}")

result = func(self, *args, **kwargs)
model.cache.insert(cache_key, result)
return result

return func(self, *args, **kwargs)

return wrapper


class CacheType(Enum):
IN_MEMORY = "in_memory"
SQLITE = "sqlite"
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus.models
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticAggOutput

Expand Down Expand Up @@ -148,6 +149,7 @@ def process_group(args):
group, user_instruction, all_cols, suffix, progress_bar_desc = args
return group.sem_agg(user_instruction, all_cols, suffix, None, progress_bar_desc=progress_bar_desc)

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
4 changes: 3 additions & 1 deletion lotus/sem_ops/sem_cluster_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache


@pd.api.extensions.register_dataframe_accessor("sem_cluster_by")
Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down Expand Up @@ -52,7 +54,7 @@ def __call__(
self._obj["cluster_id"] = pd.Series(indices, index=self._obj.index)
# if return_scores:
# self._obj["centroid_sim_score"] = pd.Series(scores, index=self._obj.index)

# if return_centroids:
# return self._obj, centroids
# else:
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.models import LM
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticExtractOutput, SemanticExtractPostprocessOutput
Expand Down Expand Up @@ -33,7 +34,6 @@ def sem_extract(
Returns:
SemanticExtractOutput: The outputs, raw outputs, and quotes.
"""

# prepare model inputs
inputs = []
for doc in docs:
Expand Down Expand Up @@ -72,6 +72,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
input_cols: list[str],
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from numpy.typing import NDArray

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, LMOutput, LogprobsForFilterCascade, SemanticFilterOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -134,6 +135,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
3 changes: 2 additions & 1 deletion lotus/sem_ops/sem_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tqdm import tqdm

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import CascadeArgs, SemanticJoinOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -234,7 +235,6 @@ def sem_join_cascade(
cot_reasoning=cot_reasoning,
default=default,
strategy=strategy,
show_progress_bar=False,
)
pbar.update(num_large)
pbar.close()
Expand Down Expand Up @@ -545,6 +545,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
other: pd.DataFrame | pd.Series,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.templates import task_instructions
from lotus.types import LMOutput, SemanticMapOutput, SemanticMapPostprocessOutput
from lotus.utils import show_safe_mode
Expand Down Expand Up @@ -80,6 +81,7 @@ def _validate(obj: pd.DataFrame) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
user_instruction: str,
Expand Down
2 changes: 2 additions & 0 deletions lotus/sem_ops/sem_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

import lotus
from lotus.cache import operator_cache
from lotus.types import RerankerOutput, RMOutput


Expand All @@ -19,6 +20,7 @@ def _validate(obj: Any) -> None:
if not isinstance(obj, pd.DataFrame):
raise AttributeError("Must be a DataFrame")

@operator_cache
def __call__(
self,
col_name: str,
Expand Down
Loading

0 comments on commit 9761855

Please sign in to comment.