diff --git a/astronomer/providers/databricks/hooks/databricks.py b/astronomer/providers/databricks/hooks/databricks.py index a105170ee..53722ef0a 100644 --- a/astronomer/providers/databricks/hooks/databricks.py +++ b/astronomer/providers/databricks/hooks/databricks.py @@ -2,6 +2,7 @@ import asyncio import base64 +import warnings from typing import Any, Dict, cast import aiohttp @@ -21,21 +22,17 @@ class DatabricksHookAsync(DatabricksHook): """ - Interact with Databricks. - - :param databricks_conn_id: Reference to the Databricks connection. - :type databricks_conn_id: str - :param timeout_seconds: The amount of time in seconds the requests library - will wait before timing-out. - :type timeout_seconds: int - :param retry_limit: The number of times to retry the connection in case of - service outages. - :type retry_limit: int - :param retry_delay: The number of seconds to wait between retries (it - might be a floating point number). - :type retry_delay: float + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.databricks.hooks.databricks.DatabricksHook` instead. """ + def __init__(self, *args: Any, **kwargs: Any): + warnings.warn( + "This class is deprecated and will be removed in 2.0.0. " + "Use `airflow.providers.databricks.hooks.databricks.DatabricksHook` instead " + ) + super().__init__(*args, **kwargs) + async def get_run_state_async(self, run_id: str) -> RunState: """ Retrieves run state of the run using an asynchronous api call. diff --git a/astronomer/providers/databricks/operators/databricks.py b/astronomer/providers/databricks/operators/databricks.py index 4e908a8ba..e5bb334fa 100644 --- a/astronomer/providers/databricks/operators/databricks.py +++ b/astronomer/providers/databricks/operators/databricks.py @@ -1,459 +1,41 @@ from __future__ import annotations +import warnings from typing import Any -from airflow.exceptions import AirflowException -from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState from airflow.providers.databricks.operators.databricks import ( - XCOM_RUN_ID_KEY, - XCOM_RUN_PAGE_URL_KEY, DatabricksRunNowOperator, DatabricksSubmitRunOperator, ) -from astronomer.providers.databricks.triggers.databricks import DatabricksTrigger -from astronomer.providers.utils.typing_compat import Context - - -def _handle_non_successful_terminal_states( - run_state: RunState, run_info: dict[str, Any], hook: DatabricksHook, task_id: str -) -> None: - """Raise AirflowException with detailed error message from run_info - - Check if the "result_state" is "FAILED". - If not, raise an AirflowException with the "result_state" and "state_message" from getrun endpoint [1] - If so, it further digs into the result_state of each task to get error output from "error" in - getrunoutput endpoint [2] or "state_message" from getrun endpoint[1] - - [1] https://docs.databricks.com/api-explorer/workspace/jobs/getrun - [2] https://docs.databricks.com/api-explorer/workspace/jobs/getrunoutput - - :param run_state: the state information extract from run_info["state"] - :param run_info: response from https://docs.databricks.com/api-explorer/workspace/jobs/getrun - :param hook: hook to connect to Databricks - """ - if run_state.result_state == "FAILED": - task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] - if task_run_id is not None: - run_output = hook.get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message - error_message = ( - f"{task_id} failed with terminal state: {run_state} " f"and with the error {notebook_error}" - ) - else: - error_message = ( - f"{task_id} failed with terminal state: {run_state} " - f"and with the error {run_state.state_message}" - ) - raise AirflowException(error_message) - class DatabricksSubmitRunOperatorAsync(DatabricksSubmitRunOperator): """ - Submits a Spark job run to Databricks using the - `api/2.1/jobs/runs/submit - `_ - API endpoint. Using DatabricksHook, it makes two non-async API calls to - submit the run, and retrieve the run page URL. By getting the job id from the response, polls for the status - in the Databricks trigger, and defer execution as expected. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:DatabricksSubmitRunOperator` - - :param tasks: Array of Objects(RunSubmitTaskSettings) <= 100 items. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit - :param json: A JSON object containing API parameters which will be passed - directly to the ``api/2.1/jobs/runs/submit`` endpoint. The other named parameters - (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will - be merged with this json dictionary if they are provided. - If there are conflicts during the merge, the named parameters will - take precedence and override the top level json keys. (templated) - - .. seealso:: - For more information about templating see :ref:`concepts:jinja-templating`. - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit - :param spark_jar_task: The main class and parameters for the JAR task. Note that - the actual JAR is specified in the ``libraries``. - *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkjartask - :param notebook_task: The notebook path and parameters for the notebook task. - *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsnotebooktask - :param spark_python_task: The python file path and parameters to run the python file with. - *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkpythontask - :param spark_submit_task: Parameters needed to run a spark-submit command. - *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparksubmittask - :param pipeline_task: Parameters needed to execute a Delta Live Tables pipeline task. - The provided dictionary must contain at least ``pipeline_id`` field! - *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task`` - *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobspipelinetask - :param new_cluster: Specs for a new cluster on which this task will be run. - *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified - (except when ``pipeline_task`` is used). - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsclusterspecnewcluster - :param existing_cluster_id: ID for existing cluster on which to run this task. - *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified - (except when ``pipeline_task`` is used). - This field will be templated. - :param libraries: Libraries which this run will use. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/2.0/jobs.html#managedlibrarieslibrary - :param run_name: The run name used for this task. - By default this will be set to the Airflow ``task_id``. This ``task_id`` is a - required parameter of the superclass ``BaseOperator``. - This field will be templated. - :param idempotency_token: an optional token that can be used to guarantee the idempotency of job run - requests. If a run with the provided token already exists, the request does not create a new run but - returns the ID of the existing run instead. This token must have at most 64 characters. - :param access_control_list: optional list of dictionaries representing Access Control List (ACL) for - a given job run. Each dictionary consists of following field - specific subject (``user_name`` for - users, or ``group_name`` for groups), and ``permission_level`` for that subject. See Jobs API - documentation for more details. - :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. - :param timeout_seconds: The timeout for this run. By default a value of 0 is used - which means to have no timeout. - This field will be templated. - :param databricks_conn_id: Reference to the :ref:`Databricks connection `. - By default and in the common case this will be ``databricks_default``. To use - token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. (templated) - :param polling_period_seconds: Controls the rate which we poll for the result of - this run. By default the operator will poll every 30 seconds. - :param databricks_retry_limit: Amount of times retry if the Databricks backend is - unreachable. Its value must be greater than or equal to 1. - :param databricks_retry_delay: Number of seconds to wait between retries (it - might be a floating point number). - :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. - :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. - :param git_source: Optional specification of a remote git repository from which - supported task types are retrieved. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit + This class is deprecated. + Use :class: `~airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator` and set + `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> None: - """ - Execute the Databricks trigger, and defer execution as expected. It makes two non-async API calls to - submit the run, and retrieve the run page URL. It also pushes these - values as xcom data if do_xcom_push is set to True in the context. - """ - # Note: This hook makes non-async calls. - # It is imported from the Databricks base class. - # Async calls (i.e. polling) are handled in the Trigger. - try: - # for apache-airflow-providers-databricks<3.2.0 - hook = self._get_hook() # type: ignore[call-arg] - except TypeError: - # for apache-airflow-providers-databricks>=3.2.0 - hook = self._get_hook(caller="DatabricksSubmitRunOperatorAsync") - self.run_id = hook.submit_run(self.json) - job_id = hook.get_job_id(self.run_id) - - if self.do_xcom_push: - context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=self.run_id) - self.log.info("Run submitted with run_id: %s", self.run_id) - self.run_page_url = hook.get_run_page_url(self.run_id) - if self.do_xcom_push: - context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=self.run_page_url) - - self.log.info("View run status, Spark UI, and logs at %s", self.run_page_url) - - run_info = hook.get_run(self.run_id) - run_state = RunState(**run_info["state"]) - if not run_state.is_terminal: - self.defer( - timeout=self.execution_timeout, - trigger=DatabricksTrigger( - conn_id=self.databricks_conn_id, - task_id=self.task_id, - run_id=str(self.run_id), - job_id=job_id, - run_page_url=self.run_page_url, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - polling_period_seconds=self.polling_period_seconds, - ), - method_name="execute_complete", - ) - else: - if run_state.is_successful: - self.log.info("%s completed successfully.", self.task_id) - return - else: - _handle_non_successful_terminal_states(run_state, run_info, hook, self.task_id) - - def execute_complete(self, context: Context, event: Any = None) -> None: # type: ignore[override] - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info("%s completed successfully.", self.task_id) - if event.get("job_id"): - context["ti"].xcom_push(key="job_id", value=event["job_id"]) + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "This class is deprecated." + "Use `airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator` " + "and set `deferrable` param to `True` instead." + ) + super().__init__(*args, deferrable=True, **kwargs) class DatabricksRunNowOperatorAsync(DatabricksRunNowOperator): """ - Runs an existing Spark job run to Databricks using the - `api/2.1/jobs/run-now - `_ - API endpoint. - - There are two ways to instantiate this operator. - - In the first way, you can take the JSON payload that you typically use - to call the ``api/2.1/jobs/run-now`` endpoint and pass it directly - to our ``DatabricksRunNowOperator`` through the ``json`` parameter. - For example :: - - json = { - "job_id": 42, - "notebook_params": { - "dry-run": "true", - "oldest-time-to-consider": "1457570074236" - } - } - - notebook_run = DatabricksRunNowOperator(task_id='notebook_run', json=json) - - Another way to accomplish the same thing is to use the named parameters - of the ``DatabricksRunNowOperator`` directly. Note that there is exactly - one named parameter for each top level parameter in the ``run-now`` - endpoint. In this method, your code would look like this: :: - - job_id=42 - - notebook_params = { - "dry-run": "true", - "oldest-time-to-consider": "1457570074236" - } - - python_params = ["douglas adams", "42"] - - jar_params = ["douglas adams", "42"] - - spark_submit_params = ["--class", "org.apache.spark.examples.SparkPi"] - - notebook_run = DatabricksRunNowOperator( - job_id=job_id, - notebook_params=notebook_params, - python_params=python_params, - jar_params=jar_params, - spark_submit_params=spark_submit_params - ) - - In the case where both the json parameter **AND** the named parameters - are provided, they will be merged together. If there are conflicts during the merge, - the named parameters will take precedence and override the top level ``json`` keys. - - Currently the named parameters that ``DatabricksRunNowOperator`` supports are - - ``job_id`` - - ``job_name`` - - ``json`` - - ``notebook_params`` - - ``python_params`` - - ``python_named_parameters`` - - ``jar_params`` - - ``spark_submit_params`` - - ``idempotency_token`` - - :param job_id: the job_id of the existing Databricks job. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param job_name: the name of the existing Databricks job. - It must exist only one job with the specified name. - ``job_id`` and ``job_name`` are mutually exclusive. - This field will be templated. - :param json: A JSON object containing API parameters which will be passed - directly to the ``api/2.1/jobs/run-now`` endpoint. The other named parameters - (i.e. ``notebook_params``, ``spark_submit_params``..) to this operator will - be merged with this json dictionary if they are provided. - If there are conflicts during the merge, the named parameters will - take precedence and override the top level json keys. (templated) - - .. seealso:: - For more information about templating see :ref:`concepts:jinja-templating`. - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param notebook_params: A dict from keys to values for jobs with notebook task, - e.g. "notebook_params": {"name": "john doe", "age": "35"}. - The map is passed to the notebook and will be accessible through the - dbutils.widgets.get function. See Widgets for more information. - If not specified upon run-now, the triggered run will use the - job's base parameters. notebook_params cannot be - specified in conjunction with jar_params. The json representation - of this field (i.e. {"notebook_params":{"name":"john doe","age":"35"}}) - cannot exceed 10,000 bytes. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/user-guide/notebooks/widgets.html - :param python_params: A list of parameters for jobs with python tasks, - e.g. "python_params": ["john doe", "35"]. - The parameters will be passed to python file as command line parameters. - If specified upon run-now, it would overwrite the parameters specified in job setting. - The json representation of this field (i.e. {"python_params":["john doe","35"]}) - cannot exceed 10,000 bytes. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param python_named_params: A list of named parameters for jobs with python wheel tasks, - e.g. "python_named_params": {"name": "john doe", "age": "35"}. - If specified upon run-now, it would overwrite the parameters specified in job setting. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param jar_params: A list of parameters for jobs with JAR tasks, - e.g. "jar_params": ["john doe", "35"]. - The parameters will be passed to JAR file as command line parameters. - If specified upon run-now, it would overwrite the parameters specified in - job setting. - The json representation of this field (i.e. {"jar_params":["john doe","35"]}) - cannot exceed 10,000 bytes. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param spark_submit_params: A list of parameters for jobs with spark submit task, - e.g. "spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"]. - The parameters will be passed to spark-submit script as command line parameters. - If specified upon run-now, it would overwrite the parameters specified - in job setting. - The json representation of this field cannot exceed 10,000 bytes. - This field will be templated. - - .. seealso:: - https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow - :param idempotency_token: an optional token that can be used to guarantee the idempotency of job run - requests. If a run with the provided token already exists, the request does not create a new run but - returns the ID of the existing run instead. This token must have at most 64 characters. - :param databricks_conn_id: Reference to the :ref:`Databricks connection `. - By default and in the common case this will be ``databricks_default``. To use - token based authentication, provide the key ``token`` in the extra field for the - connection and create the key ``host`` and leave the ``host`` field empty. (templated) - :param polling_period_seconds: Controls the rate which we poll for the result of - this run. By default, the operator will poll every 30 seconds. - :param databricks_retry_limit: Amount of times retry if the Databricks backend is - unreachable. Its value must be greater than or equal to 1. - :param databricks_retry_delay: Number of seconds to wait between retries (it - might be a floating point number). - :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. - :param do_xcom_push: Whether we should push run_id and run_page_url to xcom. - :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. + This class is deprecated. + Use :class: `~airflow.providers.databricks.operators.databricks.DatabricksRunNowOperator` and set + `deferrable` param to `True` instead. """ - def execute(self, context: Context) -> None: - """ - Logic that the operator uses to execute the Databricks trigger, - and defer execution as expected. It makes two non-async API calls to - submit the run, and retrieve the run page URL. It also pushes these - values as xcom data if do_xcom_push is set to True in the context. - """ - # Note: This hook makes non-async calls. - # It is from the Databricks base class. - try: - # for apache-airflow-providers-databricks<3.2.0 - hook = self._get_hook() # type: ignore[call-arg] - except TypeError: - # for apache-airflow-providers-databricks>=3.2.0 - hook = self._get_hook(caller="DatabricksRunNowOperatorAsync") - - if "job_name" in self.json: - job_id = hook.find_job_id_by_name(self.json["job_name"]) - if job_id is None: - raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found") - self.json["job_id"] = job_id - del self.json["job_name"] - self.run_id = hook.run_now(self.json) - - if self.do_xcom_push: - context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=self.run_id) - self.log.info("Run submitted with run_id: %s", self.run_id) - self.run_page_url = hook.get_run_page_url(self.run_id) - if self.do_xcom_push: - context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=self.run_page_url) - - self.log.info("View run status, Spark UI, and logs at %s", self.run_page_url) - - run_info = hook.get_run(self.run_id) - run_state = RunState(**run_info["state"]) - if not run_state.is_terminal: - self.defer( - timeout=self.execution_timeout, - trigger=DatabricksTrigger( - task_id=self.task_id, - conn_id=self.databricks_conn_id, - run_id=str(self.run_id), - run_page_url=self.run_page_url, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - polling_period_seconds=self.polling_period_seconds, - ), - method_name="execute_complete", - ) - elif run_state.is_terminal: - if run_state.is_successful: - self.log.info("%s completed successfully.", self.task_id) - return - else: - _handle_non_successful_terminal_states(run_state, run_info, hook, self.task_id) - - def execute_complete( - self, context: Context, event: Any = None - ) -> None: # pylint: disable=unused-argument - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if event["status"] == "error": - raise AirflowException(event["message"]) - - self.log.info("%s completed successfully.", self.task_id) - return None + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + "This class is deprecated." + "Use `airflow.providers.databricks.operators.databricks.DatabricksRunNowOperator` " + "and set `deferrable` param to `True` instead." + ) + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/databricks/triggers/databricks.py b/astronomer/providers/databricks/triggers/databricks.py index 1506143b7..0422e64d6 100644 --- a/astronomer/providers/databricks/triggers/databricks.py +++ b/astronomer/providers/databricks/triggers/databricks.py @@ -1,4 +1,5 @@ import asyncio +import warnings from typing import Any, AsyncIterator, Dict, Optional, Tuple from airflow.providers.databricks.hooks.databricks import RunState @@ -9,20 +10,8 @@ class DatabricksTrigger(BaseTrigger): """ - Wait asynchronously for databricks job to reach the terminal state. - - :param conn_id: The databricks connection id. - The default value is ``databricks_default``. - :param task_id: The task id. - :param run_id: The databricks job run id. - :param retry_limit: Amount of times retry if the Databricks backend is - unreachable. Its value must be greater than or equal to 1. - :param retry_delay: Number of seconds to wait between retries (it - might be a floating point number). - :param polling_period_seconds: Controls the rate which we poll for the result of - this run. By default, the operator will poll every 30 seconds. - :param job_id: The databricks job id. - :param run_page_url: The databricks run page url. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead. """ def __init__( @@ -36,6 +25,10 @@ def __init__( job_id: Optional[int] = None, run_page_url: Optional[str] = None, ): + warnings.warn( + "This class is deprecated and will be removed in 2.0.0." + "Use `airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger` instead." + ) super().__init__() self.conn_id = conn_id self.task_id = task_id diff --git a/setup.cfg b/setup.cfg index f9ba982d5..db125178c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -56,7 +56,7 @@ cncf.kubernetes = apache-airflow-providers-cncf-kubernetes>=4 kubernetes_asyncio databricks = - apache-airflow-providers-databricks>=2.2.0 + apache-airflow-providers-databricks>=6.1.0 databricks-sql-connector>=2.0.4;python_version>='3.10' dbt.cloud = apache-airflow-providers-dbt-cloud>=3.5.1 @@ -122,7 +122,7 @@ all = apache-airflow-providers-apache-hive>=6.1.5 apache-airflow-providers-apache-livy>=3.7.1 apache-airflow-providers-cncf-kubernetes>=4 - apache-airflow-providers-databricks>=2.2.0 + apache-airflow-providers-databricks>=6.1.0 apache-airflow-providers-google>=10.14.0 apache-airflow-providers-http apache-airflow-providers-snowflake diff --git a/tests/databricks/operators/test_databricks.py b/tests/databricks/operators/test_databricks.py index fa7b19e7f..7a98aa221 100644 --- a/tests/databricks/operators/test_databricks.py +++ b/tests/databricks/operators/test_databricks.py @@ -1,463 +1,37 @@ -from unittest import mock +from __future__ import annotations -import pytest -from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.databricks.operators.databricks import ( + DatabricksRunNowOperator, + DatabricksSubmitRunOperator, +) from astronomer.providers.databricks.operators.databricks import ( DatabricksRunNowOperatorAsync, DatabricksSubmitRunOperatorAsync, ) -from astronomer.providers.databricks.triggers.databricks import DatabricksTrigger -from tests.utils.airflow_util import create_context -JOB_ID = "42" -TASK_ID = "databricks_check" CONN_ID = "databricks_default" -RUN_ID = "1" -RUN_PAGE_URL = "https://www.test.com" -RETRY_LIMIT = 2 -RETRY_DELAY = 1.0 -POLLING_PERIOD_SECONDS = 1.0 -XCOM_RUN_ID_KEY = "run_id" -XCOM_RUN_PAGE_URL_KEY = "run_page_url" - - -def make_run_with_state_mock(lifecycle_state: str, result_state: str, state_message: str = ""): - return { - "state": { - "life_cycle_state": lifecycle_state, - "result_state": result_state, - "state_message": state_message, - }, - } class TestDatabricksSubmitRunOperatorAsync: - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksSubmitRunOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_job_id") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_operator_async_succeeded_before_defered( - self, submit_run_response, get_job_id, get_run_page_url_response, get_run, defer - ): - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_job_id.return_value = None - get_run.return_value = make_run_with_state_mock("TERMINATED", "SUCCESS") - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - - operator.execute(context=create_context(operator)) - - assert not defer.called - - @pytest.mark.parametrize("result_state", ("FAILED", "UNEXPECTED")) - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksSubmitRunOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_job_id") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_operator_async_failed_before_defered( - self, submit_run_response, get_job_id, get_run_page_url_response, get_run, defer, result_state - ): - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_job_id.return_value = None - get_run.return_value = make_run_with_state_mock("TERMINATED", result_state) - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - - assert not defer.called - - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksSubmitRunOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_output") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_job_id") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_operator_async_failed_with_error_in_run_output_before_defered( - self, - submit_run_response, - get_job_id, - get_run_page_url_response, - get_run, - get_run_output, - defer, - ): - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_job_id.return_value = None - get_run.return_value = make_run_with_state_mock("TERMINATED", "FAILED") - get_run.return_value["tasks"] = [{"state": {"result_state": "FAILED"}, "run_id": RUN_ID}] - get_run_output.return_value = {"error": "notebook error"} - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - - assert not defer.called - - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksSubmitRunOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_output") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_job_id") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_operator_async_failed_without_error_in_run_output_before_defered( - self, - submit_run_response, - get_job_id, - get_run_page_url_response, - get_run, - get_run_output, - defer, - ): - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_job_id.return_value = None - get_run.return_value = make_run_with_state_mock("TERMINATED", "FAILED") - get_run.return_value["tasks"] = [{"state": {"result_state": "FAILED"}, "run_id": RUN_ID}] - - operator = DatabricksSubmitRunOperatorAsync( + def test_init(self): + task = DatabricksSubmitRunOperatorAsync( task_id="submit_run", databricks_conn_id=CONN_ID, existing_cluster_id="xxxx-xxxxxx-xxxxxx", notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - assert not defer.called - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_job_id") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_operator_async( - self, submit_run_response, get_job_id, get_run_page_url_response, get_run - ): - """ - Asserts that a task is deferred and an DatabricksTrigger will be fired - when the DatabricksSubmitRunOperatorAsync is executed. - """ - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_job_id.return_value = None - get_run.return_value = make_run_with_state_mock("RUNNING", "SUCCESS") - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - - with pytest.raises(TaskDeferred) as exc: - operator.execute(context=create_context(operator)) - - assert isinstance(exc.value.trigger, DatabricksTrigger), "Trigger is not a DatabricksTrigger" - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_execute_complete_error( - self, submit_run_response, get_run_page_url_response - ): - """ - Asserts that a task is completed with success status. - """ - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - - with pytest.raises(AirflowException): - operator.execute_complete(context={}, event={"status": "error", "message": "error"}) - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_submit_run_execute_complete_success( - self, submit_run_response, get_run_page_url_response - ): - """Asserts that a task is completed with success status.""" - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - do_xcom_push=True, - ) - - assert ( - operator.execute_complete( - context=create_context(operator), - event={ - "status": "success", - "message": "success", - "job_id": "12345", - "run_id": RUN_ID, - "run_page_url": RUN_PAGE_URL, - }, - ) - is None - ) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator._get_hook") - def test_databricks_submit_run_operator_async_hook(self, mock_get_hook): - """ - Asserts that the hook raises TypeError for apache-airflow-providers-databricks>=3.2.0 - when the DatabricksSubmitRunOperatorAsync is executed. - """ - mock_get_hook.side_effect = TypeError("test exception") - operator = DatabricksSubmitRunOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - existing_cluster_id="xxxx-xxxxxx-xxxxxx", - notebook_task={"notebook_path": "/Users/test@astronomer.io/Quickstart Notebook"}, - ) - - with pytest.raises(TypeError): - operator.execute(context=create_context(operator)) + assert isinstance(task, DatabricksSubmitRunOperator) + assert task.deferrable is True class TestDatabricksRunNowOperatorAsync: - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksRunNowOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_async_succeeded_before_defered( - self, run_now_response, get_run_page_url_response, get_run, defer - ): - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("TERMINATED", "SUCCESS") - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", - databricks_conn_id=CONN_ID, - ) - - operator.execute(context=create_context(operator)) - - assert not defer.called - - @pytest.mark.parametrize("result_state", ("FAILED", "UNEXPECTED")) - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksRunNowOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_async_failed_before_defered( - self, run_now_response, get_run_page_url_response, get_run, defer, result_state - ): - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("TERMINATED", result_state) - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", - databricks_conn_id=CONN_ID, - ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - - assert not defer.called - - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksRunNowOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_output") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_failed_with_error_in_run_output_before_defered( - self, - run_now_response, - get_run_page_url_response, - get_run, - get_run_output, - defer, - ): - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("TERMINATED", "FAILED") - get_run.return_value["tasks"] = [{"state": {"result_state": "FAILED"}, "run_id": RUN_ID}] - get_run_output.return_value = {"error": "notebook error"} - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", - databricks_conn_id=CONN_ID, - ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - - assert not defer.called - - @mock.patch("astronomer.providers.databricks.operators.databricks.DatabricksRunNowOperatorAsync.defer") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_output") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_async_failed_without_error_in_run_output_before_defered( - self, - run_now_response, - get_run_page_url_response, - get_run, - get_run_output, - defer, - ): - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("TERMINATED", "FAILED") - get_run.return_value["tasks"] = [{"state": {"result_state": "FAILED"}, "run_id": RUN_ID}] - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", - databricks_conn_id=CONN_ID, - ) - with pytest.raises(AirflowException): - operator.execute(context=create_context(operator)) - - assert not defer.called - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_async(self, run_now_response, get_run_page_url_response, get_run): - """ - Asserts that a task is deferred and an DatabricksTrigger will be fired - when the DatabricksRunNowOperatorAsync is executed. - """ - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("RUNNING", "SUCCESS") - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", - databricks_conn_id=CONN_ID, - ) - - with pytest.raises(TaskDeferred) as exc: - operator.execute(context=create_context(operator)) - - assert isinstance(exc.value.trigger, DatabricksTrigger), "Trigger is not a DatabricksTrigger" - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.find_job_id_by_name") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.run_now") - @mock.patch("astronomer.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_operator_async_with_job_name( - self, run_now_response, get_run_page_url_response, get_run, find_job_id_by_name - ): - """ - Asserts that a task is deferred and an DatabricksTrigger will be fired - when the DatabricksRunNowOperatorAsync is executed. - """ - run_now_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - get_run.return_value = make_run_with_state_mock("RUNNING", "SUCCESS") - find_job_id_by_name.return_value = 1 - - operator = DatabricksRunNowOperatorAsync( - task_id="run_now", databricks_conn_id=CONN_ID, job_name="mock_name" - ) - - with pytest.raises(TaskDeferred) as exc: - operator.execute(context=create_context(operator)) - - assert isinstance(exc.value.trigger, DatabricksTrigger), "Trigger is not a DatabricksTrigger" - - def test_databricks_run_now_execute_complete(self): - """Asserts that logging occurs as expected""" - operator = DatabricksRunNowOperatorAsync( - task_id=TASK_ID, - databricks_conn_id=CONN_ID, - do_xcom_push=True, - ) - operator.run_page_url = RUN_PAGE_URL - with mock.patch.object(operator.log, "info") as mock_log_info: - operator.execute_complete(create_context(operator), {"status": "success", "message": "success"}) - mock_log_info.assert_called_with("%s completed successfully.", "databricks_check") - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_execute_complete_error(self, submit_run_response, get_run_page_url_response): - """Asserts that a task is completed with success status.""" - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - - operator = DatabricksRunNowOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - job_id="12345", - ) - - with pytest.raises(AirflowException): - operator.execute_complete(context={}, event={"status": "error", "message": "error"}) - - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.submit_run") - @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.get_run_page_url") - def test_databricks_run_now_execute_complete_success( - self, submit_run_response, get_run_page_url_response - ): - """Asserts that a task is completed with success status.""" - submit_run_response.return_value = {"run_id": RUN_ID} - get_run_page_url_response.return_value = RUN_PAGE_URL - - operator = DatabricksRunNowOperatorAsync( - task_id="submit_run", - databricks_conn_id=CONN_ID, - job_id="12345", - do_xcom_push=True, - ) - - assert ( - operator.execute_complete( - context=create_context(operator), - event={ - "status": "success", - "message": "success", - "job_id": "12345", - "run_id": RUN_ID, - "run_page_url": RUN_PAGE_URL, - }, - ) - is None - ) - - @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksRunNowOperator._get_hook") - def test_databricks_run_now_operator_async_hook(self, mock_get_hook): - """ - Asserts that the hook raises TypeError for apache-airflow-providers-databricks>=3.2.0 - when the DatabricksRunNowOperatorAsync is executed. - """ - mock_get_hook.side_effect = TypeError("test exception") - operator = DatabricksRunNowOperatorAsync( + def test_init(self): + task = DatabricksRunNowOperatorAsync( task_id="run_now", databricks_conn_id=CONN_ID, ) - with pytest.raises(TypeError): - operator.execute(context=create_context(operator)) + assert isinstance(task, DatabricksRunNowOperator) + assert task.deferrable is True