Skip to content

Commit

Permalink
Add dspy.Embedding (#1735)
Browse files Browse the repository at this point in the history
* Add embedding model

* force return type to be numpy array

---------

Co-authored-by: Omar Khattab <[email protected]>
  • Loading branch information
chenmoneygithub and okhat authored Nov 3, 2024
1 parent e4e7e0b commit 7e78199
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 8 deletions.
13 changes: 13 additions & 0 deletions dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
from .lm import LM
from .base_lm import BaseLM, inspect_history
from .embedding import Embedding
import litellm
import os
from pathlib import Path
from litellm.caching import Cache

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
litellm.telemetry = False

if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
# accessed at run time by litellm; i.e., fine to keep after import
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
77 changes: 77 additions & 0 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import litellm
import numpy as np


class Embedding:
"""DSPy embedding class.
The class for computing embeddings for text inputs. This class provides a unified interface for both:
1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration
2. Custom embedding functions that you provide
For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use
litellm to handle the API calls and caching.
For custom embedding models, pass a callable function that:
- Takes a list of strings as input.
- Returns embeddings as either:
- A 2D numpy array of float32 values
- A 2D list of float32 values
- Each row should represent one embedding vector
Args:
model: The embedding model to use. This can be either a string (representing the name of the hosted embedding
model, must be an embedding model supported by litellm) or a callable that represents a custom embedding
model.
Examples:
Example 1: Using a hosted model.
```python
import dspy
embedder = dspy.Embedding("openai/text-embedding-3-small")
embeddings = embedder(["hello", "world"])
assert embeddings.shape == (2, 1536)
```
Example 2: Using a custom function.
```python
import dspy
def my_embedder(texts):
return np.random.rand(len(texts), 10)
embedder = dspy.Embedding(my_embedder)
embeddings = embedder(["hello", "world"])
assert embeddings.shape == (2, 10)
```
"""

def __init__(self, model):
self.model = model

def __call__(self, inputs, caching=True, **kwargs):
"""Compute embeddings for the given inputs.
Args:
inputs: The inputs to compute embeddings for, can be a single string or a list of strings.
caching: Whether to cache the embedding response, only valid when using a hosted embedding model.
kwargs: Additional keyword arguments to pass to the embedding model.
Returns:
A 2-D numpy array of embeddings, one embedding per row.
"""
if isinstance(inputs, str):
inputs = [inputs]
if isinstance(self.model, str):
embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32)
elif callable(self.model):
return np.array(self.model(inputs, **kwargs), dtype=np.float32)
else:
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
9 changes: 1 addition & 8 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,19 @@
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional

import litellm
import ujson
from litellm.caching import Cache

from dspy.clients.finetune import FinetuneJob, TrainingMethod
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
from dspy.utils.callback import BaseCallback, with_callbacks

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
litellm.telemetry = False

if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

logger = logging.getLogger(__name__)


class LM(BaseLM):
"""
A language model supporting chat or text completion requests for use with DSPy modules.
Expand Down
64 changes: 64 additions & 0 deletions tests/clients/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from unittest.mock import Mock, patch
import numpy as np

from dspy.clients.embedding import Embedding


# Mock response format similar to litellm's embedding response.
class MockEmbeddingResponse:
def __init__(self, embeddings):
self.data = [{"embedding": emb} for emb in embeddings]
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
self.model = "mock_model"
self.object = "list"


def test_litellm_embedding():
model = "text-embedding-ada-002"
inputs = ["hello", "world"]
mock_embeddings = [
[0.1, 0.2, 0.3], # embedding for "hello"
[0.4, 0.5, 0.6], # embedding for "world"
]

with patch("litellm.embedding") as mock_litellm:
# Configure mock to return proper response format.
mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings)

# Create embedding instance and call it.
embedding = Embedding(model)
result = embedding(inputs)

# Verify litellm was called with correct parameters.
mock_litellm.assert_called_once_with(model=model, input=inputs, caching=True)

assert len(result) == len(inputs)
np.testing.assert_allclose(result, mock_embeddings)


def test_callable_embedding():
inputs = ["hello", "world", "test"]

expected_embeddings = [
[0.1, 0.2, 0.3], # embedding for "hello"
[0.4, 0.5, 0.6], # embedding for "world"
[0.7, 0.8, 0.9], # embedding for "test"
]

def mock_embedding_fn(texts):
# Simple callable that returns random embeddings.
return expected_embeddings

# Create embedding instance with callable
embedding = Embedding(mock_embedding_fn)
result = embedding(inputs)

np.testing.assert_allclose(result, expected_embeddings)


def test_invalid_model_type():
# Test that invalid model type raises ValueError
with pytest.raises(ValueError):
embedding = Embedding(123) # Invalid model type
embedding(["test"])

0 comments on commit 7e78199

Please sign in to comment.