-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add embedding model * force return type to be numpy array --------- Co-authored-by: Omar Khattab <[email protected]>
- Loading branch information
1 parent
e4e7e0b
commit 7e78199
Showing
4 changed files
with
155 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,15 @@ | ||
from .lm import LM | ||
from .base_lm import BaseLM, inspect_history | ||
from .embedding import Embedding | ||
import litellm | ||
import os | ||
from pathlib import Path | ||
from litellm.caching import Cache | ||
|
||
DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") | ||
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") | ||
litellm.telemetry = False | ||
|
||
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: | ||
# accessed at run time by litellm; i.e., fine to keep after import | ||
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import litellm | ||
import numpy as np | ||
|
||
|
||
class Embedding: | ||
"""DSPy embedding class. | ||
The class for computing embeddings for text inputs. This class provides a unified interface for both: | ||
1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration | ||
2. Custom embedding functions that you provide | ||
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use | ||
litellm to handle the API calls and caching. | ||
For custom embedding models, pass a callable function that: | ||
- Takes a list of strings as input. | ||
- Returns embeddings as either: | ||
- A 2D numpy array of float32 values | ||
- A 2D list of float32 values | ||
- Each row should represent one embedding vector | ||
Args: | ||
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding | ||
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding | ||
model. | ||
Examples: | ||
Example 1: Using a hosted model. | ||
```python | ||
import dspy | ||
embedder = dspy.Embedding("openai/text-embedding-3-small") | ||
embeddings = embedder(["hello", "world"]) | ||
assert embeddings.shape == (2, 1536) | ||
``` | ||
Example 2: Using a custom function. | ||
```python | ||
import dspy | ||
def my_embedder(texts): | ||
return np.random.rand(len(texts), 10) | ||
embedder = dspy.Embedding(my_embedder) | ||
embeddings = embedder(["hello", "world"]) | ||
assert embeddings.shape == (2, 10) | ||
``` | ||
""" | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
|
||
def __call__(self, inputs, caching=True, **kwargs): | ||
"""Compute embeddings for the given inputs. | ||
Args: | ||
inputs: The inputs to compute embeddings for, can be a single string or a list of strings. | ||
caching: Whether to cache the embedding response, only valid when using a hosted embedding model. | ||
kwargs: Additional keyword arguments to pass to the embedding model. | ||
Returns: | ||
A 2-D numpy array of embeddings, one embedding per row. | ||
""" | ||
if isinstance(inputs, str): | ||
inputs = [inputs] | ||
if isinstance(self.model, str): | ||
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs) | ||
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32) | ||
elif callable(self.model): | ||
return np.array(self.model(inputs, **kwargs), dtype=np.float32) | ||
else: | ||
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
from unittest.mock import Mock, patch | ||
import numpy as np | ||
|
||
from dspy.clients.embedding import Embedding | ||
|
||
|
||
# Mock response format similar to litellm's embedding response. | ||
class MockEmbeddingResponse: | ||
def __init__(self, embeddings): | ||
self.data = [{"embedding": emb} for emb in embeddings] | ||
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | ||
self.model = "mock_model" | ||
self.object = "list" | ||
|
||
|
||
def test_litellm_embedding(): | ||
model = "text-embedding-ada-002" | ||
inputs = ["hello", "world"] | ||
mock_embeddings = [ | ||
[0.1, 0.2, 0.3], # embedding for "hello" | ||
[0.4, 0.5, 0.6], # embedding for "world" | ||
] | ||
|
||
with patch("litellm.embedding") as mock_litellm: | ||
# Configure mock to return proper response format. | ||
mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings) | ||
|
||
# Create embedding instance and call it. | ||
embedding = Embedding(model) | ||
result = embedding(inputs) | ||
|
||
# Verify litellm was called with correct parameters. | ||
mock_litellm.assert_called_once_with(model=model, input=inputs, caching=True) | ||
|
||
assert len(result) == len(inputs) | ||
np.testing.assert_allclose(result, mock_embeddings) | ||
|
||
|
||
def test_callable_embedding(): | ||
inputs = ["hello", "world", "test"] | ||
|
||
expected_embeddings = [ | ||
[0.1, 0.2, 0.3], # embedding for "hello" | ||
[0.4, 0.5, 0.6], # embedding for "world" | ||
[0.7, 0.8, 0.9], # embedding for "test" | ||
] | ||
|
||
def mock_embedding_fn(texts): | ||
# Simple callable that returns random embeddings. | ||
return expected_embeddings | ||
|
||
# Create embedding instance with callable | ||
embedding = Embedding(mock_embedding_fn) | ||
result = embedding(inputs) | ||
|
||
np.testing.assert_allclose(result, expected_embeddings) | ||
|
||
|
||
def test_invalid_model_type(): | ||
# Test that invalid model type raises ValueError | ||
with pytest.raises(ValueError): | ||
embedding = Embedding(123) # Invalid model type | ||
embedding(["test"]) |