diff --git a/python/composio/cli/add.py b/python/composio/cli/add.py index 31fa614763..a036597dd7 100644 --- a/python/composio/cli/add.py +++ b/python/composio/cli/add.py @@ -18,6 +18,7 @@ AppAuthScheme, AppModel, AuthSchemeField, + AuthSchemeType, IntegrationModel, ) from composio.client.exceptions import ComposioClientError @@ -227,12 +228,13 @@ def add_integration( ) if auth_mode is not None: + auth_mode = t.cast(AuthSchemeType, auth_mode) auth_scheme = auth_modes[auth_mode] elif len(auth_modes) == 1: ((auth_mode, auth_scheme),) = auth_modes.items() else: auth_mode = t.cast( - str, + AuthSchemeType, click.prompt( "Select auth mode: ", type=click.Choice(choices=list(auth_modes)), diff --git a/python/composio/client/__init__.py b/python/composio/client/__init__.py index ad2d075899..89828ff0f2 100644 --- a/python/composio/client/__init__.py +++ b/python/composio/client/__init__.py @@ -11,10 +11,12 @@ import requests from composio.client.collections import ( + AUTH_SCHEMES, Actions, ActiveTriggerModel, ActiveTriggers, Apps, + AuthSchemeType, ConnectedAccountModel, ConnectedAccounts, ConnectionRequestModel, @@ -406,6 +408,11 @@ def initiate_connection( app = self.client.apps.get(name=app_name) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") if integration is None and auth_mode is not None: + if auth_mode not in AUTH_SCHEMES: + raise ComposioClientError( + f"'auth_mode' should be one of {AUTH_SCHEMES}" + ) + auth_mode = t.cast(AuthSchemeType, auth_mode) if "OAUTH" not in auth_mode: use_composio_auth = False integration = self.client.integrations.create( diff --git a/python/composio/client/collections.py b/python/composio/client/collections.py index 831cab38dc..65a2a81301 100644 --- a/python/composio/client/collections.py +++ b/python/composio/client/collections.py @@ -40,6 +40,9 @@ if t.TYPE_CHECKING: from composio.client import Composio +AUTH_SCHEMES = ("OAUTH2", "OAUTH1", "API_KEY", "BASIC", "BEARER_TOKEN") +AuthSchemeType = t.Literal["OAUTH2", "OAUTH1", "API_KEY", "BASIC", "BEARER_TOKEN"] + def to_trigger_names( triggers: t.Union[t.List[str], t.List[Trigger], t.List[TriggerType]], @@ -282,7 +285,7 @@ class AppAuthScheme(BaseModel): """App authenticatio scheme.""" scheme_name: str - auth_mode: str + auth_mode: AuthSchemeType fields: t.List[AuthSchemeField] proxy: t.Optional[t.Dict] = None @@ -1355,7 +1358,7 @@ def create( self, app_id: str, name: t.Optional[str] = None, - auth_mode: t.Optional[str] = None, + auth_mode: t.Optional["AuthSchemeType"] = None, auth_config: t.Optional[t.Dict[str, t.Any]] = None, use_composio_auth: bool = False, force_new_integration: bool = False, diff --git a/python/composio/tools/toolset.py b/python/composio/tools/toolset.py index 1b25aa2014..5a4ec3517a 100644 --- a/python/composio/tools/toolset.py +++ b/python/composio/tools/toolset.py @@ -24,10 +24,12 @@ from composio import Action, ActionType, App, AppType, TagType from composio.client import Composio, Entity from composio.client.collections import ( + AUTH_SCHEMES, ActionModel, AppAuthScheme, AppModel, AuthSchemeField, + AuthSchemeType, ConnectedAccountModel, ConnectionParams, ConnectionRequestModel, @@ -76,7 +78,6 @@ MetadataType = t.Dict[_KeyType, t.Dict] ParamType = t.TypeVar("ParamType") ProcessorType = te.Literal["pre", "post", "schema"] -AuthSchemeType = t.Literal["OAUTH2", "OAUTH1", "API_KEY", "BASIC", "BEARER_TOKEN"] class IntegrationParams(te.TypedDict): @@ -1214,13 +1215,9 @@ def get_auth_scheme_for_app( if auth_scheme is not None: return auth_schemes[auth_scheme] - for scheme in ( - "OAUTH2", - "OAUTH1", - "API_KEY", - "BASIC", - ): + for scheme in AUTH_SCHEMES: if scheme in auth_schemes: + scheme = t.cast(AuthSchemeType, scheme) return auth_schemes[scheme] raise ComposioSDKError( @@ -1363,7 +1360,7 @@ def fetch_expected_integration_params( def create_integration( self, app: AppType, - auth_mode: t.Optional[str] = None, + auth_mode: t.Optional[AuthSchemeType] = None, auth_config: t.Optional[t.Dict[str, t.Any]] = None, use_composio_oauth_app: bool = True, force_new_integration: bool = False, @@ -1387,19 +1384,26 @@ def initiate_connection( entity_id: t.Optional[str] = None, redirect_url: t.Optional[str] = None, connected_account_params: t.Optional[t.Dict] = None, + *, + auth_scheme: t.Optional[AuthSchemeType] = None, ) -> ConnectionRequestModel: if integration_id is None and app is None: raise ComposioSDKError( message="Both `integration_id` and `app` cannot be None" ) + if auth_scheme is not None: + if auth_scheme not in AUTH_SCHEMES: + raise ComposioSDKError(f"'auth_scheme' must be one of {AUTH_SCHEMES}") + if integration_id is None: try: integration_id = self._get_integration_for_app( app=t.cast( AppType, app, - ) + ), + auth_scheme=auth_scheme, ).id except NoItemsFound as e: raise ComposioSDKError(