Skip to content

Commit

Permalink
pyflyte run spark task (flyteorg#2280)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Mar 25, 2024
1 parent d9cea30 commit ecc7835
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 8 deletions.
12 changes: 8 additions & 4 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,18 @@ def _get_data_loading_config(self) -> _task_model.DataLoadingConfig:
io_strategy=self._io_strategy.value if self._io_strategy else None,
)

def _get_image(self, settings: SerializationSettings) -> str:
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
if isinstance(self._image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self._image.source_root = settings.source_root
return get_registerable_container_image(self._image, settings.image_config)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = settings.env or {}
env = {**env, **self.environment} if self.environment else env
if isinstance(self._image, ImageSpec):
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
self._image.source_root = settings.source_root
return _get_container_definition(
image=get_registerable_container_image(self._image, settings.image_config),
image=self._get_image(settings),
command=self._cmd,
args=self._args,
data_loading_config=self._get_data_loading_config(),
Expand Down
12 changes: 8 additions & 4 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]:
"""
return self._get_command_fn(settings)

def get_image(self, settings: SerializationSettings) -> str:
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
if isinstance(self.container_image, ImageSpec):
# Set the source root for the image spec if it's non-fast registration
self.container_image.source_root = settings.source_root
return get_registerable_container_image(self.container_image, settings.image_config)

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container
if self.pod_template is not None:
Expand All @@ -187,11 +194,8 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain
for elem in (settings.env, self.environment):
if elem:
env.update(elem)
if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled:
if isinstance(self.container_image, ImageSpec):
self.container_image.source_root = settings.source_root
return _get_container_definition(
image=get_registerable_container_image(self.container_image, settings.image_config),
image=self.get_image(settings),
command=[],
args=self.get_command(settings=settings),
data_loading_config=None,
Expand Down
1 change: 1 addition & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ def register_script(
destination_dir=destination_dir,
distribution_location=upload_native_url,
),
source_root=source_path,
)

if version is None:
Expand Down
8 changes: 8 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
Expand Down Expand Up @@ -136,6 +137,13 @@ def __init__(
**kwargs,
)

def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Ensure that the code is always copied into the image, even during fast-registration.
self.container_image.source_root = settings.source_root

return get_registerable_container_image(self.container_image, settings.image_config)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
Expand Down

0 comments on commit ecc7835

Please sign in to comment.