Skip to content

Commit

Permalink
Merge pull request #51 from edenartlab/jmill/work
Browse files Browse the repository at this point in the history
Jmill/work
  • Loading branch information
genekogan authored Jan 7, 2025
2 parents f51b9cd + 2ceda04 commit 6107b0d
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 41 deletions.
1 change: 1 addition & 0 deletions eve/api/api_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ChatRequest(BaseModel):
thread_id: Optional[str] = None
update_config: Optional[UpdateConfig] = None
force_reply: bool = False
model: Optional[str] = None


class CronSchedule(BaseModel):
Expand Down
9 changes: 5 additions & 4 deletions eve/api/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

logger = logging.getLogger(__name__)
db = os.getenv("DB", "STAGE").upper()
env = "prod" if db == "PROD" else "stage"


async def handle_create(request: TaskRequest):
Expand Down Expand Up @@ -76,7 +77,7 @@ async def run_prompt():
user_messages=request.user_message,
tools=tools,
force_reply=request.force_reply,
model="claude-3-5-sonnet-20241022",
model=request.model,
stream=False,
):
data = {
Expand Down Expand Up @@ -116,7 +117,7 @@ async def event_generator():
user_messages=request.user_message,
tools=tools,
force_reply=request.force_reply,
model="claude-3-5-sonnet-20241022",
model=request.model,
stream=True,
):
data = {"type": update.type}
Expand Down Expand Up @@ -158,9 +159,9 @@ async def handle_deployment_create(request: CreateDeploymentRequest):
if request.credentials:
create_modal_secrets(
request.credentials,
f"{request.agent_key}-secrets-{db}",
f"{request.agent_key}-secrets-{env}",
)
deploy_client(request.agent_key, request.platform.value)
deploy_client(request.agent_key, request.platform.value, env)
return {
"status": "success",
"message": f"Deployed {request.platform.value} client",
Expand Down
17 changes: 9 additions & 8 deletions eve/api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,15 @@ def serialize_for_json(obj):
async def emit_update(
update_config: UpdateConfig, update_channel: AblyRealtime, data: dict
):
if update_config and update_config.update_endpoint:
raise ValueError("update_endpoint and sub_channel_name cannot be used together")
elif update_config.update_endpoint:
await emit_http_update(update_config, data)
elif update_config.sub_channel_name:
await emit_channel_update(update_channel, data)
else:
raise ValueError("One of update_endpoint or sub_channel_name must be provided")
if update_config:
if update_config.update_endpoint and update_config.sub_channel_name:
raise ValueError(
"update_endpoint and sub_channel_name cannot be used together"
)
elif update_config.update_endpoint:
await emit_http_update(update_config, data)
elif update_config.sub_channel_name:
await emit_channel_update(update_channel, data)


async def emit_http_update(update_config: UpdateConfig, data: dict):
Expand Down
2 changes: 1 addition & 1 deletion eve/cli/deploy_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def deploy(agent: str, all: bool, db: str):
for agent_name in agents:
click.echo(click.style(f"\nProcessing agent: {agent_name}", fg="blue"))
agent_path = root_dir / "eve" / "agents" / env / agent_name / "api.yaml"
process_agent(agent_path)
process_agent(agent_path, env)

else:
if not agent:
Expand Down
8 changes: 6 additions & 2 deletions eve/clients/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import time

import modal

from eve.models import ClientType

db = os.getenv("DB", "STAGE")

HOUR_IMAGE_LIMIT = 50
HOUR_VIDEO_LIMIT = 10
Expand Down Expand Up @@ -105,3 +104,8 @@ def register_tool_call(user, tool_name):
def get_ably_channel_name(agent_username: str, client_platform: ClientType):
env = os.getenv("UPDATE_CHANNEL_ENV", "DEV")
return f"{agent_username.lower()}_{client_platform.value}_{env}"


def get_eden_creation_url(creation_id: str):
root_url = "beta.eden.art" if db == "PROD" else "staging2.app.eden.art"
return f"https://{root_url}/creations/{creation_id}"
29 changes: 25 additions & 4 deletions eve/clients/discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,27 @@ async def async_callback(message):
elif update_type == UpdateType.TOOL_COMPLETE:
result = data.get("result", {})
result["result"] = prepare_result(result["result"])
url = result["result"][0]["output"][0]["url"]
await self.send_message(channel, url, reference=reference)
output = result["result"][0]["output"][0]
url = output["url"]

# Get creation ID from the output
creation_id = str(output.get("creation"))

if creation_id:
eden_url = common.get_eden_creation_url(creation_id)
view = discord.ui.View()
view.add_item(
discord.ui.Button(
label="View on Eden",
url=eden_url,
style=discord.ButtonStyle.link,
)
)
await self.send_message(
channel, url, reference=reference, view=view
)
else:
await self.send_message(channel, url, reference=reference)

elif update_type == UpdateType.END_PROMPT:
await self.stop_typing(channel)
Expand Down Expand Up @@ -243,10 +262,12 @@ async def on_message(self, message: discord.Message) -> None:
async def on_member_join(self, member):
print(f"{member} has joined the guild id: {member.guild.id}")

async def send_message(self, channel, content, reference=None, limit=2000):
async def send_message(
self, channel, content, reference=None, limit=2000, view=None
):
for i in range(0, len(content), limit):
chunk = content[i : i + limit]
await channel.send(chunk, reference=reference)
await channel.send(chunk, reference=reference, view=view)

async def start_typing(self, channel):
"""
Expand Down
26 changes: 18 additions & 8 deletions eve/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
REPO_BRANCH = "main"
DEPLOYMENT_ENV_NAME = "deployments"
db = os.getenv("DB", "STAGE").upper()
env = "prod" if db == "PROD" else "stage"


def authenticate_modal_key() -> bool:
Expand Down Expand Up @@ -65,19 +64,19 @@ def clone_repo(temp_dir: str):
)


def prepare_client_file(file_path: str, agent_key: str) -> None:
def prepare_client_file(file_path: str, agent_key: str, env: str) -> None:
"""Modify the client file to use correct secret name and fix pyproject path"""
with open(file_path, "r") as f:
content = f.read()

# Get the repo root directory (three levels up from the client file)
repo_root = Path(__file__).parent.parent.parent
repo_root = Path(__file__).parent.parent
pyproject_path = repo_root / "pyproject.toml"

# Replace the static secret name with the dynamic one
modified_content = content.replace(
'modal.Secret.from_name("client-secrets")',
f'modal.Secret.from_name("{agent_key}-secrets-{db}")',
f'modal.Secret.from_name("{agent_key}-secrets-{env}")',
)

# Fix pyproject.toml path to use absolute path
Expand All @@ -94,7 +93,7 @@ def prepare_client_file(file_path: str, agent_key: str) -> None:
return str(temp_file)


def deploy_client(agent_key: str, client_name: str):
def deploy_client(agent_key: str, client_name: str, env: str):
with tempfile.TemporaryDirectory() as temp_dir:
# Clone the repo
clone_repo(temp_dir)
Expand All @@ -105,9 +104,20 @@ def deploy_client(agent_key: str, client_name: str):
)
if os.path.exists(client_path):
# Modify the client file to use the correct secret name
prepare_client_file(client_path, agent_key)
temp_file = prepare_client_file(client_path, agent_key, env)
app_name = f"{agent_key}-{client_name}-{env}"

subprocess.run(
["modal", "deploy", client_path, "-e", DEPLOYMENT_ENV_NAME], check=True
[
"modal",
"deploy",
"--name",
app_name,
temp_file,
"-e",
DEPLOYMENT_ENV_NAME,
],
check=True,
)
else:
raise Exception(f"Client modal file not found: {client_path}")
Expand All @@ -119,7 +129,7 @@ def stop_client(agent_key: str, client_name: str):
"modal",
"app",
"stop",
f"{agent_key}-{client_name}-{db}",
f"{agent_key}-{client_name}-{db.lower()}",
"-e",
DEPLOYMENT_ENV_NAME,
],
Expand Down
103 changes: 89 additions & 14 deletions eve/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic.config import ConfigDict
from instructor.function_calls import openai_schema
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception
import json

from . import sentry_sdk
from .tool import Tool
Expand All @@ -33,6 +34,7 @@ class UpdateType(str, Enum):


models = ["claude-3-5-sonnet-20241022", "gpt-4o-mini", "gpt-4o-2024-08-06"]
DEFAULT_MODEL = "claude-3-5-sonnet-20241022"


async def async_anthropic_prompt(
Expand Down Expand Up @@ -67,9 +69,7 @@ async def async_anthropic_prompt(
[r.text for r in response.content if r.type == "text" and r.text]
)
tool_calls = [
ToolCall.from_anthropic(r)
for r in response.content
if r.type == "tool_use"
ToolCall.from_anthropic(r) for r in response.content if r.type == "tool_use"
]
stop = response.stop_reason == "end_turn"
return (content, tool_calls, stop)
Expand Down Expand Up @@ -116,9 +116,7 @@ async def async_anthropic_prompt_stream(
# Handle tool use
elif chunk.type == "content_block_stop" and hasattr(chunk, "content_block"):
if chunk.content_block.type == "tool_use":
tool_calls.append(
ToolCall.from_anthropic(chunk.content_block)
)
tool_calls.append(ToolCall.from_anthropic(chunk.content_block))

# Stop reason
elif chunk.type == "message_delta" and hasattr(chunk.delta, "stop_reason"):
Expand Down Expand Up @@ -171,6 +169,80 @@ async def async_openai_prompt(
return content, tool_calls, stop


async def async_openai_prompt_stream(
messages: List[Union[UserMessage, AssistantMessage]],
system_message: Optional[str],
model: str,
response_model: Optional[type[BaseModel]],
tools: Dict[str, Tool],
) -> AsyncGenerator[Tuple[UpdateType, str], None]:
"""Yields partial tokens (ASSISTANT_TOKEN, partial_text) for streaming."""
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY env is not set")

messages_json = [item for msg in messages for item in msg.openai_schema()]
if system_message:
messages_json = [{"role": "system", "content": system_message}] + messages_json

openai_client = openai.AsyncOpenAI()
tools_schema = (
[t.openai_schema(exclude_hidden=True) for t in tools.values()]
if tools
else None
)

if response_model:
# Response models not supported in streaming mode for OpenAI
raise NotImplementedError(
"Response models not supported in streaming mode for OpenAI"
)

stream = await openai_client.chat.completions.create(
model=model, messages=messages_json, tools=tools_schema, stream=True
)

tool_calls = []
current_tool_call = None

async for chunk in stream:
delta = chunk.choices[0].delta

# Handle text content
if delta.content:
yield (UpdateType.ASSISTANT_TOKEN, delta.content)

# Handle tool calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
if tool_call.index is not None:
# Ensure we have a list long enough
while len(tool_calls) <= tool_call.index:
tool_calls.append(None)

if tool_calls[tool_call.index] is None:
tool_calls[tool_call.index] = ToolCall(
tool=tool_call.function.name, args={}
)

if tool_call.function.arguments:
current_args = tool_calls[tool_call.index].args
# Merge new arguments with existing ones
try:
new_args = json.loads(tool_call.function.arguments)
current_args.update(new_args)
except json.JSONDecodeError:
pass

# Handle finish reason
if chunk.choices[0].finish_reason:
yield (UpdateType.ASSISTANT_STOP, chunk.choices[0].finish_reason)

# Yield any accumulated tool calls at the end
for tool_call in tool_calls:
if tool_call:
yield (UpdateType.TOOL_CALL, tool_call)


@retry(
retry=retry_if_exception(
lambda e: isinstance(e, (openai.RateLimitError, anthropic.RateLimitError))
Expand Down Expand Up @@ -253,14 +325,15 @@ async def async_prompt_stream(
Add a similar function for OpenAI if you need streaming from GPT-based models.
"""
if model.startswith("claude"):
# Stream from Anthropics
async for chunk in async_anthropic_prompt_stream(
messages, system_message, model, response_model, tools
):
yield chunk
else:
# NOTE: for streaming with OpenAI, implement a similar function if desired
raise NotImplementedError("Streaming not implemented for model: " + model)
async for chunk in async_openai_prompt_stream(
messages, system_message, model, response_model, tools
):
yield chunk


def anthropic_prompt(messages, system_message, model, response_model=None, tools=None):
Expand Down Expand Up @@ -331,9 +404,11 @@ async def async_prompt_thread(
user_messages: Union[UserMessage, List[UserMessage]],
tools: Dict[str, Tool],
force_reply: bool = True,
model: Literal[tuple(models)] = "claude-3-5-sonnet-20241022",
model: Literal[tuple(models)] = None,
stream: bool = False,
):
if not model:
model = DEFAULT_MODEL
print("================================================")
print(user_messages)
print("================================================")
Expand Down Expand Up @@ -398,6 +473,8 @@ async def async_prompt_thread(
model=model,
tools=tools,
):
print("UPDATE TYPE", update_type)
print("CONTENT", content)
# stream an individual token
if update_type == UpdateType.ASSISTANT_TOKEN:
if not content: # Skip empty content
Expand All @@ -413,7 +490,7 @@ async def async_prompt_thread(

# detect stop call
elif update_type == UpdateType.ASSISTANT_STOP:
stop = content == "end_turn"
stop = content == "end_turn" or content == "stop"

# Create assistant message from accumulated content
content = "".join(content_chunks)
Expand Down Expand Up @@ -484,9 +561,7 @@ async def async_prompt_thread(
raise Exception(f"Tool {tool_call.tool} not found.")

# start task
task = await tool.async_start_task(
user.id, agent.id, tool_call.args
)
task = await tool.async_start_task(user.id, agent.id, tool_call.args)

# update tool call with task id and status
thread.update_tool_call(
Expand Down

0 comments on commit 6107b0d

Please sign in to comment.