From 9f43944ada8219a5cc7ca3c0de4484b7d7a3ca93 Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Sat, 2 Nov 2024 06:42:16 -0700 Subject: [PATCH 01/31] update link again (#1742) --- docs/docs/building-blocks/1-language_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/docs/building-blocks/1-language_models.md b/docs/docs/building-blocks/1-language_models.md index 41d90cfc9..e2a4633c7 100644 --- a/docs/docs/building-blocks/1-language_models.md +++ b/docs/docs/building-blocks/1-language_models.md @@ -118,7 +118,7 @@ Any OpenAI-compatible endpoint is easy to set up with an `openai/` prefix as wel #### Setting up SGLang -1. **Install SGLang (adapted from SGLang [documentation](https://sgl-project.github.io/starts/install.html)):** +1. **Install SGLang (adapted from SGLang [documentation](https://sgl-project.github.io/start/install.html)):** ```bash pip install "sglang[all]" From 649be4a6f49f3f074afe5fc44b9e6aeebad7d454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Viktor=20Sz=C3=A9pe?= Date: Sun, 3 Nov 2024 02:17:29 +0100 Subject: [PATCH 02/31] Fix typos (#1707) * Fix typos * Revert false positives * Fix "fullset" --------- Co-authored-by: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> --- .github/workflows/precommits_check.yml | 2 +- README.md | 2 +- docs/docs/cheatsheet.md | 6 +-- .../data-handling/loading-custom-data.md | 4 +- .../lm_local_models/HFClientTGI.md | 2 +- docs/docs/deep-dive/modules/predict.md | 2 +- docs/docs/deep-dive/optimizers/bfrs.md | 2 +- docs/docs/deep-dive/optimizers/miprov2.md | 10 ++--- .../retrieval_models_clients/MilvusRM.md | 2 +- .../retrieval_models_clients/SnowflakeRM.md | 2 +- docs/docs/dspy-usecases.md | 40 +++++++++---------- docs/docs/quick-start/getting-started-01.md | 4 +- docs/mkdocs.yml | 2 +- dsp/adapters/utils.py | 2 +- dsp/modules/dummy_lm.py | 2 +- dsp/modules/groq_client.py | 4 +- dsp/modules/hf.py | 2 +- dsp/modules/premai.py | 2 +- dsp/modules/sentence_vectorizer.py | 2 +- dsp/primitives/demonstrate.py | 2 +- dsp/utils/ann_utils.py | 6 +-- dspy/clients/openai.py | 2 +- dspy/experimental/synthesizer/synthesizer.py | 8 ++-- dspy/experimental/synthetic_data.py | 2 +- dspy/functional/functional.py | 4 +- dspy/retrieve/snowflake_rm.py | 6 +-- dspy/signatures/signature.py | 2 +- dspy/teleprompt/finetune_teleprompter.py | 4 +- dspy/teleprompt/mipro_optimizer.py | 2 +- dspy/teleprompt/signature_opt_typed.py | 4 +- dspy/utils/dummies.py | 4 +- internals/build-and-release.md | 2 +- internals/release-checklist.md | 4 +- testing/README.md | 2 +- testing/tasks/heart_disease.py | 2 +- tests/dsp_LM/functional/test_functional.py | 6 +-- tests/dsp_LM/predict/test_react.py | 8 ++-- tests/dsp_LM/predict/test_retry.py | 2 +- .../dsp_LM/teleprompt/test_mipro_optimizer.py | 2 +- tests/functional/test_functional.py | 6 +-- tests/predict/test_react.py | 8 ++-- tests/predict/test_retry.py | 2 +- tests/primitives/test_program.py | 2 +- 43 files changed, 93 insertions(+), 93 deletions(-) diff --git a/.github/workflows/precommits_check.yml b/.github/workflows/precommits_check.yml index 7ac80deda..20c9c8368 100644 --- a/.github/workflows/precommits_check.yml +++ b/.github/workflows/precommits_check.yml @@ -24,7 +24,7 @@ jobs: run: | echo "Changed files" echo ${{ steps.files.outputs.all }} - echo "Github Client version" + echo "GitHub Client version" echo $(gh --version) - name: Pre-Commit Checks run: | diff --git a/README.md b/README.md index 9a9eb4384..d892e1806 100644 --- a/README.md +++ b/README.md @@ -415,7 +415,7 @@ Guidance, LMQL, RELM, and Outlines are all exciting new libraries for controllin This is very useful in many settings, but it's generally focused on low-level, structured control of a single LM call. It doesn't help ensure the JSON (or structured output) you get is going to be correct or useful for your task. -In contrast, **DSPy** automatically optimizes the prompts in your programs to align them with various task needs, which may also include producing valid structured ouputs. That said, we are considering allowing **Signatures** in **DSPy** to express regex-like constraints that are implemented by these libraries. +In contrast, **DSPy** automatically optimizes the prompts in your programs to align them with various task needs, which may also include producing valid structured outputs. That said, we are considering allowing **Signatures** in **DSPy** to express regex-like constraints that are implemented by these libraries. ## Testing diff --git a/docs/docs/cheatsheet.md b/docs/docs/cheatsheet.md index 754d5295e..c6d3a8f0c 100644 --- a/docs/docs/cheatsheet.md +++ b/docs/docs/cheatsheet.md @@ -230,7 +230,7 @@ def gsm8k_metric(gold, pred, trace=None) -> int: class FactJudge(dspy.Signature): """Judge if the answer is factually correct based on the context.""" - context = dspy.InputField(desc="Context for the prediciton") + context = dspy.InputField(desc="Context for the prediction") question = dspy.InputField(desc="Question to be answered") answer = dspy.InputField(desc="Answer for the question") factually_correct = dspy.OutputField(desc="Is the answer factually correct based on the context?", prefix="Factual[Yes/No]:") @@ -417,7 +417,7 @@ optimized_program = teleprompter.compile( optimized_program.save(f"mipro_optimized") # Evaluate optimized program -print(f"Evluate optimized program...") +print(f"Evaluate optimized program...") evaluate(optimized_program, devset=devset[:]) ``` @@ -446,7 +446,7 @@ optimized_program = teleprompter.compile( optimized_program.save(f"mipro_optimized") # Evaluate optimized program -print(f"Evluate optimized program...") +print(f"Evaluate optimized program...") evaluate(optimized_program, devset=devset[:]) ``` ### Signature Optimizer with Types diff --git a/docs/docs/deep-dive/data-handling/loading-custom-data.md b/docs/docs/deep-dive/data-handling/loading-custom-data.md index 405a90f48..196d0119b 100644 --- a/docs/docs/deep-dive/data-handling/loading-custom-data.md +++ b/docs/docs/deep-dive/data-handling/loading-custom-data.md @@ -86,7 +86,7 @@ Using the Dataset base class now makes loading custom datasets incredibly easy a !!! caution - We did not populate `_test` attribute in the above code, which is fine and won't cause any unneccesary error as such. However it'll give you an error if you try to access the test split. + We did not populate `_test` attribute in the above code, which is fine and won't cause any unnecessary error as such. However it'll give you an error if you try to access the test split. ```python dataset.test[:5] @@ -110,6 +110,6 @@ Using the Dataset base class now makes loading custom datasets incredibly easy a To prevent that you'll just need to make sure `_test` is not `None` and populated with the appropriate data. -You can overide the methods in `Dataset` class to customize your class even more. +You can override the methods in `Dataset` class to customize your class even more. In summary, the Dataset base class provides a simplistic way to load and preprocess custom datasets with minimal code! diff --git a/docs/docs/deep-dive/language_model_clients/lm_local_models/HFClientTGI.md b/docs/docs/deep-dive/language_model_clients/lm_local_models/HFClientTGI.md index 0dd420056..3c136e0de 100644 --- a/docs/docs/deep-dive/language_model_clients/lm_local_models/HFClientTGI.md +++ b/docs/docs/deep-dive/language_model_clients/lm_local_models/HFClientTGI.md @@ -81,7 +81,7 @@ The constructor initializes the `HFModel` base class to support the handling of - `model` (_str_): ID of Hugging Face model connected to the TGI server. - `port` (_int_ or _list_): Port for communicating to the TGI server. This can be a single port number (`8080`) or a list of TGI ports (`[8080, 8081, 8082]`) to route the requests to. - `url` (_str_): Base URL of hosted TGI server. This will often be `"http://localhost"`. -- `http_request_kwargs` (_dict_): Dictionary of additional keyword agruments to pass to the HTTP request function to the TGI server. This is `None` by default. +- `http_request_kwargs` (_dict_): Dictionary of additional keyword arguments to pass to the HTTP request function to the TGI server. This is `None` by default. - `**kwargs`: Additional keyword arguments to configure the TGI client. Example of the TGI constructor: diff --git a/docs/docs/deep-dive/modules/predict.md b/docs/docs/deep-dive/modules/predict.md index 46f0a09ee..b5169ea12 100644 --- a/docs/docs/deep-dive/modules/predict.md +++ b/docs/docs/deep-dive/modules/predict.md @@ -46,7 +46,7 @@ class Predict(Parameter): This method serves as a wrapper for the `forward` method. It allows making predictions using the `Predict` class by providing keyword arguments. -**Paramters:** +**Parameters:** - `**kwargs`: Keyword arguments required for prediction. **Returns:** diff --git a/docs/docs/deep-dive/optimizers/bfrs.md b/docs/docs/deep-dive/optimizers/bfrs.md index 8728398a3..c15004af1 100644 --- a/docs/docs/deep-dive/optimizers/bfrs.md +++ b/docs/docs/deep-dive/optimizers/bfrs.md @@ -23,7 +23,7 @@ In terms of API `BootstrapFewShotWithRandomSearch` teleprompter is quite similar ## Working Example -Let's take the example of optimizing a simple CoT pipeline for GSM8k dataset, we'll take the example in [BootstrapFewShot](/deep-dive/optimizers/bootstrap-fewshot) as our running example for optimizers. We're gonna assume our data and pipeline is same as the on in `BootstrapFewShot` article. So let's start by intializing the optimizer: +Let's take the example of optimizing a simple CoT pipeline for GSM8k dataset, we'll take the example in [BootstrapFewShot](/deep-dive/optimizers/bootstrap-fewshot) as our running example for optimizers. We're gonna assume our data and pipeline is same as the on in `BootstrapFewShot` article. So let's start by initializing the optimizer: ```python from dspy.teleprompt import BootstrapFewShotWithRandomSearch diff --git a/docs/docs/deep-dive/optimizers/miprov2.md b/docs/docs/deep-dive/optimizers/miprov2.md index 142eef084..4d60167cf 100644 --- a/docs/docs/deep-dive/optimizers/miprov2.md +++ b/docs/docs/deep-dive/optimizers/miprov2.md @@ -88,7 +88,7 @@ optimized_program = teleprompter.compile( optimized_program.save(f"mipro_optimized") # Evaluate optimized program -print(f"Evluate optimized program...") +print(f"Evaluate optimized program...") evaluate(optimized_program, devset=devset[:]) ``` @@ -119,7 +119,7 @@ zeroshot_optimized_program = teleprompter.compile( zeroshot_optimized_program.save(f"mipro_zeroshot_optimized") # Evaluate optimized program -print(f"Evluate optimized program...") +print(f"Evaluate optimized program...") evaluate(zeroshot_optimized_program, devset=devset[:]) ``` @@ -156,7 +156,7 @@ optimized_program = teleprompter.compile( optimized_program.save(f"mipro_optimized") # Evaluate optimized program -print(f"Evluate optimized program...") +print(f"Evaluate optimized program...") evaluate(optimized_program, devset=devset[:]) ``` @@ -170,7 +170,7 @@ evaluate(optimized_program, devset=devset[:]) | `prompt_model` | `dspy.LM` | LM specified in `dspy.settings` | Model used for prompt generation. | | `task_model` | `dspy.LM` | LM specified in `dspy.settings` | Model used for task execution. | | `auto` | `Optional[str]` | None | If set to `light`, `medium`, or `heavy`, this will automatically configure the following hyperparameters: `num_candidates`, `num_trials`, `minibatch`, and will also cap the size of `valset` up to 100, 300, and 1000 for `light`, `medium`, and `heavy` runs respectively. | -| `num_candidates` | `int` | `10` | Number of candidate instructions & few-shot examples to generate and evaluate for each predictor. If `num_candidates=10`, this means for a 2 module LM program we'll be optimizing over 10 candidates x 2 modules x 2 variables (few-shot ex. and instructions for each module)= 40 total variables. Therfore, if we increase `num_candidates`, we will probably want to increase `num_trials` as well (see Compile parameters). | +| `num_candidates` | `int` | `10` | Number of candidate instructions & few-shot examples to generate and evaluate for each predictor. If `num_candidates=10`, this means for a 2 module LM program we'll be optimizing over 10 candidates x 2 modules x 2 variables (few-shot ex. and instructions for each module)= 40 total variables. Therefore, if we increase `num_candidates`, we will probably want to increase `num_trials` as well (see Compile parameters). | | `num_threads` | `int` | `6` | Threads to use for evaluation. | | `max_errors` | `int` | `10` | Maximum errors during an evaluation run that can be made before throwing an Exception. | | `teacher_settings` | `dict` | `{}` | Settings to use for the teacher model that bootstraps few-shot examples. An example dict would be `{lm=}`. If your LM program with your default model is struggling to bootstrap any examples, it could be worth using a more powerful teacher model for bootstrapping. | @@ -210,7 +210,7 @@ At a high level, `MIPROv2` works by creating both few-shot examples and new inst These steps are broken down in more detail below: 1) **Bootstrap Few-Shot Examples**: The same bootstrapping technique used in `BootstrapFewshotWithRandomSearch` is used to create few-shot examples. This works by randomly sampling examples from your training set, which are then run through your LM program. If the output from the program is correct for this example, it is kept as a valid few-shot example candidate. Otherwise, we try another example until we've curated the specified amount of few-shot example candidates. This step creates `num_candidates` sets of `max_bootstrapped_demos` bootstrapped examples and `max_labeled_demos` basic examples sampled from the training set. 2) **Propose Instruction Candidates**. Next, we propose instruction candidates for each predictor in the program. This is done using another LM program as a proposer, which bootstraps & summarizes relevant information about the task to generate high quality instructions. Specifically, the instruction proposer includes (1) a generated summary of properties of the training dataset, (2) a generated summary of your LM program's code and the specific predictor that an instruction is being generated for, (3) the previously bootstrapped few-shot examples to show reference inputs / outputs for a given predictor and (4) a randomly sampled tip for generation (i.e. "be creative", "be concise", etc.) to help explore the feature space of potential instructions. -3. **Find an Optimized Combination of Few-Shot Examples & Instructions**. Finally, now that we've created these few-shot examples and instructions, we use Bayesian Optimization to choose which set of these would work best for each predictor in our program. This works by running a series of `num_trials` trials, where a new set of prompts are evaluated over our validation set at each trial. This helps the Bayesian Optimizer learn which combination of prompts work best over time. If `minibatch` is set to `True` (which it is by default), then the new set of prompts are only evaluated on a minibatch of size `minibatch_size` at each trial which generally allows for more efficient exploration / exploitation. The best averaging set of prompts is then evalauted on the full validation set every `minibatch_full_eval_steps` get a less noisey performance benchmark. At the end of the optimization process, the LM program with the set of prompts that performed best on the full validation set is returned. +3. **Find an Optimized Combination of Few-Shot Examples & Instructions**. Finally, now that we've created these few-shot examples and instructions, we use Bayesian Optimization to choose which set of these would work best for each predictor in our program. This works by running a series of `num_trials` trials, where a new set of prompts are evaluated over our validation set at each trial. This helps the Bayesian Optimizer learn which combination of prompts work best over time. If `minibatch` is set to `True` (which it is by default), then the new set of prompts are only evaluated on a minibatch of size `minibatch_size` at each trial which generally allows for more efficient exploration / exploitation. The best averaging set of prompts is then evaluated on the full validation set every `minibatch_full_eval_steps` get a less noisey performance benchmark. At the end of the optimization process, the LM program with the set of prompts that performed best on the full validation set is returned. For those interested in more details, more information on `MIPROv2` along with a study on `MIPROv2` compared with other DSPy optimizers can be found in [this paper](https://arxiv.org/abs/2406.11695). \ No newline at end of file diff --git a/docs/docs/deep-dive/retrieval_models_clients/MilvusRM.md b/docs/docs/deep-dive/retrieval_models_clients/MilvusRM.md index 8e290eb48..acd3cd4ea 100644 --- a/docs/docs/deep-dive/retrieval_models_clients/MilvusRM.md +++ b/docs/docs/deep-dive/retrieval_models_clients/MilvusRM.md @@ -51,7 +51,7 @@ Search the Milvus collection for the top `k` passages matching the given query o from dspy.retrieve.milvus_rm import MilvusRM import os -os.envrion["OPENAI_API_KEY"] = "" +os.environ["OPENAI_API_KEY"] = "" retriever_model = MilvusRM( collection_name="", diff --git a/docs/docs/deep-dive/retrieval_models_clients/SnowflakeRM.md b/docs/docs/deep-dive/retrieval_models_clients/SnowflakeRM.md index e6689ea07..28211afe5 100644 --- a/docs/docs/deep-dive/retrieval_models_clients/SnowflakeRM.md +++ b/docs/docs/deep-dive/retrieval_models_clients/SnowflakeRM.md @@ -64,7 +64,7 @@ connection_parameters = { snowpark = Session.builder.configs(connection_parameters).create() snowflake_retriever = SnowflakeRM(snowflake_session=snowpark, - cortex_search_service="", + cortex_search_service="", snowflake_database="", snowflake_schema="", auto_filter=True, diff --git a/docs/docs/dspy-usecases.md b/docs/docs/dspy-usecases.md index a8cfc6208..e791d0ca9 100644 --- a/docs/docs/dspy-usecases.md +++ b/docs/docs/dspy-usecases.md @@ -58,27 +58,27 @@ WIP. This list mainly includes companies that have public posts or have OKed bei | **Name** | **Description/Link** | |---|---| -| **Stanford CS 224U Homework** | [Github](https://github.com/cgpotts/cs224u/blob/main/hw_openqa.ipynb) | -| **STORM Report Generation (10,000 GitHub stars)** | [Github](https://github.com/stanford-oval/storm) | -| **DSPy Redteaming** | [Github](https://github.com/haizelabs/dspy-redteam) | -| **DSPy Theory of Mind** | [Github](https://github.com/plastic-labs/dspy-opentom) | -| **Indic cross-lingual Natural Language Inference** | [Github](https://github.com/saifulhaq95/DSPy-Indic/blob/main/indicxlni.ipynb) | -| **Optimizing LM for Text2SQL using DSPy** | [Github](https://github.com/jjovalle99/DSPy-Text2SQL) | +| **Stanford CS 224U Homework** | [GitHub](https://github.com/cgpotts/cs224u/blob/main/hw_openqa.ipynb) | +| **STORM Report Generation (10,000 GitHub stars)** | [GitHub](https://github.com/stanford-oval/storm) | +| **DSPy Redteaming** | [GitHub](https://github.com/haizelabs/dspy-redteam) | +| **DSPy Theory of Mind** | [GitHub](https://github.com/plastic-labs/dspy-opentom) | +| **Indic cross-lingual Natural Language Inference** | [GitHub](https://github.com/saifulhaq95/DSPy-Indic/blob/main/indicxlni.ipynb) | +| **Optimizing LM for Text2SQL using DSPy** | [GitHub](https://github.com/jjovalle99/DSPy-Text2SQL) | | **DSPy PII Masking Demo by Eric Ness** | [Colab](https://colab.research.google.com/drive/1KZR1sGTp_RLWUJPAiK1FKPKI-Qn9neUm?usp=sharing) | -| **DSPy on BIG-Bench Hard Example** | [Github](https://drchrislevy.github.io/posts/dspy/dspy.html) | -| **Building a chess playing agent using DSPy** | [Github](https://medium.com/thoughts-on-machine-learning/building-a-chess-playing-agent-using-dspy-9b87c868f71e) | -| **Ittia Research Fact Checking** | [Github](https://github.com/ittia-research/check) | -| **Strategic Debate via Tree-of-Thought** | [Github](https://github.com/zbambergerNLP/strategic-debate-tot) | -| **Sanskrit to English Translation App**| [Github](https://github.com/ganarajpr/sanskrit-translator-dspy) | -| **DSPy for extracting features from PDFs on arXiv**| [Github](https://github.com/S1M0N38/dspy-arxiv) | -| **DSPygen: DSPy in Ruby on Rails**| [Github](https://github.com/seanchatmangpt/dspygen) | -| **DSPy Inspector**| [Github](https://github.com/Neoxelox/dspy-inspector) | -| **DSPy with FastAPI**| [Github](https://github.com/diicellman/dspy-rag-fastapi) | -| **DSPy for Indian Languages**| [Github](https://github.com/saifulhaq95/DSPy-Indic) | -| **Hurricane: Blog Posts with Generative Feedback Loops!**| [Github](https://github.com/weaviate-tutorials/Hurricane) | -| **RAG example using DSPy, Gradio, FastAPI, and Ollama**| [Github](https://github.com/diicellman/dspy-gradio-rag) | -| **Synthetic Data Generation**| [Github](https://colab.research.google.com/drive/1CweVOu0qhTC0yOfW5QkLDRIKuAuWJKEr?usp=sharing) | -| **Self Discover**| [Github](https://colab.research.google.com/drive/1GkAQKmw1XQgg5UNzzy8OncRe79V6pADB?usp=sharing) | +| **DSPy on BIG-Bench Hard Example** | [GitHub](https://drchrislevy.github.io/posts/dspy/dspy.html) | +| **Building a chess playing agent using DSPy** | [GitHub](https://medium.com/thoughts-on-machine-learning/building-a-chess-playing-agent-using-dspy-9b87c868f71e) | +| **Ittia Research Fact Checking** | [GitHub](https://github.com/ittia-research/check) | +| **Strategic Debate via Tree-of-Thought** | [GitHub](https://github.com/zbambergerNLP/strategic-debate-tot) | +| **Sanskrit to English Translation App**| [GitHub](https://github.com/ganarajpr/sanskrit-translator-dspy) | +| **DSPy for extracting features from PDFs on arXiv**| [GitHub](https://github.com/S1M0N38/dspy-arxiv) | +| **DSPygen: DSPy in Ruby on Rails**| [GitHub](https://github.com/seanchatmangpt/dspygen) | +| **DSPy Inspector**| [GitHub](https://github.com/Neoxelox/dspy-inspector) | +| **DSPy with FastAPI**| [GitHub](https://github.com/diicellman/dspy-rag-fastapi) | +| **DSPy for Indian Languages**| [GitHub](https://github.com/saifulhaq95/DSPy-Indic) | +| **Hurricane: Blog Posts with Generative Feedback Loops!**| [GitHub](https://github.com/weaviate-tutorials/Hurricane) | +| **RAG example using DSPy, Gradio, FastAPI, and Ollama**| [GitHub](https://github.com/diicellman/dspy-gradio-rag) | +| **Synthetic Data Generation**| [GitHub](https://colab.research.google.com/drive/1CweVOu0qhTC0yOfW5QkLDRIKuAuWJKEr?usp=sharing) | +| **Self Discover**| [GitHub](https://colab.research.google.com/drive/1GkAQKmw1XQgg5UNzzy8OncRe79V6pADB?usp=sharing) | TODO: This list in particular is highly incomplete. There are a couple dozen other good ones. diff --git a/docs/docs/quick-start/getting-started-01.md b/docs/docs/quick-start/getting-started-01.md index b276210ed..bd65cd5f2 100644 --- a/docs/docs/quick-start/getting-started-01.md +++ b/docs/docs/quick-start/getting-started-01.md @@ -167,7 +167,7 @@ pred = cot(**example.inputs()) score = metric(example, pred) print(f"Question: \t {example.question}\n") -print(f"Gold Reponse: \t {example.response}\n") +print(f"Gold Response: \t {example.response}\n") print(f"Predicted Response: \t {pred.response}\n") print(f"Semantic F1 Score: {score:.2f}") ``` @@ -176,7 +176,7 @@ print(f"Semantic F1 Score: {score:.2f}") ``` Question: what are high memory and low memory on linux? -Gold Reponse: "High Memory" refers to the application or user space, the memory that user programs can use and which isn't permanently mapped in the kernel's space, while "Low Memory" is the kernel's space, which the kernel can address directly and is permanently mapped. +Gold Response: "High Memory" refers to the application or user space, the memory that user programs can use and which isn't permanently mapped in the kernel's space, while "Low Memory" is the kernel's space, which the kernel can address directly and is permanently mapped. The user cannot access the Low Memory as it is set aside for the required kernel programs. Predicted Response: In Linux, "low memory" refers to the memory that is directly accessible by the kernel and user processes, typically the first 4GB on a 32-bit system. "High memory" refers to memory above this limit, which is not directly accessible by the kernel in a 32-bit environment. This distinction is crucial for memory management, particularly in systems with large amounts of RAM, as it influences how memory is allocated and accessed. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index a98084ba6..e487236b9 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -33,7 +33,7 @@ nav: - MultiChainComparison: deep-dive/modules/multi-chain-comparison.md - ProgramOfThought: deep-dive/modules/program-of-thought.md - Assertions: deep-dive/assertions.md - - Retreive: deep-dive/modules/retrieve.md + - Retrieve: deep-dive/modules/retrieve.md - Modules Guide: deep-dive/modules/guide.md - Metrics and Assertions: - Overview: building-blocks/7-assertions.md diff --git a/dsp/adapters/utils.py b/dsp/adapters/utils.py index 4bc15c058..1cbe604d4 100644 --- a/dsp/adapters/utils.py +++ b/dsp/adapters/utils.py @@ -53,7 +53,7 @@ def format_answers(answers: Union[str, list]) -> Optional[str]: ValueError: when is not of type list or str Returns: - _type_: Optiona[str] + _type_: Optional[str] """ if isinstance(answers, list): if len(answers) >= 1: diff --git a/dsp/modules/dummy_lm.py b/dsp/modules/dummy_lm.py index 49f35fa72..0e8c25fda 100644 --- a/dsp/modules/dummy_lm.py +++ b/dsp/modules/dummy_lm.py @@ -87,5 +87,5 @@ def __call__(self, prompt, _only_completed=True, _return_sorted=False, **kwargs) return [choice["text"] for choice in choices] def get_convo(self, index) -> str: - """Get the prompt + anwer from the ith message.""" + """Get the prompt + answer from the ith message.""" return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"] diff --git a/dsp/modules/groq_client.py b/dsp/modules/groq_client.py index e22c505ac..196a2b726 100644 --- a/dsp/modules/groq_client.py +++ b/dsp/modules/groq_client.py @@ -96,7 +96,7 @@ def basic_request(self, prompt: str, **kwargs): on_backoff=backoff_hdlr, ) def request(self, prompt: str, **kwargs): - """Handles retreival of model completions whilst handling rate limiting and caching.""" + """Handles retrieval of model completions whilst handling rate limiting and caching.""" if "model_type" in kwargs: del kwargs["model_type"] @@ -106,7 +106,7 @@ def _get_choice_text(self, choice) -> str: return choice.message.content def chat_request(self, **kwargs): - """Handles retreival of model completions whilst handling rate limiting and caching.""" + """Handles retrieval of model completions whilst handling rate limiting and caching.""" response = self.client.chat.completions.create(**kwargs) return response diff --git a/dsp/modules/hf.py b/dsp/modules/hf.py index 57f57d259..fc16c2ae9 100644 --- a/dsp/modules/hf.py +++ b/dsp/modules/hf.py @@ -49,7 +49,7 @@ def __init__( checkpoint (str, optional): load specific checkpoints of the model. Defaults to None. is_client (bool, optional): whether to access models via client. Defaults to False. hf_device_map (str, optional): HF config strategy to load the model. - Recommeded to use "auto", which will help loading large models using accelerate. Defaults to "auto". + Recommended to use "auto", which will help loading large models using accelerate. Defaults to "auto". model_kwargs (dict, optional): additional kwargs to pass to the model constructor. Defaults to empty dict. """ diff --git a/dsp/modules/premai.py b/dsp/modules/premai.py index 17e9d151f..15968c7c5 100644 --- a/dsp/modules/premai.py +++ b/dsp/modules/premai.py @@ -63,7 +63,7 @@ def __init__( api_key: Optional[str] Prem AI API key, to connect with the API. If not provided then it will check from env var by the name PREMAI_API_KEY - kwargs: Optional[dict] For any additional paramters + kwargs: Optional[dict] For any additional parameters """ self.model = "default" if model is None else model super().__init__(self.model) diff --git a/dsp/modules/sentence_vectorizer.py b/dsp/modules/sentence_vectorizer.py index 84be68420..61788eb98 100644 --- a/dsp/modules/sentence_vectorizer.py +++ b/dsp/modules/sentence_vectorizer.py @@ -211,7 +211,7 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray: class FastEmbedVectorizer(BaseSentenceVectorizer): - """Sentence vectorizer implementaion using FastEmbed - https://qdrant.github.io/fastembed.""" + """Sentence vectorizer implementation using FastEmbed - https://qdrant.github.io/fastembed.""" def __init__( self, diff --git a/dsp/primitives/demonstrate.py b/dsp/primitives/demonstrate.py index 331c31517..9b3ff9584 100644 --- a/dsp/primitives/demonstrate.py +++ b/dsp/primitives/demonstrate.py @@ -155,7 +155,7 @@ def knn( Args: train: a bunch of questions to put in index & search later - cast: function that contructs text before vectorization. By default, + cast: function that constructs text before vectorization. By default, it uses only question. Check `cast_naive_get_question_and_answer` for more details. n_probe: number of closest IVF-clusters to check for neighbours. Doesn't affect bruteforce-based search. diff --git a/dsp/utils/ann_utils.py b/dsp/utils/ann_utils.py index dcd3f09ce..a5378d3f7 100644 --- a/dsp/utils/ann_utils.py +++ b/dsp/utils/ann_utils.py @@ -100,12 +100,12 @@ def create_faiss_index( the difference between a vector and the reconstruction that can be decoded from its representation in the index. in_list_dist_type: type of distance to calculate simmilarities within one IVF. - Can be `IP` (for inner product) or `L2` distance. Case insensetive. + Can be `IP` (for inner product) or `L2` distance. Case insensitive. If the index type is bruteforce (`n_objects` < 20_000), this variable will define - the distane type for that bruteforce index. `centroid_dist_type` will be ignored. + the distance type for that bruteforce index. `centroid_dist_type` will be ignored. centroid_dist_type: type of distance to calculate simmilarities between a query and cluster centroids. Can be `IP` (for inner product) or `L2` distance. - Case insensetive. + Case insensitive. Returns: untrained FAISS-index """ if n_objects < 20_000: diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py index 0522475f8..f5d08c283 100644 --- a/dspy/clients/openai.py +++ b/dspy/clients/openai.py @@ -33,7 +33,7 @@ def is_openai_model(model: str) -> bool: if model in valid_model_names: return True - # Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI models + # Check if the model is a fine-tuned OpenAI model. Fine-tuned OpenAI models # have the prefix "ft::", followed by a string specifying # the fine-tuned model. The following RegEx pattern is used to match the # base model name. diff --git a/dspy/experimental/synthesizer/synthesizer.py b/dspy/experimental/synthesizer/synthesizer.py index bd70e6d7a..2fea0f85d 100644 --- a/dspy/experimental/synthesizer/synthesizer.py +++ b/dspy/experimental/synthesizer/synthesizer.py @@ -244,17 +244,17 @@ def generate( return data def export(self, data: List[dspy.Example], path: str, mode: str = None, **kwargs): - extention = mode or path.split(".")[-1] + extension = mode or path.split(".")[-1] dataset = Dataset.from_list( [example.toDict() for example in data], ) - if extention == "csv": + if extension == "csv": dataset.to_csv(path_or_buf=path, **kwargs) - elif extention == "json": + elif extension == "json": dataset.to_json(path_or_buf=path, **kwargs) - elif extention == "arrow" or extention == "hf": + elif extension == "arrow" or extension == "hf": dataset.save_to_disk(path) diff --git a/dspy/experimental/synthetic_data.py b/dspy/experimental/synthetic_data.py index 16bc5e797..9f965bcba 100644 --- a/dspy/experimental/synthetic_data.py +++ b/dspy/experimental/synthetic_data.py @@ -41,7 +41,7 @@ def generate(self, sample_size: int) -> List[dspy.Example]: def _define_or_infer_fields(self): """Define fields to generate if a schema class is provided. - Infer fields to generate if an inital sample of examples is provided. + Infer fields to generate if an initial sample of examples is provided. Returns: dict: dictionary of fields to generate diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 8bb826044..03d859763 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -62,7 +62,7 @@ def forward(self, **kwargs): class FunctionalModule(dspy.Module): - """To use the @cot and @predictor decorators, your module needs to inheret form this class.""" + """To use the @cot and @predictor decorators, your module needs to inherit form this class.""" def __init__(self): super().__init__() @@ -208,7 +208,7 @@ class Signature(dspy.Signature): task_description: str = dspy.InputField(desc="What I asked the model to do") language_model_output: str = dspy.InputField(desc="The output of the model") - error: str = dspy.InputField(desc="The validation error trigged by the models output") + error: str = dspy.InputField(desc="The validation error triggered by the models output") explanation: str = dspy.OutputField(desc="Explain what the model did wrong") advice: str = dspy.OutputField( desc="Instructions for the model to do better next time. A single paragraph.", diff --git a/dspy/retrieve/snowflake_rm.py b/dspy/retrieve/snowflake_rm.py index f58e205b7..7485a1954 100644 --- a/dspy/retrieve/snowflake_rm.py +++ b/dspy/retrieve/snowflake_rm.py @@ -17,14 +17,14 @@ class SnowflakeRM(dspy.Retrieve): - """A retrieval module that uses Snowlfake's Cortex Search service to return the top relevant passages for a given query. + """A retrieval module that uses Snowflake's Cortex Search service to return the top relevant passages for a given query. Assumes that a Snowflake Cortex Search endpoint has been configured by the use. For more information on configuring the Cortex Search service, visit: https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview Args: - snowflake_sesssion (object): Snowflake Snowpark session for accessing the service. + snowflake_session (object): Snowflake Snowpark session for accessing the service. cortex_search_service(str): Name of the Cortex Search service to be used. snowflake_database (str): The name of the Snowflake table containing document embeddings. snowflake_schema (str): The name of the Snowflake table containing document embeddings. @@ -241,7 +241,7 @@ class GenerateFilter(dspy.Signature): Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]} Answer: {"@or":[{"@eq":{"year":"2021"}},{"@eq":{"year":"2022"}},{"@eq":{"year":"2023"}},{"@eq":{"year":"2024"}}]} - Query: Wha is the sentiment of Biotech CEO's of companies based in New York? + Query: What is the sentiment of Biotech CEO's of companies based in New York? Attributes: industry,hq,date Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]} Answer: {"@and": [ { "@eq": { "industry"": "biotechnology" } }, { "@eq": { "HQ": "NY,US" } }]} diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 4ad7bab1b..88812be0c 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -141,7 +141,7 @@ def append(cls, name, field, type_=None) -> Type["Signature"]: return cls.insert(-1, name, field, type_) def insert(cls, index: int, name: str, field, type_: Type = None) -> Type["Signature"]: - # It's posisble to set the type as annotation=type in pydantic.Field(...) + # It's possible to set the type as annotation=type in pydantic.Field(...) # But this may be annoying for users, so we allow them to pass the type if type_ is None: type_ = field.annotation diff --git a/dspy/teleprompt/finetune_teleprompter.py b/dspy/teleprompt/finetune_teleprompter.py index 315434aba..57246fea6 100644 --- a/dspy/teleprompt/finetune_teleprompter.py +++ b/dspy/teleprompt/finetune_teleprompter.py @@ -62,10 +62,10 @@ def convert_to_module_level_message_data( prompt_completion_data = [] for data_dict in data: trace = data_dict["trace"] - trace_prompt_comletion_data = build_messages_from_trace( + trace_prompt_completion_data = build_messages_from_trace( trace=trace, exclude_demos=exclude_demos, try_to_record_lm_kwargs=try_to_record_lm_kwargs, program=program ) - for prompt_completion_dict in trace_prompt_comletion_data: + for prompt_completion_dict in trace_prompt_completion_data: if keep_data_keys: prompt_completion_dict = {**data_dict, **prompt_completion_dict} prompt_completion_data.append(prompt_completion_dict) diff --git a/dspy/teleprompt/mipro_optimizer.py b/dspy/teleprompt/mipro_optimizer.py index f0d96cc13..128ede37e 100644 --- a/dspy/teleprompt/mipro_optimizer.py +++ b/dspy/teleprompt/mipro_optimizer.py @@ -419,7 +419,7 @@ def compile( module = student.deepcopy() evaluate = Evaluate(devset=trainset, metric=self.metric, **eval_kwargs) - # In the case where the bootstrapped and labeled demos are set to 0, we'll stil bootstrap examples to use in our meta prompt + # In the case where the bootstrapped and labeled demos are set to 0, we'll still bootstrap examples to use in our meta prompt if ( max_bootstrapped_demos == 0 and max_labeled_demos == 0 ): # TODO: address case when max_bootstrapped alone is 0 diff --git a/dspy/teleprompt/signature_opt_typed.py b/dspy/teleprompt/signature_opt_typed.py index 3e5c7587f..cc1bb341c 100644 --- a/dspy/teleprompt/signature_opt_typed.py +++ b/dspy/teleprompt/signature_opt_typed.py @@ -19,7 +19,7 @@ def make_info(signature: type[Signature]) -> BaseModel: """Creates a SignatureInfo pydantic type, that describes the Signature. - Returns an instnce of this type, with the instructions and field descriptions of the input type. + Returns an instance of this type, with the instructions and field descriptions of the input type. """ # First, create the SignatureInfo type fields = { @@ -82,7 +82,7 @@ class GenerateInstructionInitial(Signature, Generic[T]): - You are an expert mathematician. - You are a professor of mathematics. Task Descriptions: - - Be consise in your answer. + - Be concise in your answer. - Be as clear as possible. - Use lots of creativity. Closers: diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 90f5def66..39ea37c61 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -95,7 +95,7 @@ def __call__(self, prompt, _only_completed=True, _return_sorted=False, **kwargs) return [choice["text"] for choice in choices] def get_convo(self, index) -> str: - """Get the prompt + anwer from the ith message.""" + """Get the prompt + answer from the ith message.""" return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"] @@ -209,7 +209,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): return outputs def get_convo(self, index): - """Get the prompt + anwer from the ith message.""" + """Get the prompt + answer from the ith message.""" return self.history[index]["messages"], self.history[index]["outputs"] diff --git a/internals/build-and-release.md b/internals/build-and-release.md index 76b3b400d..fb000ed61 100644 --- a/internals/build-and-release.md +++ b/internals/build-and-release.md @@ -56,4 +56,4 @@ Builds and publishes the package to pypi. 1. Publishes the package to pypi. -\* The package name is updated by the worfklow to allow the same files to be used to build both the pypi and test-pypi packages. \ No newline at end of file +\* The package name is updated by the workflow to allow the same files to be used to build both the pypi and test-pypi packages. \ No newline at end of file diff --git a/internals/release-checklist.md b/internals/release-checklist.md index 862ab7a5c..0213a52b7 100644 --- a/internals/release-checklist.md +++ b/internals/release-checklist.md @@ -9,10 +9,10 @@ * [ ] Confirm the tests pass and the package has been published to pypi. * If the tests fail, you can remove the tag from your local and github repo using: ```bash - git push origin --delete X.Y.Z # Delete on Github + git push origin --delete X.Y.Z # Delete on GitHub git tag -d X.Y.Z # Delete locally ``` - * Fix the errors and then repeat the steps above to recreate the tag locally and push to Github to restart the process. + * Fix the errors and then repeat the steps above to recreate the tag locally and push to GitHub to restart the process. * Note that the github action takes care of incrementing the release version on test-pypi automatically by adding a pre-release identifier in the scenario where the tests fail and you need to delete and push the same tag again. * [ ] [Create a release](https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository) * [ ] Add release notes. You can make use of [automatically generated release notes](https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes) diff --git a/testing/README.md b/testing/README.md index 6fc730072..78441af70 100644 --- a/testing/README.md +++ b/testing/README.md @@ -12,7 +12,7 @@ from optimizer_tester import OptimizerTester tester = OptimizerTester() ``` -The default verison (no parameters) expects a llama model hosted on ports [7140, 7141, 7142, 7143] and OpenAI keys stored in a .env file (OPENAI_API_KEY and OPENAI_API_BASE). +The default version (no parameters) expects a llama model hosted on ports [7140, 7141, 7142, 7143] and OpenAI keys stored in a .env file (OPENAI_API_KEY and OPENAI_API_BASE). If you prefer to specify your own model parameters then you can pass models into the OptimizerTester diff --git a/testing/tasks/heart_disease.py b/testing/tasks/heart_disease.py index f64026172..0885f475c 100644 --- a/testing/tasks/heart_disease.py +++ b/testing/tasks/heart_disease.py @@ -53,7 +53,7 @@ class HeartDiseaseInput(dspy.Signature): trestbps = dspy.InputField( desc="Resting blood pressure (in mm Hg on admission to the hospital)" ) - chol = dspy.InputField(desc="Serum cholestoral in mg/dl") + chol = dspy.InputField(desc="Serum cholesterol in mg/dl") fbs = dspy.InputField(desc="Fasting blood sugar > 120 mg/dl (true or false)") restecg = dspy.InputField( desc="Resting electrocardiographic results (normal, ST-T wave abnormality, left ventricular hypertrophy)" diff --git a/tests/dsp_LM/functional/test_functional.py b/tests/dsp_LM/functional/test_functional.py index 5e5274567..e71b41fc0 100644 --- a/tests/dsp_LM/functional/test_functional.py +++ b/tests/dsp_LM/functional/test_functional.py @@ -530,7 +530,7 @@ def f() -> Literal["2", "3"]: assert f() == "2" -def test_literal_missmatch(): +def test_literal_mismatch(): lm = DSPDummyLM([f'"{i}"' for i in range(5, 100)]) dspy.settings.configure(lm=lm) @@ -555,7 +555,7 @@ def f() -> Literal[2, 3]: assert f() == 2 -def test_literal_int_missmatch(): +def test_literal_int_mismatch(): lm = DSPDummyLM([f"{i}" for i in range(5, 100)]) dspy.settings.configure(lm=lm) @@ -893,7 +893,7 @@ class MySignature(dspy.Signature): category: str = dspy.OutputField() @model_validator(mode="after") - def check_cateogry(self): + def check_category(self): if self.category not in self.allowed_categories: raise ValueError(f"category not in {self.allowed_categories}") return self diff --git a/tests/dsp_LM/predict/test_react.py b/tests/dsp_LM/predict/test_react.py index 607105aba..6c8bbf70e 100644 --- a/tests/dsp_LM/predict/test_react.py +++ b/tests/dsp_LM/predict/test_react.py @@ -37,7 +37,7 @@ # def test_example_search(): -# # Createa a simple dataset which the model will use with the Retrieve tool. +# # Create a simple dataset which the model will use with the Retrieve tool. # lm = DSPDummyLM( # [ # "Initial thoughts", # Thought_1 @@ -49,7 +49,7 @@ # rm = dummy_rm( # [ # "We all know the color of the sky is blue.", -# "Somethng about the sky colors", +# "Something about the sky colors", # "This sentence is completely irellevant to answer the question.", # "Let's add some more sentences to act as summy passages.", # "Let's add some more sentences to act as summy passages.", @@ -79,7 +79,7 @@ # "Action 1: Search[the color of the sky]\n\n" # "Observation 1:\n" # "[1] «We all know the color of the sky is blue.»\n" -# "[2] «Somethng about the sky colors»\n" +# "[2] «Something about the sky colors»\n" # "[3] «This sentence is completely irellevant to answer the question.»\n\n" # "Thought 2: More thoughts\n\n" # "Action 2: finish[blue]" @@ -151,4 +151,4 @@ # "Observation 2: tool 2 output\n\n" # "Thought 3: Even more thoughts\n\n" # "Action 3: finish[baz]" -# ) +# ) \ No newline at end of file diff --git a/tests/dsp_LM/predict/test_retry.py b/tests/dsp_LM/predict/test_retry.py index bd22984d4..89cac67c9 100644 --- a/tests/dsp_LM/predict/test_retry.py +++ b/tests/dsp_LM/predict/test_retry.py @@ -72,7 +72,7 @@ def test_retry_forward_with_typed_predictor(): dspy.settings.configure(lm=lm, trace=[]) class AnswerQuestion(dspy.Signature): - """Answer questions with succint responses.""" + """Answer questions with succinct responses.""" class Input(pydantic.BaseModel): question: str diff --git a/tests/dsp_LM/teleprompt/test_mipro_optimizer.py b/tests/dsp_LM/teleprompt/test_mipro_optimizer.py index 86d8c00d0..c699be2eb 100644 --- a/tests/dsp_LM/teleprompt/test_mipro_optimizer.py +++ b/tests/dsp_LM/teleprompt/test_mipro_optimizer.py @@ -111,7 +111,7 @@ def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs): return [choice["text"] for choice in response["choices"]] def get_convo(self, index): - """get the prompt + anwer from the ith message""" + """get the prompt + answer from the ith message""" return self.history[index]["prompt"] + " " + self.history[index]["response"]["choices"][0]["text"] diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 9674cca19..8623c6f55 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -502,7 +502,7 @@ def f() -> Literal["2", "3"]: assert f() == "2" -def test_literal_missmatch(): +def test_literal_mismatch(): lm = DummyLM([{"f": f"{i}"} for i in range(5, 100)]) dspy.settings.configure(lm=lm) @@ -527,7 +527,7 @@ def f() -> Literal[2, 3]: assert f() == 2 -def test_literal_int_missmatch(): +def test_literal_int_mismatch(): lm = DummyLM([{"f": f"{i}"} for i in range(5, 100)]) dspy.settings.configure(lm=lm) @@ -823,7 +823,7 @@ class MySignature(dspy.Signature): category: str = dspy.OutputField() @model_validator(mode="after") - def check_cateogry(self): + def check_category(self): if self.category not in self.allowed_categories: raise ValueError(f"category not in {self.allowed_categories}") return self diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 1ba36b21b..8435f86a9 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -5,7 +5,7 @@ # def test_example_no_tools(): -# # Createa a simple dataset which the model will use with the Retrieve tool. +# # Create a simple dataset which the model will use with the Retrieve tool. # lm = DummyLM( # [ # {"Thought_1": "Initial thoughts", "Action_1": "Finish[blue]"}, @@ -25,7 +25,7 @@ # def test_example_search(): -# # Createa a simple dataset which the model will use with the Retrieve tool. +# # Create a simple dataset which the model will use with the Retrieve tool. # lm = DummyLM( # [ # {"Thought_1": "Initial thoughts", "Action_1": "Search[the color of the sky]"}, @@ -35,7 +35,7 @@ # rm = dummy_rm( # [ # "We all know the color of the sky is blue.", -# "Somethng about the sky colors", +# "Something about the sky colors", # "This sentence is completely irellevant to answer the question.", # "Let's add some more sentences to act as summy passages.", # "Let's add some more sentences to act as summy passages.", @@ -121,4 +121,4 @@ # react = dspy.ReAct(ExampleSignature) # assert react.react[0].signature.instructions is not None -# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.") +# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.") \ No newline at end of file diff --git a/tests/predict/test_retry.py b/tests/predict/test_retry.py index 687a18dbf..4289ab75e 100644 --- a/tests/predict/test_retry.py +++ b/tests/predict/test_retry.py @@ -59,7 +59,7 @@ def test_retry_forward_with_typed_predictor(): dspy.settings.configure(lm=lm, trace=[]) class AnswerQuestion(dspy.Signature): - """Answer questions with succint responses.""" + """Answer questions with succinct responses.""" class Input(pydantic.BaseModel): question: str diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index ea6633682..546fc5971 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -132,7 +132,7 @@ def test_complex_module_traversal(): "self.sub_module.nested_list[0]", "self.sub_module.nested_list[1][key]", # NOTE: named_sub_modules allows recursive structures "self.sub_module.nested_tuple[0]", - "self.sub_module.nested_tuple[1][0]", # NEW: named_sub_modules allows recursive structures, but named_prameters does not + "self.sub_module.nested_tuple[1][0]", # NEW: named_sub_modules allows recursive structures, but named_parameters does not # "self.sub_module.nested_tuple[1][1]", This should not be included, as it's the same module as the previous one } found_names = {name for name, _ in root.named_sub_modules()} From 00691955679773a93f4af294de0583f6b9bb02b8 Mon Sep 17 00:00:00 2001 From: sushmanth reddy <73489688+sushmanthreddy@users.noreply.github.com> Date: Sun, 3 Nov 2024 10:50:20 +0530 Subject: [PATCH 03/31] add the tests and load method (#1741) --- dspy/predict/predict.py | 30 +++++++++++++++++------ tests/predict/test_predict.py | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 7188044ed..0f93fd278 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -10,15 +10,12 @@ from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature, signature_to_template from dspy.utils.callback import with_callbacks - - from dspy.adapters.image_utils import Image @lru_cache(maxsize=None) def warn_once(msg: str): logging.warning(msg) - class Predict(Module, Parameter): def __init__(self, signature, _parse_values=True, callbacks=None, **config): self.stage = random.randbytes(8).hex() @@ -71,10 +68,13 @@ def load_state(self, state, use_legacy_loading=False): state (dict): The saved state of a `Predict` object. use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a saved state from a version of DSPy prior to v2.5.3. + Returns: + self: Returns self to allow method chaining """ if use_legacy_loading: self._load_state_legacy(state) - return + return self + if "signature" not in state: # Check if the state is from a version of DSPy prior to v2.5.3. raise ValueError( @@ -102,10 +102,14 @@ def load_state(self, state, use_legacy_loading=False): if "extended_signature" in state: self.extended_signature = self.extended_signature.load_state(state["extended_signature"]) + return self + def _load_state_legacy(self, state): """Legacy state loading for backwards compatibility. This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3. + Returns: + self: Returns self to allow method chaining """ for name, value in state.items(): setattr(self, name, value) @@ -130,6 +134,21 @@ def _load_state_legacy(self, state): *_, last_key = self.extended_signature.fields.keys() self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix) + return self + + def load(self, path, return_self=False): + """Load a saved state from a file. + + Args: + path (str): Path to the saved state file + return_self (bool): If True, returns self to allow method chaining. Default is False for backwards compatibility. + + Returns: + Union[None, Predict]: Returns None if return_self is False (default), returns self if return_self is True + """ + super().load(path) + return self if return_self else None + @with_callbacks def __call__(self, **kwargs): return self.forward(**kwargs) @@ -213,8 +232,6 @@ def old_generate(demos, signature, kwargs, config, lm, stage): with dsp.settings.context(lm=lm, query_only=True): x, C = dsp.generate(template, **config)(x, stage=stage) - # assert stage in x, "The generated (input, output) example was not stored" - completions = [] for c in C: @@ -279,7 +296,6 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True): lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values ) - # TODO: get some defaults during init from the context window? # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index 8d0121eed..3fb4fce4e 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -218,3 +218,48 @@ class OutputOnlySignature(dspy.Signature): lm = DummyLM([{"output": "short answer"}]) dspy.settings.configure(lm=lm) assert predictor().output == "short answer" + + + +def test_chainable_load(tmp_path): + """Test both traditional and chainable load methods.""" + + file_path = tmp_path / "test_chainable.json" + + + original = Predict("question -> answer") + original.demos = [{"question": "test", "answer": "response"}] + original.save(file_path) + + + traditional = Predict("question -> answer") + traditional.load(file_path) + assert traditional.demos == original.demos + + + chainable = Predict("question -> answer").load(file_path, return_self=True) + assert chainable is not None + assert chainable.demos == original.demos + + + assert chainable.signature.dump_state() == original.signature.dump_state() + + + result = Predict("question -> answer").load(file_path) + assert result is None + +def test_load_state_chaining(): + """Test that load_state returns self for chaining.""" + original = Predict("question -> answer") + original.demos = [{"question": "test", "answer": "response"}] + state = original.dump_state() + + + new_instance = Predict("question -> answer").load_state(state) + assert new_instance is not None + assert new_instance.demos == original.demos + + + legacy_instance = Predict("question -> answer").load_state(state, use_legacy_loading=True) + assert legacy_instance is not None + assert legacy_instance.demos == original.demos \ No newline at end of file From 2d3ed8d20d0ea3e404906101cd7955ddc20449fe Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Sun, 3 Nov 2024 09:17:41 -0800 Subject: [PATCH 04/31] fix flaky test --- tests/modules/test_hf_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/modules/test_hf_model.py b/tests/modules/test_hf_model.py index 0b06c8042..7a3287b80 100644 --- a/tests/modules/test_hf_model.py +++ b/tests/modules/test_hf_model.py @@ -23,6 +23,8 @@ def test_load_gated_model(mocker: MockerFixture): def test_load_ungated_model(mocker: MockerFixture): conf = MockConfig(architectures=["ConditionalGeneration"]) + # Mock the environment to ensure no default token is used + mocker.patch.dict('os.environ', {}, clear=True) # Clear environment variables mocker.patch("transformers.AutoModelForSeq2SeqLM.from_pretrained") mocker.patch("transformers.AutoConfig.from_pretrained", return_value=conf) mocker.patch("transformers.AutoTokenizer.from_pretrained") From 917065856791db14e07d2c026961f6eefd09fdb0 Mon Sep 17 00:00:00 2001 From: Isaac Miller <17116851+isaacbmiller@users.noreply.github.com> Date: Sun, 3 Nov 2024 13:50:21 -0800 Subject: [PATCH 05/31] Fix broken inspect_history and broken prompt cache (#1744) * Fix broken inspect_history and broken prompt cache * Remove errant print statement * Move global inspect history into base_lm * Remove skip parameter * Delete examples/temp.py * Minor adjustment to make adapters go back to original behavior --------- Co-authored-by: Omar Khattab --- dspy/__init__.py | 9 +-- dspy/adapters/chat_adapter.py | 3 +- dspy/clients/__init__.py | 2 +- dspy/clients/base_lm.py | 26 ++++++-- dspy/clients/lm.py | 4 +- dspy/utils/dummies.py | 1 + tests/clients/test_inspect_global_history.py | 68 ++++++++++++++++++++ 7 files changed, 94 insertions(+), 19 deletions(-) create mode 100644 tests/clients/test_inspect_global_history.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 84d360065..f80c8237f 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -12,7 +12,6 @@ from dspy.clients import * # isort: skip from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging - settings = dsp.settings configure_dspy_loggers(__name__) @@ -70,10 +69,4 @@ BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch COPRO = dspy.teleprompt.COPRO MIPROv2 = dspy.teleprompt.MIPROv2 -Ensemble = dspy.teleprompt.Ensemble - - -# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program. -def inspect_history(*args, **kwargs): - from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history - return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file +Ensemble = dspy.teleprompt.Ensemble \ No newline at end of file diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 27ef87ecd..6727cd245 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -209,7 +209,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= else: output[-1]["text"] += formatted_field_value["text"] if assume_text: - return "\n\n".join(output) + return "\n\n".join(output).strip() else: return output @@ -396,7 +396,6 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): parts.append(format_signature_fields_for_instructions(signature.input_fields)) parts.append(format_signature_fields_for_instructions(signature.output_fields)) parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True)) - instructions = textwrap.dedent(signature.instructions) objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 6a63509f5..0056db046 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,2 +1,2 @@ from .lm import LM -from .base_lm import BaseLM \ No newline at end of file +from .base_lm import BaseLM, inspect_history diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index b6f13d0ca..d71d384b6 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +GLOBAL_HISTORY = [] class BaseLM(ABC): def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs): @@ -14,7 +15,10 @@ def __call__(self, prompt=None, messages=None, **kwargs): pass def inspect_history(self, n: int = 1): - _inspect_history(self, n) + _inspect_history(self.history, n) + + def update_global_history(self, entry): + GLOBAL_HISTORY.append(entry) def _green(text: str, end: str = "\n"): @@ -24,15 +28,21 @@ def _green(text: str, end: str = "\n"): def _red(text: str, end: str = "\n"): return "\x1b[31m" + str(text) + "\x1b[0m" + end +def _blue(text: str, end: str = "\n"): + return "\x1b[34m" + str(text) + "\x1b[0m" + end + -def _inspect_history(lm, n: int = 1): +def _inspect_history(history, n: int = 1): """Prints the last n prompts and their completions.""" - for item in reversed(lm.history[-n:]): + for item in history[-n:]: messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] outputs = item["outputs"] + timestamp = item.get("timestamp", "Unknown time") print("\n\n\n") + print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") + for msg in messages: print(_red(f"{msg['role'].capitalize()} message:")) if isinstance(msg["content"], str): @@ -43,11 +53,13 @@ def _inspect_history(lm, n: int = 1): if c["type"] == "text": print(c["text"].strip()) elif c["type"] == "image_url": + image_str = "" if "base64" in c["image_url"].get("url", ""): len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - print(f"<{c['image_url']['url'].split('base64,')[0]}base64,") + image_str = f"<{c['image_url']['url'].split('base64,')[0]}base64," else: - print(f"") + image_str = f"" + print(_blue(image_str.strip())) print("\n") print(_red("Response:")) @@ -58,3 +70,7 @@ def _inspect_history(lm, n: int = 1): print(_red(choices_text, end="")) print("\n\n\n") + +def inspect_history(n: int = 1): + """The global history shared across all LMs.""" + return _inspect_history(GLOBAL_HISTORY, n) \ No newline at end of file diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 22e37019f..c8eba2d37 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -23,8 +23,6 @@ if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -GLOBAL_HISTORY = [] - logger = logging.getLogger(__name__) class LM(BaseLM): @@ -109,7 +107,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) - GLOBAL_HISTORY.append(entry) + self.update_global_history(entry) return outputs diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 39ea37c61..cadcc3371 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -205,6 +205,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): entry = dict(**entry, outputs=outputs, usage=0) entry = dict(**entry, cost=0) self.history.append(entry) + self.update_global_history(entry) return outputs diff --git a/tests/clients/test_inspect_global_history.py b/tests/clients/test_inspect_global_history.py new file mode 100644 index 000000000..f3bcf210d --- /dev/null +++ b/tests/clients/test_inspect_global_history.py @@ -0,0 +1,68 @@ +import pytest +from dspy.utils.dummies import DummyLM +from dspy.clients.base_lm import GLOBAL_HISTORY +import dspy + +@pytest.fixture(autouse=True) +def clear_history(): + GLOBAL_HISTORY.clear() + yield + +def test_inspect_history_basic(capsys): + # Configure a DummyLM with some predefined responses + lm = DummyLM([{"response": "Hello"}, {"response": "How are you?"}]) + dspy.settings.configure(lm=lm) + + # Make some calls to generate history + predictor = dspy.Predict("query: str -> response: str") + predictor(query="Hi") + predictor(query="What's up?") + + # Test inspecting all history + history = GLOBAL_HISTORY + print(capsys) + assert len(history) > 0 + assert isinstance(history, list) + assert all(isinstance(entry, dict) for entry in history) + assert all("messages" in entry for entry in history) + +def test_inspect_history_with_n(capsys): + lm = DummyLM([{"response": "One"}, {"response": "Two"}, {"response": "Three"}]) + dspy.settings.configure(lm=lm) + + # Generate some history + predictor = dspy.Predict("query: str -> response: str") + predictor(query="First") + predictor(query="Second") + predictor(query="Third") + + dspy.inspect_history(n=2) + # Test getting last 2 entries + out, err = capsys.readouterr() + assert not "First" in out + assert "Second" in out + assert "Third" in out + +def test_inspect_empty_history(capsys): + # Configure fresh DummyLM + lm = DummyLM([]) + dspy.settings.configure(lm=lm) + + # Test inspecting empty history + dspy.inspect_history() + history = GLOBAL_HISTORY + assert len(history) == 0 + assert isinstance(history, list) + +def test_inspect_history_n_larger_than_history(capsys): + lm = DummyLM([{"response": "First"}, {"response": "Second"}]) + dspy.settings.configure(lm=lm) + + predictor = dspy.Predict("query: str -> response: str") + predictor(query="Query 1") + predictor(query="Query 2") + + # Request more entries than exist + dspy.inspect_history(n=5) + history = GLOBAL_HISTORY + assert len(history) == 2 # Should return all available entries From e4e7e0be375686ecae91de4ffd2711f2990a09df Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Sun, 3 Nov 2024 14:10:17 -0800 Subject: [PATCH 06/31] add note --- testing/tasks/heart_disease.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/tasks/heart_disease.py b/testing/tasks/heart_disease.py index 0885f475c..02a56a316 100644 --- a/testing/tasks/heart_disease.py +++ b/testing/tasks/heart_disease.py @@ -53,7 +53,7 @@ class HeartDiseaseInput(dspy.Signature): trestbps = dspy.InputField( desc="Resting blood pressure (in mm Hg on admission to the hospital)" ) - chol = dspy.InputField(desc="Serum cholesterol in mg/dl") + chol = dspy.InputField(desc="Serum cholesterol in mg/dl") # Nov 2nd, 2024: fixed typo from `cholesteral` fbs = dspy.InputField(desc="Fasting blood sugar > 120 mg/dl (true or false)") restecg = dspy.InputField( desc="Resting electrocardiographic results (normal, ST-T wave abnormality, left ventricular hypertrophy)" From 7e781992e5fa6359f82bdb0cf0c78171e58066b4 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Sun, 3 Nov 2024 14:23:43 -0800 Subject: [PATCH 07/31] Add `dspy.Embedding` (#1735) * Add embedding model * force return type to be numpy array --------- Co-authored-by: Omar Khattab --- dspy/clients/__init__.py | 13 ++++++ dspy/clients/embedding.py | 77 +++++++++++++++++++++++++++++++++ dspy/clients/lm.py | 9 +--- tests/clients/test_embedding.py | 64 +++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 8 deletions(-) create mode 100644 dspy/clients/embedding.py create mode 100644 tests/clients/test_embedding.py diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 0056db046..ef9a8da5f 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,2 +1,15 @@ from .lm import LM from .base_lm import BaseLM, inspect_history +from .embedding import Embedding +import litellm +import os +from pathlib import Path +from litellm.caching import Cache + +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") +litellm.telemetry = False + +if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: + # accessed at run time by litellm; i.e., fine to keep after import + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" diff --git a/dspy/clients/embedding.py b/dspy/clients/embedding.py new file mode 100644 index 000000000..eec41c32b --- /dev/null +++ b/dspy/clients/embedding.py @@ -0,0 +1,77 @@ +import litellm +import numpy as np + + +class Embedding: + """DSPy embedding class. + + The class for computing embeddings for text inputs. This class provides a unified interface for both: + + 1. Hosted embedding models (e.g. OpenAI's text-embedding-3-small) via litellm integration + 2. Custom embedding functions that you provide + + For hosted models, simply pass the model name as a string (e.g. "openai/text-embedding-3-small"). The class will use + litellm to handle the API calls and caching. + + For custom embedding models, pass a callable function that: + - Takes a list of strings as input. + - Returns embeddings as either: + - A 2D numpy array of float32 values + - A 2D list of float32 values + - Each row should represent one embedding vector + + Args: + model: The embedding model to use. This can be either a string (representing the name of the hosted embedding + model, must be an embedding model supported by litellm) or a callable that represents a custom embedding + model. + + Examples: + Example 1: Using a hosted model. + + ```python + import dspy + + embedder = dspy.Embedding("openai/text-embedding-3-small") + embeddings = embedder(["hello", "world"]) + + assert embeddings.shape == (2, 1536) + ``` + + Example 2: Using a custom function. + + ```python + import dspy + + def my_embedder(texts): + return np.random.rand(len(texts), 10) + + embedder = dspy.Embedding(my_embedder) + embeddings = embedder(["hello", "world"]) + + assert embeddings.shape == (2, 10) + ``` + """ + + def __init__(self, model): + self.model = model + + def __call__(self, inputs, caching=True, **kwargs): + """Compute embeddings for the given inputs. + + Args: + inputs: The inputs to compute embeddings for, can be a single string or a list of strings. + caching: Whether to cache the embedding response, only valid when using a hosted embedding model. + kwargs: Additional keyword arguments to pass to the embedding model. + + Returns: + A 2-D numpy array of embeddings, one embedding per row. + """ + if isinstance(inputs, str): + inputs = [inputs] + if isinstance(self.model, str): + embedding_response = litellm.embedding(model=self.model, input=inputs, caching=caching, **kwargs) + return np.array([data["embedding"] for data in embedding_response.data], dtype=np.float32) + elif callable(self.model): + return np.array(self.model(inputs, **kwargs), dtype=np.float32) + else: + raise ValueError(f"`model` in `dspy.Embedding` must be a string or a callable, but got {type(self.model)}.") diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index c8eba2d37..567178432 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -5,26 +5,19 @@ import uuid from concurrent.futures import ThreadPoolExecutor from datetime import datetime -from pathlib import Path from typing import Any, Dict, List, Literal, Optional import litellm import ujson -from litellm.caching import Cache from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class from dspy.utils.callback import BaseCallback, with_callbacks -DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") -litellm.telemetry = False - -if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" logger = logging.getLogger(__name__) + class LM(BaseLM): """ A language model supporting chat or text completion requests for use with DSPy modules. diff --git a/tests/clients/test_embedding.py b/tests/clients/test_embedding.py new file mode 100644 index 000000000..d12850e52 --- /dev/null +++ b/tests/clients/test_embedding.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import Mock, patch +import numpy as np + +from dspy.clients.embedding import Embedding + + +# Mock response format similar to litellm's embedding response. +class MockEmbeddingResponse: + def __init__(self, embeddings): + self.data = [{"embedding": emb} for emb in embeddings] + self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + self.model = "mock_model" + self.object = "list" + + +def test_litellm_embedding(): + model = "text-embedding-ada-002" + inputs = ["hello", "world"] + mock_embeddings = [ + [0.1, 0.2, 0.3], # embedding for "hello" + [0.4, 0.5, 0.6], # embedding for "world" + ] + + with patch("litellm.embedding") as mock_litellm: + # Configure mock to return proper response format. + mock_litellm.return_value = MockEmbeddingResponse(mock_embeddings) + + # Create embedding instance and call it. + embedding = Embedding(model) + result = embedding(inputs) + + # Verify litellm was called with correct parameters. + mock_litellm.assert_called_once_with(model=model, input=inputs, caching=True) + + assert len(result) == len(inputs) + np.testing.assert_allclose(result, mock_embeddings) + + +def test_callable_embedding(): + inputs = ["hello", "world", "test"] + + expected_embeddings = [ + [0.1, 0.2, 0.3], # embedding for "hello" + [0.4, 0.5, 0.6], # embedding for "world" + [0.7, 0.8, 0.9], # embedding for "test" + ] + + def mock_embedding_fn(texts): + # Simple callable that returns random embeddings. + return expected_embeddings + + # Create embedding instance with callable + embedding = Embedding(mock_embedding_fn) + result = embedding(inputs) + + np.testing.assert_allclose(result, expected_embeddings) + + +def test_invalid_model_type(): + # Test that invalid model type raises ValueError + with pytest.raises(ValueError): + embedding = Embedding(123) # Invalid model type + embedding(["test"]) From 3b688c6ba1f8f269c1a24573490e0b51f33c851b Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Sun, 3 Nov 2024 14:25:08 -0800 Subject: [PATCH 08/31] update rationale to reasoning for docs for 2.5+ default CoT behavior (#1747) --- docs/docs/building-blocks/2-signatures.md | 6 ++-- docs/docs/building-blocks/3-modules.md | 8 ++--- .../deep-dive/modules/chain-of-thought.md | 29 ++++++++++++++----- docs/docs/deep-dive/modules/guide.md | 13 +++++---- docs/docs/deep-dive/optimizers/copro.md | 2 +- 5 files changed, 37 insertions(+), 21 deletions(-) diff --git a/docs/docs/building-blocks/2-signatures.md b/docs/docs/building-blocks/2-signatures.md index c2355cb6f..bca787e93 100644 --- a/docs/docs/building-blocks/2-signatures.md +++ b/docs/docs/building-blocks/2-signatures.md @@ -78,10 +78,10 @@ The 21-year-old Lee made seven appearances and scored one goal for West Ham last Many DSPy modules (except `dspy.Predict`) return auxiliary information by expanding your signature under the hood. -For example, `dspy.ChainOfThought` also adds a `rationale` field that includes the LM's reasoning before it generates the output `summary`. +For example, `dspy.ChainOfThought` also adds a `reasoning` field that includes the LM's reasoning before it generates the output `summary`. ```python -print("Rationale:", response.rationale) +print("Rationale:", response.reasoning) ``` **Output:** ```text @@ -147,7 +147,7 @@ faithfulness(context=context, text=text) **Output:** ```text Prediction( - rationale="produce the faithfulness. We know that Lee had two loan spells in League One last term, with Blackpool and then Colchester United. He scored twice for the U's but was unable to save them from relegation. However, there is no mention of him scoring three goals for Colchester United.", + reasoning="produce the faithfulness. We know that Lee had two loan spells in League One last term, with Blackpool and then Colchester United. He scored twice for the U's but was unable to save them from relegation. However, there is no mention of him scoring three goals for Colchester United.", faithfulness='False' ) ``` diff --git a/docs/docs/building-blocks/3-modules.md b/docs/docs/building-blocks/3-modules.md index b10428511..c3cc5975f 100644 --- a/docs/docs/building-blocks/3-modules.md +++ b/docs/docs/building-blocks/3-modules.md @@ -67,12 +67,12 @@ response.completions.answer Let's discuss the output object here. -The `dspy.ChainOfThought` module will generally inject a `rationale` before the output field(s) of your signature. +The `dspy.ChainOfThought` module will generally inject a `reasoning` before the output field(s) of your signature. -Let's inspect the (first) rationale and answer! +Let's inspect the (first) reasoning and answer! ```python -print(f"Rationale: {response.rationale}") +print(f"Reasoning: {response.reasoning}") print(f"Answer: {response.answer}") ``` **Output:** @@ -86,7 +86,7 @@ This is accessible whether we request one or many completions. We can also access the different completions as a list of `Prediction`s or as several lists, one for each field. ```python -response.completions[3].rationale == response.completions.rationale[3] +response.completions[3].reasoning == response.completions.reasoning[3] ``` **Output:** ```text diff --git a/docs/docs/deep-dive/modules/chain-of-thought.md b/docs/docs/deep-dive/modules/chain-of-thought.md index 9d7939283..1efb0fe94 100644 --- a/docs/docs/deep-dive/modules/chain-of-thought.md +++ b/docs/docs/deep-dive/modules/chain-of-thought.md @@ -13,15 +13,30 @@ class ChainOfThought(Predict): self.activated = activated - signature = ensure_signature(self.signature) + self.signature = signature = ensure_signature(signature) *_keys, last_key = signature.output_fields.keys() - rationale_type = rationale_type or dspy.OutputField( - prefix="Reasoning: Let's think step by step in order to", - desc="${produce the " + last_key + "}. We ...", - ) - - self.extended_signature = signature.prepend("rationale", rationale_type, type_=str) + prefix = "Reasoning: Let's think step by step in order to" + + if isinstance(dspy.settings.lm, dspy.LM): + desc = "${reasoning}" + elif dspy.settings.experimental: + desc = "${produce the output fields}. We ..." + else: + # For dspy <2.5 + desc = f"${{produce the {last_key}}}. We ..." + + rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) + + # Add "rationale" field to the output signature. + if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental: + extended_signature = signature.prepend("reasoning", rationale_type, type_=str) + else: + # For dspy <2.5 + extended_signature = signature.prepend("rationale", rationale_type, type_=str) + + self._predict = dspy.Predict(extended_signature, **config) + self._predict.extended_signature = extended_signature ``` **Parameters:** diff --git a/docs/docs/deep-dive/modules/guide.md b/docs/docs/deep-dive/modules/guide.md index f4edd79bd..f0e6beb33 100644 --- a/docs/docs/deep-dive/modules/guide.md +++ b/docs/docs/deep-dive/modules/guide.md @@ -87,18 +87,19 @@ response.completions.answer Let's discuss the output object here. -The `dspy.ChainOfThought` module will generally inject a `rationale` before the output field(s) of your signature. +The `dspy.ChainOfThought` module will generally inject a `reasoning` before the output field(s) of your signature. -Let's inspect the (first) rationale and answer! +Let's inspect the (first) reasoning and answer! ```python -print(f"Rationale: {response.rationale}") +print(f"Reasoning: {response.reasoning}") print(f"Answer: {response.answer}") ``` - Rationale: produce the answer. We can consider the fact that ColBERT has shown to outperform other state-of-the-art retrieval models in terms of efficiency and effectiveness. It uses contextualized embeddings and performs document retrieval in a way that is both accurate and scalable. - Answer: One great thing about the ColBERT retrieval model is its superior efficiency and effectiveness compared to other models. + Reasoning: ColBERT (Contextualized Late Interaction over BERT) is a retrieval model that effectively combines the strengths of dense retrieval and traditional BM25 methods. One of its great features is that it allows for efficient and scalable retrieval by using late interaction techniques, which enables the model to leverage the contextual embeddings generated by BERT while still maintaining a fast retrieval speed. This means that it can handle large document collections more effectively than many other models, providing both high relevance in search results and efficiency in processing time. + Answer: A great feature of the ColBERT retrieval model is its ability to efficiently combine contextualized embeddings from BERT with a late interaction mechanism, allowing for scalable and high-performance document retrieval. + This is accessible whether we request one or many completions. @@ -107,7 +108,7 @@ We can also access the different completions as a list of `Prediction`s or as se ```python -response.completions[3].rationale == response.completions.rationale[3] +response.completions[3].reasoning == response.completions.reasoning[3] ``` ```text diff --git a/docs/docs/deep-dive/optimizers/copro.md b/docs/docs/deep-dive/optimizers/copro.md index 0595c7037..c604a3260 100644 --- a/docs/docs/deep-dive/optimizers/copro.md +++ b/docs/docs/deep-dive/optimizers/copro.md @@ -47,7 +47,7 @@ class CoTPipeline(dspy.Module): result = self.predictor(question=question) return dspy.Prediction( answer=result.answer, - reasoning=result.rationale, + reasoning=result.reasoning, ) ``` From 03748368f6f2481ac68a2f6c7b824e3af5c1fc31 Mon Sep 17 00:00:00 2001 From: Alberto Romero Date: Mon, 4 Nov 2024 14:18:19 +0000 Subject: [PATCH 09/31] Added new resources on prompt optimization leveraging G-Eval metrics (#1750) --- docs/docs/tutorials/other_tutorial.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/docs/tutorials/other_tutorial.md b/docs/docs/tutorials/other_tutorial.md index 9f2b2ac6d..b3faaa29e 100644 --- a/docs/docs/tutorials/other_tutorial.md +++ b/docs/docs/tutorials/other_tutorial.md @@ -25,4 +25,5 @@ sidebar_position: 99999 - Hands-on Overviews of DSPy by the community: [DSPy Explained! by Connor Shorten](https://www.youtube.com/watch?v=41EfOY0Ldkc), [DSPy explained by code_your_own_ai](https://www.youtube.com/watch?v=ycfnKPxBMck), [DSPy Crash Course by AI Bites](https://youtu.be/5-zgASQKkKQ?si=3gnmVouT5_rpk_nu) - Interviews: [Weaviate Podcast in-person](https://www.youtube.com/watch?v=CDung1LnLbY), and you can find 6-7 other remote podcasts on YouTube from a few different perspectives/audiences. - **Tracing in DSPy** with Arize Phoenix: [Tutorial for tracing your prompts and the steps of your DSPy programs](https://colab.research.google.com/github/Arize-ai/phoenix/blob/main/tutorials/tracing/dspy_tracing_tutorial.ipynb) -- **Tracing & Optimization Tracking in DSPy** with Parea AI: [Tutorial on tracing & evaluating a DSPy RAG program](https://docs.parea.ai/tutorials/dspy-rag-trace-evaluate/tutorial) \ No newline at end of file +- **Tracing & Optimization Tracking in DSPy** with Parea AI: [Tutorial on tracing & evaluating a DSPy RAG program](https://docs.parea.ai/tutorials/dspy-rag-trace-evaluate/tutorial) +- **Prompt Optimization with DSPy and G-Eval Metrics** by Alberto Romero: [Medium article](https://medium.com/@a-romero/prompt-optimization-with-dspy-and-g-eval-metrics-e7d0bdd21b8b), [Repo](https://github.com/a-romero/dspy-risk-assessment), [Video](https://youtu.be/kK30U-XiiNI) \ No newline at end of file From ff8355fe01fdb558d3d0603ceca2ded9972591b1 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 4 Nov 2024 06:53:53 -0800 Subject: [PATCH 10/31] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f902bcd21..4b788a515 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ #replace_package_name_marker name="dspy", #replace_package_version_marker - version="2.5.15", + version="2.5.25", description="DSPy", long_description=long_description, long_description_content_type="text/markdown", From bb5105d96d1655b47e8396450d2e2e07c970e4e9 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 4 Nov 2024 06:54:12 -0800 Subject: [PATCH 11/31] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0b5875a6c..9fa15fe21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" #replace_package_name_marker name="dspy" #replace_package_version_marker -version="2.5.15" +version="2.5.25" description = "DSPy" readme = "README.md" authors = [{ name = "Omar Khattab", email = "okhattab@stanford.edu" }] From 6fbbd0312cbef101e01ff66d6fcea3a1ae24dc80 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 4 Nov 2024 07:22:47 -0800 Subject: [PATCH 12/31] Update setup.py --- dspy/.internal_dspyai/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/.internal_dspyai/setup.py b/dspy/.internal_dspyai/setup.py index 23ee045f3..981085483 100644 --- a/dspy/.internal_dspyai/setup.py +++ b/dspy/.internal_dspyai/setup.py @@ -8,7 +8,7 @@ #replace_package_name_marker name="dspy-ai", #replace_package_version_marker - version="2.5.8", + version="2.5.25", description="DSPy", long_description=long_description, long_description_content_type="text/markdown", From 2a788e849fe0e9adfab322871ec7d24f02cc2f98 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Mon, 4 Nov 2024 12:33:28 -0800 Subject: [PATCH 13/31] Remove `structlog` dependency (#1751) * remove structlog dependency * update poetry.lock --- poetry.lock | 19 +------------------ pyproject.toml | 2 -- requirements.txt | 1 - 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index 09912eef8..a8ee6afa8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6504,23 +6504,6 @@ files = [ {file = "striprtf-0.0.26.tar.gz", hash = "sha256:fdb2bba7ac440072d1c41eab50d8d74ae88f60a8b6575c6e2c7805dc462093aa"}, ] -[[package]] -name = "structlog" -version = "24.4.0" -description = "Structured Logging for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "structlog-24.4.0-py3-none-any.whl", hash = "sha256:597f61e80a91cc0749a9fd2a098ed76715a1c8a01f73e336b746504d1aad7610"}, - {file = "structlog-24.4.0.tar.gz", hash = "sha256:b27bfecede327a6d2da5fbc96bd859f114ecc398a6389d664f62085ee7ae6fc4"}, -] - -[package.extras] -dev = ["freezegun (>=0.2.8)", "mypy (>=1.4)", "pretend", "pytest (>=6.0)", "pytest-asyncio (>=0.17)", "rich", "simplejson", "twisted"] -docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-mermaid", "sphinxext-opengraph", "twisted"] -tests = ["freezegun (>=0.2.8)", "pretend", "pytest (>=6.0)", "pytest-asyncio (>=0.17)", "simplejson"] -typing = ["mypy (>=1.4)", "rich", "twisted"] - [[package]] name = "sympy" version = "1.13.1" @@ -7879,4 +7862,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "03caefc8740625d260d1092ca69a1dfdb56d6fb7ae92603613b81d466a70ebf9" +content-hash = "4c0c0eda720efe7fbc74f58ade43fcf01f61ee8295154dd74a1a70d6ddc30280" diff --git a/pyproject.toml b/pyproject.toml index 9fa15fe21..329bccfe6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ dependencies = [ "requests", "optuna", "pydantic~=2.0", - "structlog", "jinja2", "magicattr~=0.1.6", "litellm", @@ -130,7 +129,6 @@ groq = { version = "^0.4.2", optional = true } rich = "^13.7.1" psycopg2 = { version = "^2.9.9", optional = true } pgvector = { version = "^0.2.5", optional = true } -structlog = "^24.1.0" llama-index = {version = "^0.10.30", optional = true} snowflake-snowpark-python = { version = "*",optional=true, python = ">=3.9,<3.12" } jinja2 = "^3.1.3" diff --git a/requirements.txt b/requirements.txt index 1cdb62951..e454fcbb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,6 @@ pandas pydantic~=2.0 regex requests -structlog tenacity>=8.2.3 tqdm ujson From cadd6190dc4e02eda8978f8cb1504a8306112556 Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Tue, 5 Nov 2024 07:29:18 -0800 Subject: [PATCH 14/31] LMs: retry with exponential backoff for a limited set of error codes (#1753) * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * progress Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar --------- Signed-off-by: dbczumar --- dspy/clients/lm.py | 60 +++++++++++++++++++--- tests/clients/test_lm.py | 107 ++++++++++++++++++++++++++++++++++----- 2 files changed, 146 insertions(+), 21 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 567178432..6a55e4cfd 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,5 +1,4 @@ import functools -from .base_lm import BaseLM import logging import os import uuid @@ -7,13 +6,15 @@ from datetime import datetime from typing import Any, Dict, List, Literal, Optional -import litellm import ujson +from litellm import Router +from litellm.router import RetryPolicy from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class from dspy.utils.callback import BaseCallback, with_callbacks +from .base_lm import BaseLM logger = logging.getLogger(__name__) @@ -32,7 +33,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 3, + num_retries: int = 8, **kwargs, ): """ @@ -174,13 +175,55 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - return litellm.completion( - num_retries=num_retries, + router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) + return router.completion( cache=cache, **kwargs, ) +@functools.lru_cache(maxsize=None) +def _get_litellm_router(model: str, num_retries: int) -> Router: + """ + Get a LiteLLM router for the given model with the specified number of retries + for transient errors. + + Args: + model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4'). + num_retries: The number of times to retry a request if it fails transiently due to + network error, rate limiting, etc. Requests are retried with exponential + backoff. + Returns: + A LiteLLM router instance that can be used to query the given model. + """ + retry_policy = RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + # We don't retry on errors that are unlikely to be transient + # (e.g. bad request, invalid auth credentials) + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ) + + return Router( + # LiteLLM routers must specify a `model_list`, which maps model names passed + # to `completions()` into actual LiteLLM model names. For our purposes, the + # model name is the same as the LiteLLM model name, so we add a single + # entry to the `model_list` that maps the model name to itself + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": model, + }, + } + ], + retry_policy=retry_policy, + ) + + @functools.lru_cache(maxsize=None) def cached_litellm_text_completion(request, num_retries: int): return litellm_text_completion( @@ -197,6 +240,7 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] + text_completion_model_name = f"text-completion-openai/{model}" # Use the API key and base from the kwargs, or from the environment. api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") @@ -205,12 +249,12 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - return litellm.text_completion( + router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) + return router.text_completion( cache=cache, - model=f"text-completion-openai/{model}", + model=text_completion_model_name, api_key=api_key, api_base=api_base, prompt=prompt, - num_retries=num_retries, **kwargs, ) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index ef5d85a9c..61e5828ae 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,25 +1,106 @@ from unittest import mock -from dspy.clients.lm import LM +from litellm.router import RetryPolicy + +from dspy.clients.lm import LM, _get_litellm_router def test_lm_chat_respects_max_retries(): - lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17) + model_name = "openai/gpt4o" + num_retries = 17 + temperature = 0.5 + max_tokens = 100 + prompt = "Hello, world!" + + lm = LM( + model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens + ) + + MockRouter = mock.MagicMock() + mock_completion = mock.MagicMock() + MockRouter.completion = mock_completion - with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api: - lm(messages=[{"content": "Hello, world!", "role": "user"}]) + with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: + lm(prompt=prompt) - assert litellm_completion_api.call_count == 1 - assert litellm_completion_api.call_args[1]["max_retries"] == 17 - # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + MockRouterConstructor.assert_called_once_with( + model_list=[ + { + "model_name": model_name, + "litellm_params": { + "model": model_name, + }, + } + ], + retry_policy=RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ), + ) + mock_completion.assert_called_once_with( + model=model_name, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=max_tokens, + cache=mock.ANY, + ) def test_lm_completions_respects_max_retries(): - lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17) + model_name = "openai/gpt-3.5-turbo" + expected_model = "text-completion-" + model_name + num_retries = 17 + temperature = 0.5 + max_tokens = 100 + prompt = "Hello, world!" + api_base = "http://test.com" + api_key = "apikey" + + lm = LM( + model=model_name, + model_type="text", + num_retries=num_retries, + temperature=temperature, + max_tokens=max_tokens, + api_base=api_base, + api_key=api_key, + ) + + MockRouter = mock.MagicMock() + mock_text_completion = mock.MagicMock() + MockRouter.text_completion = mock_text_completion - with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api: - lm(prompt="Hello, world!") + with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: + lm(prompt=prompt) - assert litellm_completion_api.call_count == 1 - assert litellm_completion_api.call_args[1]["max_retries"] == 17 - # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + MockRouterConstructor.assert_called_once_with( + model_list=[ + { + "model_name": expected_model, + "litellm_params": { + "model": expected_model, + }, + } + ], + retry_policy=RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ), + ) + mock_text_completion.assert_called_once_with( + model=expected_model, + prompt=prompt + "\n\nBEGIN RESPONSE:", + temperature=temperature, + max_tokens=max_tokens, + api_key=api_key, + api_base=api_base, + cache=mock.ANY, + ) From d7d6faed071673dbc3e755fcfbc952018908bd30 Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:54:36 -0800 Subject: [PATCH 15/31] Fixes for OpenAI / Azure OpenAI compatibility with LiteLLM router (#1760) * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * Fix Signed-off-by: dbczumar --------- Signed-off-by: dbczumar --- dspy/clients/lm.py | 110 +++++++++++++++++++++++++++++---------- tests/clients/test_lm.py | 77 ++++++++++++++++++++------- 2 files changed, 141 insertions(+), 46 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 6a55e4cfd..3880d35f8 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -3,6 +3,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Literal, Optional @@ -164,6 +165,55 @@ def copy(self, **kwargs): return new_instance +@dataclass(frozen=True) +class _ProviderAPIConfig: + """ + API configurations for a provider (e.g. OpenAI, Azure OpenAI) + """ + + api_key: Optional[str] + api_base: Optional[str] + api_version: Optional[str] + # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication. + # For all other providers, this field is empty + azure_ad_token: Optional[str] + + +def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig: + """ + Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the + provider corresponding to the given model. + + Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating + the input dictionary. + """ + provider = _get_provider(model) + api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY") + api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE") + api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION") + if "azure" in provider: + azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN") + else: + azure_ad_token = None + return _ProviderAPIConfig( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + ) + + +def _get_provider(model: str) -> str: + """ + Extract the provider name from the model string of the format "/", + e.g. "openai/gpt-4". + + TODO: Not all the models are in the format of "provider/model" + """ + model = model.split("/", 1) + return model[0] if len(model) > 1 else "openai" + + @functools.lru_cache(maxsize=None) def cached_litellm_completion(request, num_retries: int): return litellm_completion( @@ -175,7 +225,8 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) + api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs) + router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config) return router.completion( cache=cache, **kwargs, @@ -183,7 +234,7 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s @functools.lru_cache(maxsize=None) -def _get_litellm_router(model: str, num_retries: int) -> Router: +def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router: """ Get a LiteLLM router for the given model with the specified number of retries for transient errors. @@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential backoff. + api_config: The API configurations (keys, base URL, etc.) for the provider + (OpenAI, Azure OpenAI, etc.) corresponding to the given model. + Returns: A LiteLLM router instance that can be used to query the given model. """ @@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: ContentPolicyViolationErrorRetries=0, ) + # LiteLLM routers must specify a `model_list`, which maps model names passed + # to `completions()` into actual LiteLLM model names. For our purposes, the + # model name is the same as the LiteLLM model name, so we add a single + # entry to the `model_list` that maps the model name to itself + litellm_params = { + "model": model, + } + if api_config.api_key is not None: + litellm_params["api_key"] = api_config.api_key + if api_config.api_base is not None: + litellm_params["api_base"] = api_config.api_base + if api_config.api_version is not None: + litellm_params["api_version"] = api_config.api_version + if api_config.azure_ad_token is not None: + litellm_params["azure_ad_token"] = api_config.azure_ad_token + model_list = [ + { + "model_name": model, + "litellm_params": litellm_params, + } + ] return Router( - # LiteLLM routers must specify a `model_list`, which maps model names passed - # to `completions()` into actual LiteLLM model names. For our purposes, the - # model name is the same as the LiteLLM model name, so we add a single - # entry to the `model_list` that maps the model name to itself - model_list=[ - { - "model_name": model, - "litellm_params": { - "model": model, - }, - } - ], + model_list=model_list, retry_policy=retry_policy, ) @@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int): def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - - # Extract the provider and model from the model string. - # TODO: Not all the models are in the format of "provider/model" - model = kwargs.pop("model").split("/", 1) - provider, model = model[0] if len(model) > 1 else "openai", model[-1] - text_completion_model_name = f"text-completion-openai/{model}" - - # Use the API key and base from the kwargs, or from the environment. - api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") - api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + model = kwargs.pop("model") + api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs) + model_name = model.split("/", 1)[-1] + text_completion_model_name = f"text-completion-openai/{model_name}" # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) + router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config) return router.text_completion( cache=cache, model=text_completion_model_name, - api_key=api_key, - api_base=api_base, prompt=prompt, **kwargs, ) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 61e5828ae..7b3a02481 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,20 +1,40 @@ from unittest import mock +import pytest from litellm.router import RetryPolicy from dspy.clients.lm import LM, _get_litellm_router -def test_lm_chat_respects_max_retries(): +@pytest.mark.parametrize("keys_in_env_vars", [True, False]) +def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): model_name = "openai/gpt4o" num_retries = 17 temperature = 0.5 max_tokens = 100 prompt = "Hello, world!" + api_version = "2024-02-01" + api_key = "apikey" + + lm_kwargs = { + "model": model_name, + "model_type": "chat", + "num_retries": num_retries, + "temperature": temperature, + "max_tokens": max_tokens, + } + if keys_in_env_vars: + api_base = "http://testfromenv.com" + monkeypatch.setenv("OPENAI_API_KEY", api_key) + monkeypatch.setenv("OPENAI_API_BASE", api_base) + monkeypatch.setenv("OPENAI_API_VERSION", api_version) + else: + api_base = "http://test.com" + lm_kwargs["api_key"] = api_key + lm_kwargs["api_base"] = api_base + lm_kwargs["api_version"] = api_version - lm = LM( - model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens - ) + lm = LM(**lm_kwargs) MockRouter = mock.MagicMock() mock_completion = mock.MagicMock() @@ -29,6 +49,9 @@ def test_lm_chat_respects_max_retries(): "model_name": model_name, "litellm_params": { "model": model_name, + "api_key": api_key, + "api_base": api_base, + "api_version": api_version, }, } ], @@ -50,25 +73,39 @@ def test_lm_chat_respects_max_retries(): ) -def test_lm_completions_respects_max_retries(): - model_name = "openai/gpt-3.5-turbo" - expected_model = "text-completion-" + model_name +@pytest.mark.parametrize("keys_in_env_vars", [True, False]) +def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): + model_name = "azure/gpt-3.5-turbo" + expected_model = "text-completion-openai/" + model_name.split("/")[-1] num_retries = 17 temperature = 0.5 max_tokens = 100 prompt = "Hello, world!" - api_base = "http://test.com" + api_version = "2024-02-01" api_key = "apikey" + azure_ad_token = "adtoken" + + lm_kwargs = { + "model": model_name, + "model_type": "text", + "num_retries": num_retries, + "temperature": temperature, + "max_tokens": max_tokens, + } + if keys_in_env_vars: + api_base = "http://testfromenv.com" + monkeypatch.setenv("AZURE_API_KEY", api_key) + monkeypatch.setenv("AZURE_API_BASE", api_base) + monkeypatch.setenv("AZURE_API_VERSION", api_version) + monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) + else: + api_base = "http://test.com" + lm_kwargs["api_key"] = api_key + lm_kwargs["api_base"] = api_base + lm_kwargs["api_version"] = api_version + lm_kwargs["azure_ad_token"] = azure_ad_token - lm = LM( - model=model_name, - model_type="text", - num_retries=num_retries, - temperature=temperature, - max_tokens=max_tokens, - api_base=api_base, - api_key=api_key, - ) + lm = LM(**lm_kwargs) MockRouter = mock.MagicMock() mock_text_completion = mock.MagicMock() @@ -83,6 +120,10 @@ def test_lm_completions_respects_max_retries(): "model_name": expected_model, "litellm_params": { "model": expected_model, + "api_key": api_key, + "api_base": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, }, } ], @@ -100,7 +141,5 @@ def test_lm_completions_respects_max_retries(): prompt=prompt + "\n\nBEGIN RESPONSE:", temperature=temperature, max_tokens=max_tokens, - api_key=api_key, - api_base=api_base, cache=mock.ANY, ) From 8c0be198496d46dcbab8f7c2cdd379079a557611 Mon Sep 17 00:00:00 2001 From: Hanna Moazam Date: Wed, 6 Nov 2024 02:37:19 +0000 Subject: [PATCH 16/31] Fixed by exluding attestations which were automatically enabled in an update to the release/v1 action (#1761) --- .github/workflows/build_and_release.yml | 10 +++++++--- dspy/.internal_dspyai/setup.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_and_release.yml b/.github/workflows/build_and_release.yml index da234b787..bc3e8b284 100644 --- a/.github/workflows/build_and_release.yml +++ b/.github/workflows/build_and_release.yml @@ -84,15 +84,19 @@ jobs: run: python3 setup.py sdist bdist_wheel - name: Publish distribution 📦 to PyPI (dspy) uses: pypa/gh-action-pypi-publish@release/v1 # This requires a trusted publisher to be setup in pypi - # Publish to dspy-ai + with: + attestations: false + # Publish to dspy-ai - name: Update version in setup.py (dspy-ai) run: sed -i '/#replace_package_version_marker/{n;s/version="[^"]*"/version="${{ needs.extract-tag.outputs.version }}"/;}' ./dspy/.internal_dspyai/setup.py - name: Update package name in setup.py run: sed -i '/#replace_package_name_marker/{n;s/name="[^"]*"/name="dspy-ai"/;}' ./dspy/.internal_dspyai/setup.py - name: Update dspy dependency version in setup.py run: | - sed -i '/#replace_dspy_version_marker/{n;s/dspy==[^"]*/dspy==${{ needs.extract-tag.outputs.version }}/;}' ./dspy/.internal_dspyai/setup.py + sed -i '/#replace_dspy_version_marker/{n;s/dspy==[^"]*/dspy>=${{ needs.extract-tag.outputs.version }}/;}' ./dspy/.internal_dspyai/setup.py - name: Build a binary wheel run: python3 ./dspy/.internal_dspyai/setup.py sdist bdist_wheel - name: Publish distribution 📦 to PyPI (dspy-ai) - uses: pypa/gh-action-pypi-publish@release/v1 # This requires a trusted publisher to be setup in pypi \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 # This requires a trusted publisher to be setup in pypi + with: + attestations: false \ No newline at end of file diff --git a/dspy/.internal_dspyai/setup.py b/dspy/.internal_dspyai/setup.py index 981085483..0996fbb1f 100644 --- a/dspy/.internal_dspyai/setup.py +++ b/dspy/.internal_dspyai/setup.py @@ -19,5 +19,5 @@ packages=find_packages(include=["dsp.*", "dspy.*", "dsp", "dspy"]), python_requires=">=3.9", #replace_dspy_version_marker - install_requires=["dspy==2.5.3"] + install_requires=["dspy>=2.5.3"] ) From 1cc9b003b89fda1ad0d4acc1e680e2d764e15e2b Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 18:54:04 -0800 Subject: [PATCH 17/31] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e454fcbb9..e8078b811 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diskcache httpx joblib~=1.3 json-repair -litellm<=1.49.1 +litellm==1.51.0 magicattr~=0.1.6 openai optuna From d86ff73f689b8bc98c06a4d1ed04d3ed84ccd529 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 18:54:22 -0800 Subject: [PATCH 18/31] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4b788a515..358382c1a 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ #replace_package_name_marker name="dspy", #replace_package_version_marker - version="2.5.25", + version="2.5.27", description="DSPy", long_description=long_description, long_description_content_type="text/markdown", From 189cb7956584c871776a6ca18781794a566e5903 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 18:54:49 -0800 Subject: [PATCH 19/31] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 329bccfe6..da6d11ea4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" #replace_package_name_marker name="dspy" #replace_package_version_marker -version="2.5.25" +version="2.5.27" description = "DSPy" readme = "README.md" authors = [{ name = "Omar Khattab", email = "okhattab@stanford.edu" }] @@ -133,7 +133,7 @@ llama-index = {version = "^0.10.30", optional = true} snowflake-snowpark-python = { version = "*",optional=true, python = ">=3.9,<3.12" } jinja2 = "^3.1.3" magicattr = "^0.1.6" -litellm = "1.49.1" +litellm = "1.51.0" diskcache = "^5.6.0" json-repair = "^0.30.0" tenacity = ">=8.2.3" From e7d2161cf4fef8ec728fc975ba30c6457c13d7cc Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 18:57:31 -0800 Subject: [PATCH 20/31] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e8078b811..70c3b992e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diskcache httpx joblib~=1.3 json-repair -litellm==1.51.0 +litellm>=1.50.0 magicattr~=0.1.6 openai optuna From 1864a27ce2ebaee4048f963d658fd4b3fc17cf09 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 18:58:11 -0800 Subject: [PATCH 21/31] Update requirements.txt --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 70c3b992e..e8078b811 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diskcache httpx joblib~=1.3 json-repair -litellm>=1.50.0 +litellm==1.51.0 magicattr~=0.1.6 openai optuna From e654062d3a3ac40b738a046d16737ba533c809ee Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 5 Nov 2024 19:25:50 -0800 Subject: [PATCH 22/31] Revert LiteLLM Router-based retries and upgrade poetry lock for litellm 1.51.0 (#1762) * Revert LiteLLM Router-based retries and upgrade poetry lock for litellm 1.51.0 * Temporarily remove retry tests * fix test --- dspy/clients/lm.py | 138 +++---------------- poetry.lock | 12 +- tests/clients/test_lm.py | 284 +++++++++++++++++++-------------------- 3 files changed, 167 insertions(+), 267 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 3880d35f8..567178432 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,21 +1,19 @@ import functools +from .base_lm import BaseLM import logging import os import uuid from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Literal, Optional +import litellm import ujson -from litellm import Router -from litellm.router import RetryPolicy from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class from dspy.utils.callback import BaseCallback, with_callbacks -from .base_lm import BaseLM logger = logging.getLogger(__name__) @@ -34,7 +32,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 8, + num_retries: int = 3, **kwargs, ): """ @@ -165,55 +163,6 @@ def copy(self, **kwargs): return new_instance -@dataclass(frozen=True) -class _ProviderAPIConfig: - """ - API configurations for a provider (e.g. OpenAI, Azure OpenAI) - """ - - api_key: Optional[str] - api_base: Optional[str] - api_version: Optional[str] - # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication. - # For all other providers, this field is empty - azure_ad_token: Optional[str] - - -def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig: - """ - Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the - provider corresponding to the given model. - - Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating - the input dictionary. - """ - provider = _get_provider(model) - api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY") - api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE") - api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION") - if "azure" in provider: - azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN") - else: - azure_ad_token = None - return _ProviderAPIConfig( - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - ) - - -def _get_provider(model: str) -> str: - """ - Extract the provider name from the model string of the format "/", - e.g. "openai/gpt-4". - - TODO: Not all the models are in the format of "provider/model" - """ - model = model.split("/", 1) - return model[0] if len(model) > 1 else "openai" - - @functools.lru_cache(maxsize=None) def cached_litellm_completion(request, num_retries: int): return litellm_completion( @@ -225,69 +174,13 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs) - router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config) - return router.completion( + return litellm.completion( + num_retries=num_retries, cache=cache, **kwargs, ) -@functools.lru_cache(maxsize=None) -def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router: - """ - Get a LiteLLM router for the given model with the specified number of retries - for transient errors. - - Args: - model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4'). - num_retries: The number of times to retry a request if it fails transiently due to - network error, rate limiting, etc. Requests are retried with exponential - backoff. - api_config: The API configurations (keys, base URL, etc.) for the provider - (OpenAI, Azure OpenAI, etc.) corresponding to the given model. - - Returns: - A LiteLLM router instance that can be used to query the given model. - """ - retry_policy = RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - # We don't retry on errors that are unlikely to be transient - # (e.g. bad request, invalid auth credentials) - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ) - - # LiteLLM routers must specify a `model_list`, which maps model names passed - # to `completions()` into actual LiteLLM model names. For our purposes, the - # model name is the same as the LiteLLM model name, so we add a single - # entry to the `model_list` that maps the model name to itself - litellm_params = { - "model": model, - } - if api_config.api_key is not None: - litellm_params["api_key"] = api_config.api_key - if api_config.api_base is not None: - litellm_params["api_base"] = api_config.api_base - if api_config.api_version is not None: - litellm_params["api_version"] = api_config.api_version - if api_config.azure_ad_token is not None: - litellm_params["azure_ad_token"] = api_config.azure_ad_token - model_list = [ - { - "model_name": model, - "litellm_params": litellm_params, - } - ] - return Router( - model_list=model_list, - retry_policy=retry_policy, - ) - - @functools.lru_cache(maxsize=None) def cached_litellm_text_completion(request, num_retries: int): return litellm_text_completion( @@ -299,18 +192,25 @@ def cached_litellm_text_completion(request, num_retries: int): def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - model = kwargs.pop("model") - api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs) - model_name = model.split("/", 1)[-1] - text_completion_model_name = f"text-completion-openai/{model_name}" + + # Extract the provider and model from the model string. + # TODO: Not all the models are in the format of "provider/model" + model = kwargs.pop("model").split("/", 1) + provider, model = model[0] if len(model) > 1 else "openai", model[-1] + + # Use the API key and base from the kwargs, or from the environment. + api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") + api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config) - return router.text_completion( + return litellm.text_completion( cache=cache, - model=text_completion_model_name, + model=f"text-completion-openai/{model}", + api_key=api_key, + api_base=api_base, prompt=prompt, + num_retries=num_retries, **kwargs, ) diff --git a/poetry.lock b/poetry.lock index a8ee6afa8..54ccaf2df 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -2411,13 +2411,13 @@ tests = ["aiohttp", "boto3", "duckdb", "pandas (>=1.4)", "polars (>=0.19)", "pyt [[package]] name = "litellm" -version = "1.49.1" +version = "1.51.0" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.49.1-py3-none-any.whl", hash = "sha256:2ba6689fe4ea3b0d69f56f2843caff6422497489e6252943b13ef1463f016728"}, - {file = "litellm-1.49.1.tar.gz", hash = "sha256:f51450ad823c8bdf057017009ae8bcce1a2810690b2f0d9dcdaff04ddc68209a"}, + {file = "litellm-1.51.0-py3-none-any.whl", hash = "sha256:0b2c20d116834166c8440e5698d7d927dbcc78fcaa08ce0c5cbea2d0de55ec6c"}, + {file = "litellm-1.51.0.tar.gz", hash = "sha256:8bf648677ee145a8fe5054a2e3f3a34895b9ab65a6015e4b94efca7ef406f466"}, ] [package.dependencies] @@ -2426,7 +2426,7 @@ click = "*" importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" -openai = ">=1.51.0" +openai = ">=1.52.0" pydantic = ">=2.0.0,<3.0.0" python-dotenv = ">=0.2.0" requests = ">=2.31.0,<3.0.0" @@ -7862,4 +7862,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "4c0c0eda720efe7fbc74f58ade43fcf01f61ee8295154dd74a1a70d6ddc30280" +content-hash = "92c91613bb51ec6d493672baf6c0d509ebb16ec0c7dd8d23f9cfd6e4654972cf" diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 7b3a02481..da4029e95 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,145 +1,145 @@ from unittest import mock import pytest -from litellm.router import RetryPolicy - -from dspy.clients.lm import LM, _get_litellm_router - - -@pytest.mark.parametrize("keys_in_env_vars", [True, False]) -def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): - model_name = "openai/gpt4o" - num_retries = 17 - temperature = 0.5 - max_tokens = 100 - prompt = "Hello, world!" - api_version = "2024-02-01" - api_key = "apikey" - - lm_kwargs = { - "model": model_name, - "model_type": "chat", - "num_retries": num_retries, - "temperature": temperature, - "max_tokens": max_tokens, - } - if keys_in_env_vars: - api_base = "http://testfromenv.com" - monkeypatch.setenv("OPENAI_API_KEY", api_key) - monkeypatch.setenv("OPENAI_API_BASE", api_base) - monkeypatch.setenv("OPENAI_API_VERSION", api_version) - else: - api_base = "http://test.com" - lm_kwargs["api_key"] = api_key - lm_kwargs["api_base"] = api_base - lm_kwargs["api_version"] = api_version - - lm = LM(**lm_kwargs) - - MockRouter = mock.MagicMock() - mock_completion = mock.MagicMock() - MockRouter.completion = mock_completion - - with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: - lm(prompt=prompt) - - MockRouterConstructor.assert_called_once_with( - model_list=[ - { - "model_name": model_name, - "litellm_params": { - "model": model_name, - "api_key": api_key, - "api_base": api_base, - "api_version": api_version, - }, - } - ], - retry_policy=RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ), - ) - mock_completion.assert_called_once_with( - model=model_name, - messages=[{"role": "user", "content": prompt}], - temperature=temperature, - max_tokens=max_tokens, - cache=mock.ANY, - ) - - -@pytest.mark.parametrize("keys_in_env_vars", [True, False]) -def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): - model_name = "azure/gpt-3.5-turbo" - expected_model = "text-completion-openai/" + model_name.split("/")[-1] - num_retries = 17 - temperature = 0.5 - max_tokens = 100 - prompt = "Hello, world!" - api_version = "2024-02-01" - api_key = "apikey" - azure_ad_token = "adtoken" - - lm_kwargs = { - "model": model_name, - "model_type": "text", - "num_retries": num_retries, - "temperature": temperature, - "max_tokens": max_tokens, - } - if keys_in_env_vars: - api_base = "http://testfromenv.com" - monkeypatch.setenv("AZURE_API_KEY", api_key) - monkeypatch.setenv("AZURE_API_BASE", api_base) - monkeypatch.setenv("AZURE_API_VERSION", api_version) - monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) - else: - api_base = "http://test.com" - lm_kwargs["api_key"] = api_key - lm_kwargs["api_base"] = api_base - lm_kwargs["api_version"] = api_version - lm_kwargs["azure_ad_token"] = azure_ad_token - - lm = LM(**lm_kwargs) - - MockRouter = mock.MagicMock() - mock_text_completion = mock.MagicMock() - MockRouter.text_completion = mock_text_completion - - with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: - lm(prompt=prompt) - - MockRouterConstructor.assert_called_once_with( - model_list=[ - { - "model_name": expected_model, - "litellm_params": { - "model": expected_model, - "api_key": api_key, - "api_base": api_base, - "api_version": api_version, - "azure_ad_token": azure_ad_token, - }, - } - ], - retry_policy=RetryPolicy( - TimeoutErrorRetries=num_retries, - RateLimitErrorRetries=num_retries, - InternalServerErrorRetries=num_retries, - BadRequestErrorRetries=0, - AuthenticationErrorRetries=0, - ContentPolicyViolationErrorRetries=0, - ), - ) - mock_text_completion.assert_called_once_with( - model=expected_model, - prompt=prompt + "\n\nBEGIN RESPONSE:", - temperature=temperature, - max_tokens=max_tokens, - cache=mock.ANY, - ) +# from litellm.router import RetryPolicy + +# from dspy.clients.lm import LM, _get_litellm_router + + +# @pytest.mark.parametrize("keys_in_env_vars", [True, False]) +# def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): +# model_name = "openai/gpt4o" +# num_retries = 17 +# temperature = 0.5 +# max_tokens = 100 +# prompt = "Hello, world!" +# api_version = "2024-02-01" +# api_key = "apikey" + +# lm_kwargs = { +# "model": model_name, +# "model_type": "chat", +# "num_retries": num_retries, +# "temperature": temperature, +# "max_tokens": max_tokens, +# } +# if keys_in_env_vars: +# api_base = "http://testfromenv.com" +# monkeypatch.setenv("OPENAI_API_KEY", api_key) +# monkeypatch.setenv("OPENAI_API_BASE", api_base) +# monkeypatch.setenv("OPENAI_API_VERSION", api_version) +# else: +# api_base = "http://test.com" +# lm_kwargs["api_key"] = api_key +# lm_kwargs["api_base"] = api_base +# lm_kwargs["api_version"] = api_version + +# lm = LM(**lm_kwargs) + +# MockRouter = mock.MagicMock() +# mock_completion = mock.MagicMock() +# MockRouter.completion = mock_completion + +# with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: +# lm(prompt=prompt) + +# MockRouterConstructor.assert_called_once_with( +# model_list=[ +# { +# "model_name": model_name, +# "litellm_params": { +# "model": model_name, +# "api_key": api_key, +# "api_base": api_base, +# "api_version": api_version, +# }, +# } +# ], +# retry_policy=RetryPolicy( +# TimeoutErrorRetries=num_retries, +# RateLimitErrorRetries=num_retries, +# InternalServerErrorRetries=num_retries, +# BadRequestErrorRetries=0, +# AuthenticationErrorRetries=0, +# ContentPolicyViolationErrorRetries=0, +# ), +# ) +# mock_completion.assert_called_once_with( +# model=model_name, +# messages=[{"role": "user", "content": prompt}], +# temperature=temperature, +# max_tokens=max_tokens, +# cache=mock.ANY, +# ) + + +# @pytest.mark.parametrize("keys_in_env_vars", [True, False]) +# def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): +# model_name = "azure/gpt-3.5-turbo" +# expected_model = "text-completion-openai/" + model_name.split("/")[-1] +# num_retries = 17 +# temperature = 0.5 +# max_tokens = 100 +# prompt = "Hello, world!" +# api_version = "2024-02-01" +# api_key = "apikey" +# azure_ad_token = "adtoken" + +# lm_kwargs = { +# "model": model_name, +# "model_type": "text", +# "num_retries": num_retries, +# "temperature": temperature, +# "max_tokens": max_tokens, +# } +# if keys_in_env_vars: +# api_base = "http://testfromenv.com" +# monkeypatch.setenv("AZURE_API_KEY", api_key) +# monkeypatch.setenv("AZURE_API_BASE", api_base) +# monkeypatch.setenv("AZURE_API_VERSION", api_version) +# monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) +# else: +# api_base = "http://test.com" +# lm_kwargs["api_key"] = api_key +# lm_kwargs["api_base"] = api_base +# lm_kwargs["api_version"] = api_version +# lm_kwargs["azure_ad_token"] = azure_ad_token + +# lm = LM(**lm_kwargs) + +# MockRouter = mock.MagicMock() +# mock_text_completion = mock.MagicMock() +# MockRouter.text_completion = mock_text_completion + +# with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: +# lm(prompt=prompt) + +# MockRouterConstructor.assert_called_once_with( +# model_list=[ +# { +# "model_name": expected_model, +# "litellm_params": { +# "model": expected_model, +# "api_key": api_key, +# "api_base": api_base, +# "api_version": api_version, +# "azure_ad_token": azure_ad_token, +# }, +# } +# ], +# retry_policy=RetryPolicy( +# TimeoutErrorRetries=num_retries, +# RateLimitErrorRetries=num_retries, +# InternalServerErrorRetries=num_retries, +# BadRequestErrorRetries=0, +# AuthenticationErrorRetries=0, +# ContentPolicyViolationErrorRetries=0, +# ), +# ) +# mock_text_completion.assert_called_once_with( +# model=expected_model, +# prompt=prompt + "\n\nBEGIN RESPONSE:", +# temperature=temperature, +# max_tokens=max_tokens, +# cache=mock.ANY, +# ) From 8bc3439052eb80ba4e5ba340c348a6e3b2c94d7c Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:01:20 -0800 Subject: [PATCH 23/31] update broken minimal example links (#1765) --- README.md | 2 +- docs/docs/deep-dive/optimizers/bootstrap-fewshot.md | 2 +- docs/docs/deep-dive/optimizers/miprov2.md | 2 +- docs/docs/tutorials/other_tutorial.md | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d892e1806..477d33b02 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ The DSPy documentation is divided into **tutorials** (step-by-step illustration | **Level** | **Tutorial** | **Run in Colab** | **Description** | | --- | ------------- | ------------- | ------------- | | Beginner | [**Getting Started**](intro.ipynb) | [](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/intro.ipynb) | Introduces the basic building blocks in DSPy. Tackles the task of complex question answering with HotPotQA. | -| Beginner | [**Minimal Working Example**](https://dspy-docs.vercel.app/docs/quick-start/minimal-example) | N/A | Builds and optimizes a very simple chain-of-thought program in DSPy for math question answering. Very short. | +| Beginner | [**Minimal Working Example**](/docs/docs/quick-start/getting-started-01.md) | N/A | Builds a very simple chain-of-thought program in DSPy for question answering. Very short. | | Beginner | [**Compiling for Tricky Tasks**](examples/nli/scone/scone.ipynb) | N/A | Teaches LMs to reason about logical statements and negation. Uses GPT-4 to bootstrap few-shot CoT demonstrations for GPT-3.5. Establishes a state-of-the-art result on [ScoNe](https://arxiv.org/abs/2305.19426). Contributed by [Chris Potts](https://twitter.com/ChrisGPotts/status/1740033519446057077). | | Beginner | [**Local Models & Custom Datasets**](examples/skycamp2023.ipynb) | [](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/examples/skycamp2023.ipynb) | Illustrates two different things together: how to use local models (Llama-2-13B in particular) and how to use your own data examples for training and development. | Intermediate | [**The DSPy Paper**](https://arxiv.org/abs/2310.03714) | N/A | Sections 3, 5, 6, and 7 of the DSPy paper can be consumed as a tutorial. They include explained code snippets, results, and discussions of the abstractions and API. diff --git a/docs/docs/deep-dive/optimizers/bootstrap-fewshot.md b/docs/docs/deep-dive/optimizers/bootstrap-fewshot.md index f7c6fca9d..2f963c717 100644 --- a/docs/docs/deep-dive/optimizers/bootstrap-fewshot.md +++ b/docs/docs/deep-dive/optimizers/bootstrap-fewshot.md @@ -8,7 +8,7 @@ When compiling a DSPy program, we generally invoke an optimizer that takes the p ## Setting up a Sample Pipeline -We will be making a basic answer generation pipeline over GSM8K dataset that we saw in the [Minimal Example](https://dspy-docs.vercel.app/docs/quick-start/minimal-example). We won't be changing anything in it! So let's start by configuring the LM which will be OpenAI LM client with `gpt-3.5-turbo` as the LLM in use. +We will be making a basic answer generation pipeline over the GSM8K dataset. We won't be changing anything in it! So let's start by configuring the LM which will be OpenAI LM client with `gpt-3.5-turbo` as the LLM in use. ```python import dspy diff --git a/docs/docs/deep-dive/optimizers/miprov2.md b/docs/docs/deep-dive/optimizers/miprov2.md index 4d60167cf..f0ec47522 100644 --- a/docs/docs/deep-dive/optimizers/miprov2.md +++ b/docs/docs/deep-dive/optimizers/miprov2.md @@ -12,7 +12,7 @@ sidebar_position: 6 ### Setting up a Sample Pipeline -We'll be making a basic answer generation pipeline over GSM8K dataset that we saw in the [Minimal Example](/quick-start/minimal-example), we won't be changing anything in it! So let's start by configuring the LM which will be OpenAI LM client with `gpt-3.5-turbo` as the LLM in use. +We'll be making a basic answer generation pipeline over the GSM8K dataset. So let's start by configuring the LM which will be OpenAI LM client with `gpt-3.5-turbo` as the LLM in use. ```python import dspy diff --git a/docs/docs/tutorials/other_tutorial.md b/docs/docs/tutorials/other_tutorial.md index b3faaa29e..c7cde9fce 100644 --- a/docs/docs/tutorials/other_tutorial.md +++ b/docs/docs/tutorials/other_tutorial.md @@ -9,7 +9,7 @@ sidebar_position: 99999 | **Level** | **Tutorial** | **Run in Colab** | **Description** | | --- | ------------- | ------------- | ------------- | | Beginner | [**Getting Started**](https://github.com/stanfordnlp/dspy/blob/main/intro.ipynb) | [](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/intro.ipynb) | Introduces the basic building blocks in DSPy. Tackles the task of complex question answering with HotPotQA. | -| Beginner | [**Minimal Working Example**](/quick-start/minimal-example) | N/A | Builds and optimizes a very simple chain-of-thought program in DSPy for math question answering. Very short. | +| Beginner | [**Minimal Working Example**](/docs/docs/quick-start/getting-started-01.md) | N/A | Builds a very simple chain-of-thought program in DSPy for question answering. Very short. | | Beginner | [**Compiling for Tricky Tasks**](https://github.com/stanfordnlp/dspy/blob/main/examples/nli/scone/scone.ipynb) | N/A | Teaches LMs to reason about logical statements and negation. Uses GPT-4 to bootstrap few-shot CoT demonstrations for GPT-3.5. Establishes a state-of-the-art result on [ScoNe](https://arxiv.org/abs/2305.19426). Contributed by [Chris Potts](https://twitter.com/ChrisGPotts/status/1740033519446057077). | | Beginner | [**Local Models & Custom Datasets**](https://github.com/stanfordnlp/dspy/blob/main/skycamp2023.ipynb) | [](https://colab.research.google.com/github/stanfordnlp/dspy/blob/main/skycamp2023.ipynb) | Illustrates two different things together: how to use local models (Llama-2-13B in particular) and how to use your own data examples for training and development. | Intermediate | [**The DSPy Paper**](https://arxiv.org/abs/2310.03714) | N/A | Sections 3, 5, 6, and 7 of the DSPy paper can be consumed as a tutorial. They include explained code snippets, results, and discussions of the abstractions and API. From 89c33f2f308f10a051826cb9ccf7071ecd874335 Mon Sep 17 00:00:00 2001 From: Tari Yekorogha Date: Thu, 7 Nov 2024 00:12:39 +0100 Subject: [PATCH 24/31] Add FalkordbRM Retriever Class for Enhanced Querying and Embedding Retrieval (#1653) * added Falkordb DSPY support * added Falkordbrm support * removed print and changed "neo4j" to falkordb * add falkordb dependencies to the extras * Resolve poetry.lock conflicts --- .../retrieval_models_clients/FalkordbRM.md | 90 ++++++++ dspy/retrieve/falkordb_rm.py | 213 ++++++++++++++++++ poetry.lock | 187 ++++++++------- pyproject.toml | 6 + setup.py | 3 +- 5 files changed, 418 insertions(+), 81 deletions(-) create mode 100644 docs/docs/deep-dive/retrieval_models_clients/FalkordbRM.md create mode 100644 dspy/retrieve/falkordb_rm.py diff --git a/docs/docs/deep-dive/retrieval_models_clients/FalkordbRM.md b/docs/docs/deep-dive/retrieval_models_clients/FalkordbRM.md new file mode 100644 index 000000000..ffb320229 --- /dev/null +++ b/docs/docs/deep-dive/retrieval_models_clients/FalkordbRM.md @@ -0,0 +1,90 @@ +# FalkordbRM + +### Constructor + +Initialize an instance of the `FalkordbRM` class. + +```python +FalkordbRM( + node_label: str, + text_node_property: str, + embedding_node_property: str, + k: int = 5, + retrieval_query: str, + embedding_provider: str = "openai", + embedding_model: str = "text-embedding-ada-002", +) +``` + +**Environment Variables:** + +You need to define the credentials as environment variables + +- `FALKORDB_HOST` (_str_): Specifies the host required for connecting with the Falkordb database. If not provided, the system will default to `localhost` + +- `FALKORDB_PORT` (_int_): Specifies the port required for connecting with the Falkordb database. If not provided, the system will default to `6379` + +- `FALKORDB_USERNAME` (_str_, optional): Specifies the username required for authenticating with a [Falkordb Cloud](https://app.falkordb.cloud/signin) database. + +- `FALKORDB_PASSWORD` (_str_, optional): Specifies the password required for authenticating with a [Falkordb Cloud](https://app.falkordb.cloud/signin) database. + +- `FALKORDB_DATABASE` (_str_, optional): Specifies the name of the database to connect to within the Falkordb instance. If not provided, the systems defaults to using a randomly generated four ascii_letters character string e.g "tari". + +- `OPENAI_API_KEY` (_str_): Specifies the API key required for authenticating with OpenAI's services. + +**Parameters:** + +- `node_label` (_str_): Specifies the label of the node to be used within Falkordb for organizing and querying data. +- `text_node_property` (_str_, _optional_): Defines the specific text property of the node that will be returned. +- `embedding_node_property` (_str_): Defines the specific embedding property of the node that will be used within Falkordb for querying data. +- `k` (_int_, _optional_): The number of top results to return from the retrieval operation. It defaults to 5 if not explicitly specified. +- `retrieval_query` (_str_, _optional_): A custom query string provided for retrieving data. If not provided, a default query tailored to the `text_node_property` will be used. +- `embedding_provider` (_str_, _optional_): The name of the service provider for generating embeddings. Only "openai" is supported. +- `embedding_model` (_str_, _optional_): The specific embedding model to use from the provider. By default, it uses the "text-embedding-ada-002" model from OpenAI. + + +### Methods + +#### `forward(self, query: [str], k: Optional[int] = None) -> dspy.Prediction` + +Search the Falkordb vector index for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_model`. + +**Parameters:** + +- `query` (str\_): The query. +- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. + +**Returns:** + +- `dspy.Prediction`: Contains the retrieved passages as a list of string with the prediction signature. + +ex: + +```python +Prediction( + passages=['Passage 1 Lorem Ipsum awesom', 'Passage 2 Lorem Ipsum Youppidoo', 'Passage 3 Lorem Ipsum Yassssss'] +) +``` + +### Quick Example how to use Falkordb in a local environment. + +```python +from dspy.retrieve.falkordb_rm import FalkordbRM +import os + + +os.environ["FALKORDB_HOST"] = 'localhost' +os.environ["FALKORDB_PORT"] = 6379 +os.environ["OPENAI_API_KEY"] = 'sk-' + +retriever_model = FalkordbRM( + node_label="myIndex", + text_node_property="text", + embedding_node_property="embedding" +) + +results = retriever_model("Explore the significance of quantum computing", k=3) + +for passage in results: + print("Document:", passage, "\n") +``` \ No newline at end of file diff --git a/dspy/retrieve/falkordb_rm.py b/dspy/retrieve/falkordb_rm.py new file mode 100644 index 000000000..f56144bb3 --- /dev/null +++ b/dspy/retrieve/falkordb_rm.py @@ -0,0 +1,213 @@ +import os +from typing import List, Optional, Union +import string +import random + +import backoff +from openai import ( + APITimeoutError, + InternalServerError, + OpenAI, + RateLimitError, + UnprocessableEntityError, +) + +from dspy import Retrieve, Prediction +from dsp.utils.settings import settings +from dsp.utils import dotdict + +try: + import falkordb +except ImportError: + raise ImportError( + "Please install the falkordb package by running `pip install dspy-ai[falkordb]`" + ) +import redis.exceptions + + +def generate_random_string(length: int) -> str: + characters = string.ascii_letters + random_string = "".join(random.choice(characters) for _ in range(length)) + return random_string + + +class Embedder: + def __init__(self, provider: str, model: str): + self.provider = provider + if self.provider == "openai": + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "Environment variable OPENAI_API_KEY must be set to" + "use openai as embedding provider" + ) + self.client = OpenAI() + self.model = model + else: + raise ValueError(f"Unsupported provider: {provider}") + + @backoff.on_exception( + backoff.expo, + ( + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, + ), + max_time=settings.backoff_time, + ) + def __call__(self, queries: Union[str, List[str]]) -> List[List[float]]: + if isinstance(queries, str): + queries = [queries] + + if self.provider == "openai": + embedding = self.client.embeddings.create(input=queries, model=self.model) + return [result.embedding for result in embedding.data] + + +DEFAULT_INDEX_QUERY = "CALL db.idx.vector.queryNodes($node_label, $embedding_node_property, $k, vecf32($embedding)) YIELD node, score " + + +class FalkordbRM(Retrieve): + """ + Implements a retriever that utilizes FalkorDB for retrieving passages. + This class manages a connection to a FalkorDB database using official FalkorDB Python drivers and requires + the database credentials. That is, if using a local FalkorDB session, host and port else if using a FalkorDB cloud session, + host, port, username, and password to be set as environment variables and optionally the database name. + Additionally, it utilizes an embedding provider (defaulting to OpenAI's services) to compute query embeddings, + which are then used to find the most relevant nodes in the FalkorDB graph based on the specified node property or custom retrieval query. + + Returns a list of passages in the form of `dspy.Prediction` objects + + Args: + Args: + node_label (str): The label of the node in the FalkorDB database to query against + text_node_property (str): The property of the node containing the text. + embedding_node_property (List[float]): The property of the node containing the embeddings. + k (Optional[int]): The default number of top passages to retrieve. Defaults to 5. + retrieval_query (Optional[str]): Custom Cypher query for retrieving passages. + embedding_provider (str): The provider of the embedding service. Defaults to "openai". + embedding_model (str): The model identifier for generating embeddings. Defaults to "text-embedding-ada-002". + + Examples: + Below is a code snippet showcasing how to initialize FalkordbRM with environment variables for the database connection and OpenAI as the embedding provider: + + ```python + import os + + import dspy + import openai + + os.environ["FALKORDB_HOST"] = "localhost" + os.environ["FALORDB_PORT"] = "6379" + os.environ["OPENAI_API_KEY"] = "sk-" (Only if using openai as embedding's provider) + + # Uncomment and set the following if you are using FalkorDB cloud + # os.environ["FALKORDB_USERNAME"] = "falkordb" + # os.environ["FALKORDB_PASSWORD"] = "password" + + + falkordb_retriever = FalkordbRM( + node_label="myIndex", + text_node_property="text", + k=10, + embedding_provider="openai", + embedding_model="text-embedding-ada-002", + ) + + dspy.settings.configure(rm=falkordb_retriever) + ``` + + In this example, `FalkordbRM` is configured to retrieve nodes based on the "text" property from an index on a node labeled "myIndex", + using embeddings computed by OpenAI's "text-embedding-ada-002" model. + """ + + def __init__( + self, + node_label: str, + text_node_property: str = None, + embedding_node_property: str = None, + k: int = 5, + retrieval_query: Optional[str] = None, + embedding_provider: str = "openai", + embedding_model: str = "text-embedding-ada-002", + ): + super().__init__(k=k) + self.node_label = node_label + self.username = os.getenv("FALKORDB_USERNAME", None) + self.password = os.getenv("FALKORDB_PASSWORD", None) + self.host = os.getenv("FALKORDB_HOST", "localhost") + self.port = int(os.getenv("FALKORDB_PORT", 6379)) + + self.database = os.getenv("FALKORDB_DATABASE", generate_random_string(4)) + self.k = k + self.retrieval_query = retrieval_query + self.text_node_property = text_node_property + self.embedding_node_property = embedding_node_property + if not self.text_node_property and not self.retrieval_query: + raise ValueError( + "Either `text_node_property` or `retrieval_query` must be set" + ) + if not embedding_node_property: + raise ValueError("`embedding_node_property` must be set") + try: + self.driver = falkordb.FalkorDB( + host=self.host, + port=self.port, + username=self.username, + password=self.password, + ).select_graph(self.database) + + except ( + redis.exceptions.ConnectionError, + redis.exceptions.AuthenticationError, + ) as e: + raise ConnectionError("Failed to connect to FalkorDB database") from e + + self.embedder = Embedder(provider=embedding_provider, model=embedding_model) + + def forward( + self, query_or_queries: Union[str, List[str]], k: Optional[int] + ) -> Prediction: + if not isinstance(query_or_queries, list): + query_or_queries = [query_or_queries] + query_vectors = self.embedder(query_or_queries) + contents = [] + retrieval_query = ( + self.retrieval_query + or f"RETURN node.{self.text_node_property} AS text, score" + ) + if not k: + k = self.k + + for vector in query_vectors: + params = { + "embedding": vector, + "node_label": self.node_label, + "text_node_property": self.text_node_property, + "embedding_node_property": self.embedding_node_property, + "k": k, + } + try: + records = self.driver.query( + DEFAULT_INDEX_QUERY + retrieval_query, + params=params, + ).result_set + except Exception as e: + if "Invalid arguments" in str(e): + raise ValueError( + f"There is no vector index on node label, {self.node_label}" + f" and node property, {self.embedding_node_property}" + ) + contents.extend( + [ + {"passage": dotdict({"long_text": r[1]}), "score": r[0]} + for r in records + ] + ) + sorted_passages = sorted( + contents, + key=lambda x: x["score"], + reverse=True, + )[:k] + return [el["passage"] for el in sorted_passages] diff --git a/poetry.lock b/poetry.lock index 54ccaf2df..9b890c757 100644 --- a/poetry.lock +++ b/poetry.lock @@ -180,13 +180,13 @@ files = [ [[package]] name = "anthropic" -version = "0.37.1" +version = "0.38.0" description = "The official Python library for the anthropic API" optional = true python-versions = ">=3.7" files = [ - {file = "anthropic-0.37.1-py3-none-any.whl", hash = "sha256:8f550f88906823752e2abf99fbe491fbc8d40bce4cb26b9663abdf7be990d721"}, - {file = "anthropic-0.37.1.tar.gz", hash = "sha256:99f688265795daa7ba9256ee68eaf2f05d53cd99d7417f4a0c2dc292c106d00a"}, + {file = "anthropic-0.38.0-py3-none-any.whl", hash = "sha256:2c8117b53da7051d8ab65f4e8e05925bd53c53380183115802ace77bde14d4eb"}, + {file = "anthropic-0.38.0.tar.gz", hash = "sha256:417e1bdecc2e3b5a1f122be950d6ac570bba62d1cdefb33efbac3797413ec5f1"}, ] [package.dependencies] @@ -476,17 +476,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.35.53" +version = "1.35.54" description = "The AWS SDK for Python" optional = true python-versions = ">=3.8" files = [ - {file = "boto3-1.35.53-py3-none-any.whl", hash = "sha256:a9c0955df0b52b43749d81bde159343a40ea2a3537a46049336fe8193871b18e"}, - {file = "boto3-1.35.53.tar.gz", hash = "sha256:f4124548bb831e13504e805f2fbbfcee06df42fffea0655862c6eb9b95d6d1be"}, + {file = "boto3-1.35.54-py3-none-any.whl", hash = "sha256:2d5e160b614db55fbee7981001c54476cb827c441cef65b2fcb2c52a62019909"}, + {file = "boto3-1.35.54.tar.gz", hash = "sha256:7d9c359bbbc858a60b51c86328db813353c8bd1940212cdbd0a7da835291c2e1"}, ] [package.dependencies] -botocore = ">=1.35.53,<1.36.0" +botocore = ">=1.35.54,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -495,13 +495,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.53" +version = "1.35.54" description = "Low-level, data-driven core of boto 3." optional = true python-versions = ">=3.8" files = [ - {file = "botocore-1.35.53-py3-none-any.whl", hash = "sha256:12869640f2f9fab3152ea312a6906d5bc6ae15522cc74b6367ee1c273269a28b"}, - {file = "botocore-1.35.53.tar.gz", hash = "sha256:e610ae076ad1eaed5680d3990493659bbabdffd67b15c61d8373a23e4bc41062"}, + {file = "botocore-1.35.54-py3-none-any.whl", hash = "sha256:9cca1811094b6cdc144c2c063a3ec2db6d7c88194b04d4277cd34fc8e3473aff"}, + {file = "botocore-1.35.54.tar.gz", hash = "sha256:131bb59ce59c8a939b31e8e647242d70cf11d32d4529fa4dca01feea1e891a76"}, ] [package.dependencies] @@ -1220,6 +1220,19 @@ files = [ [package.extras] tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"] +[[package]] +name = "falkordb" +version = "1.0.9" +description = "Python client for interacting with FalkorDB database" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "falkordb-1.0.9.tar.gz", hash = "sha256:177008e63c7e4d9ebbdfeb8cad24b0e49175bb0f6e96cac9b4ffb641c0eff0f1"}, +] + +[package.dependencies] +redis = ">=5.0.1,<6.0.0" + [[package]] name = "fastapi" version = "0.115.4" @@ -2917,13 +2930,13 @@ urllib3 = ">=1.26.0,<2.0.0" [[package]] name = "marshmallow" -version = "3.23.0" +version = "3.23.1" description = "A lightweight library for converting complex datatypes to and from native Python datatypes." optional = true python-versions = ">=3.9" files = [ - {file = "marshmallow-3.23.0-py3-none-any.whl", hash = "sha256:82f20a2397834fe6d9611b241f2f7e7b680ed89c49f84728a1ad937be6b4bdf4"}, - {file = "marshmallow-3.23.0.tar.gz", hash = "sha256:98d8827a9f10c03d44ead298d2e99c6aea8197df18ccfad360dae7f89a50da2e"}, + {file = "marshmallow-3.23.1-py3-none-any.whl", hash = "sha256:fece2eb2c941180ea1b7fcbd4a83c51bfdd50093fdd3ad2585ee5e1df2508491"}, + {file = "marshmallow-3.23.1.tar.gz", hash = "sha256:3a8dfda6edd8dcdbf216c0ede1d1e78d230a6dc9c5a088f58c4083b974a0d468"}, ] [package.dependencies] @@ -2931,7 +2944,7 @@ packaging = ">=17.0" [package.extras] dev = ["marshmallow[tests]", "pre-commit (>=3.5,<5.0)", "tox"] -docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.13)", "sphinx (==8.1.3)", "sphinx-issues (==5.0.0)", "sphinx-version-warning (==1.1.2)"] +docs = ["alabaster (==1.0.0)", "autodocsumm (==0.2.14)", "sphinx (==8.1.3)", "sphinx-issues (==5.0.0)", "sphinx-version-warning (==1.1.2)"] tests = ["pytest", "simplejson"] [[package]] @@ -3869,36 +3882,32 @@ reference = ["Pillow", "google-re2"] [[package]] name = "onnxruntime" -version = "1.19.2" +version = "1.20.0" description = "ONNX Runtime is a runtime accelerator for Machine Learning models" optional = true python-versions = "*" files = [ - {file = "onnxruntime-1.19.2-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:84fa57369c06cadd3c2a538ae2a26d76d583e7c34bdecd5769d71ca5c0fc750e"}, - {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdc471a66df0c1cdef774accef69e9f2ca168c851ab5e4f2f3341512c7ef4666"}, - {file = "onnxruntime-1.19.2-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e3a4ce906105d99ebbe817f536d50a91ed8a4d1592553f49b3c23c4be2560ae6"}, - {file = "onnxruntime-1.19.2-cp310-cp310-win32.whl", hash = "sha256:4b3d723cc154c8ddeb9f6d0a8c0d6243774c6b5930847cc83170bfe4678fafb3"}, - {file = "onnxruntime-1.19.2-cp310-cp310-win_amd64.whl", hash = "sha256:17ed7382d2c58d4b7354fb2b301ff30b9bf308a1c7eac9546449cd122d21cae5"}, - {file = "onnxruntime-1.19.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:d863e8acdc7232d705d49e41087e10b274c42f09e259016a46f32c34e06dc4fd"}, - {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c1dfe4f660a71b31caa81fc298a25f9612815215a47b286236e61d540350d7b6"}, - {file = "onnxruntime-1.19.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36511dc07c5c964b916697e42e366fa43c48cdb3d3503578d78cef30417cb84"}, - {file = "onnxruntime-1.19.2-cp311-cp311-win32.whl", hash = "sha256:50cbb8dc69d6befad4746a69760e5b00cc3ff0a59c6c3fb27f8afa20e2cab7e7"}, - {file = "onnxruntime-1.19.2-cp311-cp311-win_amd64.whl", hash = "sha256:1c3e5d415b78337fa0b1b75291e9ea9fb2a4c1f148eb5811e7212fed02cfffa8"}, - {file = "onnxruntime-1.19.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:68e7051bef9cfefcbb858d2d2646536829894d72a4130c24019219442b1dd2ed"}, - {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d2d366fbcc205ce68a8a3bde2185fd15c604d9645888703785b61ef174265168"}, - {file = "onnxruntime-1.19.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:477b93df4db467e9cbf34051662a4b27c18e131fa1836e05974eae0d6e4cf29b"}, - {file = "onnxruntime-1.19.2-cp312-cp312-win32.whl", hash = "sha256:9a174073dc5608fad05f7cf7f320b52e8035e73d80b0a23c80f840e5a97c0147"}, - {file = "onnxruntime-1.19.2-cp312-cp312-win_amd64.whl", hash = "sha256:190103273ea4507638ffc31d66a980594b237874b65379e273125150eb044857"}, - {file = "onnxruntime-1.19.2-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:636bc1d4cc051d40bc52e1f9da87fbb9c57d9d47164695dfb1c41646ea51ea66"}, - {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5bd8b875757ea941cbcfe01582970cc299893d1b65bd56731e326a8333f638a3"}, - {file = "onnxruntime-1.19.2-cp38-cp38-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b2046fc9560f97947bbc1acbe4c6d48585ef0f12742744307d3364b131ac5778"}, - {file = "onnxruntime-1.19.2-cp38-cp38-win32.whl", hash = "sha256:31c12840b1cde4ac1f7d27d540c44e13e34f2345cf3642762d2a3333621abb6a"}, - {file = "onnxruntime-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:016229660adea180e9a32ce218b95f8f84860a200f0f13b50070d7d90e92956c"}, - {file = "onnxruntime-1.19.2-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:006c8d326835c017a9e9f74c9c77ebb570a71174a1e89fe078b29a557d9c3848"}, - {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df2a94179a42d530b936f154615b54748239c2908ee44f0d722cb4df10670f68"}, - {file = "onnxruntime-1.19.2-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fae4b4de45894b9ce7ae418c5484cbf0341db6813effec01bb2216091c52f7fb"}, - {file = "onnxruntime-1.19.2-cp39-cp39-win32.whl", hash = "sha256:dc5430f473e8706fff837ae01323be9dcfddd3ea471c900a91fa7c9b807ec5d3"}, - {file = "onnxruntime-1.19.2-cp39-cp39-win_amd64.whl", hash = "sha256:38475e29a95c5f6c62c2c603d69fc7d4c6ccbf4df602bd567b86ae1138881c49"}, + {file = "onnxruntime-1.20.0-cp310-cp310-macosx_13_0_universal2.whl", hash = "sha256:2ac38bc6cbf7bb8527ded58711af6ef2c8c59d070f0fde58f83824422526922a"}, + {file = "onnxruntime-1.20.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5cfd5a22abc11b273ec76fa773e22db19b749e27bf1ed05dd50d207f1817aae1"}, + {file = "onnxruntime-1.20.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6b5daee2d03909b589f1a9ab24c325cc3c33ab7f736228158784fb1a97a92308"}, + {file = "onnxruntime-1.20.0-cp310-cp310-win32.whl", hash = "sha256:e1eb08c13f91f830eb8df4f4e17a2a2652d1165f50bbed4f28f2afbf425c55d7"}, + {file = "onnxruntime-1.20.0-cp310-cp310-win_amd64.whl", hash = "sha256:cfcc1d21a12076bcc213441b405c48e1f21dedb36943e31eb93cb7a12b34678e"}, + {file = "onnxruntime-1.20.0-cp311-cp311-macosx_13_0_universal2.whl", hash = "sha256:3398354e9145c68edc09dbc72265401150027e76716ae758e8d9b52e6a7ddca0"}, + {file = "onnxruntime-1.20.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8a831b720d0a7be8241a230cb06f592e8bb66652d7cea54ce02d83769651fdee"}, + {file = "onnxruntime-1.20.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:041fefe60af844ebd90f344c84f908201490555cd0a6d78dd0a7acdc27b59972"}, + {file = "onnxruntime-1.20.0-cp311-cp311-win32.whl", hash = "sha256:83da64d2824809d0f6977db8bfc5091f742c26f09dfd66a3934e673780f5f87a"}, + {file = "onnxruntime-1.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:bfa390046332f5fca6f8af8c9d17164621ac52e66b11518e187278b19364800c"}, + {file = "onnxruntime-1.20.0-cp312-cp312-macosx_13_0_universal2.whl", hash = "sha256:97c2b91bfea063f9c3457422d28a336bfd2859001cd880645adfa7184e29dd79"}, + {file = "onnxruntime-1.20.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:51e7b34e398089c4ed8d0f50722d7a64a4d5f11b38c4a42576458a03c6dbc72e"}, + {file = "onnxruntime-1.20.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e259378ff2843321e0bf4552adcbee48822c91d77d42dde78b87dcdf10ad01f"}, + {file = "onnxruntime-1.20.0-cp312-cp312-win32.whl", hash = "sha256:428abc1f7d8eb425887e2b7726044f2af7b5a098359455e7d2d92343f04ad0ff"}, + {file = "onnxruntime-1.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:d5f23cbfeb546e16ffea81c28d2e796a53197fdc6c92540648e2aa53a7c7a637"}, + {file = "onnxruntime-1.20.0-cp313-cp313-macosx_13_0_universal2.whl", hash = "sha256:95b91126bc3e1754868da1d3d2d08a7a10279b8ff5cea5e34e92fbe3fd691dcf"}, + {file = "onnxruntime-1.20.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d57c10d7729347d6663f32b3f569f33d69a95e150d37ff6af4be9b9ab1ffdc25"}, + {file = "onnxruntime-1.20.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b9c38735dac127d0eeb957ec312c8f1ae90ecae2779a55b2fa279aa7bd116cbd"}, + {file = "onnxruntime-1.20.0-cp313-cp313-win_amd64.whl", hash = "sha256:25514cec4ea251d492aa1e38a7395d8801e64a4c940a154aef84cfad97ae4628"}, + {file = "onnxruntime-1.20.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:640ad9ea72d322f0325a51544eddb54f4fa843c4348573c88a9cb44f46678f3f"}, + {file = "onnxruntime-1.20.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dc4e7c10c98c1f407835448c26a7e14ebff3234f131e1fbc53bd9500c828df89"}, ] [package.dependencies] @@ -4873,13 +4882,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pydantic-settings" -version = "2.6.0" +version = "2.6.1" description = "Settings management using Pydantic" optional = true python-versions = ">=3.8" files = [ - {file = "pydantic_settings-2.6.0-py3-none-any.whl", hash = "sha256:4a819166f119b74d7f8c765196b165f95cc7487ce58ea27dec8a5a26be0970e0"}, - {file = "pydantic_settings-2.6.0.tar.gz", hash = "sha256:44a1804abffac9e6a30372bb45f6cafab945ef5af25e66b1c634c01dd39e0188"}, + {file = "pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87"}, + {file = "pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0"}, ] [package.dependencies] @@ -5430,6 +5439,24 @@ files = [ [package.extras] test = ["pytest (>=3.0)", "pytest-asyncio"] +[[package]] +name = "redis" +version = "5.2.0" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.8" +files = [ + {file = "redis-5.2.0-py3-none-any.whl", hash = "sha256:ae174f2bb3b1bf2b09d54bf3e51fbc1469cf6c10aa03e21141f51969801a7897"}, + {file = "redis-5.2.0.tar.gz", hash = "sha256:0b1087665a771b1ff2e003aa5bdd354f15a70c9e25d5a7dbf9c722c16528a7b0"}, +] + +[package.dependencies] +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.35.1" @@ -5585,13 +5612,13 @@ py = ">=1.4.26,<2.0.0" [[package]] name = "rich" -version = "13.9.3" +version = "13.9.4" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.8.0" files = [ - {file = "rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283"}, - {file = "rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e"}, + {file = "rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90"}, + {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, ] [package.dependencies] @@ -7229,41 +7256,41 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "watchdog" -version = "5.0.3" +version = "6.0.0" description = "Filesystem events monitoring" optional = false python-versions = ">=3.9" files = [ - {file = "watchdog-5.0.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:85527b882f3facda0579bce9d743ff7f10c3e1e0db0a0d0e28170a7d0e5ce2ea"}, - {file = "watchdog-5.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:53adf73dcdc0ef04f7735066b4a57a4cd3e49ef135daae41d77395f0b5b692cb"}, - {file = "watchdog-5.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e25adddab85f674acac303cf1f5835951345a56c5f7f582987d266679979c75b"}, - {file = "watchdog-5.0.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f01f4a3565a387080dc49bdd1fefe4ecc77f894991b88ef927edbfa45eb10818"}, - {file = "watchdog-5.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:91b522adc25614cdeaf91f7897800b82c13b4b8ac68a42ca959f992f6990c490"}, - {file = "watchdog-5.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d52db5beb5e476e6853da2e2d24dbbbed6797b449c8bf7ea118a4ee0d2c9040e"}, - {file = "watchdog-5.0.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:94d11b07c64f63f49876e0ab8042ae034674c8653bfcdaa8c4b32e71cfff87e8"}, - {file = "watchdog-5.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:349c9488e1d85d0a58e8cb14222d2c51cbc801ce11ac3936ab4c3af986536926"}, - {file = "watchdog-5.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:53a3f10b62c2d569e260f96e8d966463dec1a50fa4f1b22aec69e3f91025060e"}, - {file = "watchdog-5.0.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:950f531ec6e03696a2414b6308f5c6ff9dab7821a768c9d5788b1314e9a46ca7"}, - {file = "watchdog-5.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ae6deb336cba5d71476caa029ceb6e88047fc1dc74b62b7c4012639c0b563906"}, - {file = "watchdog-5.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1021223c08ba8d2d38d71ec1704496471ffd7be42cfb26b87cd5059323a389a1"}, - {file = "watchdog-5.0.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:752fb40efc7cc8d88ebc332b8f4bcbe2b5cc7e881bccfeb8e25054c00c994ee3"}, - {file = "watchdog-5.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a2e8f3f955d68471fa37b0e3add18500790d129cc7efe89971b8a4cc6fdeb0b2"}, - {file = "watchdog-5.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b8ca4d854adcf480bdfd80f46fdd6fb49f91dd020ae11c89b3a79e19454ec627"}, - {file = "watchdog-5.0.3-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:90a67d7857adb1d985aca232cc9905dd5bc4803ed85cfcdcfcf707e52049eda7"}, - {file = "watchdog-5.0.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:720ef9d3a4f9ca575a780af283c8fd3a0674b307651c1976714745090da5a9e8"}, - {file = "watchdog-5.0.3-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:223160bb359281bb8e31c8f1068bf71a6b16a8ad3d9524ca6f523ac666bb6a1e"}, - {file = "watchdog-5.0.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:560135542c91eaa74247a2e8430cf83c4342b29e8ad4f520ae14f0c8a19cfb5b"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:dd021efa85970bd4824acacbb922066159d0f9e546389a4743d56919b6758b91"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_armv7l.whl", hash = "sha256:78864cc8f23dbee55be34cc1494632a7ba30263951b5b2e8fc8286b95845f82c"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_i686.whl", hash = "sha256:1e9679245e3ea6498494b3028b90c7b25dbb2abe65c7d07423ecfc2d6218ff7c"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_ppc64.whl", hash = "sha256:9413384f26b5d050b6978e6fcd0c1e7f0539be7a4f1a885061473c5deaa57221"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:294b7a598974b8e2c6123d19ef15de9abcd282b0fbbdbc4d23dfa812959a9e05"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_s390x.whl", hash = "sha256:26dd201857d702bdf9d78c273cafcab5871dd29343748524695cecffa44a8d97"}, - {file = "watchdog-5.0.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:0f9332243355643d567697c3e3fa07330a1d1abf981611654a1f2bf2175612b7"}, - {file = "watchdog-5.0.3-py3-none-win32.whl", hash = "sha256:c66f80ee5b602a9c7ab66e3c9f36026590a0902db3aea414d59a2f55188c1f49"}, - {file = "watchdog-5.0.3-py3-none-win_amd64.whl", hash = "sha256:f00b4cf737f568be9665563347a910f8bdc76f88c2970121c86243c8cfdf90e9"}, - {file = "watchdog-5.0.3-py3-none-win_ia64.whl", hash = "sha256:49f4d36cb315c25ea0d946e018c01bb028048023b9e103d3d3943f58e109dd45"}, - {file = "watchdog-5.0.3.tar.gz", hash = "sha256:108f42a7f0345042a854d4d0ad0834b741d421330d5f575b81cb27b883500176"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d1cdb490583ebd691c012b3d6dae011000fe42edb7a82ece80965b42abd61f26"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bc64ab3bdb6a04d69d4023b29422170b74681784ffb9463ed4870cf2f3e66112"}, + {file = "watchdog-6.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c897ac1b55c5a1461e16dae288d22bb2e412ba9807df8397a635d88f671d36c3"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6eb11feb5a0d452ee41f824e271ca311a09e250441c262ca2fd7ebcf2461a06c"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ef810fbf7b781a5a593894e4f439773830bdecb885e6880d957d5b9382a960d2"}, + {file = "watchdog-6.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:afd0fe1b2270917c5e23c2a65ce50c2a4abb63daafb0d419fde368e272a76b7c"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860"}, + {file = "watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134"}, + {file = "watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e6f0e77c9417e7cd62af82529b10563db3423625c5fce018430b249bf977f9e8"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:90c8e78f3b94014f7aaae121e6b909674df5b46ec24d6bebc45c44c56729af2a"}, + {file = "watchdog-6.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7631a77ffb1f7d2eefa4445ebbee491c720a5661ddf6df3498ebecae5ed375c"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881"}, + {file = "watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a0e56874cfbc4b9b05c60c8a1926fedf56324bb08cfbc188969777940aef3aa"}, + {file = "watchdog-6.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6439e374fc012255b4ec786ae3c4bc838cd7309a540e5fe0952d03687d8804e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c"}, + {file = "watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2"}, + {file = "watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a"}, + {file = "watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680"}, + {file = "watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f"}, + {file = "watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282"}, ] [package.extras] @@ -7377,13 +7404,13 @@ files = [ [[package]] name = "weaviate-client" -version = "4.9.0" +version = "4.9.2" description = "A python native Weaviate client" optional = true python-versions = ">=3.9" files = [ - {file = "weaviate_client-4.9.0-py3-none-any.whl", hash = "sha256:922a3a83c6946b6ea017d495af5980e90089f97004be4025a3d250a6c40ffaab"}, - {file = "weaviate_client-4.9.0.tar.gz", hash = "sha256:87b2995fd403f6106bd4cc8a9baa77280bdb95617ed6b9a60b0b34b5faeda999"}, + {file = "weaviate_client-4.9.2-py3-none-any.whl", hash = "sha256:7cebbfea29b7aa79354b728bfa8682bdf7183031a2cf749563dbed6a7b7fa6db"}, + {file = "weaviate_client-4.9.2.tar.gz", hash = "sha256:963484383218e0f8bd101f3dd8d2590056a7ca2853e8e0a987fa43ef4e7e4499"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index da6d11ea4..54d22897a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,9 +135,15 @@ jinja2 = "^3.1.3" magicattr = "^0.1.6" litellm = "1.51.0" diskcache = "^5.6.0" + +redis = "^5.1.1" +falkordb = "^1.0.9" + + json-repair = "^0.30.0" tenacity = ">=8.2.3" + [tool.poetry.group.dev.dependencies] pytest = "^6.2.5" transformers = "^4.38.2" diff --git a/setup.py b/setup.py index 358382c1a..7341b8a59 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,8 @@ "fastembed": ["fastembed"], "groq": ["groq~=0.8.0"], "langfuse": ["langfuse~=2.36.1"], - "pgvector": ["psycopg2~=2.9.9","pgvector~=0.2.5"] + "pgvector": ["psycopg2~=2.9.9","pgvector~=0.2.5"], + "falkordb": ["falkordb", "redis", "async-timeout"] }, classifiers=[ "Development Status :: 3 - Alpha", From 863ebededb6836afc7a0045473f8f68978d6ac2e Mon Sep 17 00:00:00 2001 From: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Date: Wed, 6 Nov 2024 16:29:27 -0800 Subject: [PATCH 25/31] Add test coverage for caching against a LiteLLM test server (#1769) * initial Signed-off-by: dbczumar * Impl Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * Add cache file Signed-off-by: dbczumar * proxy Signed-off-by: dbczumar * fix Signed-off-by: dbczumar * fix test Signed-off-by: dbczumar --------- Signed-off-by: dbczumar --- poetry.lock | 420 +++++++++++++++--- pyproject.toml | 1 + tests/caching/example_cache/cache.db | Bin 0 -> 32768 bytes tests/caching/test_caching.py | 90 ++++ tests/clients/test_lm.py | 153 +------ tests/test_utils/__init__.py | 0 tests/test_utils/server/__init__.py | 86 ++++ tests/test_utils/server/litellm_server.py | 47 ++ .../server/litellm_server_config.yaml | 14 + 9 files changed, 608 insertions(+), 203 deletions(-) create mode 100644 tests/caching/example_cache/cache.db create mode 100644 tests/caching/test_caching.py create mode 100644 tests/test_utils/__init__.py create mode 100644 tests/test_utils/server/__init__.py create mode 100644 tests/test_utils/server/litellm_server.py create mode 100644 tests/test_utils/server/litellm_server_config.yaml diff --git a/poetry.lock b/poetry.lock index 9b890c757..88441a414 100644 --- a/poetry.lock +++ b/poetry.lock @@ -236,6 +236,34 @@ files = [ {file = "appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee"}, ] +[[package]] +name = "apscheduler" +version = "3.10.4" +description = "In-process task scheduler with Cron-like capabilities" +optional = false +python-versions = ">=3.6" +files = [ + {file = "APScheduler-3.10.4-py3-none-any.whl", hash = "sha256:fb91e8a768632a4756a585f79ec834e0e27aad5860bac7eaa523d9ccefd87661"}, + {file = "APScheduler-3.10.4.tar.gz", hash = "sha256:e6df071b27d9be898e486bc7940a7be50b4af2e9da7c08f0744a96d4bd4cef4a"}, +] + +[package.dependencies] +pytz = "*" +six = ">=1.4.0" +tzlocal = ">=2.0,<3.dev0 || >=4.dev0" + +[package.extras] +doc = ["sphinx", "sphinx-rtd-theme"] +gevent = ["gevent"] +mongodb = ["pymongo (>=3.0)"] +redis = ["redis (>=3.0)"] +rethinkdb = ["rethinkdb (>=2.4.0)"] +sqlalchemy = ["sqlalchemy (>=1.4)"] +testing = ["pytest", "pytest-asyncio", "pytest-cov", "pytest-tornado5"] +tornado = ["tornado (>=4.3)"] +twisted = ["twisted"] +zookeeper = ["kazoo"] + [[package]] name = "asn1crypto" version = "1.5.1" @@ -898,38 +926,43 @@ test = ["pytest"] [[package]] name = "cryptography" -version = "43.0.3" +version = "42.0.8" description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." -optional = true +optional = false python-versions = ">=3.7" files = [ - {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, - {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, - {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e1ce50266f4f70bf41a2c6dc4358afadae90e2a1e5342d3c08883df1675374f"}, - {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:443c4a81bb10daed9a8f334365fe52542771f25aedaf889fd323a853ce7377d6"}, - {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:74f57f24754fe349223792466a709f8e0c093205ff0dca557af51072ff47ab18"}, - {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9762ea51a8fc2a88b70cf2995e5675b38d93bf36bd67d91721c309df184f49bd"}, - {file = "cryptography-43.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:81ef806b1fef6b06dcebad789f988d3b37ccaee225695cf3e07648eee0fc6b73"}, - {file = "cryptography-43.0.3-cp37-abi3-win32.whl", hash = "sha256:cbeb489927bd7af4aa98d4b261af9a5bc025bd87f0e3547e11584be9e9427be2"}, - {file = "cryptography-43.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:f46304d6f0c6ab8e52770addfa2fc41e6629495548862279641972b6215451cd"}, - {file = "cryptography-43.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:8ac43ae87929a5982f5948ceda07001ee5e83227fd69cf55b109144938d96984"}, - {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:846da004a5804145a5f441b8530b4bf35afbf7da70f82409f151695b127213d5"}, - {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f996e7268af62598f2fc1204afa98a3b5712313a55c4c9d434aef49cadc91d4"}, - {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f7b178f11ed3664fd0e995a47ed2b5ff0a12d893e41dd0494f406d1cf555cab7"}, - {file = "cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:c2e6fc39c4ab499049df3bdf567f768a723a5e8464816e8f009f121a5a9f4405"}, - {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:e1be4655c7ef6e1bbe6b5d0403526601323420bcf414598955968c9ef3eb7d16"}, - {file = "cryptography-43.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:df6b6c6d742395dd77a23ea3728ab62f98379eff8fb61be2744d4679ab678f73"}, - {file = "cryptography-43.0.3-cp39-abi3-win32.whl", hash = "sha256:d56e96520b1020449bbace2b78b603442e7e378a9b3bd68de65c782db1507995"}, - {file = "cryptography-43.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:0c580952eef9bf68c4747774cde7ec1d85a6e61de97281f2dba83c7d2c806362"}, - {file = "cryptography-43.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:d03b5621a135bffecad2c73e9f4deb1a0f977b9a8ffe6f8e002bf6c9d07b918c"}, - {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a2a431ee15799d6db9fe80c82b055bae5a752bef645bba795e8e52687c69efe3"}, - {file = "cryptography-43.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:281c945d0e28c92ca5e5930664c1cefd85efe80e5c0d2bc58dd63383fda29f83"}, - {file = "cryptography-43.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:f18c716be16bc1fea8e95def49edf46b82fccaa88587a45f8dc0ff6ab5d8e0a7"}, - {file = "cryptography-43.0.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:4a02ded6cd4f0a5562a8887df8b3bd14e822a90f97ac5e544c162899bc467664"}, - {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:53a583b6637ab4c4e3591a15bc9db855b8d9dee9a669b550f311480acab6eb08"}, - {file = "cryptography-43.0.3-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:1ec0bcf7e17c0c5669d881b1cd38c4972fade441b27bda1051665faaa89bdcaa"}, - {file = "cryptography-43.0.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2ce6fae5bdad59577b44e4dfed356944fbf1d925269114c28be377692643b4ff"}, - {file = "cryptography-43.0.3.tar.gz", hash = "sha256:315b9001266a492a6ff443b61238f956b214dbec9910a081ba5b6646a055a805"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_universal2.whl", hash = "sha256:81d8a521705787afe7a18d5bfb47ea9d9cc068206270aad0b96a725022e18d2e"}, + {file = "cryptography-42.0.8-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:961e61cefdcb06e0c6d7e3a1b22ebe8b996eb2bf50614e89384be54c48c6b63d"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3ec3672626e1b9e55afd0df6d774ff0e953452886e06e0f1eb7eb0c832e8902"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e599b53fd95357d92304510fb7bda8523ed1f79ca98dce2f43c115950aa78801"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5226d5d21ab681f432a9c1cf8b658c0cb02533eece706b155e5fbd8a0cdd3949"}, + {file = "cryptography-42.0.8-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6b7c4f03ce01afd3b76cf69a5455caa9cfa3de8c8f493e0d3ab7d20611c8dae9"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:2346b911eb349ab547076f47f2e035fc8ff2c02380a7cbbf8d87114fa0f1c583"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:ad803773e9df0b92e0a817d22fd8a3675493f690b96130a5e24f1b8fabbea9c7"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2f66d9cd9147ee495a8374a45ca445819f8929a3efcd2e3df6428e46c3cbb10b"}, + {file = "cryptography-42.0.8-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d45b940883a03e19e944456a558b67a41160e367a719833c53de6911cabba2b7"}, + {file = "cryptography-42.0.8-cp37-abi3-win32.whl", hash = "sha256:a0c5b2b0585b6af82d7e385f55a8bc568abff8923af147ee3c07bd8b42cda8b2"}, + {file = "cryptography-42.0.8-cp37-abi3-win_amd64.whl", hash = "sha256:57080dee41209e556a9a4ce60d229244f7a66ef52750f813bfbe18959770cfba"}, + {file = "cryptography-42.0.8-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:dea567d1b0e8bc5764b9443858b673b734100c2871dc93163f58c46a97a83d28"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4783183f7cb757b73b2ae9aed6599b96338eb957233c58ca8f49a49cc32fd5e"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0608251135d0e03111152e41f0cc2392d1e74e35703960d4190b2e0f4ca9c70"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:dc0fdf6787f37b1c6b08e6dfc892d9d068b5bdb671198c72072828b80bd5fe4c"}, + {file = "cryptography-42.0.8-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:9c0c1716c8447ee7dbf08d6db2e5c41c688544c61074b54fc4564196f55c25a7"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:fff12c88a672ab9c9c1cf7b0c80e3ad9e2ebd9d828d955c126be4fd3e5578c9e"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cafb92b2bc622cd1aa6a1dce4b93307792633f4c5fe1f46c6b97cf67073ec961"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:31f721658a29331f895a5a54e7e82075554ccfb8b163a18719d342f5ffe5ecb1"}, + {file = "cryptography-42.0.8-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b297f90c5723d04bcc8265fc2a0f86d4ea2e0f7ab4b6994459548d3a6b992a14"}, + {file = "cryptography-42.0.8-cp39-abi3-win32.whl", hash = "sha256:2f88d197e66c65be5e42cd72e5c18afbfae3f741742070e3019ac8f4ac57262c"}, + {file = "cryptography-42.0.8-cp39-abi3-win_amd64.whl", hash = "sha256:fa76fbb7596cc5839320000cdd5d0955313696d9511debab7ee7278fc8b5c84a"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ba4f0a211697362e89ad822e667d8d340b4d8d55fae72cdd619389fb5912eefe"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:81884c4d096c272f00aeb1f11cf62ccd39763581645b0812e99a91505fa48e0c"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c9bb2ae11bfbab395bdd072985abde58ea9860ed84e59dbc0463a5d0159f5b71"}, + {file = "cryptography-42.0.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7016f837e15b0a1c119d27ecd89b3515f01f90a8615ed5e9427e30d9cdbfed3d"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5a94eccb2a81a309806027e1670a358b99b8fe8bfe9f8d329f27d72c094dde8c"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dec9b018df185f08483f294cae6ccac29e7a6e0678996587363dc352dc65c842"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:343728aac38decfdeecf55ecab3264b015be68fc2816ca800db649607aeee648"}, + {file = "cryptography-42.0.8-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:013629ae70b40af70c9a7a5db40abe5d9054e6f4380e50ce769947b73bf3caad"}, + {file = "cryptography-42.0.8.tar.gz", hash = "sha256:8d09d05439ce7baa8e9e95b07ec5b6c886f548deb7e0f69ef25f64b3bce842f2"}, ] [package.dependencies] @@ -942,7 +975,7 @@ nox = ["nox"] pep8test = ["check-sdist", "click", "mypy", "ruff"] sdist = ["build"] ssh = ["bcrypt (>=3.1.5)"] -test = ["certifi", "cryptography-vectors (==43.0.3)", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] +test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] [[package]] @@ -1144,7 +1177,7 @@ files = [ name = "dnspython" version = "2.7.0" description = "DNS toolkit" -optional = true +optional = false python-versions = ">=3.9" files = [ {file = "dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86"}, @@ -1171,6 +1204,21 @@ files = [ {file = "docutils-0.16.tar.gz", hash = "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc"}, ] +[[package]] +name = "email-validator" +version = "2.2.0" +description = "A robust email address syntax and deliverability validation library." +optional = false +python-versions = ">=3.8" +files = [ + {file = "email_validator-2.2.0-py3-none-any.whl", hash = "sha256:561977c2d73ce3611850a06fa56b414621e0c8faa9d66f2611407d87465da631"}, + {file = "email_validator-2.2.0.tar.gz", hash = "sha256:cb690f344c617a714f22e66ae771445a1ceb46821152df8e165c5f9a364582b7"}, +] + +[package.dependencies] +dnspython = ">=2.0.0" +idna = ">=2.0.0" + [[package]] name = "environs" version = "9.5.0" @@ -1235,23 +1283,63 @@ redis = ">=5.0.1,<6.0.0" [[package]] name = "fastapi" -version = "0.115.4" +version = "0.111.1" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -optional = true +optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.115.4-py3-none-any.whl", hash = "sha256:0b504a063ffb3cf96a5e27dc1bc32c80ca743a2528574f9cdc77daa2d31b4742"}, - {file = "fastapi-0.115.4.tar.gz", hash = "sha256:db653475586b091cb8b2fec2ac54a680ac6a158e07406e1abae31679e8826349"}, + {file = "fastapi-0.111.1-py3-none-any.whl", hash = "sha256:4f51cfa25d72f9fbc3280832e84b32494cf186f50158d364a8765aabf22587bf"}, + {file = "fastapi-0.111.1.tar.gz", hash = "sha256:ddd1ac34cb1f76c2e2d7f8545a4bcb5463bce4834e81abf0b189e0c359ab2413"}, ] [package.dependencies] +email_validator = ">=2.0.0" +fastapi-cli = ">=0.0.2" +httpx = ">=0.23.0" +jinja2 = ">=2.11.2" pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.40.0,<0.42.0" +python-multipart = ">=0.0.7" +starlette = ">=0.37.2,<0.38.0" typing-extensions = ">=4.8.0" +uvicorn = {version = ">=0.12.0", extras = ["standard"]} [package.extras] -all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] -standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +all = ["email_validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] + +[[package]] +name = "fastapi-cli" +version = "0.0.5" +description = "Run and manage FastAPI apps from the command line with FastAPI CLI. 🚀" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi_cli-0.0.5-py3-none-any.whl", hash = "sha256:e94d847524648c748a5350673546bbf9bcaeb086b33c24f2e82e021436866a46"}, + {file = "fastapi_cli-0.0.5.tar.gz", hash = "sha256:d30e1239c6f46fcb95e606f02cdda59a1e2fa778a54b64686b3ff27f6211ff9f"}, +] + +[package.dependencies] +typer = ">=0.12.3" +uvicorn = {version = ">=0.15.0", extras = ["standard"]} + +[package.extras] +standard = ["uvicorn[standard] (>=0.15.0)"] + +[[package]] +name = "fastapi-sso" +version = "0.10.0" +description = "FastAPI plugin to enable SSO to most common providers (such as Facebook login, Google login and login via Microsoft Office 365 Account)" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "fastapi_sso-0.10.0-py3-none-any.whl", hash = "sha256:579bbcf84157f394a9b30a45dbca74e623cd432054c6f63c55996a775711388e"}, + {file = "fastapi_sso-0.10.0.tar.gz", hash = "sha256:8029c2c58abd861268edc3710ac45827699789bae062a5be52bbbb7a6918c637"}, +] + +[package.dependencies] +fastapi = ">=0.80" +httpx = ">=0.23.0" +oauthlib = ">=3.1.0" +pydantic = {version = ">=1.8.0", extras = ["email"]} [[package]] name = "fastembed" @@ -1768,6 +1856,27 @@ grpcio = ">=1.67.1" protobuf = ">=5.26.1,<6.0dev" setuptools = "*" +[[package]] +name = "gunicorn" +version = "22.0.0" +description = "WSGI HTTP Server for UNIX" +optional = false +python-versions = ">=3.7" +files = [ + {file = "gunicorn-22.0.0-py3-none-any.whl", hash = "sha256:350679f91b24062c86e386e198a15438d53a7a8207235a78ba1b53df4c4378d9"}, + {file = "gunicorn-22.0.0.tar.gz", hash = "sha256:4a0b436239ff76fb33f11c07a16482c521a7e09c1ce3cc293c2330afe01bec63"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +eventlet = ["eventlet (>=0.24.1,!=0.36.0)"] +gevent = ["gevent (>=1.4.0)"] +setproctitle = ["setproctitle"] +testing = ["coverage", "eventlet", "gevent", "pytest", "pytest-cov"] +tornado = ["tornado (>=0.2)"] + [[package]] name = "h11" version = "0.14.0" @@ -1830,7 +1939,7 @@ trio = ["trio (>=0.22.0,<1.0)"] name = "httptools" version = "0.6.4" description = "A collection of framework independent HTTP protocol utils." -optional = true +optional = false python-versions = ">=3.8.0" files = [ {file = "httptools-0.6.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3c73ce323711a6ffb0d247dcd5a550b8babf0f757e86a52558fe5b86d6fefcc0"}, @@ -2435,21 +2544,48 @@ files = [ [package.dependencies] aiohttp = "*" +apscheduler = {version = ">=3.10.4,<4.0.0", optional = true, markers = "extra == \"proxy\""} +backoff = {version = "*", optional = true, markers = "extra == \"proxy\""} click = "*" +cryptography = {version = ">=42.0.5,<43.0.0", optional = true, markers = "extra == \"proxy\""} +fastapi = {version = ">=0.111.0,<0.112.0", optional = true, markers = "extra == \"proxy\""} +fastapi-sso = {version = ">=0.10.0,<0.11.0", optional = true, markers = "extra == \"proxy\""} +gunicorn = {version = ">=22.0.0,<23.0.0", optional = true, markers = "extra == \"proxy\""} importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" openai = ">=1.52.0" +orjson = {version = ">=3.9.7,<4.0.0", optional = true, markers = "extra == \"proxy\""} pydantic = ">=2.0.0,<3.0.0" +PyJWT = {version = ">=2.8.0,<3.0.0", optional = true, markers = "extra == \"proxy\""} +pynacl = {version = ">=1.5.0,<2.0.0", optional = true, markers = "extra == \"proxy\""} python-dotenv = ">=0.2.0" +python-multipart = {version = ">=0.0.9,<0.0.10", optional = true, markers = "extra == \"proxy\""} +pyyaml = {version = ">=6.0.1,<7.0.0", optional = true, markers = "extra == \"proxy\""} requests = ">=2.31.0,<3.0.0" +rq = {version = "*", optional = true, markers = "extra == \"proxy\""} tiktoken = ">=0.7.0" tokenizers = "*" +uvicorn = {version = ">=0.22.0,<0.23.0", optional = true, markers = "extra == \"proxy\""} [package.extras] extra-proxy = ["azure-identity (>=1.15.0,<2.0.0)", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "resend (>=0.8.0,<0.9.0)"] proxy = ["PyJWT (>=2.8.0,<3.0.0)", "apscheduler (>=3.10.4,<4.0.0)", "backoff", "cryptography (>=42.0.5,<43.0.0)", "fastapi (>=0.111.0,<0.112.0)", "fastapi-sso (>=0.10.0,<0.11.0)", "gunicorn (>=22.0.0,<23.0.0)", "orjson (>=3.9.7,<4.0.0)", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.9,<0.0.10)", "pyyaml (>=6.0.1,<7.0.0)", "rq", "uvicorn (>=0.22.0,<0.23.0)"] +[[package]] +name = "livereload" +version = "2.7.0" +description = "Python LiveReload is an awesome tool for web developers" +optional = true +python-versions = ">=3.7" +files = [ + {file = "livereload-2.7.0-py3-none-any.whl", hash = "sha256:19bee55aff51d5ade6ede0dc709189a0f904d3b906d3ea71641ed548acff3246"}, + {file = "livereload-2.7.0.tar.gz", hash = "sha256:f4ba199ef93248902841e298670eebfe1aa9e148e19b343bc57dbf1b74de0513"}, +] + +[package.dependencies] +tornado = "*" + [[package]] name = "llama-cloud" version = "0.1.4" @@ -3838,6 +3974,22 @@ files = [ {file = "nvidia_nvtx_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:641dccaaa1139f3ffb0d3164b4b84f9d253397e38246a4f2f36728b48566d485"}, ] +[[package]] +name = "oauthlib" +version = "3.2.2" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +optional = false +python-versions = ">=3.6" +files = [ + {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, + {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, +] + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + [[package]] name = "onnx" version = "1.17.0" @@ -3969,6 +4121,73 @@ document = ["ase", "cmaes (>=0.10.0)", "fvcore", "lightgbm", "matplotlib (!=3.6. optional = ["boto3", "cmaes (>=0.10.0)", "google-cloud-storage", "matplotlib (!=3.6.0)", "pandas", "plotly (>=4.9.0)", "redis", "scikit-learn (>=0.24.2)", "scipy", "torch"] test = ["coverage", "fakeredis[lua]", "kaleido", "moto", "pytest", "scipy (>=1.9.2)", "torch"] +[[package]] +name = "orjson" +version = "3.10.11" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = false +python-versions = ">=3.8" +files = [ + {file = "orjson-3.10.11-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6dade64687f2bd7c090281652fe18f1151292d567a9302b34c2dbb92a3872f1f"}, + {file = "orjson-3.10.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82f07c550a6ccd2b9290849b22316a609023ed851a87ea888c0456485a7d196a"}, + {file = "orjson-3.10.11-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bd9a187742d3ead9df2e49240234d728c67c356516cf4db018833a86f20ec18c"}, + {file = "orjson-3.10.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:77b0fed6f209d76c1c39f032a70df2d7acf24b1812ca3e6078fd04e8972685a3"}, + {file = "orjson-3.10.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:63fc9d5fe1d4e8868f6aae547a7b8ba0a2e592929245fff61d633f4caccdcdd6"}, + {file = "orjson-3.10.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65cd3e3bb4fbb4eddc3c1e8dce10dc0b73e808fcb875f9fab40c81903dd9323e"}, + {file = "orjson-3.10.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f67c570602300c4befbda12d153113b8974a3340fdcf3d6de095ede86c06d92"}, + {file = "orjson-3.10.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:1f39728c7f7d766f1f5a769ce4d54b5aaa4c3f92d5b84817053cc9995b977acc"}, + {file = "orjson-3.10.11-cp310-none-win32.whl", hash = "sha256:1789d9db7968d805f3d94aae2c25d04014aae3a2fa65b1443117cd462c6da647"}, + {file = "orjson-3.10.11-cp310-none-win_amd64.whl", hash = "sha256:5576b1e5a53a5ba8f8df81872bb0878a112b3ebb1d392155f00f54dd86c83ff6"}, + {file = "orjson-3.10.11-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:1444f9cb7c14055d595de1036f74ecd6ce15f04a715e73f33bb6326c9cef01b6"}, + {file = "orjson-3.10.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cdec57fe3b4bdebcc08a946db3365630332dbe575125ff3d80a3272ebd0ddafe"}, + {file = "orjson-3.10.11-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4eed32f33a0ea6ef36ccc1d37f8d17f28a1d6e8eefae5928f76aff8f1df85e67"}, + {file = "orjson-3.10.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80df27dd8697242b904f4ea54820e2d98d3f51f91e97e358fc13359721233e4b"}, + {file = "orjson-3.10.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:705f03cee0cb797256d54de6695ef219e5bc8c8120b6654dd460848d57a9af3d"}, + {file = "orjson-3.10.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03246774131701de8e7059b2e382597da43144a9a7400f178b2a32feafc54bd5"}, + {file = "orjson-3.10.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8b5759063a6c940a69c728ea70d7c33583991c6982915a839c8da5f957e0103a"}, + {file = "orjson-3.10.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:677f23e32491520eebb19c99bb34675daf5410c449c13416f7f0d93e2cf5f981"}, + {file = "orjson-3.10.11-cp311-none-win32.whl", hash = "sha256:a11225d7b30468dcb099498296ffac36b4673a8398ca30fdaec1e6c20df6aa55"}, + {file = "orjson-3.10.11-cp311-none-win_amd64.whl", hash = "sha256:df8c677df2f9f385fcc85ab859704045fa88d4668bc9991a527c86e710392bec"}, + {file = "orjson-3.10.11-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:360a4e2c0943da7c21505e47cf6bd725588962ff1d739b99b14e2f7f3545ba51"}, + {file = "orjson-3.10.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:496e2cb45de21c369079ef2d662670a4892c81573bcc143c4205cae98282ba97"}, + {file = "orjson-3.10.11-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7dfa8db55c9792d53c5952900c6a919cfa377b4f4534c7a786484a6a4a350c19"}, + {file = "orjson-3.10.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51f3382415747e0dbda9dade6f1e1a01a9d37f630d8c9049a8ed0e385b7a90c0"}, + {file = "orjson-3.10.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f35a1b9f50a219f470e0e497ca30b285c9f34948d3c8160d5ad3a755d9299433"}, + {file = "orjson-3.10.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2f3b7c5803138e67028dde33450e054c87e0703afbe730c105f1fcd873496d5"}, + {file = "orjson-3.10.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f91d9eb554310472bd09f5347950b24442600594c2edc1421403d7610a0998fd"}, + {file = "orjson-3.10.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dfbb2d460a855c9744bbc8e36f9c3a997c4b27d842f3d5559ed54326e6911f9b"}, + {file = "orjson-3.10.11-cp312-none-win32.whl", hash = "sha256:d4a62c49c506d4d73f59514986cadebb7e8d186ad510c518f439176cf8d5359d"}, + {file = "orjson-3.10.11-cp312-none-win_amd64.whl", hash = "sha256:f1eec3421a558ff7a9b010a6c7effcfa0ade65327a71bb9b02a1c3b77a247284"}, + {file = "orjson-3.10.11-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c46294faa4e4d0eb73ab68f1a794d2cbf7bab33b1dda2ac2959ffb7c61591899"}, + {file = "orjson-3.10.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:52e5834d7d6e58a36846e059d00559cb9ed20410664f3ad156cd2cc239a11230"}, + {file = "orjson-3.10.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2fc947e5350fdce548bfc94f434e8760d5cafa97fb9c495d2fef6757aa02ec0"}, + {file = "orjson-3.10.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0efabbf839388a1dab5b72b5d3baedbd6039ac83f3b55736eb9934ea5494d258"}, + {file = "orjson-3.10.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a3f29634260708c200c4fe148e42b4aae97d7b9fee417fbdd74f8cfc265f15b0"}, + {file = "orjson-3.10.11-cp313-none-win32.whl", hash = "sha256:1a1222ffcee8a09476bbdd5d4f6f33d06d0d6642df2a3d78b7a195ca880d669b"}, + {file = "orjson-3.10.11-cp313-none-win_amd64.whl", hash = "sha256:bc274ac261cc69260913b2d1610760e55d3c0801bb3457ba7b9004420b6b4270"}, + {file = "orjson-3.10.11-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:19b3763e8bbf8ad797df6b6b5e0fc7c843ec2e2fc0621398534e0c6400098f87"}, + {file = "orjson-3.10.11-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1be83a13312e5e58d633580c5eb8d0495ae61f180da2722f20562974188af205"}, + {file = "orjson-3.10.11-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:afacfd1ab81f46dedd7f6001b6d4e8de23396e4884cd3c3436bd05defb1a6446"}, + {file = "orjson-3.10.11-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cb4d0bea56bba596723d73f074c420aec3b2e5d7d30698bc56e6048066bd560c"}, + {file = "orjson-3.10.11-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96ed1de70fcb15d5fed529a656df29f768187628727ee2788344e8a51e1c1350"}, + {file = "orjson-3.10.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4bfb30c891b530f3f80e801e3ad82ef150b964e5c38e1fb8482441c69c35c61c"}, + {file = "orjson-3.10.11-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d496c74fc2b61341e3cefda7eec21b7854c5f672ee350bc55d9a4997a8a95204"}, + {file = "orjson-3.10.11-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:655a493bac606655db9a47fe94d3d84fc7f3ad766d894197c94ccf0c5408e7d3"}, + {file = "orjson-3.10.11-cp38-none-win32.whl", hash = "sha256:b9546b278c9fb5d45380f4809e11b4dd9844ca7aaf1134024503e134ed226161"}, + {file = "orjson-3.10.11-cp38-none-win_amd64.whl", hash = "sha256:b592597fe551d518f42c5a2eb07422eb475aa8cfdc8c51e6da7054b836b26782"}, + {file = "orjson-3.10.11-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c95f2ecafe709b4e5c733b5e2768ac569bed308623c85806c395d9cca00e08af"}, + {file = "orjson-3.10.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:80c00d4acded0c51c98754fe8218cb49cb854f0f7eb39ea4641b7f71732d2cb7"}, + {file = "orjson-3.10.11-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:461311b693d3d0a060439aa669c74f3603264d4e7a08faa68c47ae5a863f352d"}, + {file = "orjson-3.10.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52ca832f17d86a78cbab86cdc25f8c13756ebe182b6fc1a97d534051c18a08de"}, + {file = "orjson-3.10.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f4c57ea78a753812f528178aa2f1c57da633754c91d2124cb28991dab4c79a54"}, + {file = "orjson-3.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7fcfc6f7ca046383fb954ba528587e0f9336828b568282b27579c49f8e16aad"}, + {file = "orjson-3.10.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:86b9dd983857970c29e4c71bb3e95ff085c07d3e83e7c46ebe959bac07ebd80b"}, + {file = "orjson-3.10.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4d83f87582d223e54efb2242a79547611ba4ebae3af8bae1e80fa9a0af83bb7f"}, + {file = "orjson-3.10.11-cp39-none-win32.whl", hash = "sha256:9fd0ad1c129bc9beb1154c2655f177620b5beaf9a11e0d10bac63ef3fce96950"}, + {file = "orjson-3.10.11-cp39-none-win_amd64.whl", hash = "sha256:10f416b2a017c8bd17f325fb9dee1fb5cdd7a54e814284896b7c3f2763faa017"}, + {file = "orjson-3.10.11.tar.gz", hash = "sha256:e35b6d730de6384d5b2dab5fd23f0d76fae8bbc8c353c2f78210aa5fa4beb3ef"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -4772,6 +4991,7 @@ files = [ [package.dependencies] annotated-types = ">=0.6.0" +email-validator = {version = ">=2.0.0", optional = true, markers = "extra == \"email\""} pydantic-core = "2.23.4" typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} @@ -4935,7 +5155,7 @@ windows-terminal = ["colorama (>=0.4.6)"] name = "pyjwt" version = "2.9.0" description = "JSON Web Token implementation in Python" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, @@ -5017,6 +5237,32 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pynacl" +version = "1.5.0" +description = "Python binding to the Networking and Cryptography (NaCl) library" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d"}, + {file = "PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b"}, + {file = "PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543"}, + {file = "PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93"}, + {file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"}, +] + +[package.dependencies] +cffi = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"] +tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"] + [[package]] name = "pyopenssl" version = "24.2.1" @@ -5163,6 +5409,20 @@ files = [ [package.extras] cli = ["click (>=5.0)"] +[[package]] +name = "python-multipart" +version = "0.0.9" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, + {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, +] + +[package.extras] +dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] + [[package]] name = "pytz" version = "2024.2" @@ -5741,6 +6001,21 @@ files = [ {file = "rpds_py-0.20.1.tar.gz", hash = "sha256:e1791c4aabd117653530dccd24108fa03cc6baf21f58b950d0a73c3b3b29a350"}, ] +[[package]] +name = "rq" +version = "2.0.0" +description = "RQ is a simple, lightweight, library for creating background jobs, and processing them." +optional = false +python-versions = ">=3.8" +files = [ + {file = "rq-2.0.0-py3-none-any.whl", hash = "sha256:a3a767876675dcc42683bac1869494c5020ba7fcf5c026d1f6d36a8ab98573a6"}, + {file = "rq-2.0.0.tar.gz", hash = "sha256:76d2a4a27f8fd5c4cfa200cd442efe3c1fd73525c676af06f07fcc0b81bdb70d"}, +] + +[package.dependencies] +click = ">=5" +redis = ">=3.5" + [[package]] name = "ruff" version = "0.3.7" @@ -6003,7 +6278,7 @@ type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12 name = "shellingham" version = "1.5.4" description = "Tool to Detect Surrounding Shell" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, @@ -6192,25 +6467,22 @@ test = ["cython", "html5lib", "pytest (>=4.6)", "typed_ast"] [[package]] name = "sphinx-autobuild" -version = "2024.10.3" -description = "Rebuild Sphinx documentation on changes, with hot reloading in the browser." +version = "2024.2.4" +description = "Rebuild Sphinx documentation on changes, with live-reload in the browser." optional = true python-versions = ">=3.9" files = [ - {file = "sphinx_autobuild-2024.10.3-py3-none-any.whl", hash = "sha256:158e16c36f9d633e613c9aaf81c19b0fc458ca78b112533b20dafcda430d60fa"}, - {file = "sphinx_autobuild-2024.10.3.tar.gz", hash = "sha256:248150f8f333e825107b6d4b86113ab28fa51750e5f9ae63b59dc339be951fb1"}, + {file = "sphinx_autobuild-2024.2.4-py3-none-any.whl", hash = "sha256:63fd87ab7505872a89aef468ce6503f65e794a195f4ae62269db3b85b72d4854"}, + {file = "sphinx_autobuild-2024.2.4.tar.gz", hash = "sha256:cb9d2121a176d62d45471624872afc5fad7755ad662738abe400ecf4a7954303"}, ] [package.dependencies] -colorama = ">=0.4.6" +colorama = "*" +livereload = "*" sphinx = "*" -starlette = ">=0.35" -uvicorn = ">=0.25" -watchfiles = ">=0.20" -websockets = ">=11" [package.extras] -test = ["httpx", "pytest (>=6)"] +test = ["pytest (>=6.0)", "pytest-cov"] [[package]] name = "sphinx-automodapi" @@ -6504,13 +6776,13 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] [[package]] name = "starlette" -version = "0.41.2" +version = "0.37.2" description = "The little ASGI library that shines." -optional = true +optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d"}, - {file = "starlette-0.41.2.tar.gz", hash = "sha256:9834fd799d1a87fd346deb76158668cfa0b0d56f85caefe8268e2d97c3468b62"}, + {file = "starlette-0.37.2-py3-none-any.whl", hash = "sha256:6fe59f29268538e5d0d182f2791a479a0c64638e6935d1c6989e63fb2699c6ee"}, + {file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"}, ] [package.dependencies] @@ -6976,7 +7248,7 @@ tutorials = ["matplotlib", "pandas", "tabulate"] name = "typer" version = "0.12.5" description = "Typer, build great CLIs. Easy to code. Based on Python type hints." -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "typer-0.12.5-py3-none-any.whl", hash = "sha256:62fe4e471711b147e3365034133904df3e235698399bc4de2b36c8579298d52b"}, @@ -7026,6 +7298,23 @@ files = [ {file = "tzdata-2024.2.tar.gz", hash = "sha256:7d85cc416e9382e69095b7bdf4afd9e3880418a2413feec7069d533d6b4e31cc"}, ] +[[package]] +name = "tzlocal" +version = "5.2" +description = "tzinfo object for the local timezone" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tzlocal-5.2-py3-none-any.whl", hash = "sha256:49816ef2fe65ea8ac19d19aa7a1ae0551c834303d5014c6d5a62e4cbda8047b8"}, + {file = "tzlocal-5.2.tar.gz", hash = "sha256:8d399205578f1a9342816409cc1e46a93ebd5755e39ea2d85334bea911bf0e6e"}, +] + +[package.dependencies] +tzdata = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] + [[package]] name = "ujson" version = "5.10.0" @@ -7131,13 +7420,13 @@ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] [[package]] name = "uvicorn" -version = "0.32.0" +version = "0.22.0" description = "The lightning-fast ASGI server." -optional = true -python-versions = ">=3.8" +optional = false +python-versions = ">=3.7" files = [ - {file = "uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82"}, - {file = "uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e"}, + {file = "uvicorn-0.22.0-py3-none-any.whl", hash = "sha256:e9434d3bbf05f310e762147f769c9f21235ee118ba2d2bf1155a7196448bd996"}, + {file = "uvicorn-0.22.0.tar.gz", hash = "sha256:79277ae03db57ce7d9aa0567830bbb51d7a612f54d6e1e3e92da3ef24c2c8ed8"}, ] [package.dependencies] @@ -7147,7 +7436,6 @@ h11 = ">=0.8" httptools = {version = ">=0.5.0", optional = true, markers = "extra == \"standard\""} python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""} -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} uvloop = {version = ">=0.14.0,<0.15.0 || >0.15.0,<0.15.1 || >0.15.1", optional = true, markers = "(sys_platform != \"win32\" and sys_platform != \"cygwin\") and platform_python_implementation != \"PyPy\" and extra == \"standard\""} watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""} @@ -7159,7 +7447,7 @@ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", name = "uvloop" version = "0.21.0" description = "Fast implementation of asyncio event loop on top of libuv" -optional = true +optional = false python-versions = ">=3.8.0" files = [ {file = "uvloop-0.21.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ec7e6b09a6fdded42403182ab6b832b71f4edaf7f37a9a0e371a01db5f0cb45f"}, @@ -7300,7 +7588,7 @@ watchmedo = ["PyYAML (>=3.10)"] name = "watchfiles" version = "0.24.0" description = "Simple, modern and high performance file watching and code reload in python." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "watchfiles-0.24.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:083dc77dbdeef09fa44bb0f4d1df571d2e12d8a8f985dccde71ac3ac9ac067a0"}, @@ -7427,7 +7715,7 @@ validators = "0.34.0" name = "websockets" version = "13.1" description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "websockets-13.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f48c749857f8fb598fb890a75f540e3221d0976ed0bf879cf3c7eef34151acee"}, @@ -7889,4 +8177,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "92c91613bb51ec6d493672baf6c0d509ebb16ec0c7dd8d23f9cfd6e4654972cf" +content-hash = "22dfbbe02db1c0fcb0f13cbcf4083f54dd7218d7dd2b43a0b1e8dbeed9ae5002" diff --git a/pyproject.toml b/pyproject.toml index 54d22897a..dfd156a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -155,6 +155,7 @@ pre-commit = "^3.7.0" ipykernel = "^6.29.4" semver = "^3.0.2" pillow = "^10.1.0" +litellm = {version = "1.51.0", extras = ["proxy"]} [tool.poetry.extras] chromadb = ["chromadb"] diff --git a/tests/caching/example_cache/cache.db b/tests/caching/example_cache/cache.db new file mode 100644 index 0000000000000000000000000000000000000000..47ea29f9f906f6ac79a20f5ad88f09a9dc57233f GIT binary patch literal 32768 zcmeI4&2Jk;6u@^gi9eFsld7a{RaKp=gql!0iDM^rQXw@T+cv73w2q^;LTI&KkCScJ zyYB9|X{aK#1;nMtAP)QioVXxy=?yMONJyNJIDrHl=nYg5hYB(K6+0y@kqV-ew~Djt zH*ep(`R(jP^WJ2ZFP3yhN!_v=DkC|eU+C%*P7xvq0^*+-|Aa>`zvvFWxNKTWqGC|lS+Bv%o({%cGDv>rz^Z@ zUM|U8b7_%yH~3F3h{eLL+2Zu8Iw*L}Uf^}w%ZbNvKalrQw;@{hQ zvOBwtS5eS!^3kWdy$9sJ+y4afC%avVL{gcQ=;&sR-t;m0;4O73l~>p(&=)bwrzKa3(|z&ID9vXqN!9$e6zvSsu2uJ?%3254o)+(B{>*% zhKOyg>osn=BD(|Dh4RArner8KPQH@%29w#+(yV7UR7sOI7{W)pQkrzMD3GX{Mjgk4 zzEk4$iO>`u15eMalqzH_O_;h${JOl+EcB;CqhZzPZgf00e*l5C8%|00;m9AOHlOa{>wO zUw`sb0sn-*!}swA_&t0Zui^7JkB?%M{4x1W@~h+*TmcRs00e*l5C8%|00;m9AOHk_ z01yBI|0;oaKRO(1=#E33C`JRZHJv%pB)@QT_OS#?#l8Igp?zpVtm)1*rKxMz45~D3 zi&>gw^c?C#S>%;8bW?FN_udLB;kGc)i?Z<`o9=9w+L~>dR?A7nQKmZ(HyUcwQ+G4! zhhr#rASk$MwQN&06nBNdx@sJZqS3CPNOKoHcsCJYAkwbFP1q|pRISx&Bzw@HsI?43 zG4zJcdiQsu!EU!f{`XV7vf~JyD9{_a#&pY6nwFt!8wOP!ma(ZuO=fIlyhRbUeO>5S zj4x$Syu|#1rZU}Oy5_ix`umd)1RTX5;aBi}{AKc&f00e*l5C8%|00{gC2=t=Eq8}#ckD&q44HIDZ!ktwo<0vJ1 zVS>b=Cxi+19Yk3_ey}4-(C3Q(93tp_31vd@c8DOJLb*`b86JoYpwUp?jtxXREO$f( zBFP<*fgbm3J223F0G$vY3kskFIwo$<|GN@sFy5a3yZL`H@sYrP-~a+Z00;m9AOHk_ z01yBIKmZ5;f#-_ATjxY9+)2cvLNq4CMB$xkF_+I}$8$xks#WvV{8X`;ud7))Q5~=6 z)Ldb*o~;(f>c#qGv8ayG?AYYQ)FiEHQ-$$do@Vn?GhcuHNc!RI)$rVXyXW0+gWcUW zw}$S;-F*u9l(Es=Iuqy&-HX}OX=EabKP5#UX!LJZB1obqiJL&>tn_7LS4&D zRr7^RK0l>qs&s5Tqh)i|!bEPeTC7h=X|8YS{J-+0>DO*ab=}mRHO1yRY|G@fc=o!< zRlKaVG(DD_Rv8mj6u`t0a<9U0oB&{}CCO0yXVJ*9AakEzS zD%ISI-KLIc7H=$dGw{5~*>D(b@D&6m-;Jed>#l!(&8@cIp}xYbYt-bvjq{fAwFr$S z+ZJWHh*^xU9tf3kT+*&95c6$1n5wtix5LWL!cFcSwfSNJchd=K2aVrqjkx)L5%azI zU?0Bc>MaLnfB+Bx0zd!=00AHX1b_e#00KY&2>1l#9v%dWKWpsApZ`A+@FU;~1b_e# z00KY&2mk>f00e*l5C8%|00_K@1P~J9{`~);fFHhyErd1!0U!VbfB+Bx0zd!=00AHX T1b_e#`2P~ Tuple[str, str]: + """ + Start a LiteLLM test server for a DSPy integration test case, and tear down the + server when the test case completes. + """ + with tempfile.TemporaryDirectory() as server_log_dir_path: + # Create a server log file used to store request logs + server_log_file_path = os.path.join(server_log_dir_path, "request_logs.jsonl") + open(server_log_file_path, "a").close() + + port = _get_random_port() + host = "127.0.0.1" + print(f"Starting LiteLLM proxy server on port {port}") + + process = subprocess.Popen( + ["litellm", "--host", host, "--port", str(port), "--config", _get_litellm_config_path()], + env={LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR: server_log_file_path, **os.environ.copy()}, + text=True, + ) + + try: + _wait_for_port(host=host, port=port) + except TimeoutError as e: + process.terminate() + raise e + + server_url = f"http://{host}:{port}" + yield server_url, server_log_file_path + + process.kill() + process.wait() + + +def read_litellm_test_server_request_logs(server_log_file_path: str) -> List[Dict[str, Any]]: + """ + Read request logs from a LiteLLM server used during DSPy integration tests. + + Args: + server_log_file_path: The filesystem path to the LiteLLM server request logs jsonlines file. + Return: + A list of log entries, where each entry corresponds to one request handled by the server. + """ + data = [] + with open(server_log_file_path, "r") as f: + for line in f: + data.append(json.loads(line)) + + return data + + +def _get_litellm_config_path(): + module_dir = os.path.dirname(os.path.abspath(__file__)) + return os.path.join(module_dir, "litellm_server_config.yaml") + + +def _get_random_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _wait_for_port(host, port, timeout=10): + start_time = time.time() + while time.time() - start_time < timeout: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + sock.connect((host, port)) + return True + except ConnectionRefusedError: + time.sleep(0.5) # Wait briefly before trying again + raise TimeoutError(f"Server on port {port} did not become ready within {timeout} seconds.") diff --git a/tests/test_utils/server/litellm_server.py b/tests/test_utils/server/litellm_server.py new file mode 100644 index 000000000..ed7ba4f3f --- /dev/null +++ b/tests/test_utils/server/litellm_server.py @@ -0,0 +1,47 @@ +import json +import os + +import litellm +from litellm import CustomLLM + +LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR = "LITELLM_TEST_SERVER_LOG_FILE_PATH" + + +class DSPyTestModel(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + _append_request_to_log_file(kwargs) + return _get_mock_llm_response() + + async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse: + _append_request_to_log_file(kwargs) + return _get_mock_llm_response() + + +def _get_mock_llm_response(): + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) + + +def _append_request_to_log_file(completion_kwargs): + log_file_path = os.environ.get(LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR) + if log_file_path is None: + raise ValueError( + f"Server logs file path is not defined! Please set the path using the" + + f" {LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR} environment variable." + ) + + with open(log_file_path, "a") as f: + log_blob = ( + { + "model": completion_kwargs["model"], + "messages": completion_kwargs["messages"], + }, + ) + json.dump(log_blob, f) + f.write("\n") + + +dspy_test_model = DSPyTestModel() diff --git a/tests/test_utils/server/litellm_server_config.yaml b/tests/test_utils/server/litellm_server_config.yaml new file mode 100644 index 000000000..6e9743663 --- /dev/null +++ b/tests/test_utils/server/litellm_server_config.yaml @@ -0,0 +1,14 @@ +model_list: + - model_name: "dspy-test-model" + litellm_params: + model: "dspy-test-provider/dspy-test-model" + - model_name: "dspy-test-model-2" + litellm_params: + model: "dspy-test-provider/dspy-test-model" + +litellm_settings: + custom_provider_map: + - { + "provider": "dspy-test-provider", + "custom_handler": litellm_server.dspy_test_model, + } From 08d9eceb472200b8aca9d39451c27be664d10065 Mon Sep 17 00:00:00 2001 From: arnavsinghvi11 <54859892+arnavsinghvi11@users.noreply.github.com> Date: Wed, 6 Nov 2024 17:07:33 -0800 Subject: [PATCH 26/31] add flag to suppress litellm logs in dspy.LM (#1768) * add flag to suppress litellm logs in dspy.LM * add suppress flag to dspy.configure --- dsp/utils/settings.py | 1 + dspy/clients/lm.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index ea302f35a..1fed72368 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -24,6 +24,7 @@ experimental=False, backoff_time=10, callbacks=[], + suppress_debug_info=True, ) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 567178432..41b450abc 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import Any, Dict, List, Literal, Optional - +import dspy import litellm import ujson @@ -61,6 +61,9 @@ def __init__( self.callbacks = callbacks or [] self.num_retries = num_retries + #turned off by default to avoid LiteLLM logging during every LM call + litellm.suppress_debug_info = dspy.settings.suppress_debug_info + # TODO: Arbitrary model strings could include the substring "o1-". We # should find a more robust way to check for the "o1-" family models. if "o1-" in model: From 1df86fa3d01941c01b5bf943c66247e91d6f19a9 Mon Sep 17 00:00:00 2001 From: xucailiang <74602715+xucailiang@users.noreply.github.com> Date: Thu, 7 Nov 2024 22:49:28 +0800 Subject: [PATCH 27/31] neo4j_rm support litellm embedding model (#1771) Co-authored-by: xucai --- dspy/retrieve/neo4j_rm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dspy/retrieve/neo4j_rm.py b/dspy/retrieve/neo4j_rm.py index f71dbb1cb..bcd576325 100644 --- a/dspy/retrieve/neo4j_rm.py +++ b/dspy/retrieve/neo4j_rm.py @@ -1,5 +1,5 @@ import os -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, Callable import backoff from openai import ( @@ -108,6 +108,7 @@ def __init__( retrieval_query: str = None, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", + embedding_function: Optional[Callable] = None, ): super().__init__(k=k) self.index_name = index_name @@ -136,7 +137,7 @@ def __init__( ) as e: raise ConnectionError("Failed to connect to Neo4j database") from e - self.embedder = Embedder(provider=embedding_provider, model=embedding_model) + self.embedder = embedding_function or Embedder(provider=embedding_provider, model=embedding_model) def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> Prediction: if not isinstance(query_or_queries, list): From 1e539b01f7db87d7514c722d633bc9718782d2c4 Mon Sep 17 00:00:00 2001 From: Dilara Soylu <21346670+dilarasoylu@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:22:57 -0800 Subject: [PATCH 28/31] Dev finetune update (#1698) * Adapter updates * Client updates * Add provider * Add cache dir utils * Add BootstrapFinetune * Add BetterTogether draft * Prepare PR * Add AnyScale changes -- ruff * Teporarily remove BetterTogether * Remove BetterTogether import * Add comment * Replace OpenAI client call with library call * Add OpenAI models list to check valid models * Temporarily switch to print * Prepare BootstrapFinetune for BetterTogether * Add BetterTogether * Add dev notebook * Revamp ChainOfThoughtWithHint and adjust max auto valset of MIPROv2 * fix * ruff fixes * disable cot_hint tests * unsafe ruff fixes --------- Co-authored-by: Omar Khattab --- dsp/utils/settings_v2.py | 2 +- dspy/__init__.py | 2 + dspy/adapters/base.py | 28 +- dspy/adapters/chat_adapter.py | 17 +- dspy/clients/__init__.py | 1 + dspy/clients/anyscale.py | 26 +- dspy/clients/finetune.py | 132 -- dspy/clients/lm.py | 113 +- dspy/clients/lm_finetune_utils.py | 93 - dspy/clients/openai.py | 555 +++--- dspy/clients/provider.py | 67 + dspy/clients/utils_finetune.py | 147 ++ dspy/functional/functional.py | 32 +- dspy/predict/aggregation.py | 5 +- dspy/predict/chain_of_thought_with_hint.py | 44 +- dspy/predict/retry.py | 2 +- dspy/primitives/assertions.py | 4 +- dspy/teleprompt/__init__.py | 3 +- dspy/teleprompt/bettertogether.py | 170 ++ dspy/teleprompt/bootstrap_finetune.py | 287 ++++ dspy/teleprompt/finetune.py | 189 --- dspy/teleprompt/finetune_teleprompter.py | 148 -- dspy/teleprompt/mipro_optimizer_v2.py | 2 +- dspy/utils/__init__.py | 1 + dspy/utils/caching.py | 14 + .../finetune/_internal_finetune_demo.ipynb | 1505 +++++++++++++++++ .../test_chain_of_thought_with_hint.py | 74 +- .../test_chain_of_thought_with_hint.py | 50 +- 28 files changed, 2656 insertions(+), 1057 deletions(-) delete mode 100644 dspy/clients/finetune.py delete mode 100644 dspy/clients/lm_finetune_utils.py create mode 100644 dspy/clients/provider.py create mode 100644 dspy/clients/utils_finetune.py create mode 100644 dspy/teleprompt/bettertogether.py create mode 100644 dspy/teleprompt/bootstrap_finetune.py delete mode 100644 dspy/teleprompt/finetune.py delete mode 100644 dspy/teleprompt/finetune_teleprompter.py create mode 100644 dspy/utils/caching.py create mode 100644 examples/finetune/_internal_finetune_demo.ipynb diff --git a/dsp/utils/settings_v2.py b/dsp/utils/settings_v2.py index 6652474d3..c1b93a895 100644 --- a/dsp/utils/settings_v2.py +++ b/dsp/utils/settings_v2.py @@ -73,7 +73,7 @@ def main(): futures = {executor.submit(thread_wrapper, sample_program, parent_tid, arg) for arg in range(3)} for future in as_completed(futures): - res = future.result() + future.result() print(f"Main thread {parent_tid} config after threads: {dsp_settings._get_current_config()}") diff --git a/dspy/__init__.py b/dspy/__init__.py index f80c8237f..3ba977eb8 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -67,6 +67,8 @@ BootstrapFewShot = dspy.teleprompt.BootstrapFewShot BootstrapFewShotWithRandomSearch = dspy.teleprompt.BootstrapFewShotWithRandomSearch BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch +BootstrapFinetune = dspy.teleprompt.BootstrapFinetune +BetterTogether = dspy.teleprompt.BetterTogether COPRO = dspy.teleprompt.COPRO MIPROv2 = dspy.teleprompt.MIPROv2 Ensemble = dspy.teleprompt.Ensemble \ No newline at end of file diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index c25ac61a1..1edf562ee 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,19 +1,8 @@ -import abc -from dspy.utils.callback import with_callbacks - -class Adapter: - @abc.abstractmethod - def format(self, signature, demos, inputs): - """ - Format the input data for the LLM. - """ +from abc import ABC, abstractmethod - @abc.abstractmethod - def parse(self, signature, completion): - """ - Parse the output data from the LLM. - """ +from dspy.utils.callback import with_callbacks +class Adapter(ABC): def __init__(self, callbacks=None): self.callbacks = callbacks or [] @@ -31,7 +20,6 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): outputs = lm(**inputs_, **lm_kwargs) values = [] - try: for output in outputs: value = self.parse(signature, output, _parse_values=_parse_values) @@ -45,3 +33,13 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values) raise e + @abstractmethod + def format(self, signature, demos, inputs): + raise NotImplementedError + + @abstractmethod + def parse(self, signature, completion, _parse_values): + raise NotImplementedError + + def format_finetune_data(self, signature, demos, inputs, outputs): + raise NotImplementedError diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 6727cd245..f8be79241 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -34,10 +34,6 @@ class FieldInfoWithName(NamedTuple): BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) class ChatAdapter(Adapter): - """ - ChatAdapter is used to format and parse data for chat-based LLMs. - """ - def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] @@ -90,6 +86,19 @@ def parse(self, signature, completion, _parse_values=True): return fields + # TODO(PR): Looks ok? + def format_finetune_data(self, signature, demos, inputs, outputs): + # Get system + user messages + messages = self.format(signature, demos, inputs) + + # Add the assistant message + role = "assistant" + incomplete = False + assistant_message = format_turn(signature, outputs, role, incomplete) + messages.append(assistant_message) + + # Wrap the messages in a dictionary with a "messages" key + return dict(messages=messages) def format_turn(self, signature, values, role, incomplete=False): return format_turn(signature, values, role, incomplete) diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index ef9a8da5f..9692a9a97 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,4 +1,5 @@ from .lm import LM +from .provider import Provider, TrainingJob from .base_lm import BaseLM, inspect_history from .embedding import Embedding import litellm diff --git a/dspy/clients/anyscale.py b/dspy/clients/anyscale.py index d6867fe26..3d3022867 100644 --- a/dspy/clients/anyscale.py +++ b/dspy/clients/anyscale.py @@ -7,7 +7,7 @@ from dspy.clients.finetune import ( FinetuneJob, - TrainingMethod, + # TrainingMethod, save_data, ) from dspy.clients.openai import openai_data_validation @@ -32,7 +32,7 @@ def is_anyscale_model(model: str) -> bool: """Check if the model is an AnyScale model.""" # TODO: This needs to be implemented to support fine-tuning - logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") + print("Is AnyScale model is not implemented, returning False as a default to not break lm.py") return False @@ -103,9 +103,9 @@ def finetune_anyscale( def wait_for_training(job_id): """Wait for the training to complete.""" - logger.info("[Finetune] Waiting for training to complete...") + print("[Finetune] Waiting for training to complete...") anyscale.job.wait(id=job_id) - logger.info("[Finetune] Training completed.") + print("[Finetune] Training completed.") def update_serve_model_config(lora_dynamic_path: str, serve_config_path: str): @@ -126,7 +126,7 @@ def update_serve_model_config(lora_dynamic_path: str, serve_config_path: str): def verify_dataset(dataset: List[dict[str, Any]]) -> bool: """Verify the training arguments before starting training.""" - logger.info("[Finetune] Verifying dataset...") + print("[Finetune] Verifying dataset...") dataset_validation = openai_data_validation(dataset) if dataset_validation: @@ -138,11 +138,11 @@ def verify_dataset(dataset: List[dict[str, Any]]) -> bool: def submit_data(train_path: str, job_config: Dict[str, Any]): """Upload the data to cloud storage.""" - logger.info("[Finetune] Submitting data to remote storage...") + print("[Finetune] Submitting data to remote storage...") dataset_suffix = os.path.basename(train_path).split(".")[0] dataset_name = f"dataset-{job_config.get('name', dataset_suffix)}" train_path_remote = anyscale.llm.dataset.upload(train_path, name=dataset_name, cloud=job_config.get("cloud", None)).storage_uri - logger.info(f"[Finetune] Data submitted. Remote train path: {train_path_remote}") + print(f"[Finetune] Data submitted. Remote train path: {train_path_remote}") return train_path_remote @@ -158,7 +158,7 @@ def generate_config_files(train_path: str, llmforge_config_path: str, job_config llmforge_config["train_path"] = train_path llmforge_config = {k: v for k, v in llmforge_config.items() if v is not None} - logger.info(f"Model config data: {llmforge_config}") + print(f"Model config data: {llmforge_config}") yaml.safe_dump(llmforge_config, open(llmforge_config_path, "w")) if not job_config_dict.get("env_vars", None): @@ -176,21 +176,21 @@ def generate_config_files(train_path: str, llmforge_config_path: str, job_config def start_remote_training(job_config) -> str: - logger.info("[Finetune] Starting remote training...") + print("[Finetune] Starting remote training...") job_id: str = anyscale.job.submit(job_config) - logger.info(f"[Finetune] Remote training started. Job ID: {job_id}") + print(f"[Finetune] Remote training started. Job ID: {job_id}") return job_id def wait_for_training(job_id): - logger.info("Waiting for training to complete") + print("Waiting for training to complete") anyscale.job.wait(id=job_id, timeout_s=18000) def get_model_info(job_id): - logger.info("[Finetune] Retrieving model information from Anyscale Models SDK...") + print("[Finetune] Retrieving model information from Anyscale Models SDK...") info = anyscale.llm.model.get(job_id=job_id).to_dict() - logger.info(f"[Finetune] Model info retrieved: {info}") + print(f"[Finetune] Model info retrieved: {info}") return info def read_jsonl(filename): diff --git a/dspy/clients/finetune.py b/dspy/clients/finetune.py deleted file mode 100644 index 54ba7b4ba..000000000 --- a/dspy/clients/finetune.py +++ /dev/null @@ -1,132 +0,0 @@ -import logging -import os -from abc import abstractmethod -from concurrent.futures import Future -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Optional - -import ujson -from datasets.fingerprint import Hasher - -logger = logging.getLogger(__name__) - -def get_finetune_directory() -> str: - """Get the directory to save the fine-tuned models.""" - # TODO: Move to a centralized location with all the other env variables - dspy_cachedir = os.environ.get("DSPY_CACHEDIR") - dspy_cachedir = dspy_cachedir or os.path.join(Path.home(), ".dspy_cache") - finetune_dir = os.path.join(dspy_cachedir, "finetune") - finetune_dir = os.path.abspath(finetune_dir) - return finetune_dir - - -FINETUNE_DIRECTORY = get_finetune_directory() - - -class TrainingMethod(str, Enum): - """Enum class for training methods. - - When comparing enums, Python checks for object IDs, which means that the - enums can't be compared directly. Subclassing the Enum class along with the - str class allows for direct comparison of the enums. - """ - - SFT = "SFT" - Preference = "Preference" - - -class TrainingStatus(str, Enum): - """Enum class for remote training status.""" - - not_started = "not_started" - pending = "pending" - running = "running" - succeeded = "succeeded" - failed = "failed" - cancelled = "cancelled" - - -"""Dictionary mapping training methods to the data keys they require.""" -TRAINING_METHOD_TO_DATA_KEYS = { - TrainingMethod.SFT: ["prompt", "completion"], - TrainingMethod.Preference: ["prompt", "chosen", "rejected"], -} - - -class FinetuneJob(Future): - def __init__( - self, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", - ): - self.model = model - self.train_data = train_data - self.train_kwargs: Dict[str, Any] = train_kwargs or {} - self.train_method = train_method - self.provider = provider - super().__init__() - - def get_kwargs(self): - return dict( - model=self.model, - train_data=self.train_data, - train_kwargs=self.train_kwargs, - train_method=self.train_method, - provider=self.provider, - ) - - def __repr__(self): - return str(self) - - # Subclasses should override the cancel method to cancel the finetune job; - # then call the super's cancel method so that the future can be cancelled. - def cancel(self): - """Cancel the finetune job.""" - super().cancel() - - @abstractmethod - def status(self): - """Get the status of the finetune job.""" - raise NotImplementedError("Method `status` is not implemented.") - - -def validate_finetune_data(data: List[Dict[str, Any]], train_method: TrainingMethod): - """Validate the finetune data based on the training method.""" - # Get the required data keys for the training method - required_keys = TRAINING_METHOD_TO_DATA_KEYS[train_method] - - # Check if the training data has the required keys - for ind, data_dict in enumerate(data): - if not all([key in data_dict for key in required_keys]): - raise ValueError( - f"The datapoint at index {ind} is missing the keys required for {train_method} training. Expected: " - f"{required_keys}, Found: {data_dict.keys()}" - ) - - -def save_data( - data: List[Dict[str, Any]], - provider_name: Optional[str] = None, -) -> str: - """Save the fine-tuning data to a file.""" - logger.info("[Finetune] Converting data to JSONL format...") - # Construct the file name based on the data hash - hash = Hasher.hash(data) - file_name = f"{hash}.jsonl" - file_name = f"{provider_name}_{file_name}" if provider_name else file_name - - # Find the directory to save the fine-tuning data - finetune_parent_dir = get_finetune_directory() - os.makedirs(finetune_parent_dir, exist_ok=True) - - # Save the data to a file - file_path = os.path.join(finetune_parent_dir, file_name) - file_path = os.path.abspath(file_path) - with open(file_path, "w") as f: - for item in data: - f.write(ujson.dumps(item) + "\n") - return file_path diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 41b450abc..b73d272ff 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -3,15 +3,22 @@ import logging import os import uuid -from concurrent.futures import ThreadPoolExecutor from datetime import datetime +import threading from typing import Any, Dict, List, Literal, Optional import dspy import litellm import ujson -from dspy.clients.finetune import FinetuneJob, TrainingMethod -from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class +from dspy.adapters.base import Adapter +from dspy.clients.provider import Provider, TrainingJob +from dspy.clients.openai import OpenAIProvider +from dspy.clients.utils_finetune import ( + DataFormat, + validate_data_format, + infer_data_format +) + from dspy.utils.callback import BaseCallback, with_callbacks @@ -33,6 +40,7 @@ def __init__( launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, num_retries: int = 3, + provider=None, **kwargs, ): """ @@ -56,6 +64,8 @@ def __init__( self.model_type = model_type self.cache = cache self.launch_kwargs = launch_kwargs or {} + self.provider = provider or self.infer_provider() + self.callbacks = callbacks or [] self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] self.callbacks = callbacks or [] @@ -64,8 +74,8 @@ def __init__( #turned off by default to avoid LiteLLM logging during every LM call litellm.suppress_debug_info = dspy.settings.suppress_debug_info - # TODO: Arbitrary model strings could include the substring "o1-". We - # should find a more robust way to check for the "o1-" family models. + # TODO(bug): Arbitrary model strings could include the substring "o1-". + # We should find a more robust way to check for the "o1-" family models. if "o1-" in model: assert ( max_tokens >= 5000 and temperature == 1.0 @@ -108,47 +118,86 @@ def __call__(self, prompt=None, messages=None, **kwargs): return outputs def launch(self): - """Send a request to the provider to launch the model, if needed.""" - msg = f"`launch()` is called for the auto-launched model {self.model}" - msg += " -- no action is taken!" - logger.info(msg) + self.provider.launch(self.model, self.launch_kwargs) def kill(self): - """Send a request to the provider to kill the model, if needed.""" - msg = f"`kill()` is called for the auto-launched model {self.model}" - msg += " -- no action is taken!" - logger.info(msg) + self.provider.kill(self.model, self.launch_kwargs) def finetune( - self, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", - cache_finetune: bool = True, - ) -> FinetuneJob: - """Start model fine-tuning, if supported.""" + self, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]]=None, + data_format: Optional[DataFormat] = None, + ) -> TrainingJob: from dspy import settings as settings err = "Fine-tuning is an experimental feature." err += " Set `dspy.settings.experimental` to `True` to use it." assert settings.experimental, err - FinetuneJobClass = get_provider_finetune_job_class(provider=provider) - finetune_job = FinetuneJobClass( + err = f"Provider {self.provider} does not support fine-tuning." + assert self.provider.finetunable, err + + # Perform data validation before starting the thread to fail early + train_kwargs = train_kwargs or {} + if not data_format: + adapter = self.infer_adapter() + data_format = infer_data_format(adapter) + validate_data_format(data=train_data, data_format=data_format) + + # TODO(PR): We can quickly add caching, but doing so requires + # adding functions that just call other functions as we had in the last + # iteration, unless people have other ideas. + def thread_function_wrapper(): + return self._run_finetune_job(job) + + thread = threading.Thread(target=thread_function_wrapper) + job = self.provider.TrainingJob( + thread=thread, model=self.model, train_data=train_data, train_kwargs=train_kwargs, - train_method=train_method, - provider=provider, + data_format=data_format ) - - executor = ThreadPoolExecutor(max_workers=1) - executor.submit(execute_finetune_job, finetune_job, lm=self, cache_finetune=cache_finetune) - executor.shutdown(wait=False) - - return finetune_job - + thread.start() + + return job + + def _run_finetune_job(self, job: TrainingJob): + # TODO(enhance): We should listen for keyboard interrupts somewhere. + # Requires TrainingJob.cancel() to be implemented for each provider. + try: + model = self.provider.finetune( + job=job, + model=job.model, + train_data=job.train_data, + train_kwargs=job.train_kwargs, + data_format=job.data_format + ) + lm = self.copy(model=model) + job.set_result(lm) + except Exception as err: + logger.error(err) + job.set_result(err) + + def infer_provider(self) -> Provider: + if OpenAIProvider.is_provider_model(self.model): + return OpenAIProvider() + # TODO(PR): Keeping this function here will require us to import all + # providers in this file. Is this okay? + return Provider() + + def infer_adapter(self) -> Adapter: + import dspy + if dspy.settings.adapter: + return dspy.settings.adapter + + model_type_to_adapter = { + "chat": dspy.ChatAdapter(), + } + model_type = self.model_type + return model_type_to_adapter[model_type] + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" diff --git a/dspy/clients/lm_finetune_utils.py b/dspy/clients/lm_finetune_utils.py deleted file mode 100644 index 23f300aff..000000000 --- a/dspy/clients/lm_finetune_utils.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from typing import Any, Dict, List, Optional, Type, Union - -from dspy.clients.anyscale import FinetuneJobAnyScale, finetune_anyscale -from dspy.clients.finetune import FinetuneJob, TrainingMethod -from dspy.clients.openai import FinetuneJobOpenAI, finetune_openai - -logger = logging.getLogger(__name__) - -_PROVIDER_ANYSCALE = "anyscale" -_PROVIDER_OPENAI = "openai" - - -def get_provider_finetune_job_class(provider: str) -> Type[FinetuneJob]: - """Get the FinetuneJob class for the provider.""" - provider_to_job_class = { - _PROVIDER_ANYSCALE: FinetuneJobAnyScale, - _PROVIDER_OPENAI: FinetuneJobOpenAI, - } - return provider_to_job_class[provider] - - -def get_provider_finetune_function(provider: str) -> callable: - """Return the finetune function for the given model.""" - provider_to_finetune_function = { - _PROVIDER_ANYSCALE: finetune_anyscale, - _PROVIDER_OPENAI: finetune_openai, - } - return provider_to_finetune_function[provider] - - -# Note: Type of LM should be LM. We aren't importing it here to avoid -# circular imports. -def execute_finetune_job(job: FinetuneJob, lm: Any, cache_finetune: bool = True): - """Execute the finetune job in a blocking manner.""" - try: - job_kwargs = job.get_kwargs() - if cache_finetune: - model = cached_finetune(job=job, **job_kwargs) - else: - model = finetune(job=job, **job_kwargs) - lm = lm.copy(model=model) - job.set_result(lm) - except Exception as err: - logger.error(err) - job.set_result(err) - - -# TODO: Add DiskCache, ignore job -def cached_finetune( - job, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", -) -> Union[str, Exception]: - return finetune( - job=job, - model=model, - train_data=train_data, - train_kwargs=train_kwargs, - train_method=train_method, - provider=provider, - ) - - -def finetune( - job, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", -) -> Union[str, Exception]: - """Fine-tune a new model based on the given model.""" - # Get the fine-tuning provider - try: - # Get the finetune function - provider_finetune_function = get_provider_finetune_function(provider) - - # Fine-tune a new model based on the given model - model = provider_finetune_function( - job=job, - model=model, - train_data=train_data, - train_kwargs=train_kwargs, - train_method=train_method, - ) - except Exception as err: - raise err - - return model diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py index f5d08c283..77a540a12 100644 --- a/dspy/clients/openai.py +++ b/dspy/clients/openai.py @@ -1,63 +1,67 @@ -import logging import re import time -from collections import defaultdict -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import openai -from dspy.clients.finetune import ( - FinetuneJob, - TrainingMethod, - TrainingStatus, - save_data, - validate_finetune_data, -) - -# Provider name -PROVIDER_OPENAI = "openai" - -logger = logging.getLogger(__name__) - - -def is_openai_model(model: str) -> bool: - """Check if the model is an OpenAI model.""" - # Filter the provider_prefix, if exists - provider_prefix = f"{PROVIDER_OPENAI}/" - if model.startswith(provider_prefix): - model = model[len(provider_prefix) :] - - client = openai.OpenAI() - valid_model_names = [model.id for model in client.models.list().data] - # Check if the model is a base OpenAI model - if model in valid_model_names: - return True - - # Check if the model is a fine-tuned OpenAI model. Fine-tuned OpenAI models - # have the prefix "ft::", followed by a string specifying - # the fine-tuned model. The following RegEx pattern is used to match the - # base model name. - # TODO: This part can be updated to match the actual fine-tuned model names - # by making a call to the OpenAI API to be more exact, but this might - # require an API key with the right permissions. - match = re.match(r"ft:([^:]+):", model) - if match and match.group(1) in valid_model_names: - return True - - return False +from dspy.clients.provider import TrainingJob, Provider +from dspy.clients.utils_finetune import DataFormat, TrainingStatus, save_data + + +_OPENAI_MODELS = [ + 'gpt-4-turbo', + 'gpt-4-turbo-2024-04-09', + 'tts-1', + 'tts-1-1106', + 'chatgpt-4o-latest', + 'dall-e-2', + 'whisper-1', + 'gpt-3.5-turbo-instruct', + 'gpt-3.5-turbo', + 'gpt-3.5-turbo-0125', + 'babbage-002', + 'davinci-002', + 'gpt-4o-mini-2024-07-18', + 'gpt-4o', + 'dall-e-3', + 'gpt-4o-mini', + 'gpt-4o-2024-08-06', + 'gpt-4o-2024-05-13', + 'o1-preview', + 'gpt-4o-audio-preview-2024-10-01', + 'o1-mini-2024-09-12', + 'gpt-4o-audio-preview', + 'tts-1-hd', + 'tts-1-hd-1106', + 'o1-preview-2024-09-12', + 'o1-mini', + 'gpt-4-1106-preview', + 'text-embedding-ada-002', + 'gpt-3.5-turbo-16k', + 'text-embedding-3-small', + 'text-embedding-3-large', + 'gpt-4o-realtime-preview-2024-10-01', + 'gpt-4o-realtime-preview', + 'gpt-3.5-turbo-1106', + 'gpt-4-0613', + 'gpt-4-turbo-preview', + 'gpt-4-0125-preview', + 'gpt-4', + 'gpt-3.5-turbo-instruct-0914' +] -class FinetuneJobOpenAI(FinetuneJob): +class TrainingJobOpenAI(TrainingJob): def __init__(self, *args, **kwargs): - self.provider_file_id = None # TODO: Can we get this using the job_id? - self.provider_job_id = None super().__init__(*args, **kwargs) + self.provider_file_id = None + self.provider_job_id = None def cancel(self): # Cancel the provider job - if _does_job_exist(self.provider_job_id): - status = _get_training_status(self.provider_job_id) - if _is_terminal_training_status(status): + if OpenAIProvider.does_job_exist(self.provider_job_id): + status = self.status() + if OpenAIProvider.is_terminal_training_status(status): err_msg = "Jobs that are complete cannot be canceled." err_msg += f" Job with ID {self.provider_job_id} is done." raise Exception(err_msg) @@ -65,9 +69,8 @@ def cancel(self): self.provider_job_id = None # Delete the provider file - # TODO: Should there be a separate clean method? if self.provider_file_id is not None: - if _does_file_exist(self.provider_file_id): + if OpenAIProvider.does_file_exist(self.provider_file_id): openai.files.delete(self.provider_file_id) self.provider_file_id = None @@ -75,286 +78,190 @@ def cancel(self): super().cancel() def status(self) -> TrainingStatus: - status = _get_training_status(self.provider_job_id) + status = OpenAIProvider.get_training_status(self.provider_job_id) return status -def finetune_openai( - job: FinetuneJobOpenAI, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, -) -> str: - train_kwargs = train_kwargs or {} - train_method = TrainingMethod.SFT # Note: This could be an argument; ignoring method - - # Validate train data and method - logger.info("[Finetune] Validating the formatting of the data") - _validate_data(train_data, train_method) - logger.info("[Finetune] Done!") - - # Convert to the OpenAI format - logger.info("[Finetune] Converting the data to the OpenAI format") - # TODO: Should we use the system prompt? - train_data = _convert_data(train_data) - logger.info("[Finetune] Done!") - - # Save to a file - logger.info("[Finetune] Saving the data to a file") - data_path = save_data(train_data, provider_name=PROVIDER_OPENAI) - logger.info("[Finetune] Done!") - - # Upload the data to the cloud - logger.info("[Finetune] Uploading the data to the provider") - provider_file_id = _upload_data(data_path) - job.provider_file_id = provider_file_id - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Start remote training") - # We utilize model and train_kwargs here - provider_job_id = _start_remote_training( - train_file_id=job.provider_file_id, - model=model, - train_kwargs=train_kwargs, - ) - job.provider_job_id = provider_job_id - # job.provider_job_id = "ftjob-ZdEL1mUDk0dwdDuZJQOng8Vv" - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Wait for training to complete") - # TODO: Would it be possible to stream the logs? - _wait_for_job(job) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Get trained model if the run was a success") - model = _get_trained_model(job) - logger.info("[Finetune] Done!") - - return model - - -_SUPPORTED_TRAINING_METHODS = [ - TrainingMethod.SFT, -] - - -def _get_training_status(job_id: str) -> Union[TrainingStatus, Exception]: - # TODO: Should this type be shared across all fine-tune clients? - provider_status_to_training_status = { - "validating_files": TrainingStatus.pending, - "queued": TrainingStatus.pending, - "running": TrainingStatus.running, - "succeeded": TrainingStatus.succeeded, - "failed": TrainingStatus.failed, - "cancelled": TrainingStatus.cancelled, - } - - # Check if there is an active job - if job_id is None: - logger.info("There is no active job.") - return TrainingStatus.not_started - - err_msg = f"Job with ID {job_id} does not exist." - assert _does_job_exist(job_id), err_msg - - # Retrieve the provider's job and report the status - provider_job = openai.fine_tuning.jobs.retrieve(job_id) - provider_status = provider_job.status - status = provider_status_to_training_status[provider_status] - - return status - - -def _does_job_exist(job_id: str) -> bool: - try: - # TODO: Error handling is vague - openai.fine_tuning.jobs.retrieve(job_id) - return True - except Exception: - return False - +class OpenAIProvider(Provider): + + def __init__(self): + super().__init__() + self.finetunable = True + self.TrainingJob = TrainingJobOpenAI + + @staticmethod + def is_provider_model(model: str) -> bool: + # Filter the provider_prefix, if exists + provider_prefix = "openai/" + if model.startswith(provider_prefix): + model = model[len(provider_prefix):] + + # Check if the model is a base OpenAI model + # TODO(enhance) The following list can be replaced with + # openai.models.list(), but doing so might require a key. Is there a + # way to get the list of models without a key? + valid_model_names = _OPENAI_MODELS + if model in valid_model_names: + return True + + # Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI + # models have the prefix "ft::", followed by a string + # specifying the fine-tuned model. The following RegEx pattern is used + # to match the base model name. + # TODO(enhance): This part can be updated to match the actual fine-tuned + # model names by making a call to the OpenAI API to be more exact, but + # this might require an API key with the right permissions. + match = re.match(r"ft:([^:]+):", model) + if match and match.group(1) in valid_model_names: + return True -def _does_file_exist(file_id: str) -> bool: - try: - # TODO: Error handling is vague - openai.files.retrieve(file_id) - return True - except Exception: return False - -def _is_terminal_training_status(status: TrainingStatus) -> bool: - return status in [ - TrainingStatus.succeeded, - TrainingStatus.failed, - TrainingStatus.cancelled, - ] - - -def _validate_data(data: Dict[str, str], train_method: TrainingMethod) -> Optional[Exception]: - # Check if this train method is supported - if train_method not in _SUPPORTED_TRAINING_METHODS: - err_msg = f"OpenAI does not support the training method {train_method}." - raise ValueError(err_msg) - - validate_finetune_data(data, train_method) - - -def _convert_data( - data: List[Dict[str, str]], - system_prompt: Optional[str] = None, -) -> Union[List[Dict[str, Any]], Exception]: - # Item-wise conversion function - def _row_converter(d): - messages = [{"role": "user", "content": d["prompt"]}, {"role": "assistant", "content": d["completion"]}] - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - messages_dict = {"messages": messages} - return messages_dict - - # Convert the data to the OpenAI format; validate the converted data - converted_data = list(map(_row_converter, data)) - openai_data_validation(converted_data) - return converted_data - - -def _upload_data(data_path: str) -> str: - # Upload the data to the provider - provider_file = openai.files.create( - file=open(data_path, "rb"), - purpose="fine-tune", - ) - return provider_file.id - - -def _start_remote_training(train_file_id: str, model: id, train_kwargs: Optional[Dict[str, Any]] = None) -> str: - train_kwargs = train_kwargs or {} - provider_job = openai.fine_tuning.jobs.create( - model=model, - training_file=train_file_id, - hyperparameters=train_kwargs, - ) - return provider_job.id - - -def _wait_for_job( - job: FinetuneJobOpenAI, - poll_frequency: int = 60, -): - while not _is_terminal_training_status(job.status()): - time.sleep(poll_frequency) - - -def _get_trained_model(job): - status = job.status() - if status != TrainingStatus.succeeded: - err_msg = f"Job status is {status}." - err_msg += f" Must be {TrainingStatus.succeeded} to retrieve the model." - logger.error(err_msg) - raise Exception(err_msg) - - provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) - finetuned_model = provider_job.fine_tuned_model - return finetuned_model - - -# Adapted from https://cookbook.openai.com/examples/chat_finetuning_data_prep -def openai_data_validation(dataset: List[dict[str, Any]]): - format_errors = defaultdict(int) - for ex in dataset: - if not isinstance(ex, dict): - format_errors["data_type"] += 1 - continue - - messages = ex.get("messages", None) - if not messages: - format_errors["missing_messages_list"] += 1 - continue - - for message in messages: - if "role" not in message or "content" not in message: - format_errors["message_missing_key"] += 1 - - if any(k not in ("role", "content", "name", "function_call", "weight") for k in message): - format_errors["message_unrecognized_key"] += 1 - - if message.get("role", None) not in ("system", "user", "assistant", "function"): - format_errors["unrecognized_role"] += 1 - - content = message.get("content", None) - function_call = message.get("function_call", None) - - if (not content and not function_call) or not isinstance(content, str): - format_errors["missing_content"] += 1 - - if not any(message.get("role", None) == "assistant" for message in messages): - format_errors["example_missing_assistant_message"] += 1 - - # Raise an error if there are any format errors - if format_errors: - err_msg = "Found errors in the dataset format using the OpenAI API." - err_msg += " Here are the number of datapoints for each error type:" - for k, v in format_errors.items(): - err_msg += "\n {k}: {v}" - raise ValueError(err_msg) - - -def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]: - n_missing_system = 0 - n_missing_user = 0 - n_messages = [] - convo_lens = [] - assistant_message_lens = [] - - for ex in dataset: - messages = ex["messages"] - if not any(message["role"] == "system" for message in messages): - n_missing_system += 1 - if not any(message["role"] == "user" for message in messages): - n_missing_user += 1 - n_messages.append(len(messages)) - convo_lens.append(num_tokens_from_messages(messages)) - assistant_message_lens.append(num_assistant_tokens_from_messages(messages)) - n_too_long = sum([length > 16385 for length in convo_lens]) - - if n_too_long > 0: - logger.info( - f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning." + @staticmethod + def finetune( + job: TrainingJobOpenAI, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[DataFormat] = None, + ) -> str: + print("[OpenAI Provider] Validating the data format") + OpenAIProvider.validate_data_format(data_format) + + print("[OpenAI Provider] Saving the data to a file") + data_path = save_data(train_data) + print(f"[OpenAI Provider] Data saved to {data_path}") + + print("[OpenAI Provider] Uploading the data to the provider") + provider_file_id = OpenAIProvider.upload_data(data_path) + job.provider_file_id = provider_file_id + + print("[OpenAI Provider] Starting remote training") + provider_job_id = OpenAIProvider.start_remote_training( + train_file_id=job.provider_file_id, + model=model, + train_kwargs=train_kwargs, ) + job.provider_job_id = provider_job_id + print(f"[OpenAI Provider] Job started with the OpenAI Job ID {provider_job_id}") + + print("[OpenAI Provider] Waiting for training to complete") + # TODO(feature): Could we stream OAI logs? + OpenAIProvider.wait_for_job(job) + + print("[OpenAI Provider] Attempting to retrieve the trained model") + model = OpenAIProvider.get_trained_model(job) + print(f"[OpenAI Provider] Model retrieved: {model}") + + return model + + @staticmethod + def does_job_exist(job_id: str) -> bool: + try: + # TODO(nit): This call may fail for other reasons. We should check + # the error message to ensure that the job does not exist. + openai.fine_tuning.jobs.retrieve(job_id) + return True + except Exception: + return False + + @staticmethod + def does_file_exist(file_id: str) -> bool: + try: + # TODO(nit): This call may fail for other reasons. We should check + # the error message to ensure that the file does not exist. + openai.files.retrieve(file_id) + return True + except Exception: + return False + + + @staticmethod + def is_terminal_training_status(status: TrainingStatus) -> bool: + return status in [ + TrainingStatus.succeeded, + TrainingStatus.failed, + TrainingStatus.cancelled, + ] + + @staticmethod + def get_training_status(job_id: str) -> TrainingStatus: + provider_status_to_training_status = { + "validating_files": TrainingStatus.pending, + "queued": TrainingStatus.pending, + "running": TrainingStatus.running, + "succeeded": TrainingStatus.succeeded, + "failed": TrainingStatus.failed, + "cancelled": TrainingStatus.cancelled, + } + + # Check if there is an active job + if job_id is None: + print("There is no active job.") + return TrainingStatus.not_started + + err_msg = f"Job with ID {job_id} does not exist." + assert OpenAIProvider.does_job_exist(job_id), err_msg + + # Retrieve the provider's job and report the status + provider_job = openai.fine_tuning.jobs.retrieve(job_id) + provider_status = provider_job.status + status = provider_status_to_training_status[provider_status] - if n_missing_system > 0: - logger.info(f"There are {n_missing_system} examples that are missing a system message.") - - if n_missing_user > 0: - logger.info(f"There are {n_missing_user} examples that are missing a user message.") - - return convo_lens - - -def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1): - import tiktoken - - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 - return num_tokens - - -def num_assistant_tokens_from_messages(messages): - import tiktoken - - encoding = tiktoken.get_encoding("cl100k_base") + return status - num_tokens = 0 - for message in messages: - if message["role"] == "assistant": - num_tokens += len(encoding.encode(message["content"])) - return num_tokens + @staticmethod + def validate_data_format(data_format: DataFormat): + supported_data_formats = [ + DataFormat.chat, + DataFormat.completion, + ] + if data_format not in supported_data_formats: + err_msg = f"OpenAI does not support the data format {data_format}." + raise ValueError(err_msg) + + @staticmethod + def upload_data(data_path: str) -> str: + # Upload the data to the provider + provider_file = openai.files.create( + file=open(data_path, "rb"), + purpose="fine-tune", + ) + return provider_file.id + + @staticmethod + def start_remote_training( + train_file_id: str, + model: id, + train_kwargs: Optional[Dict[str, Any]] = None + ) -> str: + train_kwargs = train_kwargs or {} + provider_job = openai.fine_tuning.jobs.create( + model=model, + training_file=train_file_id, + hyperparameters=train_kwargs, + ) + return provider_job.id + + @staticmethod + def wait_for_job( + job: TrainingJobOpenAI, + poll_frequency: int = 20, + ): + done = False + while not done: + done = OpenAIProvider.is_terminal_training_status(job.status()) + time.sleep(poll_frequency) + + + @staticmethod + def get_trained_model(job): + status = job.status() + if status != TrainingStatus.succeeded: + err_msg = f"Job status is {status}." + err_msg += f" Must be {TrainingStatus.succeeded} to retrieve model." + raise Exception(err_msg) + + provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) + finetuned_model = provider_job.fine_tuned_model + return finetuned_model diff --git a/dspy/clients/provider.py b/dspy/clients/provider.py new file mode 100644 index 000000000..9eb02be1f --- /dev/null +++ b/dspy/clients/provider.py @@ -0,0 +1,67 @@ +from concurrent.futures import Future +from abc import abstractmethod +from threading import Thread +from typing import Any, Dict, List, Optional + +from dspy.clients.utils_finetune import DataFormat + + +class TrainingJob(Future): + def __init__( + self, + thread: Thread, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[DataFormat] = None, + ): + self.thread = thread + self.model = model + self.train_data = train_data + self.train_kwargs = train_kwargs or {} + self.data_format = data_format + super().__init__() + + # Subclasses should override the cancel method to cancel the job; then call + # the super's cancel method so that the future can be cancelled. + def cancel(self): + super().cancel() + + @abstractmethod + def status(self): + raise NotImplementedError + + +class Provider: + + def __init__(self): + self.finetunable = False + self.TrainingJob = TrainingJob + + @staticmethod + def is_provider_model(model: str) -> bool: + # Subclasses should actually check whether a model is supported if they + # want to have the model provider auto-discovered. + return False + + @staticmethod + def launch(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + msg = f"`launch()` is called for the auto-launched model `{model}`" + msg += " -- no action is taken!" + print(msg) + + @staticmethod + def kill(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + msg = f"`kill()` is called for the auto-launched model `{model}`" + msg += " -- no action is taken!" + print(msg) + + @staticmethod + def finetune( + job: TrainingJob, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[DataFormat] = None, + ) -> str: + raise NotImplementedError diff --git a/dspy/clients/utils_finetune.py b/dspy/clients/utils_finetune.py new file mode 100644 index 000000000..7a8b6b4d2 --- /dev/null +++ b/dspy/clients/utils_finetune.py @@ -0,0 +1,147 @@ +import os +from enum import Enum +from typing import Any, Dict, List, Optional + +import ujson +from datasets.fingerprint import Hasher + +import dspy +from dspy.adapters.base import Adapter +from dspy.utils.caching import create_subdir_in_cachedir + + +class DataFormat(str, Enum): + chat = "chat" + completion = "completion" + + +class TrainingStatus(str, Enum): + not_started = "not_started" + pending = "pending" + running = "running" + succeeded = "succeeded" + failed = "failed" + cancelled = "cancelled" + + +def get_finetune_directory() -> str: + return create_subdir_in_cachedir(subdir="finetune") + + +def infer_data_format(adapter: Adapter) -> str: + if isinstance(adapter, dspy.ChatAdapter): + return DataFormat.chat + raise ValueError(f"Could not infer the data format for: {adapter}") + + +def write_lines(file_path, data): + with open(file_path, "w") as f: + for item in data: + f.write(ujson.dumps(item) + "\n") + + +def save_data( + data: List[Dict[str, Any]], +) -> str: + # Assign a unique name to the file based on the data hash + hash = Hasher.hash(data) + file_name = f"{hash}.jsonl" + + finetune_dir = get_finetune_directory() + file_path = os.path.join(finetune_dir, file_name) + file_path = os.path.abspath(file_path) + with open(file_path, "w") as f: + for item in data: + f.write(ujson.dumps(item) + "\n") + return file_path + + +def validate_data_format( + data: List[Dict[str, Any]], + data_format: DataFormat, + ): + find_err_funcs = { + DataFormat.chat: find_data_error_chat, + DataFormat.completion: find_data_errors_completion + } + err = f"Data format {data_format} is not supported." + assert data_format in find_err_funcs, err + find_err_func = find_err_funcs[data_format] + + if not isinstance(data, list): + err = f"Data is not a list. Found data type: {type(data)}" + raise ValueError(err) + + data_dict_errors = [] + for ind, data_dict in enumerate(data): + err = f"Not a dictionary -- found data type: {type(data_dict)}" + if isinstance(data_dict, dict): + err = find_err_func(data_dict) + if err: + err_dict = dict(index=ind, error=err) + data_dict_errors.append(err_dict) + + if data_dict_errors: + finetune_dir = get_finetune_directory() + log_path = os.path.join(finetune_dir, "data_format_errors.log") + log_path = os.path.abspath(log_path) + write_lines(log_path, data_dict_errors) + + err = f"Data format errors found. For more details, see the log file: {log_path}" + raise ValueError(err) + + +def find_data_errors_completion( + data_dict: Dict[str, str] + ) -> Optional[str]: + keys = ["prompt", "completion"] + + assert isinstance(data_dict, dict) + expected_keys = sorted(keys) + found_keys = sorted(data_dict.keys()) + if set(expected_keys) != set(found_keys): + return f"Expected Keys: {expected_keys}; Found Keys: {found_keys}" + + for key in keys: + if not isinstance(data_dict[key], str): + return f"Expected `{key}` to be of type `str`. Found: {type(data_dict[key])}" + + +# Following functions are modified from the OpenAI cookbook: +# https://cookbook.openai.com/examples/chat_finetuning_data_prep +def find_data_error_chat( + messages: Dict[str, Any] + ) -> Optional[str]: + assert isinstance(messages, dict) + + expected_keys = ["messages"] + found_keys = sorted(messages.keys()) + if set(expected_keys) != set(found_keys): + return f"Expected Keys: {expected_keys}; Found Keys: {found_keys}" + + if not isinstance(messages["messages"], list): + return f"The value of the `messages` key should be a list instance. Found: {type(messages['messages'])}" + + for ind, message in enumerate(messages["messages"]): + err = find_data_error_chat_message(message) + if err: + return f"Error in message at index {ind}: {err}" + + +def find_data_error_chat_message( + message: Dict[str, Any] + ) -> Optional[str]: + assert isinstance(message, dict) + + message_keys = sorted(["role", "content"]) + found_keys = sorted(message.keys()) + if set(message_keys) != set(found_keys): + return f"Expected Keys: {message_keys}; Found Keys: {found_keys}" + + expected_roles = sorted(["assistant", "system", "user"]) + found_role = message["role"] + if found_role not in expected_roles: + return f"Expected Roles: {expected_roles}; Found Role: {found_role}" + + if not isinstance(message["content"], str): + return f"Expected Content Type: `str`; Found Content Type: {type(message['content'])}" diff --git a/dspy/functional/functional.py b/dspy/functional/functional.py index 03d859763..7beda5eae 100644 --- a/dspy/functional/functional.py +++ b/dspy/functional/functional.py @@ -261,16 +261,21 @@ def parse(x): and typing.get_origin(type_) not in (list, tuple) # To support Python 3.9 and issubclass(type_, pydantic.BaseModel) ): - to_json = lambda x: x.model_dump_json() - from_json = lambda x, type_=type_: type_.model_validate_json(x) + def to_json(x): + return x.model_dump_json() + def from_json(x, type_=type_): + return type_.model_validate_json(x) schema = json.dumps(type_.model_json_schema()) else: adapter = pydantic.TypeAdapter(type_) - to_json = lambda x: adapter.serializer.to_json(x) - from_json = lambda x, type_=adapter: type_.validate_json(x) + def to_json(x): + return adapter.serializer.to_json(x) + def from_json(x, type_=adapter): + return type_.validate_json(x) schema = json.dumps(adapter.json_schema()) if self.wrap_json: - to_json = lambda x, inner=to_json: "```json\n" + inner(x) + "\n```\n" + def to_json(x, inner=to_json): + return "```json\n" + inner(x) + "\n```\n" schema = "```json\n" + schema + "\n```" signature = signature.with_updated_fields( name, @@ -283,24 +288,27 @@ def parse(x): ) else: # If input field is_json = False - format_ = lambda x: x if isinstance(x, str) else str(x) + def format_(x): + return x if isinstance(x, str) else str(x) if type_ in (List[str], list[str], Tuple[str], tuple[str]): format_ = passages2text # Special formatting for lists of known types. Maybe the output fields sohuld have this too? elif typing.get_origin(type_) in (List, list, Tuple, tuple): (inner_type,) = typing.get_args(type_) if inspect.isclass(inner_type) and issubclass(inner_type, pydantic.BaseModel): - format_ = ( - lambda x: x if isinstance(x, str) else "[" + ",".join(i.model_dump_json() for i in x) + "]" - ) + def format_(x): + return x if isinstance(x, str) else "[" + ",".join(i.model_dump_json() for i in x) + "]" else: - format_ = lambda x: x if isinstance(x, str) else json.dumps(x) + def format_(x): + return x if isinstance(x, str) else json.dumps(x) is_json = True elif inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): - format_ = lambda x: x if isinstance(x, str) else x.model_dump_json() + def format_(x): + return x if isinstance(x, str) else x.model_dump_json() is_json = True if self.wrap_json and is_json: - format_ = lambda x, inner=format_: x if isinstance(x, str) else "```json\n" + inner(x) + "\n```\n" + def format_(x, inner=format_): + return x if isinstance(x, str) else "```json\n" + inner(x) + "\n```\n" signature = signature.with_updated_fields(name, format=format_) return signature diff --git a/dspy/predict/aggregation.py b/dspy/predict/aggregation.py index 4cb6df3f9..bc37e203f 100644 --- a/dspy/predict/aggregation.py +++ b/dspy/predict/aggregation.py @@ -1,7 +1,8 @@ from dsp.utils import normalize_text from dspy.primitives.prediction import Completions, Prediction -default_normalize = lambda s: normalize_text(s) or None +def default_normalize(s): + return normalize_text(s) or None def majority(prediction_or_completions, normalize=default_normalize, field=None): @@ -12,7 +13,7 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None) """ assert any(isinstance(prediction_or_completions, t) for t in [Prediction, Completions, list]) - input_type = type(prediction_or_completions) + type(prediction_or_completions) # Get the completions if isinstance(prediction_or_completions, Prediction): diff --git a/dspy/predict/chain_of_thought_with_hint.py b/dspy/predict/chain_of_thought_with_hint.py index 484a2a38a..fa33ca1c6 100644 --- a/dspy/predict/chain_of_thought_with_hint.py +++ b/dspy/predict/chain_of_thought_with_hint.py @@ -1,36 +1,30 @@ -import dsp import dspy -from .predict import Predict +from .predict import Module # TODO: FIXME: Insert this right before the *first* output field. Also rewrite this to use the new signature system. -class ChainOfThoughtWithHint(Predict): - def __init__(self, signature, rationale_type=None, activated=True, **config): - super().__init__(signature, **config) - self.activated = activated - signature = self.signature - - *keys, last_key = signature.fields.keys() - rationale_type = rationale_type or dspy.OutputField( - prefix="Reasoning: Let's think step by step in order to", - desc="${produce the " + last_key + "}. We ...", - ) - self.extended_signature1 = self.signature.insert(-2, "rationale", rationale_type, type_=str) - - DEFAULT_HINT_TYPE = dspy.OutputField() - self.extended_signature2 = self.extended_signature1.insert(-2, "hint", DEFAULT_HINT_TYPE, type_=str) +class ChainOfThoughtWithHint(Module): + def __init__(self, signature, rationale_type=None, **config): + self.signature = dspy.ensure_signature(signature) + self.module = dspy.ChainOfThought(signature, rationale_type=rationale_type, **config) def forward(self, **kwargs): - signature = self.signature - - if self.activated is True or (self.activated is None and isinstance(dsp.settings.lm, dsp.GPT3)): - if 'hint' in kwargs and kwargs['hint']: - signature = self.extended_signature2 - else: - signature = self.extended_signature1 + if 'hint' in kwargs and kwargs['hint']: + hint = f"\n\t\t(secret hint: {kwargs.pop('hint')})" + original_kwargs = kwargs.copy() + + # Convert the first field's value to string and append the hint + last_key = list(self.signature.input_fields.keys())[-1] + kwargs[last_key] = str(kwargs[last_key]) + hint + + # Run CoT then update the trace with original kwargs, i.e. without the hint. + pred = self.module(**kwargs) + this_trace = dspy.settings.trace[-1] + dspy.settings.trace[-1] = (this_trace[0], original_kwargs, this_trace[2]) + return pred - return super().forward(signature=signature, **kwargs) + return self.module(**kwargs) """ diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py index 88e2a431d..ec3ba4679 100644 --- a/dspy/predict/retry.py +++ b/dspy/predict/retry.py @@ -51,7 +51,7 @@ def forward(self, *, past_outputs, **kwargs): return self.original_forward(**kwargs) def __call__(self, **kwargs): - cached_kwargs = copy.deepcopy(kwargs) + copy.deepcopy(kwargs) kwargs["_trace"] = False kwargs.setdefault("demos", self.demos if self.demos is not None else []) diff --git a/dspy/primitives/assertions.py b/dspy/primitives/assertions.py index d3f3c4cd5..a878effe6 100644 --- a/dspy/primitives/assertions.py +++ b/dspy/primitives/assertions.py @@ -224,7 +224,7 @@ def wrapper(*args, **kwargs): except (DSPySuggestionError, DSPyAssertionError) as e: if not current_error: current_error = e - error_id, error_msg, error_target_module, error_state = ( + _error_id, error_msg, error_target_module, error_state = ( e.id, e.msg, e.target_module, @@ -273,7 +273,7 @@ def wrapper(*args, **kwargs): ) # save latest failure trace for predictor per suggestion - error_ip = error_state[1] + error_state[1] error_op = error_state[2].__dict__["_store"] error_op.pop("_assert_feedback", None) error_op.pop("_assert_traces", None) diff --git a/dspy/teleprompt/__init__.py b/dspy/teleprompt/__init__.py index 6e54b5676..5952c69fe 100644 --- a/dspy/teleprompt/__init__.py +++ b/dspy/teleprompt/__init__.py @@ -1,8 +1,9 @@ from .avatar_optimizer import * +from .bettertogether import BetterTogether from .bootstrap import * +from .bootstrap_finetune import BootstrapFinetune from .copro_optimizer import COPRO from .ensemble import * -from .finetune import * from .knn_fewshot import * from .mipro_optimizer import MIPRO from .mipro_optimizer_v2 import MIPROv2 diff --git a/dspy/teleprompt/bettertogether.py b/dspy/teleprompt/bettertogether.py new file mode 100644 index 000000000..42132baf0 --- /dev/null +++ b/dspy/teleprompt/bettertogether.py @@ -0,0 +1,170 @@ +import random +from typing import Callable, List, Optional + +import dspy +from dspy.clients.lm import LM +from dspy.primitives.example import Example +from dspy.primitives.program import Program +from dspy.teleprompt.teleprompt import Teleprompter + + +from dspy.teleprompt.bootstrap_finetune import ( + BootstrapFinetune, set_missing_predictor_lms, prepare_student +) +from dspy.teleprompt.random_search import BootstrapFewShotWithRandomSearch + + +class BetterTogether(Teleprompter): + + STRAT_SEP = " -> " + + def __init__(self, + metric: Callable, + prompt_optimizer: Optional[Teleprompter] = None, + weight_optimizer: Optional[Teleprompter] = None, + seed: Optional[int] = None, + ): + err = "This is an experimental optimizer." + err += " Set `dspy.settings.experimental` to `True` to use it." + if not dspy.settings.experimental: + raise ValueError(err) + + # TODO: Note that the BetterTogether optimizer is meaningful when + # BootstrapFinetune uses a metric to filter the training data before + # fine-tuning. However, one can also choose to run this optimizer with + # a BoostrapFinetune without a metric, say, if there aren't labels + # available for the training data. Should this be noted somewhere? + # TODO: We should re-consider if the metric should be required. + self.prompt_optimizer = prompt_optimizer if prompt_optimizer else BootstrapFewShotWithRandomSearch(metric=metric) + self.weight_optimizer = weight_optimizer if weight_optimizer else BootstrapFinetune(metric=metric) + + is_supported_prompt = isinstance(self.prompt_optimizer, BootstrapFewShotWithRandomSearch) + is_supported_weight = isinstance(self.weight_optimizer, BootstrapFinetune) + if not is_supported_prompt or not is_supported_weight: + err = "The BetterTogether optimizer supports the following optimizers for now: BootstrapFinetune, BootstrapFewShotWithRandomSearch." + raise ValueError(err) + + self.rng = random.Random(seed) + + def compile( + self, + student: Program, + trainset: List[Example], + strategy: str = "p -> w -> p", + valset_ratio = 0.1, + ) -> Program: + # TODO: We could record acc on a different valset to pick the best + # strategy within the provided strategy + print("[BetterTogether] Validating the strategy") + parsed_strategy = strategy.lower().split(self.STRAT_SEP) + err = f"The strategy should be a sequence of 'p' and 'w' separated by '{self.STRAT_SEP}', but found: {strategy}" + assert all([s in ["p", "w"] for s in parsed_strategy]), err + + print("[BetterTogether] Preparing the student program...") + # TODO: Prepare student returns student.reset_copy(), which is what gets + # optimized. We should make this clear in the doc comments. + student = prepare_student(student) + set_missing_predictor_lms(student) + + print("[BetterTogether] Compiling the student program...") + student = self._run_strategies(parsed_strategy, student, trainset, valset_ratio) + + print("[BetterTogether] BetterTogether has finished compiling the student program.") + return student + + def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> Program: + # Keep track of all the partial strategies/programs in parsed_strategy + # "" corresponds to the initial student program + candidate_programs = [] + candidate_programs.append(("", student)) + + for ind, step_code in enumerate(parsed_strategy): + current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1]) + print(f"[BetterTogether] Step {ind + 1} of {len(parsed_strategy)} - Strategy `{current_strategy}`") + + print("[BetterTogether] Shuffling the trainset...") + self.rng.shuffle(trainset) + + # TODO: Should we reset or just deepcopy? How does resetting affect + # the predictor LMs? + student = student.deepcopy() + student._compiled = False + if step_code == "p": + student = self._compile_prompt_optimizer(student, trainset, valset_ratio) + elif step_code == "w": + student = self._compile_weight_optimizer(student, trainset) + + # Record the program corresponding to the current strategy + candidate_programs.append((current_strategy, student)) + + student.candidate_programs = candidate_programs + return student + + def _compile_prompt_optimizer(self, student, trainset, valset_ratio) -> Program: + print("[BetterTogether] Preparing for prompt optimization...") + + # Sampling a validation set from the trainset for the prompt optimizer + num_val = int(valset_ratio * len(trainset)) + prompt_valset = trainset[:num_val] + prompt_trainset = trainset[num_val:] + + print("[BetterTogether] Launching the program LMs for sampling...") + self._launch_lms(student) + + # TODO: To make this optimizer general, we need to ensure that all the + # prompt optimizers are accepting a valset or encode a way to check if + # a valset should be passed to an optimizer's compile method. + # TODO: We should ensure that the prompt optimizers in DSPy respect the + # predictor.lm attributes. In particular, + # BootstrapFewShotWithRandomSearch seems to be resetting these. We are + # manually re-setting the LMs here to circumvent this issue, but we + # should consider adressing it in BFRS. + print("[BetterTogether] Compiling the prompt optimizer...") + pred_lms = [pred.lm for pred in student.predictors()] + student = self.prompt_optimizer.compile(student, trainset=prompt_trainset, valset=prompt_valset) + for pred, lm in zip(student.predictors(), pred_lms): + pred.lm = lm + + print("[BetterTogether] Killing the LMs used for sampling...") + self._kill_lms(student) + + return student + + def _compile_weight_optimizer(self, student, trainset) -> Program: + print("[BetterTogether] Preparing for weight optimization...") + + # Saving the LMs before compiling the weight optimizer + original_lms = [pred.lm for pred in student.predictors()] + + # TODO: To make this optimizer general, we need to ensure that all the + # prompt optimizers are accepting a valset or encode a way to check if + # a valset should be passed to an optimizer's compile. + print("[BetterTogether] Compiling the weight optimizer...") + student = self.weight_optimizer.compile(student, trainset=trainset) + + # Updating the train kwargs for the new LMs. This is needed because the + # train_kwargs of the optimizer is configured for the original LMs. + new_lms = [pred.lm for pred in student.predictors()] + for original_lm, new_lm in zip(original_lms, new_lms): + original_params = self.weight_optimizer.train_kwargs[original_lm] + self.weight_optimizer.train_kwargs[new_lm] = original_params + + return student + + @staticmethod + def _get_unique_lms(program: Program) -> List[LM]: + lms = [pred.lm for pred in program.predictors()] + lms = list(set(lms)) + return lms + + @staticmethod + def _launch_lms(program: Program): + lms = BetterTogether._get_unique_lms(program) + for lm in lms: + lm.launch() + + @staticmethod + def _kill_lms(program: Program): + lms = BetterTogether._get_unique_lms(program) + for lm in lms: + lm.kill() diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py new file mode 100644 index 000000000..418c8f0a4 --- /dev/null +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -0,0 +1,287 @@ +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Union + +import dspy +from dspy import LM # TODO: Remove after the old LM class is removed +from dspy.adapters.base import Adapter +from dspy.clients.utils_finetune import infer_data_format +from dspy.evaluate.evaluate import Evaluate +from dspy.primitives.example import Example +from dspy.predict.predict import Predict +from dspy.primitives.program import Program +from dspy.teleprompt.teleprompt import Teleprompter + + +class FinetuneTeleprompter(Teleprompter): + + def __init__( + self, + train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None, + ): + self.train_kwargs: Dict[LM, Any] = self.convert_to_lm_dict(train_kwargs or {}) + + @staticmethod + def convert_to_lm_dict(arg) -> Dict[LM, Any]: + non_empty_dict = arg and isinstance(arg, dict) + if non_empty_dict and all(isinstance(k, LM) for k in arg.keys()): + return arg + # Default to using the same value for all LMs + return defaultdict(lambda: arg) + + +class BootstrapFinetune(FinetuneTeleprompter): + + # TODO(PR) check with team + def __init__( + self, + metric: Optional[Callable] = None, + multitask: bool = True, + train_kwargs: Optional[Union[Dict[str, Any], Dict[LM, Dict[str, Any]]]] = None, + adapter: Optional[Union[Adapter, Dict[LM, Adapter]]] = None, + exclude_demos: bool = False, + num_threads: int = 6, + ): + # TODO(feature): Inputs train_kwargs (a dict with string keys) and + # adapter (Adapter) can depend on the LM they are used with. We are + # takingthese as parameters for the time being. However, they can be + # attached to LMs themselves -- an LM could know which adapter it should + # be used with along with the train_kwargs. This will lead the only + # required argument for LM.finetune() to be the train dataset. + err = "This is an experimental optimizer." + err += " Set `dspy.settings.experimental` to `True` to use it." + err += " Constructor arguments subject to change." + assert dspy.settings.experimental, err + + super().__init__(train_kwargs=train_kwargs) + self.metric = metric + self.multitask = multitask + self.adapter: Dict[LM, Adapter] = self.convert_to_lm_dict(adapter) + self.exclude_demos = exclude_demos + self.num_threads = num_threads + + def compile(self, student: Program, trainset: List[Example], teacher: Optional[Program] = None) -> Program: + # TODO: Print statements can be converted to logger.info if we ensure + # that the default DSPy logger logs info level messages in notebook + # environments. + print("[BootstrapFinetune] Preparing the student and teacher programs...") + student = prepare_student(student) + teacher = prepare_teacher(student, teacher) + set_missing_predictor_lms(student) + set_missing_predictor_lms(teacher) + + print("[BootstrapFinetune] Bootstrapping data...") + trace_data = bootstrap_trace_data(program=teacher, dataset=trainset, metric=self.metric, num_threads=self.num_threads) + + print("[BootstrapFinetune] Preparing the train data...") + key_to_data = {} + for pred_ind, pred in enumerate(student.predictors()): + data_pred_ind = None if self.multitask else pred_ind + training_key = (pred.lm, data_pred_ind) + if training_key not in key_to_data: + train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind) + print(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}") + finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_kwargs=self.train_kwargs[pred.lm], data_format=data_format) + key_to_data[training_key] = finetune_kwargs + + print("[BootstrapFinetune] Starting LM fine-tuning...") + # TODO(feature): We could run batches of fine-tuning jobs in sequence + # to avoid exceeding the number of threads. + err = f"BootstrapFinetune requires `num_threads` to be bigger than or equal to the number of fine-tuning jobs. There are {len(key_to_data)} fine-tuning jobs to start, but the number of threads is: {self.num_threads}! If the `multitask` flag is set to False, the number of fine-tuning jobs will be equal to the number of predictors in the student program. If the `multitask` flag is set to True, the number of fine-tuning jobs will be equal to: 1 if there is only a context LM, or the number of unique LMs attached to the predictors in the student program. In any case, the number of fine-tuning jobs will be less than or equal to the number of predictors." + assert len(key_to_data) <= self.num_threads, err + print(f"[BootstrapFinetune] {len(key_to_data)} fine-tuning job(s) to start") + key_to_lm = self.finetune_lms(key_to_data) + + print("[BootstrapFinetune] Updating the student program with the fine-tuned LMs...") + for pred_ind, pred in enumerate(student.predictors()): + data_pred_ind = None if self.multitask else pred_ind + training_key = (pred.lm, data_pred_ind) + pred.lm = key_to_lm[training_key] + # TODO: What should the correct behavior be here? Should + # BootstrapFinetune modify the prompt demos according to the + # train data? + pred.demos = [] if self.exclude_demos else pred.demos + + print("[BootstrapFinetune] BootstrapFinetune has finished compiling the student program") + student._compiled = True + return student + + @staticmethod + def finetune_lms(finetune_dict) -> Dict[Any, LM]: + num_jobs = len(finetune_dict) + print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning jobs...") + # TODO(nit) Pass an identifier to the job so that we can tell the logs + # coming from different fine-tune threads. + + key_to_job = {} + for key, finetune_kwargs in finetune_dict.items(): + lm = finetune_kwargs.pop("lm") + key_to_job[key] = lm.finetune(**finetune_kwargs) + + key_to_lm = {} + for ind, (key, job) in enumerate(key_to_job.items()): + key_to_lm[key] = job.result() + job.thread.join() + print(f"Job {ind + 1}/{num_jobs} completed.") + + return key_to_lm + + def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_ind: Optional[int] = None): + # TODO(nit) Log dataset details/size; make logs nicer + if self.metric: + print(f"[BootstrapFinetune] Collected data for {len(trace_data)} examples") + trace_data = [d for d in trace_data if d["score"]] + print(f"[BootstrapFinetune] After filtering for score, {len(trace_data)} examples remain") + + data = [] + adapter = self.adapter[lm] or lm.infer_adapter() + data_format = infer_data_format(adapter) + for item in trace_data: + for pred_ind, _ in enumerate(item['trace']): + include_data = pred_ind is None or pred_ind == pred_ind + if include_data: + call_data = build_call_data_from_trace(trace=item['trace'], pred_ind=pred_ind, adapter=adapter, exclude_demos=self.exclude_demos) + data.append(call_data) + + return data, data_format + + +def build_call_data_from_trace( + trace: List[Dict], + pred_ind: int, + adapter: Optional[Adapter] = None, + exclude_demos: bool = False, +) -> Dict[str, List[Dict[str, Any]]]: + # Find data that's relevant to the predictor + pred, inputs, outputs = trace[pred_ind] # assuming that the order is kept + + if not adapter: + # TODO(feature): A trace is collected using a particular adapter. It + # would be nice to get this adapter information from the trace (e.g. + # pred.lm.adapter) as opposed to using the inference method below. + adapter = pred.lm.infer_adapter() + + demos = [] if exclude_demos else pred.demos + call_data = adapter.format_finetune_data( + signature=pred.signature, + demos=demos, + inputs=inputs, + outputs=outputs, + ) + return call_data + + +def bootstrap_trace_data( + program: Program, + dataset: List[Example], + metric: Optional[Callable] = None, + num_threads=6, +) -> List[Dict[str, Any]]: + # Return a list of dicts with the following keys: + # example_ind, example, prediction, trace, and score (if metric != None) + evaluator = Evaluate( + devset=dataset, num_threads=num_threads, display_progress=True, + provide_traceback=True # TODO(check with team) + ) + # TODO(PR): Should "trace" not be included in the lambda function? + _metric = metric if metric else lambda example, prediction: 1 + evaluator(program, metric=_metric) + + data = [] + for example_ind, example in enumerate(dataset): + data_dict = bootstrap_trace_data_one_example( + example=example, program=program, metric=metric + ) + data_dict["example_ind"] = example_ind + data.append(data_dict) + + return data + + +# TODO(PR) check with team +def bootstrap_trace_data_one_example( + example: Example, + program: Program, + metric: Optional[Callable] = None +) -> Dict[str, Any]: + # Return a dict with the following keys: + # example, prediction, trace, and score (if metric != None) + with dspy.context(trace=[]): + prediction = program(**example.inputs()) + trace = dspy.settings.trace + score = metric(example, prediction, trace) if metric else None + + data_dict = dict( + example=example, + prediction=prediction, + trace=trace, + ) + if metric: + data_dict["score"] = score + + return data_dict + + +# Note: Shared below are useful functions for preparing student/teacher programs +# Similar methods are implemented separately and used by other DSPy +# teleprompters. These can be moved to shared locations. +def set_missing_predictor_lms(program: Program) -> Program: + # If the predictors do not have LMs, set them to the global LM + for pred in program.predictors(): + if not pred.lm: + pred.lm = dspy.settings.lm + + return program + + +def prepare_student(student: Program) -> Program: + print("Ensuring that the student is not compiled") + assert not student._compiled, "The student program should not be compiled" + + # TODO: Should we use reset_copy here? How would it affect the student + # program's predictor LMs, if they are set? + student = student.deepcopy() + return student + + +def prepare_teacher(student: Program, teacher: Program = None) -> Program: + if teacher is None: + print("No teacher provided. Using a copy of the student program as the teacher.") + return student.deepcopy() + else: + teacher = teacher.deepcopy() + + print("Ensuring that the student and teacher are are structurally equivalent.") + assert_structural_equivalency(student, teacher) + + print("Ensuring that the student and teacher programs do not share predictors.") + assert_no_shared_predictor(student, teacher) + + return teacher + + +def assert_structural_equivalency(program1: object, program2: object): + assert isinstance(program1, Program) + assert isinstance(program2, Program) + + num1 = len(program1.predictors()) + num2 = len(program2.predictors()) + err = f"Structurally equivalent programs must have the the number of predictors. The number of predictors for the two modules do not match: {num1} != {num2}" + assert num1 == num2, err + + pzip = zip(program1.named_predictors(), program2.named_predictors()) + for ind, ((name1, pred1), (name2, pred2)) in enumerate(pzip): + err = f"Program predictor names must match at corresponding indices for structural equivalency. The predictor names for the programs do not match at index {ind}: '{name1}' != '{name2}'" + assert name1 == name2, err + assert isinstance(pred1, Predict) + assert isinstance(pred2, Predict) + assert pred1.signature.equals(pred2.signature) + + +def assert_no_shared_predictor(program1: Program, program2: Program): + id_to_name1 = {id(p): n for n, p in program1.named_predictors()} + id_to_name2 = {id(p): n for n, p in program2.named_predictors()} + shared_ids = set(id_to_name1.keys()) & set(id_to_name2.keys()) + + pred_names = ", ".join(id_to_name1[id] for id in shared_ids) + err = f"The programs share the following predictor(s) with each other: {pred_names}" + assert not shared_ids, err diff --git a/dspy/teleprompt/finetune.py b/dspy/teleprompt/finetune.py deleted file mode 100644 index 4ab233479..000000000 --- a/dspy/teleprompt/finetune.py +++ /dev/null @@ -1,189 +0,0 @@ -import os -import random -import time - -import ujson -from datasets.fingerprint import Hasher - -import dsp -from dspy.signatures.signature import signature_to_template - -from .bootstrap import BootstrapFewShot - -# from dspy.primitives import Example -from .teleprompt import Teleprompter - -# from .vanilla import LabeledFewShot - -# from dspy.evaluate.evaluate import Evaluate - - -if os.environ.get("DSP_NOTEBOOK_CACHEDIR"): - training_data_directory = os.path.join(os.environ.get("DSP_NOTEBOOK_CACHEDIR"), "compiler") - print(training_data_directory) -else: - training_data_directory = "local_cache/compiler" - - -""" -TODO: Reduce and document the dependencies. - -# !pip install evaluate -# !pip install tensorboardX -# !pip install transformers[torch] -# !pip install accelerate -U -# !pip install rouge_score - - -fewshot_teleprompter = BootstrapFewShot(metric=lambda gold, prediction, trace: gold.answer == prediction.answer, - max_bootstrapped_demos=3, max_labeled_demos=16, - teacher_settings=dict(lm=turbo)) - -fewshot = fewshot_teleprompter.compile(MyMultiHop(passages_per_hop=2), trainset=trainset) - -""" - - -class BootstrapFinetune(Teleprompter): - def __init__(self, metric=None, teacher_settings={}, multitask=True): - self.metric = metric - self.teacher_settings = teacher_settings - self.multitask = multitask - - metric = metric or (lambda *args: True) - self.teleprompter = BootstrapFewShot( - metric=metric, - max_bootstrapped_demos=999999, - max_labeled_demos=0, # FIXME: TODO: Make this zero? or param, with default as 16 or 0? - teacher_settings=teacher_settings, - ) - - if not os.path.exists(training_data_directory): - os.makedirs(training_data_directory) - - def compile( - self, - student, - *, - teacher=None, - trainset, - valset=None, - target="t5-large", - bsize=12, - accumsteps=1, - lr=5e-5, - epochs=1, - bf16=False, - int8=False, - peft=False, - path_prefix=None, - ): - # It's usually better to supply a few-shot teacher, rather than uncompiled module (the student). - if teacher is None: - print( - "WARNING: Using a vanilla teacher. " - "Are you sure you want to use BootstrapFinetune without a compiled teacher?", - ) - - teachers = teacher if isinstance(teacher, list) else [teacher] - finetune_data = {} - - for teacher in teachers: - # Dummy compilation to get bootstraps. - compiled = self.teleprompter.compile(student, teacher=teacher, trainset=trainset) - multitask = self.multitask - - # Prepare finetune pairs. - for name, predictor in compiled.named_predictors(): - name_ = "all" if multitask else name - finetune_data[name_] = [] if name_ not in finetune_data else finetune_data[name_] - - for demo in predictor.demos: - demo = dict(demo) - - # TODO: FIXME: generalize. - template = signature_to_template(predictor.signature) - completion = demo.pop(template.fields[-1].output_variable) - prompt = template.query(dsp.Example(demos=[], **demo)).strip() - - finetune_data[name_].append(dict(prompt=prompt, completion=completion)) - - for name_ in finetune_data: - random.Random(0).shuffle(finetune_data[name_]) - print(name_, len(finetune_data[name_])) - - # - # Dump as files. - # - finetune_paths = {} - - for name in finetune_data: - data = finetune_data[name] - hashed_name = name + "." + Hasher.hash(data) - output_path = os.path.join(training_data_directory, f"{hashed_name}.jsonl") - print(output_path) - - with open(output_path, "w") as f: - for line in data: - f.write(ujson.dumps(line) + "\n") - - finetune_paths[name] = output_path - - # - # Train! - # - import string - - compiler_config = { - "save": "".join( - random.Random(time.time()).choices(string.ascii_uppercase + string.digits, k=13), - ), # https://stackoverflow.com/a/2257449/1493011 - "peft": peft, - "fp16": False, - "bf16": bf16, - "int8": int8, - "fid": False, - "rationale": False, - "batch_size": bsize, - "epochs": epochs, - "gradient_accumulation_steps": accumsteps, # 2, - "lr": lr, - } - - compiler_config["save"] = ( - os.path.join(path_prefix, compiler_config["save"]) if path_prefix else compiler_config["save"] - ) - - from dsp.modules.finetuning import finetune_hf - - target = target - finetune_models = {} - - for name in finetune_data: - training_data_path = finetune_paths[name] - compiler_config_ = dict(compiler_config) - compiler_config_["save"] = compiler_config["save"] + "." + name - best_ckpt_path = finetune_hf(training_data_path, target, compiler_config_) - - print(f"#> Best checkpoint path: {best_ckpt_path} for {name}") - finetune_models[name] = dsp.HFModel(model=target, checkpoint=best_ckpt_path) # best_ckpt_path - - # - # Set the LMs to the finetuned ones, per module - # - compiled2 = compiled.reset_copy() - - assert len(compiled.named_predictors()) == len(compiled2.named_predictors()) - - for (name, predictor), (name2, predictor2) in zip(compiled.named_predictors(), compiled2.named_predictors()): - assert name == name2 - name = "all" if multitask else name - - # TODO: FIXME: When we assign .lm, the Predict.forward will also set only_query=True. - # This is correct for here but we may want to make it more explicitly restricted to finetuned models. - print(f"Assigning the LM of predictor {name}.") - - predictor2.lm = finetune_models[name] - assert predictor2.demos == [] - - return compiled2 diff --git a/dspy/teleprompt/finetune_teleprompter.py b/dspy/teleprompt/finetune_teleprompter.py deleted file mode 100644 index 57246fea6..000000000 --- a/dspy/teleprompt/finetune_teleprompter.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging -from typing import Any, Callable, Dict, List, Optional, Union - -import dspy -from dspy.evaluate.evaluate import Evaluate -from dspy.primitives.example import Example -from dspy.primitives.prediction import Prediction -from dspy.primitives.program import Program - -logger = logging.getLogger(__name__) - - -# TODO: Shared below are useful functions. Similar procedures are implemented -# separately and used by other DSPy teleprompters. These can be moved to shared -# locations. -def prepare_teacher(student: Program, teacher: Program = None) -> Program: - """Prepare the teacher program with respect to the student program. - Args: - student: The student program. - teacher: The teacher program. If `None`, a copy of the student program - is used as the teacher. Defaults to `None`. - """ - # If teacher is None, use a copy of the student program as the teacher - if teacher is None: - logger.info("No teacher provided. Using a copy of the student program as the teacher.") - teacher = student.deepcopy() - else: - teacher = teacher.deepcopy() - - # Ensure that the student and teacher programs have the same structure - logger.info("Ensuring that the student and teacher are are structurally equivalent.") - student._assert_structural_equivalency(teacher) - - # Ensure that the predictors of the programs point to different objects - logger.info("Ensuring that the student and teacher programs do not share predictors.") - student._assert_no_shared_predictor(teacher) - - # Ensure that the LM consistency property is satisfied - logger.info("Ensuring that the teacher program satisfies the LM consistency property.") - teacher._assert_lm_consistency() - - # If the global LM is being used, set it to the LMs of the copied teacher - # program predictors to to avoid handling the same edge cases later - if dspy.settings.lm: - teacher._set_all_predictor_lms(dspy.settings.lm) - - return teacher - - -def convert_to_module_level_message_data( - data: List[Dict], - keep_data_keys: bool = False, - exclude_demos: bool = False, - try_to_record_lm_kwargs: bool = False, - program: Program = None, -) -> List[Dict]: - """Wrapper around the function - `build_messages_from_trace`, calling it on the "trace" field - of each dictionary in the input data list and combiningin the results into - a list of prompt-completion data dictionaries.""" - - prompt_completion_data = [] - for data_dict in data: - trace = data_dict["trace"] - trace_prompt_completion_data = build_messages_from_trace( - trace=trace, exclude_demos=exclude_demos, try_to_record_lm_kwargs=try_to_record_lm_kwargs, program=program - ) - for prompt_completion_dict in trace_prompt_completion_data: - if keep_data_keys: - prompt_completion_dict = {**data_dict, **prompt_completion_dict} - prompt_completion_data.append(prompt_completion_dict) - return prompt_completion_data - - -def build_messages_from_trace( - trace: List[Dict], - exclude_demos: bool = False, - try_to_record_lm_kwargs: bool = False, - program: Program = None, -) -> Dict[str, List[Dict[str, Any]]]: - messages = [] - # If the program is provided, build the predictor index to name mapping - if program: - pred_ind_to_name = {ind: name for ind, (name, _) in enumerate(program.named_predictors())} - - # Build the prompt-completion data - - adapter = dspy.settings.adapter or dspy.ChatAdapter() - data = [] - - # TODO: Make sure that this works for multi-stage pipelines - for pred_ind, (pred, inputs, outputs) in enumerate(trace): - # Get the demos from the predictor if exclude_demos is False - demos = [] if exclude_demos else pred.demos - messages = adapter.format(pred.signature, demos, inputs) - messages.append( - adapter.format_turn(signature=pred.signature, values=outputs, role="assistant", incomplete=False) - ) - data.append(messages) - - return data - - -def bootstrap_data( - program: Program, - dataset: List[Example], - metric: Optional[Callable[[Example, Prediction, Optional[List]], Union[bool, int, float]]] = None, - num_threads=1, - max_errors: int = 0, -) -> List[Dict[str, Any]]: - """Bootstrap example, prediction, trace, example_ind, score data for the program using the dataset.""" - data = [] - - # Use Evaluate to call the program have the responses cached - cname = program.__class__.__name__ - info = f"Bootstrapping data on {len(dataset)} examples with the program {cname}, with {num_threads} threads" - logger.info(info) - evaluator = Evaluate( - devset=dataset, - num_threads=num_threads, - display_progress=True, - max_errors=max_errors, - provide_traceback=True, - ) - evaluator(program, metric=metric) - - data = [] - for example_ind, example in enumerate(dataset): - data_dict = bootstrap_one_example(example=example, example_ind=example_ind, program=program, metric=metric) - if data_dict is not None: - data.append(data_dict) - - return data - - -def bootstrap_one_example( - example: Any, example_ind: int, program: Program, metric: Optional[Callable] = None -) -> Dict[str, Any]: - with dspy.context(trace=[]): - prediction = program(**example.inputs()) - trace = dspy.settings.trace - score = metric(example, prediction, trace) if metric else None - - data_dict = {"example": example, "prediction": prediction, "trace": trace, "example_ind": example_ind} - if metric: - data_dict["score"] = score - - return data_dict diff --git a/dspy/teleprompt/mipro_optimizer_v2.py b/dspy/teleprompt/mipro_optimizer_v2.py index 6f1ffb762..08ed6ea81 100644 --- a/dspy/teleprompt/mipro_optimizer_v2.py +++ b/dspy/teleprompt/mipro_optimizer_v2.py @@ -239,7 +239,7 @@ def _set_and_validate_datasets(self, trainset: List, valset: Optional[List]): raise ValueError( "Trainset must have at least 2 examples if no valset specified." ) - valset_size = min(500, max(1, int(len(trainset) * 0.80))) + valset_size = min(1000, max(1, int(len(trainset) * 0.80))) cutoff = len(trainset) - valset_size valset = trainset[cutoff:] trainset = trainset[:cutoff] diff --git a/dspy/utils/__init__.py b/dspy/utils/__init__.py index f16061eb8..f12b34b18 100644 --- a/dspy/utils/__init__.py +++ b/dspy/utils/__init__.py @@ -1,3 +1,4 @@ from dspy.utils.callback import BaseCallback, with_callbacks from dspy.utils.dummies import * +from dspy.utils.caching import * from dspy.utils.logging_utils import * diff --git a/dspy/utils/caching.py b/dspy/utils/caching.py new file mode 100644 index 000000000..7bd0d50c4 --- /dev/null +++ b/dspy/utils/caching.py @@ -0,0 +1,14 @@ +import os +from pathlib import Path + + +_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".dspy_cache") +DSPY_CACHEDIR = os.environ.get("DSPY_CACHEDIR") or _DEFAULT_CACHE_DIR + + +def create_subdir_in_cachedir(subdir: str) -> str: + """Create a subdirectory in the DSPy cache directory.""" + subdir = os.path.join(DSPY_CACHEDIR, subdir) + subdir = os.path.abspath(subdir) + os.makedirs(subdir, exist_ok=True) + return subdir diff --git a/examples/finetune/_internal_finetune_demo.ipynb b/examples/finetune/_internal_finetune_demo.ipynb new file mode 100644 index 000000000..9b95cf301 --- /dev/null +++ b/examples/finetune/_internal_finetune_demo.ipynb @@ -0,0 +1,1505 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/scr-ssd/dilara/.miniconda3/envs/dspyprod/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "# Enable reloading on code changes\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "# Setting the environment variables\n", + "import os\n", + "\n", + "# os.environ[\"DSPY_CACHEDIR\"] =\n", + "# os.environ[\"DSP_CACHEDIR\"] =\n", + "# os.environ[\"OPENAI_API_KEY\"] =\n", + "\n", + "# Import the library\n", + "import dspy\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## I. Showcasing `LM.finetune()`" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import time" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[\"The average distance from the Earth to the Moon is about 238,855 miles (384,400 kilometers). However, this distance can vary slightly due to the Moon's elliptical orbit, ranging from approximately 225,623 miles (363,104 kilometers) at its closest (perigee) to about 252,088 miles (405,696 kilometers) at its farthest (apogee).\"]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Example call to an LM before fine-tuning\n", + "lm = dspy.LM('gpt-4o-mini-2024-07-18')\n", + "lm(\"How far is the Moon from Earth?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Using LM.finetune(), BSFT, and BetterTogether requires this flag\n", + "dspy.settings.experimental = True" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Finetune] Validating the data format" + ] + }, + { + "data": { + "text/plain": [ + "dspy.clients.openai.TrainingJobOpenAI" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "[Finetune] Saving the data to a file\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Finetune] Uploading the data to the provider\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Finetune] Start remote training\n", + "[Finetune] Wait for training to complete\n", + "[Finetune] Get trained model if the run was a success\n" + ] + } + ], + "source": [ + "# Let's construct a dummy dataset\n", + "message = {\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": \"Marv is a factual chatbot that is also sarcastic.\"},\n", + " {\"role\": \"user\", \"content\": \"How far is the Moon from Earth?\"},\n", + " {\"role\": \"assistant\", \"content\": \"384,400 kilometers\"},\n", + " ]\n", + "}\n", + "training_data = [message] * 20\n", + "\n", + "# Let's finetune the model\n", + "train_kwargs = {\n", + " \"n_epochs\": 1,\n", + "}\n", + "\n", + "job = lm.finetune(\n", + " train_data=training_data,\n", + " train_kwargs=train_kwargs,\n", + " data_format=\"chat\", # Could be left empty, inferred from \"lm.model_type\" as a default\n", + ")\n", + "type(job)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "Running the cell below immediately after the cell above returns `False`, indicating that the job is not done." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This will return False until the job is complete\n", + "job.done()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once started, a `job` object can be polled for status, assuming that a provider has implemented the status checking.\n", + "Note: It takes a bit for the `job.done()` to update once `job.status()` turns to succeeded." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TrainingStatus.pending\n", + "TrainingStatus.pending\n", + "TrainingStatus.pending\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.running\n", + "TrainingStatus.succeeded\n", + "TrainingStatus.succeeded\n", + "TrainingStatus.succeeded\n" + ] + } + ], + "source": [ + "while not job.done():\n", + " print(job.status())\n", + " time.sleep(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Base model: gpt-4o-mini-2024-07-18\n", + "Fine-tuned model: ft:gpt-4o-mini-2024-07-18:stanford::AOHSK6y9\n" + ] + } + ], + "source": [ + "# Once the job is complete, the fine-tuned LM can be obtained via job.result()\n", + "finetuned_lm = job.result()\n", + "print(finetuned_lm)\n", + "\n", + "# We can look at the model IDs to ensure that the fine-tuned model is different\n", + "print(f\"Base model: {lm.model}\")\n", + "print(f\"Fine-tuned model: {finetuned_lm.model}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['384,400 kilometers']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# We can check how the fine-tuned LM responds to the query we used for\n", + "# fine-tuning.\n", + "finetuned_lm(\"How far is the Moon from Earth?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## II. LM fine-tuning with a custom `Provider`" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from typing import Any, Dict, List, Optional\n", + "from dspy.clients.provider import Provider, TrainingJob, DataFormat\n", + "\n", + "# Using LM.finetune(), BSFT, and BetterTogether requires this flag\n", + "dspy.settings.experimental = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we define a custom provider with a dummy fine-tune method." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomProvider(Provider):\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.finetunable = True\n", + "\n", + " @staticmethod\n", + " def finetune(\n", + " job: TrainingJob,\n", + " model: str,\n", + " train_data: List[Dict[str, Any]],\n", + " train_kwargs: Optional[Dict[str, Any]] = None,\n", + " data_format: Optional[DataFormat] = None,\n", + " ) -> str:\n", + "\n", + " # Fake fine-tuning\n", + " print(\"Fake fine-tuning has started!!\")\n", + " time.sleep(15)\n", + " print(\"Done\")\n", + "\n", + " # Return the new model name; we are hard-coding an OpenAI model as a\n", + " # demo placeholder\n", + " model = \"ft:gpt-4o-mini-2024-07-18:stanford::AMDsC653\"\n", + " return model\n", + " \n", + " # # We could also override the launch/kill methods if needed\n", + " # def launch(model: str, launch_kwargs: dict):\n", + " # pass\n", + "\n", + " # def kill(model: str, launch_kwargs: dict):\n", + " # pass\n", + "\n", + "\n", + "# We could also create a custom TrainingJob class to implement\n", + "# .status() and .cancel() methods, but we don't have to." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "`launch()` is called for the auto-launched model openai/MyAmazingCustomModel -- no action is taken!\n", + "`kill()` is called for the auto-launched model openai/MyAmazingCustomModel -- no action is taken!\n" + ] + } + ], + "source": [ + "# We could also pass launch_kwargs if this model needs to be launched before\n", + "# use, assuming that the launch and kill methods are implemented by the\n", + "# custom provider.\n", + "launch_kwargs = {\n", + " \"gpu\": 1,\n", + " \"max_prompt_length\": 1000,\n", + "}\n", + "\n", + "# Create the LM we want to fine-tune, using a dummy model name\n", + "model = \"openai/MyAmazingCustomModel\"\n", + "provider = CustomProvider()\n", + "lm = dspy.LM(model, provider=provider, launch_kwargs=launch_kwargs)\n", + "lm.launch()\n", + "\n", + "# Query the model -- commented out because the model is not real\n", + "# lm(\"How far is the Moon from Earth?\")\n", + "\n", + "# kill the model once done\n", + "lm.kill()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fake fine-tuning has started!!" + ] + }, + { + "data": { + "text/plain": [ + "dspy.clients.provider.TrainingJob" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Try fine-tune\n", + "dspy.settings.experimental = True\n", + "\n", + "# Let's construct a dummy dataset\n", + "message = {\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": \"Marv is a factual chatbot that is also sarcastic.\"},\n", + " {\"role\": \"user\", \"content\": \"How far is the Moon from Earth?\"},\n", + " {\"role\": \"assistant\", \"content\": \"384,400 kilometers\"},\n", + " ]\n", + "}\n", + "training_data = [message] * 20\n", + "\n", + "# Let's finetune the model\n", + "train_kwargs = {\n", + " \"n_epochs\": 1,\n", + "}\n", + "\n", + "job = lm.finetune(\n", + " train_data=training_data,\n", + " train_kwargs=train_kwargs,\n", + " data_format=\"chat\"\n", + ")\n", + "type(job)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0s] job.done(): False\n", + "[20s] job.done(): True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Done\n" + ] + } + ], + "source": [ + "# Running the command below immediately after the cell above returns `False`, indicating that the job is not done.\n", + "print(f\"[0s] job.done(): {job.done()}\")\n", + "\n", + "# Wait\n", + "time.sleep(20)\n", + "\n", + "# Check again\n", + "print(f\"[20s] job.done(): {job.done()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can access the fine-tuned model as before" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['384,400 kilometers']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm = job.result()\n", + "lm(\"How far is the Moon from Earth?\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## III. Showcasing `BootstrapFinetune`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### i. Task Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example setup using HotPotQA" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "from dspy.datasets import HotPotQA\n", + "from dspy.evaluate import Evaluate\n", + "from dsp.utils.utils import deduplicate\n", + "\n", + "\n", + "# We are setting the experimental flag to True to make use of the fine-tuning\n", + "# features that are still in development.\n", + "dspy.settings.configure(experimental=True)\n", + "\n", + "# Define the program\n", + "class BasicMH(dspy.Module):\n", + " def __init__(self, passages_per_hop=3, num_hops=2):\n", + " super().__init__()\n", + " self.num_hops = 2\n", + " self.retrieve = dspy.Retrieve(k=passages_per_hop)\n", + " self.generate_query = [dspy.ChainOfThought(\"context, question -> search_query\") for _ in range(self.num_hops)]\n", + " self.generate_answer = dspy.ChainOfThought(\"context, question -> answer\")\n", + " \n", + " def forward(self, question):\n", + " context = []\n", + " \n", + " for hop in range(self.num_hops):\n", + " search_query = self.generate_query[hop](context=context, question=question).search_query\n", + " passages = self.retrieve(search_query).passages\n", + " context = deduplicate(context + passages)\n", + "\n", + " answer = self.generate_answer(context=context, question=question).copy(context=context)\n", + " return answer\n", + "\n", + "# Prepare the dataset\n", + "TRAIN_SIZE = 1000\n", + "DEV_SIZE = 500\n", + "dataset = HotPotQA(train_seed=1, eval_seed=2023, test_size=0, only_hard_examples=True)\n", + "trainset = [x.with_inputs('question') for x in dataset.train][:TRAIN_SIZE]\n", + "devset = [x.with_inputs('question') for x in dataset.dev][:DEV_SIZE]\n", + "\n", + "# Prepare the metric and evaluator\n", + "NUM_THREADS = 12\n", + "metric = dspy.evaluate.answer_exact_match\n", + "evaluate = Evaluate(devset=devset, metric=metric, num_threads=NUM_THREADS, display_progress=True)\n", + "\n", + "# Prepare the retriever model\n", + "COLBERT_V2_ENDPOINT = \"http://20.102.90.50:2017/wiki17_abstracts\"\n", + "retriever = dspy.ColBERTv2(url=COLBERT_V2_ENDPOINT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ii. Demo of `BootstrapFinetune`" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# Using LM.finetune(), BSFT, and BetterTogether requires this flag\n", + "dspy.settings.experimental = True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The cell below shows the different ways the `BootstrapFinetune` can be optimized." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# (1) BSFT can be initilized with no arguments!\n", + "weight_optimizer = dspy.BootstrapFinetune()\n", + "\n", + "# (2) Better to optimize with a metric to be used for filtering data before\n", + "# fine-tuning\n", + "weight_optimizer = dspy.BootstrapFinetune(\n", + " metric=metric\n", + ")\n", + "\n", + "# (3) Bootstrap fine-tune accepts other parameters as well, as shown below\n", + "train_kwargs = {\n", + " \"n_epochs\": 1,\n", + "}\n", + "adapter = dspy.ChatAdapter()\n", + "\n", + "weight_optimizer = dspy.BootstrapFinetune(\n", + " metric=metric, # Can be left empty, leads to no filtering\n", + " multitask=True, # We can also handle False!\n", + " train_kwargs=train_kwargs, # Can be left empty\n", + " adapter=adapter, # Can be left empty, leads to adapters inferred from the LM\n", + " exclude_demos=False, # Can be left empty\n", + " num_threads = 1, # Can be left empty\n", + ")\n", + "\n", + "\n", + "# (4) The adapter and train_kwargs arguments could be passed as dictionaries\n", + "# mapping LMs to their respective adapters/train_kwargs. This is useful when the\n", + "# predictors of the program point to different LMs.\n", + "lm = dspy.LM('gpt-4o-mini-2024-07-18')\n", + "adapter = dspy.ChatAdapter()\n", + "\n", + "train_kwargs = {\n", + " lm: {\n", + " \"n_epochs\": 1,\n", + " },\n", + " # lm2: train_kwargs2,\n", + "}\n", + "adapter = {\n", + " lm: adapter,\n", + " # lm2: adapter2,\n", + "}\n", + "\n", + "weight_optimizer = dspy.BootstrapFinetune(\n", + " metric=metric, # Can be left empty, leads to no filtering\n", + " multitask=True, # We can also handle False!\n", + " train_kwargs=train_kwargs, # Can be left empty\n", + " adapter=adapter, # Can be left empty, leads to adapters inferred from the LM\n", + " exclude_demos=False, # Can be left empty\n", + " num_threads = 1, # Can be left empty\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The cell below shows an example of running `BootstrapFinetune`" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preparing the student and teacher programs...\n", + "Ensuring that the student is not compiled.\n", + "No teacher provided. Using a copy of the student program as the teacher.\n", + "Bootstrapping data...\n", + "Average Metric: 5 / 10 (50.0): 100%|██████████| 10/10 [00:00<00:00, 416.94it/s]\n", + "Preparing the train data...\n", + "Collected data for 10 examples\n", + "After filtering for score, 5 examples remain\n", + "Using 15 data points for fine-tuning the model: gpt-4o-mini-2024-07-18\n", + "Starting LM fine-tuning...\n", + "1 fine-tuning job(s) to start.\n", + "Starting 1 fine-tuning jobs...\n", + "[OpenAI Provider] Validating the data format\n", + "[OpenAI Provider] Saving the data to a file\n", + "[OpenAI Provider] Data saved to /scr-ssd/dilara/.cache/dspy-new/finetune/4fd944783dbb3639.jsonl\n", + "[OpenAI Provider] Uploading the data to the provider\n", + "[OpenAI Provider] Start remote training\n", + "[OpenAI Provider] Job started with the OpenAI Job ID ftjob-gUXOgSEqEyV3v9Q6JgAqic9G\n", + "[OpenAI Provider] Wait for training to complete\n", + "[OpenAI Provider] Attempting to retrieve the trained model\n", + "[OpenAI Provider] Model retrieved: ft:gpt-4o-mini-2024-07-18:stanford::AOI4YHf2\n", + "Job 1/1 completed.\n", + "Updating the student program with the fine-tuned LMs...\n", + "BootstrapFinetune has finished compiling the student program.\n" + ] + } + ], + "source": [ + "# Using method (3) from above to create a weight-optimized program\n", + "train_kwargs = {\n", + " \"n_epochs\": 1,\n", + "}\n", + "adapter = dspy.ChatAdapter()\n", + "\n", + "weight_optimizer = dspy.BootstrapFinetune(\n", + " metric=metric, # Can be left empty, leads to no filtering\n", + " multitask=True, # We can also handle False!\n", + " train_kwargs=train_kwargs, # Can be left empty\n", + " adapter=adapter, # Can be left empty, leads to adapters inferred from the LM\n", + " exclude_demos=False, # Can be left empty\n", + " num_threads = 1, # Can be left empty\n", + ")\n", + "\n", + "lm = dspy.LM('gpt-4o-mini-2024-07-18')\n", + "small_trainset = trainset[:10] # Use a small subset of the training data\n", + "\n", + "with dspy.context(lm=lm, rm=retriever):\n", + " weight_optimized_program = weight_optimizer.compile(\n", + " student=BasicMH(),\n", + " trainset=small_trainset,\n", + " teacher=None, # Doesn't need to be set, student is used as the teacher by default\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ft:gpt-4o-mini-2024-07-18:stanford::AOI4YHf2\n", + "ft:gpt-4o-mini-2024-07-18:stanford::AOI4YHf2\n", + "ft:gpt-4o-mini-2024-07-18:stanford::AOI4YHf2\n" + ] + } + ], + "source": [ + "for p in weight_optimized_program.predictors():\n", + " print(p.lm.model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IV. Demo of `BetterTogether`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### i. Task Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Example setup using HotPotQA" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "from dspy.datasets import HotPotQA\n", + "from dspy.evaluate import Evaluate\n", + "from dsp.utils.utils import deduplicate\n", + "\n", + "\n", + "# We are setting the experimental flag to True to make use of the fine-tuning\n", + "# features that are still in development.\n", + "dspy.settings.configure(experimental=True)\n", + "\n", + "# Define the program\n", + "class BasicMH(dspy.Module):\n", + " def __init__(self, passages_per_hop=3, num_hops=2):\n", + " super().__init__()\n", + " self.num_hops = 2\n", + " self.retrieve = dspy.Retrieve(k=passages_per_hop)\n", + " self.generate_query = [dspy.ChainOfThought(\"context, question -> search_query\") for _ in range(self.num_hops)]\n", + " self.generate_answer = dspy.ChainOfThought(\"context, question -> answer\")\n", + " \n", + " def forward(self, question):\n", + " context = []\n", + " \n", + " for hop in range(self.num_hops):\n", + " search_query = self.generate_query[hop](context=context, question=question).search_query\n", + " passages = self.retrieve(search_query).passages\n", + " context = deduplicate(context + passages)\n", + "\n", + " answer = self.generate_answer(context=context, question=question).copy(context=context)\n", + " return answer\n", + "\n", + "# Prepare the dataset\n", + "TRAIN_SIZE = 1000\n", + "DEV_SIZE = 500\n", + "dataset = HotPotQA(train_seed=1, eval_seed=2023, test_size=0, only_hard_examples=True)\n", + "trainset = [x.with_inputs('question') for x in dataset.train][:TRAIN_SIZE]\n", + "devset = [x.with_inputs('question') for x in dataset.dev][:DEV_SIZE]\n", + "\n", + "# Prepare the metric and evaluator\n", + "NUM_THREADS = 12\n", + "metric = dspy.evaluate.answer_exact_match\n", + "evaluate = Evaluate(devset=devset, metric=metric, num_threads=NUM_THREADS, display_progress=True)\n", + "\n", + "# Prepare the retriever model\n", + "COLBERT_V2_ENDPOINT = \"http://20.102.90.50:2017/wiki17_abstracts\"\n", + "retriever = dspy.ColBERTv2(url=COLBERT_V2_ENDPOINT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ii. Demo" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Using LM.finetune(), BSFT, and BetterTogether requires this flag\n", + "dspy.settings.experimental = True\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Going to sample between 1 and 4 traces per predictor.\n", + "Will attempt to bootstrap 16 candidate sets.\n" + ] + } + ], + "source": [ + "# (1) The only required argument we require for BetterTogether is the metric\n", + "better_together = dspy.BetterTogether(\n", + " metric=metric # This is the only metric we require!\n", + " # We could also consider not requiring it if BootstrapFewShotWithRandomSearch is modified.\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Going to sample between 1 and 3 traces per predictor.\n", + "Will attempt to bootstrap 6 candidate sets.\n" + ] + } + ], + "source": [ + "# (2) We can also pass the weight and prompt optimizers we initialized\n", + "train_kwargs = {\n", + " \"n_epochs\": 1,\n", + "}\n", + "adapter = dspy.ChatAdapter()\n", + "\n", + "weight_optimizer = dspy.BootstrapFinetune(\n", + " metric=metric, # Can be left empty, leads to no filtering\n", + " multitask=True, # We can also handle False!\n", + " train_kwargs=train_kwargs, # Can be left empty\n", + " adapter=adapter, # Can be left empty, leads to adapters inferred from the LM\n", + " exclude_demos=True, # We are dropping the demos for fine-tuning \n", + " num_threads = 1, # Can be left empty\n", + ")\n", + "\n", + "prompt_optimizer = dspy.BootstrapFewShotWithRandomSearch(\n", + " metric=metric,\n", + " max_bootstrapped_demos=3,\n", + " max_labeled_demos=3,\n", + " num_candidate_programs=6,\n", + " num_threads=6\n", + ")\n", + "\n", + "# Initialize BetterTogether\n", + "better_together = dspy.BetterTogether(\n", + " metric=metric,\n", + " weight_optimizer=weight_optimizer, # Can be left empty\n", + " prompt_optimizer=prompt_optimizer, # Can be left empty\n", + " seed=2023, # Can be left empty\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[BetterTogether] Validating the strategy\n", + "[BetterTogether] Preparing the student program...\n", + "Ensuring that the student is not compiled\n", + "[BetterTogether] Compiling the student program...\n", + "[BetterTogether] Step 1 of 3 - Strategy `p`\n", + "[BetterTogether] Shuffling the trainset...\n", + "[BetterTogether] Preparing for prompt optimization...\n", + "[BetterTogether] Launching the program LMs for sampling...\n", + "`launch()` is called for the auto-launched model `gpt-4o-mini-2024-07-18` -- no action is taken!\n", + "[BetterTogether] Compiling the prompt optimizer...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 2 / 5 (40.0): 100%|██████████| 5/5 [00:02<00:00, 2.05it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best score: 40.0 for seed -3\n", + "Scores so far: [40.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 2 / 5 (40.0): 100%|██████████| 5/5 [00:02<00:00, 2.30it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 5/45 [00:11<01:28, 2.21s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 3 full traces after 6 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:06<00:00, 1.25s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 5/45 [00:18<02:25, 3.63s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 2 full traces after 6 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:07<00:00, 1.43s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0, 20.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 7%|▋ | 3/45 [00:05<01:10, 1.67s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 4 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 2 / 5 (40.0): 100%|██████████| 5/5 [00:05<00:00, 1.11s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0, 20.0, 40.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:02<01:32, 2.10s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:09<00:00, 1.84s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0, 20.0, 40.0, 20.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:01<01:04, 1.46s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:06<00:00, 1.28s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0, 20.0, 40.0, 20.0, 20.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:04<03:36, 4.91s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 2 / 5 (40.0): 100%|██████████| 5/5 [00:08<00:00, 1.60s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [40.0, 40.0, 20.0, 20.0, 40.0, 20.0, 20.0, 40.0]\n", + "Best score so far: 40.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 16%|█▌ | 7/45 [00:11<01:02, 1.65s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 3 full traces after 8 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:06<00:00, 1.37s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best score: 60.0 for seed 5\n", + "Scores so far: [40.0, 40.0, 20.0, 20.0, 40.0, 20.0, 20.0, 40.0, 60.0]\n", + "Best score so far: 60.0\n", + "9 candidate programs found.\n", + "[BetterTogether] Killing the LMs used for sampling...\n", + "`kill()` is called for the auto-launched model `gpt-4o-mini-2024-07-18` -- no action is taken!\n", + "[BetterTogether] Step 2 of 3 - Strategy `p -> w`\n", + "[BetterTogether] Shuffling the trainset...\n", + "[BetterTogether] Preparing for weight optimization...\n", + "[BetterTogether] Compiling the weight optimizer...\n", + "[BootstrapFinetune] Preparing the student and teacher programs...\n", + "Ensuring that the student is not compiled\n", + "No teacher provided. Using a copy of the student program as the teacher.\n", + "[BootstrapFinetune] Bootstrapping data...\n", + "Average Metric: 28 / 50 (56.0): 100%|██████████| 50/50 [03:55<00:00, 4.72s/it]\n", + "[BootstrapFinetune] Preparing the train data...\n", + "[BootstrapFinetune] Collected data for 50 examples\n", + "[BootstrapFinetune] After filtering for score, 28 examples remain\n", + "Using 84 data points for fine-tuning the model: gpt-4o-mini-2024-07-18\n", + "[BootstrapFinetune] Starting LM fine-tuning...\n", + "[BootstrapFinetune] 1 fine-tuning job(s) to start\n", + "[BootstrapFinetune] Starting 1 fine-tuning jobs...\n", + "[OpenAI Provider] Validating the data format\n", + "[OpenAI Provider] Saving the data to a file\n", + "[OpenAI Provider] Data saved to /scr-ssd/dilara/.cache/dspy-new/finetune/c4bb7c084d3a7ad3.jsonl\n", + "[OpenAI Provider] Uploading the data to the provider\n", + "[OpenAI Provider] Starting remote training\n", + "[OpenAI Provider] Job started with the OpenAI Job ID ftjob-75MiMHGY9xW1gNMGZSs5ht8v\n", + "[OpenAI Provider] Waiting for training to complete\n", + "[OpenAI Provider] Attempting to retrieve the trained model\n", + "[OpenAI Provider] Model retrieved: ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8\n", + "Job 1/1 completed.\n", + "[BootstrapFinetune] Updating the student program with the fine-tuned LMs...\n", + "[BootstrapFinetune] BootstrapFinetune has finished compiling the student program\n", + "[BetterTogether] Step 3 of 3 - Strategy `p -> w -> p`\n", + "[BetterTogether] Shuffling the trainset...\n", + "[BetterTogether] Preparing for prompt optimization...\n", + "[BetterTogether] Launching the program LMs for sampling...\n", + "`launch()` is called for the auto-launched model `ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8` -- no action is taken!\n", + "[BetterTogether] Compiling the prompt optimizer...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:02<00:00, 2.42it/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best score: 20.0 for seed -3\n", + "Scores so far: [20.0]\n", + "Best score so far: 20.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:01<00:00, 3.57it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best score: 60.0 for seed -2\n", + "Scores so far: [20.0, 60.0]\n", + "Best score so far: 60.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 9%|▉ | 4/45 [00:08<01:22, 2.02s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 3 full traces after 5 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:06<00:00, 1.21s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0]\n", + "Best score so far: 60.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|▍ | 2/45 [00:02<01:03, 1.48s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 2 full traces after 3 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 4 / 5 (80.0): 100%|██████████| 5/5 [00:06<00:00, 1.22s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best score: 80.0 for seed 0\n", + "Scores so far: [20.0, 60.0, 60.0, 80.0]\n", + "Best score so far: 80.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:01<00:58, 1.32s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 1 / 5 (20.0): 100%|██████████| 5/5 [00:06<00:00, 1.39s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0, 80.0, 20.0]\n", + "Best score so far: 80.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 7%|▋ | 3/45 [00:04<01:00, 1.43s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 4 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:06<00:00, 1.27s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0, 80.0, 20.0, 60.0]\n", + "Best score so far: 80.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:01<00:47, 1.07s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:07<00:00, 1.43s/it] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0, 80.0, 20.0, 60.0, 60.0]\n", + "Best score so far: 80.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 2%|▏ | 1/45 [00:03<02:19, 3.17s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 1 full traces after 2 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:05<00:00, 1.18s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0, 80.0, 20.0, 60.0, 60.0, 60.0]\n", + "Best score so far: 80.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 5/45 [00:14<01:58, 2.96s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bootstrapped 3 full traces after 6 examples in round 0.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Average Metric: 3 / 5 (60.0): 100%|██████████| 5/5 [00:06<00:00, 1.31s/it] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scores so far: [20.0, 60.0, 60.0, 80.0, 20.0, 60.0, 60.0, 60.0, 60.0]\n", + "Best score so far: 80.0\n", + "9 candidate programs found.\n", + "[BetterTogether] Killing the LMs used for sampling...\n", + "`kill()` is called for the auto-launched model `ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8` -- no action is taken!\n", + "[BetterTogether] BetterTogether has finished compiling the student program.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Running BetterTogether on a small dataset\n", + "\n", + "lm = dspy.LM('gpt-4o-mini-2024-07-18')\n", + "small_trainset = trainset[:50] # Use a small subset of the training data\n", + "\n", + "with dspy.context(lm=lm, rm=retriever):\n", + " optimized_program = better_together.compile(\n", + " student=BasicMH(),\n", + " trainset=small_trainset,\n", + " strategy=\"p -> w -> p\",\n", + " valset_ratio=0.1\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8\n", + "ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8\n", + "ft:gpt-4o-mini-2024-07-18:stanford::AOIn7Dm8\n" + ] + } + ], + "source": [ + "for p in optimized_program.predictors():\n", + " print(p.lm.model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dspy25", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/dsp_LM/predict/test_chain_of_thought_with_hint.py b/tests/dsp_LM/predict/test_chain_of_thought_with_hint.py index 0afa06b1c..d06e72a36 100644 --- a/tests/dsp_LM/predict/test_chain_of_thought_with_hint.py +++ b/tests/dsp_LM/predict/test_chain_of_thought_with_hint.py @@ -1,43 +1,43 @@ -import dspy -from dspy import ChainOfThoughtWithHint -from dspy.utils import DSPDummyLM +# import dspy +# from dspy import ChainOfThoughtWithHint +# from dspy.utils import DSPDummyLM -def test_cot_with_no_hint(): - lm = DSPDummyLM(["find the number after 1", "2"]) - dspy.settings.configure(lm=lm) - predict = ChainOfThoughtWithHint("question -> answer") - # Check output fields have the right order - assert list(predict.extended_signature2.output_fields.keys()) == [ - "rationale", - "hint", - "answer", - ] - assert predict(question="What is 1+1?").answer == "2" +# # def test_cot_with_no_hint(): +# # lm = DSPDummyLM(["find the number after 1", "2"]) +# # dspy.settings.configure(lm=lm) +# # predict = ChainOfThoughtWithHint("question -> answer") +# # # Check output fields have the right order +# # assert list(predict.extended_signature2.output_fields.keys()) == [ +# # "rationale", +# # "hint", +# # "answer", +# # ] +# # assert predict(question="What is 1+1?").answer == "2" - final_convo = lm.get_convo(-1) - assert final_convo.endswith( - "Question: What is 1+1?\n" - "Reasoning: Let's think step by step in order to find the number after 1\n" - "Answer: 2" - ) +# # final_convo = lm.get_convo(-1) +# # assert final_convo.endswith( +# # "Question: What is 1+1?\n" +# # "Reasoning: Let's think step by step in order to find the number after 1\n" +# # "Answer: 2" +# # ) -def test_cot_with_hint(): - lm = DSPDummyLM(["find the number after 1", "2"]) - dspy.settings.configure(lm=lm) - predict = ChainOfThoughtWithHint("question -> answer") - assert list(predict.extended_signature2.output_fields.keys()) == [ - "rationale", - "hint", - "answer", - ] - assert predict(question="What is 1+1?", hint="think small").answer == "2" +# # def test_cot_with_hint(): +# # lm = DSPDummyLM(["find the number after 1", "2"]) +# # dspy.settings.configure(lm=lm) +# # predict = ChainOfThoughtWithHint("question -> answer") +# # assert list(predict.extended_signature2.output_fields.keys()) == [ +# # "rationale", +# # "hint", +# # "answer", +# # ] +# # assert predict(question="What is 1+1?", hint="think small").answer == "2" - final_convo = lm.get_convo(-1) - assert final_convo.endswith( - "Question: What is 1+1?\n\n" - "Reasoning: Let's think step by step in order to find the number after 1\n\n" - "Hint: think small\n\n" - "Answer: 2" - ) +# # final_convo = lm.get_convo(-1) +# # assert final_convo.endswith( +# # "Question: What is 1+1?\n\n" +# # "Reasoning: Let's think step by step in order to find the number after 1\n\n" +# # "Hint: think small\n\n" +# # "Answer: 2" +# # ) diff --git a/tests/predict/test_chain_of_thought_with_hint.py b/tests/predict/test_chain_of_thought_with_hint.py index 77c4f9ac2..68ede17dc 100644 --- a/tests/predict/test_chain_of_thought_with_hint.py +++ b/tests/predict/test_chain_of_thought_with_hint.py @@ -1,30 +1,30 @@ -import textwrap +# import textwrap -import dspy -from dspy import ChainOfThoughtWithHint -from dspy.utils import DummyLM +# import dspy +# from dspy import ChainOfThoughtWithHint +# from dspy.utils import DummyLM -def test_cot_with_no_hint(): - lm = DummyLM([{"rationale": "find the number after 1", "answer": "2"}]) - dspy.settings.configure(lm=lm) - predict = ChainOfThoughtWithHint("question -> answer") - # Check output fields have the right order - assert list(predict.extended_signature2.output_fields.keys()) == [ - "rationale", - "hint", - "answer", - ] - assert predict(question="What is 1+1?").answer == "2" +# def test_cot_with_no_hint(): +# lm = DummyLM([{"rationale": "find the number after 1", "answer": "2"}]) +# dspy.settings.configure(lm=lm) +# predict = ChainOfThoughtWithHint("question -> answer") +# # Check output fields have the right order +# assert list(predict.extended_signature2.output_fields.keys()) == [ +# "rationale", +# "hint", +# "answer", +# ] +# assert predict(question="What is 1+1?").answer == "2" -def test_cot_with_hint(): - lm = DummyLM([{"rationale": "find the number after 1", "hint": "Is it helicopter?", "answer": "2"}]) - dspy.settings.configure(lm=lm) - predict = ChainOfThoughtWithHint("question -> answer") - assert list(predict.extended_signature2.output_fields.keys()) == [ - "rationale", - "hint", - "answer", - ] - assert predict(question="What is 1+1?", hint="think small").answer == "2" +# def test_cot_with_hint(): +# lm = DummyLM([{"rationale": "find the number after 1", "hint": "Is it helicopter?", "answer": "2"}]) +# dspy.settings.configure(lm=lm) +# predict = ChainOfThoughtWithHint("question -> answer") +# assert list(predict.extended_signature2.output_fields.keys()) == [ +# "rationale", +# "hint", +# "answer", +# ] +# assert predict(question="What is 1+1?", hint="think small").answer == "2" From c64b9756dde794e03cdf72033283a0be8da39c15 Mon Sep 17 00:00:00 2001 From: Isaac Miller <17116851+isaacbmiller@users.noreply.github.com> Date: Thu, 7 Nov 2024 12:11:28 -0800 Subject: [PATCH 29/31] Fix ruff error (#1772) * Trigger Build * Update ruff version * Fix ruff errors and give more verbose debug messages --- .github/workflows/run_tests.yml | 2 +- examples/finetune/_internal_finetune_demo.ipynb | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0e5583d23..d5eacbae8 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -24,7 +24,7 @@ jobs: id: ruff_fix uses: chartboost/ruff-action@v1 with: - args: check --fix-only --exit-non-zero-on-fix + args: check --fix-only --diff --exit-non-zero-on-fix continue-on-error: true - name: Fail Workflow if Ruff Fix Failed if: steps.ruff_fix.outcome == 'failure' diff --git a/examples/finetune/_internal_finetune_demo.ipynb b/examples/finetune/_internal_finetune_demo.ipynb index 9b95cf301..bde92e4e4 100644 --- a/examples/finetune/_internal_finetune_demo.ipynb +++ b/examples/finetune/_internal_finetune_demo.ipynb @@ -20,7 +20,7 @@ "%autoreload 2\n", "\n", "# Setting the environment variables\n", - "import os\n", + "import os # noqa\n", "\n", "# os.environ[\"DSPY_CACHEDIR\"] =\n", "# os.environ[\"DSP_CACHEDIR\"] =\n", @@ -527,8 +527,8 @@ "source": [ "import dspy\n", "from dspy.datasets import HotPotQA\n", - "from dspy.evaluate import Evaluate\n", - "from dsp.utils.utils import deduplicate\n", + "from dspy.evaluate import Evaluate # noqa\n", + "from dsp.utils.utils import deduplicate # noqa\n", "\n", "\n", "# We are setting the experimental flag to True to make use of the fine-tuning\n", @@ -773,8 +773,8 @@ "source": [ "import dspy\n", "from dspy.datasets import HotPotQA\n", - "from dspy.evaluate import Evaluate\n", - "from dsp.utils.utils import deduplicate\n", + "from dspy.evaluate import Evaluate # noqa\n", + "from dsp.utils.utils import deduplicate # noqa\n", "\n", "\n", "# We are setting the experimental flag to True to make use of the fine-tuning\n", From d46feed43e4a46d489eb8cf41d5d04256a4b5ce4 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Thu, 7 Nov 2024 14:57:07 -0800 Subject: [PATCH 30/31] Fixes for dspy.majority --- dspy/predict/aggregation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/predict/aggregation.py b/dspy/predict/aggregation.py index bc37e203f..852e339c9 100644 --- a/dspy/predict/aggregation.py +++ b/dspy/predict/aggregation.py @@ -23,12 +23,12 @@ def majority(prediction_or_completions, normalize=default_normalize, field=None) try: signature = completions.signature - except: + except Exception: signature = None if not field: if signature: - field = signature.output_fields[-1] + field = list(signature.output_fields.keys())[-1] else: field = list(completions[0].keys())[-1] From cdee3b4888abfc446c33aaa1d86bd07d768d1933 Mon Sep 17 00:00:00 2001 From: Chen Qian Date: Thu, 7 Nov 2024 17:19:25 -0800 Subject: [PATCH 31/31] Databricks finetuning integration (#1770) * Databricks finetuning small fix some fixes * add examples --- dspy/clients/databricks.py | 361 +++++++++++++++++++++ dspy/clients/lm.py | 57 ++-- dspy/clients/provider.py | 21 +- dspy/teleprompt/bootstrap_finetune.py | 5 +- examples/finetune/databricks_finetuning.py | 70 ++++ tests/clients/test_databricks.py | 94 ++++++ 6 files changed, 565 insertions(+), 43 deletions(-) create mode 100644 dspy/clients/databricks.py create mode 100644 examples/finetune/databricks_finetuning.py create mode 100644 tests/clients/test_databricks.py diff --git a/dspy/clients/databricks.py b/dspy/clients/databricks.py new file mode 100644 index 000000000..a56e91ac1 --- /dev/null +++ b/dspy/clients/databricks.py @@ -0,0 +1,361 @@ +import logging +import os +import re +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import requests +import ujson + +from dspy.clients.provider import Provider, TrainingJob +from dspy.clients.utils_finetune import DataFormat, get_finetune_directory + +if TYPE_CHECKING: + from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +class TrainingJobDatabricks(TrainingJob): + def __init__(self, finetuning_run=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.finetuning_run = finetuning_run + self.launch_started = False + self.launch_completed = False + self.endpoint_name = None + + def status(self): + if not self.finetuning_run: + return None + try: + from databricks.model_training import foundation_model as fm + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks_genai package via " + "`pip install databricks_genai`." + ) + run = fm.get(self.finetuning_run) + return run.status + + +class DatabricksProvider(Provider): + finetunable = True + TrainingJob = TrainingJobDatabricks + + @staticmethod + def is_provider_model(model: str) -> bool: + # We don't automatically infer Databricks models because Databricks is not a proprietary model provider. + return False + + @staticmethod + def deploy_finetuned_model( + model: str, + data_format: Optional[DataFormat] = None, + databricks_host: Optional[str] = None, + databricks_token: Optional[str] = None, + deploy_timeout: int = 900, + ): + workspace_client = _get_workspace_client() + model_version = next(workspace_client.model_versions.list(model)).version + + # Allow users to override the host and token. This is useful on Databricks hosted runtime. + databricks_host = databricks_host or workspace_client.config.host + databricks_token = databricks_token or workspace_client.config.token + + headers = {"Context-Type": "text/json", "Authorization": f"Bearer {databricks_token}"} + + optimizable_info = requests.get( + url=f"{databricks_host}/api/2.0/serving-endpoints/get-model-optimization-info/{model}/{model_version}", + headers=headers, + ).json() + + if "optimizable" not in optimizable_info or not optimizable_info["optimizable"]: + raise ValueError(f"Model is not eligible for provisioned throughput: {optimizable_info}") + + chunk_size = optimizable_info["throughput_chunk_size"] + + # Minimum desired provisioned throughput + min_provisioned_throughput = 0 + + # Maximum desired provisioned throughput + max_provisioned_throughput = chunk_size + + # Databricks serving endpoint names cannot contain ".". + model_name = model.replace(".", "_") + + get_endpoint_response = requests.get( + url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers + ) + + if get_endpoint_response.status_code == 200: + logger.info( + f"Serving endpoint {model_name} already exists, updating it instead of creating a new one." + ) + # The serving endpoint already exists, we will update it instead of creating a new one. + data = { + "served_entities": [ + { + "name": model_name, + "entity_name": model, + "entity_version": model_version, + "min_provisioned_throughput": min_provisioned_throughput, + "max_provisioned_throughput": max_provisioned_throughput, + } + ] + } + + response = requests.put( + url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}/config", + json=data, + headers=headers, + ) + else: + logger.info( + f"Creating serving endpoint {model_name} on Databricks model serving!" + ) + # Send the POST request to create the serving endpoint + data = { + "name": model_name, + "config": { + "served_entities": [ + { + "name": model_name, + "entity_name": model, + "entity_version": model_version, + "min_provisioned_throughput": min_provisioned_throughput, + "max_provisioned_throughput": max_provisioned_throughput, + } + ] + }, + } + + response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers) + + if response.status_code == 200: + logger.info( + f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!" + ) + else: + raise ValueError(f"Failed to create serving endpoint: {response.json()}.") + + logger.info( + f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check " + f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}" + ) + from openai import OpenAI + + client = OpenAI( + api_key=databricks_token, + base_url=f"{databricks_host}/serving-endpoints", + ) + # Wait for the deployment to be ready. + num_retries = deploy_timeout // 60 + for _ in range(num_retries): + try: + if data_format == DataFormat.chat: + client.chat.completions.create( + messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1 + ) + elif data_format == DataFormat.completion: + client.completions.create(prompt="hi", model=model_name, max_tokens=1) + logger.info(f"Databricks model serving endpoint {model_name} is ready!") + return + except Exception: + time.sleep(60) + + raise ValueError( + f"Failed to create serving endpoint {model_name} on Databricks model serving platform within " + f"{deploy_timeout} seconds." + ) + + @staticmethod + def finetune( + job: TrainingJobDatabricks, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[Union[DataFormat, str]] = None, + ) -> str: + if isinstance(data_format, str): + if data_format == "chat": + data_format = DataFormat.chat + elif data_format == "completion": + data_format = DataFormat.completion + else: + raise ValueError( + f"String `data_format` must be one of 'chat' or 'completion', but received: {data_format}." + ) + + if "train_data_path" not in train_kwargs: + raise ValueError("The `train_data_path` must be provided to finetune on Databricks.") + # Add the file name to the directory path. + train_kwargs["train_data_path"] = DatabricksProvider.upload_data( + train_data, train_kwargs["train_data_path"], data_format + ) + + try: + from databricks.model_training import foundation_model as fm + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks_genai package via " + "`pip install databricks_genai`." + ) + + if "register_to" not in train_kwargs: + raise ValueError("The `register_to` must be provided to finetune on Databricks.") + + # Allow users to override the host and token. This is useful on Databricks hosted runtime. + databricks_host = train_kwargs.pop("databricks_host", None) + databricks_token = train_kwargs.pop("databricks_token", None) + + skip_deploy = train_kwargs.pop("skip_deploy", False) + deploy_timeout = train_kwargs.pop("deploy_timeout", 900) + + logger.info("Starting finetuning on Databricks... this might take a few minutes to finish.") + finetuning_run = fm.create( + model=model, + **train_kwargs, + ) + + job.run = finetuning_run + + # Wait for the finetuning run to be ready. + while True: + job.run = fm.get(job.run) + if job.run.status.display_name == "Completed": + logger.info("Finetuning run completed successfully!") + break + elif job.run.status.display_name == "Failed": + raise ValueError( + f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks " + f"workspace for more details. Finetuning job's metadata: {job.run}." + ) + else: + time.sleep(60) + + if skip_deploy: + return None + + job.launch_started = True + model_to_deploy = train_kwargs.get("register_to") + job.endpoint_name = model_to_deploy.replace(".", "_") + DatabricksProvider.deploy_finetuned_model( + model_to_deploy, data_format, databricks_host, databricks_token, deploy_timeout + ) + job.launch_completed = True + # The finetuned model name should be in the format: "databricks/". + return f"databricks/{job.endpoint_name}" + + @staticmethod + def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat): + logger.info("Uploading finetuning data to Databricks Unity Catalog...") + file_path = _save_data_to_local_file(train_data, data_format) + + w = _get_workspace_client() + _create_directory_in_databricks_unity_catalog(w, databricks_unity_catalog_path) + + try: + with open(file_path, "rb") as f: + target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) + w.files.upload(target_path, f, overwrite=True) + logger.info("Successfully uploaded finetuning data to Databricks Unity Catalog!") + return target_path + except Exception as e: + raise ValueError(f"Failed to upload finetuning data to Databricks Unity Catalog: {e}") + + +def _get_workspace_client() -> "WorkspaceClient": + try: + from databricks.sdk import WorkspaceClient + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks-sdk package via " + "`pip install databricks-sdk`." + ) + return WorkspaceClient() + + +def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databricks_unity_catalog_path: str): + pattern = r"^/Volumes/(?P[^/]+)/(?P[^/]+)/(?P[^/]+)(/[^/]+)+$" + match = re.match(pattern, databricks_unity_catalog_path) + if not match: + raise ValueError( + f"Databricks Unity Catalog path must be in the format '/Volumes////...', but " + f"received: {databricks_unity_catalog_path}." + ) + + catalog = match.group("catalog") + schema = match.group("schema") + volume = match.group("volume") + + try: + volume_path = f"{catalog}.{schema}.{volume}" + w.volumes.read(volume_path) + except Exception: + raise ValueError( + f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks " + "workspace." + ) + + try: + w.files.get_directory_metadata(databricks_unity_catalog_path) + logger.info(f"Directory {databricks_unity_catalog_path} already exists, skip creating it.") + except Exception: + # Create the directory if it doesn't exist, we don't raise an error because this is a common case. + logger.info(f"Creating directory {databricks_unity_catalog_path} in Databricks Unity Catalog...") + w.files.create_directory(databricks_unity_catalog_path) + logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!") + + +def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: DataFormat): + import uuid + + file_name = f"finetuning_{uuid.uuid4()}.jsonl" + + finetune_dir = get_finetune_directory() + file_path = os.path.join(finetune_dir, file_name) + file_path = os.path.abspath(file_path) + with open(file_path, "w") as f: + for item in train_data: + if data_format == DataFormat.chat: + _validate_chat_data(item) + elif data_format == DataFormat.completion: + _validate_completion_data(item) + + f.write(ujson.dumps(item) + "\n") + return file_path + + +def _validate_chat_data(data: Dict[str, Any]): + if "messages" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but " + f"received: {data}" + ) + + if not isinstance(data["messages"], list): + raise ValueError( + "The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and " + f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}" + ) + + for message in data["messages"]: + if "role" not in message: + raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.") + if "content" not in message: + raise ValueError( + f"Each message in the 'messages' list must contain a 'content' key, but received: {message}." + ) + + +def _validate_completion_data(data: Dict[str, Any]): + if "prompt" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but " + f"received: {data}" + ) + if "response" not in data and "completion" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'response' or 'completion' key when " + f"`task=INSTRUCTION_FINETUNE`, but received: {data}" + ) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index b73d272ff..ac891841c 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,26 +1,22 @@ import functools -from .base_lm import BaseLM import logging import os +import threading import uuid from datetime import datetime -import threading from typing import Any, Dict, List, Literal, Optional -import dspy + import litellm import ujson +import dspy from dspy.adapters.base import Adapter -from dspy.clients.provider import Provider, TrainingJob from dspy.clients.openai import OpenAIProvider -from dspy.clients.utils_finetune import ( - DataFormat, - validate_data_format, - infer_data_format -) - +from dspy.clients.provider import Provider, TrainingJob +from dspy.clients.utils_finetune import DataFormat, infer_data_format, validate_data_format from dspy.utils.callback import BaseCallback, with_callbacks +from .base_lm import BaseLM logger = logging.getLogger(__name__) @@ -37,10 +33,10 @@ def __init__( temperature: float = 0.0, max_tokens: int = 1000, cache: bool = True, - launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, num_retries: int = 3, provider=None, + finetuning_model: Optional[str] = None, **kwargs, ): """ @@ -58,18 +54,21 @@ def __init__( num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential backoff. + provider: The provider to use. If not specified, the provider will be inferred from the model. + finetuning_model: The model to finetune. In some providers, the models available for finetuning is different + from the models available for inference. """ # Remember to update LM.copy() if you modify the constructor! self.model = model self.model_type = model_type self.cache = cache - self.launch_kwargs = launch_kwargs or {} self.provider = provider or self.infer_provider() self.callbacks = callbacks or [] self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] self.callbacks = callbacks or [] self.num_retries = num_retries + self.finetuning_model = finetuning_model #turned off by default to avoid LiteLLM logging during every LM call litellm.suppress_debug_info = dspy.settings.suppress_debug_info @@ -117,18 +116,18 @@ def __call__(self, prompt=None, messages=None, **kwargs): return outputs - def launch(self): - self.provider.launch(self.model, self.launch_kwargs) + def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None): + self.provider.launch(self.model, **launch_kwargs) - def kill(self): - self.provider.kill(self.model, self.launch_kwargs) + def kill(self, kill_kwargs: Optional[Dict[str, Any]] = None): + self.provider.kill(self.model, **kill_kwargs) def finetune( - self, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]]=None, - data_format: Optional[DataFormat] = None, - ) -> TrainingJob: + self, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[DataFormat] = None, + ) -> TrainingJob: from dspy import settings as settings err = "Fine-tuning is an experimental feature." @@ -152,12 +151,9 @@ def thread_function_wrapper(): return self._run_finetune_job(job) thread = threading.Thread(target=thread_function_wrapper) + model_to_finetune = self.finetuning_model or self.model job = self.provider.TrainingJob( - thread=thread, - model=self.model, - train_data=train_data, - train_kwargs=train_kwargs, - data_format=data_format + thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format ) thread.start() @@ -172,23 +168,24 @@ def _run_finetune_job(self, job: TrainingJob): model=job.model, train_data=job.train_data, train_kwargs=job.train_kwargs, - data_format=job.data_format + data_format=job.data_format, ) lm = self.copy(model=model) job.set_result(lm) except Exception as err: logger.error(err) job.set_result(err) - + def infer_provider(self) -> Provider: if OpenAIProvider.is_provider_model(self.model): return OpenAIProvider() # TODO(PR): Keeping this function here will require us to import all # providers in this file. Is this okay? return Provider() - + def infer_adapter(self) -> Adapter: import dspy + if dspy.settings.adapter: return dspy.settings.adapter @@ -197,7 +194,7 @@ def infer_adapter(self) -> Adapter: } model_type = self.model_type return model_type_to_adapter[model_type] - + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" diff --git a/dspy/clients/provider.py b/dspy/clients/provider.py index 9eb02be1f..cc4e7147b 100644 --- a/dspy/clients/provider.py +++ b/dspy/clients/provider.py @@ -1,7 +1,7 @@ -from concurrent.futures import Future from abc import abstractmethod +from concurrent.futures import Future from threading import Thread -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dspy.clients.utils_finetune import DataFormat @@ -9,9 +9,9 @@ class TrainingJob(Future): def __init__( self, - thread: Thread, - model: str, - train_data: List[Dict[str, Any]], + thread: Optional[Thread] = None, + model: Optional[str] = None, + train_data: Optional[List[Dict[str, Any]]] = None, train_kwargs: Optional[Dict[str, Any]] = None, data_format: Optional[DataFormat] = None, ): @@ -33,7 +33,6 @@ def status(self): class Provider: - def __init__(self): self.finetunable = False self.TrainingJob = TrainingJob @@ -45,23 +44,23 @@ def is_provider_model(model: str) -> bool: return False @staticmethod - def launch(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + def launch(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): msg = f"`launch()` is called for the auto-launched model `{model}`" msg += " -- no action is taken!" print(msg) - + @staticmethod - def kill(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + def kill(model: str, kill_kwargs: Optional[Dict[str, Any]] = None): msg = f"`kill()` is called for the auto-launched model `{model}`" msg += " -- no action is taken!" print(msg) - + @staticmethod def finetune( job: TrainingJob, model: str, train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]] = None, - data_format: Optional[DataFormat] = None, + data_format: Optional[Union[DataFormat, str]] = None, ) -> str: raise NotImplementedError diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py index 418c8f0a4..32f0ca58f 100644 --- a/dspy/teleprompt/bootstrap_finetune.py +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -6,8 +6,8 @@ from dspy.adapters.base import Adapter from dspy.clients.utils_finetune import infer_data_format from dspy.evaluate.evaluate import Evaluate -from dspy.primitives.example import Example from dspy.predict.predict import Predict +from dspy.primitives.example import Example from dspy.primitives.program import Program from dspy.teleprompt.teleprompt import Teleprompter @@ -235,7 +235,8 @@ def set_missing_predictor_lms(program: Program) -> Program: def prepare_student(student: Program) -> Program: print("Ensuring that the student is not compiled") - assert not student._compiled, "The student program should not be compiled" + if getattr(student, "_compiled", False): + raise ValueError("The student program should not be compiled.") # TODO: Should we use reset_copy here? How would it affect the student # program's predictor LMs, if they are set? diff --git a/examples/finetune/databricks_finetuning.py b/examples/finetune/databricks_finetuning.py new file mode 100644 index 000000000..372c0121e --- /dev/null +++ b/examples/finetune/databricks_finetuning.py @@ -0,0 +1,70 @@ +from typing import Literal + +from datasets import load_dataset + +import dspy +from dspy.clients.databricks import DatabricksProvider + +# Define the range as a tuple of valid integers +CLASSES = tuple(range(77)) + +ds = load_dataset("PolyAI/banking77") +trainset_hf = ds["train"][:100] +trainset = [] + +for text, label in zip(trainset_hf["text"], trainset_hf["label"]): + # Each example should have two fields, `inputs` and `answer`, with `inputs` as the input field, + # and `answer` as the output field. + trainset.append(dspy.Example(text=text, answer=label).with_inputs("text")) + +gold = {text: label for text, label in zip(trainset_hf["text"], trainset_hf["label"])} + +lm = dspy.LM( + model="databricks/databricks-meta-llama-3-1-70b-instruct", + provider=DatabricksProvider, + finetuning_model="meta-llama/Llama-3.2-3B", +) + +dspy.settings.configure(lm=lm) +dspy.settings.experimental = True + + +def accuracy(example, pred, trace=None): + return int(example.answer == int(pred.answer)) + + +class Classify(dspy.Signature): + """As a part of a banking issue traiging system, classify the intent of a natural language query.""" + + text = dspy.InputField() + answer: Literal[CLASSES] = dspy.OutputField() + + +class Program(dspy.Module): + def __init__(self, oracle=False): + self.oracle = oracle + self.classify = dspy.ChainOfThoughtWithHint(Classify) + + def forward(self, text): + if self.oracle and text in gold: + hint = f"the right label is {gold[text]}" + else: + hint = None + return self.classify(text=text, hint=hint) + + +model = Program(oracle=True) +print("Try the original model: ", model("I am still waiting on my card?")) + +train_kwargs = { + "train_data_path": "/Volumes/main/chenmoney/testing/dspy_testing/classification", + "register_to": "main.chenmoney.finetuned_model_classification", + "task_type": "CHAT_COMPLETION", +} + +optimized = dspy.BootstrapFinetune(metric=accuracy, num_threads=10, train_kwargs=train_kwargs).compile( + student=model, trainset=trainset +) +optimized.oracle = False + +print("Try the optimized model: ", optimized("I am still waiting on my card?")) diff --git a/tests/clients/test_databricks.py b/tests/clients/test_databricks.py new file mode 100644 index 000000000..eb209e9f6 --- /dev/null +++ b/tests/clients/test_databricks.py @@ -0,0 +1,94 @@ +"""Test the Databricks finetuning and deployment. + +This test requires valid Databricks credentials, so it is skipped on github actions. Right now it is only used for +manual testing. +""" + +from dspy.clients.databricks import ( + DatabricksProvider, + _create_directory_in_databricks_unity_catalog, + TrainingJobDatabricks, +) + +import pytest +import dspy + +try: + from databricks.sdk import WorkspaceClient + + WorkspaceClient() +except (ImportError, Exception): + # Skip the test if the Databricks SDK is not configured or credentials are not available. + pytestmark = pytest.mark.skip(reason="Databricks SDK not configured or credentials not available") + + +def test_create_directory_in_databricks_unity_catalog(): + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + with pytest.raises( + ValueError, + match=( + "Databricks Unity Catalog path must be in the format '/Volumes////...', " + "but received: /badstring/whatever" + ), + ): + _create_directory_in_databricks_unity_catalog(w, "/badstring/whatever") + + _create_directory_in_databricks_unity_catalog(w, "/Volumes/main/chenmoney/testing/dspy_testing") + # Check that the directory was created successfully, otherwise `get_directory_metadata` will raise an exception. + w.files.get_directory_metadata("/Volumes/main/chenmoney/testing/dspy_testing") + + +def test_create_finetuning_job(): + fake_training_data = [ + { + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + }, + { + "messages": [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "Paris!"}, + ] + }, + { + "messages": [ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "Berlin!"}, + ] + }, + ] + dspy.settings.experimental = True + + job = TrainingJobDatabricks() + + finetuned_model = DatabricksProvider.finetune( + job=job, + model="meta-llama/Llama-3.2-1B", + train_data=fake_training_data, + data_format="chat", + train_kwargs={ + "train_data_path": "/Volumes/main/chenmoney/testing/dspy_testing", + "register_to": "main.chenmoney.finetuned_model", + "task_type": "CHAT_COMPLETION", + "skip_deploy": True, + }, + ) + assert job.finetuning_run.status.display_name is not None + + +def test_deploy_finetuned_model(): + dspy.settings.experimental = True + model_to_deploy = "main.chenmoney.finetuned_model" + + DatabricksProvider.deploy_finetuned_model( + model=model_to_deploy, + data_format="chat", + ) + + lm = dspy.LM(model="databricks/main_chenmoney_finetuned_model") + lm("what is 2 + 2?")