From b42edfaca1a9e10a1f22744d35b4a2778a68a6eb Mon Sep 17 00:00:00 2001 From: Umer Ahmad Date: Sat, 25 Jan 2025 03:30:28 -0500 Subject: [PATCH 1/3] adding missing fields to flytetask remote entity Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad --- flytekit/remote/entities.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index fd78d4c3c4..3f397c5db3 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -49,7 +49,11 @@ def __init__( custom, container=None, task_type_version: int = 0, + security_context=None, config=None, + k8s_pod=None, + sql=None, + extended_resources=None, should_register: bool = False, ): super(FlyteTask, self).__init__( @@ -61,7 +65,11 @@ def __init__( custom, container=container, task_type_version=task_type_version, + security_context=security_context, config=config, + k8s_pod=k8s_pod, + sql=sql, + extended_resources=extended_resources, ) ) self._should_register = should_register @@ -146,6 +154,10 @@ def k8s_pod(self): def sql(self): return self.template.sql + @property + def extended_resources(self): + return self.template.extended_resources + @property def should_register(self) -> bool: return self._should_register @@ -172,6 +184,11 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: custom=base_model.custom, container=base_model.container, task_type_version=base_model.task_type_version, + security_context=base_model.security_context, + config=base_model.config, + k8s_pod=base_model.k8s_pod, + sql=base_model.sql, + extended_resources=base_model.extended_resources, ) # Override the newly generated name if one exists in the base model if not base_model.id.is_empty: From 9a58c5ab2e252a61449934177677a4de639716e7 Mon Sep 17 00:00:00 2001 From: Umer Ahmad Date: Sat, 25 Jan 2025 19:15:24 -0500 Subject: [PATCH 2/3] Patch fetch task remote unit test Signed-off-by: Umer Ahmad Signed-off-by: Umer Ahmad --- tests/flytekit/unit/remote/test_remote.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 9911cad02f..253da792fb 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -829,6 +829,9 @@ def dynamic0(): def workflow1(): return dynamic0() + mock_client.get_task.return_value.closure.compiled_task.template.sql = None + mock_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None + rr = FlyteRemote( Config.for_sandbox(), default_project="flytesnacks", From 950d4c530521ee40a4d9af540dfc6c152c0b73da Mon Sep 17 00:00:00 2001 From: Umer Ahmad Date: Fri, 7 Feb 2025 16:20:17 -0800 Subject: [PATCH 3/3] Change patch to be global using fixture Signed-off-by: Umer Ahmad --- tests/flytekit/unit/remote/test_remote.py | 41 +++++++++++------------ 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 253da792fb..2e2dcdc22b 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -547,14 +547,21 @@ def wf1(name: str = "union") -> float: flyte_remote.register_script(wf1) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_local_server(mock_client): +@pytest.fixture() +def mock_flyte_remote_client(): + with patch("flytekit.remote.remote.FlyteRemote.client") as mock_flyte_remote_client: + mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.sql = None + mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None + yield mock_flyte_remote_client + + +def test_local_server(mock_flyte_remote_client): ctx = FlyteContextManager.current_context() lt = TypeEngine.to_literal_type(typing.Dict[str, int]) lm = TypeEngine.to_literal(ctx, {"hello": 55}, typing.Dict[str, int], lt) lm = lm.map.to_flyte_idl() - mock_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm) + mock_flyte_remote_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm) rr = FlyteRemote( Config.for_sandbox(), @@ -566,8 +573,7 @@ def test_local_server(mock_client): @mock.patch("flytekit.remote.remote.uuid") -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_execution_name(mock_client, mock_uuid): +def test_execution_name(mock_uuid, mock_flyte_remote_client): test_uuid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da") mock_uuid.uuid4.return_value = test_uuid remote = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") @@ -597,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid): entity=ft, inputs={"t": datetime.now(), "v": 0}, ) - mock_client.create_execution.assert_has_calls( + mock_flyte_remote_client.create_execution.assert_has_calls( [ mock.call(ANY, ANY, "execution-test", ANY, ANY), mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY), @@ -688,9 +694,8 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist ) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_fetch_active_launchplan_not_found(mock_client, remote): - mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") +def test_fetch_active_launchplan_not_found(mock_flyte_remote_client, remote): + mock_flyte_remote_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found") assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None @@ -785,8 +790,7 @@ async def eager_wf(a: int) -> int: _get_pickled_target_dict(eager_wf) -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_launchplan_auto_activate(mock_client): +def test_launchplan_auto_activate(mock_flyte_remote_client): @workflow def wf() -> int: return 1 @@ -804,15 +808,14 @@ def wf() -> int: # The first one should not update the launchplan rr.register_launch_plan(lp1, version="1", serialization_settings=ss) - mock_client.update_launch_plan.assert_not_called() + mock_flyte_remote_client.update_launch_plan.assert_not_called() # the second one should rr.register_launch_plan(lp2, version="1", serialization_settings=ss) - mock_client.update_launch_plan.assert_called() + mock_flyte_remote_client.update_launch_plan.assert_called() -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_register_task_with_node_dependency_hints(mock_client): +def test_register_task_with_node_dependency_hints(mock_flyte_remote_client): @task def task0(): return None @@ -829,9 +832,6 @@ def dynamic0(): def workflow1(): return dynamic0() - mock_client.get_task.return_value.closure.compiled_task.template.sql = None - mock_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None - rr = FlyteRemote( Config.for_sandbox(), default_project="flytesnacks", @@ -861,8 +861,7 @@ def workflow1(): @mock.patch("flytekit.remote.remote.FlyteRemote.fetch_launch_plan") @mock.patch("flytekit.remote.remote.FlyteRemote.raw_register") @mock.patch("flytekit.remote.remote.FlyteRemote._serialize_and_register") -@mock.patch("flytekit.remote.remote.FlyteRemote.client") -def test_register_launch_plan(mock_client, mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable): +def test_register_launch_plan(mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable, mock_flyte_remote_client): serialization_settings = SerializationSettings( image_config=ImageConfig.auto_default_image(), version="dummy_version", @@ -886,7 +885,7 @@ def hello_world_wf() -> str: lp = LaunchPlan.get_or_create(workflow=hello_world_wf, name="additional_lp_for_hello_world", default_inputs={}) mock_get_serializable.return_value = MagicMock() - mock_client.get_workflow.return_value = MagicMock() + mock_flyte_remote_client.get_workflow.return_value = MagicMock() mock_remote_lp = MagicMock() mock_fetch_launch_plan.return_value = mock_remote_lp