From d5a893ff564cadf03591d4ecf16466959f927ac6 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:47:27 +0800 Subject: [PATCH] Add ssh context managers for key files --- tracecat/ssh.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tracecat/ssh.py b/tracecat/ssh.py index 67c4c43cf..8f5979e0e 100644 --- a/tracecat/ssh.py +++ b/tracecat/ssh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os import subprocess @@ -6,10 +8,19 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING +import aiofiles import paramiko +from sqlmodel.ext.asyncio.session import AsyncSession +from tracecat.contexts import ctx_role from tracecat.logger import logger +from tracecat.secrets.service import SecretsService +from tracecat.types.auth import Role + +if TYPE_CHECKING: + from tracecat.git import GitUrl @dataclass @@ -168,3 +179,70 @@ def add_ssh_key_to_agent_sync(key_data: str, env: SshEnv) -> None: async def add_ssh_key_to_agent(key_data: str, env: SshEnv) -> None: """Asynchronously add the SSH key to the agent then remove it.""" return await asyncio.to_thread(add_ssh_key_to_agent_sync, key_data, env) + + +@asynccontextmanager +async def temp_key_file(key_content: str) -> AsyncIterator[str]: + """Create a temporary file containing an SSH key with secure permissions. + + Args: + key_content: The SSH key content to write to the temporary file + + Returns: + An SSH command string configured to use the temporary key file + + Raises: + OSError: If unable to create temp file or set permissions + """ + async with aiofiles.tempfile.NamedTemporaryFile(mode="w", delete=True) as f: + # Write key content + await f.write(key_content) + await f.flush() + + # Set strict permissions (important!) + os.chmod(f.name, 0o600) + + # Use the key file in SSH command with more permissive host key checking + ssh_cmd = ( + f"ssh -i {f.name} -o IdentitiesOnly=yes " + "-o StrictHostKeyChecking=accept-new " + f"-o UserKnownHostsFile={Path.home().joinpath('.ssh/known_hosts')!s}" + ) + yield ssh_cmd + + +@asynccontextmanager +async def opt_temp_key_file( + git_url: GitUrl | None, + session: AsyncSession, + role: Role | None = None, +) -> AsyncIterator[str | None]: + """Context manager for optional SSH key file.""" + if git_url is None: + yield None + else: + role = role or ctx_role.get() + service = SecretsService(session=session, role=role) + ssh_key = await service.get_ssh_key() + async with temp_key_file(key_content=ssh_key.reveal().value) as ssh_cmd: + yield ssh_cmd + + +@asynccontextmanager +async def ssh_context( + *, + git_url: GitUrl | None = None, + session: AsyncSession, + role: Role | None = None, +) -> AsyncIterator[SshEnv | None]: + """Context manager for SSH environment variables.""" + if git_url is None: + yield None + else: + logger.info("Getting SSH key", role=role, git_url=git_url) + sec_svc = SecretsService(session, role=role) + secret = await sec_svc.get_ssh_key() + async with temporary_ssh_agent() as env: + await add_ssh_key_to_agent(secret.reveal().value, env=env) + await add_host_to_known_hosts(git_url.host, env=env) + yield env