Skip to content

Commit

Permalink
Merge branch 'dev' into d/new-cookbooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Oct 11, 2024
2 parents 02ec74d + a51441a commit 7ffae7f
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 75 deletions.
14 changes: 11 additions & 3 deletions agents-api/agents_api/activities/excecute_api_call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Annotated, Any, Optional, TypedDict, Union

import httpx
Expand All @@ -20,6 +21,8 @@ class RequestArgs(TypedDict):
json_: Optional[dict[str, Any]]
cookies: Optional[dict[str, str]]
params: Optional[Union[str, dict[str, Any]]]
url: Optional[str]
headers: Optional[dict[str, str]]


@beartype
Expand All @@ -29,18 +32,23 @@ async def execute_api_call(
) -> Any:
try:
async with httpx.AsyncClient() as client:
arg_url = request_args.pop("url", None)
arg_headers = request_args.pop("headers", None)

response = await client.request(
method=api_call.method,
url=str(api_call.url),
headers=api_call.headers,
url=arg_url or str(api_call.url),
headers=arg_headers or api_call.headers,
follow_redirects=api_call.follow_redirects,
**request_args,
)

content_base64 = base64.b64encode(response.content).decode("ascii")

response_dict = {
"status_code": response.status_code,
"headers": dict(response.headers),
"content": response.content,
"content": content_base64,
"json": response.json(),
}

Expand Down
189 changes: 174 additions & 15 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
import base64
import datetime as dt
import functools
import itertools
import json
from functools import reduce
from itertools import accumulate
from random import random
from time import time
from typing import Any, Callable
import math
import random
import statistics
import string
import time
import urllib.parse
from typing import Any, Callable, ParamSpec, Type, TypeVar, cast

import re2
import yaml
import zoneinfo
from beartype import beartype
from simpleeval import EvalWithCompoundTypes, SimpleEval
from yaml import CSafeLoader
from yaml import CSafeDumper, CSafeLoader

T = TypeVar("T")


P = ParamSpec("P")
R = TypeVar("R")


# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"abs": abs,
"accumulate": accumulate,
"all": all,
"any": any,
"bool": bool,
Expand All @@ -25,23 +38,169 @@
"int": int,
"len": len,
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"map": map,
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"random": random,
"range": range,
"reduce": reduce,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"set": set,
"str": str,
"sum": sum,
"time": time,
"tuple": tuple,
"reduce": functools.reduce,
"zip": zip,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
}


class stdlib_re:
fullmatch = re2.fullmatch
search = re2.search
escape = re2.escape
findall = re2.findall
finditer = re2.finditer
match = re2.match
split = re2.split
sub = re2.sub
subn = re2.subn


class stdlib_json:
loads = json.loads
dumps = json.dumps


class stdlib_yaml:
load = lambda string: yaml.load(string, Loader=CSafeLoader) # noqa: E731
dump = lambda value: yaml.dump(value, Dumper=CSafeDumper) # noqa: E731


class stdlib_time:
strftime = time.strftime
strptime = time.strptime
time = time


class stdlib_random:
choice = random.choice
choices = random.choices
sample = random.sample
shuffle = random.shuffle
randrange = random.randrange
randint = random.randint
random = random.random


class stdlib_itertools:
accumulate = itertools.accumulate


class stdlib_functools:
partial = functools.partial
reduce = functools.reduce


class stdlib_base64:
b64encode = base64.b64encode
b64decode = base64.b64decode


class stdlib_urllib:
class parse:
urlparse = urllib.parse.urlparse
urlencode = urllib.parse.urlencode
unquote = urllib.parse.unquote
quote = urllib.parse.quote
parse_qs = urllib.parse.parse_qs
parse_qsl = urllib.parse.parse_qsl
urlsplit = urllib.parse.urlsplit
urljoin = urllib.parse.urljoin
unwrap = urllib.parse.unwrap


class stdlib_string:
ascii_letters = string.ascii_letters
ascii_lowercase = string.ascii_lowercase
ascii_uppercase = string.ascii_uppercase
digits = string.digits
hexdigits = string.hexdigits
octdigits = string.octdigits
punctuation = string.punctuation
whitespace = string.whitespace
printable = string.printable


class stdlib_zoneinfo:
ZoneInfo = zoneinfo.ZoneInfo


