From de3bfa9bf56cb96c3ea4ffdbfff3b76bb9d19789 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Mon, 30 Dec 2024 17:28:53 +0100 Subject: [PATCH 1/3] removed legacy, fix inline trace, fix multiple choice, fix content type, remove entrypoint, etc. --- agenta-cli/agenta/__init__.py | 6 +- agenta-cli/agenta/sdk/__init__.py | 2 +- agenta-cli/agenta/sdk/decorators/routing.py | 601 ++++++------------ agenta-cli/agenta/sdk/middleware/config.py | 19 +- agenta-cli/agenta/sdk/middleware/inline.py | 28 + agenta-cli/agenta/sdk/types.py | 165 +++-- .../debugging/simple-app/agenta/__init__.py | 2 +- .../simple-app/agenta/sdk/__init__.py | 2 +- agenta-web/package-lock.json | 240 +++---- services/chat-new-sdk/_app.py | 5 +- services/completion-new-sdk-prompt.rest | 24 +- services/completion-new-sdk-prompt/_app.py | 19 +- .../docker-compose.yml | 2 +- services/completion-new-sdk.rest | 51 +- services/completion-new-sdk/_app.py | 12 +- 15 files changed, 510 insertions(+), 668 deletions(-) create mode 100644 agenta-cli/agenta/sdk/middleware/inline.py diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 71097cd83..c21691717 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -5,8 +5,9 @@ import agenta.client.backend.types as client_types # pylint: disable=wrong-import-order from .sdk.types import ( + BaseConfigModel, DictInput, - MultipleChoice, + MCField, FloatParam, IntParam, MultipleChoiceParam, @@ -16,13 +17,14 @@ FileInputURL, BinaryParam, Prompt, + PromptTemplate, ) from .sdk.utils.logging import log as logging from .sdk.tracing import Tracing, get_tracer from .sdk.decorators.tracing import instrument from .sdk.tracing.conventions import Reference -from .sdk.decorators.routing import entrypoint, app, route +from .sdk.decorators.routing import app, route from .sdk.agenta_init import Config, AgentaSingleton, init as _init from .sdk.utils.costs import calculate_token_usage from .sdk.client import Agenta diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index 6d3c4da84..c185c8824 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -23,7 +23,7 @@ from .tracing import Tracing, get_tracer from .decorators.tracing import instrument from .tracing.conventions import Reference -from .decorators.routing import entrypoint, app, route +from .decorators.routing import app, route from .agenta_init import Config, AgentaSingleton, init as _init from .utils.costs import calculate_token_usage from .managers.vault import VaultManager diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 783533f65..101c55530 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -1,19 +1,20 @@ -from typing import Type, Any, Callable, Dict, Optional, Tuple, List -from inspect import signature, iscoroutinefunction, Signature, Parameter, _empty +from typing import Type, Any, Callable, Dict, Optional, Tuple +from inspect import signature, iscoroutinefunction, Parameter from functools import wraps from traceback import format_exception from asyncio import sleep +from json import dumps +from uuid import UUID -from tempfile import NamedTemporaryFile -from annotated_types import Ge, Le, Gt, Lt from pydantic import BaseModel, HttpUrl, ValidationError -from fastapi import Body, FastAPI, UploadFile, HTTPException, Request +from fastapi import Body, FastAPI, HTTPException, Request +from agenta.sdk.middleware.inline import InlineMiddleware +from agenta.sdk.middleware.vault import VaultMiddleware +from agenta.sdk.middleware.config import ConfigMiddleware from agenta.sdk.middleware.auth import AuthMiddleware from agenta.sdk.middleware.otel import OTelMiddleware -from agenta.sdk.middleware.config import ConfigMiddleware -from agenta.sdk.middleware.vault import VaultMiddleware from agenta.sdk.middleware.cors import CORSMiddleware from agenta.sdk.context.routing import ( @@ -25,22 +26,14 @@ tracing_context, TracingContext, ) +from agenta.sdk.utils.exceptions import ( + display_exception, + suppress, +) from agenta.sdk.router import router -from agenta.sdk.utils.exceptions import suppress, display_exception from agenta.sdk.utils.logging import log -from agenta.sdk.types import ( - DictInput, - FloatParam, - IntParam, - MultipleChoiceParam, - MultipleChoice, - GroupedMultipleChoiceParam, - TextParam, - MessagesInput, - FileInputURL, - BaseResponse, - BinaryParam, -) + +from agenta.sdk.types import BaseResponse import agenta as ag @@ -57,246 +50,156 @@ class PathValidator(BaseModel): class route: # pylint: disable=invalid-name - # This decorator is used to expose specific stages of a workflow (embedding, retrieval, summarization, etc.) - # as independent endpoints. It is designed for backward compatibility with existing code that uses - # the @entrypoint decorator, which has certain limitations. By using @route(), we can create new - # routes without altering the main workflow entrypoint. This helps in modularizing the services - # and provides flexibility in how we expose different functionalities as APIs. - def __init__( - self, - path: Optional[str] = "/", - config_schema: Optional[BaseModel] = None, - ): - self.config_schema: BaseModel = config_schema - path = "/" + path.strip("/").strip() - path = "" if path == "/" else path - PathValidator(url=f"http://example.com{path}") - - self.route_path = path - - self.e = None - - def __call__(self, f): - self.e = entrypoint( - f, - route_path=self.route_path, - config_schema=self.config_schema, - ) - - return f - - -class entrypoint: """ Decorator class to wrap a function for HTTP POST, terminal exposure and enable tracing. This decorator generates the following endpoints: Playground Endpoints - - /generate with @entrypoint, @route("/"), @route(path="") # LEGACY - - /playground/run with @entrypoint, @route("/"), @route(path="") - - /playground/run/{route} with @route({route}), @route(path={route}) - - Deployed Endpoints: - - /generate_deployed with @entrypoint, @route("/"), @route(path="") # LEGACY - - /run with @entrypoint, @route("/"), @route(path="") - - /run/{route} with @route({route}), @route(path={route}) + - /test with e.g. @route("/"), @route(path="") + - /test/{route} with e.g. @route({route}), @route(path={route}) - The rationale is: - - There may be multiple endpoints, based on the different routes. - - It's better to make it explicit that an endpoint is for the playground. - - Prefixing the routes with /run is more futureproof in case we add more endpoints. + Environment Endpoints: + - /run with e.g. @route("/"), @route(path="") + - /run/{route} with e.g. @route({route}), @route(path={route}) Example: ```python import agenta as ag - @ag.entrypoint + @ag.route() async def chain_of_prompts_llm(prompt: str): return ... ``` """ routes = list() - _middleware = False + _run_path = "/run" _test_path = "/test" - # LEGACY - _legacy_playground_run_path = "/playground/run" - _legacy_generate_path = "/generate" - _legacy_generate_deployed_path = "/generate_deployed" + + _config_key = "ag_config" def __init__( self, - func: Callable[..., Any], - route_path: str = "", + path: Optional[str] = "/", config_schema: Optional[BaseModel] = None, + content_type: Optional[str] = None, ): + self.route_path = "/" + path.strip("/").strip() + self.route_path = "" if self.route_path == "/" else self.route_path + self.config_schema: BaseModel = config_schema + self.content_type = content_type + + PathValidator(url=f"http://example.com{path}") + + self.func = None + self.config = None + self.default_parameters = {} + + self.parse_config() + + if not route._middleware: + route._middleware = True + self.attach_middleware() + + def __call__( + self, + func: Callable[..., Any], + ) -> Callable[..., Any]: self.func = func - self.route_path = route_path - self.config_schema = config_schema - signature_parameters = signature(func).parameters - config, default_parameters = self.parse_config() + self.create_run_route() + self.create_test_route() - ### --- Middleware --- # - if not entrypoint._middleware: - entrypoint._middleware = True + # --- Route(r) Setup --- # - app.add_middleware(VaultMiddleware) - app.add_middleware(ConfigMiddleware) - app.add_middleware(AuthMiddleware) - app.add_middleware(OTelMiddleware) - app.add_middleware(CORSMiddleware) - ### ------------------ # + def parse_config(self) -> Tuple[Optional[Type[BaseModel]], Dict[str, Any]]: + if self.config_schema: + try: + self.config = self.config_schema() if self.config_schema else None + self.default_parameters = self.config.dict() if self.config else {} + except ValidationError as e: + raise ValueError( + f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" + ) from e + except Exception as e: + raise ValueError( + f"Unexpected error initializing config_schema: {str(e)}" + ) from e - ### --- Run --- # - @wraps(func) + def attach_middleware(self): + app.add_middleware(InlineMiddleware) + app.add_middleware(VaultMiddleware) + app.add_middleware(ConfigMiddleware) + app.add_middleware(AuthMiddleware) + app.add_middleware(OTelMiddleware) + app.add_middleware(CORSMiddleware) + + # --- Route Registration --- # + + def create_run_route(self): + @wraps(self.func) async def run_wrapper(request: Request, *args, **kwargs) -> Any: - # LEGACY - # TODO: Removing this implies breaking changes in : - # - calls to /generate_deployed - kwargs = { - k: v - for k, v in kwargs.items() - if k not in ["config", "environment", "app"] - } - # LEGACY - - kwargs, _ = self.process_kwargs(kwargs, default_parameters) - if request.state.config["parameters"] is None: + kwargs, _ = self.process_kwargs(kwargs) + + if ( + request.state.config["parameters"] is None + or request.state.config["references"] is None + ): raise HTTPException( status_code=400, detail="Config not found based on provided references.", ) + return await self.execute_wrapper(request, *args, **kwargs) - return await self.execute_wrapper(request, False, *args, **kwargs) + self.update_wrapper_signature(wrapper=run_wrapper, add_config=False) - self.update_run_wrapper_signature(wrapper=run_wrapper) - - run_route = f"{entrypoint._run_path}{route_path}" + run_route = f"{route._run_path}{self.route_path}" app.post(run_route, response_model=BaseResponse)(run_wrapper) - # LEGACY - # TODO: Removing this implies breaking changes in : - # - calls to /generate_deployed must be replaced with calls to /run - if route_path == "": - run_route = entrypoint._legacy_generate_deployed_path - app.post(run_route, response_model=BaseResponse)(run_wrapper) - # LEGACY - ### ----------- # - - ### --- Test --- # - @wraps(func) + def create_test_route(self): + @wraps(self.func) async def test_wrapper(request: Request, *args, **kwargs) -> Any: - kwargs, config = self.process_kwargs(kwargs, default_parameters) + kwargs, config = self.process_kwargs(kwargs) + request.state.inline = True request.state.config["parameters"] = config - return await self.execute_wrapper(request, True, *args, **kwargs) - self.update_test_wrapper_signature( - wrapper=test_wrapper, - config_instance=config - ) + if request.state.config["references"]: + request.state.config["references"] = { + k: v + for k, v in request.state.config["references"].items() + if k.startswith("application") + } or None - test_route = f"{entrypoint._test_path}{route_path}" - app.post(test_route, response_model=BaseResponse)(test_wrapper) + return await self.execute_wrapper(request, *args, **kwargs) - # LEGACY - # TODO: Removing this implies breaking changes in : - # - calls to /generate must be replaced with calls to /test - if route_path == "": - test_route = entrypoint._legacy_generate_path - app.post(test_route, response_model=BaseResponse)(test_wrapper) - # LEGACY - - # LEGACY - # TODO: Removing this implies no breaking changes - if route_path == "": - test_route = entrypoint._legacy_playground_run_path - app.post(test_route, response_model=BaseResponse)(test_wrapper) - # LEGACY - ### ------------ # - - ### --- OpenAPI --- # - test_route = f"{entrypoint._test_path}{route_path}" - entrypoint.routes.append( - { - "func": func.__name__, - "endpoint": test_route, - "params": signature_parameters, - "config": config, - } - ) + self.update_wrapper_signature(wrapper=test_wrapper, add_config=True) - # LEGACY - if route_path == "": - test_route = entrypoint._legacy_generate_path - entrypoint.routes.append( - { - "func": func.__name__, - "endpoint": test_route, - "params": ( - {**default_parameters, **signature_parameters} - if not config - else signature_parameters - ), - "config": config, - } - ) - # LEGACY - - app.openapi_schema = None # Forces FastAPI to re-generate the schema - openapi_schema = app.openapi() - - for _route in entrypoint.routes: - if _route["config"] is not None: - self.override_config_in_schema( - openapi_schema=openapi_schema, - func_name=_route["func"], - endpoint=_route["endpoint"], - config=_route["config"], - ) - ### --------------- # + test_route = f"{route._test_path}{self.route_path}" + app.post(test_route, response_model=BaseResponse)(test_wrapper) - def parse_config(self) -> Tuple[Optional[Type[BaseModel]], Dict[str, Any]]: - """Parse the config schema and return the config class and default parameters.""" - config = None - default_parameters = {} + def process_kwargs( + self, + kwargs: Dict[str, Any], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + config_params = kwargs.pop(route._config_key, {}) # TODO: rename this - if self.config_schema: - try: - config = self.config_schema() if self.config_schema else None - default_parameters = config.dict() if config else {} - except ValidationError as e: - raise ValueError( - f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" - ) from e - except Exception as e: - raise ValueError( - f"Unexpected error initializing config_schema: {str(e)}" - ) from e + if isinstance(config_params, BaseModel): # TODO: explain this + config_params = config_params.model_dump() - return config, default_parameters + config = {**self.default_parameters, **config_params} - def process_kwargs( - self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any] - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Remove the config parameters from the kwargs.""" - # Extract agenta_config if present - config_params = kwargs.pop("agenta_config", {}) - if isinstance(config_params, BaseModel): - config_params = config_params.dict() - # Merge with default parameters - config = {**default_parameters, **config_params} return kwargs, config + # --- Function Request/Response --- # + async def execute_wrapper( self, request: Request, - inline: bool, *args, **kwargs, ): @@ -308,6 +211,7 @@ async def execute_wrapper( parameters = state.config.get("parameters") references = state.config.get("references") secrets = state.vault.get("secrets") + inline = state.inline with routing_context_manager( context=RoutingContext( @@ -322,27 +226,19 @@ async def execute_wrapper( references=references, ) ): - result = await self.execute_function(inline, *args, **kwargs) + try: + result = ( + await self.func(*args, **kwargs) + if iscoroutinefunction(self.func) + else self.func(*args, **kwargs) + ) - return result + return await self.handle_success(result, inline) - async def execute_function( - self, - inline: bool, - *args, - **kwargs, - ): - try: - result = ( - await self.func(*args, **kwargs) - if iscoroutinefunction(self.func) - else self.func(*args, **kwargs) - ) + except Exception as error: # pylint: disable=broad-except + self.handle_failure(error) - return await self.handle_success(result, inline) - - except Exception as error: # pylint: disable=broad-except - self.handle_failure(error) + return result async def handle_success( self, @@ -350,21 +246,38 @@ async def handle_success( inline: bool, ): data = None + content_type = self.content_type tree = None - content_type = "string" + tree_id = None + + print(result) + print(content_type) with suppress(): - if isinstance(result, (dict, list)): - content_type = "json" - data = self.patch_result(result) + if isinstance(result, str): + content_type = "text/plain" + data = result + elif not isinstance(result, str) and content_type == "text/plain": + data = dumps(result) if inline: - tree = await self.fetch_inline_trace(inline) + tree, tree_id = await self.fetch_inline_trace(inline) try: - return BaseResponse(data=data, tree=tree, content_type=content_type) - except: - return BaseResponse(data=data, content_type=content_type) + return BaseResponse( + data=data, + content_type=content_type, + tree=tree, + tree_id=tree_id, + ) + + except: # pylint: disable=bare-except + display_exception("Response Exception") + + return BaseResponse( + data=data, + content_type=content_type, + ) def handle_failure( self, @@ -372,217 +285,87 @@ def handle_failure( ): display_exception("Application Exception") - status_code = 500 - message = str(error) - stacktrace = format_exception(error, value=error, tb=error.__traceback__) # type: ignore - detail = {"message": message, "stacktrace": stacktrace} - - raise HTTPException(status_code=status_code, detail=detail) - - def patch_result( - self, - result: Any, - ): - """ - Patch the result to only include the message if the result is a FuncResponse-style dictionary with message, cost, and usage keys. - - Example: - ```python - result = { - "message": "Hello, world!", - "cost": 0.5, - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - "total_tokens": 30 - } - } - result = patch_result(result) - print(result) - # Output: "Hello, world!" - ``` - """ - data = ( - result["message"] - if isinstance(result, dict) - and all(key in result for key in ["message", "cost", "usage"]) - else result + raise HTTPException( + status_code=500, + detail={ + "message": str(error), + "stacktrace": format_exception( + error, value=error, tb=error.__traceback__ + ), + }, ) - if data is None: - data = ( - "Function executed successfully, but did return None. \n Are you sure you did not forget to return a value?", - ) - - if not isinstance(result, dict): - data = str(data) - - return data - async def fetch_inline_trace( self, - inline, + inline: bool, ): - WAIT_FOR_SPANS = True - TIMEOUT = 1 TIMESTEP = 0.1 - FINALSTEP = 0.001 - NOFSTEPS = TIMEOUT / TIMESTEP - - trace = None + NOFSTEPS = 1 / TIMESTEP context = tracing_context.get() - link = context.link - trace_id = link.get("tree_id") if link else None + tree = None + _tree_id = link.get("tree_id") if link else None + tree_id = str(UUID(int=_tree_id)) if _tree_id else None - if trace_id is not None: + if _tree_id is not None: if inline: - if WAIT_FOR_SPANS: - remaining_steps = NOFSTEPS - - while ( - not ag.tracing.is_inline_trace_ready(trace_id) - and remaining_steps > 0 - ): - await sleep(TIMESTEP) + remaining_steps = NOFSTEPS - remaining_steps -= 1 + while ( + not ag.tracing.is_inline_trace_ready(_tree_id) + and remaining_steps > 0 + ): + await sleep(TIMESTEP) - await sleep(FINALSTEP) + remaining_steps -= 1 - trace = ag.tracing.get_inline_trace(trace_id) - else: - trace = {"trace_id": trace_id} + tree = ag.tracing.get_inline_trace(_tree_id) - return trace + return tree, tree_id - # --- OpenAPI --- # + # --- Function Signature --- # - def add_request_to_signature( + def update_wrapper_signature( self, wrapper: Callable[..., Any], + add_config: bool, ): - original_sig = signature(wrapper) parameters = [ Parameter( - "request", + name="request", kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request, - ), - *original_sig.parameters.values(), - ] - new_sig = Signature( - parameters, - return_annotation=original_sig.return_annotation, - ) - wrapper.__signature__ = new_sig - - def update_wrapper_signature( - self, wrapper: Callable[..., Any], updated_params: List - ): - """ - Updates the signature of a wrapper function with a new list of parameters. - - Args: - wrapper (callable): A callable object, such as a function or a method, that requires a signature update. - updated_params (List[Parameter]): A list of `Parameter` objects representing the updated parameters - for the wrapper function. - """ - - wrapper_signature = signature(wrapper) - wrapper_signature = wrapper_signature.replace(parameters=updated_params) - wrapper.__signature__ = wrapper_signature # type: ignore - - def update_test_wrapper_signature( - self, - wrapper: Callable[..., Any], - config_instance: Type[BaseModel], # TODO: change to our type - ) -> None: - """Update the function signature to include new parameters.""" - - updated_params: List[Parameter] = [] - self.add_config_params_to_parser(updated_params, config_instance) - self.add_func_params_to_parser(updated_params) - self.update_wrapper_signature(wrapper, updated_params) - self.add_request_to_signature(wrapper) - - def update_run_wrapper_signature( - self, - wrapper: Callable[..., Any], - ) -> None: - """Update the function signature to include new parameters.""" - - updated_params: List[Parameter] = [] - self.add_func_params_to_parser(updated_params) - self.update_wrapper_signature(wrapper, updated_params) - self.add_request_to_signature(wrapper) - - def add_config_params_to_parser( - self, updated_params: list, config_instance: Type[BaseModel] - ) -> None: - """Add configuration parameters to function signature.""" - for name, field in config_instance.__fields__.items(): - assert field.default is not None, f"Field {name} has no default value" - updated_params.append( - Parameter( - name="agenta_config", - kind=Parameter.KEYWORD_ONLY, - annotation=type(config_instance), # Get the actual class type - default=Body(config_instance), # Use the instance directly ) - ) + ] - def add_func_params_to_parser(self, updated_params: list) -> None: - """Add function parameters to function signature.""" for name, param in signature(self.func).parameters.items(): assert ( len(param.default.__class__.__bases__) == 1 ), f"Inherited standard type of {param.default.__class__} needs to be one." - updated_params.append( + + parameters.append( Parameter( - name, - Parameter.KEYWORD_ONLY, - default=Body(..., embed=True), - annotation=param.default.__class__.__bases__[ - 0 - ], # determines and get the base (parent/inheritance) type of the sdk-type at run-time. \ - # E.g __class__ is ag.MessagesInput() and accessing it parent type will return (,), \ - # thus, why we are accessing the first item. + name=name, + kind=Parameter.KEYWORD_ONLY, + default=Body(param.default, embed=True), + annotation=param.default.__class__.__bases__[0], + # ^ determines and gets the base (parent/inheritance) type of the SDK type at run-time. ) ) - def override_config_in_schema( - self, - openapi_schema: dict, - func_name: str, - endpoint: str, - config: Type[BaseModel], - ): - """Override config in OpenAPI schema to add agenta-specific metadata.""" - endpoint = endpoint[1:].replace("/", "_") - schema_key = f"Body_{func_name}_{endpoint}_post" - schema_to_override = openapi_schema["components"]["schemas"][schema_key] - - # Get the config class name to find its schema - config_class_name = type(config).__name__ - config_schema = openapi_schema["components"]["schemas"][config_class_name] - - # Process each field in the config class - for field_name, field in config.__class__.__fields__.items(): - # Check if field has Annotated metadata for MultipleChoice - if hasattr(field, "metadata") and field.metadata: - for meta in field.metadata: - if isinstance(meta, MultipleChoice): - choices = meta.choices - if isinstance(choices, dict): - config_schema["properties"][field_name].update({ - "x-parameter": "grouped_choice", - "choices": choices - }) - elif isinstance(choices, list): - config_schema["properties"][field_name].update({ - "x-parameter": "choice", - "enum": choices - }) + if self.config and add_config: + for name, field in self.config.model_fields.items(): + assert field.default is not None, f"Field {name} has no default value" + + parameters.append( + Parameter( + name=self._config_key, + kind=Parameter.KEYWORD_ONLY, + annotation=type(self.config), # Get the actual class type + default=Body(self.config), # Use the instance directly + ) + ) + + wrapper.__signature__ = signature(wrapper).replace(parameters=parameters) diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py index 8ea9eb9ff..5c307a4fa 100644 --- a/agenta-cli/agenta/sdk/middleware/config.py +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -40,7 +40,7 @@ async def dispatch( request: Request, call_next: Callable, ): - request.state.config = {} + request.state.config = {"parameters": None, "references": None} with suppress(): parameters, references = await self._get_config(request) @@ -105,24 +105,21 @@ async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: config = response.json() - if not config: - _cache.put(_hash, {"parameters": None, "references": None}) - - return None, None - parameters = config.get("params") references = {} for ref_key in ["application_ref", "variant_ref", "environment_ref"]: refs = config.get(ref_key) - ref_prefix = ref_key.split("_", maxsplit=1)[0] - for ref_part_key in ["id", "slug", "version"]: - ref_part = refs.get(ref_part_key) + if refs: + ref_prefix = ref_key.split("_", maxsplit=1)[0] + + for ref_part_key in ["id", "slug", "version"]: + ref_part = refs.get(ref_part_key) - if ref_part: - references[ref_prefix + "." + ref_part_key] = ref_part + if ref_part: + references[ref_prefix + "." + ref_part_key] = ref_part _cache.put(_hash, {"parameters": parameters, "references": references}) diff --git a/agenta-cli/agenta/sdk/middleware/inline.py b/agenta-cli/agenta/sdk/middleware/inline.py new file mode 100644 index 000000000..1a0d71d11 --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/inline.py @@ -0,0 +1,28 @@ +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request, FastAPI + + +from agenta.sdk.utils.exceptions import suppress + +from agenta.sdk.utils.constants import TRUTHY + + +class InlineMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + request.state.inline = False + + with suppress(): + inline = str(request.query_params.get("inline")) in TRUTHY + + request.state.inline = inline + + return await call_next(request) diff --git a/agenta-cli/agenta/sdk/types.py b/agenta-cli/agenta/sdk/types.py index cd12a99a1..9f77c62f2 100644 --- a/agenta-cli/agenta/sdk/types.py +++ b/agenta-cli/agenta/sdk/types.py @@ -11,11 +11,46 @@ from agenta.sdk.assets import supported_llm_models +class BaseConfigModel(BaseModel): + @classmethod + def model_json_schema( + cls, + by_alias: bool = True, + ref_template: str = "#/components/schemas/{model}", + schema_generator: Any = None, + mode: Any = None, + ) -> dict: + schema = super().model_json_schema(by_alias=by_alias, ref_template=ref_template) + + for field_name, field in cls.model_fields.items(): + if "multiple_choice" in field.json_schema_extra: + values: MultipleChoice = field.json_schema_extra["multiple_choice"] + choices = values.choices + + if isinstance(choices, dict): + schema["properties"][field_name]["x-parameter"] = "grouped_choice" + schema["properties"][field_name]["choices"] = choices + elif isinstance(choices, list): + schema["properties"][field_name]["x-parameter"] = "choice" + schema["properties"][field_name]["enum"] = choices + + return schema + + @dataclass class MultipleChoice: choices: Union[List[str], Dict[str, List[str]]] +def MCField( # pylint: disable=invalid-name + default: str, + choices: Union[List[str], Dict[str, List[str]]], +) -> Field: + field = Field(default=default) + field.json_schema_extra = {"multiple_choice": MultipleChoice(choices)} + return field + + class LLMTokenUsage(BaseModel): completion_tokens: int prompt_tokens: int @@ -26,6 +61,7 @@ class BaseResponse(BaseModel): version: Optional[str] = "3.1" data: Optional[Union[str, Dict[str, Any]]] = None content_type: Optional[str] = "string" + tree_id: Optional[str] = None tree: Optional[AgentaNodesResponse] = None @@ -247,6 +283,7 @@ class Prompt(BaseModel): frequency_penalty: float presence_penalty: float + # ----------------------------------------------------- # New Prompt model # ----------------------------------------------------- @@ -257,6 +294,7 @@ class ToolCall(BaseModel): type: Literal["function"] = "function" function: Dict[str, str] + class Message(BaseModel): role: Literal["system", "user", "assistant", "tool", "function"] content: Optional[str] = None @@ -264,6 +302,7 @@ class Message(BaseModel): tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None + class ResponseFormatText(BaseModel): type: Literal["text"] """The type of response format being defined: `text`""" @@ -286,7 +325,7 @@ class JSONSchema(BaseModel): model_config = { "populate_by_name": True, - "json_schema_extra": {"required": ["name", "schema"]} + "json_schema_extra": {"required": ["name", "schema"]}, } @@ -296,98 +335,106 @@ class ResponseFormatJSONSchema(BaseModel): json_schema: JSONSchema -ResponseFormat = Union[ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema] +ResponseFormat = Union[ + ResponseFormatText, ResponseFormatJSONObject, ResponseFormatJSONSchema +] + class ModelConfig(BaseModel): """Configuration for model parameters""" + model: Annotated[str, MultipleChoice(choices=supported_llm_models)] = Field( - default="gpt-3.5-turbo", - description="ID of the model to use" + default="gpt-3.5-turbo", description="ID of the model to use" ) temperature: Optional[float] = Field( default=None, ge=0.0, le=2.0, - description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic" + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic", ) max_tokens: Optional[int] = Field( default=None, ge=0, - description="The maximum number of tokens that can be generated in the chat completion" + description="The maximum number of tokens that can be generated in the chat completion", ) top_p: Optional[float] = Field( default=None, ge=0.0, le=1.0, - description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass" + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass", ) frequency_penalty: Optional[float] = Field( default=None, ge=-2.0, le=2.0, - description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far" + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far", ) presence_penalty: Optional[float] = Field( default=None, ge=-2.0, le=2.0, - description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far" + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far", ) response_format: Optional[ResponseFormat] = Field( default=None, - description="An object specifying the format that the model must output" + description="An object specifying the format that the model must output", ) stream: Optional[bool] = Field( - default=None, - description="If set, partial message deltas will be sent" + default=None, description="If set, partial message deltas will be sent" ) tools: Optional[List[Dict]] = Field( default=None, - description="A list of tools the model may call. Currently, only functions are supported as a tool" + description="A list of tools the model may call. Currently, only functions are supported as a tool", ) tool_choice: Optional[Union[Literal["none", "auto"], Dict]] = Field( - default=None, - description="Controls which (if any) tool is called by the model" + default=None, description="Controls which (if any) tool is called by the model" ) + class PromptTemplateError(Exception): """Base exception for all PromptTemplate errors""" + pass + class InputValidationError(PromptTemplateError): """Raised when input validation fails""" - def __init__(self, message: str, missing: Optional[set] = None, extra: Optional[set] = None): + + def __init__( + self, message: str, missing: Optional[set] = None, extra: Optional[set] = None + ): self.missing = missing self.extra = extra super().__init__(message) + class TemplateFormatError(PromptTemplateError): """Raised when template formatting fails""" + def __init__(self, message: str, original_error: Optional[Exception] = None): self.original_error = original_error super().__init__(message) + class PromptTemplate(BaseModel): """A template for generating prompts with formatting capabilities""" + messages: List[Message] = Field( - default=[ - Message(role="system", content=""), - Message(role="user", content="") - ] + default=[Message(role="system", content=""), Message(role="user", content="")] ) system_prompt: Optional[str] = None user_prompt: Optional[str] = None template_format: Literal["fstring", "jinja2", "curly"] = Field( default="fstring", - description="Format type for template variables: fstring {var}, jinja2 {{ var }}, or curly {{var}}" + description="Format type for template variables: fstring {var}, jinja2 {{ var }}, or curly {{var}}", ) input_keys: Optional[List[str]] = Field( default=None, - description="Optional list of input keys for validation. If not provided, any inputs will be accepted" + description="Optional list of input keys for validation. If not provided, any inputs will be accepted", ) llm_config: ModelConfig = Field( default_factory=ModelConfig, - description="Configuration for the model parameters" + description="Configuration for the model parameters", ) model_config = { @@ -398,7 +445,6 @@ class PromptTemplate(BaseModel): } } - @root_validator(pre=True) def init_messages(cls, values): if "messages" not in values: @@ -418,26 +464,30 @@ def _format_with_template(self, content: str, kwargs: Dict[str, Any]) -> str: return content.format(**kwargs) elif self.template_format == "jinja2": from jinja2 import Template, TemplateError + try: return Template(content).render(**kwargs) except TemplateError as e: raise TemplateFormatError( f"Jinja2 template error in content: '{content}'. Error: {str(e)}", - original_error=e + original_error=e, ) elif self.template_format == "curly": import re + result = content for key, value in kwargs.items(): - result = re.sub(r'\{\{' + key + r'\}\}', str(value), result) - if re.search(r'\{\{.*?\}\}', result): - unreplaced = re.findall(r'\{\{(.*?)\}\}', result) + result = re.sub(r"\{\{" + key + r"\}\}", str(value), result) + if re.search(r"\{\{.*?\}\}", result): + unreplaced = re.findall(r"\{\{(.*?)\}\}", result) raise TemplateFormatError( f"Unreplaced variables in curly template: {unreplaced}" ) return result else: - raise TemplateFormatError(f"Unknown template format: {self.template_format}") + raise TemplateFormatError( + f"Unknown template format: {self.template_format}" + ) except KeyError as e: key = str(e).strip("'") raise TemplateFormatError( @@ -445,11 +495,10 @@ def _format_with_template(self, content: str, kwargs: Dict[str, Any]) -> str: ) except Exception as e: raise TemplateFormatError( - f"Error formatting template '{content}': {str(e)}", - original_error=e + f"Error formatting template '{content}': {str(e)}", original_error=e ) - def format(self, **kwargs) -> 'PromptTemplate': + def format(self, **kwargs) -> "PromptTemplate": """ Format the template with provided inputs. Only validates against input_keys if they are specified. @@ -462,18 +511,20 @@ def format(self, **kwargs) -> 'PromptTemplate': if self.input_keys is not None: missing = set(self.input_keys) - set(kwargs.keys()) extra = set(kwargs.keys()) - set(self.input_keys) - + error_parts = [] if missing: - error_parts.append(f"Missing required inputs: {', '.join(sorted(missing))}") + error_parts.append( + f"Missing required inputs: {', '.join(sorted(missing))}" + ) if extra: error_parts.append(f"Unexpected inputs: {', '.join(sorted(extra))}") - + if error_parts: raise InputValidationError( " | ".join(error_parts), missing=missing if missing else None, - extra=extra if extra else None + extra=extra if extra else None, ) new_messages = [] @@ -484,24 +535,26 @@ def format(self, **kwargs) -> 'PromptTemplate': except TemplateFormatError as e: raise TemplateFormatError( f"Error in message {i} ({msg.role}): {str(e)}", - original_error=e.original_error + original_error=e.original_error, ) else: new_content = None - - new_messages.append(Message( - role=msg.role, - content=new_content, - name=msg.name, - tool_calls=msg.tool_calls, - tool_call_id=msg.tool_call_id - )) - + + new_messages.append( + Message( + role=msg.role, + content=new_content, + name=msg.name, + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id, + ) + ) + return PromptTemplate( messages=new_messages, template_format=self.template_format, llm_config=self.llm_config, - input_keys=self.input_keys + input_keys=self.input_keys, ) def to_openai_kwargs(self) -> dict: @@ -514,29 +567,31 @@ def to_openai_kwargs(self) -> dict: # Add optional parameters only if they are set if self.llm_config.temperature is not None: kwargs["temperature"] = self.llm_config.temperature - + if self.llm_config.top_p is not None: kwargs["top_p"] = self.llm_config.top_p if self.llm_config.stream is not None: kwargs["stream"] = self.llm_config.stream - + if self.llm_config.max_tokens is not None: kwargs["max_tokens"] = self.llm_config.max_tokens - + if self.llm_config.frequency_penalty is not None: kwargs["frequency_penalty"] = self.llm_config.frequency_penalty - + if self.llm_config.presence_penalty is not None: kwargs["presence_penalty"] = self.llm_config.presence_penalty - + if self.llm_config.response_format: - kwargs["response_format"] = self.llm_config.response_format.dict(by_alias=True) - + kwargs["response_format"] = self.llm_config.response_format.dict( + by_alias=True + ) + if self.llm_config.tools: kwargs["tools"] = self.llm_config.tools # Only set tool_choice if tools are present if self.llm_config.tool_choice is not None: kwargs["tool_choice"] = self.llm_config.tool_choice - return kwargs \ No newline at end of file + return kwargs diff --git a/agenta-cli/debugging/simple-app/agenta/__init__.py b/agenta-cli/debugging/simple-app/agenta/__init__.py index 53c65db70..1f99a81d8 100644 --- a/agenta-cli/debugging/simple-app/agenta/__init__.py +++ b/agenta-cli/debugging/simple-app/agenta/__init__.py @@ -23,7 +23,7 @@ from .sdk.tracing import Tracing, get_tracer from .sdk.decorators.tracing import instrument from .sdk.tracing.conventions import Reference -from .sdk.decorators.routing import entrypoint, app, route +from .sdk.decorators.routing import app, route from .sdk.agenta_init import Config, AgentaSingleton, init as _init from .sdk.utils.costs import calculate_token_usage from .sdk.client import Agenta diff --git a/agenta-cli/debugging/simple-app/agenta/sdk/__init__.py b/agenta-cli/debugging/simple-app/agenta/sdk/__init__.py index c1e40757c..85c3488ae 100644 --- a/agenta-cli/debugging/simple-app/agenta/sdk/__init__.py +++ b/agenta-cli/debugging/simple-app/agenta/sdk/__init__.py @@ -24,7 +24,7 @@ from .tracing import Tracing, get_tracer from .decorators.tracing import instrument from .tracing.conventions import Reference -from .decorators.routing import entrypoint, app, route +from .decorators.routing import app, route from .agenta_init import Config, AgentaSingleton, init as _init from .utils.costs import calculate_token_usage from .managers.config import ConfigManager diff --git a/agenta-web/package-lock.json b/agenta-web/package-lock.json index d30e525ce..fdfd94379 100644 --- a/agenta-web/package-lock.json +++ b/agenta-web/package-lock.json @@ -1816,6 +1816,21 @@ } } }, + "node_modules/@next/swc-darwin-arm64": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.17.tgz", + "integrity": "sha512-WiOf5nElPknrhRMTipXYTJcUz7+8IAjOYw3vXzj3BYRcVY0hRHKWgTgQ5439EvzQyHEko77XK+yN9x9OJ0oOog==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, "node_modules/@next/swc-darwin-x64": { "version": "14.2.17", "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.17.tgz", @@ -1832,6 +1847,111 @@ "node": ">= 10" } }, + "node_modules/@next/swc-linux-arm64-gnu": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.17.tgz", + "integrity": "sha512-SSHLZls3ZwNEHsc+d0ynKS+7Af0Nr8+KTUBAy9pm6xz9SHkJ/TeuEg6W3cbbcMSh6j4ITvrjv3Oi8n27VR+IPw==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-arm64-musl": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.17.tgz", + "integrity": "sha512-VFge37us5LNPatB4F7iYeuGs9Dprqe4ZkW7lOEJM91r+Wf8EIdViWHLpIwfdDXinvCdLl6b4VyLpEBwpkctJHA==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-gnu": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.17.tgz", + "integrity": "sha512-aaQlpxUVb9RZ41adlTYVQ3xvYEfBPUC8+6rDgmQ/0l7SvK8S1YNJzPmDPX6a4t0jLtIoNk7j+nroS/pB4nx7vQ==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-linux-x64-musl": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.17.tgz", + "integrity": "sha512-HSyEiFaEY3ay5iATDqEup5WAfrhMATNJm8dYx3ZxL+e9eKv10XKZCwtZByDoLST7CyBmyDz+OFJL1wigyXeaoA==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-arm64-msvc": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.17.tgz", + "integrity": "sha512-h5qM9Btqv87eYH8ArrnLoAHLyi79oPTP2vlGNSg4CDvUiXgi7l0+5KuEGp5pJoMhjuv9ChRdm7mRlUUACeBt4w==", + "cpu": [ + "arm64" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-ia32-msvc": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.17.tgz", + "integrity": "sha512-BD/G++GKSLexQjdyoEUgyo5nClU7er5rK0sE+HlEqnldJSm96CIr/+YOTT063LVTT/dUOeQsNgp5DXr86/K7/A==", + "cpu": [ + "ia32" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@next/swc-win32-x64-msvc": { + "version": "14.2.17", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.17.tgz", + "integrity": "sha512-vkQfN1+4V4KqDibkW2q0sJ6CxQuXq5l2ma3z0BRcfIqkAMZiiW67T9yCpwqJKP68QghBtPEFjPAlaqe38O6frw==", + "cpu": [ + "x64" + ], + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -20357,126 +20477,6 @@ "type": "github", "url": "https://github.com/sponsors/wooorm" } - }, - "node_modules/@next/swc-darwin-arm64": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.2.17.tgz", - "integrity": "sha512-WiOf5nElPknrhRMTipXYTJcUz7+8IAjOYw3vXzj3BYRcVY0hRHKWgTgQ5439EvzQyHEko77XK+yN9x9OJ0oOog==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-arm64-gnu": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.17.tgz", - "integrity": "sha512-SSHLZls3ZwNEHsc+d0ynKS+7Af0Nr8+KTUBAy9pm6xz9SHkJ/TeuEg6W3cbbcMSh6j4ITvrjv3Oi8n27VR+IPw==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-arm64-musl": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.17.tgz", - "integrity": "sha512-VFge37us5LNPatB4F7iYeuGs9Dprqe4ZkW7lOEJM91r+Wf8EIdViWHLpIwfdDXinvCdLl6b4VyLpEBwpkctJHA==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-x64-gnu": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.17.tgz", - "integrity": "sha512-aaQlpxUVb9RZ41adlTYVQ3xvYEfBPUC8+6rDgmQ/0l7SvK8S1YNJzPmDPX6a4t0jLtIoNk7j+nroS/pB4nx7vQ==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-x64-musl": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.17.tgz", - "integrity": "sha512-HSyEiFaEY3ay5iATDqEup5WAfrhMATNJm8dYx3ZxL+e9eKv10XKZCwtZByDoLST7CyBmyDz+OFJL1wigyXeaoA==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-arm64-msvc": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.17.tgz", - "integrity": "sha512-h5qM9Btqv87eYH8ArrnLoAHLyi79oPTP2vlGNSg4CDvUiXgi7l0+5KuEGp5pJoMhjuv9ChRdm7mRlUUACeBt4w==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-ia32-msvc": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.17.tgz", - "integrity": "sha512-BD/G++GKSLexQjdyoEUgyo5nClU7er5rK0sE+HlEqnldJSm96CIr/+YOTT063LVTT/dUOeQsNgp5DXr86/K7/A==", - "cpu": [ - "ia32" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-x64-msvc": { - "version": "14.2.17", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.17.tgz", - "integrity": "sha512-vkQfN1+4V4KqDibkW2q0sJ6CxQuXq5l2ma3z0BRcfIqkAMZiiW67T9yCpwqJKP68QghBtPEFjPAlaqe38O6frw==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } } } } diff --git a/services/chat-new-sdk/_app.py b/services/chat-new-sdk/_app.py index 935118c1f..65b8a8038 100644 --- a/services/chat-new-sdk/_app.py +++ b/services/chat-new-sdk/_app.py @@ -4,6 +4,7 @@ from agenta.sdk.assets import supported_llm_models from pydantic import BaseModel, Field import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM @@ -22,9 +23,7 @@ class MyConfig(BaseModel): temperature: float = Field(default=0.2, le=1, ge=0) - model: Annotated[str, ag.MultipleChoice(choices=supported_llm_models)] = Field( - default="gpt-3.5-turbo" - ) + model: str = ag.MCField(default="gpt-3.5-turbo", choices=supported_llm_models) max_tokens: int = Field(default=-1, ge=-1, le=4000) prompt_system: str = Field(default=SYSTEM_PROMPT) diff --git a/services/completion-new-sdk-prompt.rest b/services/completion-new-sdk-prompt.rest index 12366f3da..17ec01c5f 100644 --- a/services/completion-new-sdk-prompt.rest +++ b/services/completion-new-sdk-prompt.rest @@ -4,17 +4,17 @@ ### These request can be run using the Rest Client extension in vsCode ### Health Check -GET {{baseUrl}}/{{service}}/health HTTP/1.1 +GET {{baseUrl}}/{{service}}/health ### OpenAPI -GET {{baseUrl}}/{{service}}/openapi.json HTTP/1.1 +GET {{baseUrl}}/{{service}}/openapi.json ### Basic Text Response - Geography Assistant -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "prompt": { "llm_config": { "model": "gpt-4", @@ -41,11 +41,11 @@ Content-Type: application/json } ### JSON Object Response - Movie Information -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "prompt": { "llm_config": { "model": "gpt-4", @@ -72,11 +72,11 @@ Content-Type: application/json } ### JSON Schema Response - Recipe Generator -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "prompt": { "llm_config": { "model": "gpt-4", @@ -150,11 +150,11 @@ Content-Type: application/json } ### Function Calling with Tools - Weather Assistant -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "prompt": { "llm_config": { "model": "gpt-4", @@ -204,11 +204,11 @@ Content-Type: application/json } ### Function Calling with Multiple Tools - Smart Home Assistant -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "prompt": { "llm_config": { "model": "gpt-4", diff --git a/services/completion-new-sdk-prompt/_app.py b/services/completion-new-sdk-prompt/_app.py index 9d9ed65b3..35e16a6db 100644 --- a/services/completion-new-sdk-prompt/_app.py +++ b/services/completion-new-sdk-prompt/_app.py @@ -1,10 +1,11 @@ from typing import Dict + +from pydantic import Field from fastapi import HTTPException import agenta as ag + import litellm -from agenta.sdk.types import PromptTemplate -from pydantic import BaseModel, Field litellm.drop_params = True litellm.callbacks = [ag.callbacks.litellm_handler()] @@ -12,21 +13,22 @@ ag.init() -class MyConfig(BaseModel): - prompt: PromptTemplate = Field( - default=PromptTemplate( +class MyConfig(ag.BaseConfigModel): + prompt: ag.PromptTemplate = Field( + default=ag.PromptTemplate( system_prompt="You are an expert in geography", user_prompt="What is the capital of {country}?", ) ) -@ag.route("/", config_schema=MyConfig) +@ag.route("/", config_schema=MyConfig, content_type="text/plain") @ag.instrument() async def generate( inputs: Dict[str, str], ): config = ag.ConfigManager.get_from_route(schema=MyConfig) + if config.prompt.input_keys is not None: required_keys = set(config.prompt.input_keys) provided_keys = set(inputs.keys()) @@ -36,16 +38,21 @@ async def generate( status_code=422, detail=f"Invalid inputs. Expected: {sorted(required_keys)}, got: {sorted(provided_keys)}", ) + response = await litellm.acompletion( **config.prompt.format(**inputs).to_openai_kwargs() ) + message = response.choices[0].message if message.content is not None: return message.content + if hasattr(message, "refusal") and message.refusal is not None: return message.refusal + if hasattr(message, "parsed") and message.parsed is not None: return message.parsed + if hasattr(message, "tool_calls") and message.tool_calls is not None: return [tool_call.dict() for tool_call in message.tool_calls] diff --git a/services/completion-new-sdk-prompt/docker-compose.yml b/services/completion-new-sdk-prompt/docker-compose.yml index 623442d7a..2f565bd20 100644 --- a/services/completion-new-sdk-prompt/docker-compose.yml +++ b/services/completion-new-sdk-prompt/docker-compose.yml @@ -7,7 +7,7 @@ services: environment: - AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED=True - AGENTA_HOST=http://host.docker.internal - - OPENAI_API_KEY=sk-xxxx + - OPENAI_API_KEY=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 networks: - agenta-network diff --git a/services/completion-new-sdk.rest b/services/completion-new-sdk.rest index af8ddf26a..d394ecd70 100644 --- a/services/completion-new-sdk.rest +++ b/services/completion-new-sdk.rest @@ -1,76 +1,53 @@ @baseUrl = http://localhost -@service = completion-live-sdk +@service = completion-new-sdk ### These request can be run using the Rest Client extension in vsCode ### Health Check -GET {{baseUrl}}/{{service}}/health HTTP/1.1 +GET {{baseUrl}}/{{service}}/health ### OpenAPI -GET {{baseUrl}}/{{service}}/openapi.json HTTP/1.1 +GET {{baseUrl}}/{{service}}/openapi.json ### Generate -POST {{baseUrl}}/{{service}}/generate HTTP/1.1 +POST {{baseUrl}}/{{service}}/test Content-Type: application/json { - "agenta_config": { + "ag_config": { "temperature": 0.7, - "model": "gpt-3.5-turbo-16k", + "model": "gpt-3.5", "max_tokens": 256, "prompt_system": "You are an expert in linguistics.", "prompt_user": "What is the meaning of {word}?", "top_p": 0.95, "frequence_penalty": 0.5, "presence_penalty": 0.5, - "force_json": true + "force_json": false }, "inputs": { "word": "language" } } -### Playground Run -POST {{baseUrl}}/{{service}}/playground/run HTTP/1.1 +### Run +POST {{baseUrl}}/{{service}}/run?application_id=0194184e-e836-7e10-83a5-4f2c32927a21&variant_slug=default Content-Type: application/json { - "agenta_config": { - "temperature": 0.7, - "model": "gpt-3.5-turbo-16k", - "max_tokens": 256, - "prompt_system": "You are an expert in linguistics.", - "prompt_user": "What is the meaning of {word}?", - "top_p": 0.95, - "frequence_penalty": 0.5, - "presence_penalty": 0.5, - "force_json": true - }, "inputs": { - "word": "language" + "country": "Germany" } } -### Generate Deployed -POST {{baseUrl}}/{{service}}/generate_deployed HTTP/1.1 -Content-Type: application/json - -{ - "inputs": { - "country": "Italy" - }, - "config": "default", - "environment": "production" -} -### Run -POST {{baseUrl}}/{{service}}/run HTTP/1.1 +### Run (baggage) +POST {{baseUrl}}/{{service}}/run Content-Type: application/json +Baggage: application_id=0194184e-e836-7e10-83a5-4f2c32927a21, variant_slug=default { "inputs": { "country": "Germany" - }, - "config": "default", - "environment": "production" + } } diff --git a/services/completion-new-sdk/_app.py b/services/completion-new-sdk/_app.py index 60cef2965..d3a606482 100644 --- a/services/completion-new-sdk/_app.py +++ b/services/completion-new-sdk/_app.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field import os + # Import mock if MOCK_LLM environment variable is set if os.getenv("MOCK_LLM", True): from mock_litellm import MockLiteLLM @@ -30,9 +31,7 @@ class MyConfig(BaseModel): temperature: float = Field(default=1, ge=0.0, le=2.0) - model: Annotated[str, ag.MultipleChoice(choices=supported_llm_models)] = Field( - default="gpt-3.5-turbo" - ) + model: str = ag.MCField(default="gpt-3.5-turbo", choices=supported_llm_models) max_tokens: int = Field(default=-1, ge=-1, le=4000) prompt_system: str = Field(default=prompts["system_prompt"]) prompt_user: str = Field(default=prompts["user_prompt"]) @@ -89,7 +88,6 @@ async def generate( inputs: ag.DictInput = ag.DictInput(default_keys=["country"]), ): config = ag.ConfigManager.get_from_route(schema=MyConfig) - print("popo", config) try: prompt_user = config.prompt_user.format(**inputs) except Exception as e: @@ -106,8 +104,4 @@ async def generate( ) response = await llm_call(prompt_system=prompt_system, prompt_user=prompt_user) - return { - "message": response["message"], - "usage": response.get("usage", None), - "cost": response.get("cost", None), - } + return response["message"] From d4cfc98b431d834b3d9befc8da4efb55b0a983c1 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Mon, 30 Dec 2024 17:30:28 +0100 Subject: [PATCH 2/3] fix secret --- services/completion-new-sdk-prompt/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/completion-new-sdk-prompt/docker-compose.yml b/services/completion-new-sdk-prompt/docker-compose.yml index 2f565bd20..623442d7a 100644 --- a/services/completion-new-sdk-prompt/docker-compose.yml +++ b/services/completion-new-sdk-prompt/docker-compose.yml @@ -7,7 +7,7 @@ services: environment: - AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED=True - AGENTA_HOST=http://host.docker.internal - - OPENAI_API_KEY=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 + - OPENAI_API_KEY=sk-xxxx networks: - agenta-network From b10460a71de490f222585ceff9d7469d0c4e0cb8 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Mon, 30 Dec 2024 17:36:30 +0100 Subject: [PATCH 3/3] fix app id --- services/completion-new-sdk.rest | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/services/completion-new-sdk.rest b/services/completion-new-sdk.rest index d394ecd70..b1c1cfe3a 100644 --- a/services/completion-new-sdk.rest +++ b/services/completion-new-sdk.rest @@ -1,5 +1,6 @@ @baseUrl = http://localhost @service = completion-new-sdk +@application_id = 0194184e-e836-7e10-83a5-4f2c32927a21 ### These request can be run using the Rest Client extension in vsCode @@ -31,7 +32,7 @@ Content-Type: application/json } ### Run -POST {{baseUrl}}/{{service}}/run?application_id=0194184e-e836-7e10-83a5-4f2c32927a21&variant_slug=default +POST {{baseUrl}}/{{service}}/run?application_id={{application_id}}&variant_slug=default Content-Type: application/json { @@ -44,7 +45,7 @@ Content-Type: application/json ### Run (baggage) POST {{baseUrl}}/{{service}}/run Content-Type: application/json -Baggage: application_id=0194184e-e836-7e10-83a5-4f2c32927a21, variant_slug=default +Baggage: application_id={{application_id}}, variant_slug=default { "inputs": {