Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webhook tasks using FlyteAgents #3058

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
4 changes: 4 additions & 0 deletions flytekit/extras/webhook/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .agent import WebhookAgent
from .task import WebhookTask

__all__ = ["WebhookTask", "WebhookAgent"]
94 changes: 94 additions & 0 deletions flytekit/extras/webhook/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import http
from typing import Optional

import aiohttp
from flyteidl.core.execution_pb2 import TaskExecution

from flytekit.extend.backend.base_agent import AgentRegistry, Resource, SyncAgentBase
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.utils.dict_formatter import format_dict

from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY


class WebhookAgent(SyncAgentBase):
name = "Webhook Agent"

def __init__(self):
super().__init__(task_type_name=TASK_TYPE)

async def do(
self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
) -> Resource:
try:
final_dict = self._get_final_dict(task_template, inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for dict formatting

Consider adding error handling for self._get_final_dict() call to handle potential exceptions from format_dict() or literal_map_string_repr()

Code suggestion
Check the AI-generated fix before applying
Suggested change
final_dict = self._get_final_dict(task_template, inputs)
try:
final_dict = self._get_final_dict(task_template, inputs)
except ValueError as e:
return Resource(phase=TaskExecution.FAILED, message=f"Failed to format webhook data: {str(e)}")

Code Review Run #882444


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

return await self._process_webhook(final_dict)
except aiohttp.ClientError as e:
return Resource(phase=TaskExecution.FAILED, message=str(e))

def _get_final_dict(self, task_template: TaskTemplate, inputs: LiteralMap) -> dict:
custom_dict = task_template.custom
input_dict = {
"inputs": literal_map_string_repr(inputs),
}
return format_dict("test", custom_dict, input_dict)

async def _make_http_request(
self, method: http.HTTPMethod, url: str, headers: dict, data: dict, timeout: int
) -> tuple:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider more specific return type annotation

The _make_http_request method's return type annotation tuple is too generic. Consider specifying the exact types being returned (int for status and str for text).

Code suggestion
Check the AI-generated fix before applying
Suggested change
) -> tuple:
) -> tuple[int, str]:

Code Review Run #882444


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

# TODO This is a potential performance bottleneck. Consider using a connection pool. To do this, we need to
# create a session object and reuse it for multiple requests. This will reduce the overhead of creating a new
# connection for each request. The problem for not doing so is local execution, does not have a common event
# loop and agent executor creates a new event loop for each request (in the mixin).
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
if method == http.HTTPMethod.GET:
response = await session.get(url, headers=headers, params=data)
else:
response = await session.post(url, json=data, headers=headers)
return response.status, await response.text()

@staticmethod
def _build_response(
status: int,
text: str,
data: dict = None,
url: str = None,
show_data: bool = False,
show_url: bool = False,
) -> dict:
final_response = {
"status_code": status,
"response_data": text,
}
if show_data:
final_response["input_data"] = data
if show_url:
final_response["url"] = url
return final_response

async def _process_webhook(self, final_dict: dict) -> Resource:
url = final_dict.get(URL_KEY)
body = final_dict.get(DATA_KEY)
headers = final_dict.get(HEADERS_KEY)
method = http.HTTPMethod(final_dict.get(METHOD_KEY))
show_data = final_dict.get(SHOW_DATA_KEY, False)
show_url = final_dict.get(SHOW_URL_KEY, False)
timeout_sec = final_dict.get(TIMEOUT_SEC, 10)

status, text = await self._make_http_request(method, url, headers, body, timeout_sec)
Comment on lines +88 to +90
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating timeout_sec parameter value

Consider validating the timeout_sec value to ensure it's a positive integer. A negative or zero timeout could cause unexpected behavior.

Code suggestion
Check the AI-generated fix before applying
Suggested change
timeout_sec = final_dict.get(TIMEOUT_SEC, 10)
status, text = await self._make_http_request(method, url, headers, body, timeout_sec)
timeout_sec = final_dict.get(TIMEOUT_SEC, 10)
if not isinstance(timeout_sec, int) or timeout_sec <= 0:
timeout_sec = 10
status, text = await self._make_http_request(method, url, headers, body, timeout_sec)

Code Review Run #882444


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

