Skip to content

Commit

Permalink
Add deepseek-r1 integration using openai client
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-jkim committed Jan 29, 2025
1 parent fe10865 commit 6306633
Show file tree
Hide file tree
Showing 21 changed files with 2,932 additions and 1,961 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
- uses: actions/checkout@v3 # Updated to the latest version
- uses: actions/checkout@v4 # Updated to the latest version
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4 # Updated to the latest version
with:
Expand All @@ -37,7 +37,7 @@ jobs:
poetry run pytest
- name: Upload pytest results as an artifact (optional)
uses: actions/upload-artifact@v3 # Updated to the latest version
uses: actions/upload-artifact@v4 # Updated to the latest version
if: always() # Always run this step to ensure test results are saved even if previous steps fail
with:
name: pytest-results
Expand Down
19 changes: 13 additions & 6 deletions adalflow/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
## [0.2.7] - 2024-09-23
## [0.2.7] - 2025-01-16

### Improved
- Better diagnose report for `Trainer.diagnose`.
- Multi-hop RAG with handling of Cycle.

## [0.2.7] - TO Be Released
### Added
- `Memory` is completed with `call` and `add_dialog_turn` methods.
- Integrated `LanceDB` in the `Retriever`
- Multi-modal (image input and generation) in `OpenAIClient` along with tests.
- `ComponentList` to support a list of components registered in a component. Added `test_componentlist` to test the `ComponentList`.

### Improved
- Better diagnose report for `Trainer.diagnose`.
- `BedrockAPIClient` added more details on setup, yet it is still in experimental stage.
- `AzureAPIClient` added more details on setup, yet it is still in experimental stage.
- `Retriever` class:
- Support data id (field).
- `GradComponent`: Support pass-through gradient for the `forward` method.

Optimization
- Aggregated all backward engine prompts in `backward_engine_prompt`.
- Added `TGDData` for the optimizer to support reasoning at proposing new prompt.
- Added `sequential_order` in the `Trainer` to support the sequential training order. Reorganized the trainer code.
## [0.2.6] - 2024-11-25
### Improved
- Add default `max_tokens=512` to the `AnthropicAPIClient` to avoid the error when the user does not provide the `max_tokens` in the prompt.
Expand Down
4 changes: 3 additions & 1 deletion adalflow/adalflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.6"
__version__ = "0.2.7"

from adalflow.core.component import Component, fun_to_component
from adalflow.core.container import Sequential, ComponentList
Expand Down Expand Up @@ -61,6 +61,7 @@
AnthropicAPIClient,
CohereAPIClient,
BedrockAPIClient,
DeepSeekClient,
)

# data pipeline
Expand Down Expand Up @@ -124,6 +125,7 @@
"OpenAIClient",
"GoogleGenAIClient",
"GroqAPIClient",
"DeepSeekClient",
"OllamaClient",
"TransformersClient",
"AnthropicAPIClient",
Expand Down
7 changes: 7 additions & 0 deletions adalflow/adalflow/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
"adalflow.components.model_client.ollama_client.OllamaClient",
OptionalPackages.OLLAMA,
)

# no imports are needed for DeepSeek
DeepSeekClient = LazyImport(
"adalflow.components.model_client.deepseek_client.DeepSeekClient",
None
)
get_first_message_content = LazyImport(
"adalflow.components.model_client.openai_client.get_first_message_content",
OptionalPackages.OPENAI,
Expand All @@ -76,6 +82,7 @@
"GroqAPIClient",
"OpenAIClient",
"GoogleGenAIClient",
"DeepSeekClient",
]

