From 10ea9aff1b1061fd730cd3291973b210b57b0984 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Sun, 26 Jan 2025 02:49:01 +0000 Subject: [PATCH] feat(engine): Pull all dependent secrets outside template action loop (#800) --- .../tracecat_registry/_internal/models.py | 11 + tests/unit/test_templates.py | 355 ++++++++++++++++++ tracecat/executor/service.py | 71 +--- tracecat/registry/actions/models.py | 15 +- tracecat/registry/actions/service.py | 48 ++- tracecat/registry/loaders.py | 20 +- tracecat/types/exceptions.py | 2 +- 7 files changed, 449 insertions(+), 73 deletions(-) create mode 100644 tests/unit/test_templates.py diff --git a/registry/tracecat_registry/_internal/models.py b/registry/tracecat_registry/_internal/models.py index 3ae226d57..796afb4ee 100644 --- a/registry/tracecat_registry/_internal/models.py +++ b/registry/tracecat_registry/_internal/models.py @@ -29,3 +29,14 @@ def validate_keys(cls, v): "At least one of 'keys' or 'optional_keys' must be specified" ) return v + + def __hash__(self) -> int: + """Custom hash implementation based on relevant fields.""" + return hash( + ( + self.name, + tuple(self.keys) if self.keys else None, + tuple(self.optional_keys) if self.optional_keys else None, + self.optional, + ) + ) diff --git a/tests/unit/test_templates.py b/tests/unit/test_templates.py new file mode 100644 index 000000000..e0beb079c --- /dev/null +++ b/tests/unit/test_templates.py @@ -0,0 +1,355 @@ +import os +import sys +import textwrap +import uuid +from importlib.machinery import ModuleSpec +from types import ModuleType +from typing import Any + +import pytest +from pydantic import BaseModel, SecretStr, TypeAdapter + +from tests.shared import TEST_WF_ID, generate_test_exec_id +from tracecat import config +from tracecat.dsl.models import ( + ActionStatement, + RunActionInput, + RunContext, +) +from tracecat.executor.service import run_action_from_input +from tracecat.expressions.expectations import ExpectedField +from tracecat.registry.actions.models import ( + ActionStep, + BoundRegistryAction, + RegistryActionCreate, + RegistrySecret, + TemplateAction, + TemplateActionDefinition, +) +from tracecat.registry.actions.service import RegistryActionsService +from tracecat.registry.repository import Repository +from tracecat.secrets.models import SecretCreate, SecretKeyValue +from tracecat.secrets.service import SecretsService +from tracecat.types.exceptions import TracecatValidationError + + +@pytest.fixture +def mock_package(tmp_path): + """Pytest fixture that creates a mock package with files and cleans up after the test.""" + + # Create a new module + test_module = ModuleType("test_module") + + # Create a module spec for the test module + module_spec = ModuleSpec("test_module", None) + test_module.__spec__ = module_spec + # Set __path__ to the temporary directory + test_module.__path__ = [str(tmp_path)] + + try: + # Add the module to sys.modules + sys.modules["test_module"] = test_module + with open(os.path.join(tmp_path, "has_secret.py"), "w") as f: + f.write( + textwrap.dedent( + """ + from tracecat_registry import registry, RegistrySecret, secrets + + secret = RegistrySecret( + name="the_secret", + keys=["THE_SECRET_KEY"], + ) + + @registry.register( + description="This is a deprecated function", + namespace="testing", + secrets=[secret], + ) + def has_secret() -> str: + return secrets.get("THE_SECRET_KEY") + """ + ) + ) + + yield test_module + + finally: + # Clean up + del sys.modules["test_module"] + + +@pytest.mark.integration +@pytest.mark.anyio +async def test_template_action_with_nested_secrets_can_be_fetched( + test_role, + monkeypatch, + db_session_with_repo, + mock_package, +): + """Test template action with secrets. + + The test verifies: + 1. Template action with secrets executes successfully + """ + + monkeypatch.setattr(config, "TRACECAT__UNSAFE_DISABLE_SM_MASKING", True) + + # Arrange + # 1. Register test udfs + repo = Repository() + + session, db_repo_id = db_session_with_repo + repo.init(include_base=True, include_templates=False) + repo._register_udfs_from_package(mock_package) + assert repo.get("testing.has_secret") is not None + + template_action_registered = TemplateAction( + type="action", + definition=TemplateActionDefinition( + title="Test Action Registered", + description="Test template registered in the registry", + name="template_action_registered", + namespace="testing", + display_group="Testing", + expects={ + "num": ExpectedField( + type="int", + description="Number to add 100 to", + ) + }, + secrets=[ + RegistrySecret( + name="template_secret_registered", + keys=["TEMPLATE_SECRET_KEY_REGISTERED"], + ) # This secret isn't used but we just pull it to verify it's fetched + ], + steps=[ + ActionStep( + ref="base", + action="core.transform.reshape", + args={ + "value": "${{ inputs.num + 100 }}", + }, + ), + ActionStep( + ref="secret", + action="testing.has_secret", + args={}, + ), + ], + # Return the secret value from the secret step + returns="${{ steps.secret.result }}", + ), + ) + repo.register_template_action(template_action_registered) + + # It then returns the fetched secret + template_action = TemplateAction( + type="action", + definition=TemplateActionDefinition( + title="Test Action", + description="This is just a test", + name="template_action", + namespace="testing", + display_group="Testing", + expects={ + "num": ExpectedField( + type="int", + description="Number to add 100 to", + ) + }, + secrets=[ + RegistrySecret( + name="template_secret", + keys=["TEMPLATE_SECRET_KEY"], + ) + ], + steps=[ + ActionStep( + ref="base", + action="core.transform.reshape", + args={ + "value": "${{ inputs.num + 100 }}", + }, + ), + ActionStep( + ref="secret", + action="testing.has_secret", + args={}, + ), + ActionStep( + ref="template_secret_registered", + action="testing.template_action_registered", + args={ + "num": "${{ inputs.num }}", + }, + ), + ], + returns={ + "secret_step": "${{ steps.secret.result }}", + "nested_secret_step": "${{ steps.template_secret_registered.result }}", + }, + ), + ) + + # We expect the secret to be fetched + def get_secrets(action: BoundRegistryAction) -> list[RegistrySecret]: + """Recursively fetch secrets from the template action.""" + secrets = [] + # Base case + if action.type == "udf": + if action.secrets: + secrets.extend(action.secrets) + elif action.type == "template": + assert action.template_action is not None + if template_secrets := action.template_action.definition.secrets: + secrets.extend(template_secrets) + for step in action.template_action.definition.steps: + step_action = repo.get(step.action) + step_secrets = get_secrets(step_action) + secrets.extend(step_secrets) + return secrets + + bound_action = BoundRegistryAction( + fn=lambda: None, + type="template", + name="template_action", + namespace="testing", + description="This is just a test", + secrets=[], + args_docs={}, + rtype_cls=Any, + rtype_adapter=TypeAdapter(Any), + default_title="Test Action", + display_group="Testing", + doc_url=None, + author=None, + deprecated=None, + include_in_schema=True, + template_action=template_action, + origin="testing.template_action", + args_cls=BaseModel, + ) + assert set(get_secrets(bound_action)) == { + RegistrySecret( + name="template_secret", + keys=["TEMPLATE_SECRET_KEY"], + ), + RegistrySecret( + name="the_secret", + keys=["THE_SECRET_KEY"], + ), + RegistrySecret( + name="template_secret_registered", + keys=["TEMPLATE_SECRET_KEY_REGISTERED"], + ), + } + + # Now run the action + + repo.register_template_action(template_action) + + assert "testing.template_action" in repo + + ra_service = RegistryActionsService(session, role=test_role) + # create actions for each step + action_names = {step.action for step in template_action.definition.steps} | { + "testing.template_action", + } + for action_name in action_names: + if action_name.startswith("testing"): + step_create_params = RegistryActionCreate.from_bound( + repo.get(action_name), db_repo_id + ) + await ra_service.create_action(step_create_params) + # Add secrets to the db + sec_service = SecretsService(session, role=test_role) + # Add secret for the UDF + await sec_service.create_secret( + SecretCreate( + name="the_secret", + environment="default", + keys=[ + SecretKeyValue( + key="THE_SECRET_KEY", value=SecretStr("UDF_SECRET_VALUE") + ) + ], + ) + ) + # Add secret for the registered template action + await sec_service.create_secret( + SecretCreate( + name="template_secret_registered", + environment="default", + keys=[ + SecretKeyValue( + key="TEMPLATE_SECRET_KEY_REGISTERED", + value=SecretStr("REGISTERED_SECRET_VALUE"), + ) + ], + ) + ) + # Add secret for the main template action + await sec_service.create_secret( + SecretCreate( + name="template_secret", + environment="default", + keys=[ + SecretKeyValue( + key="TEMPLATE_SECRET_KEY", + value=SecretStr("TEMPLATE_SECRET_VALUE"), + ) + ], + ) + ) + + input = RunActionInput( + task=ActionStatement( + ref="test_action", + action="testing.template_action", + args={"num": 123123}, + ), + exec_context={}, + run_context=RunContext( + wf_id=TEST_WF_ID, + wf_exec_id=generate_test_exec_id("test_template_action_with_secrets"), + wf_run_id=uuid.uuid4(), + environment="default", + ), + ) + result = await run_action_from_input(input=input, role=test_role) + assert result == { + "secret_step": "UDF_SECRET_VALUE", + "nested_secret_step": "UDF_SECRET_VALUE", + } + + +def test_template_action_definition_validates_self_reference(): + """Test that TemplateActionDefinition validates against self-referential steps. + + The test verifies: + 1. A template action cannot reference itself in its steps + 2. The validation error message is descriptive + """ + with pytest.raises(TracecatValidationError) as exc_info: + TemplateActionDefinition( + title="Self Referential Action", + description="This action tries to reference itself", + name="self_ref", + namespace="testing", + display_group="Testing", + expects={}, + steps=[ + ActionStep( + ref="self_ref_step", + action="testing.self_ref", # This references the template itself + args={}, + ), + ], + returns="${{ steps.self_ref_step.result }}", + ) + + assert "Steps cannot reference the template action itself: testing.self_ref" in str( + exc_info.value + ) + assert "1 steps reference the template action" in str(exc_info.value) diff --git a/tracecat/executor/service.py b/tracecat/executor/service.py index 28ba3a69a..e2b48da6d 100644 --- a/tracecat/executor/service.py +++ b/tracecat/executor/service.py @@ -107,37 +107,13 @@ async def run_single_action( context: ExecutionContext, ) -> Any: """Run a UDF async.""" - - # Here, we pass in context - # For this action, check whether its dependent secrets are already in the context - # For any that aren't, pull them in - - action_secret_names = set() - optional_secrets = set() - secrets = context.get(ExprContext.SECRETS, {}) - - for secret in action.secrets or []: - # Only add if not already pulled - if secret.name not in secrets: - if secret.optional: - optional_secrets.add(secret.name) - action_secret_names.add(secret.name) - - args_secret_refs = set(extract_templated_secrets(args)) - async with AuthSandbox( - secrets=list(action_secret_names | args_secret_refs), - target="context", - environment=get_runtime_env(), - optional_secrets=list(optional_secrets), - ) as sandbox: - secrets |= sandbox.secrets.copy() - - context[ExprContext.SECRETS] = context.get(ExprContext.SECRETS, {}) | secrets if action.is_template: logger.info("Running template action async", action=action.name) result = await run_template_action(action=action, args=args, context=context) else: logger.trace("Running UDF async", action=action.name) + # Get secrets from context + secrets = context.get(ExprContext.SECRETS, {}) flat_secrets = flatten_secrets(secrets) with env_sandbox(flat_secrets): result = await _run_action_direct(action=action, args=args) @@ -189,7 +165,7 @@ async def run_template_action( ) async with RegistryActionsService.with_session() as service: step_action = await service.load_action_impl(action_name=step.action) - logger.trace("Running action step", step_ation=step_action.action) + logger.trace("Running action step", step_action=step_action.action) result = await run_single_action( action=step_action, args=evaled_args, @@ -217,34 +193,25 @@ async def run_action_from_input(input: RunActionInput, role: Role) -> Any: task = input.task action_name = task.action - # Multi-phase expression resolution - # --------------------------------- - # 1. Resolve all expressions in all shared (non action-local) contexts - # 2. Enter loop iteration (if any) - # 3. Resolve all action-local expressions - - # Set - # If there's a for loop, we need to process this action in parallel - - # Evaluate `SECRETS` context (XXX: You likely should use the secrets manager instead) - # -------------------------- - # Securely inject secrets into the task arguments - # 1. Find all secrets in the task arguments - # 2. Load the secrets - # 3. Inject the secrets into the task arguments using an enriched context - # NOTE: Regardless of loop iteration, we should only make this call/substitution once!! - async with RegistryActionsService.with_session() as service: - action = await service.load_action_impl(action_name=action_name) + reg_action = await service.get_action(action_name) + action_secrets = await service.fetch_all_action_secrets(reg_action) + action = service.get_bound(reg_action) + + args_secrets = set(extract_templated_secrets(task.args)) + optional_secrets = {s.name for s in action_secrets if s.optional} + required_secrets = {s.name for s in action_secrets if not s.optional} + + logger.info( + "Required secrets", + required_secrets=required_secrets, + optional_secrets=optional_secrets, + args_secrets=args_secrets, + ) - action_secret_names = {secret.name for secret in action.secrets or []} - optional_secrets = { - secret.name for secret in action.secrets or [] if secret.optional - } - args_secret_refs = set(extract_templated_secrets(task.args)) + # Get all secrets in one call async with AuthSandbox( - secrets=list(action_secret_names | args_secret_refs), - target="context", + secrets=list(required_secrets | args_secrets), environment=get_runtime_env(), optional_secrets=list(optional_secrets), ) as sandbox: diff --git a/tracecat/registry/actions/models.py b/tracecat/registry/actions/models.py index 59d23ae70..bbc49f887 100644 --- a/tracecat/registry/actions/models.py +++ b/tracecat/registry/actions/models.py @@ -198,9 +198,20 @@ def validate_steps(self): step_refs = [step.ref for step in self.steps] unique_step_refs = set(step_refs) + # Check for duplicate step refs if len(step_refs) != len(unique_step_refs): duplicate_step_refs = [ref for ref in step_refs if step_refs.count(ref) > 1] - raise ValueError(f"Duplicate step references found: {duplicate_step_refs}") + raise TracecatValidationError( + f"Duplicate step references found: {duplicate_step_refs}" + ) + + # Check if any step action references the template action + template_action = f"{self.namespace}.{self.name}" + if violating_steps := [s for s in self.steps if s.action == template_action]: + raise TracecatValidationError( + f"Steps cannot reference the template action itself: {template_action}." + f"{len(violating_steps)} steps reference the template action: {violating_steps}" + ) return self @@ -506,7 +517,7 @@ class RegistryActionTemplateImpl(BaseModel): ] RegistryActionImplValidator: TypeAdapter[RegistryActionImpl] = TypeAdapter( AnnotatedRegistryActionImpl -) +) # type: ignore class model_converters: diff --git a/tracecat/registry/actions/service.py b/tracecat/registry/actions/service.py index 074a4118d..572047bab 100644 --- a/tracecat/registry/actions/service.py +++ b/tracecat/registry/actions/service.py @@ -61,7 +61,7 @@ async def list_actions( result = await self.session.exec(statement) return result.all() - async def get_action(self, *, action_name: str) -> RegistryAction: + async def get_action(self, action_name: str) -> RegistryAction: """Get an action by name.""" namespace, name = action_name.rsplit(".", maxsplit=1) statement = select(RegistryAction).where( @@ -75,6 +75,17 @@ async def get_action(self, *, action_name: str) -> RegistryAction: raise RegistryError(f"Action {namespace}.{name} not found in repository") return action + async def get_actions(self, action_names: list[str]) -> Sequence[RegistryAction]: + """Get actions by name.""" + statement = select(RegistryAction).where( + RegistryAction.owner_id == config.TRACECAT__DEFAULT_ORG_ID, + func.concat(RegistryAction.namespace, ".", RegistryAction.name).in_( + action_names + ), + ) + result = await self.session.exec(statement) + return result.all() + async def create_action( self, params: RegistryActionCreate, @@ -256,3 +267,38 @@ async def read_action_with_implicit_secrets( ) -> RegistryActionRead: extra_secrets = await self.get_action_implicit_secrets(action) return RegistryActionRead.from_database(action, extra_secrets) + + async def fetch_all_action_secrets( + self, action: RegistryAction + ) -> set[RegistrySecret]: + """Recursively fetch all secrets from the action and its template steps. + + Args: + action: The registry action to fetch secrets from + + Returns: + set[RegistrySecret]: A set of secret names used by the action and its template steps + """ + secrets = set() + impl = RegistryActionImplValidator.validate_python(action.implementation) + if impl.type == "udf": + if action.secrets: + secrets.update(RegistrySecret(**secret) for secret in action.secrets) + elif impl.type == "template": + ta = impl.template_action + if ta is None: + raise ValueError("Template action is not defined") + # Add secrets from the template action itself + if template_secrets := ta.definition.secrets: + secrets.update(template_secrets) + # Recursively fetch secrets from each step + step_action_names = [step.action for step in ta.definition.steps] + step_ras = await self.get_actions(step_action_names) + for step_ra in step_ras: + step_secrets = await self.fetch_all_action_secrets(step_ra) + secrets.update(step_secrets) + return secrets + + def get_bound(self, action: RegistryAction) -> BoundRegistryAction: + """Get the bound action for a registry action.""" + return get_bound_action_impl(action) diff --git a/tracecat/registry/loaders.py b/tracecat/registry/loaders.py index 876028862..8f43289e1 100644 --- a/tracecat/registry/loaders.py +++ b/tracecat/registry/loaders.py @@ -13,10 +13,7 @@ from tracecat.logger import logger from tracecat.registry.actions.models import ( BoundRegistryAction, - RegistryActionImpl, RegistryActionImplValidator, - RegistryActionTemplateImpl, - RegistryActionType, RegistryActionUDFImpl, ) from tracecat.registry.repository import ( @@ -31,12 +28,11 @@ def get_bound_action_impl( action: RegistryAction, -) -> BoundRegistryAction[type[BaseModel]]: +) -> BoundRegistryAction: impl = RegistryActionImplValidator.validate_python(action.implementation) - impl_loader = _LOADERS[impl.type] - fn: F = impl_loader(impl) secrets = [RegistrySecret(**secret) for secret in action.secrets or []] if impl.type == "udf": + fn = load_udf_impl(impl) key = getattr(fn, "__tracecat_udf_key") kwargs = getattr(fn, "__tracecat_udf_kwargs") logger.trace("Binding UDF", key=key, name=action.name, kwargs=kwargs) @@ -70,7 +66,7 @@ def get_bound_action_impl( logger.trace("Binding template action", name=action.name) defn = impl.template_action.definition return BoundRegistryAction( - fn=fn, + fn=_not_implemented, type=impl.type, name=action.name, namespace=action.namespace, @@ -106,17 +102,7 @@ def load_udf_impl(impl: RegistryActionUDFImpl) -> F: return fn -def load_template_impl(impl: RegistryActionTemplateImpl) -> F: - return _not_implemented - - def _not_implemented() -> NoReturn: raise NotImplementedError( "This is a template action, it must be run with concrete arguments" ) - - -_LOADERS: dict[RegistryActionType, Callable[[RegistryActionImpl], F]] = { - "udf": load_udf_impl, # type: ignore - "template": load_template_impl, # type: ignore -} diff --git a/tracecat/types/exceptions.py b/tracecat/types/exceptions.py index af519efbe..f2293079a 100644 --- a/tracecat/types/exceptions.py +++ b/tracecat/types/exceptions.py @@ -21,7 +21,7 @@ def __init__(self, *args, detail: Any | None = None, **kwargs): class TracecatValidationError(TracecatException): - """Tracecat user-facting validation error""" + """Tracecat user-facing validation error""" class TracecatDSLError(TracecatValidationError):