From 95ea976d3028ee4126eb3f6b46a01ea8b3abd615 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Wed, 4 Oct 2023 17:51:38 -0700 Subject: [PATCH 1/5] save state Signed-off-by: Ayush Kamat --- latch/types/directory.py | 80 ++++++++++++++++++++++++++++++++-------- latch/types/file.py | 10 +++++ 2 files changed, 75 insertions(+), 15 deletions(-) diff --git a/latch/types/directory.py b/latch/types/directory.py index 0570a58c..45a8e6aa 100644 --- a/latch/types/directory.py +++ b/latch/types/directory.py @@ -7,6 +7,7 @@ from flytekit.core.annotation import FlyteAnnotation from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.exceptions.user import FlyteUserException from flytekit.models.literals import Literal from flytekit.types.directory.types import ( FlyteDirectory, @@ -21,25 +22,41 @@ from latch_cli.utils.path import normalize_path -class Child(TypedDict): +class IterdirChild(TypedDict): type: str name: str -class ChildLdataTreeEdge(TypedDict): - child: Child +class IterdirChildLdataTreeEdge(TypedDict): + child: IterdirChild -class ChildLdataTreeEdges(TypedDict): - nodes: List[ChildLdataTreeEdge] +class IterdirChildLdataTreeEdges(TypedDict): + nodes: List[IterdirChildLdataTreeEdge] -class LDataResolvePathFinalLinkTarget(TypedDict): - childLdataTreeEdges: ChildLdataTreeEdges +class IterDirLDataResolvePathFinalLinkTarget(TypedDict): + childLdataTreeEdges: IterdirChildLdataTreeEdges -class LdataResolvePathData(TypedDict): - finalLinkTarget: LDataResolvePathFinalLinkTarget +class IterdirLdataResolvePathData(TypedDict): + finalLinkTarget: IterDirLDataResolvePathFinalLinkTarget + + +class NodeDescendantsNode(TypedDict): + relPath: str + + +class NodeDescendantsDescendants(TypedDict): + nodes: List[NodeDescendantsNode] + + +class NodeDescendantsFinalLinkTarget(TypedDict): + descendants: NodeDescendantsDescendants + + +class NodeDescendantsLDataResolvePathData(TypedDict): + finalLinkTarget: NodeDescendantsFinalLinkTarget class LatchDir(FlyteDirectory): @@ -78,6 +95,8 @@ def __init__( self, path: Union[str, PathLike], remote_path: Optional[PathLike] = None, + *, + do_download: bool = True, **kwargs, ): if path is None: @@ -112,11 +131,42 @@ def downloader(): and ctx.inspect_objects_only is False ): self.path = ctx.file_access.get_random_local_directory() - return ctx.file_access.get_data( - self._remote_directory, - self.path, - is_multipart=True, - ) + + if do_download: + return ctx.file_access.get_data( + self._remote_directory, + self.path, + is_multipart=True, + ) + + res: Optional[NodeDescendantsLDataResolvePathData] = execute( + gql.gql(""" + query NodeDescendantsQuery($path: String!) { + ldataResolvePathData(argPath: $path) { + finalLinkTarget { + descendants { + nodes { + relPath + } + } + } + } + } + """), + {"path": self._remote_directory}, + )["ldataResolvePathData"] + + if res is None: + # todo(ayush): proper error message + exit + raise FlyteUserException( + f"No directory at {self._remote_directory}" + ) + + for x in res["finalLinkTarget"]["descendants"]["nodes"]: + p = Path(self.path) / x["relPath"] + + p.parent.mkdir(exist_ok=True, parents=True) + p.touch(exist_ok=True) super().__init__(self.path, downloader, self._remote_directory) @@ -132,7 +182,7 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: return ret - res: Optional[LdataResolvePathData] = execute( + res: Optional[IterdirLdataResolvePathData] = execute( gql.gql(""" query LDataChildren($argPath: String!) { ldataResolvePathData(argPath: $argPath) { diff --git a/latch/types/file.py b/latch/types/file.py index 6af46837..061fbd25 100644 --- a/latch/types/file.py +++ b/latch/types/file.py @@ -1,5 +1,7 @@ +import os import re from os import PathLike +from pathlib import Path from typing import Optional, Type, Union from urllib.parse import urlparse @@ -58,6 +60,8 @@ def __init__( self, path: Union[str, PathLike], remote_path: Optional[Union[str, PathLike]] = None, + *, + do_download: bool = True, **kwargs, ): if path is None: @@ -108,6 +112,12 @@ def downloader(): local_path_hint = data["name"] self.path = ctx.file_access.get_random_local_path(local_path_hint) + + if not do_download: + Path(self.path).parent.mkdir(parents=True, exist_ok=True) + Path(self.path).touch(exist_ok=True) + return + return ctx.file_access.get_data( self._remote_path, self.path, From 19ae7ad2a22654279ff2ebe74341d9012a8fbc67 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Sat, 7 Oct 2023 13:07:39 -0700 Subject: [PATCH 2/5] init Signed-off-by: Ayush Kamat --- latch/types/directory.py | 85 ++++++++++++++++++-------------- latch/types/file.py | 29 ++++++++--- latch_cli/snakemake/serialize.py | 4 +- latch_cli/snakemake/workflow.py | 3 +- 4 files changed, 72 insertions(+), 49 deletions(-) diff --git a/latch/types/directory.py b/latch/types/directory.py index 45a8e6aa..c716e8dd 100644 --- a/latch/types/directory.py +++ b/latch/types/directory.py @@ -95,8 +95,6 @@ def __init__( self, path: Union[str, PathLike], remote_path: Optional[PathLike] = None, - *, - do_download: bool = True, **kwargs, ): if path is None: @@ -112,6 +110,8 @@ def __init__( else: self.path = str(path) + self._path_witness = False + if _is_valid_url(self.path) and remote_path is None: self._remote_directory = self.path else: @@ -130,45 +130,26 @@ def downloader(): # todo(kenny) is this necessary? and ctx.inspect_objects_only is False ): - self.path = ctx.file_access.get_random_local_directory() - - if do_download: - return ctx.file_access.get_data( - self._remote_directory, - self.path, - is_multipart=True, - ) - - res: Optional[NodeDescendantsLDataResolvePathData] = execute( - gql.gql(""" - query NodeDescendantsQuery($path: String!) { - ldataResolvePathData(argPath: $path) { - finalLinkTarget { - descendants { - nodes { - relPath - } - } - } - } - } - """), - {"path": self._remote_directory}, - )["ldataResolvePathData"] + self._idempotent_set_path() - if res is None: - # todo(ayush): proper error message + exit - raise FlyteUserException( - f"No directory at {self._remote_directory}" - ) + return ctx.file_access.get_data( + self._remote_directory, + self.path, + is_multipart=True, + ) - for x in res["finalLinkTarget"]["descendants"]["nodes"]: - p = Path(self.path) / x["relPath"] + super().__init__(self.path, downloader, self._remote_directory) - p.parent.mkdir(exist_ok=True, parents=True) - p.touch(exist_ok=True) + def _idempotent_set_path(self): + if self._path_witness: + return - super().__init__(self.path, downloader, self._remote_directory) + ctx = FlyteContextManager.current_context() + if ctx is None: + return + + self.path = ctx.file_access.get_random_local_directory() + self._path_witness = True def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: ret: List[Union[LatchFile, "LatchDir"]] = [] @@ -215,6 +196,36 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: return ret + def touch(self): # fixme(ayush): better name + self._idempotent_set_path() + + res: Optional[NodeDescendantsLDataResolvePathData] = execute( + gql.gql(""" + query NodeDescendantsQuery($path: String!) { + ldataResolvePathData(argPath: $path) { + finalLinkTarget { + descendants { + nodes { + relPath + } + } + } + } + } + """), + {"path": self._remote_directory}, + )["ldataResolvePathData"] + + if res is None: + # todo(ayush): proper error message + exit + raise FlyteUserException(f"No directory at {self._remote_directory}") + + for x in res["finalLinkTarget"]["descendants"]["nodes"]: + p = Path(self.path) / x["relPath"] + + p.parent.mkdir(exist_ok=True, parents=True) + p.touch(exist_ok=True) + @property def local_path(self) -> str: """File path local to the environment executing the task.""" diff --git a/latch/types/file.py b/latch/types/file.py index 061fbd25..5888ca3f 100644 --- a/latch/types/file.py +++ b/latch/types/file.py @@ -60,8 +60,6 @@ def __init__( self, path: Union[str, PathLike], remote_path: Optional[Union[str, PathLike]] = None, - *, - do_download: bool = True, **kwargs, ): if path is None: @@ -77,6 +75,8 @@ def __init__( else: self.path = str(path) + self._path_witness = False + if _is_valid_url(self.path) and remote_path is None: self._remote_path = str(path) else: @@ -111,12 +111,7 @@ def downloader(): if data is not None and data["name"] is not None: local_path_hint = data["name"] - self.path = ctx.file_access.get_random_local_path(local_path_hint) - - if not do_download: - Path(self.path).parent.mkdir(parents=True, exist_ok=True) - Path(self.path).touch(exist_ok=True) - return + self._idempotent_set_path(local_path_hint) return ctx.file_access.get_data( self._remote_path, @@ -126,6 +121,24 @@ def downloader(): super().__init__(self.path, downloader, self._remote_path) + def _idempotent_set_path(self, hint: Optional[str] = None): + if self._path_witness: + return + + ctx = FlyteContextManager.current_context() + if ctx is None: + return + + self.path = ctx.file_access.get_random_local_path(hint) + self._path_witness = True + + def touch(self): + self._idempotent_set_path() + + p = Path(self.path) + p.parent.mkdir(exist_ok=True, parents=True) + p.touch(exist_ok=True) + @property def local_path(self) -> str: """File path local to the environment executing the task.""" diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 382bddcd..6b7ed075 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -124,9 +124,7 @@ def extract_dag(self): priorityfiles=set(), ) - self.persistence = Persistence( - dag=dag, - ) + self._persistence = Persistence(dag=dag) dag.init() dag.update_checkpoint_dependencies() diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 102878a1..8241f5a7 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -357,7 +357,8 @@ def get_fn_code( {param}_dst_p = Path("{param_meta.path}") print(f"Downloading {param}: {{{param}.remote_path}}") - {param}_p = Path({param}).resolve() + {param}.touch() + {param}_p = Path({param}.path) print(f" {{file_name_and_size({param}_p)}}") """, From 4c5294beee34fc64fdaa2ed3e73672529ec366e0 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Fri, 13 Oct 2023 13:55:20 -0700 Subject: [PATCH 3/5] rename things Signed-off-by: Ayush Kamat --- latch/types/directory.py | 8 ++++---- latch/types/file.py | 8 ++++---- latch_cli/snakemake/workflow.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/latch/types/directory.py b/latch/types/directory.py index c716e8dd..90823af7 100644 --- a/latch/types/directory.py +++ b/latch/types/directory.py @@ -110,7 +110,7 @@ def __init__( else: self.path = str(path) - self._path_witness = False + self._path_generated = False if _is_valid_url(self.path) and remote_path is None: self._remote_directory = self.path @@ -141,7 +141,7 @@ def downloader(): super().__init__(self.path, downloader, self._remote_directory) def _idempotent_set_path(self): - if self._path_witness: + if self._path_generated: return ctx = FlyteContextManager.current_context() @@ -149,7 +149,7 @@ def _idempotent_set_path(self): return self.path = ctx.file_access.get_random_local_directory() - self._path_witness = True + self._path_generated = True def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: ret: List[Union[LatchFile, "LatchDir"]] = [] @@ -196,7 +196,7 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: return ret - def touch(self): # fixme(ayush): better name + def _create_imposters(self): self._idempotent_set_path() res: Optional[NodeDescendantsLDataResolvePathData] = execute( diff --git a/latch/types/file.py b/latch/types/file.py index 5888ca3f..2f37d62e 100644 --- a/latch/types/file.py +++ b/latch/types/file.py @@ -75,7 +75,7 @@ def __init__( else: self.path = str(path) - self._path_witness = False + self._path_generated = False if _is_valid_url(self.path) and remote_path is None: self._remote_path = str(path) @@ -122,7 +122,7 @@ def downloader(): super().__init__(self.path, downloader, self._remote_path) def _idempotent_set_path(self, hint: Optional[str] = None): - if self._path_witness: + if self._path_generated: return ctx = FlyteContextManager.current_context() @@ -130,9 +130,9 @@ def _idempotent_set_path(self, hint: Optional[str] = None): return self.path = ctx.file_access.get_random_local_path(hint) - self._path_witness = True + self._path_generated = True - def touch(self): + def _create_imposters(self): self._idempotent_set_path() p = Path(self.path) diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 8241f5a7..2ac0c3ae 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -357,7 +357,7 @@ def get_fn_code( {param}_dst_p = Path("{param_meta.path}") print(f"Downloading {param}: {{{param}.remote_path}}") - {param}.touch() + {param}._create_imposters() {param}_p = Path({param}.path) print(f" {{file_name_and_size({param}_p)}}") From 29ababe54075d4e2412c22637c1b0fd0f55f1b34 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Fri, 13 Oct 2023 13:57:38 -0700 Subject: [PATCH 4/5] only initialize path once Signed-off-by: Ayush Kamat --- latch/types/directory.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/latch/types/directory.py b/latch/types/directory.py index 90823af7..71463e71 100644 --- a/latch/types/directory.py +++ b/latch/types/directory.py @@ -220,8 +220,9 @@ def _create_imposters(self): # todo(ayush): proper error message + exit raise FlyteUserException(f"No directory at {self._remote_directory}") + root = Path(self.path) for x in res["finalLinkTarget"]["descendants"]["nodes"]: - p = Path(self.path) / x["relPath"] + p = root / x["relPath"] p.parent.mkdir(exist_ok=True, parents=True) p.touch(exist_ok=True) From 8c1cb0041109e28f55d4911a692515aa58d6ce68 Mon Sep 17 00:00:00 2001 From: Ayush Kamat Date: Tue, 17 Oct 2023 14:27:39 -0700 Subject: [PATCH 5/5] jit touch Signed-off-by: Ayush Kamat --- latch_cli/snakemake/serialize.py | 32 +++++++++++++++++--- latch_cli/snakemake/workflow.py | 52 +++++++++++++++++++++++++++----- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 6b7ed075..e6e4816d 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -4,7 +4,7 @@ import traceback from pathlib import Path from textwrap import dedent -from typing import List, Optional, Set, Union, get_args +from typing import Dict, List, Optional, Set, Union, get_args import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan @@ -189,12 +189,15 @@ def snakemake_workflow_extractor( def extract_snakemake_workflow( - pkg_root: Path, snakefile: Path, version: Optional[str] = None + pkg_root: Path, + snakefile: Path, + version: Optional[str] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ) -> SnakemakeWorkflow: extractor = snakemake_workflow_extractor(pkg_root, snakefile, version) with extractor: dag = extractor.extract_dag() - wf = SnakemakeWorkflow(dag, version) + wf = SnakemakeWorkflow(dag, version, local_to_remote_path_mapping) wf.compile() return wf @@ -319,7 +322,7 @@ def generate_snakemake_entrypoint( import shutil import subprocess from subprocess import CalledProcessError - from typing import NamedTuple + from typing import NamedTuple, Dict import stat import sys @@ -330,9 +333,19 @@ def generate_snakemake_entrypoint( from latch.types.directory import LatchDir from latch.types.file import LatchFile + from latch_cli.utils import urljoins + sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) + def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): + if local.is_file(): + mapping[str(local)] = remote + return + + for p in local.iterdir(): + update_mapping(p, urljoins(remote, p.name), mapping) + def check_exists_and_rename(old: Path, new: Path): if new.exists(): print(f"A file already exists at {new} and will be overwritten.") @@ -388,7 +401,7 @@ def generate_jit_register_code( from functools import partial from pathlib import Path import shutil - from typing import List, NamedTuple, Optional, TypedDict + from typing import List, NamedTuple, Optional, TypedDict, Dict import hashlib from urllib.parse import urljoin @@ -417,6 +430,7 @@ def generate_jit_register_code( serialize_snakemake, ) import latch_cli.snakemake + from latch_cli.utils import urljoins from latch import small_task from latch_sdk_gql.execute import execute @@ -426,6 +440,14 @@ def generate_jit_register_code( sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) + def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): + if local.is_file(): + mapping[str(local)] = remote + return + + for p in local.iterdir(): + update_mapping(p, urljoins(remote, p.name), mapping) + def check_exists_and_rename(old: Path, new: Path): if new.exists(): print(f"A file already exists at {new} and will be overwritten.") diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 2ac0c3ae..a122cf8a 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -124,9 +124,12 @@ class RemoteFile: def snakemake_dag_to_interface( - dag: DAG, wf_name: str, docstring: Optional[Docstring] = None + dag: DAG, + wf_name: str, + docstring: Optional[Docstring] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ) -> Tuple[Interface, LiteralMap, List[RemoteFile]]: - outputs: Dict[str, LatchFile] = {} + outputs: Dict[str, Union[Type[LatchFile], Type[LatchDir]]] = {} for target in dag.targetjobs: for desired in target.input: param = variable_name_for_value(desired, target.input) @@ -141,7 +144,7 @@ def snakemake_dag_to_interface( outputs[param] = LatchFile literals: Dict[str, Literal] = {} - inputs: Dict[str, Tuple[LatchFile, None]] = {} + inputs: Dict[str, Tuple[Type[LatchFile], None]] = {} return_files: List[RemoteFile] = [] for job in dag.jobs: dep_outputs = [] @@ -157,11 +160,31 @@ def snakemake_dag_to_interface( LatchFile, None, ) + + print(x) + + print(local_to_remote_path_mapping) + remote_path = ( Path("/.snakemake_latch") / "workflows" / wf_name / "inputs" / x ) - remote_url = f"latch://{remote_path}" - return_files.append(RemoteFile(local_path=x, remote_path=remote_url)) + use_original_remote_path: bool = ( + local_to_remote_path_mapping is not None + and x in local_to_remote_path_mapping + ) + + if use_original_remote_path: + remote_path = local_to_remote_path_mapping.get(x) + + remote_url = ( + urlparse(str(remote_path))._replace(scheme="latch").geturl() + ) + + if not use_original_remote_path: + return_files.append( + RemoteFile(local_path=x, remote_path=remote_url) + ) + literals[param] = Literal( scalar=Scalar( blob=Blob( @@ -347,6 +370,13 @@ def get_fn_code( code_block = "" code_block += self.get_fn_interface(fn_name=task_name) + code_block += reindent( + r""" + local_to_remote_path_mapping = {} + """, + 1, + ) + for param, t in self.python_interface.inputs.items(): if t in (LatchFile, LatchDir): param_meta = self.parameter_metadata[param] @@ -383,6 +413,8 @@ def get_fn_code( {param}_dst_p ) + update_mapping({param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) + """, 1, ) @@ -408,7 +440,7 @@ def get_fn_code( exec_id_hash.update(os.environ["FLYTE_INTERNAL_EXECUTION_ID"].encode("utf-8")) version = exec_id_hash.hexdigest()[:16] - wf = extract_snakemake_workflow(pkg_root, snakefile, version) + wf = extract_snakemake_workflow(pkg_root, snakefile, version, local_to_remote_path_mapping) wf_name = wf.name generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}) @@ -583,7 +615,7 @@ class _WorkflowInfoNode(TypedDict): _interface_request = { "workflow_id": wf_id, "params": params, - "snakemake_jit": True, + # "snakemake_jit": True, } response = requests.post(urljoin(config.nucleus_url, "/api/create-execution"), headers=headers, json=_interface_request) @@ -600,6 +632,7 @@ def __init__( self, dag: DAG, version: Optional[str] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ): assert metadata._snakemake_metadata is not None name = metadata._snakemake_metadata.name @@ -607,7 +640,10 @@ def __init__( assert name is not None native_interface, literal_map, return_files = snakemake_dag_to_interface( - dag, name, None + dag, + name, + None, + local_to_remote_path_mapping, ) self.literal_map = literal_map self.return_files = return_files