Skip to content

Commit

Permalink
Merge pull request #181 from ComposioHQ/feat/pusher
Browse files Browse the repository at this point in the history
Add support for listening to trigger events
  • Loading branch information
angrybayblade authored Jun 19, 2024
2 parents 6ca8c20 + fb4d1d5 commit ba01730
Show file tree
Hide file tree
Showing 19 changed files with 297 additions and 44 deletions.
18 changes: 4 additions & 14 deletions composio/cli/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
CLI Context.
"""

import os
import typing as t
from functools import update_wrapper
from pathlib import Path
Expand All @@ -13,11 +12,7 @@
from rich.console import Console

from composio.client import Composio
from composio.constants import (
ENV_COMPOSIO_API_KEY,
LOCAL_CACHE_DIRECTORY_NAME,
USER_DATA_FILE_NAME,
)
from composio.constants import LOCAL_CACHE_DIRECTORY_NAME, USER_DATA_FILE_NAME
from composio.storage.user import UserData


Expand All @@ -32,6 +27,8 @@ class Context:
_cache_dir: t.Optional[Path] = None
_console: t.Optional[Console] = None

using_api_key_from_env: bool = False

@property
def click_ctx(self) -> click.Context:
"""Click runtime context."""
Expand Down Expand Up @@ -59,17 +56,10 @@ def user_data(self) -> UserData:
path = self.cache_dir / USER_DATA_FILE_NAME
if not path.exists():
self._user_data = UserData(path=path)
self._user_data.api_key = os.environ.get(
ENV_COMPOSIO_API_KEY,
self._user_data.api_key,
)
self._user_data.store()

if self._user_data is None:
self._user_data = UserData.load(path=path)
self._user_data.api_key = os.environ.get(
ENV_COMPOSIO_API_KEY,
self._user_data.api_key,
)
return self._user_data

@property
Expand Down
1 change: 1 addition & 0 deletions composio/cli/logout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
composio logout
"""


import click

from composio.cli.context import Context, pass_context
Expand Down
2 changes: 1 addition & 1 deletion composio/cli/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _disable_trigger(context: Context, id: str) -> None:
try:
response = context.client.triggers.disable(id=id)
if response["status"] == "success":
context.console.print(f"Enabled trigger with ID: [green]{id}[/green]")
context.console.print(f"Disabled trigger with ID: [green]{id}[/green]")
return
raise click.ClickException(f"Could not disable trigger with ID: {id}")
except ComposioSDKError as e:
Expand Down
1 change: 1 addition & 0 deletions composio/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ class BaseClient:
"""Composio client abstraction."""

