Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] add driver/executor pod in Spark #3016

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class SparkType(enum.Enum):
Expand All @@ -27,6 +28,8 @@ def __init__(
executor_path: str,
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
databricks_instance: Optional[str] = None,
driver_pod: Optional[K8sPod] = None,
executor_pod: Optional[K8sPod] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -47,6 +50,8 @@ def __init__(
databricks_conf = {}
self._databricks_conf = databricks_conf
self._databricks_instance = databricks_instance
self._driver_pod = driver_pod
self._executor_pod = executor_pod

def with_overrides(
self,
Expand All @@ -71,6 +76,8 @@ def with_overrides(
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_instance=self.databricks_instance,
driver_pod=self.driver_pod,
executor_pod=self.executor_pod,
Comment on lines +79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing pod config in overrides method

Consider adding driver_pod and executor_pod to the with_overrides method to ensure consistent pod configuration overrides.

Code suggestion
Check the AI-generated fix before applying
 @@ -56,6 +56,8 @@ def with_overrides(
          new_spark_conf: Optional[Dict[str, str]] = None,
          new_hadoop_conf: Optional[Dict[str, str]] = None,
          new_databricks_conf: Optional[Dict[str, Dict]] = None,
 +        new_driver_pod: Optional[K8sPod] = None,
 +        new_executor_pod: Optional[K8sPod] = None,
      ) -> "SparkJob":
          if not new_spark_conf:
              new_spark_conf = self.spark_conf
 @@ -66,6 +68,12 @@ def with_overrides(
          if not new_databricks_conf:
              new_databricks_conf = self.databricks_conf
 
 +        if not new_driver_pod:
 +            new_driver_pod = self.driver_pod
 +
 +        if not new_executor_pod:
 +            new_executor_pod = self.executor_pod
 +
          return SparkJob(
              spark_type=self.spark_type,
              application_file=self.application_file,
 @@ -74,6 +82,8 @@ def with_overrides(
              hadoop_conf=new_hadoop_conf,
              databricks_conf=new_databricks_conf,
              databricks_instance=self.databricks_instance,
 +            driver_pod=new_driver_pod,
 +            executor_pod=new_executor_pod,
              executor_path=self.executor_path,
          )

Code Review Run #3c7587


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

executor_path=self.executor_path,
)

Expand Down Expand Up @@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
"""
return self._databricks_instance

@property
def driver_pod(self) -> K8sPod:
"""
Additional pod specs for driver pod.
:rtype: K8sPod
"""
return self._driver_pod

@property
def executor_pod(self) -> K8sPod:
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._executor_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand Down Expand Up @@ -167,6 +190,8 @@ def to_flyte_idl(self):
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksInstance=self.databricks_instance,
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
)

@classmethod
Expand All @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_instance=pb2_object.databricksInstance,
driver_pod=pb2_object.driverPod,
executor_pod=pb2_object.executorPod,
)
25 changes: 25 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.core.pod_template import PodTemplate
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sPod, K8sObjectMetadata

from .models import SparkJob, SparkType

Expand All @@ -31,12 +34,16 @@ class Spark(object):
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
driver_pod: K8sPod for Spark driver pod
executor_pod: K8sPod for Spark executor pod
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[PodTemplate] = None
executor_pod: Optional[PodTemplate] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -168,6 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.to_k8s_pod(self.task_config.driver_pod, settings),
executor_pod=self.to_k8s_pod(self.task_config.executor_pod, settings),
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand All @@ -176,6 +185,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:

return MessageToDict(job.to_flyte_idl())

def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding return type hints

Consider adding type hints for the return value of _get_container() in the to_k8s_pod() method. The method appears to use this internal method but its return type is not clearly specified in the type hints.

Code suggestion
Check the AI-generated fix before applying
Suggested change
def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None:
def to_k8s_pod(self, pod_template: PodTemplate | None, settings: SerializationSettings) -> K8sPod | None:
from flytekit.models import task as _task_model
_get_container: Callable[..., _task_model.Container] = self._get_container

Code Review Run #3c7587


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

"""
Convert the podTemplate to K8sPod
"""
if pod_template is None:
return None

return K8sPod(
pod_spec=_serialize_pod_spec(pod_template, self._get_container(settings), settings),
metadata=K8sObjectMetadata(
labels=pod_template.labels,
annotations=pod_template.annotations,
),
)


def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark

Expand Down
155 changes: 145 additions & 10 deletions plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,39 @@
import pyspark
import pytest

from google.protobuf.json_format import MessageToDict
from flytekit import PodTemplate
from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
from flytekit import (
StructuredDataset,
StructuredDatasetTransformerEngine,
task,
ImageSpec,
)
from flytekit.configuration import (
Image,
ImageConfig,
SerializationSettings,
FastSerializationSettings,
DefaultImages,
)
from flytekit.core.context_manager import (
ExecutionParameters,
FlyteContextManager,
ExecutionState,
)
from flytekit.models.task import K8sPod
from kubernetes.client.models import (
V1Container,
V1PodSpec,
V1Toleration,
V1EnvVar,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -68,7 +92,10 @@ def my_spark(a: str) -> int:
retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark": "1"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert retrieved_settings["mainApplicationFile"] == "local:///usr/local/bin/entrypoint.py"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
Expand Down Expand Up @@ -121,11 +148,13 @@ def test_to_html():
df = spark.createDataFrame([("Bob", 10)], ["name", "age"])
sd = StructuredDataset(dataframe=df)
tf = StructuredDatasetTransformerEngine()
output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame)
output = tf.to_html(
FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame
)
assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output


@mock.patch('pyspark.context.SparkContext.addPyFile')
@mock.patch("pyspark.context.SparkContext.addPyFile")
def test_spark_addPyFile(mock_add_pyfile):
@task(
task_config=Spark(
Expand All @@ -151,8 +180,11 @@ def my_spark(a: int) -> int:

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings)
ctx.with_execution_state(
ctx.new_execution_state().with_params(
mode=ExecutionState.Mode.TASK_EXECUTION
)
).with_serialization_settings(serialization_settings)
) as new_ctx:
my_spark.pre_execute(new_ctx.user_space_params)
mock_add_pyfile.assert_called_once()
Expand All @@ -173,7 +205,10 @@ def spark1(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark1.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark1.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark1._default_executor_path == "/usr/bin/python3"
assert spark1._default_applications_path == "local:///usr/local/bin/entrypoint.py"

Expand All @@ -185,6 +220,106 @@ def spark2(partitions: int) -> float:
print("Starting Spark with Partitions: {}".format(partitions))
return 1.0

assert spark2.container_image.base_image == f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
assert (
spark2.container_image.base_image
== f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}"
)
assert spark2._default_executor_path == "/usr/bin/python3"
assert spark2._default_applications_path == "local:///usr/local/bin/entrypoint.py"


def test_spark_driver_executor_podSpec():
custom_image = ImageSpec(
registry="ghcr.io/flyteorg",
packages=["flytekitplugins-spark"],
)

driver_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-driver", value="driver")],
),
],
tolerations=[
V1Toleration(
key="x/custom-driver",
operator="Equal",
value="foo-driver",
effect="NoSchedule",
),
],
)

executor_pod_spec = V1PodSpec(
containers=[
V1Container(
name="primary",
image="ghcr.io/flyteorg",
command=["echo"],
args=["wow"],
env=[V1EnvVar(name="x/custom-executor", value="executor")],
),
],
tolerations=[
V1Toleration(
key="x/custom-executor",
operator="Equal",
value="foo-executor",
effect="NoSchedule",
),
],
)

@task(
task_config=Spark(
spark_conf={"spark.driver.memory": "1000M"},
driver_pod=K8sPod(pod_spec=driver_pod_spec.to_dict()),
executor_pod=K8sPod(pod_spec=executor_pod_spec.to_dict()),
),
# limits=Resources(cpu="50m", mem="2000M"),
container_image=custom_image,
pod_template=PodTemplate(primary_container_name="primary"),
)
def my_spark(a: str) -> int:
session = flytekit.current_context().spark_session
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
return 10

assert my_spark.task_config is not None
assert my_spark.task_config.spark_conf == {"spark.driver.memory": "1000M"}

default_img = Image(name="default", fqn="test", tag="tag")
settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)

retrieved_settings = my_spark.get_custom(settings)
assert retrieved_settings["sparkConf"] == {"spark.driver.memory": "1000M"}
assert retrieved_settings["executorPath"] == "/usr/bin/python3"
assert (
retrieved_settings["mainApplicationFile"]
== "local:///usr/local/bin/entrypoint.py"
)
assert retrieved_settings["driverPod"] == MessageToDict(K8sPod(pod_spec=driver_pod_spec.to_dict()).to_flyte_idl())
assert retrieved_settings["executorPod"] == MessageToDict(K8sPod(pod_spec=executor_pod_spec.to_dict()).to_flyte_idl())

pb = ExecutionParameters.new_builder()
pb.working_dir = "/tmp"
pb.execution_id = "ex:local:local:local"
p = pb.build()
new_p = my_spark.pre_execute(p)
assert new_p is not None
assert new_p.has_attr("SPARK_SESSION")

assert my_spark.sess is not None
configs = my_spark.sess.sparkContext.getConf().getAll()
assert ("spark.driver.memory", "1000M") in configs
assert ("spark.app.name", "FlyteSpark: ex:local:local:local") in configs
Loading