diff --git a/latch/types/directory.py b/latch/types/directory.py index 0570a58c..71463e71 100644 --- a/latch/types/directory.py +++ b/latch/types/directory.py @@ -7,6 +7,7 @@ from flytekit.core.annotation import FlyteAnnotation from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.exceptions.user import FlyteUserException from flytekit.models.literals import Literal from flytekit.types.directory.types import ( FlyteDirectory, @@ -21,25 +22,41 @@ from latch_cli.utils.path import normalize_path -class Child(TypedDict): +class IterdirChild(TypedDict): type: str name: str -class ChildLdataTreeEdge(TypedDict): - child: Child +class IterdirChildLdataTreeEdge(TypedDict): + child: IterdirChild -class ChildLdataTreeEdges(TypedDict): - nodes: List[ChildLdataTreeEdge] +class IterdirChildLdataTreeEdges(TypedDict): + nodes: List[IterdirChildLdataTreeEdge] -class LDataResolvePathFinalLinkTarget(TypedDict): - childLdataTreeEdges: ChildLdataTreeEdges +class IterDirLDataResolvePathFinalLinkTarget(TypedDict): + childLdataTreeEdges: IterdirChildLdataTreeEdges -class LdataResolvePathData(TypedDict): - finalLinkTarget: LDataResolvePathFinalLinkTarget +class IterdirLdataResolvePathData(TypedDict): + finalLinkTarget: IterDirLDataResolvePathFinalLinkTarget + + +class NodeDescendantsNode(TypedDict): + relPath: str + + +class NodeDescendantsDescendants(TypedDict): + nodes: List[NodeDescendantsNode] + + +class NodeDescendantsFinalLinkTarget(TypedDict): + descendants: NodeDescendantsDescendants + + +class NodeDescendantsLDataResolvePathData(TypedDict): + finalLinkTarget: NodeDescendantsFinalLinkTarget class LatchDir(FlyteDirectory): @@ -93,6 +110,8 @@ def __init__( else: self.path = str(path) + self._path_generated = False + if _is_valid_url(self.path) and remote_path is None: self._remote_directory = self.path else: @@ -111,7 +130,8 @@ def downloader(): # todo(kenny) is this necessary? and ctx.inspect_objects_only is False ): - self.path = ctx.file_access.get_random_local_directory() + self._idempotent_set_path() + return ctx.file_access.get_data( self._remote_directory, self.path, @@ -120,6 +140,17 @@ def downloader(): super().__init__(self.path, downloader, self._remote_directory) + def _idempotent_set_path(self): + if self._path_generated: + return + + ctx = FlyteContextManager.current_context() + if ctx is None: + return + + self.path = ctx.file_access.get_random_local_directory() + self._path_generated = True + def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: ret: List[Union[LatchFile, "LatchDir"]] = [] @@ -132,7 +163,7 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: return ret - res: Optional[LdataResolvePathData] = execute( + res: Optional[IterdirLdataResolvePathData] = execute( gql.gql(""" query LDataChildren($argPath: String!) { ldataResolvePathData(argPath: $argPath) { @@ -165,6 +196,37 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]: return ret + def _create_imposters(self): + self._idempotent_set_path() + + res: Optional[NodeDescendantsLDataResolvePathData] = execute( + gql.gql(""" + query NodeDescendantsQuery($path: String!) { + ldataResolvePathData(argPath: $path) { + finalLinkTarget { + descendants { + nodes { + relPath + } + } + } + } + } + """), + {"path": self._remote_directory}, + )["ldataResolvePathData"] + + if res is None: + # todo(ayush): proper error message + exit + raise FlyteUserException(f"No directory at {self._remote_directory}") + + root = Path(self.path) + for x in res["finalLinkTarget"]["descendants"]["nodes"]: + p = root / x["relPath"] + + p.parent.mkdir(exist_ok=True, parents=True) + p.touch(exist_ok=True) + @property def local_path(self) -> str: """File path local to the environment executing the task.""" diff --git a/latch/types/file.py b/latch/types/file.py index 6af46837..2f37d62e 100644 --- a/latch/types/file.py +++ b/latch/types/file.py @@ -1,5 +1,7 @@ +import os import re from os import PathLike +from pathlib import Path from typing import Optional, Type, Union from urllib.parse import urlparse @@ -73,6 +75,8 @@ def __init__( else: self.path = str(path) + self._path_generated = False + if _is_valid_url(self.path) and remote_path is None: self._remote_path = str(path) else: @@ -107,7 +111,8 @@ def downloader(): if data is not None and data["name"] is not None: local_path_hint = data["name"] - self.path = ctx.file_access.get_random_local_path(local_path_hint) + self._idempotent_set_path(local_path_hint) + return ctx.file_access.get_data( self._remote_path, self.path, @@ -116,6 +121,24 @@ def downloader(): super().__init__(self.path, downloader, self._remote_path) + def _idempotent_set_path(self, hint: Optional[str] = None): + if self._path_generated: + return + + ctx = FlyteContextManager.current_context() + if ctx is None: + return + + self.path = ctx.file_access.get_random_local_path(hint) + self._path_generated = True + + def _create_imposters(self): + self._idempotent_set_path() + + p = Path(self.path) + p.parent.mkdir(exist_ok=True, parents=True) + p.touch(exist_ok=True) + @property def local_path(self) -> str: """File path local to the environment executing the task.""" diff --git a/latch_cli/snakemake/serialize.py b/latch_cli/snakemake/serialize.py index 382bddcd..e6e4816d 100644 --- a/latch_cli/snakemake/serialize.py +++ b/latch_cli/snakemake/serialize.py @@ -4,7 +4,7 @@ import traceback from pathlib import Path from textwrap import dedent -from typing import List, Optional, Set, Union, get_args +from typing import Dict, List, Optional, Set, Union, get_args import click from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan @@ -124,9 +124,7 @@ def extract_dag(self): priorityfiles=set(), ) - self.persistence = Persistence( - dag=dag, - ) + self._persistence = Persistence(dag=dag) dag.init() dag.update_checkpoint_dependencies() @@ -191,12 +189,15 @@ def snakemake_workflow_extractor( def extract_snakemake_workflow( - pkg_root: Path, snakefile: Path, version: Optional[str] = None + pkg_root: Path, + snakefile: Path, + version: Optional[str] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ) -> SnakemakeWorkflow: extractor = snakemake_workflow_extractor(pkg_root, snakefile, version) with extractor: dag = extractor.extract_dag() - wf = SnakemakeWorkflow(dag, version) + wf = SnakemakeWorkflow(dag, version, local_to_remote_path_mapping) wf.compile() return wf @@ -321,7 +322,7 @@ def generate_snakemake_entrypoint( import shutil import subprocess from subprocess import CalledProcessError - from typing import NamedTuple + from typing import NamedTuple, Dict import stat import sys @@ -332,9 +333,19 @@ def generate_snakemake_entrypoint( from latch.types.directory import LatchDir from latch.types.file import LatchFile + from latch_cli.utils import urljoins + sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) + def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): + if local.is_file(): + mapping[str(local)] = remote + return + + for p in local.iterdir(): + update_mapping(p, urljoins(remote, p.name), mapping) + def check_exists_and_rename(old: Path, new: Path): if new.exists(): print(f"A file already exists at {new} and will be overwritten.") @@ -390,7 +401,7 @@ def generate_jit_register_code( from functools import partial from pathlib import Path import shutil - from typing import List, NamedTuple, Optional, TypedDict + from typing import List, NamedTuple, Optional, TypedDict, Dict import hashlib from urllib.parse import urljoin @@ -419,6 +430,7 @@ def generate_jit_register_code( serialize_snakemake, ) import latch_cli.snakemake + from latch_cli.utils import urljoins from latch import small_task from latch_sdk_gql.execute import execute @@ -428,6 +440,14 @@ def generate_jit_register_code( sys.stdout.reconfigure(line_buffering=True) sys.stderr.reconfigure(line_buffering=True) + def update_mapping(local: Path, remote: str, mapping: Dict[str, str]): + if local.is_file(): + mapping[str(local)] = remote + return + + for p in local.iterdir(): + update_mapping(p, urljoins(remote, p.name), mapping) + def check_exists_and_rename(old: Path, new: Path): if new.exists(): print(f"A file already exists at {new} and will be overwritten.") diff --git a/latch_cli/snakemake/workflow.py b/latch_cli/snakemake/workflow.py index 102878a1..a122cf8a 100644 --- a/latch_cli/snakemake/workflow.py +++ b/latch_cli/snakemake/workflow.py @@ -124,9 +124,12 @@ class RemoteFile: def snakemake_dag_to_interface( - dag: DAG, wf_name: str, docstring: Optional[Docstring] = None + dag: DAG, + wf_name: str, + docstring: Optional[Docstring] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ) -> Tuple[Interface, LiteralMap, List[RemoteFile]]: - outputs: Dict[str, LatchFile] = {} + outputs: Dict[str, Union[Type[LatchFile], Type[LatchDir]]] = {} for target in dag.targetjobs: for desired in target.input: param = variable_name_for_value(desired, target.input) @@ -141,7 +144,7 @@ def snakemake_dag_to_interface( outputs[param] = LatchFile literals: Dict[str, Literal] = {} - inputs: Dict[str, Tuple[LatchFile, None]] = {} + inputs: Dict[str, Tuple[Type[LatchFile], None]] = {} return_files: List[RemoteFile] = [] for job in dag.jobs: dep_outputs = [] @@ -157,11 +160,31 @@ def snakemake_dag_to_interface( LatchFile, None, ) + + print(x) + + print(local_to_remote_path_mapping) + remote_path = ( Path("/.snakemake_latch") / "workflows" / wf_name / "inputs" / x ) - remote_url = f"latch://{remote_path}" - return_files.append(RemoteFile(local_path=x, remote_path=remote_url)) + use_original_remote_path: bool = ( + local_to_remote_path_mapping is not None + and x in local_to_remote_path_mapping + ) + + if use_original_remote_path: + remote_path = local_to_remote_path_mapping.get(x) + + remote_url = ( + urlparse(str(remote_path))._replace(scheme="latch").geturl() + ) + + if not use_original_remote_path: + return_files.append( + RemoteFile(local_path=x, remote_path=remote_url) + ) + literals[param] = Literal( scalar=Scalar( blob=Blob( @@ -347,6 +370,13 @@ def get_fn_code( code_block = "" code_block += self.get_fn_interface(fn_name=task_name) + code_block += reindent( + r""" + local_to_remote_path_mapping = {} + """, + 1, + ) + for param, t in self.python_interface.inputs.items(): if t in (LatchFile, LatchDir): param_meta = self.parameter_metadata[param] @@ -357,7 +387,8 @@ def get_fn_code( {param}_dst_p = Path("{param_meta.path}") print(f"Downloading {param}: {{{param}.remote_path}}") - {param}_p = Path({param}).resolve() + {param}._create_imposters() + {param}_p = Path({param}.path) print(f" {{file_name_and_size({param}_p)}}") """, @@ -382,6 +413,8 @@ def get_fn_code( {param}_dst_p ) + update_mapping({param}_dst_p, {param}.remote_path, local_to_remote_path_mapping) + """, 1, ) @@ -407,7 +440,7 @@ def get_fn_code( exec_id_hash.update(os.environ["FLYTE_INTERNAL_EXECUTION_ID"].encode("utf-8")) version = exec_id_hash.hexdigest()[:16] - wf = extract_snakemake_workflow(pkg_root, snakefile, version) + wf = extract_snakemake_workflow(pkg_root, snakefile, version, local_to_remote_path_mapping) wf_name = wf.name generate_snakemake_entrypoint(wf, pkg_root, snakefile, {repr(remote_output_url)}) @@ -582,7 +615,7 @@ class _WorkflowInfoNode(TypedDict): _interface_request = { "workflow_id": wf_id, "params": params, - "snakemake_jit": True, + # "snakemake_jit": True, } response = requests.post(urljoin(config.nucleus_url, "/api/create-execution"), headers=headers, json=_interface_request) @@ -599,6 +632,7 @@ def __init__( self, dag: DAG, version: Optional[str] = None, + local_to_remote_path_mapping: Optional[Dict[str, str]] = None, ): assert metadata._snakemake_metadata is not None name = metadata._snakemake_metadata.name @@ -606,7 +640,10 @@ def __init__( assert name is not None native_interface, literal_map, return_files = snakemake_dag_to_interface( - dag, name, None + dag, + name, + None, + local_to_remote_path_mapping, ) self.literal_map = literal_map self.return_files = return_files