http: HttpClient
api_key: str
252 changes: 250 additions & 2 deletions composio/client/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,22 @@
"""

import base64
import json
import os
import time
import typing as t
import warnings

import pysher
import typing_extensions as te
from pydantic import BaseModel, ConfigDict
from pysher.channel import Channel

from composio.client.base import BaseClient, Collection
from composio.client.endpoints import v1
from composio.client.enums import Action, App, Tag, Trigger
from composio.client.exceptions import ComposioClientError
from composio.constants import PUSHER_CLUSTER, PUSHER_KEY

from .local_handler import LocalToolHandler

Expand Down Expand Up @@ -402,6 +407,226 @@ class FileModel(BaseModel):
content: bytes


class Connection(BaseModel):
id: str
integrationId: str
clientUniqueUserId: str
status: str


class Metadata(BaseModel):
id: str
connectionId: str
triggerName: str
triggerData: str
triggerConfig: t.Dict[str, t.Any]
connection: Connection


class TriggerEventData(BaseModel):
"""Trigger event payload."""

appName: str
payload: dict
originalPayload: t.Dict[str, t.Any]
metadata: Metadata

clientId: t.Optional[int] = None


class _ChunkedTriggerEventData(BaseModel):
"""Cunked trigger event data model."""

id: str
index: int
chunk: str
final: bool


class _TriggerEventFilters(te.TypedDict):
"""Trigger event filterset."""

app_name: te.NotRequired[str]
trigger_id: te.NotRequired[str]
connection_id: te.NotRequired[str]
trigger_name: te.NotRequired[str]
entity_id: te.NotRequired[str]
integration_id: te.NotRequired[str]


TriggerCallback = t.Callable[[TriggerEventData], None]


class TriggerSubscription:
"""Trigger subscription."""

_channel: Channel
_alive: bool

def __init__(self) -> None:
"""Initialize subscription object."""
self._alive = False
self._chunks: t.Dict[str, t.Dict[int, str]] = {}
self._callbacks: t.List[t.Tuple[TriggerCallback, _TriggerEventFilters]] = []

def callback(
self,
filters: t.Optional[_TriggerEventFilters] = None,
) -> t.Callable[[TriggerCallback], TriggerCallback]:
"""Register a trigger callaback."""

def _wrap(f: TriggerCallback) -> TriggerCallback:
self._callbacks.append((f, filters or {}))
return f

return _wrap

def _validate_filter(
self,
check: t.Any,
name: str,
filters: _TriggerEventFilters,
) -> None:
"""Check if filter is provided and raise if the values does not match."""
value = filters.get(name)
if value is None:
return
if value != check:
raise ValueError(
f"Skipping since `{name}` filter does not match the event",
)

def _handle_callback(
self,
callback: TriggerCallback,
data: TriggerEventData,
filters: _TriggerEventFilters,
) -> None:
"""Handle callback."""
for name, check in (
("app_name", data.appName),
("trigger_id", data.metadata.id),
("connection_id", data.metadata.connectionId),
("trigger_name", data.metadata.triggerName),
("entity_id", data.metadata.connection.clientUniqueUserId),
("integration_id", data.metadata.connection.integrationId),
):
self._validate_filter(
check=check,
name=name,
filters=filters,
)
callback(data)

def handle_event(self, event: str) -> None:
"""Filter events and call the callback function."""
try:
data = TriggerEventData(**json.loads(event))
except Exception as e:
print(f"Error decoding payload: {e}")
try:
for callback, filters in self._callbacks:
self._handle_callback(
callback=callback,
data=data,
filters=filters,
)
except BaseException as e:
print(f"Erorr handling event `{data.metadata.id}`: {e}")

def handle_chunked_events(self, event: str) -> None:
"""Handle chunked events."""
data = _ChunkedTriggerEventData(**json.loads(event))
if data.id not in self._chunks:
self._chunks[data.id] = {}

self._chunks[data.id][data.index] = data.chunk
if data.final:
_chunks = self._chunks.pop(data.id)
self.handle_event(
event="".join([_chunks[idx] for idx in sorted(_chunks)]),
)

def is_alive(self) -> bool:
"""Check if subscription is live."""
return self._alive

def set_alive(self) -> None:
"""Set `_alive` to True."""
self._alive = True

def listen(self) -> None:
"""Wait infinitely."""
while True:
time.sleep(1)


class _PusherClient:
"""Pusher client for Composio SDK."""

def __init__(self, client_id: str, base_url: str, api_key: str) -> None:
"""Initialize pusher client."""
self.client_id = client_id
self.base_url = base_url
self.api_key = api_key
self.subscription = TriggerSubscription()

def _get_connection_handler(
self,
client_id: str,
pusher: pysher.Pusher,
subscription: TriggerSubscription,
) -> t.Callable[[str], None]:
def _connection_handler(_: str) -> None:
channel = t.cast(
Channel,
pusher.subscribe(
channel_name=f"private-{client_id}_triggers",
),
)
channel.bind(
event_name="trigger_to_client",
callback=subscription.handle_event,
)
channel.bind(
event_name="chunked-trigger_to_client",
callback=subscription.handle_chunked_events,
)
subscription.set_alive()

return _connection_handler

def connect(self, timeout: float = 15.0) -> TriggerSubscription:
"""Connect to Pusher channel for given client ID."""
pusher = pysher.Pusher(
key=PUSHER_KEY,
cluster=PUSHER_CLUSTER,
auth_endpoint=f"{self.base_url}/v1/client/auth/pusher_auth?fromPython=true",
auth_endpoint_headers={
"x-api-key": self.api_key,
},
)
pusher.connection.bind(
"pusher:connection_established",
self._get_connection_handler(
client_id=self.client_id,
pusher=pusher,
subscription=self.subscription,
),
)
pusher.connect()

# Wait for connection to get established
deadline = time.time() + timeout
while time.time() < deadline:
if self.subscription.is_alive():
return self.subscription
time.sleep(0.5)
raise TimeoutError(
"Timed out while waiting for trigger listener to be established"
)


class Triggers(Collection[TriggerModel]):
"""Collection of triggers."""

Expand Down Expand Up @@ -462,12 +687,35 @@ def disable(self, id: str) -> t.Dict:
:param connected_account_id: ID of the relevant connected account
"""
response = self._raise_if_required(
self.client.http.post(
url=str(self.endpoint.disable / id),
self.client.http.patch(
url=str(self.endpoint / "instance" / id / "status"),
json={
"enabled": False,
},
)
)
return response.json()

def subscribe(self, timeout: float = 15.0) -> TriggerSubscription:
"""Subscribe to a trigger and receive trigger events."""
response = self._raise_if_required(
response=self.client.http.get(
url="/v1/client/auth/client_info",
)
)
client_id = response.json().get("client", {}).get("id")
if client_id is None:
raise ComposioClientError("Error fetching client ID")

pusher = _PusherClient(
client_id=client_id,
base_url=self.client.http.base_url,
api_key=self.client.api_key,
)
return pusher.connect(
timeout=timeout,
)


class ActiveTriggerModel(BaseModel):
"""Active trigger data model."""
Expand Down
2 changes: 1 addition & 1 deletion composio/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ def request(url: str, **kwargs: t.Any) -> t.Any:
return request

def __getattribute__(self, name: str) -> t.Any:
if name in ("get", "post", "put", "delete"):
if name in ("get", "post", "put", "delete", "patch"):
return self._wrap(super().__getattribute__(name))
return super().__getattribute__(name)
10 changes: 10 additions & 0 deletions composio/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@
"""
Composio API server base url -> web url mappings.
"""

PUSHER_KEY = "ff9f18c208855d77a152"
"""
API Key for Pusher subscriptions.
"""

PUSHER_CLUSTER = "mt1"
"""
Name of the pusher cluster.
"""
5 changes: 5 additions & 0 deletions composio/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ActionModel,
FileModel,
SuccessExecuteActionResponseModel,
TriggerSubscription,
)
from composio.client.enums import Action, App, Tag
from composio.client.local_handler import LocalToolHandler
Expand Down Expand Up @@ -185,3 +186,7 @@ def get_action_schemas(
items = items + remote_items

return items

def create_trigger_listener(self, timeout: float = 15.0) -> TriggerSubscription:
"""Create trigger subscription."""
return self.client.triggers.subscribe(timeout=timeout)
Loading

0 comments on commit ba01730

Please sign in to comment.