From 153ea25683fd9553e099ae1a52325c4339fbbaea Mon Sep 17 00:00:00 2001 From: Zhaohan Dong <65422392+zhaohan-dong@users.noreply.github.com> Date: Thu, 10 Oct 2024 23:58:00 +0000 Subject: [PATCH 1/2] feat: Adding LiteLLM Router support to dspy.LM with RoutedLM Signed-off-by: Zhaohan Dong <65422392+zhaohan-dong@users.noreply.github.com> --- dspy/clients/__init__.py | 2 +- dspy/clients/lm.py | 94 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 079dc5420..1cfcb6485 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1 +1 @@ -from .lm import LM \ No newline at end of file +from .lm import LM, RoutedLM \ No newline at end of file diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 994c8bc41..83ba7d8da 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -135,3 +135,97 @@ def _inspect_history(lm, n: int = 1): print(_red(choices_text, end="")) print("\n\n\n") + + +class RoutedLM(LM): + """LM which uses LiteLLM Router to perform completion requests""" + + def __init__(self, model, router, **kwargs): + # Type checking that router must be a litellm.Router instance, with model in router.model_names + if not isinstance(router, litellm.router.Router): + raise TypeError( + f"The 'router' argument must be an instance of {litellm.router.Router.__name__}, but received a type '{type(router).__name__}' instead." + ) + # Check if model is supported by the router + available_models = router.get_model_names() + if model not in available_models: + raise ValueError( + f"The model '{model}' must be one of the router's model_names. Available models on router: {available_models}" + ) + + super().__init__(model, **kwargs) + self.router = router + + def __call__(self, prompt=None, messages=None, **kwargs): + cache = kwargs.pop("cache", self.cache) + messages = messages or [{"role": "user", "content": prompt}] + kwargs = {**self.kwargs, **kwargs} + + if self.model_type == "chat": + completion = ( + self._cached_router_completion if cache else self._router_completion + ) + else: + completion = ( + self._cached_router_text_completion + if cache + else self._router_text_completion + ) + + response = completion( + ujson.dumps(dict(model=self.model, messages=messages, **kwargs)) + ) + outputs = [ + c.message.content if hasattr(c, "message") else c["text"] + for c in response["choices"] + ] + + # Follow LM's logging + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} + entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) + entry = dict(**entry, outputs=outputs, usage=dict(response["usage"])) + entry = dict( + **entry, cost=response.get("_hidden_params", {}).get("response_cost") + ) + entry = dict( + **entry, + timestamp=datetime.now().isoformat(), + uuid=str(uuid.uuid4()), + model=self.model, + model_type=self.model_type, + ) + self.history.append(entry) + return outputs + + @functools.lru_cache(maxsize=None) + def _cached_router_completion(self, request): + """Cache-enabled completion method that uses the router.""" + return self._router_completion( + request, cache={"no-cache": False, "no-store": False} + ) + + def _router_completion(self, request, cache={"no-cache": True, "no-store": True}): + """Actual completion logic using the router.""" + kwargs = ujson.loads(request) + return self.router.completion(cache=cache, **kwargs) + + @functools.lru_cache(maxsize=None) + def _cached_router_text_completion(self, request): + return self._router_text_completion( + request, cache={"no-cache": False, "no-store": False} + ) + + def _router_text_completion( + self, request, cache={"no-cache": True, "no-store": True} + ): + kwargs = ujson.loads(request) + + # The model alias for litellm.Router assigned by user, not the official model name + model_name = kwargs.pop("model") + prompt = "\n\n".join( + [x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"] + ) + + return self.router.text_completion( + cache=cache, model=model_name, prompt=prompt, **kwargs + ) From a02d3b24ed11abc3343cdc0dbd67984fe68f0394 Mon Sep 17 00:00:00 2001 From: Zhaohan Dong <65422392+zhaohan-dong@users.noreply.github.com> Date: Sun, 13 Oct 2024 20:11:36 +0000 Subject: [PATCH 2/2] Extract completion function selection Signed-off-by: Zhaohan Dong <65422392+zhaohan-dong@users.noreply.github.com> --- dspy/clients/lm.py | 50 ++++++++++------------------------------------ 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 83ba7d8da..aa1080173 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -36,10 +36,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): kwargs = {**self.kwargs, **kwargs} # Make the request and handle LRU & disk caching. - if self.model_type == "chat": - completion = cached_litellm_completion if cache else litellm_completion - else: - completion = cached_litellm_text_completion if cache else litellm_text_completion + completion = self._get_completion_func(cache=cache) response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs))) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] @@ -57,12 +54,18 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) - + return outputs def inspect_history(self, n: int = 1): _inspect_history(self, n) + def _get_completion_func(self, cache: bool = False): + if self.model_type == "chat": + return cached_litellm_completion if cache else litellm_completion + else: + return cached_litellm_text_completion if cache else litellm_text_completion + @functools.lru_cache(maxsize=None) def cached_litellm_completion(request): @@ -156,47 +159,16 @@ def __init__(self, model, router, **kwargs): super().__init__(model, **kwargs) self.router = router - def __call__(self, prompt=None, messages=None, **kwargs): - cache = kwargs.pop("cache", self.cache) - messages = messages or [{"role": "user", "content": prompt}] - kwargs = {**self.kwargs, **kwargs} - + def _get_completion_func(self, cache): if self.model_type == "chat": - completion = ( - self._cached_router_completion if cache else self._router_completion - ) + return self._cached_router_completion if cache else self._router_completion else: - completion = ( + return ( self._cached_router_text_completion if cache else self._router_text_completion ) - response = completion( - ujson.dumps(dict(model=self.model, messages=messages, **kwargs)) - ) - outputs = [ - c.message.content if hasattr(c, "message") else c["text"] - for c in response["choices"] - ] - - # Follow LM's logging - kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")} - entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response) - entry = dict(**entry, outputs=outputs, usage=dict(response["usage"])) - entry = dict( - **entry, cost=response.get("_hidden_params", {}).get("response_cost") - ) - entry = dict( - **entry, - timestamp=datetime.now().isoformat(), - uuid=str(uuid.uuid4()), - model=self.model, - model_type=self.model_type, - ) - self.history.append(entry) - return outputs - @functools.lru_cache(maxsize=None) def _cached_router_completion(self, request): """Cache-enabled completion method that uses the router."""