Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] add driver/executor pod in Spark #3016

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
pod_template: "PodTemplate",
primary_container: "task_models.Container",
settings: SerializationSettings,
task_type: str = "",
) -> Dict[str, Any]:
# import here to avoid circular import
from kubernetes.client import ApiClient, V1PodSpec
Expand Down Expand Up @@ -169,15 +170,18 @@
# with the values given to ContainerTask.
# The attributes include: image, command, args, resource, and env (env is unioned)

is_primary = False
if container.name == cast(PodTemplate, pod_template).primary_container_name:
if container.image is None:
# Copy the image from primary_container only if the image is not specified in the pod spec.
container.image = primary_container.image
else:
container.image = get_registerable_container_image(container.image, settings.image_config)

container.command = primary_container.command
container.args = primary_container.args
if task_type != "spark":
# for spark driver/executor, do not use the command and args from task podTemplate
container.command = primary_container.command
container.args = primary_container.args
Comment on lines +181 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting Spark container logic

Consider extracting the Spark-specific container command/args logic into a separate helper function to improve code organization and readability. The current nested if condition makes the code harder to follow.

Code suggestion
Check the AI-generated fix before applying
 -            if task_type != "spark":
 -                # for spark driver/executor, do not use the command and args from task podTemplate
 -                container.command = primary_container.command
 -                container.args = primary_container.args
 +            if _should_copy_container_command_args(task_type):
 +                container.command = primary_container.command
 +                container.args = primary_container.args
 +
 def _should_copy_container_command_args(task_type: str) -> bool:
 +    # for spark driver/executor, do not use the command and args from task podTemplate
 +    return task_type != "spark"

Code Review Run #27c6ae


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


limits, requests = {}, {}
for resource in primary_container.resources.limits:
Expand All @@ -192,9 +196,14 @@
container.env = [V1EnvVar(name=key, value=val) for key, val in primary_container.env.items()] + (
container.env or []
)
is_primary = True
else:
container.image = get_registerable_container_image(container.image, settings.image_config)

if task_type == "spark" and not is_primary:
# for spark driver/executor, only take the primary container
continue

Check warning on line 205 in flytekit/core/utils.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/utils.py#L205

Added line #L205 was not covered by tests

final_containers.append(container)
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers

Expand Down
27 changes: 27 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.task import K8sPod


class SparkType(enum.Enum):
Expand All @@ -27,6 +28,8 @@ def __init__(
executor_path: str,
databricks_conf: Optional[Dict[str, Dict[str, Dict]]] = None,
databricks_instance: Optional[str] = None,
driver_pod: Optional[K8sPod] = None,
executor_pod: Optional[K8sPod] = None,
):
"""
This defines a SparkJob target. It will execute the appropriate SparkJob.
Expand All @@ -47,6 +50,8 @@ def __init__(
databricks_conf = {}
self._databricks_conf = databricks_conf
self._databricks_instance = databricks_instance
self._driver_pod = driver_pod
self._executor_pod = executor_pod

def with_overrides(
self,
Expand All @@ -71,6 +76,8 @@ def with_overrides(
hadoop_conf=new_hadoop_conf,
databricks_conf=new_databricks_conf,
databricks_instance=self.databricks_instance,
driver_pod=self.driver_pod,
executor_pod=self.executor_pod,
Comment on lines +79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing pod config in overrides method

Consider adding driver_pod and executor_pod to the with_overrides method to ensure consistent pod configuration overrides.

Code suggestion
Check the AI-generated fix before applying
 @@ -56,6 +56,8 @@ def with_overrides(
          new_spark_conf: Optional[Dict[str, str]] = None,
          new_hadoop_conf: Optional[Dict[str, str]] = None,
          new_databricks_conf: Optional[Dict[str, Dict]] = None,
 +        new_driver_pod: Optional[K8sPod] = None,
 +        new_executor_pod: Optional[K8sPod] = None,
      ) -> "SparkJob":
          if not new_spark_conf:
              new_spark_conf = self.spark_conf
 @@ -66,6 +68,12 @@ def with_overrides(
          if not new_databricks_conf:
              new_databricks_conf = self.databricks_conf
 
 +        if not new_driver_pod:
 +            new_driver_pod = self.driver_pod
 +
 +        if not new_executor_pod:
 +            new_executor_pod = self.executor_pod
 +
          return SparkJob(
              spark_type=self.spark_type,
              application_file=self.application_file,
 @@ -74,6 +82,8 @@ def with_overrides(
              hadoop_conf=new_hadoop_conf,
              databricks_conf=new_databricks_conf,
              databricks_instance=self.databricks_instance,
 +            driver_pod=new_driver_pod,
 +            executor_pod=new_executor_pod,
              executor_path=self.executor_path,
          )

Code Review Run #3c7587


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

executor_path=self.executor_path,
)

Expand Down Expand Up @@ -139,6 +146,22 @@ def databricks_instance(self) -> str:
"""
return self._databricks_instance

