Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc fixes #627

Merged
merged 6 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Annotated, Any, Optional, TypedDict, Union

import httpx
Expand All @@ -20,6 +21,8 @@ class RequestArgs(TypedDict):
json_: Optional[dict[str, Any]]
cookies: Optional[dict[str, str]]
params: Optional[Union[str, dict[str, Any]]]
url: Optional[str]
headers: Optional[dict[str, str]]


@beartype
Expand All @@ -29,18 +32,23 @@ async def execute_api_call(
) -> Any:
try:
async with httpx.AsyncClient() as client:
arg_url = request_args.pop("url", None)
arg_headers = request_args.pop("headers", None)

response = await client.request(
method=api_call.method,
url=str(api_call.url),
headers=api_call.headers,
url=arg_url or str(api_call.url),
headers=arg_headers or api_call.headers,
follow_redirects=api_call.follow_redirects,
**request_args,
)

content_base64 = base64.b64encode(response.content).decode("ascii")

response_dict = {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": response.content,
"content": content_base64,
"json": response.json(),
}

Expand Down
189 changes: 174 additions & 15 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import base64
import datetime as dt
import functools
import itertools
import json
from functools import reduce
from itertools import accumulate
from random import random
from time import time
from typing import Any, Callable
import math
import random
import statistics
import string
import time
import urllib.parse
from typing import Any, Callable, ParamSpec, Type, TypeVar, cast

import re2
import yaml
import zoneinfo
from beartype import beartype
from simpleeval import EvalWithCompoundTypes, SimpleEval
from yaml import CSafeLoader
from yaml import CSafeDumper, CSafeLoader

T = TypeVar("T")


P = ParamSpec("P")
R = TypeVar("R")


# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -25,23 +38,169 @@
"int": int,
"len": len,
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"random": random,
"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
"str": str,
"sum": sum,
"time": time,
"tuple": tuple,
"reduce": functools.reduce,
"zip": zip,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
}


class stdlib_re:
fullmatch = re2.fullmatch
search = re2.search
escape = re2.escape
findall = re2.findall
finditer = re2.finditer
match = re2.match
split = re2.split
sub = re2.sub
subn = re2.subn


class stdlib_json:
loads = json.loads
dumps = json.dumps


class stdlib_yaml:
load = lambda string: yaml.load(string, Loader=CSafeLoader) # noqa: E731
dump = lambda value: yaml.dump(value, Dumper=CSafeDumper) # noqa: E731


class stdlib_time:
strftime = time.strftime
strptime = time.strptime
time = time


class stdlib_random:
choice = random.choice
choices = random.choices
sample = random.sample
shuffle = random.shuffle
randrange = random.randrange
randint = random.randint
random = random.random


class stdlib_itertools:
accumulate = itertools.accumulate


class stdlib_functools:
partial = functools.partial
reduce = functools.reduce


class stdlib_base64:
b64encode = base64.b64encode
b64decode = base64.b64decode


class stdlib_urllib:
class parse:
urlparse = urllib.parse.urlparse
urlencode = urllib.parse.urlencode
unquote = urllib.parse.unquote
quote = urllib.parse.quote
parse_qs = urllib.parse.parse_qs
parse_qsl = urllib.parse.parse_qsl
urlsplit = urllib.parse.urlsplit
urljoin = urllib.parse.urljoin
unwrap = urllib.parse.unwrap


class stdlib_string:
ascii_letters = string.ascii_letters
ascii_lowercase = string.ascii_lowercase
ascii_uppercase = string.ascii_uppercase
digits = string.digits
hexdigits = string.hexdigits
octdigits = string.octdigits
punctuation = string.punctuation
whitespace = string.whitespace
printable = string.printable


class stdlib_zoneinfo:
ZoneInfo = zoneinfo.ZoneInfo


class stdlib_datetime:
class timezone:
class utc:
utc = dt.timezone.utc

class datetime:
now = dt.datetime.now
datetime = dt.datetime
timedelta = dt.timedelta
date = dt.date
time = dt.time

timedelta = dt.timedelta


class stdlib_math:
sqrt = math.sqrt
exp = math.exp
ceil = math.ceil
floor = math.floor
isinf = math.isinf
isnan = math.isnan
log = math.log
log10 = math.log10
log2 = math.log2
pow = math.pow
sin = math.sin
cos = math.cos
tan = math.tan
asin = math.asin
acos = math.acos
atan = math.atan
atan2 = math.atan2

pi = math.pi
e = math.e


class stdlib_statistics:
mean = statistics.mean
stdev = statistics.stdev
geometric_mean = statistics.geometric_mean
median = statistics.median
median_low = statistics.median_low
median_high = statistics.median_high
mode = statistics.mode
quantiles = statistics.quantiles


stdlib = {
"re": stdlib_re,
"json": stdlib_json,
"yaml": stdlib_yaml,
"time": stdlib_time,
"random": stdlib_random,
"itertools": stdlib_itertools,
"functools": stdlib_functools,
"base64": stdlib_base64,
"urllib": stdlib_urllib,
"string": stdlib_string,
"zoneinfo": stdlib_zoneinfo,
"datetime": stdlib_datetime,
"math": stdlib_math,
"statistics": stdlib_statistics,
}


Expand All @@ -50,7 +209,7 @@ def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
names=names | stdlib, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/common/exceptions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import temporalio.exceptions

# List of error types that should not be retried
NON_RETRYABLE_ERROR_TYPES = [
NON_RETRYABLE_ERROR_TYPES = (
# Temporal-specific errors
temporalio.exceptions.WorkflowAlreadyStartedError,
temporalio.exceptions.TerminatedError,
Expand Down Expand Up @@ -99,10 +99,10 @@
litellm.exceptions.ServiceUnavailableError,
litellm.exceptions.OpenAIError,
litellm.exceptions.APIError,
]
)


def is_non_retryable_error(error: Exception) -> bool:
def is_non_retryable_error(error: BaseException) -> bool:
"""
Determines if the given error is non-retryable.

Expand All @@ -115,4 +115,4 @@ def is_non_retryable_error(error: Exception) -> bool:
Returns:
bool: True if the error is non-retryable, False otherwise.
"""
return isinstance(error, tuple(NON_RETRYABLE_ERROR_TYPES))
return isinstance(error, NON_RETRYABLE_ERROR_TYPES)
4 changes: 2 additions & 2 deletions agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CustomActivityInterceptor(ActivityInboundInterceptor):
async def execute_activity(self, input: ExecuteActivityInput):
try:
return await super().execute_activity(input)
except Exception as e:
except BaseException as e:
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
if is_non_retryable_error(e):
raise ApplicationError(
str(e),
Expand All @@ -53,7 +53,7 @@ class CustomWorkflowInterceptor(WorkflowInboundInterceptor):
async def execute_workflow(self, input: ExecuteWorkflowInput):
try:
return await super().execute_workflow(input)
except Exception as e:
except BaseException as e:
if is_non_retryable_error(e):
raise ApplicationError(
str(e),
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@
} # type: ignore


PartialTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest)
class PartialTransition(create_partial_model(CreateTransitionRequest)):
user_state: dict[str, Any] = Field(default_factory=dict)


class ExecutionInput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def create_execution_transition(
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type == "finish" else None,
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
Expand Down
Loading