diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 61ae41c060..1c16539b92 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -47,6 +47,8 @@ class Node(object): ID, which from the registration step """ + TIMEOUT_OVERRIDE_SENTINEL = object() + def __init__( self, id: str, @@ -127,7 +129,7 @@ def metadata(self) -> _workflow_model.NodeMetadata: def _override_node_metadata( self, name, - timeout: Optional[Union[int, datetime.timedelta]] = None, + timeout: Optional[Union[int, datetime.timedelta, object]] = TIMEOUT_OVERRIDE_SENTINEL, retries: Optional[int] = None, interruptible: typing.Optional[bool] = None, cache: typing.Optional[bool] = None, @@ -142,14 +144,16 @@ def _override_node_metadata( else: node_metadata = self._metadata - if timeout is None: - node_metadata._timeout = datetime.timedelta() - elif isinstance(timeout, int): - node_metadata._timeout = datetime.timedelta(seconds=timeout) - elif isinstance(timeout, datetime.timedelta): - node_metadata._timeout = timeout - else: - raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") + if timeout is not Node.TIMEOUT_OVERRIDE_SENTINEL: + if timeout is None: + node_metadata._timeout = 0 + elif isinstance(timeout, int): + node_metadata._timeout = datetime.timedelta(seconds=timeout) + elif isinstance(timeout, datetime.timedelta): + node_metadata._timeout = timeout + else: + raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") + if retries is not None: assert_not_promise(retries, "retries") node_metadata._retries = ( @@ -181,7 +185,7 @@ def with_overrides( aliases: Optional[Dict[str, str]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - timeout: Optional[Union[int, datetime.timedelta]] = None, + timeout: Optional[Union[int, datetime.timedelta, object]] = TIMEOUT_OVERRIDE_SENTINEL, retries: Optional[int] = None, interruptible: Optional[bool] = None, name: Optional[str] = None, diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index d64f2461e5..76d817ffeb 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -586,7 +586,7 @@ def with_overrides( aliases: Optional[Dict[str, str]] = None, requests: Optional[Resources] = None, limits: Optional[Resources] = None, - timeout: Optional[Union[int, datetime.timedelta]] = None, + timeout: Optional[Union[int, datetime.timedelta, object]] = Node.TIMEOUT_OVERRIDE_SENTINEL, retries: Optional[int] = None, interruptible: Optional[bool] = None, name: Optional[str] = None, diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 381f456bdb..51686b7385 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -302,18 +302,42 @@ def my_wf(a: typing.List[str]) -> typing.List[str]: ] +preset_timeout = datetime.timedelta(seconds=100) + + @pytest.mark.parametrize( - "timeout,expected", - [(None, datetime.timedelta()), (10, datetime.timedelta(seconds=10))], + "timeout,t1_expected_timeout_overridden, t1_expected_timeout_unset, t2_expected_timeout_overridden, " + "t2_expected_timeout_unset", + [ + (None, 0, 0, 0, preset_timeout), + (10, datetime.timedelta(seconds=10), 0, + datetime.timedelta(seconds=10), preset_timeout) + ], ) -def test_timeout_override(timeout, expected): +def test_timeout_override( + timeout, + t1_expected_timeout_overridden, + t1_expected_timeout_unset, + t2_expected_timeout_overridden, + t2_expected_timeout_unset, + ): @task def t1(a: str) -> str: return f"*~*~*~{a}*~*~*~" + @task( + timeout=preset_timeout + ) + def t2(a: str) -> str: + return f"*~*~*~{a}*~*~*~" + @workflow def my_wf(a: str) -> str: - return t1(a=a).with_overrides(timeout=timeout) + s = t1(a=a).with_overrides(timeout=timeout) + s1 = t1(a=s).with_overrides() + s2 = t2(a=s1).with_overrides(timeout=timeout) + s3 = t2(a=s2).with_overrides() + return s3 serialization_settings = flytekit.configuration.SerializationSettings( project="test_proj", @@ -323,8 +347,11 @@ def my_wf(a: str) -> str: env={}, ) wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf) - assert len(wf_spec.template.nodes) == 1 - assert wf_spec.template.nodes[0].metadata.timeout == expected + assert len(wf_spec.template.nodes) == 4 + assert wf_spec.template.nodes[0].metadata.timeout == t1_expected_timeout_overridden + assert wf_spec.template.nodes[1].metadata.timeout == t1_expected_timeout_unset + assert wf_spec.template.nodes[2].metadata.timeout == t2_expected_timeout_overridden + assert wf_spec.template.nodes[3].metadata.timeout == t2_expected_timeout_unset def test_timeout_override_invalid_value():