Skip to content

Commit

Permalink
Merge branch 'litellm-router' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaohan-dong authored Oct 20, 2024
2 parents ef83201 + a4f5b8e commit 9b4b9f9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .lm import LM
from .lm import LM, RoutedLM
76 changes: 71 additions & 5 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.
if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion
completion = self._get_completion_func(cache=cache)

response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs)))
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand All @@ -80,7 +77,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
model_type=self.model_type,
)
self.history.append(entry)

return outputs

def inspect_history(self, n: int = 1):
Expand Down Expand Up @@ -148,6 +145,12 @@ def copy(self, **kwargs):
return new_instance


def _get_completion_func(self, cache: bool = False):
if self.model_type == "chat":
return cached_litellm_completion if cache else litellm_completion
else:
return cached_litellm_text_completion if cache else litellm_text_completion


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request):
Expand Down Expand Up @@ -221,3 +224,66 @@ def _inspect_history(lm, n: int = 1):
print(_red(choices_text, end=""))

print("\n\n\n")


class RoutedLM(LM):
"""LM which uses LiteLLM Router to perform completion requests"""

def __init__(self, model, router, **kwargs):
# Type checking that router must be a litellm.Router instance, with model in router.model_names
if not isinstance(router, litellm.router.Router):
raise TypeError(
f"The 'router' argument must be an instance of {litellm.router.Router.__name__}, but received a type '{type(router).__name__}' instead."
)
# Check if model is supported by the router
available_models = router.get_model_names()
if model not in available_models:
raise ValueError(
f"The model '{model}' must be one of the router's model_names. Available models on router: {available_models}"
)

super().__init__(model, **kwargs)
self.router = router

def _get_completion_func(self, cache):
if self.model_type == "chat":
return self._cached_router_completion if cache else self._router_completion
else:
return (
self._cached_router_text_completion
if cache
else self._router_text_completion
)

@functools.lru_cache(maxsize=None)
def _cached_router_completion(self, request):
"""Cache-enabled completion method that uses the router."""
return self._router_completion(
request, cache={"no-cache": False, "no-store": False}
)

def _router_completion(self, request, cache={"no-cache": True, "no-store": True}):
"""Actual completion logic using the router."""
kwargs = ujson.loads(request)
return self.router.completion(cache=cache, **kwargs)

@functools.lru_cache(maxsize=None)
def _cached_router_text_completion(self, request):
return self._router_text_completion(
request, cache={"no-cache": False, "no-store": False}
)

def _router_text_completion(
self, request, cache={"no-cache": True, "no-store": True}
):
kwargs = ujson.loads(request)

# The model alias for litellm.Router assigned by user, not the official model name
model_name = kwargs.pop("model")
prompt = "\n\n".join(
[x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]
)

return self.router.text_completion(
cache=cache, model=model_name, prompt=prompt, **kwargs
)

0 comments on commit 9b4b9f9

Please sign in to comment.