Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - dspy.RM/retrieve refactor #1739

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dsp/modules/colbertv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# TODO: Ideally, this takes the name of the index and looks up its port.


#TODO remove references of ColBERTv2 from here now that it is supported in retrieve/
class ColBERTv2:
"""Wrapper for the ColBERTv2 Retrieval."""

Expand Down
3 changes: 2 additions & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predict import *
from .primitives import *
from .retrieve import *
from .retriever import *
from .signatures import *

# Functional must be imported after primitives, predict and signatures
Expand All @@ -24,7 +25,7 @@
Mistral = dsp.Mistral
Databricks = dsp.Databricks
Cohere = dsp.Cohere
ColBERTv2 = dsp.ColBERTv2
# ColBERTv2 = dsp.ColBERTv2
ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal
ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal
Pyserini = dsp.PyseriniRetriever
Expand Down
1 change: 0 additions & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic.fields import FieldInfo
from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
Expand Down
2 changes: 1 addition & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .lm import LM
from .provider import Provider, TrainingJob
from .base_lm import BaseLM, inspect_history
from .embedding import Embedding
from .embedding import Embedder
import litellm
import os
from pathlib import Path
Expand Down
43 changes: 26 additions & 17 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import litellm
import numpy as np
from typing import Callable, List, Union, Optional


class Embedding:
class Embedder:
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
"""DSPy embedding class.

The class for computing embeddings for text inputs. This class provides a unified interface for both:
Expand All @@ -13,25 +13,26 @@ class Embedding:
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
litellm to handle the API calls and caching.

For custom embedding models, pass a callable function that:
For custom embedding models, pass a callable function to `embedding_function` that:
- Takes a list of strings as input.
- Returns embeddings as either:
- A 2D numpy array of float32 values
- A 2D list of float32 values
- Each row should represent one embedding vector

Args:
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
model.
embedding_model: The embedding model to use, either a string (for hosted models supported by litellm) or
a callable function that returns custom embeddings.
embedding_function: An optional custom embedding function. If not provided, defaults to litellm
for hosted models when `embedding_model` is a string.

Examples:
Example 1: Using a hosted model.

```python
import dspy

embedder = dspy.Embedding("openai/text-embedding-3-small")
embedder = dspy.Embedding(embedding_model="openai/text-embedding-3-small")
embeddings = embedder(["hello", "world"])

assert embeddings.shape == (2, 1536)
Expand All @@ -41,21 +42,27 @@ class Embedding:

```python
import dspy
import numpy as np

def my_embedder(texts):
return np.random.rand(len(texts), 10)

embedder = dspy.Embedding(my_embedder)
embedder = dspy.Embedding(embedding_function=my_embedder)
embeddings = embedder(["hello", "world"])

assert embeddings.shape == (2, 10)
```
"""

def __init__(self, model):
self.model = model
def __init__(self, embedding_model: Union[str, Callable[[List[str]], List[List[float]]]] = 'text-embedding-ada-002', embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None):
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
self.embedding_model = embedding_model
self.embedding_function = embedding_function or self.default_embedding_function

def default_embedding_function(self, texts: List[str], caching: bool = True, **kwargs) -> List[List[float]]:
embeddings_response = litellm.embedding(model=self.embedding_model, input=texts, caching=caching, **kwargs)
return [data['embedding'] for data in embeddings_response.data]

def __call__(self, inputs, caching=True, **kwargs):
def __call__(self, inputs: Union[str, List[str]], caching: bool = True, **kwargs) -> np.ndarray:
"""Compute embeddings for the given inputs.

Args:
Expand All @@ -68,10 +75,12 @@ def __call__(self, inputs, caching=True, **kwargs):
"""
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(self.model, str):
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
elif callable(self.model):
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
if callable(self.embedding_function):
embeddings = self.embedding_function(inputs, **kwargs)
elif isinstance(self.embedding_model, str):
embeddings = self.default_embedding_function(inputs, caching=caching, **kwargs)
else:
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
raise ValueError(
f"`embedding_model` must be a string or `embedding_function` must be a callable, but got types: `embedding_model`={type(self.embedding_model)}, `embedding_function`={type(self.embedding_function)}."
)
return np.array(embeddings, dtype=np.float32)
12 changes: 11 additions & 1 deletion dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from typing import Dict, List, Optional, Union
import logging
from functools import lru_cache