for name in __all__:
Expand Down
74 changes: 62 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""AWS Bedrock ModelClient integration."""

import json
import os
from typing import Dict, Optional, Any, Callable
from typing import Dict, Optional, Any, Callable, Generator as GeneratorType
import backoff
import logging

Expand All @@ -26,7 +27,6 @@ def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion["output"]["message"]["content"][0]["text"]
return completion["output"]["message"]["content"][0]["text"]


__all__ = [
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout

self._client = None
self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
Expand Down Expand Up @@ -158,16 +159,51 @@ def init_sync_client(self):
def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
def handle_stream_response(self, stream: dict) -> GeneratorType:
r"""Handle the stream response from bedrock. Yield the chunks.
Args:
stream (dict): The stream response generator from bedrock.
Returns:
GeneratorType: A generator that yields the chunks from bedrock stream.
"""
try:
stream: GeneratorType = stream["stream"]
for chunk in stream:
log.debug(f"Raw chunk: {chunk}")
yield chunk
except Exception as e:
log.debug(f"Error in handle_stream_response: {e}") # Debug print
raise

def parse_chat_completion(self, completion: dict) -> "GeneratorOutput":
r"""Parse the completion, and assign it into the raw_response attribute.
If the completion is a stream, it will be handled by the handle_stream_response
method that returns a Generator. Otherwise, the completion will be parsed using
the get_first_message_content method.
Args:
completion (dict): The completion response from bedrock API call.
Returns:
GeneratorOutput: A generator output object with the parsed completion. May
return a generator if the completion is a stream.
"""
try:
data = completion["output"]["message"]["content"][0]["text"]
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
usage = None
data = self.chat_completion_parser(completion)
if not isinstance(data, GeneratorType):
# Streaming completion usage tracking is not implemented.
usage = self.track_completion_usage(completion)
return GeneratorOutput(
data=None, error=None, raw_response=data, usage=usage
)
except Exception as e:
log.error(f"Error parsing completion: {e}")
log.error(f"Error parsing the completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
data=None, error=str(e), raw_response=json.dumps(completion)
)

def track_completion_usage(self, completion: Dict) -> CompletionUsage:
Expand All @@ -191,6 +227,7 @@ def list_models(self):
print(f" Description: {model['description']}")
print(f" Provider: {model['provider']}")
print("")

except Exception as e:
print(f"Error listing models: {e}")

Expand Down Expand Up @@ -222,14 +259,27 @@ def convert_inputs_to_api_kwargs(
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException,
),
max_time=5,
max_time=2,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
def call(
self,
api_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED,
) -> dict:
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
if "stream" in api_kwargs and api_kwargs.get("stream", False):
log.debug("Streaming call")
api_kwargs.pop(
"stream", None
) # stream is not a valid parameter for bedrock
self.chat_completion_parser = self.handle_stream_response
return self.sync_client.converse_stream(**api_kwargs)
else:
api_kwargs.pop("stream", None)
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

Expand Down
72 changes: 72 additions & 0 deletions adalflow/adalflow/components/model_client/deepseek_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import logging
import backoff
from typing import (
Dict,
Sequence,
Optional,
List,
Any,
TypeVar,
Callable,
Literal,
)

from adalflow.utils.lazy_import import safe_import, OptionalPackages
from adalflow.components.model_client.openai_client import OpenAIClient
from openai.types import Completion

openai = safe_import(OptionalPackages.OPENAI.value[0], OptionalPackages.OPENAI.value[1])

class DeepSeekClient(OpenAIClient):
"""
A component wrapper for the DeepSeek API client.
DeepSeek's API is compatible with OpenAI's API, making it possible to use OpenAI SDKs
or OpenAI-compatible software with DeepSeek by adjusting the API base URL.
This client extends `OpenAIClient` but modifies the default `base_url` to use DeepSeek's API.
Documentation: https://api-docs.deepseek.com/guides/reasoning_model
Args:
api_key (Optional[str], optional): DeepSeek API key. Defaults to `None`.
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse API responses.
input_type (Literal["text", "messages"], optional): Defines how input is handled. Defaults to `"text"`.
base_url (str, optional): API base URL, defaults to `"https://api.deepseek.com/v1/"`.
"""

def __init__(
self,
api_key: Optional[str] = None,
chat_completion_parser: Callable[[Completion], Any] = None,
input_type: Literal["text", "messages"] = "messages",
base_url: str = "https://api.deepseek.com/v1/",
env_api_key_name: str = "DEEPSEEK_API_KEY"
):
"""Initializes DeepSeek API client with the correct base URL. The input_type is set to "messages" by default to be compatible with DeepSeek reasoner.
"""
super().__init__(api_key=api_key, chat_completion_parser=chat_completion_parser, input_type=input_type, base_url=base_url, env_api_key_name=env_api_key_name)

# Example usage:
if __name__ == "__main__":
from adalflow.core import Generator
from adalflow.utils import setup_env, get_logger

log = get_logger(level="DEBUG")

prompt_kwargs = {"input_str": "What is the meaning of life?"}

setup_env()

gen = Generator(
model_client=DeepSeekClient(),
model_kwargs={"model": "deepseek-reasoner", "stream": True},
)

gen_response = gen(prompt_kwargs)
print(f"gen_response: {gen_response}")

for genout in gen_response.data:
print(f"genout: {genout}")

Loading

0 comments on commit 6306633

Please sign in to comment.