diff --git a/flytekit/core/container_task.py b/flytekit/core/container_task.py index 1b078f83a7..7773226c1a 100644 --- a/flytekit/core/container_task.py +++ b/flytekit/core/container_task.py @@ -112,14 +112,18 @@ def _get_data_loading_config(self) -> _task_model.DataLoadingConfig: io_strategy=self._io_strategy.value if self._io_strategy else None, ) + def _get_image(self, settings: SerializationSettings) -> str: + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self._image, ImageSpec): + # Set the source root for the image spec if it's non-fast registration + self._image.source_root = settings.source_root + return get_registerable_container_image(self._image, settings.image_config) + def _get_container(self, settings: SerializationSettings) -> _task_model.Container: env = settings.env or {} env = {**env, **self.environment} if self.environment else env - if isinstance(self._image, ImageSpec): - if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: - self._image.source_root = settings.source_root return _get_container_definition( - image=get_registerable_container_image(self._image, settings.image_config), + image=self._get_image(settings), command=self._cmd, args=self._args, data_loading_config=self._get_data_loading_config(), diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index c43e3d4d14..7099456e5b 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -175,6 +175,13 @@ def get_command(self, settings: SerializationSettings) -> List[str]: """ return self._get_command_fn(settings) + def get_image(self, settings: SerializationSettings) -> str: + if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: + if isinstance(self.container_image, ImageSpec): + # Set the source root for the image spec if it's non-fast registration + self.container_image.source_root = settings.source_root + return get_registerable_container_image(self.container_image, settings.image_config) + def get_container(self, settings: SerializationSettings) -> _task_model.Container: # if pod_template is not None, return None here but in get_k8s_pod, return pod_template merged with container if self.pod_template is not None: @@ -187,11 +194,8 @@ def _get_container(self, settings: SerializationSettings) -> _task_model.Contain for elem in (settings.env, self.environment): if elem: env.update(elem) - if settings.fast_serialization_settings is None or not settings.fast_serialization_settings.enabled: - if isinstance(self.container_image, ImageSpec): - self.container_image.source_root = settings.source_root return _get_container_definition( - image=get_registerable_container_image(self.container_image, settings.image_config), + image=self.get_image(settings), command=[], args=self.get_command(settings=settings), data_loading_config=None, diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 4fd17fe40b..aad5adbd3f 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -986,6 +986,7 @@ def register_script( destination_dir=destination_dir, distribution_location=upload_native_url, ), + source_root=source_path, ) if version is None: diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 39a93afd06..079cf8815c 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -7,6 +7,7 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec @@ -136,6 +137,13 @@ def __init__( **kwargs, ) + def get_image(self, settings: SerializationSettings) -> str: + if isinstance(self.container_image, ImageSpec): + # Ensure that the code is always copied into the image, even during fast-registration. + self.container_image.source_root = settings.source_root + + return get_registerable_container_image(self.container_image, settings.image_config) + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: job = SparkJob( spark_conf=self.task_config.spark_conf,