import dsp
from dspy.predict.parameter import Parameter
Expand All @@ -16,13 +18,20 @@ def single_query_passage(passages):
passages_dict["passages"] = passages_dict.pop("long_text")
return Prediction(**passages_dict)

@lru_cache(maxsize=None)
def warn_once(msg: str):
logging.warning(msg)
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved

class Retrieve(Parameter):
name = "Search"
input_variable = "query"
desc = "takes a search query and returns one or more potentially relevant passages from a corpus"

def __init__(self, k=3, callbacks=None):
warn_once(
"Existing retriever integrations under dspy/retrieve inheriting `dspy.Retrieve` are deprecated and will be removed in DSPy 2.6+. \n"
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
"For future retriever integrations, please use the `dspy.Retriever` interface under dspy/retriever/retriever.py and reference any of the custom integrations supported in dspy/retriever/"
)
self.stage = random.randbytes(8).hex()
self.k = k
self.callbacks = callbacks or []
Expand Down Expand Up @@ -104,6 +113,7 @@ def forward(
# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too.


#TODO potentially add for deprecation/removal in 2.6+
class RetrieveThenRerank(Parameter):
name = "Search"
input_variable = "query"
Expand Down Expand Up @@ -163,4 +173,4 @@ def forward(
pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
return single_query_passage(passages=passages)
return single_query_passage(passages=passages)
1 change: 1 addition & 0 deletions dspy/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .retriever import Retriever
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
76 changes: 76 additions & 0 deletions dspy/retriever/colbertv2_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any, Union, Optional
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
import dspy
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
from dsp.utils import dotdict
import requests
import functools

class ColBERTv2(dspy.Retriever):
"""
ColBERTv2 Retriever for retrieval of top-k most relevant text passages for given query.

Args:
post_requests (bool): Determines if POST requests should be used
arnavsinghvi11 marked this conversation as resolved.
Show resolved Hide resolved
instead of GET requests for querying the server.
url (str): URL endpoint for ColBERTv2 server

Returns:
An object containing the retrieved passages.

Example:
from dspy.retriever.colbertv2_retriever import ColBERTv2
results = ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=5).passages
print(results)
"""
def __init__(self, url: str = "http://0.0.0.0", port: Optional[Union[str, int]] = None, post_requests: bool = False):
super().__init__(embedder=None)
self.post_requests = post_requests
self.url = f"{url}:{port}" if port else url
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this port required? If so let's raise an ValueError for clear action.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an optional i believe. our examples typically use 'http://20.102.90.50:2017/wiki17_abstracts' where the port can be specified in the URL


def forward(self, query: str, k: int = 10) -> Any:
if self.post_requests:
topk = colbertv2_post_request(self.url, query, k)
else:
topk = colbertv2_get_request(self.url, query, k)
return dotdict({'passages': [dotdict(psg) for psg in topk]})


from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory
@CacheMemory.cache
def colbertv2_get_request_v2(url: str, query: str, k: int):
assert (
k <= 100
), "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."

payload = {"query": query, "k": k}
res = requests.get(url, params=payload, timeout=10)

topk = res.json()["topk"][:k]
topk = [{**d, "long_text": d["text"]} for d in topk]
return topk[:k]


@functools.cache
@NotebookCacheMemory.cache
def colbertv2_get_request_v2_wrapped(*args, **kwargs):
return colbertv2_get_request_v2(*args, **kwargs)


colbertv2_get_request = colbertv2_get_request_v2_wrapped


@CacheMemory.cache
def colbertv2_post_request_v2(url: str, query: str, k: int):
headers = {"Content-Type": "application/json; charset=utf-8"}
payload = {"query": query, "k": k}
res = requests.post(url, json=payload, headers=headers, timeout=10)

return res.json()["topk"][:k]


@functools.cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this _wrapped function? seems it's not adding extra logic except from some cache, call we move the annotations to colbertv2_post_request_v2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i think we can simplify it but will defer to @okhat as this was existing behavior from colbertv2.py

@NotebookCacheMemory.cache
def colbertv2_post_request_v2_wrapped(*args, **kwargs):
return colbertv2_post_request_v2(*args, **kwargs)


colbertv2_post_request = colbertv2_post_request_v2_wrapped
Loading
Loading