Skip to content

Commit

Permalink
Merge pull request #516 from latchbio/ayush/ast-parsing
Browse files Browse the repository at this point in the history
replace importing the wf with ast parsing
  • Loading branch information
ayushkamat authored Feb 27, 2025
2 parents b01ef19 + b0c9c67 commit aa8a34f
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 58 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ Types of changes

# Latch SDK Changelog

## 2.57.0 - 2024-02-27

### Changed

* During registration, we no longer execute workflow code outside of a container to check for syntax errors / parse out task-specific Dockerfiles as this was error prone. We now only parse the code into an AST to check for syntax errors, and inspect the AST to pull out any task specific `dockerfile` arguments.
* This is a breaking change for users that use the `dockerfile` argument for the `task` decorator:
* Old behavior: `dockerfile` accepts a `pathlib.Path` argument which will be resolved against the current working directory
* New behavior: `dockerfile` must be a `str` literal (no variable values or expressions) - this is so that we can pull it out from the AST without executing anything

## 2.56.10 - 2025-02-26

### Dependencies
Expand Down
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Setup

install:
uv sync
uv sync --no-cache --frozen

# Packaging

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ include = ["src/**/*.py", "src/latch_cli/services/init/*"]

[project]
name = "latch"
version = "2.56.10"
version = "2.57.0"
description = "The Latch SDK"
authors = [{ name = "Kenny Workman", email = "[email protected]" }]
maintainers = [
Expand Down
124 changes: 124 additions & 0 deletions src/latch_cli/centromere/ast_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import ast
import os
import traceback
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
from textwrap import dedent
from typing import Literal, Optional

import click

from latch.resources import tasks


@dataclass
class FlyteObject:
type: Literal["task", "workflow"]
name: str
dockerfile: Optional[Path] = None


task_decorators = set(filter(lambda x: x.endswith("task"), tasks.__dict__.keys()))


class Visitor(ast.NodeVisitor):
def __init__(self, file: Path, module_name: str):
self.file = file
self.module_name = module_name
self.flyte_objects: list[FlyteObject] = []

# todo(ayush): skip defs that arent run on import
def visit_FunctionDef(self, node: ast.FunctionDef): # noqa: N802
if len(node.decorator_list) == 0:
return self.generic_visit(node)

fqn = f"{self.module_name}.{node.name}"

# todo(ayush): |
# 1. support ast.Attribute (@latch.tasks.small_task)
# 2. normalize to fqn before checking whether or not its a task decorator
# 3. save fully qualified name for tasks (need to parse based on import graph)
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", fqn))
elif decorator.id in task_decorators:
self.flyte_objects.append(FlyteObject("task", fqn))

elif isinstance(decorator, ast.Call):
func = decorator.func
assert isinstance(func, ast.Name)

if func.id not in task_decorators and func.id != "workflow":
continue

if func.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", fqn))
continue

# note(ayush): this only works if `dockerfile` is a keyword arg - if someone
# is insane enough to pass in the 14 other arguments first then have `dockerfile`
# as a positional arg i will fix it
dockerfile: Optional[Path] = None
for kw in decorator.keywords:
if kw.arg != "dockerfile":
continue

try:
dockerfile = Path(ast.literal_eval(kw.value))
except ValueError as e:
click.secho(
dedent(f"""\
There was an issue parsing the `dockerfile` argument for task `{fqn}` in {self.file}.
Note that values passed to `dockerfile` must be string literals.
"""),
fg="red",
)

raise click.exceptions.Exit(1) from e

self.flyte_objects.append(FlyteObject("task", fqn, dockerfile))

return self.generic_visit(node)


def get_flyte_objects(module: Path) -> list[FlyteObject]:
res: list[FlyteObject] = []
queue: Queue[Path] = Queue()
queue.put(module)

while not queue.empty():
file = queue.get()

if file.is_dir():
for child in file.iterdir():
queue.put(child)

continue

# todo(ayush): follow the import graph instead
assert file.is_file()
if file.suffix != ".py":
continue

module_name = str(file.with_suffix("").relative_to(module.parent)).replace(
os.sep, "."
)

v = Visitor(file, module_name)

try:
parsed = ast.parse(file.read_text(), filename=file)
except SyntaxError as e:
traceback.print_exc()
click.secho(
"\nRegistration failed due to a syntax error (see above)", fg="red"
)
raise click.exceptions.Exit(1) from e

v.visit(parsed)

res.extend(v.flyte_objects)

return res
71 changes: 44 additions & 27 deletions src/latch_cli/centromere/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,22 @@
import paramiko
import paramiko.util
from docker.transport import SSHHTTPAdapter
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteEntities
from flytekit.core.workflow import PythonFunctionWorkflow

import latch_cli.tinyrequests as tinyrequests
from latch.utils import account_id_from_token, current_workspace, retrieve_or_login
from latch_cli.centromere.ast_parsing import get_flyte_objects
from latch_cli.centromere.utils import (
RemoteConnInfo,
_construct_dkr_client,
_construct_ssh_client,
_import_flyte_objects,
)
from latch_cli.constants import docker_image_name_illegal_pat, latch_constants
from latch_cli.docker_utils import get_default_dockerfile
from latch_cli.utils import (
WorkflowType,
generate_temporary_ssh_credentials,
hash_directory,
identifier_suffix_from_str,
)
from latch_sdk_config.latch import config

Expand Down Expand Up @@ -176,8 +174,8 @@ def __init__(

if self.workflow_type == WorkflowType.latchbiosdk:
try:
_import_flyte_objects([self.pkg_root], module_name=self.wf_module)
except ModuleNotFoundError:
flyte_objects = get_flyte_objects(self.pkg_root / self.wf_module)
except ModuleNotFoundError as e:
click.secho(
dedent(
f"""
Expand All @@ -189,14 +187,23 @@ def __init__(
),
fg="red",
)
raise click.exceptions.Exit(1)
raise click.exceptions.Exit(1) from e

wf_name: Optional[str] = None

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
wf_name = name_path.read_text().strip()

for entity in FlyteEntities.entities:
if isinstance(entity, PythonFunctionWorkflow):
self.workflow_name = entity.name
if wf_name is None:
for obj in flyte_objects:
if obj.type != "workflow":
continue

wf_name = obj.name
break

if not hasattr(self, "workflow_name"):
if wf_name is None:
click.secho(
dedent("""\
Unable to locate workflow code. If you are a registering a Snakemake project, make sure to pass the Snakefile path with the --snakefile flag.
Expand All @@ -205,21 +212,30 @@ def __init__(
)
raise click.exceptions.Exit(1)

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
self.workflow_name = name_path.read_text().strip()
self.workflow_name = wf_name

for entity in FlyteEntities.entities:
if isinstance(entity, PythonTask):
if (
hasattr(entity, "dockerfile_path")
and entity.dockerfile_path is not None
):
self.container_map[entity.name] = _Container(
dockerfile=entity.dockerfile_path,
image_name=self.task_image_name(entity.name),
pkg_dir=entity.dockerfile_path.parent,
)
for obj in flyte_objects:
if obj.type != "task" or obj.dockerfile is None:
continue

dockerfile = self.pkg_root / obj.dockerfile

if not dockerfile.exists():
click.secho(
f"""\
The `dockerfile` value (provided {obj.dockerfile}, resolved to {dockerfile}) for task `{obj.name}` does not exist.
Note that relative paths are resolved with respect to the package root.\
""",
fg="red",
)

raise click.exceptions.Exit(1)

self.container_map[obj.name] = _Container(
dockerfile=obj.dockerfile,
image_name=self.task_image_name(obj.name),
pkg_dir=obj.dockerfile.parent,
)

elif self.workflow_type == WorkflowType.snakemake:
assert snakefile is not None
Expand Down Expand Up @@ -438,8 +454,6 @@ def image(self):
else:
account_id = self.account_id

from ..utils import identifier_suffix_from_str

wf_name = identifier_suffix_from_str(self.workflow_name).lower()
wf_name = docker_image_name_illegal_pat.sub("_", wf_name)

Expand Down Expand Up @@ -479,6 +493,9 @@ def image_tagged(self):
return f"{self.image}:{self.version}"

def task_image_name(self, task_name: str) -> str:
task_name = identifier_suffix_from_str(task_name).lower()
task_name = docker_image_name_illegal_pat.sub("_", task_name)

return f"{self.image}:{task_name}-{self.version}"

@property
Expand Down
29 changes: 10 additions & 19 deletions src/latch_cli/centromere/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import builtins
import contextlib
import functools
import os
import random
import string
Expand All @@ -12,10 +11,8 @@
from typing import Callable, Iterator, List, Optional, TypeVar

import docker
import docker.errors
import paramiko
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader
from typing_extensions import ParamSpec

from latch_cli.constants import latch_constants
Expand All @@ -42,6 +39,10 @@ def _add_sys_paths(paths: List[Path]) -> Iterator[None]:


def _import_flyte_objects(paths: List[Path], module_name: str = "wf"):
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader

with _add_sys_paths(paths):

class FakeModule(ModuleType):
Expand Down Expand Up @@ -76,20 +77,15 @@ def __new__(*args, **kwargs):
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
try:
return real_import(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level,
name, globals=globals, locals=locals, fromlist=fromlist, level=level
)
except (ModuleNotFoundError, AttributeError) as e:
except (ModuleNotFoundError, AttributeError):
return FakeModule(name)

# Temporary ctx tells lytekit to skip local execution when
# inspecting objects
fap = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"),
raw_output_prefix="bar",
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), raw_output_prefix="bar"
)
tmp_context = FlyteContext(fap, inspect_objects_only=True)

Expand Down Expand Up @@ -201,9 +197,7 @@ def _construct_ssh_client(
raise ConnectionError("unable to create connection to jump host")

sock = gateway_transport.open_channel(
kind="direct-tcpip",
dest_addr=(remote_conn_info.ip, 22),
src_addr=("", 0),
kind="direct-tcpip", dest_addr=(remote_conn_info.ip, 22), src_addr=("", 0)
)
else:
sock = None
Expand All @@ -214,10 +208,7 @@ def _construct_ssh_client(
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
ssh.connect(
remote_conn_info.ip,
username=remote_conn_info.username,
sock=sock,
pkey=pkey,
remote_conn_info.ip, username=remote_conn_info.username, sock=sock, pkey=pkey
)

transport = ssh.get_transport()
Expand Down
Loading

0 comments on commit aa8a34f

Please sign in to comment.