diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 079dc5420..1cfcb6485 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1 +1 @@ -from .lm import LM \ No newline at end of file +from .lm import LM, RoutedLM \ No newline at end of file diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f2949372a..8017330c0 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -59,10 +59,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): kwargs = {**self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. - if self.model_type == "chat": - completion = cached_litellm_completion if cache else litellm_completion - else: - completion = cached_litellm_text_completion if cache else litellm_text_completion + completion = self._get_completion_func(cache=cache) response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs))) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -80,7 +77,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) - + return outputs def inspect_history(self, n: int = 1): @@ -148,6 +145,12 @@ def copy(self, **kwargs): return new_instance + def _get_completion_func(self, cache: bool = False): + if self.model_type == "chat": + return cached_litellm_completion if cache else litellm_completion + else: + return cached_litellm_text_completion if cache else litellm_text_completion + @functools.lru_cache(maxsize=None) def cached_litellm_completion(request): @@ -221,3 +224,66 @@ def _inspect_history(lm, n: int = 1): print(_red(choices_text, end="")) print("\n\n\n") + + +class RoutedLM(LM): + """LM which uses LiteLLM Router to perform completion requests""" + + def __init__(self, model, router, **kwargs): + # Type checking that router must be a litellm.Router instance, with model in router.model_names + if not isinstance(router, litellm.router.Router): + raise TypeError( + f"The 'router' argument must be an instance of {litellm.router.Router.__name__}, but received a type '{type(router).__name__}' instead." + ) + # Check if model is supported by the router + available_models = router.get_model_names() + if model not in available_models: + raise ValueError( + f"The model '{model}' must be one of the router's model_names. Available models on router: {available_models}" + ) + + super().__init__(model, **kwargs) + self.router = router + + def _get_completion_func(self, cache): + if self.model_type == "chat": + return self._cached_router_completion if cache else self._router_completion + else: + return ( + self._cached_router_text_completion + if cache + else self._router_text_completion + ) + + @functools.lru_cache(maxsize=None) + def _cached_router_completion(self, request): + """Cache-enabled completion method that uses the router.""" + return self._router_completion( + request, cache={"no-cache": False, "no-store": False} + ) + + def _router_completion(self, request, cache={"no-cache": True, "no-store": True}): + """Actual completion logic using the router.""" + kwargs = ujson.loads(request) + return self.router.completion(cache=cache, **kwargs) + + @functools.lru_cache(maxsize=None) + def _cached_router_text_completion(self, request): + return self._router_text_completion( + request, cache={"no-cache": False, "no-store": False} + ) + + def _router_text_completion( + self, request, cache={"no-cache": True, "no-store": True} + ): + kwargs = ujson.loads(request) + + # The model alias for litellm.Router assigned by user, not the official model name + model_name = kwargs.pop("model") + prompt = "\n\n".join( + [x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"] + ) + + return self.router.text_completion( + cache=cache, model=model_name, prompt=prompt, **kwargs + )