@property
def driver_pod(self) -> K8sPod:
"""
Additional pod specs for driver pod.
:rtype: K8sPod
"""
return self._driver_pod

@property
def executor_pod(self) -> K8sPod:
"""
Additional pod specs for the worker node pods.
:rtype: K8sPod
"""
return self._executor_pod

def to_flyte_idl(self):
"""
:rtype: flyteidl.plugins.spark_pb2.SparkJob
Expand Down Expand Up @@ -167,6 +190,8 @@ def to_flyte_idl(self):
hadoopConf=self.hadoop_conf,
databricksConf=databricks_conf,
databricksInstance=self.databricks_instance,
driverPod=self.driver_pod.to_flyte_idl() if self.driver_pod else None,
executorPod=self.executor_pod.to_flyte_idl() if self.executor_pod else None,
)

@classmethod
Expand All @@ -193,4 +218,6 @@ def from_flyte_idl(cls, pb2_object):
executor_path=pb2_object.executorPath,
databricks_conf=json_format.MessageToDict(pb2_object.databricksConf),
databricks_instance=pb2_object.databricksInstance,
driver_pod=pb2_object.driverPod,
executor_pod=pb2_object.executorPod,
)
28 changes: 28 additions & 0 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.pod_template import PodTemplate
from flytekit.core.utils import _serialize_pod_spec
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
from flytekit.models.task import K8sObjectMetadata, K8sPod

from .models import SparkJob, SparkType

Expand All @@ -31,12 +34,16 @@ class Spark(object):
hadoop_conf: Dictionary of hadoop conf. The variables should match a typical hadoop configuration for spark
executor_path: Python binary executable to use for PySpark in driver and executor.
applications_path: MainFile is the path to a bundled JAR, Python, or R file of the application to execute.
driver_pod: PodTemplate for Spark driver pod
executor_pod: PodTemplate for Spark executor pod
"""

spark_conf: Optional[Dict[str, str]] = None
hadoop_conf: Optional[Dict[str, str]] = None
executor_path: Optional[str] = None
applications_path: Optional[str] = None
driver_pod: Optional[PodTemplate] = None
executor_pod: Optional[PodTemplate] = None

def __post_init__(self):
if self.spark_conf is None:
Expand Down Expand Up @@ -168,6 +175,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
executor_path=self._default_executor_path or settings.python_interpreter,
main_class="",
spark_type=SparkType.PYTHON,
driver_pod=self.to_k8s_pod(settings, self.task_config.driver_pod),
executor_pod=self.to_k8s_pod(settings, self.task_config.executor_pod),
)
if isinstance(self.task_config, (Databricks, DatabricksV2)):
cfg = cast(DatabricksV2, self.task_config)
Expand All @@ -176,6 +185,25 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:

return MessageToDict(job.to_flyte_idl())

def to_k8s_pod(
self, settings: SerializationSettings, pod_template: Optional[PodTemplate] = None
) -> Optional[K8sPod]:
"""
Convert the podTemplate to K8sPod
"""
if pod_template is None:
return None

return K8sPod(
pod_spec=_serialize_pod_spec(
pod_template, self._get_container(settings), settings, task_type=self._SPARK_TASK_TYPE
),
metadata=K8sObjectMetadata(
labels=pod_template.labels,
annotations=pod_template.annotations,
),
)

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
import pyspark as _pyspark

Expand Down
Loading
Loading