Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
jwlee64 committed Mar 4, 2025
1 parent 4bf2177 commit 08796df
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 106 deletions.
2 changes: 1 addition & 1 deletion weave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from weave.flow.model import Model
from weave.flow.obj import Object
from weave.flow.prompt.prompt import EasyPrompt, MessagesPrompt, Prompt, StringPrompt
from weave.flow.provider import Provider as Provider
from weave.flow.scorer import Scorer
from weave.flow.provider import Provider
from weave.initialization import *
from weave.trace.util import Thread as Thread
from weave.trace.util import ThreadPoolExecutor as ThreadPoolExecutor
Expand Down
3 changes: 2 additions & 1 deletion weave/flow/provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from enum import Enum
from typing import Any, Optional, TypedDict, Union
from typing_extensions import Self

from pydantic import Field
from rich.table import Table
from typing_extensions import Self

from weave.flow.obj import Object
from weave.trace.api import publish as weave_publish
Expand Down
124 changes: 38 additions & 86 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@
from weave.trace_server.file_storage_uris import FileStorageURI
from weave.trace_server.ids import generate_id
from weave.trace_server.llm_completion import (
lite_llm_completion,
CUSTOM_PROVIDER_PREFIX,
get_custom_provider_info,
lite_llm_completion,
)
from weave.trace_server.model_providers.model_providers import (
read_model_to_provider_info_map,
Expand Down Expand Up @@ -1613,106 +1614,35 @@ def actions_execute_batch(
def completions_create(
self, req: tsi.CompletionsCreateReq
) -> tsi.CompletionsCreateRes:
# Required fields
model_name = req.inputs.model
api_key = None
provider = None
secret_name = None
base_url = None
extra_headers = {}
return_type = None

# Handle custom provider case
if CUSTOM_PROVIDER_PREFIX in model_name:
# Parse the model name to extract provider_id and provider_model_id
# Format: __weave_custom_provider__/<provider_id>/<provider_model_id>
parts = model_name.split("/")
if len(parts) < 3:
raise InvalidRequest(
f"Invalid custom provider model format: {model_name}"
)

provider_id = parts[1]
provider_model_id = parts[2]

# Fetch the provider object
try:
provider_obj_req = tsi.ObjReadReq(
project_id=req.project_id,
object_id=provider_id,
digest="latest",
metadata_only=False,
)
provider_obj_res = self.obj_read(provider_obj_req)
provider_obj = provider_obj_res.obj

if provider_obj.base_object_class != "Provider":
raise InvalidRequest(
f"Object {provider_id} is not a Provider, it is a {provider_obj.base_object_class}"
)

# Extract provider information
base_url = provider_obj.val.get("base_url")
secret_name = provider_obj.val.get("api_key_name")
extra_headers = provider_obj.val.get("extra_headers", {})
return_type = provider_obj.val.get("return_type", "openai")

# Fetch the provider model object
provider_model_obj_req = tsi.ObjReadReq(
project_id=req.project_id,
object_id=f"{provider_id}-{provider_model_id}",
digest="latest",
metadata_only=False,
)
provider_model_obj_res = self.obj_read(provider_model_obj_req)
provider_model_obj = provider_model_obj_res.obj

if provider_model_obj.base_object_class != "ProviderModel":
raise InvalidRequest(
f"Object {provider_model_id} is not a ProviderModel, it is a {provider_model_obj.base_object_class}"
)

# Use the provider model's name as the actual model name for the API call
req.inputs.model = provider_model_obj.val.get("name")
# Custom model fields
base_url: Optional[str] = None
extra_headers: dict[str, str] = {}
return_type: Optional[str] = None

except Exception as e:
raise InvalidRequest(
f"Failed to fetch provider or model information: {str(e)}"
)
# For custom and standard models, we fetch the fields differently
# 1. Standard models: All of the information comes from the model_to_provider_info_map
# 2. Custom models: We fetch the provider object and provider model object
if CUSTOM_PROVIDER_PREFIX not in model_name:
# Handle standard model case
# 1. We get the model info from the map
# 2. We fetch the API key, with the secret fetcher
# 3. We set the provider, to the litellm provider
# 4. If no api key, we raise an error, except for bedrock and bedrock_converse (we fetch bedrock credentials, in lite_llm_completion)

# Get the API key
secret_fetcher = _secret_fetcher_context.get()
if not secret_fetcher:
raise InvalidRequest(
f"No secret fetcher found, cannot fetch API key for model {model_name}"
)

if not secret_name:
raise InvalidRequest(f"No secret name found for provider {provider_id}")

api_key = (
secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)
)

if not api_key:
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for provider {provider_id}",
api_key_name=secret_name,
)

