Skip to content

Commit

Permalink
OpenAI Chat & Assistant hook functions (apache#38736)
Browse files Browse the repository at this point in the history
* Adding hook functions for chats and assistants
* v1.16 supports filtering messages by run_id
* Updated provider dependencies for OpenAI
  • Loading branch information
nathadfield authored Apr 19, 2024
1 parent a6f612d commit 2674a69
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 3 deletions.
176 changes: 175 additions & 1 deletion airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,22 @@
from __future__ import annotations

from functools import cached_property
from typing import Any
from typing import TYPE_CHECKING, Any, Literal

from openai import OpenAI

if TYPE_CHECKING:
from openai.types.beta import Assistant, AssistantDeleted, Thread, ThreadDeleted
from openai.types.beta.threads import Message, Run
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionMessage,
ChatCompletionSystemMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)

from airflow.hooks.base import BaseHook


Expand Down Expand Up @@ -77,6 +89,168 @@ def get_conn(self) -> OpenAI:
**openai_client_kwargs,
)

def create_chat_completion(
self,
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
],
model: str = "gpt-3.5-turbo",
**kwargs: Any,
) -> list[ChatCompletionMessage]:
"""
Create a model response for the given chat conversation and returns a list of chat completions.
:param messages: A list of messages comprising the conversation so far
:param model: ID of the model to use
"""
response = self.conn.chat.completions.create(model=model, messages=messages, **kwargs)
return response.choices

def create_assistant(self, model: str = "gpt-3.5-turbo", **kwargs: Any) -> Assistant:
"""Create an OpenAI assistant using the given model.
:param model: The OpenAI model for the assistant to use.
"""
assistant = self.conn.beta.assistants.create(model=model, **kwargs)
return assistant

def get_assistant(self, assistant_id: str) -> Assistant:
"""
Get an OpenAI assistant.
:param assistant_id: The ID of the assistant to retrieve.
"""
assistant = self.conn.beta.assistants.retrieve(assistant_id=assistant_id)
return assistant

def get_assistants(self, **kwargs: Any) -> list[Assistant]:
"""Get a list of Assistant objects."""
assistants = self.conn.beta.assistants.list(**kwargs)
return assistants.data

def get_assistant_by_name(self, assistant_name: str) -> Assistant | None:
"""Get an OpenAI Assistant object for a given name.
:param assistant_name: The name of the assistant to retrieve
"""
response = self.get_assistants()
for assistant in response:
if assistant.name == assistant_name:
return assistant
return None

def modify_assistant(self, assistant_id: str, **kwargs: Any) -> Assistant:
"""Modify an existing Assistant object.
:param assistant_id: The ID of the assistant to be modified.
"""
assistant = self.conn.beta.assistants.update(assistant_id=assistant_id, **kwargs)
return assistant

def delete_assistant(self, assistant_id: str) -> AssistantDeleted:
"""Delete an OpenAI Assistant for a given ID.
:param assistant_id: The ID of the assistant to delete.
"""
response = self.conn.beta.assistants.delete(assistant_id=assistant_id)
return response

def create_thread(self, **kwargs: Any) -> Thread:
"""Create an OpenAI thread."""
thread = self.conn.beta.threads.create(**kwargs)
return thread

def modify_thread(self, thread_id: str, metadata: dict[str, Any]) -> Thread:
"""Modify an existing Thread object.
:param thread_id: The ID of the thread to modify.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
"""
thread = self.conn.beta.threads.update(thread_id=thread_id, metadata=metadata)
return thread

def delete_thread(self, thread_id: str) -> ThreadDeleted:
"""Delete an OpenAI thread for a given thread_id.
:param thread_id: The ID of the thread to delete.
"""
response = self.conn.beta.threads.delete(thread_id=thread_id)
return response

def create_message(
self, thread_id: str, role: Literal["user", "assistant"], content: str, **kwargs: Any
) -> Message:
"""Create a message for a given Thread.
:param thread_id: The ID of the thread to create a message for.
:param role: The role of the entity that is creating the message. Allowed values include: 'user', 'assistant'.
:param content: The content of the message.
"""
thread_message = self.conn.beta.threads.messages.create(
thread_id=thread_id, role=role, content=content, **kwargs
)
return thread_message

def get_messages(self, thread_id: str, **kwargs: Any) -> list[Message]:
"""Return a list of messages for a given Thread.
:param thread_id: The ID of the thread the messages belong to.
"""
messages = self.conn.beta.threads.messages.list(thread_id=thread_id, **kwargs)
return messages.data

def modify_message(self, thread_id: str, message_id, **kwargs: Any) -> Message:
"""Modify an existing message for a given Thread.
:param thread_id: The ID of the thread to which this message belongs.
:param message_id: The ID of the message to modify.
"""
thread_message = self.conn.beta.threads.messages.update(
thread_id=thread_id, message_id=message_id, **kwargs
)
return thread_message

def create_run(self, thread_id: str, assistant_id: str, **kwargs: Any) -> Run:
"""Create a run for a given thread and assistant.
:param thread_id: The ID of the thread to run.
:param assistant_id: The ID of the assistant to use to execute this run.
"""
run = self.conn.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id, **kwargs)
return run

def get_run(self, thread_id: str, run_id: str) -> Run:
"""Retrieve a run for a given thread and run.
:param thread_id: The ID of the thread that was run.
:param run_id: The ID of the run to retrieve.
"""
run = self.conn.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id)
return run

def get_runs(self, thread_id: str, **kwargs: Any) -> list[Run]:
"""
Return a list of runs belonging to a thread.
:param thread_id: The ID of the thread the run belongs to.
"""
runs = self.conn.beta.threads.runs.list(thread_id=thread_id, **kwargs)
return runs.data

def modify_run(self, thread_id: str, run_id: str, **kwargs: Any) -> Run:
"""
Modify a run on a given thread.
:param thread_id: The ID of the thread that was run.
:param run_id: The ID of the run to modify.
"""
run = self.conn.beta.threads.runs.update(thread_id=thread_id, run_id=run_id, **kwargs)
return run

def create_embeddings(
self,
text: str | list[str] | list[int] | list[list[int]],
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/openai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ integrations:

dependencies:
- apache-airflow>=2.6.0
- openai[datalib]>=1.0
- openai[datalib]>=1.16

hooks:
- integration-name: OpenAI
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@
"openai": {
"deps": [
"apache-airflow>=2.6.0",
"openai[datalib]>=1.0"
"openai[datalib]>=1.16"
],
"devel-deps": [],
"cross-providers-deps": [],
Expand Down
Loading

0 comments on commit 2674a69

Please sign in to comment.