Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP - dspy.RM/retrieve refactor #1739

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion dsp/modules/colbertv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

# TODO: Ideally, this takes the name of the index and looks up its port.


class ColBERTv2:
"""Wrapper for the ColBERTv2 Retrieval."""

Expand Down
3 changes: 2 additions & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .predict import *
from .primitives import *
from .retrieve import *
from .retriever import *
from .signatures import *

# Functional must be imported after primitives, predict and signatures
Expand All @@ -24,7 +25,7 @@
Mistral = dsp.Mistral
Databricks = dsp.Databricks
Cohere = dsp.Cohere
ColBERTv2 = dsp.ColBERTv2
ColBERTv2 = ColBERTv2
ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal
ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal
Pyserini = dsp.PyseriniRetriever
Expand Down
1 change: 0 additions & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pydantic.fields import FieldInfo
from typing import Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
Expand Down
48 changes: 27 additions & 21 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import litellm
import numpy as np

from typing import Callable, List, Union

class Embedding:
"""DSPy embedding class.
Expand All @@ -13,25 +13,24 @@ class Embedding:
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
litellm to handle the API calls and caching.

For custom embedding models, pass a callable function that:
For custom embedding models, pass a callable as `embedding_model` that:
- Takes a list of strings as input.
- Returns embeddings as either:
- A 2D numpy array of float32 values
- A 2D list of float32 values
- Each row should represent one embedding vector
- A 2D numpy array of float32 values.
- A 2D list of float32 values.
- Each row represents one embedding vector.

Args:
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
model.
embedding_model: The embedding model to use, either a string (for hosted models supported by litellm) or
a callable that returns custom embeddings.

Examples:
Example 1: Using a hosted model.

```python
import dspy

embedder = dspy.Embedding("openai/text-embedding-3-small")
embedder = dspy.Embedding(embedding_model="openai/text-embedding-3-small")
embeddings = embedder(["hello", "world"])

assert embeddings.shape == (2, 1536)
Expand All @@ -41,37 +40,44 @@ class Embedding:

```python
import dspy
import numpy as np

def my_embedder(texts):
return np.random.rand(len(texts), 10)

embedder = dspy.Embedding(my_embedder)
embedder = dspy.Embedding(embedding_model=my_embedder)
embeddings = embedder(["hello", "world"])

assert embeddings.shape == (2, 10)
```
"""

def __init__(self, model):
self.model = model
def __init__(self, embedding_model: Union[str, Callable] = 'openai/text-embedding-3-small'):
self.embedding_model = embedding_model

def default_embedding_model(self, texts: List[str], caching: bool = True, **kwargs) -> List[List[float]]:
embeddings_response = litellm.embedding(model=self.embedding_model, input=texts, caching=caching, **kwargs)
return [data['embedding'] for data in embeddings_response.data]

def __call__(self, inputs, caching=True, **kwargs):
def __call__(self, inputs: Union[str, List[str]], caching: bool = True, **kwargs) -> np.ndarray:
"""Compute embeddings for the given inputs.

Args:
inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
caching: Whether to cache the embedding response, only valid when using a hosted embedding model.
inputs: Query inputs to compute embeddings for, can be a single string or a list of strings.
caching: Cache flag for embedding response when using an embedding model.
kwargs: Additional keyword arguments to pass to the embedding model.

Returns:
A 2-D numpy array of embeddings, one embedding per row.
"""
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(self.model, str):
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
elif callable(self.model):
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
if callable(self.embedding_model):
embeddings = self.embedding_model(inputs, **kwargs)
elif isinstance(self.embedding_model, str):
embeddings = self.default_embedding_model(inputs, caching=caching, **kwargs)
else:
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
raise ValueError(
f"`embedding_model` must be a string or a callable, but got type: {type(self.embedding_model)}."
)
return np.array(embeddings, dtype=np.float32)
14 changes: 13 additions & 1 deletion dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from typing import Dict, List, Optional, Union
import logging
from functools import lru_cache

import dsp
from dspy.predict.parameter import Parameter
Expand All @@ -16,13 +18,22 @@ def single_query_passage(passages):
passages_dict["passages"] = passages_dict.pop("long_text")
return Prediction(**passages_dict)

