diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 904ba51f42..55445edaa2 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -16,7 +16,6 @@ from flytekit.core.utils import timeit from flytekit.exceptions import scopes as exception_scopes from flytekit.loggers import logger -from flytekit.models.array_job import ArrayJob from flytekit.models.core.workflow import NodeMetadata from flytekit.models.interface import Variable from flytekit.models.task import Container, K8sPod, Sql, Task @@ -146,8 +145,8 @@ def prepare_target(self): finally: self.python_function_task.reset_command_fn() - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - return ArrayJob(parallelism=self._concurrency, min_success_ratio=self._min_success_ratio).to_dict() + def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]]: + return self.python_function_task.get_custom(settings) def get_container(self, settings: SerializationSettings) -> Container: with self.prepare_target(): diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index cd30ef1143..86c1b7470f 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -64,7 +64,6 @@ def t1(a: int) -> int: task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) assert task_spec.template.metadata.retries.retries == 2 - assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "python-task" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ @@ -94,6 +93,31 @@ def t1(a: int) -> int: ] +def test_serialize_plugin_custom(serialization_settings): + from flytekitplugins.kftensorflow import TfJob, Worker, Chief, PS, Evaluator + + task_config = TfJob( + worker=Worker(replicas=1), + chief=Chief(replicas=1), + ps=PS(replicas=1), + evaluator=Evaluator(replicas=1), + ) + + @task(task_config=task_config) + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2)) + task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) + + assert task_spec.template.custom == { + "chiefReplicas": {"replicas": 1, "resources": {}}, + "evaluatorReplicas": {"replicas": 1, "resources": {}}, + "psReplicas": {"replicas": 1, "resources": {}}, + "workerReplicas": {"replicas": 1, "resources": {}}, + } + + def test_fast_serialization(serialization_settings): serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True)