Skip to content

Commit

Permalink
Replace branch field with sha
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jan 15, 2025
1 parent e95e00b commit b2dc4ba
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 50 deletions.
12 changes: 6 additions & 6 deletions tests/unit/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,17 @@ async def test_registry_async_function_can_be_called(mock_package):
host="github.com",
org="org",
repo="repo",
branch="main",
ref=None,
),
),
# GitHub (with branch)
# GitHub (with branch/sha)
(
"git+ssh://[email protected]/org/repo@branch",
"git+ssh://[email protected]/org/repo@branchOrSHAOrTag",
GitUrl(
host="github.com",
org="org",
repo="repo",
branch="branch",
ref="branchOrSHAOrTag",
),
),
# GitLab
Expand All @@ -174,7 +174,7 @@ async def test_registry_async_function_can_be_called(mock_package):
host="gitlab.com",
org="org",
repo="repo",
branch="main",
ref=None,
),
),
# GitLab (with branch)
Expand All @@ -184,7 +184,7 @@ async def test_registry_async_function_can_be_called(mock_package):
host="gitlab.com",
org="org",
repo="repo",
branch="branch",
ref="branch",
),
),
],
Expand Down
14 changes: 13 additions & 1 deletion tracecat/executor/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,20 @@ async def run_action_on_ray_cluster(
If any exceptions are thrown here, they're platform level errors.
All application/user level errors are caught by the executor and returned as values.
"""
# Initialize runtime environment variables
env_vars = {"GIT_SSH_COMMAND": ctx.ssh_command} if ctx.ssh_command else {}
additional_vars: dict[str, Any] = {}

obj_ref = run_action_task.remote(input, role)
# Add git URL to pip dependencies if SHA is present
if ctx.git_url and ctx.git_url.ref:
url = ctx.git_url.to_url()
additional_vars["pip"] = [url]
logger.warning("Adding git URL to runtime env", git_url=ctx.git_url, url=url)

runtime_env = RuntimeEnv(env_vars=env_vars, **additional_vars)

logger.info("Running action on ray cluster", runtime_env=runtime_env)
obj_ref = run_action_task.options(runtime_env=runtime_env).remote(input, ctx.role)
try:
coro = asyncio.to_thread(ray.get, obj_ref)
exec_result = await asyncio.wait_for(coro, timeout=EXECUTION_TIMEOUT)
Expand Down
19 changes: 9 additions & 10 deletions tracecat/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,21 @@
from tracecat.types.exceptions import TracecatSettingsError

GIT_SSH_URL_REGEX = re.compile(
r"^git\+ssh://git@(?P<host>[^/]+)/(?P<org>[^/]+)/(?P<repo>[^/@]+?)(?:\.git)?(?:@(?P<sha>[^/]+))?$"
r"^git\+ssh://git@(?P<host>[^/]+)/(?P<org>[^/]+)/(?P<repo>[^/@]+?)(?:\.git)?(?:@(?P<ref>[^/]+))?$"
)
"""Git SSH URL with git user and optional sha."""
"""Git SSH URL with git user and optional ref."""


@dataclass
class GitUrl:
host: str
org: str
repo: str
branch: str | None = None
sha: str | None = None
ref: str | None = None

def to_url(self) -> str:
base = f"git+ssh://git@{self.host}/{self.org}/{self.repo}.git"
return f"{base}@{self.sha}" if self.sha else base
return f"{base}@{self.ref}" if self.ref else base


async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str:
Expand All @@ -51,8 +50,8 @@ async def get_git_repository_sha(repo_url: str, env: SshEnv) -> str:
raise RuntimeError(f"Failed to get repository SHA: {error_message}")

# The output format is: "<SHA>\tHEAD"
sha = stdout.decode().split()[0]
return sha
ref = stdout.decode().split()[0]
return ref

except Exception as e:
logger.error("Error getting repository SHA", error=str(e))
Expand All @@ -78,7 +77,7 @@ def parse_git_url(url: str, *, allowed_domains: set[str] | None = None) -> GitUr
host = match.group("host")
org = match.group("org")
repo = match.group("repo")
sha = match.group("sha")
ref = match.group("ref")

if (
not isinstance(host, str)
Expand All @@ -92,7 +91,7 @@ def parse_git_url(url: str, *, allowed_domains: set[str] | None = None) -> GitUr
f"Domain {host} not in allowed domains. Must be configured in `git_allowed_domains` organization setting."
)

return GitUrl(host=host, org=org, repo=repo, sha=sha)
return GitUrl(host=host, org=org, repo=repo, ref=ref)

raise ValueError(f"Unsupported URL format: {url}. Must be a valid Git SSH URL.")

Expand Down Expand Up @@ -144,5 +143,5 @@ async def prepare_git_url(role: Role | None = None) -> GitUrl | None:
raise TracecatSettingsError(
"Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)."
) from e
git_url.sha = sha
git_url.ref = sha
return git_url
52 changes: 19 additions & 33 deletions tracecat/registry/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@

from tracecat import config
from tracecat.contexts import ctx_role
from tracecat.db.engine import get_async_session_context_manager
from tracecat.expressions.expectations import create_expectation_model
from tracecat.expressions.validation import TemplateValidator
from tracecat.git import get_git_repository_sha, parse_git_url
from tracecat.git import GitUrl, get_git_repository_sha, parse_git_url
from tracecat.logger import logger
from tracecat.parse import safe_url
from tracecat.registry.actions.models import BoundRegistryAction, TemplateAction
Expand All @@ -41,14 +42,8 @@
)
from tracecat.registry.repositories.models import RegistryRepositoryCreate
from tracecat.registry.repositories.service import RegistryReposService
from tracecat.secrets.service import SecretsService
from tracecat.settings.service import get_setting
from tracecat.ssh import (
SshEnv,
add_host_to_known_hosts,
add_ssh_key_to_agent,
temporary_ssh_agent,
)
from tracecat.ssh import SshEnv, ssh_context
from tracecat.types.auth import Role
from tracecat.types.exceptions import RegistryError

Expand Down Expand Up @@ -117,10 +112,6 @@ def get(self, name: str) -> BoundRegistryAction[ArgsClsT]:
"""Retrieve a registered udf."""
return self._store[name]

def safe_remote_url(self, remote_registry_url: str) -> str:
"""Clean a remote registry url."""
return safe_url(remote_registry_url)

def init(self, include_base: bool = True, include_templates: bool = True) -> None:
"""Initialize the registry."""
if not self._is_initialized:
Expand Down Expand Up @@ -219,11 +210,6 @@ def register_template_action(
origin=origin,
)

def _reset(self) -> None:
logger.warning("Resetting registry")
self._store = {}
self._is_initialized = False

def _load_base_udfs(self) -> None:
"""Load all udfs and template actions into the registry."""
# Load udfs
Expand Down Expand Up @@ -258,12 +244,12 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
or {"github.com"},
)

cleaned_url = safe_url(self._origin)
try:
git_url = parse_git_url(self._origin, allowed_domains=allowed_domains)
git_url = parse_git_url(cleaned_url, allowed_domains=allowed_domains)
host = git_url.host
org = git_url.org
repo_name = git_url.repo
branch = git_url.branch
except ValueError as e:
raise RegistryError(
"Invalid Git repository URL. Please provide a valid Git SSH URL (git+ssh)."
Expand All @@ -283,38 +269,38 @@ async def load_from_origin(self, commit_sha: str | None = None) -> str | None:
org=org,
repo=repo_name,
package_name=package_name,
ref=branch,
)

cleaned_url = self.safe_remote_url(self._origin)
logger.debug("Cleaned URL", url=cleaned_url)
logger.debug("Git URL", git_url=git_url)
commit_sha = await self._install_remote_repository(
host=host, repo_url=cleaned_url, commit_sha=commit_sha
git_url=git_url, commit_sha=commit_sha
)
module = await self._load_remote_repository(cleaned_url, package_name)
logger.info(
"Imported and reloaded remote repository",
module_name=module.__name__,
package_name=package_name,
commit_sha=commit_sha,
)
return commit_sha

async def _install_remote_repository(
self, host: str, repo_url: str, commit_sha: str | None = None
self, git_url: GitUrl, commit_sha: str | None = None
) -> str:
"""Install the remote repository into the filesystem and return the commit sha."""

logger.info("Getting SSH key", role=self.role)
async with SecretsService.with_session(role=self.role) as service:
secret = await service.get_ssh_key()
logger.info("Getting SSH key", role=self.role, git_url=git_url)

async with temporary_ssh_agent() as env:
logger.info("Entered temporary SSH agent context")
await add_ssh_key_to_agent(secret.reveal().value, env=env)
await add_host_to_known_hosts(host, env=env)
url = git_url.to_url()
async with (
get_async_session_context_manager() as session,
ssh_context(role=self.role, git_url=git_url, session=session) as env,
):
if env is None:
raise RegistryError("No SSH key found")
if commit_sha is None:
commit_sha = await get_git_repository_sha(repo_url, env=env)
await install_remote_repository(repo_url, commit_sha=commit_sha, env=env)
commit_sha = await get_git_repository_sha(url, env=env)
await install_remote_repository(url, commit_sha=commit_sha, env=env)
return commit_sha

async def _load_remote_repository(
Expand Down

0 comments on commit b2dc4ba

Please sign in to comment.