Skip to content

Commit

Permalink
Implement run-method on KiotaRequestAdapterHook and move logic away f…
Browse files Browse the repository at this point in the history
…rom triggerer to hook (apache#39237)



---------

Co-authored-by: David Blain <[email protected]>
  • Loading branch information
dabla and davidblain-infrabel authored Apr 25, 2024
1 parent 5eaf173 commit 15c2734
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 123 deletions.
164 changes: 157 additions & 7 deletions airflow/providers/microsoft/azure/hooks/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,74 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from urllib.parse import urljoin, urlparse
import json
from http import HTTPStatus
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import quote, urljoin, urlparse

import httpx
from azure.identity import ClientSecretCredential
from httpx import Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.response_handler import ResponseHandler
from kiota_authentication_azure.azure_identity_authentication_provider import (
AzureIdentityAuthenticationProvider,
)
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from msgraph_core import GraphClientFactory
from msgraph_core._enums import APIVersion, NationalClouds
from kiota_http.middleware.options import ResponseHandlerOption
from msgraph_core import APIVersion, GraphClientFactory
from msgraph_core._enums import NationalClouds

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType

from airflow.models import Connection


class CallableResponseHandler(ResponseHandler):
"""
CallableResponseHandler executes the passed callable_function with response as parameter.
param callable_function: Function that is applied to the response.
"""

def __init__(
self,
callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
):
self.callable_function = callable_function

async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
) -> Any:
"""
Invoke this callback method when a response is received.
param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
value = self.callable_function(response, error_map)
if response.status_code not in {200, 201, 202, 204, 302}:
message = value or response.reason_phrase
status_code = HTTPStatus(response.status_code)
if status_code == HTTPStatus.BAD_REQUEST:
raise AirflowBadRequest(message)
elif status_code == HTTPStatus.NOT_FOUND:
raise AirflowNotFoundException(message)
raise AirflowException(message)
return value


class KiotaRequestAdapterHook(BaseHook):
"""
A Microsoft Graph API interaction hook, a Wrapper around KiotaRequestAdapter.
Expand All @@ -54,6 +100,7 @@ class KiotaRequestAdapterHook(BaseHook):
or you can pass a string as "v1.0" or "beta".
"""

DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
default_conn_name: str = "msgraph_default"

Expand Down Expand Up @@ -117,13 +164,16 @@ def to_httpx_proxies(cls, proxies: dict) -> dict:
proxies[cls.format_no_proxy_url(url.strip())] = None
return proxies

@classmethod
def to_msal_proxies(cls, authority: str | None, proxies: dict):
def to_msal_proxies(self, authority: str | None, proxies: dict):
self.log.info("authority: %s", authority)
if authority:
no_proxies = proxies.get("no")
self.log.info("no_proxies: %s", no_proxies)
if no_proxies:
for url in no_proxies.split(","):
self.log.info("url: %s", url)
domain_name = urlparse(url).path.replace("*", "")
self.log.info("domain_name: %s", domain_name)
if authority.endswith(domain_name):
return None
return proxies
Expand Down Expand Up @@ -206,3 +256,103 @@ def get_conn(self) -> RequestAdapter:
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
self._api_version = api_version
return request_adapter

def test_connection(self):
"""Test HTTP Connection."""
try:
self.run()
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)

async def run(
self,
url: str = "",
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
):
response = await self.get_conn().send_primitive_async(
request_info=self.request_information(
url=url,
response_type=response_type,
response_handler=response_handler,
path_parameters=path_parameters,
method=method,
query_parameters=query_parameters,
headers=headers,
data=data,
),
response_type=response_type,
error_map=self.error_mapping(),
)

self.log.debug("response: %s", response)

return response

def request_information(
self,
url: str,
response_type: ResponseType | None = None,
response_handler: Callable[
[NativeResponseType, dict[str, ParsableFactory | None] | None], Any
] = lambda response, error_map: response.json(),
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
) -> RequestInformation:
request_information = RequestInformation()
request_information.path_parameters = path_parameters or {}
request_information.http_method = Method(method.strip().upper())
request_information.query_parameters = self.encoded_query_parameters(query_parameters)
if url.startswith("http"):
request_information.url = url
elif request_information.query_parameters.keys():
query = ",".join(request_information.query_parameters.keys())
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}{{?{query}}}"
else:
request_information.url_template = f"{{+baseurl}}/{self.normalize_url(url)}"
if not response_type:
request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
response_handler=CallableResponseHandler(response_handler)
)
headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name, header_value=header_value)
self.log.info("data: %s", data)
if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
request_information.content = data
elif data:
request_information.headers.try_add(
header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
)
request_information.content = json.dumps(data).encode("utf-8")
return request_information

@staticmethod
def normalize_url(url: str) -> str | None:
if url.startswith("/"):
return url.replace("/", "", 1)
return url

@staticmethod
def encoded_query_parameters(query_parameters) -> dict:
if query_parameters:
return {quote(key): value for key, value in query_parameters.items()}
return {}

@staticmethod
def error_mapping() -> dict[str, ParsableFactory | None]:
return {
"4XX": APIError,
"5XX": APIError,
}
98 changes: 12 additions & 86 deletions airflow/providers/microsoft/azure/triggers/msgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from base64 import b64encode
from contextlib import suppress
from datetime import datetime
from io import BytesIO
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Expand All @@ -31,21 +30,17 @@
Callable,
Sequence,
)
from urllib.parse import quote
from uuid import UUID

import pendulum
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
from kiota_abstractions.response_handler import ResponseHandler
from kiota_http.middleware.options import ResponseHandlerOption

from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from io import BytesIO

from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
Expand Down Expand Up @@ -87,31 +82,6 @@ def deserialize(self, response) -> Any:
return response


class CallableResponseHandler(ResponseHandler):
"""
CallableResponseHandler executes the passed callable_function with response as parameter.
param callable_function: Function that is applied to the response.
"""

def __init__(
self,
callable_function: Callable[[NativeResponseType, dict[str, ParsableFactory | None] | None], Any],
):
self.callable_function = callable_function

async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
) -> Any:
"""
Invoke this callback method when a response is received.
param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
return self.callable_function(response, error_map)


