Skip to content

Commit

Permalink
Add ssh context managers for key files
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jan 15, 2025
1 parent b2dc4ba commit e3c9ae1
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tracecat/ssh.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import os
import subprocess
Expand All @@ -6,10 +8,19 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import aiofiles
import paramiko
from sqlmodel.ext.asyncio.session import AsyncSession

from tracecat.contexts import ctx_role
from tracecat.logger import logger
from tracecat.secrets.service import SecretsService
from tracecat.types.auth import Role

if TYPE_CHECKING:
from tracecat.git import GitUrl


@dataclass
Expand Down Expand Up @@ -168,3 +179,70 @@ def add_ssh_key_to_agent_sync(key_data: str, env: SshEnv) -> None:
async def add_ssh_key_to_agent(key_data: str, env: SshEnv) -> None:
"""Asynchronously add the SSH key to the agent then remove it."""
return await asyncio.to_thread(add_ssh_key_to_agent_sync, key_data, env)


@asynccontextmanager
async def temp_key_file(key_content: str) -> AsyncIterator[str]:
"""Create a temporary file containing an SSH key with secure permissions.
Args:
key_content: The SSH key content to write to the temporary file
Returns:
An SSH command string configured to use the temporary key file
Raises:
OSError: If unable to create temp file or set permissions
"""
async with aiofiles.tempfile.NamedTemporaryFile(mode="w", delete=True) as f:
# Write key content
await f.write(key_content)
await f.flush()

# Set strict permissions (important!)
os.chmod(f.name, 0o600)

# Use the key file in SSH command with more permissive host key checking
ssh_cmd = (
f"ssh -i {f.name} -o IdentitiesOnly=yes "
"-o StrictHostKeyChecking=accept-new "
f"-o UserKnownHostsFile={Path.home().joinpath('.ssh/known_hosts')!s}"
)
yield ssh_cmd


@asynccontextmanager
async def opt_temp_key_file(
git_url: GitUrl | None,
session: AsyncSession,
role: Role | None = None,
) -> AsyncIterator[str | None]:
"""Context manager for optional SSH key file."""
if git_url is None:
yield None
else:
role = role or ctx_role.get()
service = SecretsService(session=session, role=role)
ssh_key = await service.get_ssh_key()
async with temp_key_file(key_content=ssh_key.reveal().value) as ssh_cmd:
yield ssh_cmd


@asynccontextmanager
async def ssh_context(
*,
git_url: GitUrl | None = None,
session: AsyncSession,
role: Role | None = None,
) -> AsyncIterator[SshEnv | None]:
"""Context manager for SSH environment variables."""
if git_url is None:
yield None
else:
logger.info("Getting SSH key", role=role, git_url=git_url)
sec_svc = SecretsService(session, role=role)
secret = await sec_svc.get_ssh_key()
async with temporary_ssh_agent() as env:
await add_ssh_key_to_agent(secret.reveal().value, env=env)
await add_host_to_known_hosts(git_url.host, env=env)
yield env

0 comments on commit e3c9ae1

Please sign in to comment.