Skip to content

Commit

Permalink
Merge pull request #327 from latchbio/ayush/jit-touch
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushkamat authored Oct 20, 2023
2 parents 2191676 + 8c1cb00 commit 0e445d0
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 29 deletions.
84 changes: 73 additions & 11 deletions latch/types/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.core.annotation import FlyteAnnotation
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer
from flytekit.exceptions.user import FlyteUserException
from flytekit.models.literals import Literal
from flytekit.types.directory.types import (
FlyteDirectory,
Expand All @@ -21,25 +22,41 @@
from latch_cli.utils.path import normalize_path


class Child(TypedDict):
class IterdirChild(TypedDict):
type: str
name: str


class ChildLdataTreeEdge(TypedDict):
child: Child
class IterdirChildLdataTreeEdge(TypedDict):
child: IterdirChild


class ChildLdataTreeEdges(TypedDict):
nodes: List[ChildLdataTreeEdge]
class IterdirChildLdataTreeEdges(TypedDict):
nodes: List[IterdirChildLdataTreeEdge]


class LDataResolvePathFinalLinkTarget(TypedDict):
childLdataTreeEdges: ChildLdataTreeEdges
class IterDirLDataResolvePathFinalLinkTarget(TypedDict):
childLdataTreeEdges: IterdirChildLdataTreeEdges


class LdataResolvePathData(TypedDict):
finalLinkTarget: LDataResolvePathFinalLinkTarget
class IterdirLdataResolvePathData(TypedDict):
finalLinkTarget: IterDirLDataResolvePathFinalLinkTarget


class NodeDescendantsNode(TypedDict):
relPath: str


class NodeDescendantsDescendants(TypedDict):
nodes: List[NodeDescendantsNode]


class NodeDescendantsFinalLinkTarget(TypedDict):
descendants: NodeDescendantsDescendants


class NodeDescendantsLDataResolvePathData(TypedDict):
finalLinkTarget: NodeDescendantsFinalLinkTarget


class LatchDir(FlyteDirectory):
Expand Down Expand Up @@ -93,6 +110,8 @@ def __init__(
else:
self.path = str(path)

self._path_generated = False

if _is_valid_url(self.path) and remote_path is None:
self._remote_directory = self.path
else:
Expand All @@ -111,7 +130,8 @@ def downloader():
# todo(kenny) is this necessary?
and ctx.inspect_objects_only is False
):
self.path = ctx.file_access.get_random_local_directory()
self._idempotent_set_path()

return ctx.file_access.get_data(
self._remote_directory,
self.path,
Expand All @@ -120,6 +140,17 @@ def downloader():

super().__init__(self.path, downloader, self._remote_directory)

def _idempotent_set_path(self):
if self._path_generated:
return

ctx = FlyteContextManager.current_context()
if ctx is None:
return

self.path = ctx.file_access.get_random_local_directory()
self._path_generated = True

def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]:
ret: List[Union[LatchFile, "LatchDir"]] = []

Expand All @@ -132,7 +163,7 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]:

return ret

