Skip to content

Commit

Permalink
Add LM Caching (#31)
Browse files Browse the repository at this point in the history
Adds a `Cache` to the `LM` object. The cache implements LRU evection.
Added tests to verify / demonstrate the behavior.

---------

Co-authored-by: liana313 <[email protected]>
  • Loading branch information
sidjha1 and liana313 authored Nov 12, 2024
1 parent 47bc97e commit 30ff6f7
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 21 deletions.
83 changes: 83 additions & 0 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def print_usage_after_each_test(setup_models):
print(f"\nUsage stats for {model_name} after test:")
model.print_total_usage()
model.reset_stats()
model.reset_cache()


################################################################################
Expand Down Expand Up @@ -276,3 +277,85 @@ def test_custom_tokenizer():
tokens = custom_lm.count_tokens("Hello, world!")
assert custom_lm.count_tokens([{"role": "user", "content": "Hello, world!"}]) == tokens
assert tokens < 100


################################################################################
# Cache tests
################################################################################
@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)

# Check that "What is the capital of France?" becomes cached
first_batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]

first_responses = lm(first_batch).outputs
assert lm.stats.total_usage.cache_hits == 0

second_batch = [
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "What is the capital of Germany?"}],
]
second_responses = lm(second_batch).outputs
assert second_responses[0] == first_responses[1]
assert lm.stats.total_usage.cache_hits == 1

# Test clearing cache
lm.reset_cache()
lm.reset_stats()
lm(second_batch)
assert lm.stats.total_usage.cache_hits == 0


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_disable_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=False)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]
lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 0

# Now enable cache. Note that the first batch is not cached.
lotus.settings.configure(enable_cache=True)
first_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 0
second_responses = lm(batch).outputs
assert lm.stats.total_usage.cache_hits == 2
assert first_responses == second_responses


@pytest.mark.parametrize("model", get_enabled("gpt-4o-mini"))
def test_reset_cache(setup_models, model):
lm = setup_models[model]
lotus.settings.configure(lm=lm, enable_cache=True)

batch = [
[{"role": "user", "content": "Hello, world!"}],
[{"role": "user", "content": "What is the capital of France?"}],
]
lm(batch)
assert lm.stats.total_usage.cache_hits == 0
lm(batch)
assert lm.stats.total_usage.cache_hits == 2

lm.reset_cache(max_size=1)
lm(batch)
assert lm.stats.total_usage.cache_hits == 2
lm(batch)
assert lm.stats.total_usage.cache_hits == 3

lm.reset_cache(max_size=0)
lm(batch)
assert lm.stats.total_usage.cache_hits == 3
lm(batch)
assert lm.stats.total_usage.cache_hits == 3
27 changes: 27 additions & 0 deletions examples/model_examples/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import pandas as pd

import lotus
from lotus.models import LM

lm = LM(model="gpt-4o-mini")

lotus.settings.configure(lm=lm, enable_cache=True) # default caching is False
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = "{Course Name} requires a lot of math"
df = df.sem_filter(user_instruction)
print("====== intial run ======")
print(df)

# run a second time
df = df.sem_filter(user_instruction)
print("====== second run ======")
print(df)

43 changes: 43 additions & 0 deletions lotus/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections import OrderedDict
from functools import wraps
from typing import Any, Callable

import lotus


def require_cache_enabled(func: Callable) -> Callable:
"""Decorator to check if caching is enabled before calling the function."""

@wraps(func)
def wrapper(self, *args, **kwargs):
if not lotus.settings.enable_cache:
return None
return func(self, *args, **kwargs)

return wrapper


class Cache:
def __init__(self, max_size: int):
self.max_size = max_size
self.cache: OrderedDict[str, Any] = OrderedDict()

@require_cache_enabled
def get(self, key: str) -> Any | None:
if key in self.cache:
lotus.logger.debug(f"Cache hit for {key}")

return self.cache.get(key)

@require_cache_enabled
def insert(self, key: str, value: Any):
self.cache[key] = value

# LRU eviction
if len(self.cache) > self.max_size:
self.cache.popitem(last=False)

def reset(self, max_size: int | None = None):
self.cache.clear()
if max_size is not None:
self.max_size = max_size
73 changes: 53 additions & 20 deletions lotus/models/lm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from typing import Any

import litellm
Expand All @@ -9,6 +10,7 @@
from tokenizers import Tokenizer

import lotus
from lotus.cache import Cache
from lotus.types import LMOutput, LMStats, LogprobsForCascade, LogprobsForFilterCascade


