Skip to content

Commit

Permalink
Typing improvements to check that we support both v1 and v1beta1
Browse files Browse the repository at this point in the history
  • Loading branch information
aabmass committed Jan 17, 2025
1 parent c84c532 commit 874098f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
TYPE_CHECKING,
Any,
Callable,
Iterable,
MutableSequence,
Optional,
Union,
)

from opentelemetry._events import EventLogger
Expand All @@ -33,28 +30,35 @@
from opentelemetry.trace import SpanKind, Tracer

if TYPE_CHECKING:
from google.cloud.aiplatform_v1.services.prediction_service import client
from google.cloud.aiplatform_v1.types import (
content,
prediction_service,
)
from vertexai.generative_models import (
GenerationResponse,
from google.cloud.aiplatform_v1beta1.services.prediction_service import (
client as client_v1beta1,
)
from vertexai.generative_models._generative_models import (
_GenerativeModel,
from google.cloud.aiplatform_v1beta1.types import (
content as content_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
prediction_service as prediction_service_v1beta1,
)


# Use parameter signature from
# https://github.com/googleapis/python-aiplatform/blob/v1.76.0/google/cloud/aiplatform_v1/services/prediction_service/client.py#L2088
# to handle named vs positional args robustly
def _extract_params(
request: Optional[
Union[prediction_service.GenerateContentRequest, dict[Any, Any]]
] = None,
request: prediction_service.GenerateContentRequest
| prediction_service_v1beta1.GenerateContentRequest
| dict[Any, Any]
| None = None,
*,
model: Optional[str] = None,
contents: Optional[MutableSequence[content.Content]] = None,
model: str | None = None,
contents: MutableSequence[content.Content]
| MutableSequence[content_v1beta1.Content]
| None = None,
**_kwargs: Any,
) -> GenerateContentParams:
# Request vs the named parameters are mututally exclusive or the RPC will fail
Expand Down Expand Up @@ -86,9 +90,12 @@ def generate_content_create(

def traced_method(
wrapped: Callable[
..., GenerationResponse | Iterable[GenerationResponse]
...,
prediction_service.GenerateContentResponse
| prediction_service_v1beta1.GenerateContentResponse,
],
instance: _GenerativeModel,
instance: client.PredictionServiceClient
| client_v1beta1.PredictionServiceClient,
args: Any,
kwargs: Any,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import (
TYPE_CHECKING,
Mapping,
Optional,
Sequence,
)

Expand All @@ -31,18 +30,32 @@

if TYPE_CHECKING:
from google.cloud.aiplatform_v1.types import content, tool
from google.cloud.aiplatform_v1beta1.types import (
content as content_v1beta1,
)
from google.cloud.aiplatform_v1beta1.types import (
tool as tool_v1beta1,
)


@dataclass(frozen=True)
class GenerateContentParams:
model: str
contents: Optional[Sequence[content.Content]] = None
system_instruction: Optional[content.Content | None] = None
tools: Optional[Sequence[tool.Tool]] = None
tool_config: Optional[tool.ToolConfig] = None
labels: Optional[Mapping[str, str]] = None
safety_settings: Optional[Sequence[content.SafetySetting]] = None
generation_config: Optional[content.GenerationConfig] = None
contents: (
Sequence[content.Content] | Sequence[content_v1beta1.Content] | None
) = None
system_instruction: content.Content | content_v1beta1.Content | None = None
tools: Sequence[tool.Tool] | Sequence[tool_v1beta1.Tool] | None = None
tool_config: tool.ToolConfig | tool_v1beta1.ToolConfig | None = None
labels: Mapping[str, str] | None = None
safety_settings: (
Sequence[content.SafetySetting]
| Sequence[content_v1beta1.SafetySetting]
| None
) = None
generation_config: (
content.GenerationConfig | content_v1beta1.GenerationConfig | None
) = None


def get_genai_request_attributes(
Expand Down

0 comments on commit 874098f

Please sign in to comment.