forked from TracecatHQ/tracecat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(integration): S3 getter and Ansible playbook runner (TracecatHQ#573
- Loading branch information
Showing
5 changed files
with
218 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Generic interface to Ansible Python API.""" | ||
|
||
import asyncio | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Annotated, Any | ||
|
||
import orjson | ||
from ansible_runner import run_async | ||
from pydantic import Field | ||
|
||
from tracecat_registry import RegistrySecret, registry, secrets | ||
|
||
ansible_secret = RegistrySecret( | ||
name="ansible", | ||
optional_keys=[ | ||
"ANSIBLE_SSH_KEY", | ||
"ANSIBLE_PASSWORDS", | ||
], | ||
) | ||
"""Ansible Runner secret. | ||
- name: `ansible` | ||
- optional_keys: | ||
- `ANSIBLE_SSH_KEY` | ||
- `ANSIBLE_PASSWORDS` | ||
`ANSIBLE_SSH_KEY` should be the private key string, not the path to the file. | ||
`ANSIBLE_PASSWORDS` should be a JSON object mapping password prompts to their responses (e.g. `{"Vault password": "password"}`). | ||
""" | ||
|
||
|
||
@registry.register( | ||
default_title="Run Ansible Playbook", | ||
description="Run an Ansible playbook", | ||
display_group="Ansible", | ||
namespace="integrations.ansible", | ||
secrets=[ansible_secret], | ||
) | ||
async def run_ansible_playbook( | ||
playbook: Annotated[ | ||
list[dict[str, Any]], Field(..., description="List of plays to run") | ||
], | ||
extra_vars: Annotated[ | ||
dict[str, Any], | ||
Field(description="Extra variables to pass to the playbook"), | ||
] = None, | ||
runner_kwargs: Annotated[ | ||
dict[str, Any], | ||
Field(description="Additional keyword arguments to pass to the Ansible runner"), | ||
] = None, | ||
) -> dict[str, Any]: | ||
ssh_key = secrets.get("ANSIBLE_SSH_KEY") | ||
passwords = secrets.get("ANSIBLE_PASSWORDS") | ||
|
||
if not ssh_key and not passwords: | ||
raise ValueError( | ||
"Either `ANSIBLE_SSH_KEY` or `ANSIBLE_PASSWORDS` must be provided" | ||
) | ||
|
||
runner_kwargs = runner_kwargs or {} | ||
|
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
if ssh_key: | ||
ssh_key_path = Path(temp_dir) / "id_rsa" | ||
with ssh_key_path.open("w") as f: | ||
f.write(ssh_key) | ||
runner_kwargs["ssh_key"] = str(ssh_key_path.resolve()) | ||
|
||
if passwords: | ||
runner_kwargs["passwords"] = orjson.loads(passwords) | ||
|
||
loop = asyncio.get_running_loop() | ||
|
||
def run(): | ||
return run_async( | ||
private_data_dir=temp_dir, | ||
playbook=playbook, | ||
extravars=extra_vars, | ||
**runner_kwargs, | ||
) | ||
|
||
result = await loop.run_in_executor(None, run) | ||
if result is None: | ||
raise ValueError("Ansible runner returned no result.") | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""S3 integration to download files and return contents as a string.""" | ||
|
||
import re | ||
from typing import Annotated | ||
|
||
from pydantic import Field | ||
|
||
from tracecat_registry import RegistrySecret, registry | ||
from tracecat_registry.integrations.boto3 import get_session | ||
|
||
# Add this at the top with other constants | ||
BUCKET_REGEX = re.compile(r"^[a-z0-9][a-z0-9.-]*[a-z0-9]$") | ||
|
||
s3_secret = RegistrySecret( | ||
name="s3", | ||
optional_keys=[ | ||
"AWS_ACCESS_KEY_ID", | ||
"AWS_SECRET_ACCESS_KEY", | ||
"AWS_REGION", | ||
"AWS_PROFILE_NAME", | ||
"AWS_ROLE_ARN", | ||
"AWS_ROLE_SESSION_NAME", | ||
], | ||
) | ||
"""AWS secret. | ||
Secret | ||
------ | ||
- name: `aws` | ||
- optional_keys: | ||
Either: | ||
- `AWS_ACCESS_KEY_ID` | ||
- `AWS_SECRET_ACCESS_KEY` | ||
- `AWS_REGION` | ||
Or: | ||
- `AWS_PROFILE_NAME` | ||
Or: | ||
- `AWS_ROLE_ARN` | ||
- `AWS_ROLE_SESSION_NAME` | ||
""" | ||
|
||
|
||
@registry.register( | ||
default_title="Parse S3 URI", | ||
description="Parse an S3 URI into a bucket and key.", | ||
display_group="AWS S3", | ||
namespace="integrations.aws_s3", | ||
) | ||
async def parse_uri(uri: str) -> tuple[str, str]: | ||
uri = str(uri).strip() | ||
if not uri.startswith("s3://"): | ||
raise ValueError("S3 URI must start with 's3://'") | ||
|
||
uri_path = uri.replace("s3://", "") | ||
uri_paths = uri_path.split("/") | ||
bucket = uri_paths.pop(0) | ||
key = "/".join(uri_paths) | ||
|
||
return bucket, key | ||
|
||
|
||
@registry.register( | ||
default_title="Download S3 Object", | ||
description="Download an object from S3 and return its body as a string.", | ||
display_group="AWS S3", | ||
namespace="integrations.aws_s3", | ||
secrets=[s3_secret], | ||
) | ||
async def download_object( | ||
bucket: Annotated[str, Field(..., description="S3 bucket name.")], | ||
key: Annotated[str, Field(..., description="S3 object key.")], | ||
) -> str: | ||
session = await get_session() | ||
async with session.client("s3") as s3_client: | ||
obj = await s3_client.get_object(Bucket=bucket, Key=key) | ||
body = await obj["Body"].read() | ||
# Defensively handle different types of bodies | ||
if isinstance(body, bytes): | ||
return body.decode("utf-8") | ||
return body |
34 changes: 34 additions & 0 deletions
34
registry/tracecat_registry/templates/ansible/run_playbook_from_s3.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
type: action | ||
definition: | ||
name: run_playbook_from_s3 | ||
namespace: integrations.ansible | ||
title: Run Ansible Playbook from S3 | ||
description: Download an Ansible playbook from S3 and run it | ||
display_group: Ansible | ||
expects: | ||
playbook_path: | ||
type: str | ||
description: Path to the playbook in S3 | ||
extra_vars: | ||
type: dict[str, any] | ||
description: Extra variables to pass to the playbook | ||
runner_kwargs: | ||
type: dict[str, any] | ||
description: Additional keyword arguments to pass to the Ansible runner | ||
steps: | ||
- ref: parse_uri | ||
action: integrations.aws_s3.parse_uri | ||
args: | ||
uri: ${{ inputs.playbook_path }} | ||
- ref: download_playbook | ||
action: integrations.aws_s3.download_object | ||
args: | ||
bucket: ${{ steps.parse_uri.result.bucket }} | ||
key: ${{ steps.parse_uri.result.key }} | ||
- ref: run_playbook | ||
action: integrations.ansible.run_playbook | ||
args: | ||
playbook: ${{ steps.download_playbook.result }} | ||
extra_vars: ${{ inputs.extra_vars }} | ||
runner_kwargs: ${{ inputs.runner_kwargs }} | ||
returns: ${{ steps.run_playbook.result }} |