Skip to content

Commit

Permalink
jit touch
Browse files Browse the repository at this point in the history
Signed-off-by: Ayush Kamat <[email protected]>
  • Loading branch information
ayushkamat committed Oct 17, 2023
1 parent 29ababe commit 8c1cb00
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
32 changes: 27 additions & 5 deletions latch_cli/snakemake/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traceback
from pathlib import Path
from textwrap import dedent
from typing import List, Optional, Set, Union, get_args
from typing import Dict, List, Optional, Set, Union, get_args

import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan
Expand Down Expand Up @@ -189,12 +189,15 @@ def snakemake_workflow_extractor(


def extract_snakemake_workflow(
pkg_root: Path, snakefile: Path, version: Optional[str] = None
pkg_root: Path,
snakefile: Path,
version: Optional[str] = None,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
) -> SnakemakeWorkflow:
extractor = snakemake_workflow_extractor(pkg_root, snakefile, version)
with extractor:
dag = extractor.extract_dag()
wf = SnakemakeWorkflow(dag, version)
wf = SnakemakeWorkflow(dag, version, local_to_remote_path_mapping)
wf.compile()

return wf
Expand Down Expand Up @@ -319,7 +322,7 @@ def generate_snakemake_entrypoint(
import shutil
import subprocess
from subprocess import CalledProcessError
from typing import NamedTuple
from typing import NamedTuple, Dict
import stat
import sys
Expand All @@ -330,9 +333,19 @@ def generate_snakemake_entrypoint(
from latch.types.directory import LatchDir
from latch.types.file import LatchFile
from latch_cli.utils import urljoins
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def update_mapping(local: Path, remote: str, mapping: Dict[str, str]):
if local.is_file():
mapping[str(local)] = remote
return
for p in local.iterdir():
update_mapping(p, urljoins(remote, p.name), mapping)
def check_exists_and_rename(old: Path, new: Path):
if new.exists():
print(f"A file already exists at {new} and will be overwritten.")
Expand Down Expand Up @@ -388,7 +401,7 @@ def generate_jit_register_code(
from functools import partial
from pathlib import Path
import shutil
from typing import List, NamedTuple, Optional, TypedDict
from typing import List, NamedTuple, Optional, TypedDict, Dict
import hashlib
from urllib.parse import urljoin
Expand Down Expand Up @@ -417,6 +430,7 @@ def generate_jit_register_code(
serialize_snakemake,
)
import latch_cli.snakemake
from latch_cli.utils import urljoins
from latch import small_task
from latch_sdk_gql.execute import execute
Expand All @@ -426,6 +440,14 @@ def generate_jit_register_code(
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def update_mapping(local: Path, remote: str, mapping: Dict[str, str]):
if local.is_file():
mapping[str(local)] = remote
return
for p in local.iterdir():
update_mapping(p, urljoins(remote, p.name), mapping)
def check_exists_and_rename(old: Path, new: Path):
if new.exists():
print(f"A file already exists at {new} and will be overwritten.")
Expand Down
52 changes: 44 additions & 8 deletions latch_cli/snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ class RemoteFile:


def snakemake_dag_to_interface(
dag: DAG, wf_name: str, docstring: Optional[Docstring] = None
dag: DAG,
wf_name: str,
docstring: Optional[Docstring] = None,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
) -> Tuple[Interface, LiteralMap, List[RemoteFile]]:
outputs: Dict[str, LatchFile] = {}
outputs: Dict[str, Union[Type[LatchFile], Type[LatchDir]]] = {}
for target in dag.targetjobs:
for desired in target.input:
param = variable_name_for_value(desired, target.input)
Expand All @@ -141,7 +144,7 @@ def snakemake_dag_to_interface(
outputs[param] = LatchFile

literals: Dict[str, Literal] = {}
inputs: Dict[str, Tuple[LatchFile, None]] = {}
inputs: Dict[str, Tuple[Type[LatchFile], None]] = {}
return_files: List[RemoteFile] = []
for job in dag.jobs:
dep_outputs = []
Expand All @@ -157,11 +160,31 @@ def snakemake_dag_to_interface(
LatchFile,
None,
)

print(x)

print(local_to_remote_path_mapping)

remote_path = (
Path("/.snakemake_latch") / "workflows" / wf_name / "inputs" / x
)
remote_url = f"latch://{remote_path}"
return_files.append(RemoteFile(local_path=x, remote_path=remote_url))
use_original_remote_path: bool = (
local_to_remote_path_mapping is not None
and x in local_to_remote_path_mapping
)

if use_original_remote_path:
remote_path = local_to_remote_path_mapping.get(x)

remote_url = (
urlparse(str(remote_path))._replace(scheme="latch").geturl()
)

if not use_original_remote_path:
return_files.append(
RemoteFile(local_path=x, remote_path=remote_url)
)

literals[param] = Literal(
scalar=Scalar(
blob=Blob(
Expand Down Expand Up @@ -347,6 +370,13 @@ def get_fn_code(
code_block = ""
code_block += self.get_fn_interface(fn_name=task_name)

code_block += reindent(
r"""
local_to_remote_path_mapping = {}
""",
1,
)

for param, t in self.python_interface.inputs.items():
if t in (LatchFile, LatchDir):
param_meta = self.parameter_metadata[param]
Expand Down Expand Up @@ -383,6 +413,8 @@ def get_fn_code(
{param}_dst_p
)
update_mapping({param}_dst_p, {param}.remote_path, local_to_remote_path_mapping)
""",
1,
)
Expand All @@ -408,7 +440,7 @@ def get_fn_code(
exec_id_hash.update(os.environ["FLYTE_INTERNAL_EXECUTION_ID"].encode("utf-8"))
version = exec_id_hash.hexdigest()[:16]
wf = extract_snakemake_workflow(pkg_root, snakefile, version)
wf = extract_snakemake_workflow(pkg_root, snakefile, version, local_to_remote_path_mapping)
wf_name = wf.name
generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)})
Expand Down Expand Up @@ -583,7 +615,7 @@ class _WorkflowInfoNode(TypedDict):
_interface_request = {
"workflow_id": wf_id,
"params": params,
"snakemake_jit": True,
# "snakemake_jit": True,
}
response = requests.post(urljoin(config.nucleus_url, "/api/create-execution"), headers=headers, json=_interface_request)
Expand All @@ -600,14 +632,18 @@ def __init__(
self,
dag: DAG,
version: Optional[str] = None,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
):
assert metadata._snakemake_metadata is not None
name = metadata._snakemake_metadata.name

assert name is not None

native_interface, literal_map, return_files = snakemake_dag_to_interface(
dag, name, None
dag,
name,
None,
local_to_remote_path_mapping,
)
self.literal_map = literal_map
self.return_files = return_files
Expand Down

0 comments on commit 8c1cb00

Please sign in to comment.