diff --git a/flytekit/core/utils.py b/flytekit/core/utils.py index 9f1967d2f9..f2e6e24bbb 100644 --- a/flytekit/core/utils.py +++ b/flytekit/core/utils.py @@ -139,6 +139,7 @@ def _serialize_pod_spec( pod_template: "PodTemplate", primary_container: "task_models.Container", settings: SerializationSettings, + task_type: str = "", ) -> Dict[str, Any]: # import here to avoid circular import from kubernetes.client import ApiClient, V1PodSpec @@ -169,6 +170,7 @@ def _serialize_pod_spec( # with the values given to ContainerTask. # The attributes include: image, command, args, resource, and env (env is unioned) + is_primary = False if container.name == cast(PodTemplate, pod_template).primary_container_name: if container.image is None: # Copy the image from primary_container only if the image is not specified in the pod spec. @@ -176,8 +178,10 @@ def _serialize_pod_spec( else: container.image = get_registerable_container_image(container.image, settings.image_config) - container.command = primary_container.command - container.args = primary_container.args + if task_type != "spark": + # for spark driver/executor, do not use the command and args from task podTemplate + container.command = primary_container.command + container.args = primary_container.args limits, requests = {}, {} for resource in primary_container.resources.limits: @@ -192,9 +196,14 @@ def _serialize_pod_spec( container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + ( container.env or [] ) + is_primary = True else: container.image = get_registerable_container_image(container.image, settings.image_config) + if task_type == "spark" and not is_primary: + # for spark driver/executor, only take the primary container + continue + final_containers.append(container) cast(V1PodSpec, pod_template.pod_spec).containers = final_containers diff --git a/plugins/flytekit-spark/flytekitplugins/spark/models.py b/plugins/flytekit-spark/flytekitplugins/spark/models.py index e74a9fbe3f..1f185609f4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/models.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/models.py @@ -7,6 +7,7 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common +from flytekit.models.task import K8sPod class SparkType(enum.Enum): @@ -27,6 +28,8 @@ def __init__( executor_path: str, databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None, databricks_instance: Optional[str] = None, + driver_pod: Optional[K8sPod] = None, + executor_pod: Optional[K8sPod] = None, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -47,6 +50,8 @@ def __init__( databricks_conf = {} self._databricks_conf = databricks_conf self._databricks_instance = databricks_instance + self._driver_pod = driver_pod + self._executor_pod = executor_pod def with_overrides( self, @@ -71,6 +76,8 @@ def with_overrides( hadoop_conf=new_hadoop_conf, databricks_conf=new_databricks_conf, databricks_instance=self.databricks_instance, + driver_pod=self.driver_pod, + executor_pod=self.executor_pod, executor_path=self.executor_path, ) @@ -139,6 +146,22 @@ def databricks_instance(self) -> str: """ return self._databricks_instance + @property + def driver_pod(self) -> K8sPod: + """ + Additional pod specs for driver pod. + :rtype: K8sPod + """ + return self._driver_pod + + @property + def executor_pod(self) -> K8sPod: + """ + Additional pod specs for the worker node pods. + :rtype: K8sPod + """ + return self._executor_pod + def to_flyte_idl(self): """ :rtype: flyteidl.plugins.spark_pb2.SparkJob @@ -167,6 +190,8 @@ def to_flyte_idl(self): hadoopConf=self.hadoop_conf, databricksConf=databricks_conf, databricksInstance=self.databricks_instance, + driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None, + executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None, ) @classmethod @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object): executor_path=pb2_object.executorPath, databricks_conf=json_format.MessageToDict(pb2_object.databricksConf), databricks_instance=pb2_object.databricksInstance, + driver_pod=pb2_object.driverPod, + executor_pod=pb2_object.executorPod, ) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7d2f718617..40ff840ac9 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -10,9 +10,12 @@ from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters +from flytekit.core.pod_template import PodTemplate +from flytekit.core.utils import _serialize_pod_spec from flytekit.extend import ExecutionState, TaskPlugins from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec +from flytekit.models.task import K8sObjectMetadata, K8sPod from .models import SparkJob, SparkType @@ -31,12 +34,16 @@ class Spark(object): hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark executor_path: Python binary executable to use for PySpark in driver and executor. applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute. + driver_pod: PodTemplate for Spark driver pod + executor_pod: PodTemplate for Spark executor pod """ spark_conf: Optional[Dict[str, str]] = None hadoop_conf: Optional[Dict[str, str]] = None executor_path: Optional[str] = None applications_path: Optional[str] = None + driver_pod: Optional[PodTemplate] = None + executor_pod: Optional[PodTemplate] = None def __post_init__(self): if self.spark_conf is None: @@ -168,6 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: executor_path=self._default_executor_path or settings.python_interpreter, main_class="", spark_type=SparkType.PYTHON, + driver_pod=self.to_k8s_pod(settings, self.task_config.driver_pod), + executor_pod=self.to_k8s_pod(settings, self.task_config.executor_pod), ) if isinstance(self.task_config, (Databricks, DatabricksV2)): cfg = cast(DatabricksV2, self.task_config) @@ -176,6 +185,25 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) + def to_k8s_pod( + self, settings: SerializationSettings, pod_template: Optional[PodTemplate] = None + ) -> Optional[K8sPod]: + """ + Convert the podTemplate to K8sPod + """ + if pod_template is None: + return None + + return K8sPod( + pod_spec=_serialize_pod_spec( + pod_template, self._get_container(settings), settings, task_type=self._SPARK_TASK_TYPE + ), + metadata=K8sObjectMetadata( + labels=pod_template.labels, + annotations=pod_template.annotations, + ), + ) + def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: import pyspark as _pyspark diff --git a/plugins/flytekit-spark/tests/test_spark_task.py b/plugins/flytekit-spark/tests/test_spark_task.py index 7ce5f14ebf..ebcba98935 100644 --- a/plugins/flytekit-spark/tests/test_spark_task.py +++ b/plugins/flytekit-spark/tests/test_spark_task.py @@ -5,23 +5,58 @@ import pyspark import pytest +from google.protobuf.json_format import MessageToDict +from flytekit import PodTemplate from flytekit.core import context_manager from flytekitplugins.spark import Spark from flytekitplugins.spark.task import Databricks, new_spark_session from pyspark.sql import SparkSession import flytekit -from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec -from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages -from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState +from flytekit import ( + StructuredDataset, + StructuredDatasetTransformerEngine, + task, + ImageSpec, +) +from flytekit.configuration import ( + Image, + ImageConfig, + SerializationSettings, + FastSerializationSettings, + DefaultImages, +) +from flytekit.core.context_manager import ( + ExecutionParameters, + FlyteContextManager, + ExecutionState, +) +from flytekit.models.task import K8sObjectMetadata, K8sPod +from kubernetes.client.models import ( + V1Container, + V1PodSpec, + V1Toleration, + V1EnvVar, +) + + +# @pytest.fixture(scope="function") +# def reset_spark_session() -> None: +# pyspark.sql.SparkSession.builder.getOrCreate().stop() +# yield +# pyspark.sql.SparkSession.builder.getOrCreate().stop() + @pytest.fixture(scope="function") def reset_spark_session() -> None: - pyspark.sql.SparkSession.builder.getOrCreate().stop() + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None yield - pyspark.sql.SparkSession.builder.getOrCreate().stop() - + if SparkSession._instantiatedSession: + SparkSession.builder.getOrCreate().stop() + SparkSession._instantiatedSession = None def test_spark_task(reset_spark_session): databricks_conf = { @@ -68,7 +103,10 @@ def my_spark(a: str) -> int: retrieved_settings = my_spark.get_custom(settings) assert retrieved_settings["sparkConf"] == {"spark": "1"} assert retrieved_settings["executorPath"] == "/usr/bin/python3" - assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) pb = ExecutionParameters.new_builder() pb.working_dir = "/tmp" @@ -121,11 +159,13 @@ def test_to_html(): df = spark.createDataFrame([("Bob", 10)], ["name", "age"]) sd = StructuredDataset(dataframe=df) tf = StructuredDatasetTransformerEngine() - output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame) + output = tf.to_html( + FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame + ) assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output -@mock.patch('pyspark.context.SparkContext.addPyFile') +@mock.patch("pyspark.context.SparkContext.addPyFile") def test_spark_addPyFile(mock_add_pyfile): @task( task_config=Spark( @@ -151,8 +191,11 @@ def my_spark(a: int) -> int: ctx = context_manager.FlyteContextManager.current_context() with context_manager.FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings) + ctx.with_execution_state( + ctx.new_execution_state().with_params( + mode=ExecutionState.Mode.TASK_EXECUTION + ) + ).with_serialization_settings(serialization_settings) ) as new_ctx: my_spark.pre_execute(new_ctx.user_space_params) mock_add_pyfile.assert_called_once() @@ -173,7 +216,10 @@ def spark1(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark1.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark1._default_executor_path == "/usr/bin/python3" assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py" @@ -185,6 +231,215 @@ def spark2(partitions: int) -> float: print("Starting Spark with Partitions: {}".format(partitions)) return 1.0 - assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + assert ( + spark2.container_image.base_image + == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" + ) assert spark2._default_executor_path == "/usr/bin/python3" assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py" + + +def clean_dict(d): + """ + Recursively remove keys with None values from dictionaries and lists. + """ + if isinstance(d, dict): + return {k: clean_dict(v) for k, v in d.items() if v is not None} + elif isinstance(d, list): + return [clean_dict(item) for item in d if item is not None] + else: + return d + + +def test_spark_driver_executor_podSpec(reset_spark_session): + custom_image = ImageSpec( + registry="ghcr.io/flyteorg", + packages=["flytekitplugins-spark"], + ) + + driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="driver-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-driver", value="driver")], + ), + V1Container( + name="not-primary", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="executor-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[V1EnvVar(name="x/custom-executor", value="executor")], + ), + V1Container( + name="not-primary", + command=["echo"], + args=["not_primary"], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + driver_pod = PodTemplate( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + primary_container_name="driver-primary", + pod_spec=driver_pod_spec, + ) + + executor_pod = PodTemplate( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + primary_container_name="executor-primary", + pod_spec=executor_pod_spec, + ) + + expect_driver_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="driver-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="FOO", value="baz"), + V1EnvVar(name="x/custom-driver", value="driver"), + ], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-driver", + operator="Equal", + value="foo-driver", + effect="NoSchedule", + ), + ], + ) + + expect_executor_pod_spec = V1PodSpec( + containers=[ + V1Container( + name="executor-primary", + image="ghcr.io/flyteorg", + command=["echo"], + args=["wow"], + env=[ + V1EnvVar(name="FOO", value="baz"), + V1EnvVar(name="x/custom-executor", value="executor"), + ], + ), + ], + tolerations=[ + V1Toleration( + key="x/custom-executor", + operator="Equal", + value="foo-executor", + effect="NoSchedule", + ), + ], + ) + + driver_pod_spec_dict_remove_None = expect_driver_pod_spec.to_dict() + executor_pod_spec_dict_remove_None = expect_executor_pod_spec.to_dict() + + driver_pod_spec_dict_remove_None = clean_dict(driver_pod_spec_dict_remove_None) + executor_pod_spec_dict_remove_None = clean_dict(executor_pod_spec_dict_remove_None) + + target_driver_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_d": "lValA", "lKeyB_d": "lValB"}, + annotations={"aKeyA_d": "aValA", "aKeyB_d": "aValB"}, + ), + pod_spec=driver_pod_spec_dict_remove_None, # type: ignore + ) + + target_executor_k8sPod = K8sPod( + metadata=K8sObjectMetadata( + labels={"lKeyA_e": "lValA", "lKeyB_e": "lValB"}, + annotations={"aKeyA_e": "aValA", "aKeyB_e": "aValB"}, + ), + pod_spec=executor_pod_spec_dict_remove_None, # type: ignore + ) + + @task( + task_config=Spark( + spark_conf={"spark.driver.memory": "1000M"}, + driver_pod=driver_pod, + executor_pod=executor_pod, + ), + container_image=custom_image, + pod_template=PodTemplate(primary_container_name="primary"), + ) + def my_spark(a: str) -> int: + session = flytekit.current_context().spark_session + configs = session.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" + return 10 + + assert my_spark.task_config is not None + assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"} + default_img = Image(name="default", fqn="test", tag="tag") + + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + retrieved_settings = my_spark.get_custom(settings) + assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"} + assert retrieved_settings["executorPath"] == "/usr/bin/python3" + assert ( + retrieved_settings["mainApplicationFile"] + == "local:///usr/local/bin/entrypoint.py" + ) + assert retrieved_settings["driverPod"] == MessageToDict( + target_driver_k8sPod.to_flyte_idl() + ) + assert retrieved_settings["executorPod"] == MessageToDict( + target_executor_k8sPod.to_flyte_idl() + ) + + pb = ExecutionParameters.new_builder() + pb.working_dir = "/tmp" + pb.execution_id = "ex:local:local:local" + p = pb.build() + new_p = my_spark.pre_execute(p) + assert new_p is not None + assert new_p.has_attr("SPARK_SESSION") + + assert my_spark.sess is not None + configs = my_spark.sess.sparkContext.getConf().getAll() + assert ("spark.driver.memory", "1000M") in configs + assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs