Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Oct 31, 2024
1 parent 6fe6935 commit 2ce32fc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 2 deletions.
3 changes: 2 additions & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .lm import LM
from .lm import LM
from .embedding import Embedding
25 changes: 25 additions & 0 deletions dspy/clients/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
from pathlib import Path

import litellm
from litellm.caching import Cache

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
litellm.telemetry = False

if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"


class Embedding:
def __init__(self, model):
self.model = model

def __call__(self, inputs, caching=True, **kwargs):
if isinstance(self.model, str):
return litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs)
elif isinstance(self.model, callable):
return self.model(inputs, **kwargs)
else:
raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.")
3 changes: 2 additions & 1 deletion dspy/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .retrieve import Retrieve, RetrieveThenRerank
from .retrieve import Retrieve, RetrieveThenRerank
from dspy.retrieve.retrieve_v2 import RetrieveV2
43 changes: 43 additions & 0 deletions dspy/retrieve/retrieve_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
import os

import cloudpickle

from dspy.utils.callback import with_callbacks

logger = logging.getLogger(__name__)


class RetrieveV2:
def __init__(self, index=None, use_local_index=False, documents=None, callbacks=None):
self.use_local_index = use_local_index
self.documents = documents
self.index = index
if self.use_local_index:
self.index = self.build_local_index(documents)

def save(self, path):
if self.use_local_index:
with open(os.path.join(path, "documents.pkl"), "wb") as file:
cloudpickle.dump(self.documents, file)

def load(self, path):
if self.use_local_index:
file_path = os.path.join(path, "documents.pkl")
if not os.path.exists(file_path):
logger.warning(f"File {file_path} does not exist, nothing to load.")
return

with open(file_path, "rb") as file:
self.documents = cloudpickle.load(file)
self.index = self.build_local_index(self.documents)

def build_local_index(self, documents):
pass

@with_callbacks
def __call__(self, query, k=None, **kwargs):
return self.forward(query, k=k, **kwargs)

def forward(self, query, k=None, **kwargs):
raise NotImplementedError

0 comments on commit 2ce32fc

Please sign in to comment.