Skip to content

Commit

Permalink
Fix: Handle overriding of container image in backend (flyteorg#2176)
Browse files Browse the repository at this point in the history
* Handle overriding of container image in backend

Signed-off-by: Fabio Grätz <[email protected]>

* Adapt tests

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
fg91 and Fabio Grätz authored Mar 1, 2024
1 parent ea3c02d commit eb20459
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 10 deletions.
3 changes: 2 additions & 1 deletion flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self._outputs = None
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None
self._container_image: typing.Optional[str] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -193,7 +194,7 @@ def with_overrides(self, *args, **kwargs):
if "container_image" in kwargs:
v = kwargs["container_image"]
assert_not_promise(v, "container_image")
self.run_entity._container_image = v
self._container_image = v

if "accelerator" in kwargs:
v = kwargs["accelerator"]
Expand Down
16 changes: 13 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,14 @@ def from_flyte_idl(cls, pb2_object):

class TaskNodeOverrides(_common.FlyteIdlEntity):
def __init__(
self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources]
self,
resources: typing.Optional[Resources],
extended_resources: typing.Optional[tasks_pb2.ExtendedResources],
container_image: typing.Optional[str],
):
self._resources = resources
self._extended_resources = extended_resources
self._container_image = container_image

@property
def resources(self) -> Resources:
Expand All @@ -608,19 +612,25 @@ def resources(self) -> Resources:
def extended_resources(self) -> tasks_pb2.ExtendedResources:
return self._extended_resources

@property
def container_image(self) -> str:
return self._container_image

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
container_image=self.container_image,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
resources = Resources.from_flyte_idl(pb2_object.resources)
extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None
container_image = pb2_object.container_image if len(pb2_object.container_image) > 0 else None
if bool(resources.requests) or bool(resources.limits):
return cls(resources=resources, extended_resources=extended_resources)
return cls(resources=None, extended_resources=extended_resources)
return cls(resources=resources, extended_resources=extended_resources, container_image=container_image)
return cls(resources=None, extended_resources=extended_resources, container_image=container_image)


class TaskNode(_common.FlyteIdlEntity):
Expand Down
18 changes: 15 additions & 3 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,11 @@ def get_serializable_node(
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
overrides=TaskNodeOverrides(
resources=entity._resources,
extended_resources=entity._extended_resources,
container_image=entity._container_image,
),
),
)
if entity._aliases:
Expand Down Expand Up @@ -554,7 +558,11 @@ def get_serializable_node(
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=entity.flyte_entity.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
overrides=TaskNodeOverrides(
resources=entity._resources,
extended_resources=entity._extended_resources,
container_image=entity._container_image,
),
),
)
elif isinstance(entity.flyte_entity, FlyteWorkflow):
Expand Down Expand Up @@ -603,7 +611,11 @@ def get_serializable_array_node(
task_spec = get_serializable(entity_mapping, settings, entity, options)
task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources),
overrides=TaskNodeOverrides(
resources=node._resources,
extended_resources=node._extended_resources,
container_image=node._container_image,
),
)
node = workflow_model.Node(
id=entity.name,
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,4 @@ def my_mappable_task(a: int) -> typing.Optional[str]:
def wf(x: typing.List[int]):
array_node_map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0].run_entity.container_image == "random:image"
assert wf.nodes[0]._container_image == "random:image"
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def my_mappable_task(a: int) -> typing.Optional[str]:
def wf(x: typing.List[int]):
map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image")

assert wf.nodes[0].flyte_entity.run_task.container_image == "random:image"
assert wf.nodes[0]._container_image == "random:image"


def test_bounded_inputs_vars_order(serialization_settings):
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def wf() -> str:
bar().with_overrides(container_image="hello/world")
return "hi"

assert wf.nodes[0].flyte_entity.container_image == "hello/world"
assert wf.nodes[0]._container_image == "hello/world"


def test_override_accelerator():
Expand Down

0 comments on commit eb20459

Please sign in to comment.