From d9d128e2709ef1397179c41fa93fd84696aee97e Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 19:59:48 +0800 Subject: [PATCH 01/11] Show ssh key length and name in debug logs --- tracecat/secrets/service.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tracecat/secrets/service.py b/tracecat/secrets/service.py index 2c50b4738..fdd716aeb 100644 --- a/tracecat/secrets/service.py +++ b/tracecat/secrets/service.py @@ -279,14 +279,14 @@ async def get_ssh_key( key_name: str = GIT_SSH_KEY_SECRET_NAME, environment: str | None = None, ) -> SecretKeyValue: - # NOTE: Don't set the workspace_id, as we want to search for - # organization secrets if it's not set. logger.info("Getting SSH key", key_name=key_name, role=self.role) try: secret = await self.get_org_secret_by_name(key_name, environment) + key = self.decrypt_keys(secret.encrypted_keys)[0] + logger.debug("SSH key found", key_name=key_name, key_length=len(key.value)) + return key except TracecatNotFoundError as e: raise TracecatNotFoundError( f"SSH key {key_name} not found. Please check whether this key exists.\n\n" " If not, please create a key in your organization's credentials page and try again." ) from e - return self.decrypt_keys(secret.encrypted_keys)[0] From 0c7c63807eb4cb63fb8a9b062aeb04cd12cd5edf Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:02:07 +0800 Subject: [PATCH 02/11] Validate actions directly in Temporal activity --- tracecat/dsl/action.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tracecat/dsl/action.py b/tracecat/dsl/action.py index 9fa7ff1e4..34e882d59 100644 --- a/tracecat/dsl/action.py +++ b/tracecat/dsl/action.py @@ -9,13 +9,15 @@ from temporalio.exceptions import ApplicationError from tracecat.contexts import ctx_logger, ctx_run +from tracecat.db.engine import get_async_session_context_manager from tracecat.dsl.common import context_locator from tracecat.dsl.models import ActionErrorInfo, ActionStatement, RunActionInput from tracecat.executor.client import ExecutorClient from tracecat.logger import logger from tracecat.registry.actions.models import RegistryActionValidateResponse from tracecat.types.auth import Role -from tracecat.types.exceptions import ExecutorClientError +from tracecat.types.exceptions import ExecutorClientError, RegistryError +from tracecat.validation.service import validate_registry_action_args def contextualize_message( @@ -61,10 +63,25 @@ async def validate_action_activity( - Validate the action arguments against the UDF spec. - Return the validated arguments. """ - client = ExecutorClient(role=input.role) - return await client.validate_action( - action_name=input.task.action, args=input.task.args - ) + try: + async with get_async_session_context_manager() as session: + result = await validate_registry_action_args( + session=session, + action_name=input.task.action, + args=input.task.args, + ) + + if result.status == "error": + logger.warning( + "Error validating UDF args", + message=result.msg, + details=result.detail, + ) + return RegistryActionValidateResponse.from_validation_result(result) + except KeyError as e: + raise RegistryError( + f"Action {input.task.action!r} not found in registry", + ) from e @staticmethod @activity.defn From 032559c2f8505152452c0265fef8dd437366ada1 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:07:57 +0800 Subject: [PATCH 03/11] Pull model executor basic implmenentation --- tracecat/executor/models.py | 10 ++++++ tracecat/executor/router.py | 14 +++++--- tracecat/executor/service.py | 60 ++++++++++++++++++++++++--------- tracecat/registry/repository.py | 50 +++++++++++---------------- 4 files changed, 83 insertions(+), 51 deletions(-) diff --git a/tracecat/executor/models.py b/tracecat/executor/models.py index b6bc87d18..c796320f8 100644 --- a/tracecat/executor/models.py +++ b/tracecat/executor/models.py @@ -1,10 +1,13 @@ from __future__ import annotations import traceback +from dataclasses import dataclass from pydantic import UUID4, BaseModel from tracecat.config import TRACECAT__APP_ENV +from tracecat.git import GitUrl +from tracecat.types.auth import Role class ExecutorSyncInput(BaseModel): @@ -56,3 +59,10 @@ def from_exc(e: Exception, action_name: str) -> ExecutorActionErrorInfo: function=tb.name, lineno=tb.lineno, ) + + +@dataclass +class DispatchActionContext: + role: Role + ssh_command: str | None = None + git_url: GitUrl | None = None diff --git a/tracecat/executor/router.py b/tracecat/executor/router.py index 298758bb2..fcd312078 100644 --- a/tracecat/executor/router.py +++ b/tracecat/executor/router.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from tracecat.auth.credentials import RoleACL -from tracecat.contexts import ctx_logger, ctx_role +from tracecat.contexts import ctx_logger from tracecat.db.dependencies import AsyncDBSession from tracecat.dsl.models import RunActionInput from tracecat.executor.models import ExecutorActionErrorInfo, ExecutorSyncInput @@ -16,7 +16,7 @@ ) from tracecat.registry.repository import RegistryReposService, Repository from tracecat.types.auth import Role -from tracecat.types.exceptions import WrappedExecutionError +from tracecat.types.exceptions import TracecatSettingsError, WrappedExecutionError from tracecat.validation.service import validate_registry_action_args router = APIRouter() @@ -53,18 +53,24 @@ async def run_action( allow_service=True, # Only services can execute actions require_workspace="no", ), + session: AsyncDBSession, action_name: str, action_input: RunActionInput, ) -> Any: """Execute a registry action.""" ref = action_input.task.ref - ctx_role.set(role) act_logger = logger.bind(role=role, action_name=action_name, ref=ref) ctx_logger.set(act_logger) act_logger.info("Starting action") + try: - return await dispatch_action_on_cluster(input=action_input, role=role) + return await dispatch_action_on_cluster(input=action_input, session=session) + except TracecatSettingsError as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"message": str(e)}, + ) from e except WrappedExecutionError as e: # This is an error that occurred inside an executing action err = e.error diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index f4c982d7d..7ccbda8eb 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -1,8 +1,3 @@ -"""Functions for executing actions and templates. - -NOTE: This is only used in the API server, not the worker -""" - from __future__ import annotations import asyncio @@ -13,6 +8,8 @@ import ray import uvloop from ray.exceptions import RayTaskError +from ray.runtime_env import RuntimeEnv +from sqlmodel.ext.asyncio.session import AsyncSession from tracecat import config from tracecat.auth.sandbox import AuthSandbox @@ -27,22 +24,22 @@ RunActionInput, ) from tracecat.executor.engine import EXECUTION_TIMEOUT -from tracecat.executor.models import ExecutorActionErrorInfo +from tracecat.executor.models import DispatchActionContext, ExecutorActionErrorInfo from tracecat.expressions.common import ExprContext, ExprOperand from tracecat.expressions.eval import ( eval_templated_object, extract_templated_secrets, get_iterables_from_expression, ) +from tracecat.git import GitUrl, prepare_git_url from tracecat.logger import logger from tracecat.parse import traverse_leaves -from tracecat.registry.actions.models import ( - BoundRegistryAction, -) +from tracecat.registry.actions.models import BoundRegistryAction from tracecat.registry.actions.service import RegistryActionsService from tracecat.secrets.common import apply_masks_object from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT from tracecat.secrets.secrets_manager import env_sandbox +from tracecat.ssh import opt_temp_key_file from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatException, WrappedExecutionError @@ -319,15 +316,27 @@ def run_action_task(input: RunActionInput, role: Role) -> ExecutionResult: async def run_action_on_ray_cluster( - input: RunActionInput, role: Role + input: RunActionInput, ctx: DispatchActionContext ) -> ExecutionResult: """Run an action on the ray cluster. If any exceptions are thrown here, they're platform level errors. All application/user level errors are caught by the executor and returned as values. """ + # Initialize runtime environment variables + env_vars = {"GIT_SSH_COMMAND": ctx.ssh_command} if ctx.ssh_command else {} + additional_vars: dict[str, Any] = {} + + # Add git URL to pip dependencies if SHA is present + if ctx.git_url and ctx.git_url.ref: + url = ctx.git_url.to_url() + additional_vars["pip"] = [url] + logger.info("Adding git URL to runtime env", git_url=ctx.git_url, url=url) - obj_ref = run_action_task.remote(input, role) + runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars) + + logger.info("Running action on ray cluster", runtime_env=runtime_env) + obj_ref = run_action_task.options(runtime_env=runtime_env).remote(input, ctx.role) try: coro = asyncio.to_thread(ray.get, obj_ref) exec_result = await asyncio.wait_for(coro, timeout=EXECUTION_TIMEOUT) @@ -349,7 +358,12 @@ async def run_action_on_ray_cluster( return exec_result -async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: +async def dispatch_action_on_cluster( + input: RunActionInput, + *, + session: AsyncSession, + git_url: GitUrl | None = None, +) -> Any: """Schedule actions on the ray cluster. This function handles dispatching actions to be executed on a Ray cluster. It supports @@ -358,7 +372,7 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: Args: input: The RunActionInput containing the task definition and execution context role: The Role used for authorization - + git_url: The Git URL to use for the action Returns: Any: For single actions, returns the ExecutionResult. For for_each loops, returns a list of results from all parallel executions. @@ -367,12 +381,26 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: TracecatException: If there are errors evaluating for_each expressions or during execution ExecutorErrorWrapper: If there are errors from the executor itself """ + git_url = await prepare_git_url() - task = input.task + role = ctx_role.get() + + async with opt_temp_key_file(git_url=git_url, session=session) as ssh_command: + logger.info("SSH command", ssh_command=ssh_command) + ctx = DispatchActionContext(role=role, git_url=git_url, ssh_command=ssh_command) + result = await _dispatch_action(input=input, ctx=ctx) + return result + +async def _dispatch_action( + input: RunActionInput, + ctx: DispatchActionContext, +) -> Any: + task = input.task + logger.info("Preparing runtime environment", ctx=ctx) # If there's no for_each, execute normally if not task.for_each: - return await run_action_on_ray_cluster(input, role) + return await run_action_on_ray_cluster(input, ctx) logger.info("Running for_each on action in parallel", action=task.action) @@ -383,7 +411,7 @@ async def dispatch_action_on_cluster(input: RunActionInput, role: Role) -> Any: iterators = get_iterables_from_expression(expr=task.for_each, operand=base_context) async def coro(patched_input: RunActionInput): - return await run_action_on_ray_cluster(patched_input, role) + return await run_action_on_ray_cluster(patched_input, ctx) try: async with GatheringTaskGroup() as tg: diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index 830c60e3a..b44e7a114 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -29,9 +29,10 @@ from tracecat import config from tracecat.contexts import ctx_role +from tracecat.db.engine import get_async_session_context_manager from tracecat.expressions.expectations import create_expectation_model from tracecat.expressions.validation import TemplateValidator -from tracecat.git import get_git_repository_sha, parse_git_url +from tracecat.git import GitUrl, get_git_repository_sha, parse_git_url from tracecat.logger import logger from tracecat.parse import safe_url from tracecat.registry.actions.models import BoundRegistryAction, TemplateAction @@ -41,14 +42,8 @@ ) from tracecat.registry.repositories.models import RegistryRepositoryCreate from tracecat.registry.repositories.service import RegistryReposService -from tracecat.secrets.service import SecretsService from tracecat.settings.service import get_setting -from tracecat.ssh import ( - SshEnv, - add_host_to_known_hosts, - add_ssh_key_to_agent, - temporary_ssh_agent, -) +from tracecat.ssh import SshEnv, ssh_context from tracecat.types.auth import Role from tracecat.types.exceptions import RegistryError @@ -117,10 +112,6 @@ def get(self, name: str) -> BoundRegistryAction[ArgsClsT]: """Retrieve a registered udf.""" return self._store[name] - def safe_remote_url(self, remote_registry_url: str) -> str: - """Clean a remote registry url.""" - return safe_url(remote_registry_url) - def init(self, include_base: bool = True, include_templates: bool = True) -> None: """Initialize the registry.""" if not self._is_initialized: @@ -219,11 +210,6 @@ def register_template_action( origin=origin, ) - def _reset(self) -> None: - logger.warning("Resetting registry") - self._store = {} - self._is_initialized = False - def _load_base_udfs(self) -> None: """Load all udfs and template actions into the registry.""" # Load udfs @@ -258,8 +244,9 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: or {"github.com"}, ) + cleaned_url = safe_url(self._origin) try: - git_url = parse_git_url(self._origin, allowed_domains=allowed_domains) + git_url = parse_git_url(cleaned_url, allowed_domains=allowed_domains) host = git_url.host org = git_url.org repo_name = git_url.repo @@ -284,35 +271,36 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None: package_name=package_name, ) - cleaned_url = self.safe_remote_url(self._origin) - logger.debug("Cleaned URL", url=cleaned_url) + logger.debug("Git URL", git_url=git_url) commit_sha = await self._install_remote_repository( - host=host, repo_url=cleaned_url, commit_sha=commit_sha + git_url=git_url, commit_sha=commit_sha ) module = await self._load_remote_repository(cleaned_url, package_name) logger.info( "Imported and reloaded remote repository", module_name=module.__name__, package_name=package_name, + commit_sha=commit_sha, ) return commit_sha async def _install_remote_repository( - self, host: str, repo_url: str, commit_sha: str | None = None + self, git_url: GitUrl, commit_sha: str | None = None ) -> str: """Install the remote repository into the filesystem and return the commit sha.""" - logger.info("Getting SSH key", role=self.role) - async with SecretsService.with_session(role=self.role) as service: - secret = await service.get_ssh_key() + logger.info("Getting SSH key", role=self.role, git_url=git_url) - async with temporary_ssh_agent() as env: - logger.info("Entered temporary SSH agent context") - await add_ssh_key_to_agent(secret.reveal().value, env=env) - await add_host_to_known_hosts(host, env=env) + url = git_url.to_url() + async with ( + get_async_session_context_manager() as session, + ssh_context(role=self.role, git_url=git_url, session=session) as env, + ): + if env is None: + raise RegistryError("No SSH key found") if commit_sha is None: - commit_sha = await get_git_repository_sha(repo_url, env=env) - await install_remote_repository(repo_url, commit_sha=commit_sha, env=env) + commit_sha = await get_git_repository_sha(url, env=env) + await install_remote_repository(url, commit_sha=commit_sha, env=env) return commit_sha async def _load_remote_repository( From 2cf71f0a049845c3b3aacdf772440e9ecabd82b2 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:16:48 +0800 Subject: [PATCH 04/11] Don't sync executor on start --- tracecat/api/executor.py | 43 ---------------------------------------- 1 file changed, 43 deletions(-) diff --git a/tracecat/api/executor.py b/tracecat/api/executor.py index 1983368b2..8877638f7 100644 --- a/tracecat/api/executor.py +++ b/tracecat/api/executor.py @@ -6,7 +6,6 @@ from tracecat import config from tracecat.api.common import ( - bootstrap_role, custom_generate_unique_id, generic_exception_handler, setup_oss_models, @@ -16,9 +15,6 @@ from tracecat.executor.router import router as executor_router from tracecat.logger import logger from tracecat.middleware import RequestLoggingMiddleware -from tracecat.registry.repositories.service import RegistryReposService -from tracecat.registry.repository import Repository -from tracecat.settings.service import get_setting from tracecat.types.exceptions import TracecatException @@ -28,49 +24,10 @@ async def lifespan(app: FastAPI): await setup_oss_models() except Exception as e: logger.error("Failed to preload OSS models", error=e) - try: - await setup_custom_remote_repository() - except Exception as e: - logger.error("Error setting up custom remote repository", exc=e) - with setup_ray(): yield -async def setup_custom_remote_repository(): - """Install the remote repository if it is set. - - Steps - ----- - 1. Get the SHA of the remote repository from the DB - 2. If it doesn't exist, create it - 3. If it does exist, sync it - """ - role = bootstrap_role() - url = await get_setting( - "git_repo_url", - role=role, - # TODO: Deprecate in future version - default=config.TRACECAT__REMOTE_REPOSITORY_URL, - ) - if not url: - logger.info("Remote repository URL not set, skipping") - return - logger.info("Remote repository URL found", url=url) - async with RegistryReposService.with_session(role) as service: - db_repo = await service.get_repository(url) - # If it doesn't exist, do nothing - if db_repo is None: - logger.warning("Remote repository not found in DB, skipping") - return - # If it does exist, sync it - if db_repo.last_synced_at is None: - logger.info("Remote repository not synced, skipping") - return - repo = Repository(db_repo.origin, role=role) - await repo.load_from_origin(commit_sha=db_repo.commit_sha) - - def create_app(**kwargs) -> FastAPI: if config.TRACECAT__ALLOW_ORIGINS is not None: allow_origins = config.TRACECAT__ALLOW_ORIGINS.split(",") From ee06591c0e6bb9f12fde22c91b22f18354758b33 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:17:57 +0800 Subject: [PATCH 05/11] Remove obsolete executor endpoints --- tracecat/executor/client.py | 66 +----------------------- tracecat/executor/router.py | 61 +--------------------- tracecat/registry/repositories/router.py | 36 ------------- 3 files changed, 2 insertions(+), 161 deletions(-) diff --git a/tracecat/executor/client.py b/tracecat/executor/client.py index afb88d304..c16215d21 100644 --- a/tracecat/executor/client.py +++ b/tracecat/executor/client.py @@ -7,19 +7,12 @@ import httpx import orjson -from pydantic import UUID4 -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) from tracecat import config from tracecat.clients import AuthenticatedServiceClient from tracecat.contexts import ctx_role from tracecat.dsl.models import RunActionInput -from tracecat.executor.models import ExecutorActionErrorInfo, ExecutorSyncInput +from tracecat.executor.models import ExecutorActionErrorInfo from tracecat.logger import logger from tracecat.registry.actions.models import ( RegistryActionValidateResponse, @@ -118,63 +111,6 @@ async def validate_action( f"Unexpected error while listing registries: {str(e)}" ) from e - # === Management === - - async def sync_executor( - self, repository_id: UUID4, *, max_attempts: int = 3 - ) -> None: - """Sync the executor from the registry. - - Args: - origin: The origin of the sync request - - Raises: - RegistryError: If the sync fails after all retries - """ - - @retry( - stop=stop_after_attempt(max_attempts), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type( - ( - httpx.HTTPStatusError, - httpx.RequestError, - httpx.TimeoutException, - httpx.ConnectError, - ) - ), - ) - async def _sync_request() -> None: - try: - async with self._client() as client: - response = await client.post( - "/sync", - content=ExecutorSyncInput( - repository_id=repository_id - ).model_dump_json(), - headers={"Content-Type": "application/json"}, - ) - response.raise_for_status() - except Exception as e: - logger.error("Error syncing executor", error=e) - raise - - try: - logger.info("Syncing executor", repository_id=repository_id) - _ = await _sync_request() - except httpx.HTTPStatusError as e: - raise RegistryError( - f"Failed to sync executor: HTTP {e.response.status_code}" - ) from e - except httpx.RequestError as e: - raise RegistryError( - f"Network error while syncing executor: {str(e)}" - ) from e - except Exception as e: - raise RegistryError( - f"Unexpected error while syncing executor: {str(e)}" - ) from e - # === Utility === def _handle_http_status_error( diff --git a/tracecat/executor/router.py b/tracecat/executor/router.py index fcd312078..f7002c730 100644 --- a/tracecat/executor/router.py +++ b/tracecat/executor/router.py @@ -7,44 +7,15 @@ from tracecat.contexts import ctx_logger from tracecat.db.dependencies import AsyncDBSession from tracecat.dsl.models import RunActionInput -from tracecat.executor.models import ExecutorActionErrorInfo, ExecutorSyncInput +from tracecat.executor.models import ExecutorActionErrorInfo from tracecat.executor.service import dispatch_action_on_cluster from tracecat.logger import logger -from tracecat.registry.actions.models import ( - RegistryActionValidate, - RegistryActionValidateResponse, -) -from tracecat.registry.repository import RegistryReposService, Repository from tracecat.types.auth import Role from tracecat.types.exceptions import TracecatSettingsError, WrappedExecutionError -from tracecat.validation.service import validate_registry_action_args router = APIRouter() -@router.post("/sync") -async def sync_executor( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot sync the executor - allow_service=True, # Only services can sync the executor - require_workspace="no", - ), - session: AsyncDBSession, - input: ExecutorSyncInput, -) -> None: - """Sync the executor from the registry.""" - rr_service = RegistryReposService(session, role=role) - db_repo = await rr_service.get_repository_by_id(input.repository_id) - # If it doesn't exist, do nothing - if db_repo is None: - logger.info("Remote repository not found in DB, skipping") - return - # If it does exist, sync it - repo = Repository(db_repo.origin, role=role) - await repo.load_from_origin(commit_sha=db_repo.commit_sha) - - @router.post("/run/{action_name}", tags=["execution"]) async def run_action( *, @@ -91,33 +62,3 @@ async def run_action( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=err_info_dict, ) from e - - -@router.post("/validate/{action_name}") -async def validate_action( - *, - role: Role = RoleACL( - allow_user=False, # XXX(authz): Users cannot validate actions - allow_service=True, # Only services can validate actions - require_workspace="no", - ), - session: AsyncDBSession, - action_name: str, - params: RegistryActionValidate, -) -> RegistryActionValidateResponse: - """Validate a registry action.""" - try: - result = await validate_registry_action_args( - session=session, action_name=action_name, args=params.args - ) - - if result.status == "error": - logger.warning( - "Error validating UDF args", message=result.msg, details=result.detail - ) - return RegistryActionValidateResponse.from_validation_result(result) - except KeyError as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Action {action_name!r} not found in registry", - ) from e diff --git a/tracecat/registry/repositories/router.py b/tracecat/registry/repositories/router.py index 19f70c1dd..85fd41fea 100644 --- a/tracecat/registry/repositories/router.py +++ b/tracecat/registry/repositories/router.py @@ -6,7 +6,6 @@ from tracecat.auth.credentials import RoleACL from tracecat.db.dependencies import AsyncDBSession -from tracecat.executor.client import ExecutorClient from tracecat.logger import logger from tracecat.registry.actions.models import RegistryActionRead from tracecat.registry.actions.service import RegistryActionsService @@ -102,41 +101,6 @@ async def sync_registry_repository( ) from e -@router.post("/{repository_id}/sync-executor", status_code=status.HTTP_204_NO_CONTENT) -async def sync_executor_from_registry_repository( - *, - role: Role = RoleACL( - allow_user=True, - allow_service=False, - require_workspace="no", - min_access_level=AccessLevel.ADMIN, - ), - session: AsyncDBSession, - repository_id: UUID4, -): - # # We might want to update the executor's view of the repository here - # # (3) Update the executor's view of the repository - rr_service = RegistryReposService(session, role) - try: - repo = await rr_service.get_repository_by_id(repository_id) - except NoResultFound as e: - logger.error("Registry repository not found", repository_id=repository_id) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Registry repository not found", - ) from e - logger.info("Syncing executor", origin=repo.origin) - client = ExecutorClient(role=role) - try: - await client.sync_executor(repository_id=repo.id) - except RegistryError as e: - logger.warning("Error syncing executor", exc=e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Error while syncing executor {repo.origin!r}: {e}", - ) from e - - @router.get("") async def list_registry_repositories( *, From dd5abd593ceef80bf84b67206eadd8963f6424c3 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:18:05 +0800 Subject: [PATCH 06/11] Update openapi client --- frontend/src/client/services.gen.ts | 24 ------------------------ frontend/src/client/types.gen.ts | 22 ---------------------- 2 files changed, 46 deletions(-) diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 5c136b687..fbcd5d760 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -78,8 +78,6 @@ import type { RegistryRepositoriesGetRegistryRepositoryResponse, RegistryRepositoriesListRegistryRepositoriesResponse, RegistryRepositoriesReloadRegistryRepositoriesResponse, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryData, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryResponse, RegistryRepositoriesSyncRegistryRepositoryData, RegistryRepositoriesSyncRegistryRepositoryResponse, RegistryRepositoriesUpdateRegistryRepositoryData, @@ -1894,28 +1892,6 @@ export const registryRepositoriesSyncRegistryRepository = ( }) } -/** - * Sync Executor From Registry Repository - * @param data The data for the request. - * @param data.repositoryId - * @returns void Successful Response - * @throws ApiError - */ -export const registryRepositoriesSyncExecutorFromRegistryRepository = ( - data: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData -): CancelablePromise => { - return __request(OpenAPI, { - method: "POST", - url: "/registry/repos/{repository_id}/sync-executor", - path: { - repository_id: data.repositoryId, - }, - errors: { - 422: "Validation Error", - }, - }) -} - /** * List Registry Repositories * List all registry repositories. diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 9f4f0e949..db1f33e16 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -1884,13 +1884,6 @@ export type RegistryRepositoriesSyncRegistryRepositoryData = { export type RegistryRepositoriesSyncRegistryRepositoryResponse = void -export type RegistryRepositoriesSyncExecutorFromRegistryRepositoryData = { - repositoryId: string -} - -export type RegistryRepositoriesSyncExecutorFromRegistryRepositoryResponse = - void - export type RegistryRepositoriesListRegistryRepositoriesResponse = Array @@ -3015,21 +3008,6 @@ export type $OpenApiTs = { } } } - "/registry/repos/{repository_id}/sync-executor": { - post: { - req: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData - res: { - /** - * Successful Response - */ - 204: void - /** - * Validation Error - */ - 422: HTTPValidationError - } - } - } "/registry/repos": { get: { res: { From f5cd95adda740a4ed467935436afdb20d928a13f Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:18:20 +0800 Subject: [PATCH 07/11] Add dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index cad00840e..9b5c7dc24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies = [ "tenacity==8.3.0", "uv==0.4.10", "uvicorn==0.29.0", + "virtualenv==20.27.0", ] dynamic = ["version"] From e8fe8ff96d4780830925700a0c2193949f8ac093 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:26:36 +0800 Subject: [PATCH 08/11] Update UI --- .../registry/registry-repos-table.tsx | 139 +----------------- frontend/src/lib/hooks.tsx | 39 ----- 2 files changed, 4 insertions(+), 174 deletions(-) diff --git a/frontend/src/components/registry/registry-repos-table.tsx b/frontend/src/components/registry/registry-repos-table.tsx index c2f44f9a2..85db1bc82 100644 --- a/frontend/src/components/registry/registry-repos-table.tsx +++ b/frontend/src/components/registry/registry-repos-table.tsx @@ -5,7 +5,6 @@ import { RegistryRepositoryReadMinimal } from "@/client" import { DropdownMenuLabel } from "@radix-ui/react-dropdown-menu" import { DotsHorizontalIcon } from "@radix-ui/react-icons" import { - ArrowRightToLineIcon, CopyIcon, LoaderCircleIcon, RefreshCcw, @@ -54,8 +53,6 @@ export function RegistryRepositoriesTable() { syncRepo, syncRepoIsPending, deleteRepo, - syncExecutor, - syncExecutorIsPending, } = useRegistryRepositories() const [selectedRepo, setSelectedRepo] = useState(null) @@ -101,7 +98,7 @@ export function RegistryRepositoriesTable() { label: (
- Sync only + Sync
), action: async () => { @@ -130,120 +127,6 @@ export function RegistryRepositoriesTable() { } }, }, - { - label: ( -
- - Sync and push to executor -
- ), - action: async () => { - if (!selectedRepo) { - console.error("No repository selected") - return - } - console.log("Reloading repository", selectedRepo.origin) - try { - await syncRepo({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully synced repository", - description: ( -
-
- Successfully reloaded actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - await syncExecutor({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully pushed to executor", - description: ( -
-
- Successfully pushed actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - } catch (error) { - console.error("Error reloading repository", error) - } finally { - setSelectedRepo(null) - } - }, - }, - ], - } - case AlertAction.SYNC_EXECUTOR: - return { - title: "Push to executor", - description: ( -
- - You are about to push the current version of the repository{" "} - - - {selectedRepo?.origin} - - to the executor. - {selectedRepo?.commit_sha && ( -
- Current SHA: - - {selectedRepo.commit_sha} - -
- )} - {selectedRepo?.last_synced_at && ( -
- Last synced: - - {new Date(selectedRepo.last_synced_at).toLocaleString()} - -
- )} -

- Are you sure you want to proceed? This will reload all existing - modules from this repository on the executor. -

-
- ), - actions: [ - { - label: ( -
- - Push to executor -
- ), - action: async () => { - if (!selectedRepo) { - console.error("No repository selected") - return - } - try { - await syncExecutor({ repositoryId: selectedRepo.id }) - toast({ - title: "Successfully synced executor", - description: ( -
-
- Successfully reloaded actions from{" "} - {selectedRepo.origin} -
-
- ), - }) - } catch (error) { - console.error("Error syncing executor", error) - } finally { - setSelectedRepo(null) - } - }, - }, ], } case AlertAction.DELETE: @@ -402,11 +285,11 @@ export function RegistryRepositoriesTable() { > Open menu {row.original.id === selectedRepo?.id && - (syncRepoIsPending || syncExecutorIsPending) ? ( + syncRepoIsPending ? (
- {syncRepoIsPending ? "Pulling..." : "Pushing..."} + Pulling...
) : ( @@ -473,20 +356,6 @@ export function RegistryRepositoriesTable() { Sync from remote - {row.original.last_synced_at !== null && ( - { - e.stopPropagation() // Prevent row click - setSelectedRepo(row.original) - setAlertAction(AlertAction.SYNC_EXECUTOR) - setAlertOpen(true) - }} - > - - Push to executor - - )} { @@ -526,7 +395,7 @@ export function RegistryRepositoriesTable() { setAlertOpen(false) await action.action() }} - disabled={syncRepoIsPending || syncExecutorIsPending} + disabled={syncRepoIsPending} > {action.label} diff --git a/frontend/src/lib/hooks.tsx b/frontend/src/lib/hooks.tsx index aac44d3e5..ec9c75f42 100644 --- a/frontend/src/lib/hooks.tsx +++ b/frontend/src/lib/hooks.tsx @@ -37,8 +37,6 @@ import { RegistryRepositoriesDeleteRegistryRepositoryData, registryRepositoriesListRegistryRepositories, registryRepositoriesReloadRegistryRepositories, - registryRepositoriesSyncExecutorFromRegistryRepository, - RegistryRepositoriesSyncExecutorFromRegistryRepositoryData, registryRepositoriesSyncRegistryRepository, RegistryRepositoriesSyncRegistryRepositoryData, RegistryRepositoryReadMinimal, @@ -1159,40 +1157,6 @@ export function useRegistryRepositories() { }, }) - const { - mutateAsync: syncExecutor, - isPending: syncExecutorIsPending, - error: syncExecutorError, - } = useMutation({ - mutationFn: async ( - params: RegistryRepositoriesSyncExecutorFromRegistryRepositoryData - ) => await registryRepositoriesSyncExecutorFromRegistryRepository(params), - onSuccess: () => { - queryClient.invalidateQueries({ queryKey: ["registry_repositories"] }) - queryClient.invalidateQueries({ queryKey: ["registry_actions"] }) - toast({ - title: "Synced executor", - description: "Executor synced successfully.", - }) - }, - onError: (error: TracecatApiError) => { - const apiError = error as TracecatApiError - switch (apiError.status) { - case 403: - toast({ - title: "You cannot perform this action", - description: `${apiError.message}: ${apiError.body.detail}`, - }) - break - default: - toast({ - title: "Failed to sync executor", - description: `An unexpected error occurred while syncing the executor. ${apiError.message}: ${apiError.body.detail}`, - }) - } - }, - }) - return { repos, reposIsLoading, @@ -1203,9 +1167,6 @@ export function useRegistryRepositories() { deleteRepo, deleteRepoIsPending, deleteRepoError, - syncExecutor, - syncExecutorIsPending, - syncExecutorError, } } From abcd5477415a919a91d3459b0dca354bcff5cd6a Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:43:19 +0800 Subject: [PATCH 09/11] Use ExpressionStr instead of typing.Annotated with string --- tracecat/dsl/models.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/tracecat/dsl/models.py b/tracecat/dsl/models.py index 2be334583..a19264c21 100644 --- a/tracecat/dsl/models.py +++ b/tracecat/dsl/models.py @@ -2,14 +2,14 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Annotated, Any, Literal, TypedDict +from typing import Any, Literal, TypedDict from pydantic import BaseModel, Field from tracecat.dsl.constants import DEFAULT_ACTION_TIMEOUT from tracecat.dsl.enums import JoinStrategy from tracecat.expressions.common import ExprContext -from tracecat.expressions.validation import ExpressionStr, TemplateValidator +from tracecat.expressions.validation import ExpressionStr from tracecat.identifiers import WorkflowExecutionID, WorkflowID, WorkflowRunID from tracecat.secrets.constants import DEFAULT_SECRETS_ENVIRONMENT @@ -94,20 +94,13 @@ class ActionStatement(BaseModel): """Control flow options""" - run_if: Annotated[ - str | None, - Field(default=None, description="Condition to run the task"), - TemplateValidator(), - ] - - for_each: Annotated[ - str | list[str] | None, - Field( - default=None, - description="Iterate over a list of items and run the task for each item.", - ), - TemplateValidator(), - ] + run_if: ExpressionStr | None = Field( + default=None, description="Condition to run the task" + ) + for_each: ExpressionStr | list[ExpressionStr] | None = Field( + default=None, + description="Iterate over a list of items and run the task for each item.", + ) retry_policy: ActionRetryPolicy = Field( default_factory=ActionRetryPolicy, description="Retry policy for the action." ) From 292ac1974c199468c9d786df6eed8b005a3428d5 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:43:33 +0800 Subject: [PATCH 10/11] Add tests --- tests/unit/test_executor_service.py | 145 ++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 tests/unit/test_executor_service.py diff --git a/tests/unit/test_executor_service.py b/tests/unit/test_executor_service.py new file mode 100644 index 000000000..85d1f6bd2 --- /dev/null +++ b/tests/unit/test_executor_service.py @@ -0,0 +1,145 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest + +from tracecat.dsl.models import ActionStatement, RunActionInput, RunContext +from tracecat.executor.models import DispatchActionContext +from tracecat.executor.service import _dispatch_action, dispatch_action_on_cluster +from tracecat.expressions.common import ExprContext +from tracecat.git import GitUrl +from tracecat.types.auth import Role + + +@pytest.fixture +def mock_session(): + return AsyncMock() + + +@pytest.fixture +def basic_task_input(): + """Fixture that provides a basic RunActionInput without looping.""" + wf_id = "wf-" + uuid.uuid4().hex + wf_exec_id = wf_id + ":exec-test" + wf_run_id = uuid.uuid4() + return RunActionInput( + task=ActionStatement( + action="test_action", + args={"key": "value"}, + ref="test_ref", + ), + exec_context={ + ExprContext.ACTIONS: { + "test_action": { + "args": {"key": "value"}, + "ref": "test-ref", + } + } + }, + run_context=RunContext( + wf_id=wf_id, + wf_exec_id=wf_exec_id, + wf_run_id=wf_run_id, + environment="test-env", + ), + ) + + +@pytest.fixture +def basic_looped_task_input(): + wf_id = "wf-" + uuid.uuid4().hex + wf_exec_id = wf_id + ":exec-test" + wf_run_id = uuid.uuid4() + return RunActionInput( + task=ActionStatement( + action="test_action", + args={"key": "value"}, + ref="test_ref", + for_each="${{ for var.x in [1,2,3] }}", + ), + exec_context={ + ExprContext.ACTIONS: { + "test_action": { + "args": {"key": "value"}, + "ref": "test-ref", + } + } + }, + run_context=RunContext( + wf_id=wf_id, + wf_exec_id=wf_exec_id, + wf_run_id=wf_run_id, + environment="test-env", + ), + ) + + +@pytest.fixture +def dispatch_context(): + return DispatchActionContext( + role=Role(type="service", service_id="tracecat-executor"), + ssh_command="ssh -i /tmp/key", + git_url=GitUrl(host="github.com", org="org", repo="repo", ref="abc123"), + ) + + +@pytest.mark.anyio +async def test_dispatch_action_basic(mock_session, basic_task_input, dispatch_context): + with patch("tracecat.executor.service.run_action_on_ray_cluster") as mock_ray: + mock_ray.return_value = {"result": "success"} + + result = await _dispatch_action(input=basic_task_input, ctx=dispatch_context) + + assert result == {"result": "success"} + mock_ray.assert_called_once_with(basic_task_input, dispatch_context) + + +@pytest.mark.anyio +async def test_dispatch_action_with_foreach( + mock_session, basic_looped_task_input, dispatch_context +): + with patch("tracecat.executor.service.run_action_on_ray_cluster") as mock_ray: + mock_ray.return_value = {"result": "success"} + + result = await _dispatch_action( + input=basic_looped_task_input, ctx=dispatch_context + ) + + assert result == [{"result": "success"}] * 3 + + # Assert the number of calls + assert mock_ray.call_count == 3 + + # Get all calls and their arguments + calls = mock_ray.call_args_list + + # Verify each call's arguments + for i, call in enumerate(calls, 1): + args, kwargs = call + input_arg = args[0] + # Verify the loop variable 'x' was set to different values (1, 2, 3) + assert input_arg.task.args["key"] == "value" + assert input_arg.exec_context[ExprContext.LOCAL_VARS] == {"x": i} + assert args[1] == dispatch_context + + +@pytest.mark.anyio +async def test_dispatch_action_with_git_url(mock_session, basic_task_input): + with ( + patch("tracecat.executor.service.prepare_git_url") as mock_git_url, + patch("tracecat.executor.service._dispatch_action") as mock_dispatch, + patch("tracecat.executor.service.opt_temp_key_file") as mock_key_file, + ): + mock_git_url.return_value = GitUrl( + host="github.com", org="org", repo="repo", ref="abc123" + ) + mock_key_file.return_value.__aenter__.return_value = "ssh -i /tmp/key" + mock_dispatch.return_value = {"result": "success"} + + result = await dispatch_action_on_cluster( + input=basic_task_input, session=mock_session + ) + + assert result == {"result": "success"} + mock_git_url.assert_called_once() + mock_key_file.return_value.__aenter__.assert_called_once() From c502b5b4c176827c8ee343752ae29d1620723024 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Wed, 15 Jan 2025 20:48:09 +0800 Subject: [PATCH 11/11] Adjust log levels --- tracecat/executor/service.py | 4 ++-- tracecat/registry/repository.py | 2 -- tracecat/secrets/service.py | 1 - tracecat/ssh.py | 1 - 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index 7ccbda8eb..28ba3a69a 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -331,7 +331,7 @@ async def run_action_on_ray_cluster( if ctx.git_url and ctx.git_url.ref: url = ctx.git_url.to_url() additional_vars["pip"] = [url] - logger.info("Adding git URL to runtime env", git_url=ctx.git_url, url=url) + logger.trace("Adding git URL to runtime env", git_url=ctx.git_url, url=url) runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars) @@ -386,7 +386,7 @@ async def dispatch_action_on_cluster( role = ctx_role.get() async with opt_temp_key_file(git_url=git_url, session=session) as ssh_command: - logger.info("SSH command", ssh_command=ssh_command) + logger.trace("SSH command", ssh_command=ssh_command) ctx = DispatchActionContext(role=role, git_url=git_url, ssh_command=ssh_command) result = await _dispatch_action(input=input, ctx=ctx) return result diff --git a/tracecat/registry/repository.py b/tracecat/registry/repository.py index b44e7a114..750b496db 100644 --- a/tracecat/registry/repository.py +++ b/tracecat/registry/repository.py @@ -289,8 +289,6 @@ async def _install_remote_repository( ) -> str: """Install the remote repository into the filesystem and return the commit sha.""" - logger.info("Getting SSH key", role=self.role, git_url=git_url) - url = git_url.to_url() async with ( get_async_session_context_manager() as session, diff --git a/tracecat/secrets/service.py b/tracecat/secrets/service.py index fdd716aeb..c8df473e5 100644 --- a/tracecat/secrets/service.py +++ b/tracecat/secrets/service.py @@ -279,7 +279,6 @@ async def get_ssh_key( key_name: str = GIT_SSH_KEY_SECRET_NAME, environment: str | None = None, ) -> SecretKeyValue: - logger.info("Getting SSH key", key_name=key_name, role=self.role) try: secret = await self.get_org_secret_by_name(key_name, environment) key = self.decrypt_keys(secret.encrypted_keys)[0] diff --git a/tracecat/ssh.py b/tracecat/ssh.py index 8f5979e0e..799ceb122 100644 --- a/tracecat/ssh.py +++ b/tracecat/ssh.py @@ -239,7 +239,6 @@ async def ssh_context( if git_url is None: yield None else: - logger.info("Getting SSH key", role=role, git_url=git_url) sec_svc = SecretsService(session, role=role) secret = await sec_svc.get_ssh_key() async with temporary_ssh_agent() as env: