diff --git a/airflow/providers/microsoft/azure/hooks/msgraph.py b/airflow/providers/microsoft/azure/hooks/msgraph.py index 7fcc328f8670a..84b2252bd25ae 100644 --- a/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -17,28 +17,74 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING -from urllib.parse import urljoin, urlparse +import json +from http import HTTPStatus +from io import BytesIO +from typing import TYPE_CHECKING, Any, Callable +from urllib.parse import quote, urljoin, urlparse import httpx from azure.identity import ClientSecretCredential from httpx import Timeout +from kiota_abstractions.api_error import APIError +from kiota_abstractions.method import Method +from kiota_abstractions.request_information import RequestInformation +from kiota_abstractions.response_handler import ResponseHandler from kiota_authentication_azure.azure_identity_authentication_provider import ( AzureIdentityAuthenticationProvider, ) from kiota_http.httpx_request_adapter import HttpxRequestAdapter -from msgraph_core import GraphClientFactory -from msgraph_core._enums import APIVersion, NationalClouds +from kiota_http.middleware.options import ResponseHandlerOption +from msgraph_core import APIVersion, GraphClientFactory +from msgraph_core._enums import NationalClouds -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException from airflow.hooks.base import BaseHook if TYPE_CHECKING: from kiota_abstractions.request_adapter import RequestAdapter + from kiota_abstractions.request_information import QueryParams + from kiota_abstractions.response_handler import NativeResponseType + from kiota_abstractions.serialization import ParsableFactory + from kiota_http.httpx_request_adapter import ResponseType from airflow.models import Connection +class CallableResponseHandler(ResponseHandler): + """ + CallableResponseHandler executes the passed callable_function with response as parameter. + + param callable_function: Function that is applied to the response. + """ + + def __init__( + self, + callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any], + ): + self.callable_function = callable_function + + async def handle_response_async( + self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None + ) -> Any: + """ + Invoke this callback method when a response is received. + + param response: The type of the native response object. + param error_map: The error dict to use in case of a failed request. + """ + value = self.callable_function(response, error_map) + if response.status_code not in {200, 201, 202, 204, 302}: + message = value or response.reason_phrase + status_code = HTTPStatus(response.status_code) + if status_code == HTTPStatus.BAD_REQUEST: + raise AirflowBadRequest(message) + elif status_code == HTTPStatus.NOT_FOUND: + raise AirflowNotFoundException(message) + raise AirflowException(message) + return value + + class KiotaRequestAdapterHook(BaseHook): """ A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter. @@ -54,6 +100,7 @@ class KiotaRequestAdapterHook(BaseHook): or you can pass a string as "v1.0" or "beta". """ + DEFAULT_HEADERS = {"Accept": "application/json;q=1"} cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {} default_conn_name: str = "msgraph_default" @@ -117,13 +164,16 @@ def to_httpx_proxies(cls, proxies: dict) -> dict: proxies[cls.format_no_proxy_url(url.strip())] = None return proxies - @classmethod - def to_msal_proxies(cls, authority: str | None, proxies: dict): + def to_msal_proxies(self, authority: str | None, proxies: dict): + self.log.info("authority: %s", authority) if authority: no_proxies = proxies.get("no") + self.log.info("no_proxies: %s", no_proxies) if no_proxies: for url in no_proxies.split(","): + self.log.info("url: %s", url) domain_name = urlparse(url).path.replace("*", "") + self.log.info("domain_name: %s", domain_name) if authority.endswith(domain_name): return None return proxies @@ -206,3 +256,103 @@ def get_conn(self) -> RequestAdapter: self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) self._api_version = api_version return request_adapter + + def test_connection(self): + """Test HTTP Connection.""" + try: + self.run() + return True, "Connection successfully tested" + except Exception as e: + return False, str(e) + + async def run( + self, + url: str = "", + response_type: ResponseType | None = None, + response_handler: Callable[ + [NativeResponseType, dict[str, ParsableFactory | None] | None], Any + ] = lambda response, error_map: response.json(), + path_parameters: dict[str, Any] | None = None, + method: str = "GET", + query_parameters: dict[str, QueryParams] | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | BytesIO | None = None, + ): + response = await self.get_conn().send_primitive_async( + request_info=self.request_information( + url=url, + response_type=response_type, + response_handler=response_handler, + path_parameters=path_parameters, + method=method, + query_parameters=query_parameters, + headers=headers, + data=data, + ), + response_type=response_type, + error_map=self.error_mapping(), + ) + + self.log.debug("response: %s", response) + + return response + + def request_information( + self, + url: str, + response_type: ResponseType | None = None, + response_handler: Callable[ + [NativeResponseType, dict[str, ParsableFactory | None] | None], Any + ] = lambda response, error_map: response.json(), + path_parameters: dict[str, Any] | None = None, + method: str = "GET", + query_parameters: dict[str, QueryParams] | None = None, + headers: dict[str, str] | None = None, + data: dict[str, Any] | str | BytesIO | None = None, + ) -> RequestInformation: + request_information = RequestInformation() + request_information.path_parameters = path_parameters or {} + request_information.http_method = Method(method.strip().upper()) + request_information.query_parameters = self.encoded_query_parameters(query_parameters) + if url.startswith("http"): + request_information.url = url + elif request_information.query_parameters.keys(): + query = ",".join(request_information.query_parameters.keys()) + request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}{{?{query}}}" + else: + request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}" + if not response_type: + request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption( + response_handler=CallableResponseHandler(response_handler) + ) + headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS + for header_name, header_value in headers.items(): + request_information.headers.try_add(header_name=header_name, header_value=header_value) + self.log.info("data: %s", data) + if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str): + request_information.content = data + elif data: + request_information.headers.try_add( + header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" + ) + request_information.content = json.dumps(data).encode("utf-8") + return request_information + + @staticmethod + def normalize_url(url: str) -> str | None: + if url.startswith("/"): + return url.replace("/", "", 1) + return url + + @staticmethod + def encoded_query_parameters(query_parameters) -> dict: + if query_parameters: + return {quote(key): value for key, value in query_parameters.items()} + return {} + + @staticmethod + def error_mapping() -> dict[str, ParsableFactory | None]: + return { + "4XX": APIError, + "5XX": APIError, + } diff --git a/airflow/providers/microsoft/azure/triggers/msgraph.py b/airflow/providers/microsoft/azure/triggers/msgraph.py index 814a91f7c10a4..1848f969f8431 100644 --- a/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -22,7 +22,6 @@ from base64 import b64encode from contextlib import suppress from datetime import datetime -from io import BytesIO from json import JSONDecodeError from typing import ( TYPE_CHECKING, @@ -31,21 +30,17 @@ Callable, Sequence, ) -from urllib.parse import quote from uuid import UUID import pendulum -from kiota_abstractions.api_error import APIError -from kiota_abstractions.method import Method -from kiota_abstractions.request_information import RequestInformation -from kiota_abstractions.response_handler import ResponseHandler -from kiota_http.middleware.options import ResponseHandlerOption from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.module_loading import import_string if TYPE_CHECKING: + from io import BytesIO + from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.request_information import QueryParams from kiota_abstractions.response_handler import NativeResponseType @@ -87,31 +82,6 @@ def deserialize(self, response) -> Any: return response -class CallableResponseHandler(ResponseHandler): - """ - CallableResponseHandler executes the passed callable_function with response as parameter. - - param callable_function: Function that is applied to the response. - """ - - def __init__( - self, - callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any], - ): - self.callable_function = callable_function - - async def handle_response_async( - self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None - ) -> Any: - """ - Invoke this callback method when a response is received. - - param response: The type of the native response object. - param error_map: The error dict to use in case of a failed request. - """ - return self.callable_function(response, error_map) - - class MSGraphTrigger(BaseTrigger): """ A Microsoft Graph API trigger which allows you to execute an async REST call to the Microsoft Graph API. @@ -134,7 +104,6 @@ class MSGraphTrigger(BaseTrigger): Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ - DEFAULT_HEADERS = {"Accept": "application/json;q=1"} template_fields: Sequence[str] = ( "url", "response_type", @@ -235,7 +204,16 @@ def api_version(self) -> APIVersion: async def run(self) -> AsyncIterator[TriggerEvent]: """Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook.""" try: - response = await self.execute() + response = await self.hook.run( + url=self.url, + response_type=self.response_type, + response_handler=self.response_handler, + path_parameters=self.path_parameters, + method=self.method, + query_parameters=self.query_parameters, + headers=self.headers, + data=self.data, + ) self.log.debug("response: %s", response) @@ -262,55 +240,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except Exception as e: self.log.exception("An error occurred: %s", e) yield TriggerEvent({"status": "failure", "message": str(e)}) - - def normalize_url(self) -> str | None: - if self.url.startswith("/"): - return self.url.replace("/", "", 1) - return self.url - - def encoded_query_parameters(self) -> dict: - if self.query_parameters: - return {quote(key): value for key, value in self.query_parameters.items()} - return {} - - def request_information(self) -> RequestInformation: - request_information = RequestInformation() - request_information.path_parameters = self.path_parameters or {} - request_information.http_method = Method(self.method.strip().upper()) - request_information.query_parameters = self.encoded_query_parameters() - if self.url.startswith("http"): - request_information.url = self.url - elif request_information.query_parameters.keys(): - query = ",".join(request_information.query_parameters.keys()) - request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}{{?{query}}}" - else: - request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}" - if not self.response_type: - request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption( - response_handler=CallableResponseHandler(self.response_handler) - ) - headers = {**self.DEFAULT_HEADERS, **self.headers} if self.headers else self.DEFAULT_HEADERS - for header_name, header_value in headers.items(): - request_information.headers.try_add(header_name=header_name, header_value=header_value) - if isinstance(self.data, BytesIO) or isinstance(self.data, bytes) or isinstance(self.data, str): - request_information.content = self.data - elif self.data: - request_information.headers.try_add( - header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" - ) - request_information.content = json.dumps(self.data).encode("utf-8") - return request_information - - @staticmethod - def error_mapping() -> dict[str, ParsableFactory | None]: - return { - "4XX": APIError, - "5XX": APIError, - } - - async def execute(self) -> AsyncIterator[TriggerEvent]: - return await self.get_conn().send_primitive_async( - request_info=self.request_information(), - response_type=self.response_type, - error_map=self.error_mapping(), - ) diff --git a/tests/providers/microsoft/azure/hooks/test_msgraph.py b/tests/providers/microsoft/azure/hooks/test_msgraph.py index 1c1046e1fa4f3..9d2db07acf709 100644 --- a/tests/providers/microsoft/azure/hooks/test_msgraph.py +++ b/tests/providers/microsoft/azure/hooks/test_msgraph.py @@ -16,13 +16,21 @@ # under the License. from __future__ import annotations +import asyncio from unittest.mock import patch +import pytest from kiota_http.httpx_request_adapter import HttpxRequestAdapter from msgraph_core import APIVersion, NationalClouds -from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook -from tests.providers.microsoft.conftest import get_airflow_connection, mock_connection +from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException +from airflow.providers.microsoft.azure.hooks.msgraph import CallableResponseHandler, KiotaRequestAdapterHook +from tests.providers.microsoft.conftest import ( + get_airflow_connection, + load_json, + mock_connection, + mock_json_response, +) class TestKiotaRequestAdapterHook: @@ -77,3 +85,55 @@ def test_get_host_when_connection_has_no_scheme_or_host(self): actual = KiotaRequestAdapterHook.get_host(connection) assert actual == NationalClouds.Global.value + + def test_encoded_query_parameters(self): + actual = KiotaRequestAdapterHook.encoded_query_parameters( + query_parameters={"$expand": "reports,users,datasets,dataflows,dashboards", "$top": 5000}, + ) + + assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000} + + +class TestResponseHandler: + def test_handle_response_async_when_ok(self): + users = load_json("resources", "users.json") + response = mock_json_response(200, users) + + actual = asyncio.run( + CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( + response, None + ) + ) + + assert isinstance(actual, dict) + assert actual == users + + def test_handle_response_async_when_bad_request(self): + response = mock_json_response(400, {}) + + with pytest.raises(AirflowBadRequest): + asyncio.run( + CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( + response, None + ) + ) + + def test_handle_response_async_when_not_found(self): + response = mock_json_response(404, {}) + + with pytest.raises(AirflowNotFoundException): + asyncio.run( + CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( + response, None + ) + ) + + def test_handle_response_async_when_internal_server_error(self): + response = mock_json_response(500, {}) + + with pytest.raises(AirflowException): + asyncio.run( + CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( + response, None + ) + ) diff --git a/tests/providers/microsoft/azure/triggers/test_msgraph.py b/tests/providers/microsoft/azure/triggers/test_msgraph.py index a4e2c7f0fd598..23085563cf8f9 100644 --- a/tests/providers/microsoft/azure/triggers/test_msgraph.py +++ b/tests/providers/microsoft/azure/triggers/test_msgraph.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import asyncio import json import locale from base64 import b64decode, b64encode @@ -28,7 +27,6 @@ from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.triggers.msgraph import ( - CallableResponseHandler, MSGraphTrigger, ResponseSerializer, ) @@ -132,32 +130,6 @@ def test_template_fields(self): for template_field in MSGraphTrigger.template_fields: getattr(trigger, template_field) - def test_encoded_query_parameters(self): - trigger = MSGraphTrigger( - url="myorg/admin/groups", - conn_id="msgraph_api", - query_parameters={"$expand": "reports,users,datasets,dataflows,dashboards", "$top": 5000}, - ) - - actual = trigger.encoded_query_parameters() - - assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000} - - -class TestResponseHandler: - def test_handle_response_async(self): - users = load_json("resources", "users.json") - response = mock_json_response(200, users) - - actual = asyncio.run( - CallableResponseHandler(lambda response, error_map: response.json()).handle_response_async( - response, None - ) - ) - - assert isinstance(actual, dict) - assert actual == users - class TestResponseSerializer: def test_serialize_when_bytes_then_base64_encoded(self):