diff --git a/agenta-backend/agenta_backend/models/api/evaluation_model.py b/agenta-backend/agenta_backend/models/api/evaluation_model.py index 44f3ed500..c9bad617a 100644 --- a/agenta-backend/agenta_backend/models/api/evaluation_model.py +++ b/agenta-backend/agenta_backend/models/api/evaluation_model.py @@ -274,7 +274,7 @@ class LLMRunRateLimit(BaseModel): class LMProvidersEnum(str, Enum): openai = "OPENAI_API_KEY" - mistralai = "MISTRAL_API_KEY" + mistral = "MISTRAL_API_KEY" cohere = "COHERE_API_KEY" anthropic = "ANTHROPIC_API_KEY" anyscale = "ANYSCALE_API_KEY" diff --git a/agenta-backend/agenta_backend/routers/app_router.py b/agenta-backend/agenta_backend/routers/app_router.py index d4a54f16b..3074c4b00 100644 --- a/agenta-backend/agenta_backend/routers/app_router.py +++ b/agenta-backend/agenta_backend/routers/app_router.py @@ -578,6 +578,8 @@ async def create_app_and_variant_from_template( if isCloudEE() else "Step 7: Starting variant and injecting environment variables" ) + + envvars = {} if isCloudEE(): supported_llm_prodviders_keys = [ "OPENAI_API_KEY", @@ -593,21 +595,27 @@ async def create_app_and_variant_from_template( "GROQ_API_KEY", "GEMINI_API_KEY", ] + missing_keys = [ key for key in supported_llm_prodviders_keys if not os.environ.get(key) ] + if missing_keys: missing_keys_str = ", ".join(missing_keys) raise Exception( f"Unable to start app container. The following environment variables are missing: {missing_keys_str}. Please file an issue by clicking on the button below." ) - envvars = {**(payload.env_vars or {})} + if not isCloudEE(): + envvars = {**(payload.env_vars or {})} + for key in supported_llm_prodviders_keys: if not envvars.get(key): envvars[key] = os.environ[key] + else: envvars = {} if payload.env_vars is None else payload.env_vars + await app_manager.start_variant( app_variant_db, str(app.project_id), diff --git a/agenta-backend/agenta_backend/routers/permissions_router.py b/agenta-backend/agenta_backend/routers/permissions_router.py index 9e7532b97..3dfd0bdba 100644 --- a/agenta-backend/agenta_backend/routers/permissions_router.py +++ b/agenta-backend/agenta_backend/routers/permissions_router.py @@ -51,7 +51,7 @@ async def verify_permissions( if isOss(): return Allow(None) - if not action or not resource_type or not resource_id: + if not action or not resource_type: raise Deny() if isCloudEE(): @@ -70,8 +70,8 @@ async def verify_permissions( # CHECK PERMISSION 2/2: RESOURCE allow_resource = await check_resource_access( project_id=UUID(request.state.project_id), - resource_id=resource_id, resource_type=resource_type, + resource_id=resource_id, ) if not allow_resource: @@ -80,13 +80,14 @@ async def verify_permissions( return Allow(request.state.credentials) except Exception as exc: # pylint: disable=bare-except + print(exc) raise Deny() from exc async def check_resource_access( project_id: UUID, - resource_id: UUID, resource_type: str, + resource_id: Optional[UUID] = None, ) -> bool: resource_project_id = None @@ -95,6 +96,15 @@ async def check_resource_access( resource_project_id = app.project_id + if resource_type == "service": + if resource_id is None: + resource_project_id = project_id + + else: + base = await db_manager.fetch_base_by_id(base_id=str(resource_id)) + + resource_project_id = base.project_id + allow_resource = resource_project_id == project_id return allow_resource diff --git a/agenta-backend/agenta_backend/services/helpers.py b/agenta-backend/agenta_backend/services/helpers.py index 18951ad6f..1e0b904db 100644 --- a/agenta-backend/agenta_backend/services/helpers.py +++ b/agenta-backend/agenta_backend/services/helpers.py @@ -94,7 +94,7 @@ def format_llm_provider_keys( Dict[str, str]: formatted llm provided keys Example: - Input: {: '...', ...} + Input: {: '...', ...} Output: {'MISTRAL_API_KEY': '...', ...} """ diff --git a/agenta-backend/pyproject.toml b/agenta-backend/pyproject.toml index 901da388f..daf64b108 100644 --- a/agenta-backend/pyproject.toml +++ b/agenta-backend/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta_backend" -version = "0.30.0" +version = "0.31.0" description = "" authors = ["Mahmoud Mabrouk "] readme = "README.md" diff --git a/agenta-cli/agenta/__init__.py b/agenta-cli/agenta/__init__.py index 53c65db70..59629b8dd 100644 --- a/agenta-cli/agenta/__init__.py +++ b/agenta-cli/agenta/__init__.py @@ -28,6 +28,7 @@ from .sdk.utils.costs import calculate_token_usage from .sdk.client import Agenta from .sdk.litellm import litellm as callbacks +from .sdk.managers.secrets import SecretsManager from .sdk.managers.config import ConfigManager from .sdk.managers.variant import VariantManager from .sdk.managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/client/backend/types/provider_kind.py b/agenta-cli/agenta/client/backend/types/provider_kind.py index d46946203..175e2aa3c 100644 --- a/agenta-cli/agenta/client/backend/types/provider_kind.py +++ b/agenta-cli/agenta/client/backend/types/provider_kind.py @@ -10,7 +10,7 @@ "deepinfra", "alephalpha", "groq", - "mistralai", + "mistral", "anthropic", "perplexityai", "togetherai", diff --git a/agenta-cli/agenta/client/client.py b/agenta-cli/agenta/client/client.py index d5e4547f7..17dc1ac46 100644 --- a/agenta-cli/agenta/client/client.py +++ b/agenta-cli/agenta/client/client.py @@ -559,5 +559,5 @@ def run_evaluation(app_name: str, host: str, api_key: str = None) -> str: raise APIRequestError( f"Request to run evaluations failed with status code {response.status_code} and error message: {error_message}." ) - print(response.json()) + return response.json() diff --git a/agenta-cli/agenta/sdk/__init__.py b/agenta-cli/agenta/sdk/__init__.py index c1e40757c..9f3dc8662 100644 --- a/agenta-cli/agenta/sdk/__init__.py +++ b/agenta-cli/agenta/sdk/__init__.py @@ -27,6 +27,7 @@ from .decorators.routing import entrypoint, app, route from .agenta_init import Config, AgentaSingleton, init as _init from .utils.costs import calculate_token_usage +from .managers.secrets import SecretsManager from .managers.config import ConfigManager from .managers.variant import VariantManager from .managers.deployment import DeploymentManager diff --git a/agenta-cli/agenta/sdk/agenta_init.py b/agenta-cli/agenta/sdk/agenta_init.py index c2180457c..06659f4f4 100644 --- a/agenta-cli/agenta/sdk/agenta_init.py +++ b/agenta-cli/agenta/sdk/agenta_init.py @@ -6,8 +6,9 @@ from agenta.sdk.utils.logging import log from agenta.sdk.utils.globals import set_global from agenta.client.backend.client import AgentaApi, AsyncAgentaApi + from agenta.sdk.tracing import Tracing -from agenta.client.exceptions import APIRequestError +from agenta.sdk.context.routing import routing_context class AgentaSingleton: @@ -59,9 +60,7 @@ def init( ValueError: If `app_id` is not specified either as an argument, in the config file, or in the environment variables. """ - log.info("---------------------------") - log.info("Agenta SDK - using version: %s", version("agenta")) - log.info("---------------------------") + log.info("Agenta - SDK version: %s", version("agenta")) config = {} if config_fname: @@ -86,6 +85,13 @@ def init( self.api_key = api_key or getenv("AGENTA_API_KEY") or config.get("api_key") + self.base_id = getenv("AGENTA_BASE_ID") + + self.service_id = getenv("AGENTA_SERVICE_ID") or self.base_id + + log.info("Agenta - Service ID: %s", self.service_id) + log.info("Agenta - Application ID: %s", self.app_id) + self.tracing = Tracing( url=f"{self.host}/api/observability/v1/otlp/traces", # type: ignore redact=redact, @@ -94,6 +100,7 @@ def init( self.tracing.configure( api_key=self.api_key, + service_id=self.service_id, # DEPRECATING app_id=self.app_id, ) @@ -108,8 +115,6 @@ def init( api_key=self.api_key if self.api_key else "", ) - self.base_id = getenv("AGENTA_BASE_ID") - self.config = Config( host=self.host, base_id=self.base_id, @@ -120,28 +125,43 @@ def init( class Config: def __init__( self, - host: str, + # LEGACY + host: Optional[str] = None, base_id: Optional[str] = None, - api_key: Optional[str] = "", + api_key: Optional[str] = None, + # LEGACY + **kwargs, ): - self.host = host + self.default_parameters = {**kwargs} + + def set_default(self, **kwargs): + self.default_parameters.update(kwargs) + + def get_default(self): + return self.default_parameters + + def __getattr__(self, key): + context = routing_context.get() + + parameters = context.parameters + + if not parameters: + return None + + if key in parameters: + value = parameters[key] + + if isinstance(value, dict): + nested_config = Config() + nested_config.set_default(**value) - self.base_id = base_id + return nested_config - if self.base_id is None: - # print( - # "Warning: Your configuration will not be saved permanently since base_id is not provided.\n" - # ) - pass + return value - if base_id is None or host is None: - self.persist = False - else: - self.persist = True - self.client = AgentaApi( - base_url=self.host + "/api", - api_key=api_key if api_key else "", - ) + return None + + ### --- LEGACY --- ### def register_default(self, overwrite=False, **kwargs): """alias for default""" @@ -153,104 +173,13 @@ def default(self, overwrite=False, **kwargs): overwrite: Whether to overwrite the existing configuration or not **kwargs: A dict containing the parameters """ - self.set( - **kwargs - ) # In case there is no connectivity, we still can use the default values - try: - self.push(config_name="default", overwrite=overwrite, **kwargs) - except Exception as ex: - log.warning( - "Unable to push the default configuration to the server. %s", str(ex) - ) - - def push(self, config_name: str, overwrite=True, **kwargs): - """Pushes the parameters for the app variant to the server - Args: - config_name: Name of the configuration to push to - overwrite: Whether to overwrite the existing configuration or not - **kwargs: A dict containing the parameters - """ - if not self.persist: - return - try: - self.client.configs.save_config( - base_id=self.base_id, - config_name=config_name, - parameters=kwargs, - overwrite=overwrite, - ) - except Exception as ex: - log.warning( - "Failed to push the configuration to the server with error: %s", ex - ) - - def pull( - self, config_name: str = "default", environment_name: Optional[str] = None - ): - """Pulls the parameters for the app variant from the server and sets them to the config""" - if not self.persist and ( - config_name != "default" or environment_name is not None - ): - raise ValueError( - "Cannot pull the configuration from the server since the app_name and base_name are not provided." - ) - if self.persist: - try: - if environment_name: - config = self.client.configs.get_config( - base_id=self.base_id, environment_name=environment_name - ) - - else: - config = self.client.configs.get_config( - base_id=self.base_id, - config_name=config_name, - ) - except Exception as ex: - log.warning( - "Failed to pull the configuration from the server with error: %s", - str(ex), - ) - try: - self.set(**{"current_version": config.current_version, **config.parameters}) - except Exception as ex: - log.warning("Failed to set the configuration with error: %s", str(ex)) + self.set(**kwargs) - def all(self): - """Returns all the parameters for the app variant""" - return { - k: v - for k, v in self.__dict__.items() - if k - not in [ - "app_name", - "base_name", - "host", - "base_id", - "api_key", - "persist", - "client", - ] - } - - # function to set the parameters for the app variant def set(self, **kwargs): - """Sets the parameters for the app variant - - Args: - **kwargs: A dict containing the parameters - """ - for key, value in kwargs.items(): - setattr(self, key, value) - - def dump(self): - """Returns all the information about the current version in the configuration. + self.set_default(**kwargs) - Raises: - NotImplementedError: _description_ - """ - - raise NotImplementedError() + def all(self): + return self.default_parameters def init( diff --git a/agenta-cli/agenta/sdk/assets.py b/agenta-cli/agenta/sdk/assets.py index c62cc9dd9..ef1d79984 100644 --- a/agenta-cli/agenta/sdk/assets.py +++ b/agenta-cli/agenta/sdk/assets.py @@ -1,23 +1,9 @@ supported_llm_models = { - "Mistral AI": [ - "mistral/mistral-tiny", - "mistral/mistral-small", - "mistral/mistral-medium", - "mistral/mistral-large-latest", - ], - "Open AI": [ - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo", - "gpt-4", - "gpt-4o", - "gpt-4o-mini", - "gpt-4-1106-preview", - ], - "Gemini": ["gemini/gemini-1.5-pro-latest", "gemini/gemini-1.5-flash"], - "Cohere": [ - "cohere/command-light", - "cohere/command-r-plus", - "cohere/command-nightly", + "Aleph Alpha": [ + "luminous-base", + "luminous-base-control", + "luminous-extended-control", + "luminous-supreme", ], "Anthropic": [ "anthropic/claude-3-5-sonnet-20240620", @@ -33,11 +19,10 @@ "anyscale/meta-llama/Llama-2-13b-chat-hf", "anyscale/meta-llama/Llama-2-70b-chat-hf", ], - "Perplexity AI": [ - "perplexity/pplx-7b-chat", - "perplexity/pplx-70b-chat", - "perplexity/pplx-7b-online", - "perplexity/pplx-70b-online", + "Cohere": [ + "cohere/command-light", + "cohere/command-r-plus", + "cohere/command-nightly", ], "DeepInfra": [ "deepinfra/meta-llama/Llama-2-70b-chat-hf", @@ -46,6 +31,46 @@ "deepinfra/mistralai/Mistral-7B-Instruct-v0.1", "deepinfra/jondurbin/airoboros-l2-70b-gpt4-1.4.1", ], + "Gemini": [ + "gemini/gemini-1.5-pro-latest", + "gemini/gemini-1.5-flash", + ], + "Groq": [ + "groq/llama3-8b-8192", + "groq/llama3-70b-8192", + "groq/llama2-70b-4096", + "groq/mixtral-8x7b-32768", + "groq/gemma-7b-it", + ], + "Mistral": [ + "mistral/mistral-tiny", + "mistral/mistral-small", + "mistral/mistral-medium", + "mistral/mistral-large-latest", + ], + "Open AI": [ + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-4", + "gpt-4o", + "gpt-4o-mini", + "gpt-4-1106-preview", + ], + "OpenRouter": [ + "openrouter/openai/gpt-3.5-turbo", + "openrouter/openai/gpt-3.5-turbo-16k", + "openrouter/anthropic/claude-instant-v1", + "openrouter/google/palm-2-chat-bison", + "openrouter/google/palm-2-codechat-bison", + "openrouter/meta-llama/llama-2-13b-chat", + "openrouter/meta-llama/llama-2-70b-chat", + ], + "Perplexity AI": [ + "perplexity/pplx-7b-chat", + "perplexity/pplx-70b-chat", + "perplexity/pplx-7b-online", + "perplexity/pplx-70b-online", + ], "Together AI": [ "together_ai/togethercomputer/llama-2-70b-chat", "together_ai/togethercomputer/llama-2-70b", @@ -59,26 +84,12 @@ "together_ai/NousResearch/Nous-Hermes-Llama2-13b", "together_ai/Austism/chronos-hermes-13b", ], - "Aleph Alpha": [ - "luminous-base", - "luminous-base-control", - "luminous-extended-control", - "luminous-supreme", - ], - "OpenRouter": [ - "openrouter/openai/gpt-3.5-turbo", - "openrouter/openai/gpt-3.5-turbo-16k", - "openrouter/anthropic/claude-instant-v1", - "openrouter/google/palm-2-chat-bison", - "openrouter/google/palm-2-codechat-bison", - "openrouter/meta-llama/llama-2-13b-chat", - "openrouter/meta-llama/llama-2-70b-chat", - ], - "Groq": [ - "groq/llama3-8b-8192", - "groq/llama3-70b-8192", - "groq/llama2-70b-4096", - "groq/mixtral-8x7b-32768", - "groq/gemma-7b-it", - ], +} + +providers_list = list(supported_llm_models.keys()) + +model_to_provider_mapping = { + model: provider + for provider, models in supported_llm_models.items() + for model in models } diff --git a/agenta-cli/agenta/sdk/context/exporting.py b/agenta-cli/agenta/sdk/context/exporting.py new file mode 100644 index 000000000..2fe03a09c --- /dev/null +++ b/agenta-cli/agenta/sdk/context/exporting.py @@ -0,0 +1,25 @@ +from typing import Optional + +from contextlib import contextmanager +from contextvars import ContextVar + +from pydantic import BaseModel + + +class ExportingContext(BaseModel): + credentials: Optional[str] = None + + +exporting_context = ContextVar("exporting_context", default=ExportingContext()) + + +@contextmanager +def exporting_context_manager( + *, + context: Optional[ExportingContext] = None, +): + token = exporting_context.set(context) + try: + yield + finally: + exporting_context.reset(token) diff --git a/agenta-cli/agenta/sdk/context/routing.py b/agenta-cli/agenta/sdk/context/routing.py index 1d716a69e..128489828 100644 --- a/agenta-cli/agenta/sdk/context/routing.py +++ b/agenta-cli/agenta/sdk/context/routing.py @@ -1,24 +1,24 @@ +from typing import Any, Dict, List, Optional + from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, Dict, Optional -routing_context = ContextVar("routing_context", default={}) +from pydantic import BaseModel + + +class RoutingContext(BaseModel): + parameters: Optional[Dict[str, Any]] = None + secrets: Optional[List[Any]] = None + + +routing_context = ContextVar("routing_context", default=RoutingContext()) @contextmanager def routing_context_manager( *, - config: Optional[Dict[str, Any]] = None, - application: Optional[Dict[str, Any]] = None, - variant: Optional[Dict[str, Any]] = None, - environment: Optional[Dict[str, Any]] = None, + context: Optional[RoutingContext] = None, ): - context = { - "config": config, - "application": application, - "variant": variant, - "environment": environment, - } token = routing_context.set(context) try: yield diff --git a/agenta-cli/agenta/sdk/context/tracing.py b/agenta-cli/agenta/sdk/context/tracing.py index 0585a014a..3bebe13dc 100644 --- a/agenta-cli/agenta/sdk/context/tracing.py +++ b/agenta-cli/agenta/sdk/context/tracing.py @@ -1,3 +1,28 @@ +from typing import Any, Dict, Optional + +from contextlib import contextmanager from contextvars import ContextVar -tracing_context = ContextVar("tracing_context", default={}) +from pydantic import BaseModel + + +class TracingContext(BaseModel): + credentials: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + references: Optional[Dict[str, Any]] = None + link: Optional[Dict[str, Any]] = None + + +tracing_context = ContextVar("tracing_context", default=TracingContext()) + + +@contextmanager +def tracing_context_manager( + *, + context: Optional[TracingContext] = None, +): + token = tracing_context.set(context) + try: + yield + finally: + tracing_context.reset(token) diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 6be0c1c30..a0da8996a 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -1,26 +1,34 @@ from typing import Type, Any, Callable, Dict, Optional, Tuple, List -from annotated_types import Ge, Le, Gt, Lt -from pydantic import BaseModel, HttpUrl, ValidationError -from json import dumps from inspect import signature, iscoroutinefunction, Signature, Parameter, _empty -from argparse import ArgumentParser from functools import wraps -from asyncio import sleep, get_event_loop -from traceback import format_exc, format_exception -from pathlib import Path +from traceback import format_exception +from asyncio import sleep + from tempfile import NamedTemporaryFile -from os import environ +from annotated_types import Ge, Le, Gt, Lt +from pydantic import BaseModel, HttpUrl, ValidationError -from fastapi.middleware.cors import CORSMiddleware -from fastapi import Body, FastAPI, UploadFile, HTTPException +from fastapi import Body, FastAPI, UploadFile, HTTPException, Request -from agenta.sdk.middleware.auth import AuthorizationMiddleware -from agenta.sdk.context.routing import routing_context_manager, routing_context -from agenta.sdk.context.tracing import tracing_context +from agenta.sdk.middleware.auth import AuthMiddleware +from agenta.sdk.middleware.otel import OTelMiddleware +from agenta.sdk.middleware.config import ConfigMiddleware +from agenta.sdk.middleware.vault import VaultMiddleware +from agenta.sdk.middleware.cors import CORSMiddleware + +from agenta.sdk.context.routing import ( + routing_context_manager, + RoutingContext, +) +from agenta.sdk.context.tracing import ( + tracing_context_manager, + tracing_context, + TracingContext, +) from agenta.sdk.router import router -from agenta.sdk.utils import helpers -from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.utils.exceptions import suppress, display_exception from agenta.sdk.utils.logging import log +from agenta.sdk.utils.helpers import get_current_version from agenta.sdk.types import ( DictInput, FloatParam, @@ -39,19 +47,10 @@ import agenta as ag -AGENTA_USE_CORS = str(environ.get("AGENTA_USE_CORS", "true")).lower() in ( - "true", - "1", - "t", -) - app = FastAPI() log.setLevel("DEBUG") -_MIDDLEWARES = True - - app.include_router(router, prefix="") @@ -59,13 +58,17 @@ class PathValidator(BaseModel): url: HttpUrl -class route: +class route: # pylint: disable=invalid-name # This decorator is used to expose specific stages of a workflow (embedding, retrieval, summarization, etc.) # as independent endpoints. It is designed for backward compatibility with existing code that uses # the @entrypoint decorator, which has certain limitations. By using @route(), we can create new # routes without altering the main workflow entrypoint. This helps in modularizing the services # and provides flexibility in how we expose different functionalities as APIs. - def __init__(self, path, config_schema: BaseModel): + def __init__( + self, + path: Optional[str] = "/", + config_schema: Optional[BaseModel] = None, + ): self.config_schema: BaseModel = config_schema path = "/" + path.strip("/").strip() path = "" if path == "/" else path @@ -73,9 +76,13 @@ def __init__(self, path, config_schema: BaseModel): self.route_path = path + self.e = None + def __call__(self, f): self.e = entrypoint( - f, route_path=self.route_path, config_schema=self.config_schema + f, + route_path=self.route_path, + config_schema=self.config_schema, ) return f @@ -114,231 +121,212 @@ async def chain_of_prompts_llm(prompt: str): routes = list() + _middleware = False + _run_path = "/run" + _test_path = "/test" + # LEGACY + _legacy_playground_run_path = "/playground/run" + _legacy_generate_path = "/generate" + _legacy_generate_deployed_path = "/generate_deployed" + def __init__( self, func: Callable[..., Any], - route_path="", + route_path: str = "", config_schema: Optional[BaseModel] = None, ): - ### --- Update Middleware --- # - try: - global _MIDDLEWARES # pylint: disable=global-statement - - if _MIDDLEWARES: - app.add_middleware( - AuthorizationMiddleware, - host=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host, - resource_id=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id, - resource_type="application", - ) + self.func = func + self.route_path = route_path + self.config_schema = config_schema - if AGENTA_USE_CORS: - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True, - ) + signature_parameters = signature(func).parameters + ingestible_files = self.extract_ingestible_files() + config, default_parameters = self.parse_config() - _MIDDLEWARES = False + ### --- Middleware --- # + if not entrypoint._middleware: + entrypoint._middleware = True - except: # pylint: disable=bare-except - log.warning("Agenta SDK - failed to secure route: %s", route_path) - ### --- Update Middleware --- # + app.add_middleware(VaultMiddleware) + app.add_middleware(ConfigMiddleware) + app.add_middleware(AuthMiddleware) + app.add_middleware(OTelMiddleware) + app.add_middleware(CORSMiddleware) + ### ------------------ # - DEFAULT_PATH = "generate" - PLAYGROUND_PATH = "/playground" - RUN_PATH = "/run" - func_signature = signature(func) - try: - config = ( - config_schema() if config_schema else None - ) # we initialize the config object to be able to use it - except ValidationError as e: - raise ValueError( - f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" - ) from e - except Exception as e: - raise ValueError( - f"Unexpected error initializing config_schema: {str(e)}" - ) from e - - config_params = config.dict() if config else ag.config.all() - ingestible_files = self.extract_ingestible_files(func_signature) + ### --- Run --- # + @wraps(func) + async def run_wrapper(request: Request, *args, **kwargs) -> Any: + # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate_deployed + kwargs = { + k: v + for k, v in kwargs.items() + if k not in ["config", "environment", "app"] + } + # LEGACY - self.route_path = route_path + kwargs, _ = self.split_kwargs(kwargs, default_parameters) - ### --- Playground --- # + # TODO: Why is this not used in the run_wrapper? + # self.ingest_files(kwargs, ingestible_files) + + return await self.execute_wrapper(request, False, *args, **kwargs) + + self.update_run_wrapper_signature( + wrapper=run_wrapper, + ingestible_files=ingestible_files, + ) + + run_route = f"{entrypoint._run_path}{route_path}" + app.post(run_route, response_model=BaseResponse)(run_wrapper) + + # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate_deployed must be replaced with calls to /run + if route_path == "": + run_route = entrypoint._legacy_generate_deployed_path + app.post(run_route, response_model=BaseResponse)(run_wrapper) + # LEGACY + ### ----------- # + + ### --- Test --- # @wraps(func) - async def wrapper(*args, **kwargs) -> Any: - func_params, api_config_params = self.split_kwargs(kwargs, config_params) - self.ingest_files(func_params, ingestible_files) - if not config_schema: - ag.config.set(**api_config_params) - - with routing_context_manager( - config=api_config_params, - ): - entrypoint_result = await self.execute_function( - func, - True, # inline trace: True - *args, - params=func_params, - config_params=config_params, - ) + async def test_wrapper(request: Request, *args, **kwargs) -> Any: + kwargs, parameters = self.split_kwargs(kwargs, default_parameters) - return entrypoint_result + request.state.config["parameters"] = parameters - self.update_function_signature( - wrapper=wrapper, - func_signature=func_signature, - config_class=config, - config_dict=config_params, + # TODO: Why is this only used in the test_wrapper? + self.ingest_files(kwargs, ingestible_files) + + return await self.execute_wrapper(request, True, *args, **kwargs) + + self.update_test_wrapper_signature( + wrapper=test_wrapper, ingestible_files=ingestible_files, + config_class=config, + config_dict=default_parameters, ) - # + test_route = f"{entrypoint._test_path}{route_path}" + app.post(test_route, response_model=BaseResponse)(test_wrapper) + + # LEGACY + # TODO: Removing this implies breaking changes in : + # - calls to /generate must be replaced with calls to /test if route_path == "": - route = f"/{DEFAULT_PATH}" - app.post(route, response_model=BaseResponse)(wrapper) - entrypoint.routes.append( - { - "func": func.__name__, - "endpoint": route, - "params": ( - {**config_params, **func_signature.parameters} - if not config - else func_signature.parameters - ), - "config": config, - } - ) + test_route = entrypoint._legacy_generate_path + app.post(test_route, response_model=BaseResponse)(test_wrapper) + # LEGACY - route = f"{PLAYGROUND_PATH}{RUN_PATH}{route_path}" - app.post(route, response_model=BaseResponse)(wrapper) + # LEGACY + # TODO: Removing this implies no breaking changes + if route_path == "": + test_route = entrypoint._legacy_playground_run_path + app.post(test_route, response_model=BaseResponse)(test_wrapper) + # LEGACY + ### ------------ # + + ### --- OpenAPI --- # + test_route = f"{entrypoint._test_path}{route_path}" entrypoint.routes.append( { "func": func.__name__, - "endpoint": route, + "endpoint": test_route, "params": ( - {**config_params, **func_signature.parameters} + {**default_parameters, **signature_parameters} if not config - else func_signature.parameters + else signature_parameters ), "config": config, } ) - ### ---------------------------- # - ### --- Deployed --- # - @wraps(func) - async def wrapper_deployed(*args, **kwargs) -> Any: - func_params = { - k: v - for k, v in kwargs.items() - if k not in ["config", "environment", "app"] - } - if not config_schema: - if "environment" in kwargs and kwargs["environment"] is not None: - ag.config.pull(environment_name=kwargs["environment"]) - elif "config" in kwargs and kwargs["config"] is not None: - ag.config.pull(config_name=kwargs["config"]) - else: - ag.config.pull(config_name="default") - - app_id = environ.get("AGENTA_APP_ID") - - with routing_context_manager( - application={ - "id": app_id, - "slug": kwargs.get("app"), - }, - variant={ - "slug": kwargs.get("config"), - }, - environment={ - "slug": kwargs.get("environment"), - }, - ): - entrypoint_result = await self.execute_function( - func, - False, # inline trace: False - *args, - params=func_params, - config_params=config_params, - ) - - return entrypoint_result - - self.update_deployed_function_signature( - wrapper_deployed, - func_signature, - ingestible_files, - ) + # LEGACY if route_path == "": - route_deployed = f"/{DEFAULT_PATH}_deployed" - app.post(route_deployed, response_model=BaseResponse)(wrapper_deployed) - - route_deployed = f"{RUN_PATH}{route_path}" - app.post(route_deployed, response_model=BaseResponse)(wrapper_deployed) - ### ---------------- # + test_route = entrypoint._legacy_generate_path + entrypoint.routes.append( + { + "func": func.__name__, + "endpoint": test_route, + "params": ( + {**default_parameters, **signature_parameters} + if not config + else signature_parameters + ), + "config": config, + } + ) + # LEGACY - ### --- Update OpenAPI --- # app.openapi_schema = None # Forces FastAPI to re-generate the schema openapi_schema = app.openapi() - # Inject the current version of the SDK into the openapi_schema - openapi_schema["agenta_sdk"] = {"version": helpers.get_current_version()} + openapi_schema["agenta_sdk"] = {"version": get_current_version()} - for route in entrypoint.routes: + for _route in entrypoint.routes: self.override_schema( openapi_schema=openapi_schema, - func_name=route["func"], - endpoint=route["endpoint"], - params=route["params"], + func_name=_route["func"], + endpoint=_route["endpoint"], + params=_route["params"], ) - if route["config"] is not None: # new SDK version + + if _route["config"] is not None: # new SDK version self.override_config_in_schema( openapi_schema=openapi_schema, - func_name=route["func"], - endpoint=route["endpoint"], - config=route["config"], + func_name=_route["func"], + endpoint=_route["endpoint"], + config=_route["config"], ) + ### --------------- # - if self.is_main_script(func) and route_path == "": - self.handle_terminal_run( - func, - func_signature.parameters, # type: ignore - config_params, - ingestible_files, - ) - - def extract_ingestible_files( - self, - func_signature: Signature, - ) -> Dict[str, Parameter]: + def extract_ingestible_files(self) -> Dict[str, Parameter]: """Extract parameters annotated as InFile from function signature.""" return { name: param - for name, param in func_signature.parameters.items() + for name, param in signature(self.func).parameters.items() if param.annotation is InFile } + def parse_config(self) -> Dict[str, Any]: + config = None + default_parameters = ag.config.all() + + if self.config_schema: + try: + config = self.config_schema() if self.config_schema else None + default_parameters = config.dict() if config else default_parameters + except ValidationError as e: + raise ValueError( + f"Error initializing config_schema. Please ensure all required fields have default values: {str(e)}" + ) from e + except Exception as e: + raise ValueError( + f"Unexpected error initializing config_schema: {str(e)}" + ) from e + + return config, default_parameters + def split_kwargs( - self, kwargs: Dict[str, Any], config_params: Dict[str, Any] + self, kwargs: Dict[str, Any], default_parameters: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Split keyword arguments into function parameters and API configuration parameters.""" + arguments = {k: v for k, v in kwargs.items() if k not in default_parameters} + parameters = {k: v for k, v in kwargs.items() if k in default_parameters} - func_params = {k: v for k, v in kwargs.items() if k not in config_params} - api_config_params = {k: v for k, v in kwargs.items() if k in config_params} - return func_params, api_config_params + return arguments, parameters - def ingest_file(self, upfile: UploadFile): + def ingest_file( + self, + upfile: UploadFile, + ): temp_file = NamedTemporaryFile(delete=False) temp_file.write(upfile.file.read()) temp_file.close() + return InFile(file_name=upfile.filename, file_path=temp_file.name) def ingest_files( @@ -352,51 +340,81 @@ def ingest_files( if name in func_params and func_params[name] is not None: func_params[name] = self.ingest_file(func_params[name]) - async def execute_function( + async def execute_wrapper( self, - func: Callable[..., Any], - inline_trace, + request: Request, + inline: bool, *args, - **func_params, + **kwargs, ): - log.info("Agenta SDK - handling route: %s", repr(self.route_path or "/")) + if not request: + raise HTTPException(status_code=500, detail="Missing 'request'.") + + state = request.state + credentials = state.auth.get("credentials") + parameters = state.config.get("parameters") + references = state.config.get("references") + secrets = state.vault.get("secrets") + + with routing_context_manager( + context=RoutingContext( + parameters=parameters, + secrets=secrets, + ) + ): + with tracing_context_manager( + context=TracingContext( + credentials=credentials, + parameters=parameters, + references=references, + ) + ): + result = await self.execute_function(inline, *args, **kwargs) - tracing_context.set(routing_context.get()) + return result + async def execute_function( + self, + inline: bool, + *args, + **kwargs, + ): try: result = ( - await func(*args, **func_params["params"]) - if iscoroutinefunction(func) - else func(*args, **func_params["params"]) + await self.func(*args, **kwargs) + if iscoroutinefunction(self.func) + else self.func(*args, **kwargs) ) - return await self.handle_success(result, inline_trace) + return await self.handle_success(result, inline) - except Exception as error: + except Exception as error: # pylint: disable=broad-except self.handle_failure(error) - async def handle_success(self, result: Any, inline_trace: bool): + async def handle_success( + self, + result: Any, + inline: bool, + ): data = None tree = None with suppress(): data = self.patch_result(result) - if inline_trace: - tree = await self.fetch_inline_trace(inline_trace) - - log.info(f"----------------------------------") - log.info(f"Agenta SDK - exiting with success: 200") - log.info(f"----------------------------------") + if inline: + tree = await self.fetch_inline_trace(inline) - return BaseResponse(data=data, tree=tree) + try: + return BaseResponse(data=data, tree=tree) + except: + return BaseResponse(data=data) - def handle_failure(self, error: Exception): - log.warning("--------------------------------------------------") - log.warning("Agenta SDK - handling application exception below:") - log.warning("--------------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("--------------------------------------------------") + def handle_failure( + self, + error: Exception, + ): + display_exception("Application Exception") status_code = 500 message = str(error) @@ -405,7 +423,10 @@ def handle_failure(self, error: Exception): raise HTTPException(status_code=status_code, detail=detail) - def patch_result(self, result: Any): + def patch_result( + self, + result: Any, + ): """ Patch the result to only include the message if the result is a FuncResponse-style dictionary with message, cost, and usage keys. @@ -442,7 +463,10 @@ def patch_result(self, result: Any): return data - async def fetch_inline_trace(self, inline_trace): + async def fetch_inline_trace( + self, + inline, + ): WAIT_FOR_SPANS = True TIMEOUT = 1 TIMESTEP = 0.1 @@ -451,12 +475,14 @@ async def fetch_inline_trace(self, inline_trace): trace = None - root_context: Dict[str, Any] = tracing_context.get().get("root") + context = tracing_context.get() - trace_id = root_context.get("trace_id") if root_context else None + link = context.link + + trace_id = link.get("tree_id") if link else None if trace_id is not None: - if inline_trace: + if inline: if WAIT_FOR_SPANS: remaining_steps = NOFSTEPS @@ -476,6 +502,27 @@ async def fetch_inline_trace(self, inline_trace): return trace + # --- OpenAPI --- # + + def add_request_to_signature( + self, + wrapper: Callable[..., Any], + ): + original_sig = signature(wrapper) + parameters = [ + Parameter( + "request", + kind=Parameter.POSITIONAL_OR_KEYWORD, + annotation=Request, + ), + *original_sig.parameters.values(), + ] + new_sig = Signature( + parameters, + return_annotation=original_sig.return_annotation, + ) + wrapper.__signature__ = new_sig + def update_wrapper_signature( self, wrapper: Callable[..., Any], updated_params: List ): @@ -492,10 +539,9 @@ def update_wrapper_signature( wrapper_signature = wrapper_signature.replace(parameters=updated_params) wrapper.__signature__ = wrapper_signature # type: ignore - def update_function_signature( + def update_test_wrapper_signature( self, wrapper: Callable[..., Any], - func_signature: Signature, config_class: Type[BaseModel], # TODO: change to our type config_dict: Dict[str, Any], ingestible_files: Dict[str, Parameter], @@ -507,19 +553,19 @@ def update_function_signature( self.add_config_params_to_parser(updated_params, config_class) else: self.deprecated_add_config_params_to_parser(updated_params, config_dict) - self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + self.add_func_params_to_parser(updated_params, ingestible_files) self.update_wrapper_signature(wrapper, updated_params) + self.add_request_to_signature(wrapper) - def update_deployed_function_signature( + def update_run_wrapper_signature( self, wrapper: Callable[..., Any], - func_signature: Signature, ingestible_files: Dict[str, Parameter], ) -> None: """Update the function signature to include new parameters.""" updated_params: List[Parameter] = [] - self.add_func_params_to_parser(updated_params, func_signature, ingestible_files) + self.add_func_params_to_parser(updated_params, ingestible_files) for param in [ "config", "environment", @@ -533,6 +579,7 @@ def update_deployed_function_signature( ) ) self.update_wrapper_signature(wrapper, updated_params) + self.add_request_to_signature(wrapper) def add_config_params_to_parser( self, updated_params: list, config_class: Type[BaseModel] @@ -573,11 +620,10 @@ def deprecated_add_config_params_to_parser( def add_func_params_to_parser( self, updated_params: list, - func_signature: Signature, ingestible_files: Dict[str, Parameter], ) -> None: """Add function parameters to function signature.""" - for name, param in func_signature.parameters.items(): + for name, param in signature(self.func).parameters.items(): if name in ingestible_files: updated_params.append( Parameter(name, param.kind, annotation=UploadFile) @@ -599,115 +645,6 @@ def add_func_params_to_parser( ) ) - def is_main_script(self, func: Callable) -> bool: - """ - Check if the script containing the function is the main script being run. - - Args: - func (Callable): The function object to check. - - Returns: - bool: True if the script containing the function is the main script, False otherwise. - - Example: - if is_main_script(my_function): - print("This is the main script.") - """ - return func.__module__ == "__main__" - - def handle_terminal_run( - self, - func: Callable, - func_params: Dict[str, Parameter], - config_params: Dict[str, Any], - ingestible_files: Dict, - ): - """ - Parses command line arguments and sets configuration when script is run from the terminal. - - Args: - func_params (dict): A dictionary containing the function parameters and their annotations. - config_params (dict): A dictionary containing the configuration parameters. - ingestible_files (dict): A dictionary containing the files that should be ingested. - """ - - # For required parameters, we add them as arguments - parser = ArgumentParser() - for name, param in func_params.items(): - if name in ingestible_files: - parser.add_argument(name, type=str) - else: - parser.add_argument(name, type=param.annotation) - - for name, param in config_params.items(): - if type(param) is MultipleChoiceParam: - parser.add_argument( - f"--{name}", - type=str, - default=param.default, - choices=param.choices, # type: ignore - ) - else: - parser.add_argument( - f"--{name}", - type=type(param), - default=param, - ) - - args = parser.parse_args() - - # split the arg list into the arg in the app_param and - # the args from the sig.parameter - args_config_params = {k: v for k, v in vars(args).items() if k in config_params} - args_func_params = { - k: v for k, v in vars(args).items() if k not in config_params - } - for name in ingestible_files: - args_func_params[name] = InFile( - file_name=Path(args_func_params[name]).stem, - file_path=args_func_params[name], - ) - - # Update args_config_params with default values from config_params if not provided in command line arguments - args_config_params.update( - { - key: value - for key, value in config_params.items() - if key not in args_config_params - } - ) - - loop = get_event_loop() - - with routing_context_manager(config=args_config_params): - result = loop.run_until_complete( - self.execute_function( - func, - True, # inline trace: True - **{"params": args_func_params, "config_params": args_config_params}, - ) - ) - - if result.trace: - log.info("\n========= Result =========\n") - - log.info(f"trace_id: {result.trace['trace_id']}") - log.info(f"latency: {result.trace.get('latency')}") - log.info(f"cost: {result.trace.get('cost')}") - log.info(f"usage: {list(result.trace.get('usage', {}).values())}") - - log.info(" ") - log.info("data:") - log.info(dumps(result.data, indent=2)) - - log.info(" ") - log.info("trace:") - log.info("----------------") - log.info(dumps(result.trace.get("spans", []), indent=2)) - log.info("----------------") - - log.info("\n==========================\n") - def override_config_in_schema( self, openapi_schema: dict, diff --git a/agenta-cli/agenta/sdk/decorators/tracing.py b/agenta-cli/agenta/sdk/decorators/tracing.py index 68f707b69..f368509fc 100644 --- a/agenta-cli/agenta/sdk/decorators/tracing.py +++ b/agenta-cli/agenta/sdk/decorators/tracing.py @@ -1,8 +1,12 @@ from typing import Callable, Optional, Any, Dict, List, Union + from functools import wraps from itertools import chain from inspect import iscoroutinefunction, getfullargspec +from opentelemetry import baggage as baggage +from opentelemetry.context import attach, detach + from agenta.sdk.utils.exceptions import suppress from agenta.sdk.context.tracing import tracing_context from agenta.sdk.tracing.conventions import parse_span_kind @@ -39,10 +43,12 @@ def __call__(self, func: Callable[..., Any]): is_coroutine_function = iscoroutinefunction(func) @wraps(func) - async def async_wrapper(*args, **kwargs): - async def _async_auto_instrumented(*args, **kwargs): + async def awrapper(*args, **kwargs): + async def aauto_instrumented(*args, **kwargs): self._parse_type_and_kind() + token = self._attach_baggage() + with ag.tracer.start_as_current_span(func.__name__, kind=self.kind): self._pre_instrument(func, *args, **kwargs) @@ -52,13 +58,17 @@ async def _async_auto_instrumented(*args, **kwargs): return result - return await _async_auto_instrumented(*args, **kwargs) + self._detach_baggage(token) + + return await aauto_instrumented(*args, **kwargs) @wraps(func) - def sync_wrapper(*args, **kwargs): - def _sync_auto_instrumented(*args, **kwargs): + def wrapper(*args, **kwargs): + def auto_instrumented(*args, **kwargs): self._parse_type_and_kind() + token = self._attach_baggage() + with ag.tracer.start_as_current_span(func.__name__, kind=self.kind): self._pre_instrument(func, *args, **kwargs) @@ -68,9 +78,11 @@ def _sync_auto_instrumented(*args, **kwargs): return result - return _sync_auto_instrumented(*args, **kwargs) + self._detach_baggage(token) + + return auto_instrumented(*args, **kwargs) - return async_wrapper if is_coroutine_function else sync_wrapper + return awrapper if is_coroutine_function else wrapper def _parse_type_and_kind(self): if not ag.tracing.get_current_span().is_recording(): @@ -78,6 +90,25 @@ def _parse_type_and_kind(self): self.kind = parse_span_kind(self.type) + def _attach_baggage(self): + context = tracing_context.get() + + references = context.references + + token = None + if references: + for k, v in references.items(): + token = attach(baggage.set_baggage(f"ag.refs.{k}", v)) + + return token + + def _detach_baggage( + self, + token, + ): + if token: + detach(token) + def _pre_instrument( self, func, @@ -86,29 +117,21 @@ def _pre_instrument( ): span = ag.tracing.get_current_span() + context = tracing_context.get() + with suppress(): + trace_id = span.context.trace_id + + ag.tracing.credentials[trace_id] = context.credentials + span.set_attributes( attributes={"node": self.type}, namespace="type", ) if span.parent is None: - rctx = tracing_context.get() - - span.set_attributes( - attributes={"configuration": rctx.get("config", {})}, - namespace="meta", - ) - span.set_attributes( - attributes={"environment": rctx.get("environment", {})}, - namespace="meta", - ) span.set_attributes( - attributes={"version": rctx.get("version", {})}, - namespace="meta", - ) - span.set_attributes( - attributes={"variant": rctx.get("variant", {})}, + attributes={"configuration": context.parameters or {}}, namespace="meta", ) @@ -118,6 +141,7 @@ def _pre_instrument( io=self._parse(func, *args, **kwargs), ignore=self.ignore_inputs, ) + span.set_attributes( attributes={"inputs": _inputs}, namespace="data", @@ -161,6 +185,7 @@ def _post_instrument( io=self._patch(result), ignore=self.ignore_outputs, ) + span.set_attributes( attributes={"outputs": _outputs}, namespace="data", @@ -171,15 +196,12 @@ def _post_instrument( with suppress(): if hasattr(span, "parent") and span.parent is None: - tracing_context.set( - tracing_context.get() - | { - "root": { - "trace_id": span.get_span_context().trace_id, - "span_id": span.get_span_context().span_id, - } - } - ) + context = tracing_context.get() + context.link = { + "tree_id": span.get_span_context().trace_id, + "node_id": span.get_span_context().span_id, + } + tracing_context.set(context) def _parse( self, diff --git a/agenta-cli/agenta/sdk/managers/config.py b/agenta-cli/agenta/sdk/managers/config.py index edadbaedc..d3ec7b97c 100644 --- a/agenta-cli/agenta/sdk/managers/config.py +++ b/agenta-cli/agenta/sdk/managers/config.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from agenta.sdk.managers.shared import SharedManager -from agenta.sdk.decorators.routing import routing_context +from agenta.sdk.context.routing import routing_context T = TypeVar("T", bound=BaseModel) @@ -20,7 +20,7 @@ class ConfigManager: @staticmethod def get_from_route( schema: Optional[Type[T]] = None, - ) -> Union[Dict[str, Any], T]: + ) -> Optional[Union[Dict[str, Any], T]]: """ Retrieves the configuration from the route context and returns a config object. @@ -47,125 +47,15 @@ def get_from_route( context = routing_context.get() - parameters = None - - if "config" in context and context["config"]: - parameters = context["config"] - - else: - app_id: Optional[str] = None - app_slug: Optional[str] = None - variant_id: Optional[str] = None - variant_slug: Optional[str] = None - variant_version: Optional[int] = None - environment_id: Optional[str] = None - environment_slug: Optional[str] = None - environment_version: Optional[int] = None - - if "application" in context: - app_id = context["application"].get("id") - app_slug = context["application"].get("slug") - - if "variant" in context: - variant_id = context["variant"].get("id") - variant_slug = context["variant"].get("slug") - variant_version = context["variant"].get("version") - - if "environment" in context: - environment_id = context["environment"].get("id") - environment_slug = context["environment"].get("slug") - environment_version = context["environment"].get("version") - - parameters = ConfigManager.get_from_registry( - app_id=app_id, - app_slug=app_slug, - variant_id=variant_id, - variant_slug=variant_slug, - variant_version=variant_version, - environment_id=environment_id, - environment_slug=environment_slug, - environment_version=environment_version, - ) + parameters = context.parameters - if schema: - return schema(**parameters) - - return parameters + if not parameters: + return None - @staticmethod - async def aget_from_route( - schema: Optional[Type[T]] = None, - ) -> Union[Dict[str, Any], T]: - """ - Asynchronously retrieves the configuration from the route context and returns a config object. + if not schema: + return parameters - This method checks the route context for configuration information and returns - an instance of the specified schema based on the available context data. - - Args: - schema (Type[T]): A Pydantic model class that defines the structure of the configuration. - - Returns: - T: An instance of the specified schema populated with the configuration data. - - Raises: - ValueError: If conflicting configuration sources are provided or if no valid - configuration source is found in the context. - - Note: - The method prioritizes the inputs in the following way: - 1. 'config' (i.e. when called explicitly from the playground) - 2. 'environment' - 3. 'variant' - Only one of these should be provided. - """ - - context = routing_context.get() - - parameters = None - - if "config" in context and context["config"]: - parameters = context["config"] - - else: - app_id: Optional[str] = None - app_slug: Optional[str] = None - variant_id: Optional[str] = None - variant_slug: Optional[str] = None - variant_version: Optional[int] = None - environment_id: Optional[str] = None - environment_slug: Optional[str] = None - environment_version: Optional[int] = None - - if "application" in context: - app_id = context["application"].get("id") - app_slug = context["application"].get("slug") - - if "variant" in context: - variant_id = context["variant"].get("id") - variant_slug = context["variant"].get("slug") - variant_version = context["variant"].get("version") - - if "environment" in context: - environment_id = context["environment"].get("id") - environment_slug = context["environment"].get("slug") - environment_version = context["environment"].get("version") - - parameters = await ConfigManager.async_get_from_registry( - app_id=app_id, - app_slug=app_slug, - variant_id=variant_id, - variant_slug=variant_slug, - variant_version=variant_version, - environment_id=environment_id, - environment_slug=environment_slug, - environment_version=environment_version, - ) - - if schema: - return schema(**parameters) - - return parameters + return schema(**parameters) @staticmethod def get_from_registry( diff --git a/agenta-cli/agenta/sdk/managers/secrets.py b/agenta-cli/agenta/sdk/managers/secrets.py new file mode 100644 index 000000000..aca598881 --- /dev/null +++ b/agenta-cli/agenta/sdk/managers/secrets.py @@ -0,0 +1,38 @@ +from typing import Optional, Dict, Any + +from agenta.sdk.context.routing import routing_context + +from agenta.sdk.assets import model_to_provider_mapping + + +class SecretsManager: + @staticmethod + def get_from_route() -> Optional[Dict[str, Any]]: + context = routing_context.get() + + secrets = context.secrets + + if not secrets: + return None + + return secrets + + @staticmethod + def get_api_key_for_model(model: str) -> str: + secrets = SecretsManager.get_from_route() + + if not secrets: + return None + + provider = model_to_provider_mapping.get(model) + + if not provider: + return None + + provider = provider.lower().replace(" ", "") + + for secret in secrets: + if secret["data"]["provider"] == provider: + return secret["data"]["key"] + + return None diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index c02e46322..fd82198d0 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -1,90 +1,116 @@ from typing import Callable, Optional -from os import environ -from uuid import UUID + +from os import getenv from json import dumps -from traceback import format_exc import httpx from starlette.middleware.base import BaseHTTPMiddleware -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse -from agenta.sdk.utils.logging import log -from agenta.sdk.middleware.cache import TTLLRUCache +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL +from agenta.sdk.utils.constants import TRUTHY +from agenta.sdk.utils.exceptions import display_exception -AGENTA_SDK_AUTH_CACHE_CAPACITY = environ.get( - "AGENTA_SDK_AUTH_CACHE_CAPACITY", - 512, -) +import agenta as ag -AGENTA_SDK_AUTH_CACHE_TTL = environ.get( - "AGENTA_SDK_AUTH_CACHE_TTL", - 15 * 60, # 15 minutes -) -AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in ( - "true", - "1", - "t", +_SHARED_SERVICE = getenv("AGENTA_SHARED_SERVICE", "false").lower() in TRUTHY +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY +_UNAUTHORIZED_ALLOWED = ( + getenv("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", "false").lower() in TRUTHY ) +_ALWAYS_ALLOW_LIST = ["/health"] -AGENTA_SDK_AUTH_CACHE = False +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) -AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( - environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) -).lower() in ("true", "1", "t") +class DenyResponse(JSONResponse): + def __init__( + self, + status_code: int = 401, + detail: str = "Unauthorized", + ) -> None: + super().__init__( + status_code=status_code, + content={"detail": detail}, + ) -class Deny(Response): - def __init__(self) -> None: - super().__init__(status_code=401, content="Unauthorized") +class DenyException(Exception): + def __init__( + self, + status_code: int = 401, + content: str = "Unauthorized", + ) -> None: + super().__init__() -cache = TTLLRUCache( - capacity=AGENTA_SDK_AUTH_CACHE_CAPACITY, - ttl=AGENTA_SDK_AUTH_CACHE_TTL, -) + self.status_code = status_code + self.content = content -class AuthorizationMiddleware(BaseHTTPMiddleware): - def __init__( - self, - app: FastAPI, - host: str, - resource_id: UUID, - resource_type: str, - ): +class AuthMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): super().__init__(app) - self.host = host - self.resource_id = resource_id - self.resource_type = resource_type + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.resource_id = ( + ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.service_id + if not _SHARED_SERVICE + else None + ) + + async def dispatch(self, request: Request, call_next: Callable): + try: + if _UNAUTHORIZED_ALLOWED or request.url.path in _ALWAYS_ALLOW_LIST: + request.state.auth = {} + + else: + credentials = await self._get_credentials(request) + + request.state.auth = {"credentials": credentials} - async def dispatch( - self, - request: Request, - call_next: Callable, - ): - if AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED: return await call_next(request) - try: - authorization = ( - request.headers.get("Authorization") - or request.headers.get("authorization") - or None + except DenyException as deny: + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=deny.status_code, + detail=deny.content, ) + except: # pylint: disable=bare-except + display_exception("Auth Middleware Exception") + + return DenyResponse( + status_code=500, + detail="Auth: Unexpected Error.", + ) + + async def _get_credentials(self, request: Request) -> Optional[str]: + try: + authorization = request.headers.get("authorization", None) + headers = {"Authorization": authorization} if authorization else None - cookies = {"sAccessToken": request.cookies.get("sAccessToken")} + access_token = request.cookies.get("sAccessToken", None) + + cookies = {"sAccessToken": access_token} if access_token else None + + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + project_id = ( + # CLEANEST + baggage.get("project_id") + # ALTERNATIVE + or request.query_params.get("project_id") + ) - params = { - "action": "run_service", - "resource_type": self.resource_type, - "resource_id": self.resource_id, - } + params = {"action": "run_service", "resource_type": "service"} - project_id = request.query_params.get("project_id") + if self.resource_id: + params["resource_id"] = self.resource_id if project_id: params["project_id"] = project_id @@ -98,48 +124,57 @@ async def dispatch( sort_keys=True, ) - policy = None - if AGENTA_SDK_AUTH_CACHE: - policy = cache.get(_hash) - - if not policy: - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.host}/api/permissions/verify", - headers=headers, - cookies=cookies, - params=params, - ) + if _CACHE_ENABLED: + credentials = _cache.get(_hash) + + if credentials: + return credentials - if response.status_code != 200: - cache.put(_hash, {"effect": "deny"}) - return Deny() + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/permissions/verify", + headers=headers, + cookies=cookies, + params=params, + ) - auth = response.json() + if response.status_code == 401: + raise DenyException( + status_code=401, + content="Invalid credentials", + ) + elif response.status_code == 403: + raise DenyException( + status_code=403, + content="Service execution not allowed.", + ) + elif response.status_code != 200: + raise DenyException( + status_code=400, + content="Auth: Unexpected Error.", + ) - if auth.get("effect") != "allow": - cache.put(_hash, {"effect": "deny"}) - return Deny() + auth = response.json() - policy = { - "effect": "allow", - "credentials": auth.get("credentials"), - } + if auth.get("effect") != "allow": + raise DenyException( + status_code=403, + content="Service execution not allowed.", + ) - cache.put(_hash, policy) + credentials = auth.get("credentials") - if not policy or policy.get("effect") == "deny": - return Deny() + _cache.put(_hash, credentials) - request.state.credentials = policy.get("credentials") + return credentials - return await call_next(request) + except DenyException as deny: + raise deny - except: # pylint: disable=bare-except - log.warning("------------------------------------------------------") - log.warning("Agenta SDK - handling auth middleware exception below:") - log.warning("------------------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------------------") + except Exception as exc: # pylint: disable=bare-except + display_exception("Auth Middleware Exception (suppressed)") - return Deny() + raise DenyException( + status_code=500, + content="Auth: Unexpected Error.", + ) from exc diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py index 5445b1faf..641f4f802 100644 --- a/agenta-cli/agenta/sdk/middleware/cache.py +++ b/agenta-cli/agenta/sdk/middleware/cache.py @@ -1,6 +1,10 @@ +from os import getenv from time import time from collections import OrderedDict +CACHE_CAPACITY = int(getenv("AGENTA_MIDDLEWARE_CACHE_CAPACITY", "512")) +CACHE_TTL = int(getenv("AGENTA_MIDDLEWARE_CACHE_TTL", str(5 * 60))) # 5 minutes + class TTLLRUCache: def __init__(self, capacity: int, ttl: int): diff --git a/agenta-cli/agenta/sdk/middleware/config.py b/agenta-cli/agenta/sdk/middleware/config.py new file mode 100644 index 000000000..8ea9eb9ff --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/config.py @@ -0,0 +1,254 @@ +from typing import Callable, Optional, Tuple, Dict + +from os import getenv +from json import dumps + +from pydantic import BaseModel + +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request, FastAPI + +import httpx + +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL +from agenta.sdk.utils.constants import TRUTHY +from agenta.sdk.utils.exceptions import suppress + +import agenta as ag + + +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY + +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) + + +class Reference(BaseModel): + id: Optional[str] = None + slug: Optional[str] = None + version: Optional[str] = None + + +class ConfigMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + self.application_id = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + request.state.config = {} + + with suppress(): + parameters, references = await self._get_config(request) + + request.state.config = { + "parameters": parameters, + "references": references, + } + + return await call_next(request) + + # @atimeit + async def _get_config(self, request: Request) -> Optional[Tuple[Dict, Dict]]: + credentials = request.state.auth.get("credentials") + + headers = None + if credentials: + headers = {"Authorization": credentials} + + application_ref = await self._parse_application_ref(request) + variant_ref = await self._parse_variant_ref(request) + environment_ref = await self._parse_environment_ref(request) + + refs = {} + if application_ref: + refs["application_ref"] = application_ref.model_dump() + if variant_ref: + refs["variant_ref"] = variant_ref.model_dump() + if environment_ref: + refs["environment_ref"] = environment_ref.model_dump() + + if not refs: + return None, None + + _hash = dumps( + { + "headers": headers, + "refs": refs, + }, + sort_keys=True, + ) + + if _CACHE_ENABLED: + config_cache = _cache.get(_hash) + + if config_cache: + parameters = config_cache.get("parameters") + references = config_cache.get("references") + + return parameters, references + + config = None + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.host}/api/variants/configs/fetch", + headers=headers, + json=refs, + ) + + if response.status_code != 200: + return None, None + + config = response.json() + + if not config: + _cache.put(_hash, {"parameters": None, "references": None}) + + return None, None + + parameters = config.get("params") + + references = {} + + for ref_key in ["application_ref", "variant_ref", "environment_ref"]: + refs = config.get(ref_key) + ref_prefix = ref_key.split("_", maxsplit=1)[0] + + for ref_part_key in ["id", "slug", "version"]: + ref_part = refs.get(ref_part_key) + + if ref_part: + references[ref_prefix + "." + ref_part_key] = ref_part + + _cache.put(_hash, {"parameters": parameters, "references": references}) + + return parameters, references + + async def _parse_application_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + application_id = ( + # CLEANEST + baggage.get("application_id") + # ALTERNATIVE + or request.query_params.get("application_id") + # LEGACY + or request.query_params.get("app_id") + or self.application_id + ) + application_slug = ( + # CLEANEST + baggage.get("application_slug") + # ALTERNATIVE + or request.query_params.get("application_slug") + # LEGACY + or request.query_params.get("app_slug") + or body.get("app") + ) + + if not any([application_id, application_slug]): + return None + + return Reference( + id=application_id, + slug=application_slug, + ) + + async def _parse_variant_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + variant_id = ( + # CLEANEST + baggage.get("variant_id") + # ALTERNATIVE + or request.query_params.get("variant_id") + ) + variant_slug = ( + # CLEANEST + baggage.get("variant_slug") + # ALTERNATIVE + or request.query_params.get("variant_slug") + # LEGACY + or request.query_params.get("config") + or body.get("config") + ) + variant_version = ( + # CLEANEST + baggage.get("variant_version") + # ALTERNATIVE + or request.query_params.get("variant_version") + ) + + if not any([variant_id, variant_slug, variant_version]): + return None + + return Reference( + id=variant_id, + slug=variant_slug, + version=variant_version, + ) + + async def _parse_environment_ref( + self, + request: Request, + ) -> Optional[Reference]: + baggage = request.state.otel.get("baggage") if request.state.otel else {} + + body = {} + try: + body = await request.json() + except: # pylint: disable=bare-except + pass + + environment_id = ( + # CLEANEST + baggage.get("environment_id") + # ALTERNATIVE + or request.query_params.get("environment_id") + ) + environment_slug = ( + # CLEANEST + baggage.get("environment_slug") + # ALTERNATIVE + or request.query_params.get("environment_slug") + # LEGACY + or request.query_params.get("environment") + or body.get("environment") + ) + environment_version = ( + # CLEANEST + baggage.get("environment_version") + # ALTERNATIVE + or request.query_params.get("environment_version") + ) + + if not any([environment_id, environment_slug, environment_version]): + return None + + return Reference( + id=environment_id, + slug=environment_slug, + version=environment_version, + ) diff --git a/agenta-cli/agenta/sdk/middleware/cors.py b/agenta-cli/agenta/sdk/middleware/cors.py new file mode 100644 index 000000000..a4568dbe4 --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/cors.py @@ -0,0 +1,29 @@ +from os import getenv + +from starlette.types import ASGIApp, Receive, Scope, Send +from fastapi.middleware.cors import CORSMiddleware as _CORSMiddleware + +_TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} +_USE_CORS = getenv("AGENTA_USE_CORS", "enable").lower() in _TRUTHY + + +class CORSMiddleware(_CORSMiddleware): + def __init__(self, app: ASGIApp): + self.app = app + + if _USE_CORS: + super().__init__( + app=app, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + expose_headers=None, + max_age=None, + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if _USE_CORS: + return await super().__call__(scope, receive, send) + + return await self.app(scope, receive, send) diff --git a/agenta-cli/agenta/sdk/middleware/otel.py b/agenta-cli/agenta/sdk/middleware/otel.py new file mode 100644 index 000000000..0a6396f97 --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/otel.py @@ -0,0 +1,40 @@ +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import Request, FastAPI + +from opentelemetry.baggage.propagation import W3CBaggagePropagator + +from agenta.sdk.utils.exceptions import suppress + + +class OTelMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + async def dispatch(self, request: Request, call_next: Callable): + request.state.otel = {} + + with suppress(): + baggage = await self._get_baggage(request) + + request.state.otel = {"baggage": baggage} + + return await call_next(request) + + async def _get_baggage( + self, + request, + ): + _baggage = {"baggage": request.headers.get("Baggage", "")} + + context = W3CBaggagePropagator().extract(_baggage) + + baggage = {} + + if context: + for partial in context.values(): + for key, value in partial.items(): + baggage[key] = value + + return baggage diff --git a/agenta-cli/agenta/sdk/middleware/vault.py b/agenta-cli/agenta/sdk/middleware/vault.py new file mode 100644 index 000000000..d7c1793af --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/vault.py @@ -0,0 +1,158 @@ +from os import getenv +from json import dumps +from typing import Callable, Dict, Optional, List, Any + +import httpx +from fastapi import FastAPI, Request +from starlette.middleware.base import BaseHTTPMiddleware + +from agenta.sdk.utils.constants import TRUTHY +from agenta.client.backend.types.provider_kind import ProviderKind +from agenta.sdk.utils.exceptions import suppress, display_exception +from agenta.client.backend.types.secret_dto import SecretDto as SecretDTO +from agenta.client.backend.types.provider_key_dto import ( + ProviderKeyDto as ProviderKeyDTO, +) +from agenta.sdk.middleware.cache import TTLLRUCache, CACHE_CAPACITY, CACHE_TTL + +import agenta as ag + + +# ProviderKind (agenta.client.backend.types.provider_kind import ProviderKind) defines a type hint that allows \ +# for a fixed set of string literals representing various provider names, alongside `typing.Any`. +PROVIDER_KINDS = [] + +# Rationale behind the following: +# ------------------------------- +# You cannot loop directly over the values in `typing.Literal` because: +# - `Literal` is not iterable. +# - `ProviderKind.__args__` includes `Literal` and `Any`, but the actual string values +# are nested within the `Literal`'s own `__args__` attribute. + +# To solve this, we programmatically extract the values from `Literal` while retaining +# the structure of ProviderKind. This ensures: +# 1. We don't modify the original `ProviderKind` type definition. +# 2. We dynamically access the literal values for use at runtime when necessary. +for arg in ProviderKind.__args__: # type: ignore + if hasattr(arg, "__args__"): + PROVIDER_KINDS.extend(arg.__args__) + + +_CACHE_ENABLED = getenv("AGENTA_MIDDLEWARE_CACHE_ENABLED", "true").lower() in TRUTHY + +_cache = TTLLRUCache(capacity=CACHE_CAPACITY, ttl=CACHE_TTL) + + +class VaultMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + + self.host = ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host + + def _transform_secrets_response_to_secret_dto( + self, secrets_list: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + secrets_dto_dict = [ + { + "kind": secret.get("secret", {}).get("kind"), + "data": secret.get("secret", {}).get("data", {}), + } + for secret in secrets_list + ] + return secrets_dto_dict + + async def dispatch( + self, + request: Request, + call_next: Callable, + ): + request.state.vault = {} + + with suppress(): + secrets = await self._get_secrets(request) + + request.state.vault = {"secrets": secrets} + + return await call_next(request) + + async def _get_secrets(self, request: Request) -> Optional[Dict]: + credentials = request.state.auth.get("credentials") + + headers = None + if credentials: + headers = {"Authorization": credentials} + + _hash = dumps( + { + "headers": headers, + }, + sort_keys=True, + ) + + if _CACHE_ENABLED: + secrets_cache = _cache.get(_hash) + + if secrets_cache: + secrets = secrets_cache.get("secrets") + + return secrets + + local_secrets: List[SecretDTO] = [] + + try: + for provider_kind in PROVIDER_KINDS: + provider = provider_kind + key_name = f"{provider.upper()}_API_KEY" + key = getenv(key_name) + + if not key: + continue + + secret = SecretDTO( # 'kind' attribute in SecretDTO defaults to 'provider_kind' + data=ProviderKeyDTO( + provider=provider, + key=key, + ), + ) + + local_secrets.append(secret.model_dump()) + except: # pylint: disable=bare-except + display_exception("Vault: Local Secrets Exception") + + vault_secrets: List[SecretDTO] = [] + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/vault/v1/secrets", + headers=headers, + ) + + if response.status_code != 200: + vault_secrets = [] + + else: + secrets = response.json() + vault_secrets = self._transform_secrets_response_to_secret_dto( + secrets + ) + except: # pylint: disable=bare-except + display_exception("Vault: Vault Secrets Exception") + + merged_secrets = {} + + if local_secrets: + for secret in local_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret + + if vault_secrets: + for secret in vault_secrets: + provider = secret["data"]["provider"] + merged_secrets[provider] = secret + + secrets = list(merged_secrets.values()) + + _cache.put(_hash, {"secrets": secrets}) + + return secrets diff --git a/agenta-cli/agenta/sdk/tracing/context.py b/agenta-cli/agenta/sdk/tracing/context.py deleted file mode 100644 index 23925db01..000000000 --- a/agenta-cli/agenta/sdk/tracing/context.py +++ /dev/null @@ -1,24 +0,0 @@ -from contextvars import ContextVar -from contextlib import contextmanager -from traceback import format_exc - -from agenta.sdk.utils.logging import log - -tracing_context = ContextVar("tracing_context", default={}) - - -@contextmanager -def tracing_context_manager(): - _tracing_context = {"health": {"status": "ok"}} - - token = tracing_context.set(_tracing_context) - try: - yield - except: # pylint: disable=bare-except - log.warning("----------------------------------------------") - log.warning("Agenta SDK - handling tracing exception below:") - log.warning("----------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("----------------------------------------------") - finally: - tracing_context.reset(token) diff --git a/agenta-cli/agenta/sdk/tracing/exporters.py b/agenta-cli/agenta/sdk/tracing/exporters.py index 62f03a10b..7a38201d5 100644 --- a/agenta-cli/agenta/sdk/tracing/exporters.py +++ b/agenta-cli/agenta/sdk/tracing/exporters.py @@ -9,6 +9,11 @@ ) from agenta.sdk.utils.exceptions import suppress +from agenta.sdk.context.exporting import ( + exporting_context_manager, + exporting_context, + ExportingContext, +) class InlineTraceExporter(SpanExporter): @@ -58,8 +63,41 @@ def fetch( return trace -OTLPSpanExporter._MAX_RETRY_TIMEOUT = 2 # pylint: disable=protected-access +class OTLPExporter(OTLPSpanExporter): + _MAX_RETRY_TIMEOUT = 2 + + def __init__(self, *args, credentials: Dict[int, str] = None, **kwargs): + super().__init__(*args, **kwargs) + + self.credentials = credentials + + def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: + credentials = None + + if self.credentials: + trace_ids = set(span.get_span_context().trace_id for span in spans) + + if len(trace_ids) == 1: + trace_id = trace_ids.pop() + + if trace_id in self.credentials: + credentials = self.credentials.pop(trace_id) + + with exporting_context_manager( + context=ExportingContext( + credentials=credentials, + ) + ): + return super().export(spans) + + def _export(self, serialized_data: bytes): + credentials = exporting_context.get().credentials + + if credentials: + self._session.headers.update({"Authorization": credentials}) + + return super()._export(serialized_data) + ConsoleExporter = ConsoleSpanExporter InlineExporter = InlineTraceExporter -OTLPExporter = OTLPSpanExporter diff --git a/agenta-cli/agenta/sdk/tracing/inline.py b/agenta-cli/agenta/sdk/tracing/inline.py index 6905ad5cf..3bf55cdf8 100644 --- a/agenta-cli/agenta/sdk/tracing/inline.py +++ b/agenta-cli/agenta/sdk/tracing/inline.py @@ -101,8 +101,8 @@ class NodeDTO(BaseModel): Data = Dict[str, Any] Metrics = Dict[str, Any] Metadata = Dict[str, Any] -Tags = Dict[str, str] -Refs = Dict[str, str] +Tags = Dict[str, Any] +Refs = Dict[str, Any] class LinkDTO(BaseModel): diff --git a/agenta-cli/agenta/sdk/tracing/processors.py b/agenta-cli/agenta/sdk/tracing/processors.py index b5d04d808..2c612220c 100644 --- a/agenta-cli/agenta/sdk/tracing/processors.py +++ b/agenta-cli/agenta/sdk/tracing/processors.py @@ -1,5 +1,6 @@ from typing import Optional, Dict, List +from opentelemetry.baggage import get_all as get_baggage from opentelemetry.context import Context from opentelemetry.sdk.trace import Span from opentelemetry.sdk.trace.export import ( @@ -11,8 +12,7 @@ ) from agenta.sdk.utils.logging import log - -# LOAD CONTEXT, HERE ! +from agenta.sdk.tracing.conventions import Reference class TraceProcessor(BatchSpanProcessor): @@ -43,9 +43,17 @@ def on_start( span: Span, parent_context: Optional[Context] = None, ) -> None: + baggage = get_baggage(parent_context) + for key in self.references.keys(): span.set_attribute(f"ag.refs.{key}", self.references[key]) + for key in baggage.keys(): + if key.startswith("ag.refs."): + _key = key.replace("ag.refs.", "") + if _key in [_.value for _ in Reference.__members__.values()]: + span.set_attribute(key, baggage[key]) + if span.context.trace_id not in self._registry: self._registry[span.context.trace_id] = dict() @@ -89,7 +97,7 @@ def force_flush( ret = super().force_flush(timeout_millis) if not ret: - log.warning("Agenta SDK - skipping export due to timeout.") + log.warning("Agenta - Skipping export due to timeout.") def is_ready( self, diff --git a/agenta-cli/agenta/sdk/tracing/tracing.py b/agenta-cli/agenta/sdk/tracing/tracing.py index 809c86493..0e92bb9d1 100644 --- a/agenta-cli/agenta/sdk/tracing/tracing.py +++ b/agenta-cli/agenta/sdk/tracing/tracing.py @@ -41,6 +41,8 @@ def __init__( self.headers: Dict[str, str] = dict() # REFERENCES self.references: Dict[str, str] = dict() + # CREDENTIALS + self.credentials: Dict[int, str] = dict() # TRACER PROVIDER self.tracer_provider: Optional[TracerProvider] = None @@ -60,13 +62,16 @@ def __init__( def configure( self, api_key: Optional[str] = None, + service_id: Optional[str] = None, # DEPRECATING app_id: Optional[str] = None, ): # HEADERS (OTLP) if api_key: - self.headers["Authorization"] = api_key + self.headers["Authorization"] = f"ApiKey {api_key}" # REFERENCES + if service_id: + self.references["service.id"] = service_id if app_id: self.references["application.id"] = app_id @@ -84,31 +89,28 @@ def configure( self.tracer_provider.add_span_processor(self.inline) # TRACE PROCESSORS -- OTLP try: - log.info("--------------------------------------------") log.info( - "Agenta SDK - connecting to otlp receiver at: %s", + "Agenta - OLTP URL: %s", self.otlp_url, ) - log.info("--------------------------------------------") - check( - self.otlp_url, - headers=self.headers, - timeout=1, - ) + # check( + # self.otlp_url, + # headers=self.headers, + # timeout=1, + # ) _otlp = TraceProcessor( OTLPExporter( endpoint=self.otlp_url, headers=self.headers, + credentials=self.credentials, ), references=self.references, ) self.tracer_provider.add_span_processor(_otlp) - log.info("Success: traces will be exported.") - log.info("--------------------------------------------") except: # pylint: disable=bare-except - log.warning("Agenta SDK - traces will not be exported.") + log.warning("Agenta - OLTP unreachable, skipping exports.") # GLOBAL TRACER PROVIDER -- INSTRUMENTATION LIBRARIES set_tracer_provider(self.tracer_provider) diff --git a/agenta-cli/agenta/sdk/utils/constants.py b/agenta-cli/agenta/sdk/utils/constants.py new file mode 100644 index 000000000..fc2e1ae25 --- /dev/null +++ b/agenta-cli/agenta/sdk/utils/constants.py @@ -0,0 +1 @@ +TRUTHY = {"true", "1", "t", "y", "yes", "on", "enable", "enabled"} diff --git a/agenta-cli/agenta/sdk/utils/exceptions.py b/agenta-cli/agenta/sdk/utils/exceptions.py index a451b1de7..a1d5cb379 100644 --- a/agenta-cli/agenta/sdk/utils/exceptions.py +++ b/agenta-cli/agenta/sdk/utils/exceptions.py @@ -6,6 +6,17 @@ from agenta.sdk.utils.logging import log +def display_exception(message: str): + _len = len("Agenta - ") + len(message) + len(":") + _bar = "-" * _len + + log.warning(_bar) + log.warning("Agenta - %s:", message) + log.warning(_bar) + log.warning(format_exc().strip("\n")) + log.warning(_bar) + + class suppress(AbstractContextManager): # pylint: disable=invalid-name def __init__(self): pass @@ -14,15 +25,10 @@ def __enter__(self): pass def __exit__(self, exc_type, exc_value, exc_tb): - if exc_type is None: - return True - else: - log.warning("-------------------------------------------------") - log.warning("Agenta SDK - suppressing tracing exception below:") - log.warning("-------------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("-------------------------------------------------") - return True + if exc_type is not None: + display_exception("Exception (suppressed)") + + return True def handle_exceptions(): @@ -33,12 +39,10 @@ def decorator(func): async def async_wrapper(*args, **kwargs): try: return await func(*args, **kwargs) + except Exception as e: - log.warning("------------------------------------------") - log.warning("Agenta SDK - intercepting exception below:") - log.warning("------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------") + display_exception("Exception") + raise e @wraps(func) @@ -46,11 +50,8 @@ def sync_wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - log.warning("------------------------------------------") - log.warning("Agenta SDK - intercepting exception below:") - log.warning("------------------------------------------") - log.warning(format_exc().strip("\n")) - log.warning("------------------------------------------") + display_exception("Exception") + raise e return async_wrapper if is_coroutine_function else sync_wrapper diff --git a/agenta-cli/agenta/sdk/utils/globals.py b/agenta-cli/agenta/sdk/utils/globals.py index f05141e08..ceae07642 100644 --- a/agenta-cli/agenta/sdk/utils/globals.py +++ b/agenta-cli/agenta/sdk/utils/globals.py @@ -1,14 +1,10 @@ -import agenta +import agenta as ag def set_global(config=None, tracing=None): - """Allows usage of agenta.config and agenta.tracing in the user's code. + """Allows usage of agenta.config and agenta.tracing in the user's code.""" - Args: - config: _description_. Defaults to None. - tracing: _description_. Defaults to None. - """ if config is not None: - agenta.config = config + ag.config = config if tracing is not None: - agenta.tracing = tracing + ag.tracing = tracing diff --git a/agenta-cli/agenta/sdk/utils/timing.py b/agenta-cli/agenta/sdk/utils/timing.py new file mode 100644 index 000000000..c73b5f210 --- /dev/null +++ b/agenta-cli/agenta/sdk/utils/timing.py @@ -0,0 +1,58 @@ +import time +from functools import wraps + +from agenta.sdk.utils.logging import log + + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper + + +def atimeit(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = time.time() + result = await func(*args, **kwargs) + end_time = time.time() + + execution_time = end_time - start_time + + if execution_time < 1e-3: + time_value = execution_time * 1e6 + unit = "us" + elif execution_time < 1: + time_value = execution_time * 1e3 + unit = "ms" + else: + time_value = execution_time + unit = "s" + + class_name = args[0].__class__.__name__ if args else None + + log.info(f"'{class_name}.{func.__name__}' executed in {time_value:.4f} {unit}.") + return result + + return wrapper diff --git a/agenta-cli/pyproject.toml b/agenta-cli/pyproject.toml index 0eb01187a..52b0a9069 100644 --- a/agenta-cli/pyproject.toml +++ b/agenta-cli/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta" -version = "0.30.0" +version = "0.31.0" description = "The SDK for agenta is an open-source LLMOps platform." readme = "README.md" authors = ["Mahmoud Mabrouk "] diff --git a/agenta-cli/tests/baggage/_main.py b/agenta-cli/tests/baggage/_main.py new file mode 100644 index 000000000..4040d2adb --- /dev/null +++ b/agenta-cli/tests/baggage/_main.py @@ -0,0 +1,8 @@ +from uvicorn import run + +import app # pylint: disable=unused-import + +import agenta # pylint: disable=unused-import + +if __name__ == "__main__": + run("agenta:app", host="0.0.0.0", port=8888, reload=True) diff --git a/agenta-cli/tests/baggage/agenta b/agenta-cli/tests/baggage/agenta new file mode 120000 index 000000000..d77f00d8a --- /dev/null +++ b/agenta-cli/tests/baggage/agenta @@ -0,0 +1 @@ +/Users/junaway/Agenta/github/agenta-sandbox/baggage/agenta \ No newline at end of file diff --git a/agenta-cli/tests/baggage/app.py b/agenta-cli/tests/baggage/app.py new file mode 100644 index 000000000..465b65b0a --- /dev/null +++ b/agenta-cli/tests/baggage/app.py @@ -0,0 +1,17 @@ +import agenta as ag + +ag.init(config_fname="config.toml") + +ag.config.default( + flag=ag.BinaryParam(value=False), +) + + +@ag.route() +@ag.instrument() +def main(aloha: str = "Aloha") -> str: + print(ag.ConfigManager.get_from_route()) + print(ag.SecretsManager.get_from_route()) + print(ag.config.flag) + + return aloha diff --git a/agenta-cli/tests/baggage/config.toml b/agenta-cli/tests/baggage/config.toml new file mode 100644 index 000000000..f32346649 --- /dev/null +++ b/agenta-cli/tests/baggage/config.toml @@ -0,0 +1,4 @@ +app_name = "baggage" +app_id = "0193b67a-b673-7919-85c2-0b5b0a2183d3" +backend_host = "http://localhost" +api_key = "XELnjVve.c1f177c87250b603cf1ed2a69ebdfc1cec3124776058e7afcbba93890c515e74" diff --git a/agenta-cli/tests/baggage/specs/check_generate.py b/agenta-cli/tests/baggage/specs/check_generate.py new file mode 100644 index 000000000..e07277b74 --- /dev/null +++ b/agenta-cli/tests/baggage/specs/check_generate.py @@ -0,0 +1,82 @@ +import pytest +import httpx +import os + +BASE_URL = os.getenv("BASE_URL", None) or None +API_KEY = os.getenv("API_KEY", None) or None + +# 200 +# 401 +# 403 +# 405 +# 422 +# 500 + + +def test_unauth_generate(): + """Test /generate without credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/generate") + + assert ( + response.status_code == 401 + ), f"Expected status 401, got {response.status_code}" + + data = response.json() + + assert ( + data["detail"] == "Missing 'authorization' header." + ), f'Expected "Missing \'authorization\' header.", got "{data["detail"]}"' + + +# REQUIRES +# - a valid API key -> API endpoint to create a new API key +# - a valid APP_ID -> API endpoint to create a an app from hooks +def test_auth_generate(): + """Test /generate with credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + assert ( + API_KEY is not None + ), "API KEY environment variable must be set to run this test" + + response = httpx.post( + f"{BASE_URL}/generate", + headers={"Authorization": API_KEY}, + json={ + "aloha": "mahalo", + }, + ) + + assert ( + response.status_code == 200 + ), f"Expected status 200, got {response.status_code}" + + data = response.json() + + assert "data" in data, "Expected 'data' key in response JSON" + + assert "mahalo" in data["data"], "Expected data:'mahalo' in response JSON" + + assert "tree" in data, "Expected 'tree' key in response JSON" + + assert "nodes" in data["tree"], "Expected tree:'nodes' in response JSON" + + assert ( + len(data["tree"]["nodes"]) == 1 + ), "Expected tree:'nodes' length 1 in response JSON" + + assert ( + "inputs" in data["tree"]["nodes"][0]["data"] + ), "Expected tree:'nodes':'inputs' in response JSON" + + assert ( + "outputs" in data["tree"]["nodes"][0]["data"] + ), "Expected tree:'nodes':'outputs' in response JSON" diff --git a/agenta-cli/tests/baggage/specs/check_openapi.py b/agenta-cli/tests/baggage/specs/check_openapi.py new file mode 100644 index 000000000..753131380 --- /dev/null +++ b/agenta-cli/tests/baggage/specs/check_openapi.py @@ -0,0 +1,51 @@ +import pytest +import httpx +import os + +BASE_URL = os.getenv("BASE_URL", None) or None +API_KEY = os.getenv("API_KEY", None) or None + + +def test_unauth_openapi(): + """Test /openapi.json without credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/openapi.json") + + assert ( + response.status_code == 401 + ), f"Expected status 401, got {response.status_code}" + + data = response.json() + + assert ( + data["detail"] == "Missing 'authorization' header." + ), f'Expected "Missing \'authorization\' header.", got "{data["detail"]}"' + + +# REQUIRES +# - a valid API key -> API endpoint to create a new API key +# - a valid APP_ID -> API endpoint to create a an app from hooks +def test_auth_openapi(): + """Test /openapi.json with credentials for status 401.""" + + assert ( + BASE_URL is not None + ), "BASE_URL environment variable must be set to run this test" + + assert ( + API_KEY is not None + ), "API KEY environment variable must be set to run this test" + + response = httpx.get(f"{BASE_URL}/openapi.json", headers={"Authorization": API_KEY}) + + assert ( + response.status_code == 200 + ), f"Expected status 200, got {response.status_code}" + + data = response.json() + + assert "openapi" in data, "Expected 'openapi' key in response JSON" diff --git a/agenta-cli/tests/run_pytest.sh b/agenta-cli/tests/run_pytest.sh new file mode 100755 index 000000000..19bdf1163 --- /dev/null +++ b/agenta-cli/tests/run_pytest.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Define default values +TEST_TARGET="specs/*" +PYTEST_OPTIONS="" +MARKERS="" +APP="" + +# Function to display usage +usage() { + echo "Usage: $0 [-t test_target] [-o pytest_options] [-m markers] [-a app]" + echo " -t test_target Specify the pytest test target to run. Default is 'specs/'." + echo " -o pytest_options Pass additional options to pytest." + echo " -m markers Specify marker expressions (e.g., 'smoke or integration')." + echo " -a app Specify the FastAPI app to run." + exit 1 +} + +# Parse command-line arguments +while getopts "t:o:m:a:" opt; do + case ${opt} in + t) TEST_TARGET="$OPTARG" ;; + o) PYTEST_OPTIONS="$OPTARG" ;; + m) MARKERS="$OPTARG" ;; + a) APP="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +TEST_TARGET="./apps/${APP}/${TEST_TARGET}" + +# Build marker expression if markers are specified +if [[ -n "$MARKERS" ]]; then + MARKER_EXPR="-m \"$MARKERS\"" +fi + +# Run pytest with the specified options +echo "Running pytest tests in $TEST_TARGET with options: $PYTEST_OPTIONS $MARKER_EXPR" +eval pytest "$TEST_TARGET" $PYTEST_OPTIONS $MARKER_EXPR \ No newline at end of file diff --git a/agenta-cli/tests/run_tests.sh b/agenta-cli/tests/run_tests.sh new file mode 100755 index 000000000..a3f5e82f1 --- /dev/null +++ b/agenta-cli/tests/run_tests.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +# Define default values for the server +HOST="127.0.0.1" +PORT="8888" +TEST_TARGET="specs/*" +PYTEST_OPTIONS="" +MARKERS="" +APP="" + +# Function to display usage +usage() { + echo "Usage: $0 [-h host] [-p port] [-t test_target] [-o pytest_options] [-m markers] [-a app] [-k key]" + echo " -h host Specify the FastAPI server host. Default is 127.0.0.1." + echo " -p port Specify the FastAPI server port. Default is 8000." + echo " -t test_target Specify the pytest test target to run. Default is 'specs'." + echo " -o pytest_options Pass additional options to pytest." + echo " -m markers Specify marker expressions (e.g., 'smoke or integration')." + echo " -a app Specify the FastAPI app to run." + echo " -k key Specify the API key." + exit 1 +} + +# Parse command-line arguments +while getopts "h:p:t:o:m:a:k:" opt; do + case ${opt} in + h) HOST="$OPTARG" ;; + p) PORT="$OPTARG" ;; + t) TEST_TARGET="$OPTARG" ;; + o) PYTEST_OPTIONS="$OPTARG" ;; + m) MARKERS="$OPTARG" ;; + a) APP="$OPTARG" ;; + k) API_KEY="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +# Start the FastAPI server +./start_server.sh -h "$HOST" -p "$PORT" -a "$APP" + +# Export the base URL as an environment variable +export BASE_URL="http://${HOST}:${PORT}" + +# Export the API key as an environment variable +export API_KEY="$API_KEY" + +# Run pytest tests with markers +./run_pytest.sh -t "$TEST_TARGET" -o "$PYTEST_OPTIONS" -m "$MARKERS" -a "$APP" + +# Stop the FastAPI server +./stop_server.sh \ No newline at end of file diff --git a/agenta-cli/tests/start_server.sh b/agenta-cli/tests/start_server.sh new file mode 100755 index 000000000..f6b4129ef --- /dev/null +++ b/agenta-cli/tests/start_server.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# Define default values +HOST="127.0.0.1" +PORT="8888" +APP= + +# Function to display usage +usage() { + echo "Usage: $0 [-h host] [-p port] [-a app]" + echo " -h host Specify the FastAPI server host. Default is 127.0.0.1." + echo " -p port Specify the FastAPI server port. Default is 8000." + echo " -a app Specify the FastAPI app to run. Default is 'baggage'." + exit 1 +} + +# Parse command-line arguments +while getopts "h:p:a:" opt; do + case ${opt} in + h) HOST="$OPTARG" ;; + p) PORT="$OPTARG" ;; + a) APP="$OPTARG" ;; + *) usage ;; + esac +done + +if [[ -z "$APP" ]]; then + echo "Error: Please specify the FastAPI app to run with the -a option." + usage +fi + +# Start the FastAPI server +echo "Starting FastAPI server at http://${HOST}:${PORT}..." +#uvicorn app:app --host "${HOST}" --port "${PORT}" & +cd ./apps/${APP} +python3 _main.py & +SERVER_PID=$! +echo "Server PID: $SERVER_PID" +echo $SERVER_PID > ../../server.pid # Save PID to a file for later use +sleep 3 # Wait for the server to start \ No newline at end of file diff --git a/agenta-cli/tests/stop_server.sh b/agenta-cli/tests/stop_server.sh new file mode 100755 index 000000000..1678d850d --- /dev/null +++ b/agenta-cli/tests/stop_server.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Function to display usage +usage() { + echo "Usage: $0" + echo "Stops the FastAPI server started with start_server.sh." + exit 1 +} + +# Check if the PID file exists +if [[ ! -f server.pid ]]; then + echo "Error: PID file 'server.pid' not found. Is the server running?" + exit 1 +fi + +# Read the PID from the file +SERVER_PID=$(cat server.pid) + +# Validate that the PID is a running process +if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Stopping FastAPI server (PID: $SERVER_PID)..." + kill "$SERVER_PID" # Send the termination signal + + # Wait for the process to terminate + sleep 2 + + # Double-check if the process is still running + if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Error: Failed to stop the server. Attempting to force stop..." + kill -9 "$SERVER_PID" # Force kill the process + sleep 1 + + if ps -p "$SERVER_PID" > /dev/null 2>&1; then + echo "Error: Unable to stop the server process even with force. Manual intervention required." + exit 1 + fi + fi + + echo "FastAPI server stopped successfully." + rm -f server.pid # Remove the PID file +else + echo "Error: No process found with PID $SERVER_PID. Removing stale PID file." + rm -f server.pid + exit 1 +fi \ No newline at end of file diff --git a/agenta-web/cypress/support/commands/evaluations.ts b/agenta-web/cypress/support/commands/evaluations.ts index 57004db6d..1747649de 100644 --- a/agenta-web/cypress/support/commands/evaluations.ts +++ b/agenta-web/cypress/support/commands/evaluations.ts @@ -50,7 +50,8 @@ Cypress.Commands.add("createVariantsAndTestsets", () => { cy.wrap(testsetName).as("testsetName") cy.get(".ag-row").should("have.length", 1) - cy.get('[data-cy="testset-header-column-edit-button"]').eq(0).click() + cy.wait(2000) + cy.get('[data-cy="testset-header-column-edit-button"]').eq(0).should("exist").click() cy.get('[data-cy="testset-header-column-edit-input"]').clear() cy.get('[data-cy="testset-header-column-edit-input"]').type("country") cy.get('[data-cy="testset-header-column-save-button"]').click() diff --git a/agenta-web/package-lock.json b/agenta-web/package-lock.json index d79b676fb..af84fca67 100644 --- a/agenta-web/package-lock.json +++ b/agenta-web/package-lock.json @@ -1,12 +1,12 @@ { "name": "agenta", - "version": "0.30.0", + "version": "0.31.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "agenta", - "version": "0.30.0", + "version": "0.31.0", "dependencies": { "@ag-grid-community/client-side-row-model": "^31.3.4", "@ag-grid-community/core": "^31.3.4", @@ -85,7 +85,7 @@ "cypress": "^13.15.0", "husky": "^9.1.7", "node-mocks-http": "^1.12.2", - "prettier": "^3.2.5", + "prettier": "3.2.5", "prettier-plugin-tailwindcss": "^0.1" }, "engines": { @@ -14770,9 +14770,9 @@ } }, "node_modules/prettier": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", - "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "version": "3.2.5", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.2.5.tgz", + "integrity": "sha512-3/GWa9aOC0YeD7LUfvOG2NiDyhOWRvt1k+rcKhOuYnMY24iiCphgneUfJDyFXd6rZCAnuLBv6UeAULtrhT/F4A==", "license": "MIT", "bin": { "prettier": "bin/prettier.cjs" diff --git a/agenta-web/package.json b/agenta-web/package.json index 624a54174..48af4de39 100644 --- a/agenta-web/package.json +++ b/agenta-web/package.json @@ -1,6 +1,6 @@ { "name": "agenta", - "version": "0.30.0", + "version": "0.31.0", "private": true, "engines": { "node": ">=18" @@ -99,7 +99,7 @@ "cypress": "^13.15.0", "husky": "^9.1.7", "node-mocks-http": "^1.12.2", - "prettier": "^3.2.5", + "prettier": "3.2.5", "prettier-plugin-tailwindcss": "^0.1" } } diff --git a/agenta-web/src/components/pages/app-management/index.tsx b/agenta-web/src/components/pages/app-management/index.tsx index 6a69ef89a..f5822b29d 100644 --- a/agenta-web/src/components/pages/app-management/index.tsx +++ b/agenta-web/src/components/pages/app-management/index.tsx @@ -8,7 +8,7 @@ import {createUseStyles} from "react-jss" import {useAppsData} from "@/contexts/app.context" import {useProfileData} from "@/contexts/profile.context" import {usePostHogAg} from "@/lib/helpers/analytics/hooks/usePostHogAg" -import {LlmProvider, getAllProviderLlmKeys} from "@/lib/helpers/llmProviders" +import {type LlmProvider} from "@/lib/helpers/llmProviders" import {dynamicComponent, dynamicContext} from "@/lib/helpers/dynamic" import dayjs from "dayjs" import {useAppTheme} from "@/components/Layout/ThemeContextProvider" @@ -17,6 +17,7 @@ import GetStartedSection from "./components/GetStartedSection" import ApplicationManagementSection from "./components/ApplicationManagementSection" import ResultComponent from "@/components/ResultComponent/ResultComponent" import {useProjectData} from "@/contexts/project.context" +import {useVaultSecret} from "@/hooks/useVaultSecret" const CreateAppStatusModal: any = dynamicComponent( "pages/app-management/modals/CreateAppStatusModal", @@ -83,6 +84,7 @@ const AppManagement: React.FC = () => { details: undefined, appId: undefined, }) + const {secrets} = useVaultSecret() const {project} = useProjectData() const [useOrgData, setUseOrgData] = useState(() => () => "") @@ -118,7 +120,7 @@ const AppManagement: React.FC = () => { setStatusModalOpen(true) // attempt to create and start the template, notify user of the progress - const apiKeys = getAllProviderLlmKeys() + const apiKeys = secrets await createAndStartTemplate({ appName: newApp, templateId: template_id, diff --git a/agenta-web/src/components/pages/evaluations/NewEvaluation/NewEvaluationModal.tsx b/agenta-web/src/components/pages/evaluations/NewEvaluation/NewEvaluationModal.tsx index 68cb03539..f1f2b42f6 100644 --- a/agenta-web/src/components/pages/evaluations/NewEvaluation/NewEvaluationModal.tsx +++ b/agenta-web/src/components/pages/evaluations/NewEvaluation/NewEvaluationModal.tsx @@ -14,6 +14,7 @@ import SelectTestsetSection from "./SelectTestsetSection" import SelectVariantSection from "./SelectVariantSection" import SelectEvaluatorSection from "./SelectEvaluatorSection" import {dynamicComponent} from "@/lib/helpers/dynamic" +import {useVaultSecret} from "@/hooks/useVaultSecret" const AdvancedSettingsPopover: any = dynamicComponent( "pages/evaluations/NewEvaluation/AdvancedSettingsPopover", @@ -73,6 +74,7 @@ const NewEvaluationModal: React.FC = ({onSuccess, ...props}) => { const [selectedTestsetId, setSelectedTestsetId] = useState("") const [selectedVariantIds, setSelectedVariantIds] = useState([]) const [selectedEvalConfigs, setSelectedEvalConfigs] = useState([]) + const {secrets} = useVaultSecret() const [activePanel, setActivePanel] = useState("testsetPanel") const handlePanelChange = (key: string | string[]) => { @@ -163,7 +165,7 @@ const NewEvaluationModal: React.FC = ({onSuccess, ...props}) => { variant_ids: selectedVariantIds, evaluators_configs: selectedEvalConfigs, rate_limit: rateLimitValues, - lm_providers_keys: apiKeyObject(), + lm_providers_keys: apiKeyObject(secrets), correct_answer_column: correctAnswerColumn, }) .then(onSuccess) diff --git a/agenta-web/src/components/pages/settings/Secrets/Secrets.tsx b/agenta-web/src/components/pages/settings/Secrets/Secrets.tsx index b28f7dbf6..e4cb775c7 100644 --- a/agenta-web/src/components/pages/settings/Secrets/Secrets.tsx +++ b/agenta-web/src/components/pages/settings/Secrets/Secrets.tsx @@ -1,19 +1,25 @@ -import { - getLlmProviderKey, - saveLlmProviderKey, - removeSingleLlmProviderKey, - getAllProviderLlmKeys, - LlmProvider, -} from "@/lib/helpers/llmProviders" +import {useVaultSecret} from "@/hooks/useVaultSecret" +import {type LlmProvider} from "@/lib/helpers/llmProviders" import {Button, Input, Space, Typography, message} from "antd" -import {useState} from "react" +import {useEffect, useState} from "react" const {Title, Text} = Typography export default function Secrets() { - const [llmProviderKeys, setLlmProviderKeys] = useState(getAllProviderLlmKeys()) + const {secrets, handleModifyVaultSecret, handleDeleteVaultSecret} = useVaultSecret() + const [llmProviderKeys, setLlmProviderKeys] = useState([]) + const [loadingSecrets, setLoadingSecrets] = useState>({}) const [messageAPI, contextHolder] = message.useMessage() + useEffect(() => { + setLlmProviderKeys(secrets) + }, [secrets]) + + const setSecretLoading = (id: string | undefined, isLoading: boolean) => { + if (!id) return + setLoadingSecrets((prev) => ({...prev, [id]: isLoading})) + } + return (
{contextHolder} @@ -30,47 +36,70 @@ export default function Secrets() { Available Providers
- {llmProviderKeys.map(({title, key}: LlmProvider, i: number) => ( - - { - const newLlmProviderKeys = [...llmProviderKeys] - newLlmProviderKeys[i].key = e.target.value - setLlmProviderKeys(newLlmProviderKeys) - }} - addonBefore={`${title}`} - visibilityToggle={false} - className={"w-[420px]"} - /> - - - - ))} + {llmProviderKeys.map( + ({name, title, key, id: secretId}: LlmProvider, i: number) => ( + + { + const newLlmProviderKeys = [...llmProviderKeys] + newLlmProviderKeys[i].key = e.target.value + setLlmProviderKeys(newLlmProviderKeys) + }} + addonBefore={`${title}`} + visibilityToggle={false} + className={"w-[420px]"} + /> + + + + ), + )}
diff --git a/agenta-web/src/hooks/useVaultSecret.ts b/agenta-web/src/hooks/useVaultSecret.ts new file mode 100644 index 000000000..499413097 --- /dev/null +++ b/agenta-web/src/hooks/useVaultSecret.ts @@ -0,0 +1,151 @@ +import {useEffect, useRef, useState} from "react" +import { + getAllProviderLlmKeys, + llmAvailableProviders, + llmAvailableProvidersToken, + LlmProvider, + removeSingleLlmProviderKey, + saveLlmProviderKey, +} from "@/lib/helpers/llmProviders" +import {isDemo} from "@/lib/helpers/utils" +import {dynamicLib, dynamicService} from "@/lib/helpers/dynamic" + +export const useVaultSecret = () => { + const [secrets, setSecrets] = useState(llmAvailableProviders) + const shouldRunMigration = useRef(true) + + const getVaultSecrets = async () => { + try { + if (isDemo()) { + const {fetchVaultSecret} = await dynamicService("vault/api") + const data = await fetchVaultSecret() + + setSecrets((prevSecret) => { + return prevSecret.map((secret) => { + const match = data.find((item: LlmProvider) => item.name === secret.name) + if (match) { + return { + ...secret, + key: match.key, + id: match.id, + } + } else { + return secret + } + }) + }) + } else { + setSecrets(getAllProviderLlmKeys()) + } + } catch (error) { + console.error(error) + } + } + + useEffect(() => { + getVaultSecrets() + }, []) + + const migrateProviderKeys = async () => { + try { + const localStorageProviders = localStorage.getItem(llmAvailableProvidersToken) + + if (localStorageProviders) { + const providers = JSON.parse(localStorageProviders) + + for (const provider of providers) { + if (provider.key) { + await handleModifyVaultSecret(provider as LlmProvider) + } + } + + localStorage.setItem(`${llmAvailableProvidersToken}Backup`, localStorageProviders) + + localStorage.removeItem(llmAvailableProvidersToken) + } + } catch (error) { + console.error(error) + } + } + + useEffect(() => { + if (shouldRunMigration.current) { + shouldRunMigration.current = false + if (isDemo()) { + migrateProviderKeys() + } + } + }, []) + + const handleModifyVaultSecret = async (provider: LlmProvider) => { + try { + if (isDemo()) { + const {updateVaultSecret, createVaultSecret} = await dynamicService("vault/api") + const {SecretDTOProvider, SecretDTOKind} = await dynamicLib("types_ee") + + const envNameMap: Record = { + OPENAI_API_KEY: SecretDTOProvider.OPENAI, + COHERE_API_KEY: SecretDTOProvider.COHERE, + ANYSCALE_API_KEY: SecretDTOProvider.ANYSCALE, + DEEPINFRA_API_KEY: SecretDTOProvider.DEEPINFRA, + ALEPHALPHA_API_KEY: SecretDTOProvider.ALEPHALPHA, + GROQ_API_KEY: SecretDTOProvider.GROQ, + MISTRAL_API_KEY: SecretDTOProvider.MISTRAL, + ANTHROPIC_API_KEY: SecretDTOProvider.ANTHROPIC, + PERPLEXITYAI_API_KEY: SecretDTOProvider.PERPLEXITYAI, + TOGETHERAI_API_KEY: SecretDTOProvider.TOGETHERAI, + OPENROUTER_API_KEY: SecretDTOProvider.OPENROUTER, + GEMINI_API_KEY: SecretDTOProvider.GEMINI, + } + + const payload = { + header: { + name: provider.title, + description: "", + }, + secret: { + kind: SecretDTOKind.PROVIDER_KEY, + data: { + provider: envNameMap[provider.name], + key: provider.key, + }, + }, + } + + const findSecret = secrets.find((s) => s.name === provider.name) + + if (findSecret && provider.id) { + await updateVaultSecret({secret_id: provider.id, payload}) + } else { + await createVaultSecret({payload}) + } + + await getVaultSecrets() + } else { + saveLlmProviderKey(provider.title, provider.key) + } + } catch (error) { + console.error(error) + } + } + + const handleDeleteVaultSecret = async (provider: LlmProvider) => { + try { + if (isDemo() && provider.id) { + const {deleteVaultSecret} = await dynamicService("vault/api") + await deleteVaultSecret({secret_id: provider.id}) + await getVaultSecrets() + } else { + removeSingleLlmProviderKey(provider.title) + } + } catch (error) { + console.error(error) + } + } + + return { + secrets, + handleModifyVaultSecret, + handleDeleteVaultSecret, + } +} diff --git a/agenta-web/src/lib/helpers/dynamic.ts b/agenta-web/src/lib/helpers/dynamic.ts index 19f5a003c..effbc89ea 100644 --- a/agenta-web/src/lib/helpers/dynamic.ts +++ b/agenta-web/src/lib/helpers/dynamic.ts @@ -30,3 +30,11 @@ export async function dynamicService(path: string, fallback?: any) { return fallback } } + +export async function dynamicLib(path: string, fallback?: any) { + try { + return await import(`@/lib/${path}`) + } catch (error) { + return fallback + } +} diff --git a/agenta-web/src/lib/helpers/llmProviders.ts b/agenta-web/src/lib/helpers/llmProviders.ts index 5d2b3105f..4c750a336 100644 --- a/agenta-web/src/lib/helpers/llmProviders.ts +++ b/agenta-web/src/lib/helpers/llmProviders.ts @@ -1,12 +1,12 @@ import cloneDeep from "lodash/cloneDeep" -import {camelToSnake} from "./utils" -const llmAvailableProvidersToken = "llmAvailableProvidersToken" +export const llmAvailableProvidersToken = "llmAvailableProvidersToken" export type LlmProvider = { title: string key: string name: string + id?: string } export const llmAvailableProviders: LlmProvider[] = [ @@ -55,9 +55,6 @@ export const saveLlmProviderKey = (providerName: string, keyValue: string) => { localStorage.setItem(llmAvailableProvidersToken, JSON.stringify(keys)) } -export const getLlmProviderKey = (providerName: string) => - getAllProviderLlmKeys().find((item: LlmProvider) => item.title === providerName)?.key - export const getAllProviderLlmKeys = () => { const providers = cloneDeep(llmAvailableProviders) try { diff --git a/agenta-web/src/lib/helpers/utils.ts b/agenta-web/src/lib/helpers/utils.ts index 1842365e3..289159993 100644 --- a/agenta-web/src/lib/helpers/utils.ts +++ b/agenta-web/src/lib/helpers/utils.ts @@ -7,7 +7,7 @@ import dayjs from "dayjs" import utc from "dayjs/plugin/utc" import {notification} from "antd" import Router from "next/router" -import {getAllProviderLlmKeys, getApikeys} from "./llmProviders" +import {getApikeys, LlmProvider} from "./llmProviders" import yaml from "js-yaml" if (typeof window !== "undefined") { @@ -50,9 +50,7 @@ export const EvaluationTypeLabels: Record = { [EvaluationType.rag_context_relevancy]: "RAG Context Relevancy", } -export const apiKeyObject = () => { - const apiKeys = getAllProviderLlmKeys() - +export const apiKeyObject = (apiKeys: LlmProvider[]) => { if (!apiKeys) return {} return apiKeys.reduce((acc: GenericObject, {key, name}: GenericObject) => {