logger = logging.getLogger(__name__)

@lru_cache(maxsize=None)
def warn_once(msg: str):
logger.warning(msg)

class Retrieve(Parameter):
name = "Search"
input_variable = "query"
desc = "takes a search query and returns one or more potentially relevant passages from a corpus"

def __init__(self, k=3, callbacks=None):
warn_once(
"Existing retriever integrations under dspy/retrieve inheriting `dspy.Retrieve` are deprecated and will be removed in the DSPy 2.7 release. \n"
"For future retriever integrations, please use the `dspy.Retriever` interface under dspy/retriever/retriever.py and reference any of the custom integrations supported in dspy/retriever/"
)
self.stage = random.randbytes(8).hex()
self.k = k
self.callbacks = callbacks or []
Expand Down Expand Up @@ -104,6 +115,7 @@ def forward(
# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too.


#TODO potentially add for deprecation/removal in 2.7
class RetrieveThenRerank(Parameter):
name = "Search"
input_variable = "query"
Expand Down Expand Up @@ -163,4 +175,4 @@ def forward(
pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
return single_query_passage(passages=passages)
return single_query_passage(passages=passages)
2 changes: 2 additions & 0 deletions dspy/retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .retriever import Retriever
from .colbertv2 import ColBERTv2
80 changes: 80 additions & 0 deletions dspy/retriever/colbertv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, Union, Optional, List
from dspy.retriever import Retriever
from dsp.utils import dotdict
import requests
import functools

class ColBERTv2(Retriever):
"""
ColBERTv2 Retriever for retrieval of top-k most relevant text passages for given query.

Args:
url (str): Base URL endpoint for the ColBERTv2 server.
port (Union[str, int], optional): Port number for server. Appended to URL if provided.
post_requests (bool, optional): Determines if POST requests should be used instead of GET requests for querying the server. Defaults to False.
k (int, optional): Number of top passages to retrieve. Defaults to 10.
callbacks (Optional[List[Any]]): List of callback functions to be called during retrieval.
cache (bool, optional): Enable retrieval caching. Disabled by default.


Returns:
An object containing the retrieved passages.

Example:
import dspy
results = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')(query, k=10).passages
print(results)
"""
def __init__(self, url: str = "http://0.0.0.0", port: Optional[Union[str, int]] = None, post_requests: bool = False, k: int = 10, callbacks: Optional[List[Any]] = None, cache: bool = False):
super().__init__(embedder=None, k=k, callbacks=callbacks, cache=cache)
self.post_requests = post_requests
self.url = f"{url}:{port}" if port else url

def forward(self, query: str, k: int = 10) -> Any:
if self.post_requests:
topk = colbertv2_post_request(self.url, query, k)
else:
topk = colbertv2_get_request(self.url, query, k)
return dotdict({'passages': [dotdict(psg) for psg in topk]})


from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory
@CacheMemory.cache
def colbertv2_get_request_v2(url: str, query: str, k: int):
assert (
k <= 100
), "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."

payload = {"query": query, "k": k}
res = requests.get(url, params=payload, timeout=10)

topk = res.json()["topk"][:k]
topk = [{**d, "long_text": d["text"]} for d in topk]
return topk[:k]


@functools.cache
@NotebookCacheMemory.cache
def colbertv2_get_request_v2_wrapped(*args, **kwargs):
return colbertv2_get_request_v2(*args, **kwargs)


colbertv2_get_request = colbertv2_get_request_v2_wrapped


@CacheMemory.cache
def colbertv2_post_request_v2(url: str, query: str, k: int):
headers = {"Content-Type": "application/json; charset=utf-8"}
payload = {"query": query, "k": k}
res = requests.post(url, json=payload, headers=headers, timeout=10)

return res.json()["topk"][:k]


@functools.cache
@NotebookCacheMemory.cache
def colbertv2_post_request_v2_wrapped(*args, **kwargs):
return colbertv2_post_request_v2(*args, **kwargs)


colbertv2_post_request = colbertv2_post_request_v2_wrapped
Loading
Loading