Expand All @@ -21,6 +23,7 @@ def __init__(
max_tokens: int = 512,
max_batch_size: int = 64,
tokenizer: Tokenizer | None = None,
max_cache_size: int = 1024,
**kwargs: dict[str, Any],
):
self.model = model
Expand All @@ -31,40 +34,66 @@ def __init__(
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

self.stats: LMStats = LMStats()
self.cache = Cache(max_cache_size)

def __call__(self, messages: list[list[dict[str, str]]], **kwargs: dict[str, Any]) -> LMOutput:
all_kwargs = {**self.kwargs, **kwargs}

# Set top_logprobs if logprobs requested
if all_kwargs.get("logprobs", False):
all_kwargs["top_logprobs"] = all_kwargs.get("top_logprobs", 10)

all_responses: list[ModelResponse] = []
for i in range(0, len(messages), self.max_batch_size):
batch = messages[i : i + self.max_batch_size]
responses: list[ModelResponse] = batch_completion(
self.model,
batch,
drop_params=True,
**all_kwargs, # type: ignore
)
all_responses.extend(responses)

# throw errors, if any
for resp in all_responses:
if isinstance(resp, OpenAIError):
raise resp
all_kwargs.setdefault("top_logprobs", 10)

# Check cache and separate cached and uncached messages
hashed_messages = [self._hash_messages(msg, all_kwargs) for msg in messages]
cached_responses = [self.cache.get(hash) for hash in hashed_messages]
uncached_data = [
(msg, hash) for msg, hash, resp in zip(messages, hashed_messages, cached_responses) if resp is None
]
self.stats.total_usage.cache_hits += len(messages) - len(uncached_data)

# Process uncached messages in batches
uncached_responses = self._process_uncached_messages(uncached_data, all_kwargs)

# Add new responses to cache
for resp, (_, hash) in zip(uncached_responses, uncached_data):
self._cache_response(resp, hash)

# Merge all responses in original order and extract outputs
all_responses = self._merge_responses(cached_responses, uncached_responses)
outputs = [self._get_top_choice(resp) for resp in all_responses]
logprobs = (
[self._get_top_choice_logprobs(resp) for resp in all_responses] if all_kwargs.get("logprobs") else None
)

for resp in all_responses:
self._update_stats(resp)

return LMOutput(outputs=outputs, logprobs=logprobs)

def _process_uncached_messages(self, uncached_data, all_kwargs):
"""Processes uncached messages in batches and returns responses."""
uncached_responses = []
for i in range(0, len(uncached_data), self.max_batch_size):
batch = [msg for msg, _ in uncached_data[i : i + self.max_batch_size]]
uncached_responses.extend(batch_completion(self.model, batch, drop_params=True, **all_kwargs))
return uncached_responses

def _cache_response(self, response, hash):
"""Caches a response and updates stats if successful."""
if isinstance(response, OpenAIError):
raise response
self._update_stats(response)
self.cache.insert(hash, response)

def _hash_messages(self, messages: list[dict[str, str]], kwargs: dict[str, Any]) -> str:
"""Hash messages and kwargs to create a unique key for the cache"""
to_hash = str(self.model) + str(messages) + str(kwargs)
return hashlib.sha256(to_hash.encode()).hexdigest()

def _merge_responses(
self, cached_responses: list[ModelResponse | None], uncached_responses: list[ModelResponse]
) -> list[ModelResponse]:
"""Merge cached and uncached responses, maintaining order"""
uncached_iter = iter(uncached_responses)
return [resp if resp is not None else next(uncached_iter) for resp in cached_responses]

def _update_stats(self, response: ModelResponse):
if not hasattr(response, "usage"):
return
Expand Down Expand Up @@ -155,8 +184,12 @@ def print_total_usage(self):
print(f"Total prompt tokens: {self.stats.total_usage.prompt_tokens}")
print(f"Total completion tokens: {self.stats.total_usage.completion_tokens}")
print(f"Total tokens: {self.stats.total_usage.total_tokens}")
print(f"Total cache hits: {self.stats.total_usage.cache_hits}")

def reset_stats(self):
self.stats = LMStats(
total_usage=LMStats.TotalUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0, total_cost=0.0)
)

def reset_cache(self, max_size: int | None = None):
self.cache.reset(max_size)
2 changes: 1 addition & 1 deletion lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ def __repr__(self) -> str:

# set defaults
settings = Settings()
settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50)
settings.configure(cascade_is_weight=0.5, cascade_num_calibration_quantiles=50, enable_cache=False)
1 change: 1 addition & 0 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TotalUsage(BaseModel):
completion_tokens: int = 0
total_tokens: int = 0
total_cost: float = 0.0
cache_hits: int = 0

total_usage: TotalUsage = TotalUsage()

Expand Down

0 comments on commit 30ff6f7

Please sign in to comment.