Skip to content

Commit

Permalink
Raw Container Task Local Execution (#2258)
Browse files Browse the repository at this point in the history
* init

Signed-off-by: Future-Outlier <[email protected]>

* v1

Signed-off-by: Future-Outlier <[email protected]>

* argurments bug fixed and add log when pulling image

Signed-off-by: Future-Outlier <[email protected]>

* change v to k and handle boolean special case

Signed-off-by: Future-Outlier <[email protected]>

* support blob type and datetime

Signed-off-by: Future-Outlier <[email protected]>

* add unit tests

Signed-off-by: Future-Outlier <[email protected]>

* add exception

Signed-off-by: Future-Outlier <[email protected]>

* nit

Signed-off-by: Future-Outlier <[email protected]>

* fix test

Signed-off-by: Future-Outlier <[email protected]>

* update for flytefile and flytedirectory

Signed-off-by: Future-Outlier <[email protected]>

* support both file paths and template inputs

Signed-off-by: Future-Outlier <[email protected]>

* pytest use sys platform to handle macos and windows case and support regex to parse the input

Signed-off-by: Future-Outlier <[email protected]>

* support datetime.timedelta

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* add tests and change boolean logic

Signed-off-by: Future-Outlier <[email protected]>

* support

Signed-off-by: Future-Outlier <[email protected]>

* change annotations

Signed-off-by: Future-Outlier <[email protected]>

* add flytefile and flytedir tests

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* add more tests

Signed-off-by: Future-Outlier <[email protected]>

* lint

Signed-off-by: Future-Outlier <[email protected]>

* change image name

Signed-off-by: Future-Outlier <[email protected]>

* Update pingsu's advice

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>

* add docker in dev-requirement

Signed-off-by: Future-Outlier <[email protected]>

* refactor execution

Signed-off-by: Future-Outlier <[email protected]>

* use render pattern

Signed-off-by: Future-Outlier <[email protected]>

* add back container task object in test

Signed-off-by: Future-Outlier <[email protected]>

* refactor output in container task execution

Signed-off-by: Future-Outlier <[email protected]>

* update pingsu's render input advice

Signed-off-by: Future-Outlier <[email protected]>

* update tests

Signed-off-by: Future-Outlier <[email protected]>

* add LiteralMap TypeHints

Signed-off-by: Future-Outlier <[email protected]>

* update dev-req

Signed-off-by: Future-Outlier <[email protected]>

---------

Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
  • Loading branch information
2 people authored and fiedlerNr9 committed Jul 25, 2024
1 parent 9cd1b13 commit edb6037
Show file tree
Hide file tree
Showing 7 changed files with 525 additions and 17 deletions.
165 changes: 162 additions & 3 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import typing
from enum import Enum
from typing import Any, Dict, List, Optional, OrderedDict, Type
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand All @@ -11,10 +12,13 @@
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
DOCKER_IMPORT_ERROR_MESSAGE = "Docker is not installed. Please install Docker by running `pip install docker`."


class ContainerTask(PythonTask):
Expand Down Expand Up @@ -82,6 +86,7 @@ def __init__(
self._args = arguments
self._input_data_dir = input_data_dir
self._output_data_dir = output_data_dir
self._outputs = outputs
self._md_format = metadata_format
self._io_strategy = io_strategy
self._resources = ResourceSpec(
Expand All @@ -93,8 +98,162 @@ def __init__(
def resources(self) -> ResourceSpec:
return self._resources

def local_execute(self, ctx: FlyteContext, **kwargs) -> Any:
raise RuntimeError("ContainerTask is not supported in local executions.")
def _extract_command_key(self, cmd: str, **kwargs) -> Any:
"""
Extract the key from the command using regex.
"""
import re

input_regex = r"^\{\{\s*\.inputs\.(.*?)\s*\}\}$"
match = re.match(input_regex, cmd)
if match:
return match.group(1)
return None

def _render_command_and_volume_binding(self, cmd: str, **kwargs) -> Tuple[str, Dict[str, Dict[str, str]]]:
"""
We support template-style references to inputs, e.g., "{{.inputs.infile}}".
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

command = ""
volume_binding = {}
k = self._extract_command_key(cmd)

if k:
input_val = kwargs.get(k)
if type(input_val) in [FlyteFile, FlyteDirectory]:
local_flyte_file_or_dir_path = str(input_val)
remote_flyte_file_or_dir_path = os.path.join(self._input_data_dir, k.replace(".", "/")) # type: ignore
volume_binding[local_flyte_file_or_dir_path] = {
"bind": remote_flyte_file_or_dir_path,
"mode": "rw",
}
command = remote_flyte_file_or_dir_path
else:
command = str(input_val)
else:
command = cmd

return command, volume_binding

def _prepare_command_and_volumes(
self, cmd_and_args: List[str], **kwargs
) -> Tuple[List[str], Dict[str, Dict[str, str]]]:
"""
Prepares the command and volume bindings for the container based on input arguments and command templates.
Parameters:
- cmd_and_args (List[str]): The command and arguments to prepare.
- **kwargs: Keyword arguments representing task inputs.
Returns:
- Tuple[List[str], Dict[str, Dict[str, str]]]: A tuple containing the prepared commands and volume bindings.
"""

commands = []
volume_bindings = {}

for cmd in cmd_and_args:
command, volume_binding = self._render_command_and_volume_binding(cmd, **kwargs)
commands.append(command)
volume_bindings.update(volume_binding)

return commands, volume_bindings

def _pull_image_if_not_exists(self, client, image: str):
try:
if not client.images.list(filters={"reference": image}):
logger.info(f"Pulling image: {image} for container task: {self.name}")
client.images.pull(image)
except Exception as e:
logger.error(f"Failed to pull image {image}: {str(e)}")
raise

def _string_to_timedelta(self, s: str):
import datetime
import re

regex = r"(?:(\d+) days?, )?(?:(\d+):)?(\d+):(\d+)(?:\.(\d+))?"
parts = re.match(regex, s)
if not parts:
raise ValueError("Invalid timedelta string format")

days = int(parts.group(1)) if parts.group(1) else 0
hours = int(parts.group(2)) if parts.group(2) else 0
minutes = int(parts.group(3)) if parts.group(3) else 0
seconds = int(parts.group(4)) if parts.group(4) else 0
microseconds = int(parts.group(5)) if parts.group(5) else 0

return datetime.timedelta(
days=days,
hours=hours,
minutes=minutes,
seconds=seconds,
microseconds=microseconds,
)

def _convert_output_val_to_correct_type(self, output_val: Any, output_type: Any) -> Any:
import datetime

if output_type == bool:
return output_val.lower() != "false"
elif output_type == datetime.datetime:
return datetime.datetime.fromisoformat(output_val)
elif output_type == datetime.timedelta:
return self._string_to_timedelta(output_val)
else:
return output_type(output_val)

def _get_output_dict(self, output_directory: str) -> Dict[str, Any]:
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

output_dict = {}
if self._outputs:
for k, output_type in self._outputs.items():
output_path = os.path.join(output_directory, k)
if output_type in [FlyteFile, FlyteDirectory]:
output_dict[k] = output_type(path=output_path)
else:
with open(output_path, "r") as f:
output_val = f.read()
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
return output_dict

def execute(self, **kwargs) -> LiteralMap:
try:
import docker
except ImportError:
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)

from flytekit.core.type_engine import TypeEngine

ctx = FlyteContext.current_context()

# Normalize the input and output directories
self._input_data_dir = os.path.normpath(self._input_data_dir) if self._input_data_dir else ""
self._output_data_dir = os.path.normpath(self._output_data_dir) if self._output_data_dir else ""

output_directory = ctx.file_access.get_random_local_directory()
cmd_and_args = (self._cmd or []) + (self._args or [])
commands, volume_bindings = self._prepare_command_and_volumes(cmd_and_args, **kwargs)
volume_bindings[output_directory] = {"bind": self._output_data_dir, "mode": "rw"}

client = docker.from_env()
self._pull_image_if_not_exists(client, self._image)

container = client.containers.run(
self._image, command=commands, remove=True, volumes=volume_bindings, detach=True
)
# Wait for the container to finish the task
# TODO: Add a 'timeout' parameter to control the max wait time for the container to finish the task.
container.wait()

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
return outputs_literal_map

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
9 changes: 9 additions & 0 deletions tests/flytekit/unit/core/Dockerfile.raw_container
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM python:3.9-alpine

WORKDIR /root

COPY ./write_flytefile.py /root/write_flytefile.py
COPY ./write_flytedir.py /root/write_flytedir.py
COPY ./return_same_value.py /root/return_same_value.py

CMD ["/bin/sh"]
22 changes: 22 additions & 0 deletions tests/flytekit/unit/core/return_same_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import sys


def write_output(output_dir, output_file, v):
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True) # This will create the directory if it doesn't exist
with open(f"{output_dir}/{output_file}", "w") as f:
f.write(str(v))


def main(*args, output_dir):
# Generate output files for each input argument
for i, arg in enumerate(args, start=1):
# Using i to generate filenames like 'a', 'b', 'c', ...
output_file = chr(ord("a") + i - 1)
write_output(output_dir, output_file, arg)


if __name__ == "__main__":
*inputs, output_dir = sys.argv[1:] # Unpack all inputs except for the last one for output_dir
main(*inputs, output_dir=output_dir)
87 changes: 73 additions & 14 deletions tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import sys
from collections import OrderedDict
from typing import Tuple

import pytest
from kubernetes.client.models import (
Expand All @@ -13,14 +16,83 @@
V1Toleration,
)

from flytekit import kwtypes
from flytekit import kwtypes, task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.container_task import ContainerTask
from flytekit.core.pod_template import PodTemplate
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="Skip if running on windows or macos due to CI Docker environment setup failure",
)
def test_local_execution():
calculate_ellipse_area_python_template_style = ContainerTask(
name="calculate_ellipse_area_python_template_style",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=float, b=float),
outputs=kwtypes(area=float, metadata=str),
image="ghcr.io/flyteorg/rawcontainers-python:v2",
command=[
"python",
"calculate-ellipse-area.py",
"{{.inputs.a}}",
"{{.inputs.b}}",
"/var/outputs",
],
)

area, metadata = calculate_ellipse_area_python_template_style(a=3.0, b=4.0)
assert isinstance(area, float)
assert isinstance(metadata, str)

# Workflow execution with container task
@task
def t1(a: float, b: float) -> Tuple[float, float]:
return a + b, a * b

@workflow
def wf(a: float, b: float) -> Tuple[float, str]:
a, b = t1(a=a, b=b)
area, metadata = calculate_ellipse_area_python_template_style(a=a, b=b)
return area, metadata

area, metadata = wf(a=3.0, b=4.0)
assert isinstance(area, float)
assert isinstance(metadata, str)


@pytest.mark.skipif(
sys.platform == "win32",
reason="Skip if running on windows due to path error",
)
def test_local_execution_special_cases():
# Boolean conversion from string checks
assert all([bool(s) for s in ["False", "false", "True", "true"]])

# Path normalization
input_data_dir = "/var/inputs"
assert os.path.normpath(input_data_dir) == "/var/inputs"
assert os.path.normpath(input_data_dir + "/") == "/var/inputs"

# Datetime and timedelta string conversions
ct = ContainerTask(
name="local-execution",
image="test-image",
command="echo",
)

from datetime import datetime, timedelta

now = datetime.now()
assert datetime.fromisoformat(str(now)) == now
td = timedelta(days=1, hours=1, minutes=1, seconds=1, microseconds=1)
assert td == ct._string_to_timedelta(str(td))


def test_pod_template():
ps = V1PodSpec(
containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")]
Expand Down Expand Up @@ -86,19 +158,6 @@ def test_pod_template():
assert serialized_pod_spec["runtimeClassName"] == "nvidia"


def test_local_execution():
ct = ContainerTask(
name="name",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
image="inexistent-image:v42",
command=["some", "command"],
)

with pytest.raises(RuntimeError):
ct()


def test_raw_container_with_image_spec(mock_image_spec_builder):
ImageBuildEngine.register("test-raw-container", mock_image_spec_builder)
image_spec = ImageSpec(registry="flyte", base_image="r-base", builder="test-raw-container")
Expand Down
Loading

0 comments on commit edb6037

Please sign in to comment.