diff --git a/airflow/providers/openai/hooks/openai.py b/airflow/providers/openai/hooks/openai.py index f57c41c2b9f4b..31a4c16b9f29c 100644 --- a/airflow/providers/openai/hooks/openai.py +++ b/airflow/providers/openai/hooks/openai.py @@ -18,10 +18,22 @@ from __future__ import annotations from functools import cached_property -from typing import Any +from typing import TYPE_CHECKING, Any, Literal from openai import OpenAI +if TYPE_CHECKING: + from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted + from openai.types.beta.threads import Message, Run + from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionFunctionMessageParam, + ChatCompletionMessage, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionUserMessageParam, + ) + from airflow.hooks.base import BaseHook @@ -77,6 +89,168 @@ def get_conn(self) -> OpenAI: **openai_client_kwargs, ) + def create_chat_completion( + self, + messages: list[ + ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam + | ChatCompletionToolMessageParam + | ChatCompletionFunctionMessageParam + ], + model: str = "gpt-3.5-turbo", + **kwargs: Any, + ) -> list[ChatCompletionMessage]: + """ + Create a model response for the given chat conversation and returns a list of chat completions. + + :param messages: A list of messages comprising the conversation so far + :param model: ID of the model to use + """ + response = self.conn.chat.completions.create(model=model, messages=messages, **kwargs) + return response.choices + + def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant: + """Create an OpenAI assistant using the given model. + + :param model: The OpenAI model for the assistant to use. + """ + assistant = self.conn.beta.assistants.create(model=model, **kwargs) + return assistant + + def get_assistant(self, assistant_id: str) -> Assistant: + """ + Get an OpenAI assistant. + + :param assistant_id: The ID of the assistant to retrieve. + """ + assistant = self.conn.beta.assistants.retrieve(assistant_id=assistant_id) + return assistant + + def get_assistants(self, **kwargs: Any) -> list[Assistant]: + """Get a list of Assistant objects.""" + assistants = self.conn.beta.assistants.list(**kwargs) + return assistants.data + + def get_assistant_by_name(self, assistant_name: str) -> Assistant | None: + """Get an OpenAI Assistant object for a given name. + + :param assistant_name: The name of the assistant to retrieve + """ + response = self.get_assistants() + for assistant in response: + if assistant.name == assistant_name: + return assistant + return None + + def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant: + """Modify an existing Assistant object. + + :param assistant_id: The ID of the assistant to be modified. + """ + assistant = self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs) + return assistant + + def delete_assistant(self, assistant_id: str) -> AssistantDeleted: + """Delete an OpenAI Assistant for a given ID. + + :param assistant_id: The ID of the assistant to delete. + """ + response = self.conn.beta.assistants.delete(assistant_id=assistant_id) + return response + + def create_thread(self, **kwargs: Any) -> Thread: + """Create an OpenAI thread.""" + thread = self.conn.beta.threads.create(**kwargs) + return thread + + def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread: + """Modify an existing Thread object. + + :param thread_id: The ID of the thread to modify. + :param metadata: Set of 16 key-value pairs that can be attached to an object. + """ + thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata) + return thread + + def delete_thread(self, thread_id: str) -> ThreadDeleted: + """Delete an OpenAI thread for a given thread_id. + + :param thread_id: The ID of the thread to delete. + """ + response = self.conn.beta.threads.delete(thread_id=thread_id) + return response + + def create_message( + self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any + ) -> Message: + """Create a message for a given Thread. + + :param thread_id: The ID of the thread to create a message for. + :param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'. + :param content: The content of the message. + """ + thread_message = self.conn.beta.threads.messages.create( + thread_id=thread_id, role=role, content=content, **kwargs + ) + return thread_message + + def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]: + """Return a list of messages for a given Thread. + + :param thread_id: The ID of the thread the messages belong to. + """ + messages = self.conn.beta.threads.messages.list(thread_id=thread_id, **kwargs) + return messages.data + + def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message: + """Modify an existing message for a given Thread. + + :param thread_id: The ID of the thread to which this message belongs. + :param message_id: The ID of the message to modify. + """ + thread_message = self.conn.beta.threads.messages.update( + thread_id=thread_id, message_id=message_id, **kwargs + ) + return thread_message + + def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run: + """Create a run for a given thread and assistant. + + :param thread_id: The ID of the thread to run. + :param assistant_id: The ID of the assistant to use to execute this run. + """ + run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs) + return run + + def get_run(self, thread_id: str, run_id: str) -> Run: + """Retrieve a run for a given thread and run. + + :param thread_id: The ID of the thread that was run. + :param run_id: The ID of the run to retrieve. + """ + run = self.conn.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id) + return run + + def get_runs(self, thread_id: str, **kwargs: Any) -> list[Run]: + """ + Return a list of runs belonging to a thread. + + :param thread_id: The ID of the thread the run belongs to. + """ + runs = self.conn.beta.threads.runs.list(thread_id=thread_id, **kwargs) + return runs.data + + def modify_run(self, thread_id: str, run_id: str, **kwargs: Any) -> Run: + """ + Modify a run on a given thread. + + :param thread_id: The ID of the thread that was run. + :param run_id: The ID of the run to modify. + """ + run = self.conn.beta.threads.runs.update(thread_id=thread_id, run_id=run_id, **kwargs) + return run + def create_embeddings( self, text: str | list[str] | list[int] | list[list[int]], diff --git a/airflow/providers/openai/provider.yaml b/airflow/providers/openai/provider.yaml index 05c472b5a0971..3039c65c92a95 100644 --- a/airflow/providers/openai/provider.yaml +++ b/airflow/providers/openai/provider.yaml @@ -41,7 +41,7 @@ integrations: dependencies: - apache-airflow>=2.6.0 - - openai[datalib]>=1.0 + - openai[datalib]>=1.16 hooks: - integration-name: OpenAI diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index ec9ebadebe084..6d88417d50143 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -805,7 +805,7 @@ "openai": { "deps": [ "apache-airflow>=2.6.0", - "openai[datalib]>=1.0" + "openai[datalib]>=1.16" ], "devel-deps": [], "cross-providers-deps": [], diff --git a/tests/providers/openai/hooks/test_openai.py b/tests/providers/openai/hooks/test_openai.py index c3f17f7105ca9..aa7a479bbbe0b 100644 --- a/tests/providers/openai/hooks/test_openai.py +++ b/tests/providers/openai/hooks/test_openai.py @@ -23,11 +23,24 @@ openai = pytest.importorskip("openai") +from openai.pagination import SyncCursorPage from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted +from openai.types.beta.threads import Message, Run +from openai.types.chat import ChatCompletion from airflow.models import Connection from airflow.providers.openai.hooks.openai import OpenAIHook +ASSISTANT_ID = "test_assistant_abc123" +ASSISTANT_NAME = "Test Assistant" +ASSISTANT_INSTRUCTIONS = "You are a test assistant." +THREAD_ID = "test_thread_abc123" +MESSAGE_ID = "test_message_abc123" +RUN_ID = "test_run_abc123" +MODEL = "gpt-4" +METADATA = {"modified": "true", "user": "abc123"} + @pytest.fixture def mock_openai_connection(): @@ -56,6 +69,226 @@ def mock_embeddings_response(): ) +@pytest.fixture +def mock_completion(): + return ChatCompletion( + id="chatcmpl-123", + object="chat.completion", + created=1677652288, + model=MODEL, + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?", + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + ) + + +@pytest.fixture +def mock_assistant(): + return Assistant( + id=ASSISTANT_ID, + name=ASSISTANT_NAME, + object="assistant", + created_at=1677652288, + model=MODEL, + instructions=ASSISTANT_INSTRUCTIONS, + tools=[], + file_ids=[], + metadata={}, + ) + + +@pytest.fixture +def mock_assistant_list(mock_assistant): + return SyncCursorPage[Assistant](data=[mock_assistant]) + + +@pytest.fixture +def mock_thread(): + return Thread(id=THREAD_ID, object="thread", created_at=1698984975, metadata={}) + + +@pytest.fixture +def mock_message(): + return Message( + id=MESSAGE_ID, + object="thread.message", + created_at=1698984975, + thread_id=THREAD_ID, + status="completed", + role="user", + content=[{"type": "text", "text": {"value": "Tell me something interesting.", "annotations": []}}], + assistant_id=ASSISTANT_ID, + run_id=RUN_ID, + file_ids=[], + metadata={}, + ) + + +@pytest.fixture +def mock_message_list(mock_message): + return SyncCursorPage[Message](data=[mock_message]) + + +@pytest.fixture +def mock_run(): + return Run( + id=RUN_ID, + object="thread.run", + created_at=1698107661, + assistant_id=ASSISTANT_ID, + thread_id=THREAD_ID, + status="completed", + started_at=1699073476, + completed_at=1699073476, + model=MODEL, + instructions="You are a test assistant.", + tools=[], + file_ids=[], + metadata={}, + ) + + +@pytest.fixture +def mock_run_list(mock_run): + return SyncCursorPage[Run](data=[mock_run]) + + +def test_create_chat_completion(mock_openai_hook, mock_completion): + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ] + + mock_openai_hook.conn.chat.completions.create.return_value = mock_completion + completion = mock_openai_hook.create_chat_completion(model=MODEL, messages=messages) + choice = completion[0] + assert choice.message.content == "Hello there, how may I assist you today?" + + +def test_create_assistant(mock_openai_hook, mock_assistant): + mock_openai_hook.conn.beta.assistants.create.return_value = mock_assistant + assistant = mock_openai_hook.create_assistant( + name=ASSISTANT_NAME, model=MODEL, instructions=ASSISTANT_INSTRUCTIONS + ) + assert assistant.name == ASSISTANT_NAME + assert assistant.model == MODEL + assert assistant.instructions == ASSISTANT_INSTRUCTIONS + + +def test_get_assistant(mock_openai_hook, mock_assistant): + mock_openai_hook.conn.beta.assistants.retrieve.return_value = mock_assistant + assistant = mock_openai_hook.get_assistant(assistant_id=ASSISTANT_ID) + assert assistant.name == ASSISTANT_NAME + assert assistant.model == MODEL + assert assistant.instructions == ASSISTANT_INSTRUCTIONS + + +def test_get_assistants(mock_openai_hook, mock_assistant_list): + mock_openai_hook.conn.beta.assistants.list.return_value = mock_assistant_list + assistants = mock_openai_hook.get_assistants() + assert isinstance(assistants, list) + + +def test_get_assistant_by_name(mock_openai_hook, mock_assistant_list): + mock_openai_hook.conn.beta.assistants.list.return_value = mock_assistant_list + assistant = mock_openai_hook.get_assistant_by_name(assistant_name=ASSISTANT_NAME) + assert assistant.name == ASSISTANT_NAME + + +def test_modify_assistant(mock_openai_hook, mock_assistant): + new_assistant_name = "New Test Assistant" + mock_assistant.name = new_assistant_name + mock_openai_hook.conn.beta.assistants.update.return_value = mock_assistant + assistant = mock_openai_hook.modify_assistant(assistant_id=ASSISTANT_ID, name=new_assistant_name) + assert assistant.name == new_assistant_name + + +def test_delete_assistant(mock_openai_hook): + delete_response = AssistantDeleted(id=ASSISTANT_ID, object="assistant.deleted", deleted=True) + mock_openai_hook.conn.beta.assistants.delete.return_value = delete_response + assistant_deleted = mock_openai_hook.delete_assistant(assistant_id=ASSISTANT_ID) + assert assistant_deleted.deleted + + +def test_create_thread(mock_openai_hook, mock_thread): + mock_openai_hook.conn.beta.threads.create.return_value = mock_thread + thread = mock_openai_hook.create_thread() + assert thread.id == THREAD_ID + + +def test_modify_thread(mock_openai_hook, mock_thread): + mock_thread.metadata = METADATA + mock_openai_hook.conn.beta.threads.update.return_value = mock_thread + thread = mock_openai_hook.modify_thread(thread_id=THREAD_ID, metadata=METADATA) + assert thread.metadata.get("modified") == "true" + assert thread.metadata.get("user") == "abc123" + + +def test_delete_thread(mock_openai_hook): + delete_response = ThreadDeleted(id=THREAD_ID, object="thread.deleted", deleted=True) + mock_openai_hook.conn.beta.threads.delete.return_value = delete_response + thread_deleted = mock_openai_hook.delete_thread(thread_id=THREAD_ID) + assert thread_deleted.deleted + + +def test_create_message(mock_openai_hook, mock_message): + role = "user" + content = "Tell me something interesting." + mock_openai_hook.conn.beta.threads.messages.create.return_value = mock_message + message = mock_openai_hook.create_message(thread_id=THREAD_ID, content=content, role=role) + assert message.id == MESSAGE_ID + + +def test_get_messages(mock_openai_hook, mock_message_list): + mock_openai_hook.conn.beta.threads.messages.list.return_value = mock_message_list + messages = mock_openai_hook.get_messages(thread_id=THREAD_ID) + assert isinstance(messages, list) + + +def test_modify_messages(mock_openai_hook, mock_message): + mock_message.metadata = METADATA + mock_openai_hook.conn.beta.threads.messages.update.return_value = mock_message + message = mock_openai_hook.modify_message(thread_id=THREAD_ID, message_id=MESSAGE_ID, metadata=METADATA) + assert message.metadata.get("modified") == "true" + assert message.metadata.get("user") == "abc123" + + +def test_create_run(mock_openai_hook, mock_run): + thread_id = THREAD_ID + assistant_id = ASSISTANT_ID + mock_openai_hook.conn.beta.threads.runs.create.return_value = mock_run + run = mock_openai_hook.create_run(thread_id=thread_id, assistant_id=assistant_id) + assert run.id == RUN_ID + + +def test_get_runs(mock_openai_hook, mock_run_list): + mock_openai_hook.conn.beta.threads.runs.list.return_value = mock_run_list + runs = mock_openai_hook.get_runs(thread_id=THREAD_ID) + assert isinstance(runs, list) + + +def test_get_run_with_run_id(mock_openai_hook, mock_run): + mock_openai_hook.conn.beta.threads.runs.retrieve.return_value = mock_run + run = mock_openai_hook.get_run(thread_id=THREAD_ID, run_id=RUN_ID) + assert run.id == RUN_ID + + +def test_modify_run(mock_openai_hook, mock_run): + mock_run.metadata = METADATA + mock_openai_hook.conn.beta.threads.runs.update.return_value = mock_run + message = mock_openai_hook.modify_run(thread_id=THREAD_ID, run_id=RUN_ID, metadata=METADATA) + assert message.metadata.get("modified") == "true" + assert message.metadata.get("user") == "abc123" + + def test_create_embeddings(mock_openai_hook, mock_embeddings_response): text = "Sample text" mock_openai_hook.conn.embeddings.create.return_value = mock_embeddings_response