if status != 200:
return Resource(
phase=TaskExecution.FAILED,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider broader exception handling for webhook

Consider adding error handling for the _process_webhook call. The method could raise other exceptions besides aiohttp.ClientError that should be caught.

Code suggestion
Check the AI-generated fix before applying
Suggested change
phase=TaskExecution.FAILED,
return Resource(phase=TaskExecution.FAILED, message=f"HTTP client error: {str(e)}")
except Exception as e:
return Resource(phase=TaskExecution.FAILED, message=f"Webhook processing error: {str(e)}")

Code Review Run #cec794


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

message=f"Webhook failed with status code {status}, response: {text}",
)
final_response = self._build_response(status, text, body, url, show_data, show_url)
return Resource(
phase=TaskExecution.SUCCEEDED,
outputs={"info": final_response},
message="Webhook was successfully invoked!",
)


AgentRegistry.register(WebhookAgent())
9 changes: 9 additions & 0 deletions flytekit/extras/webhook/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
TASK_TYPE: str = "webhook"

URL_KEY: str = "url"
METHOD_KEY: str = "method"
HEADERS_KEY: str = "headers"
DATA_KEY: str = "data"
SHOW_DATA_KEY: str = "show_data"
SHOW_URL_KEY: str = "show_url"
TIMEOUT_SEC: str = "timeout_sec"
105 changes: 105 additions & 0 deletions flytekit/extras/webhook/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import http
from datetime import timedelta
from typing import Any, Dict, Optional, Type, Union

from flytekit import Documentation
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin

from ...core.interface import Interface
from .constants import DATA_KEY, HEADERS_KEY, METHOD_KEY, SHOW_DATA_KEY, SHOW_URL_KEY, TASK_TYPE, TIMEOUT_SEC, URL_KEY


class WebhookTask(SyncAgentExecutorMixin, PythonTask):
"""
This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output.
"""

def __init__(
self,
name: str,
url: str,
method: http.HTTPMethod = http.HTTPMethod.POST,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
dynamic_inputs: Optional[Dict[str, Type]] = None,
show_data: bool = False,
show_url: bool = False,
description: Optional[str] = None,
timeout: Union[int, timedelta] = timedelta(seconds=10),
# secret_requests: Optional[List[Secret]] = None, TODO Secret support is coming soon
):
"""
This task is used to invoke a webhook. The webhook can be invoked with a POST or GET method.

All the parameters can be formatted using python format strings. The following parameters are available for
formatting:
- dynamic_inputs: These are the dynamic inputs to the task. The keys are the names of the inputs and the values
are the values of the inputs. All inputs are available under the prefix `inputs.`.
For example, if the inputs are {"input1": 10, "input2": "hello"}, then you can
use {inputs.input1} and {inputs.input2} in the URL and the body. Define the dynamic_inputs argument in the
constructor to use these inputs. The dynamic inputs should not be actual values, but the types of the inputs.

TODO Coming soon secrets support
- secrets: These are the secrets that are requested by the task. The keys are the names of the secrets and the
values are the values of the secrets. All secrets are available under the prefix `secrets.`.
For example, if the secret requested are Secret(name="secret1") and Secret(name="secret), then you can use
{secrets.secret1} and {secrets.secret2} in the URL and the body. Define the secret_requests argument in the
constructor to use these secrets. The secrets should not be actual values, but the types of the secrets.

:param name: Name of this task, should be unique in the project
:param url: The endpoint or URL to invoke for this webhook. This can be a static string or a python format string,
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
details of available formatting parameters.
:param method: The HTTP method to use for the request. Default is POST.
:param headers: The headers to send with the request. This can be a static dictionary or a python format string,
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
details of available formatting parameters.
:param data: The body to send with the request. This can be a static dictionary or a python format string,
where the format arguments are the dynamic_inputs to the task, secrets etc. Refer to the description for more
details of available formatting parameters. the data should be a json serializable dictionary and will be
sent as the json body of the POST request and as the query parameters of the GET request.
:param dynamic_inputs: The dynamic inputs to the task. The keys are the names of the inputs and the values
are the types of the inputs. These inputs are available under the prefix `inputs.` to be used in the URL,
headers and body and other formatted fields.
:param secret_requests: The secrets that are requested by the task. (TODO not yet supported)
:param show_data: If True, the body of the request will be logged in the UI as the output of the task.
:param show_url: If True, the URL of the request will be logged in the UI as the output of the task.
:param description: Description of the task
:param timeout: The timeout for the request (connection and read). Default is 10 seconds. If int value is provided,
it is considered as seconds.
"""
if method not in {http.HTTPMethod.GET, http.HTTPMethod.POST}:
raise ValueError(f"Method should be either GET or POST. Got {method}")

interface = Interface(
inputs=dynamic_inputs or {},
outputs={"info": dict},
)
super().__init__(
name=name,
interface=interface,
task_type=TASK_TYPE,
# secret_requests=secret_requests,
docs=Documentation(short_description=description) if description else None,
)
self._url = url
self._method = method
self._headers = headers
self._data = data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating data parameter JSON serialization

