diff --git a/flytekit/core/node.py b/flytekit/core/node.py index f5a3db4afa..e8a37bf3f0 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -65,6 +65,7 @@ def __init__( self._outputs = None self._resources: typing.Optional[_resources_model] = None self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None + self._container_image: typing.Optional[str] = None def runs_before(self, other: Node): """ @@ -193,7 +194,7 @@ def with_overrides(self, *args, **kwargs): if "container_image" in kwargs: v = kwargs["container_image"] assert_not_promise(v, "container_image") - self.run_entity._container_image = v + self._container_image = v if "accelerator" in kwargs: v = kwargs["accelerator"] diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 62636d1420..aef5d3c46f 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -595,10 +595,14 @@ def from_flyte_idl(cls, pb2_object): class TaskNodeOverrides(_common.FlyteIdlEntity): def __init__( - self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources] + self, + resources: typing.Optional[Resources], + extended_resources: typing.Optional[tasks_pb2.ExtendedResources], + container_image: typing.Optional[str], ): self._resources = resources self._extended_resources = extended_resources + self._container_image = container_image @property def resources(self) -> Resources: @@ -608,19 +612,25 @@ def resources(self) -> Resources: def extended_resources(self) -> tasks_pb2.ExtendedResources: return self._extended_resources + @property + def container_image(self) -> str: + return self._container_image + def to_flyte_idl(self): return _core_workflow.TaskNodeOverrides( resources=self.resources.to_flyte_idl() if self.resources is not None else None, extended_resources=self.extended_resources, + container_image=self.container_image, ) @classmethod def from_flyte_idl(cls, pb2_object): resources = Resources.from_flyte_idl(pb2_object.resources) extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None + container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None if bool(resources.requests) or bool(resources.limits): - return cls(resources=resources, extended_resources=extended_resources) - return cls(resources=None, extended_resources=extended_resources) + return cls(resources=resources, extended_resources=extended_resources, container_image=container_image) + return cls(resources=None, extended_resources=extended_resources, container_image=container_image) class TaskNode(_common.FlyteIdlEntity): diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 7bc719cef8..2847ff1b3d 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -477,7 +477,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) if entity._aliases: @@ -554,7 +558,11 @@ def get_serializable_node( output_aliases=[], task_node=workflow_model.TaskNode( reference_id=entity.flyte_entity.id, - overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources), + overrides=TaskNodeOverrides( + resources=entity._resources, + extended_resources=entity._extended_resources, + container_image=entity._container_image, + ), ), ) elif isinstance(entity.flyte_entity, FlyteWorkflow): @@ -603,7 +611,11 @@ def get_serializable_array_node( task_spec = get_serializable(entity_mapping, settings, entity, options) task_node = workflow_model.TaskNode( reference_id=task_spec.template.id, - overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources), + overrides=TaskNodeOverrides( + resources=node._resources, + extended_resources=node._extended_resources, + container_image=node._container_image, + ), ) node = workflow_model.Node( id=entity.name, diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 40bb864c4f..4bcecde6a7 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -305,4 +305,4 @@ def my_mappable_task(a: int) -> typing.Optional[str]: def wf(x: typing.List[int]): array_node_map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") - assert wf.nodes[0].run_entity.container_image == "random:image" + assert wf.nodes[0]._container_image == "random:image" diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index c87d4c6b1f..2ae716d4b7 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -352,7 +352,7 @@ def my_mappable_task(a: int) -> typing.Optional[str]: def wf(x: typing.List[int]): map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") - assert wf.nodes[0].flyte_entity.run_task.container_image == "random:image" + assert wf.nodes[0]._container_image == "random:image" def test_bounded_inputs_vars_order(serialization_settings): diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index df16ddd244..56eb82aa1d 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -465,7 +465,7 @@ def wf() -> str: bar().with_overrides(container_image="hello/world") return "hi" - assert wf.nodes[0].flyte_entity.container_image == "hello/world" + assert wf.nodes[0]._container_image == "hello/world" def test_override_accelerator():