diff --git a/weave/__init__.py b/weave/__init__.py index 3a21558ec465..4a1d37aa6817 100644 --- a/weave/__init__.py +++ b/weave/__init__.py @@ -14,8 +14,8 @@ from weave.flow.model import Model from weave.flow.obj import Object from weave.flow.prompt.prompt import EasyPrompt, MessagesPrompt, Prompt, StringPrompt +from weave.flow.provider import Provider as Provider from weave.flow.scorer import Scorer -from weave.flow.provider import Provider from weave.initialization import * from weave.trace.util import Thread as Thread from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor diff --git a/weave/flow/provider.py b/weave/flow/provider.py index 8f14fb51c117..34569730ac65 100644 --- a/weave/flow/provider.py +++ b/weave/flow/provider.py @@ -1,8 +1,9 @@ from enum import Enum from typing import Any, Optional, TypedDict, Union -from typing_extensions import Self + from pydantic import Field from rich.table import Table +from typing_extensions import Self from weave.flow.obj import Object from weave.trace.api import publish as weave_publish diff --git a/weave/trace_server/clickhouse_trace_server_batched.py b/weave/trace_server/clickhouse_trace_server_batched.py index 958ce3e032de..12e9fcc06968 100644 --- a/weave/trace_server/clickhouse_trace_server_batched.py +++ b/weave/trace_server/clickhouse_trace_server_batched.py @@ -87,8 +87,9 @@ from weave.trace_server.file_storage_uris import FileStorageURI from weave.trace_server.ids import generate_id from weave.trace_server.llm_completion import ( - lite_llm_completion, CUSTOM_PROVIDER_PREFIX, + get_custom_provider_info, + lite_llm_completion, ) from weave.trace_server.model_providers.model_providers import ( read_model_to_provider_info_map, @@ -1613,106 +1614,35 @@ def actions_execute_batch( def completions_create( self, req: tsi.CompletionsCreateReq ) -> tsi.CompletionsCreateRes: + # Required fields model_name = req.inputs.model api_key = None provider = None - secret_name = None - base_url = None - extra_headers = {} - return_type = None - - # Handle custom provider case - if CUSTOM_PROVIDER_PREFIX in model_name: - # Parse the model name to extract provider_id and provider_model_id - # Format: __weave_custom_provider__// - parts = model_name.split("/") - if len(parts) < 3: - raise InvalidRequest( - f"Invalid custom provider model format: {model_name}" - ) - - provider_id = parts[1] - provider_model_id = parts[2] - - # Fetch the provider object - try: - provider_obj_req = tsi.ObjReadReq( - project_id=req.project_id, - object_id=provider_id, - digest="latest", - metadata_only=False, - ) - provider_obj_res = self.obj_read(provider_obj_req) - provider_obj = provider_obj_res.obj - - if provider_obj.base_object_class != "Provider": - raise InvalidRequest( - f"Object {provider_id} is not a Provider, it is a {provider_obj.base_object_class}" - ) - - # Extract provider information - base_url = provider_obj.val.get("base_url") - secret_name = provider_obj.val.get("api_key_name") - extra_headers = provider_obj.val.get("extra_headers", {}) - return_type = provider_obj.val.get("return_type", "openai") - # Fetch the provider model object - provider_model_obj_req = tsi.ObjReadReq( - project_id=req.project_id, - object_id=f"{provider_id}-{provider_model_id}", - digest="latest", - metadata_only=False, - ) - provider_model_obj_res = self.obj_read(provider_model_obj_req) - provider_model_obj = provider_model_obj_res.obj - - if provider_model_obj.base_object_class != "ProviderModel": - raise InvalidRequest( - f"Object {provider_model_id} is not a ProviderModel, it is a {provider_model_obj.base_object_class}" - ) - - # Use the provider model's name as the actual model name for the API call - req.inputs.model = provider_model_obj.val.get("name") + # Custom model fields + base_url: Optional[str] = None + extra_headers: dict[str, str] = {} + return_type: Optional[str] = None - except Exception as e: - raise InvalidRequest( - f"Failed to fetch provider or model information: {str(e)}" - ) + # For custom and standard models, we fetch the fields differently + # 1. Standard models: All of the information comes from the model_to_provider_info_map + # 2. Custom models: We fetch the provider object and provider model object + if CUSTOM_PROVIDER_PREFIX not in model_name: + # Handle standard model case + # 1. We get the model info from the map + # 2. We fetch the API key, with the secret fetcher + # 3. We set the provider, to the litellm provider + # 4. If no api key, we raise an error, except for bedrock and bedrock_converse (we fetch bedrock credentials, in lite_llm_completion) - # Get the API key secret_fetcher = _secret_fetcher_context.get() if not secret_fetcher: raise InvalidRequest( f"No secret fetcher found, cannot fetch API key for model {model_name}" ) - - if not secret_name: - raise InvalidRequest(f"No secret name found for provider {provider_id}") - - api_key = ( - secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name) - ) - - if not api_key: - raise MissingLLMApiKeyError( - f"No API key {secret_name} found for provider {provider_id}", - api_key_name=secret_name, - ) - - provider = "custom" # Use "custom" as the provider for litellm - - else: - # Handle standard model case model_info = self._model_to_provider_info_map.get(model_name) if not model_info: raise InvalidRequest(f"No model info found for model {model_name}") - secret_fetcher = _secret_fetcher_context.get() - if not secret_fetcher: - raise InvalidRequest( - f"No secret fetcher found, cannot fetch API key for model {model_name}" - ) - secret_name = model_info.get("api_key_name") if not secret_name: raise InvalidRequest(f"No secret name found for model {model_name}") @@ -1722,12 +1652,34 @@ def completions_create( ) provider = model_info.get("litellm_provider", "openai") + # We fetch bedrock credentials, in lite_llm_completion, later if not api_key and provider != "bedrock" and provider != "bedrock_converse": raise MissingLLMApiKeyError( f"No API key {secret_name} found for model {model_name}", api_key_name=secret_name, ) + else: + # Handle custom provider case + # We fetch the provider object and provider model object + ( + base_url, + api_key, + extra_headers, + return_type, + actual_model_name, + ) = get_custom_provider_info( + project_id=req.project_id, + model_name=model_name, + obj_read_func=self.obj_read, + ) + + # Always use "custom" as the provider for litellm + provider = "custom" + # Update the model name for the API call + req.inputs.model = actual_model_name + + # Now that we have all the fields for both cases, we can make the API call start_time = datetime.datetime.now() # Make the API call diff --git a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py index a0761b6ee62f..37d3ac6aec07 100644 --- a/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py +++ b/weave/trace_server/interface/builtin_object_classes/builtin_object_registry.py @@ -7,8 +7,8 @@ ) from weave.trace_server.interface.builtin_object_classes.leaderboard import Leaderboard from weave.trace_server.interface.builtin_object_classes.llm_model import ( - ProviderModel, LLMModel, + ProviderModel, ) from weave.trace_server.interface.builtin_object_classes.provider import Provider from weave.trace_server.interface.builtin_object_classes.test_only_example import ( diff --git a/weave/trace_server/interface/builtin_object_classes/llm_model.py b/weave/trace_server/interface/builtin_object_classes/llm_model.py index a92e151de0ec..f7a62bf50a0b 100644 --- a/weave/trace_server/interface/builtin_object_classes/llm_model.py +++ b/weave/trace_server/interface/builtin_object_classes/llm_model.py @@ -1,5 +1,6 @@ from enum import Enum from typing import Optional, Union + from pydantic import BaseModel from weave.trace_server.interface.builtin_object_classes import base_object_def @@ -20,14 +21,12 @@ class ModelParams(BaseModel): class ProviderModel(base_object_def.BaseObject): - name: str provider: base_object_def.RefStr max_tokens: int mode: ModelMode = ModelMode.CHAT class LLMModel(base_object_def.BaseObject): - name: str provider_model: base_object_def.RefStr prompt: Optional[base_object_def.RefStr] = None default_params: ModelParams = ModelParams() diff --git a/weave/trace_server/interface/builtin_object_classes/provider.py b/weave/trace_server/interface/builtin_object_classes/provider.py index 20dcbaffc927..93c38f529492 100644 --- a/weave/trace_server/interface/builtin_object_classes/provider.py +++ b/weave/trace_server/interface/builtin_object_classes/provider.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Optional, Dict from weave.trace_server.interface.builtin_object_classes import base_object_def @@ -9,9 +8,7 @@ class ReturnType(str, Enum): class Provider(base_object_def.BaseObject): - name: str base_url: str api_key_name: str - description: Optional[str] = None - extra_headers: Dict[str, str] = {} + extra_headers: dict[str, str] = {} return_type: ReturnType = ReturnType.OPENAI diff --git a/weave/trace_server/llm_completion.py b/weave/trace_server/llm_completion.py index da5d08cb9501..b0bfa0d78eef 100644 --- a/weave/trace_server/llm_completion.py +++ b/weave/trace_server/llm_completion.py @@ -1,11 +1,13 @@ -from typing import Optional +from typing import Callable, Optional from weave.trace_server import trace_server_interface as tsi from weave.trace_server.errors import ( InvalidRequest, MissingLLMApiKeyError, ) -from weave.trace_server.secret_fetcher_context import _secret_fetcher_context +from weave.trace_server.secret_fetcher_context import ( + _secret_fetcher_context, +) NOVA_MODELS = ("nova-pro-v1", "nova-lite-v1", "nova-micro-v1") @@ -37,19 +39,10 @@ def lite_llm_completion( azure_api_base, azure_api_version = get_azure_credentials(inputs.model) import litellm - from litellm import LiteLLM - litellm.set_verbose = True # This allows us to drop params that are not supported by the LLM provider litellm.drop_params = True - import logging - - logging.basicConfig(level=logging.DEBUG) - import os - - os.environ["LITELLM_LOG"] = "DEBUG" - # Handle custom provider if provider == "custom" and base_url: try: @@ -163,3 +156,106 @@ def get_azure_credentials(model_name: str) -> tuple[str, str]: ) return azure_api_base, azure_api_version + + +def get_custom_provider_info( + project_id: str, + model_name: str, + obj_read_func: Callable, +) -> tuple[str, str, dict[str, str], str, str]: + """ + Extract provider information from a custom provider model. + + Args: + project_id: The project ID + model_name: The model name (format: __weave_custom_provider__//) + obj_read_func: Function to read objects from the database + secret_fetcher: Secret fetcher to get API keys + + Returns: + Tuple containing: + - base_url: The base URL for the provider + - api_key: The API key for the provider + - extra_headers: Extra headers to send with the request + - return_type: The return type for the provider + - actual_model_name: The actual model name to use for the API call + """ + secret_fetcher = _secret_fetcher_context.get() + if not secret_fetcher: + raise InvalidRequest( + f"No secret fetcher found, cannot fetch API key for model {model_name}" + ) + + # Parse the model name to extract provider_id and provider_model_id + # Format: __weave_custom_provider__// + parts = model_name.split("/") + if len(parts) < 3: + raise InvalidRequest(f"Invalid custom provider model format: {model_name}") + + provider_id = parts[1] + provider_model_id = parts[2] + + # Default values + base_url = None + secret_name = None + extra_headers = {} + return_type = "openai" + actual_model_name = model_name + + try: + # Fetch the provider object + provider_obj_req = tsi.ObjReadReq( + project_id=project_id, + object_id=provider_id, + digest="latest", + metadata_only=False, + ) + provider_obj_res = obj_read_func(provider_obj_req) + provider_obj = provider_obj_res.obj + + if provider_obj.base_object_class != "Provider": + raise InvalidRequest( + f"Object {provider_id} is not a Provider, it is a {provider_obj.base_object_class}" + ) + + # Extract provider information + base_url = provider_obj.val.get("base_url") + secret_name = provider_obj.val.get("api_key_name") + extra_headers = provider_obj.val.get("extra_headers", {}) + return_type = provider_obj.val.get("return_type", "openai") + + # Fetch the provider model object + # Provider models have the format: - + provider_model_obj_req = tsi.ObjReadReq( + project_id=project_id, + object_id=f"{provider_id}-{provider_model_id}", + digest="latest", + metadata_only=False, + ) + provider_model_obj_res = obj_read_func(provider_model_obj_req) + provider_model_obj = provider_model_obj_res.obj + + if provider_model_obj.base_object_class != "ProviderModel": + raise InvalidRequest( + f"Object {provider_model_id} is not a ProviderModel, it is a {provider_model_obj.base_object_class}" + ) + + # Use the provider model's name as the actual model name for the API call + actual_model_name = provider_model_obj.val.get("name") + + except Exception as e: + raise InvalidRequest(f"Failed to fetch provider or model information: {str(e)}") + + # Get the API key + if not secret_name: + raise InvalidRequest(f"No secret name found for provider {provider_id}") + + api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name) + + if not api_key: + raise MissingLLMApiKeyError( + f"No API key {secret_name} found for provider {provider_id}", + api_key_name=secret_name, + ) + + return base_url, api_key, extra_headers, return_type, actual_model_name