res: Optional[LdataResolvePathData] = execute(
res: Optional[IterdirLdataResolvePathData] = execute(
gql.gql("""
query LDataChildren($argPath: String!) {
ldataResolvePathData(argPath: $argPath) {
Expand Down Expand Up @@ -165,6 +196,37 @@ def iterdir(self) -> List[Union[LatchFile, "LatchDir"]]:

return ret

def _create_imposters(self):
self._idempotent_set_path()

res: Optional[NodeDescendantsLDataResolvePathData] = execute(
gql.gql("""
query NodeDescendantsQuery($path: String!) {
ldataResolvePathData(argPath: $path) {
finalLinkTarget {
descendants {
nodes {
relPath
}
}
}
}
}
"""),
{"path": self._remote_directory},
)["ldataResolvePathData"]

if res is None:
# todo(ayush): proper error message + exit
raise FlyteUserException(f"No directory at {self._remote_directory}")

root = Path(self.path)
for x in res["finalLinkTarget"]["descendants"]["nodes"]:
p = root / x["relPath"]

p.parent.mkdir(exist_ok=True, parents=True)
p.touch(exist_ok=True)

@property
def local_path(self) -> str:
"""File path local to the environment executing the task."""
Expand Down
25 changes: 24 additions & 1 deletion latch/types/file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import re
from os import PathLike
from pathlib import Path
from typing import Optional, Type, Union
from urllib.parse import urlparse

Expand Down Expand Up @@ -73,6 +75,8 @@ def __init__(
else:
self.path = str(path)

self._path_generated = False

if _is_valid_url(self.path) and remote_path is None:
self._remote_path = str(path)
else:
Expand Down Expand Up @@ -107,7 +111,8 @@ def downloader():
if data is not None and data["name"] is not None:
local_path_hint = data["name"]

self.path = ctx.file_access.get_random_local_path(local_path_hint)
self._idempotent_set_path(local_path_hint)

return ctx.file_access.get_data(
self._remote_path,
self.path,
Expand All @@ -116,6 +121,24 @@ def downloader():

super().__init__(self.path, downloader, self._remote_path)

def _idempotent_set_path(self, hint: Optional[str] = None):
if self._path_generated:
return

ctx = FlyteContextManager.current_context()
if ctx is None:
return

self.path = ctx.file_access.get_random_local_path(hint)
self._path_generated = True

def _create_imposters(self):
self._idempotent_set_path()

p = Path(self.path)
p.parent.mkdir(exist_ok=True, parents=True)
p.touch(exist_ok=True)

@property
def local_path(self) -> str:
"""File path local to the environment executing the task."""
Expand Down
36 changes: 28 additions & 8 deletions latch_cli/snakemake/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traceback
from pathlib import Path
from textwrap import dedent
from typing import List, Optional, Set, Union, get_args
from typing import Dict, List, Optional, Set, Union, get_args

import click
from flyteidl.admin.launch_plan_pb2 import LaunchPlan as _idl_admin_LaunchPlan
Expand Down Expand Up @@ -124,9 +124,7 @@ def extract_dag(self):
priorityfiles=set(),
)

self.persistence = Persistence(
dag=dag,
)
self._persistence = Persistence(dag=dag)

dag.init()
dag.update_checkpoint_dependencies()
Expand Down Expand Up @@ -191,12 +189,15 @@ def snakemake_workflow_extractor(


def extract_snakemake_workflow(
pkg_root: Path, snakefile: Path, version: Optional[str] = None
pkg_root: Path,
snakefile: Path,
version: Optional[str] = None,
local_to_remote_path_mapping: Optional[Dict[str, str]] = None,
) -> SnakemakeWorkflow:
extractor = snakemake_workflow_extractor(pkg_root, snakefile, version)
with extractor:
dag = extractor.extract_dag()
wf = SnakemakeWorkflow(dag, version)
wf = SnakemakeWorkflow(dag, version, local_to_remote_path_mapping)
wf.compile()

return wf
Expand Down Expand Up @@ -321,7 +322,7 @@ def generate_snakemake_entrypoint(
import shutil
import subprocess
from subprocess import CalledProcessError
from typing import NamedTuple
from typing import NamedTuple, Dict
import stat
import sys
Expand All @@ -332,9 +333,19 @@ def generate_snakemake_entrypoint(
from latch.types.directory import LatchDir
from latch.types.file import LatchFile
from latch_cli.utils import urljoins
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def update_mapping(local: Path, remote: str, mapping: Dict[str, str]):
if local.is_file():
mapping[str(local)] = remote
return
for p in local.iterdir():
update_mapping(p, urljoins(remote, p.name), mapping)
def check_exists_and_rename(old: Path, new: Path):
if new.exists():
print(f"A file already exists at {new} and will be overwritten.")
Expand Down Expand Up @@ -390,7 +401,7 @@ def generate_jit_register_code(
from functools import partial
from pathlib import Path
import shutil
from typing import List, NamedTuple, Optional, TypedDict
from typing import List, NamedTuple, Optional, TypedDict, Dict
import hashlib
from urllib.parse import urljoin
Expand Down Expand Up @@ -419,6 +430,7 @@ def generate_jit_register_code(
serialize_snakemake,
)
import latch_cli.snakemake
from latch_cli.utils import urljoins
from latch import small_task
from latch_sdk_gql.execute import execute
Expand All @@ -428,6 +440,14 @@ def generate_jit_register_code(
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def update_mapping(local: Path, remote: str, mapping: Dict[str, str]):
if local.is_file():
mapping[str(local)] = remote
return
for p in local.iterdir():
update_mapping(p, urljoins(remote, p.name), mapping)
def check_exists_and_rename(old: Path, new: Path):
if new.exists():
print(f"A file already exists at {new} and will be overwritten.")
Expand Down
Loading

0 comments on commit 0e445d0

Please sign in to comment.