class MSGraphTrigger(BaseTrigger):
"""
A Microsoft Graph API trigger which allows you to execute an async REST call to the Microsoft Graph API.
Expand All @@ -134,7 +104,6 @@ class MSGraphTrigger(BaseTrigger):
Bytes will be base64 encoded into a string, so it can be stored as an XCom.
"""

DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
template_fields: Sequence[str] = (
"url",
"response_type",
Expand Down Expand Up @@ -235,7 +204,16 @@ def api_version(self) -> APIVersion:
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook."""
try:
response = await self.execute()
response = await self.hook.run(
url=self.url,
response_type=self.response_type,
response_handler=self.response_handler,
path_parameters=self.path_parameters,
method=self.method,
query_parameters=self.query_parameters,
headers=self.headers,
data=self.data,
)

self.log.debug("response: %s", response)

Expand All @@ -262,55 +240,3 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
except Exception as e:
self.log.exception("An error occurred: %s", e)
yield TriggerEvent({"status": "failure", "message": str(e)})

def normalize_url(self) -> str | None:
if self.url.startswith("/"):
return self.url.replace("/", "", 1)
return self.url

def encoded_query_parameters(self) -> dict:
if self.query_parameters:
return {quote(key): value for key, value in self.query_parameters.items()}
return {}

def request_information(self) -> RequestInformation:
request_information = RequestInformation()
request_information.path_parameters = self.path_parameters or {}
request_information.http_method = Method(self.method.strip().upper())
request_information.query_parameters = self.encoded_query_parameters()
if self.url.startswith("http"):
request_information.url = self.url
elif request_information.query_parameters.keys():
query = ",".join(request_information.query_parameters.keys())
request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}{{?{query}}}"
else:
request_information.url_template = f"{{+baseurl}}/{self.normalize_url()}"
if not self.response_type:
request_information.request_options[ResponseHandlerOption.get_key()] = ResponseHandlerOption(
response_handler=CallableResponseHandler(self.response_handler)
)
headers = {**self.DEFAULT_HEADERS, **self.headers} if self.headers else self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name, header_value=header_value)
if isinstance(self.data, BytesIO) or isinstance(self.data, bytes) or isinstance(self.data, str):
request_information.content = self.data
elif self.data:
request_information.headers.try_add(
header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
)
request_information.content = json.dumps(self.data).encode("utf-8")
return request_information

@staticmethod
def error_mapping() -> dict[str, ParsableFactory | None]:
return {
"4XX": APIError,
"5XX": APIError,
}

async def execute(self) -> AsyncIterator[TriggerEvent]:
return await self.get_conn().send_primitive_async(
request_info=self.request_information(),
response_type=self.response_type,
error_map=self.error_mapping(),
)
Loading

0 comments on commit 15c2734

Please sign in to comment.