Consider adding validation for the data parameter to ensure it is JSON serializable before sending the request. This could help catch serialization issues early.

Code suggestion
Check the AI-generated fix before applying
 @@ -85,2 +85,7 @@
          self._headers = headers
 +        if data is not None:
 +            import json
 +            try:
 +                json.dumps(data)
 +            except (TypeError, ValueError) as e:
 +                raise ValueError(f"The data parameter must be JSON serializable. Error: {str(e)}")
          self._data = data

Code Review Run #cec794


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

self._show_data = show_data
self._show_url = show_url
self._timeout_sec = timeout if isinstance(timeout, int) else timeout.total_seconds()

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
config = {
URL_KEY: self._url,
METHOD_KEY: self._method.value,
HEADERS_KEY: self._headers or {},
DATA_KEY: self._data or {},
SHOW_DATA_KEY: self._show_data,
Comment on lines +137 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider backward compatibility for parameter rename

Consider if renaming from body to data maintains backward compatibility. This change could potentially break existing code that relies on the body parameter.

Code suggestion
Check the AI-generated fix before applying
Suggested change
DATA_KEY: self._data or {},
SHOW_DATA_KEY: self._show_data,
DATA_KEY: self._data or self._body or {},
SHOW_DATA_KEY: self._show_data if self._show_data is not None else self._show_body,

Code Review Run #cec794


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

SHOW_URL_KEY: self._show_url,
TIMEOUT_SEC: self._timeout_sec,
}
return config
86 changes: 86 additions & 0 deletions flytekit/utils/dict_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import re
from typing import Any, Dict, Optional


def get_nested_value(d: Dict[str, Any], keys: list[str]) -> Any:
"""
Retrieve the nested value from a dictionary based on a list of keys.
"""
for key in keys:
if key not in d:
raise ValueError(f"Could not find the key {key} in {d}.")
d = d[key]
return d


def replace_placeholder(
service: str,
original_dict: str,
placeholder: str,
replacement: str,
) -> str:
"""
Replace a placeholder in the original string and handle the specific logic for the sagemaker service and idempotence token.
"""
temp_dict = original_dict.replace(f"{{{placeholder}}}", replacement)
if service == "sagemaker" and placeholder in [
"inputs.idempotence_token",
"idempotence_token",
]:
if len(temp_dict) > 63:
truncated_token = replacement[: 63 - len(original_dict.replace(f"{{{placeholder}}}", ""))]
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
return original_dict.replace(f"{{{placeholder}}}", truncated_token)
else:
return temp_dict
return temp_dict


def format_dict(
service: str,
original_dict: Any,
update_dict: Dict[str, Any],
idempotence_token: Optional[str] = None,
) -> Any:
"""
Recursively update a dictionary with format strings with values from another dictionary where the keys match
the format string. This goes a little beyond regular python string formatting and uses `.` to denote nested keys.

For example, if original_dict is {"EndpointConfigName": "{endpoint_config_name}"},
and update_dict is {"endpoint_config_name": "my-endpoint-config"},
then the result will be {"EndpointConfigName": "my-endpoint-config"}.

For nested keys if the original_dict is {"EndpointConfigName": "{inputs.endpoint_config_name}"},
and update_dict is {"inputs": {"endpoint_config_name": "my-endpoint-config"}},
then the result will be {"EndpointConfigName": "my-endpoint-config"}.

:param service: The AWS service to use
:param original_dict: The dictionary to update (in place)
:param update_dict: The dictionary to use for updating
:param idempotence_token: Hash of config -- this is to ensure the execution ID is deterministic
:return: The updated dictionary
"""
if original_dict is None:
return None

if isinstance(original_dict, str) and "{" in original_dict and "}" in original_dict:
matches = re.findall(r"\{([^}]+)\}", original_dict)
for match in matches:
if "." in match:
keys = match.split(".")
nested_value = get_nested_value(update_dict, keys)
if f"{{{match}}}" == original_dict:
return nested_value
else:
original_dict = replace_placeholder(service, original_dict, match, str(nested_value))
elif match == "idempotence_token" and idempotence_token:
original_dict = replace_placeholder(service, original_dict, match, idempotence_token)
return original_dict

if isinstance(original_dict, list):
return [format_dict(service, item, update_dict, idempotence_token) for item in original_dict]

if isinstance(original_dict, dict):
for key, value in original_dict.items():
original_dict[key] = format_dict(service, value, update_dict, idempotence_token)

return original_dict
Loading
Loading