provider = "custom" # Use "custom" as the provider for litellm

else:
# Handle standard model case
model_info = self._model_to_provider_info_map.get(model_name)
if not model_info:
raise InvalidRequest(f"No model info found for model {model_name}")

secret_fetcher = _secret_fetcher_context.get()
if not secret_fetcher:
raise InvalidRequest(
f"No secret fetcher found, cannot fetch API key for model {model_name}"
)

secret_name = model_info.get("api_key_name")
if not secret_name:
raise InvalidRequest(f"No secret name found for model {model_name}")
Expand All @@ -1722,12 +1652,34 @@ def completions_create(
)
provider = model_info.get("litellm_provider", "openai")

# We fetch bedrock credentials, in lite_llm_completion, later
if not api_key and provider != "bedrock" and provider != "bedrock_converse":
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for model {model_name}",
api_key_name=secret_name,
)

else:
# Handle custom provider case
# We fetch the provider object and provider model object
(
base_url,
api_key,
extra_headers,
return_type,
actual_model_name,
) = get_custom_provider_info(
project_id=req.project_id,
model_name=model_name,
obj_read_func=self.obj_read,
)

# Always use "custom" as the provider for litellm
provider = "custom"
# Update the model name for the API call
req.inputs.model = actual_model_name

# Now that we have all the fields for both cases, we can make the API call
start_time = datetime.datetime.now()

# Make the API call
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
)
from weave.trace_server.interface.builtin_object_classes.leaderboard import Leaderboard
from weave.trace_server.interface.builtin_object_classes.llm_model import (
ProviderModel,
LLMModel,
ProviderModel,
)
from weave.trace_server.interface.builtin_object_classes.provider import Provider
from weave.trace_server.interface.builtin_object_classes.test_only_example import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import Optional, Union

from pydantic import BaseModel

from weave.trace_server.interface.builtin_object_classes import base_object_def
Expand All @@ -20,14 +21,12 @@ class ModelParams(BaseModel):


class ProviderModel(base_object_def.BaseObject):
name: str
provider: base_object_def.RefStr
max_tokens: int
mode: ModelMode = ModelMode.CHAT


class LLMModel(base_object_def.BaseObject):
name: str
provider_model: base_object_def.RefStr
prompt: Optional[base_object_def.RefStr] = None
default_params: ModelParams = ModelParams()
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import Optional, Dict

from weave.trace_server.interface.builtin_object_classes import base_object_def

Expand All @@ -9,9 +8,7 @@ class ReturnType(str, Enum):


class Provider(base_object_def.BaseObject):
name: str
base_url: str
api_key_name: str
description: Optional[str] = None
extra_headers: Dict[str, str] = {}
extra_headers: dict[str, str] = {}
return_type: ReturnType = ReturnType.OPENAI
118 changes: 107 additions & 11 deletions weave/trace_server/llm_completion.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional
from typing import Callable, Optional

from weave.trace_server import trace_server_interface as tsi
from weave.trace_server.errors import (
InvalidRequest,
MissingLLMApiKeyError,
)
from weave.trace_server.secret_fetcher_context import _secret_fetcher_context
from weave.trace_server.secret_fetcher_context import (
_secret_fetcher_context,
)

NOVA_MODELS = ("nova-pro-v1", "nova-lite-v1", "nova-micro-v1")