class stdlib_datetime:
class timezone:
class utc:
utc = dt.timezone.utc

class datetime:
now = dt.datetime.now
datetime = dt.datetime
timedelta = dt.timedelta
date = dt.date
time = dt.time

timedelta = dt.timedelta


class stdlib_math:
sqrt = math.sqrt
exp = math.exp
ceil = math.ceil
floor = math.floor
isinf = math.isinf
isnan = math.isnan
log = math.log
log10 = math.log10
log2 = math.log2
pow = math.pow
sin = math.sin
cos = math.cos
tan = math.tan
asin = math.asin
acos = math.acos
atan = math.atan
atan2 = math.atan2

pi = math.pi
e = math.e


class stdlib_statistics:
mean = statistics.mean
stdev = statistics.stdev
geometric_mean = statistics.geometric_mean
median = statistics.median
median_low = statistics.median_low
median_high = statistics.median_high
mode = statistics.mode
quantiles = statistics.quantiles


stdlib = {
"re": stdlib_re,
"json": stdlib_json,
"yaml": stdlib_yaml,
"time": stdlib_time,
"random": stdlib_random,
"itertools": stdlib_itertools,
"functools": stdlib_functools,
"base64": stdlib_base64,
"urllib": stdlib_urllib,
"string": stdlib_string,
"zoneinfo": stdlib_zoneinfo,
"datetime": stdlib_datetime,
"math": stdlib_math,
"statistics": stdlib_statistics,
}


Expand All @@ -50,7 +209,7 @@ def get_evaluator(
names: dict[str, Any], extra_functions: dict[str, Callable] | None = None
) -> SimpleEval:
evaluator = EvalWithCompoundTypes(
names=names, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
names=names | stdlib, functions=ALLOWED_FUNCTIONS | (extra_functions or {})
)

return evaluator
Expand Down
8 changes: 4 additions & 4 deletions agents-api/agents_api/common/exceptions/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import temporalio.exceptions

# List of error types that should not be retried
NON_RETRYABLE_ERROR_TYPES = [
NON_RETRYABLE_ERROR_TYPES = (
# Temporal-specific errors
temporalio.exceptions.WorkflowAlreadyStartedError,
temporalio.exceptions.TerminatedError,
Expand Down Expand Up @@ -99,10 +99,10 @@
litellm.exceptions.ServiceUnavailableError,
litellm.exceptions.OpenAIError,
litellm.exceptions.APIError,
]
)


def is_non_retryable_error(error: Exception) -> bool:
def is_non_retryable_error(error: BaseException) -> bool:
"""
Determines if the given error is non-retryable.
Expand All @@ -115,4 +115,4 @@ def is_non_retryable_error(error: Exception) -> bool:
Returns:
bool: True if the error is non-retryable, False otherwise.
"""
return isinstance(error, tuple(NON_RETRYABLE_ERROR_TYPES))
return isinstance(error, NON_RETRYABLE_ERROR_TYPES)
4 changes: 2 additions & 2 deletions agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CustomActivityInterceptor(ActivityInboundInterceptor):
async def execute_activity(self, input: ExecuteActivityInput):
try:
return await super().execute_activity(input)
except Exception as e:
except BaseException as e:
if is_non_retryable_error(e):
raise ApplicationError(
str(e),
Expand All @@ -53,7 +53,7 @@ class CustomWorkflowInterceptor(WorkflowInboundInterceptor):
async def execute_workflow(self, input: ExecuteWorkflowInput):
try:
return await super().execute_workflow(input)
except Exception as e:
except BaseException as e:
if is_non_retryable_error(e):
raise ApplicationError(
str(e),
Expand Down
3 changes: 2 additions & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@
} # type: ignore


PartialTransition: Type[BaseModel] = create_partial_model(CreateTransitionRequest)
class PartialTransition(create_partial_model(CreateTransitionRequest)):
user_state: dict[str, Any] = Field(default_factory=dict)


class ExecutionInput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def create_execution_transition(
data=UpdateExecutionRequest(
status=transition_to_execution_status[data.type]
),
output=data.output if data.type == "finish" else None,
output=data.output if data.type != "error" else None,
error=str(data.output)
if data.type == "error" and data.output
else None,
Expand Down
Loading

0 comments on commit 7ffae7f

Please sign in to comment.