Expand Down Expand Up @@ -37,19 +39,10 @@ def lite_llm_completion(
azure_api_base, azure_api_version = get_azure_credentials(inputs.model)

import litellm
from litellm import LiteLLM

litellm.set_verbose = True
# This allows us to drop params that are not supported by the LLM provider
litellm.drop_params = True

import logging

logging.basicConfig(level=logging.DEBUG)
import os

os.environ["LITELLM_LOG"] = "DEBUG"

# Handle custom provider
if provider == "custom" and base_url:
try:
Expand Down Expand Up @@ -163,3 +156,106 @@ def get_azure_credentials(model_name: str) -> tuple[str, str]:
)

return azure_api_base, azure_api_version


def get_custom_provider_info(
project_id: str,
model_name: str,
obj_read_func: Callable,
) -> tuple[str, str, dict[str, str], str, str]:
"""
Extract provider information from a custom provider model.
Args:
project_id: The project ID
model_name: The model name (format: __weave_custom_provider__/<provider_id>/<provider_model_id>)
obj_read_func: Function to read objects from the database
secret_fetcher: Secret fetcher to get API keys
Returns:
Tuple containing:
- base_url: The base URL for the provider
- api_key: The API key for the provider
- extra_headers: Extra headers to send with the request
- return_type: The return type for the provider
- actual_model_name: The actual model name to use for the API call
"""
secret_fetcher = _secret_fetcher_context.get()
if not secret_fetcher:
raise InvalidRequest(
f"No secret fetcher found, cannot fetch API key for model {model_name}"
)

# Parse the model name to extract provider_id and provider_model_id
# Format: __weave_custom_provider__/<provider_id>/<provider_model_id>
parts = model_name.split("/")
if len(parts) < 3:
raise InvalidRequest(f"Invalid custom provider model format: {model_name}")

provider_id = parts[1]
provider_model_id = parts[2]

# Default values
base_url = None
secret_name = None
extra_headers = {}
return_type = "openai"
actual_model_name = model_name

try:
# Fetch the provider object
provider_obj_req = tsi.ObjReadReq(
project_id=project_id,
object_id=provider_id,
digest="latest",
metadata_only=False,
)
provider_obj_res = obj_read_func(provider_obj_req)
provider_obj = provider_obj_res.obj

if provider_obj.base_object_class != "Provider":
raise InvalidRequest(
f"Object {provider_id} is not a Provider, it is a {provider_obj.base_object_class}"
)

# Extract provider information
base_url = provider_obj.val.get("base_url")
secret_name = provider_obj.val.get("api_key_name")
extra_headers = provider_obj.val.get("extra_headers", {})
return_type = provider_obj.val.get("return_type", "openai")

# Fetch the provider model object
# Provider models have the format: <provider_id>-<provider_model>
provider_model_obj_req = tsi.ObjReadReq(
project_id=project_id,
object_id=f"{provider_id}-{provider_model_id}",
digest="latest",
metadata_only=False,
)
provider_model_obj_res = obj_read_func(provider_model_obj_req)
provider_model_obj = provider_model_obj_res.obj

if provider_model_obj.base_object_class != "ProviderModel":
raise InvalidRequest(
f"Object {provider_model_id} is not a ProviderModel, it is a {provider_model_obj.base_object_class}"
)

# Use the provider model's name as the actual model name for the API call
actual_model_name = provider_model_obj.val.get("name")

except Exception as e:
raise InvalidRequest(f"Failed to fetch provider or model information: {str(e)}")

# Get the API key
if not secret_name:
raise InvalidRequest(f"No secret name found for provider {provider_id}")

api_key = secret_fetcher.fetch(secret_name).get("secrets", {}).get(secret_name)

if not api_key:
raise MissingLLMApiKeyError(
f"No API key {secret_name} found for provider {provider_id}",
api_key_name=secret_name,
)

return base_url, api_key, extra_headers, return_type, actual_model_name

0 comments on commit 08796df

Please sign in to comment.