From ad03bf92846b8effd9539d05cfb559409a5e6b57 Mon Sep 17 00:00:00 2001 From: Nikhil Narayen Date: Fri, 7 Feb 2025 13:37:24 -0500 Subject: [PATCH 1/3] Fix integration test for span cleanup (#1374) --- truss/templates/server/model_wrapper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 49b1d256e..8b78c00ba 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -641,7 +641,7 @@ async def _stream_with_background_task( generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], span: trace.Span, trace_ctx: trace.Context, - release_and_end: Callable[[], None], + cleanup_fn: Callable[[], None], ) -> AsyncGenerator[bytes, None]: # The streaming read timeout is the amount of time in between streamed chunk # before a timeout is triggered. @@ -661,7 +661,7 @@ async def _stream_with_background_task( self._write_response_to_queue(response_queue, async_generator, span) ) # Defer the release of the semaphore until the write_response_to_queue task. - gen_task.add_done_callback(lambda _: release_and_end()) + gen_task.add_done_callback(lambda _: cleanup_fn()) # The gap between responses in a stream must be < streaming_read_timeout # TODO: this whole buffering might be superfluous and sufficiently done by @@ -717,7 +717,7 @@ async def _process_model_fn( if inspect.isgenerator(result) or inspect.isasyncgen(result): return await self._handle_generator_response( - request, result, fn_span, detached_ctx, release_and_end=lambda: None + request, result, fn_span, detached_ctx ) return result @@ -738,13 +738,13 @@ async def _handle_generator_response( generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], span: trace.Span, trace_ctx: trace.Context, - release_and_end: Callable[[], None], + get_cleanup_fn: Callable[[], Callable[[], None]] = lambda: lambda: None, ): if self._should_gather_generator(request): return await _gather_generator(generator) else: return await self._stream_with_background_task( - generator, span, trace_ctx, release_and_end + generator, span, trace_ctx, cleanup_fn=get_cleanup_fn() ) async def completions( @@ -824,7 +824,7 @@ async def __call__( predict_result, span_predict, detached_ctx, - release_and_end=get_defer_fn(), + get_cleanup_fn=get_defer_fn, ) if isinstance(predict_result, starlette.responses.Response): From ecc1b4f2c44127e6241001035133c173eacaf38f Mon Sep 17 00:00:00 2001 From: Tyron Jung <37984135+tyranitar@users.noreply.github.com> Date: Fri, 7 Feb 2025 12:23:56 -0800 Subject: [PATCH 2/3] Use hostname for Chains (#1349) * Use hostname for Chains * Use hostname for Oracle invoke URL * Fix tests * Fix tests * Reformat * Address PR feedback * Fix watch command * Fix tests --- .../deployment/deployment_client.py | 14 +- truss/api/definitions.py | 4 +- truss/remote/baseten/api.py | 128 ++++++++++-------- truss/remote/baseten/core.py | 47 +++++-- truss/remote/baseten/custom_types.py | 1 - truss/remote/baseten/remote.py | 39 +++--- truss/remote/baseten/service.py | 45 +++--- truss/tests/remote/baseten/test_api.py | 19 ++- truss/tests/remote/baseten/test_core.py | 47 ++++--- truss/tests/remote/baseten/test_remote.py | 112 ++++++++++----- truss/tests/remote/baseten/test_service.py | 36 +++-- 11 files changed, 297 insertions(+), 195 deletions(-) diff --git a/truss-chains/truss_chains/deployment/deployment_client.py b/truss-chains/truss_chains/deployment/deployment_client.py index 60e2df437..c075cb3ff 100644 --- a/truss-chains/truss_chains/deployment/deployment_client.py +++ b/truss-chains/truss_chains/deployment/deployment_client.py @@ -364,12 +364,14 @@ def __init__( @property def run_remote_url(self) -> str: """URL to invoke the entrypoint.""" - return b10_service.URLConfig.invocation_url( - self._remote.api.rest_api_url, - b10_service.URLConfig.CHAIN, - self._chain_deployment_handle.chain_id, - self._chain_deployment_handle.chain_deployment_id, - self._chain_deployment_handle.is_draft, + + handle = self._chain_deployment_handle + + return b10_service.URLConfig.invoke_url( + hostname=handle.hostname, + config=b10_service.URLConfig.CHAIN, + entity_version_id=handle.chain_deployment_id, + is_draft=handle.is_draft, ) def run_remote(self, json_data: Dict) -> Any: diff --git a/truss/api/definitions.py b/truss/api/definitions.py index dcb68081b..9106fa8e2 100644 --- a/truss/api/definitions.py +++ b/truss/api/definitions.py @@ -14,8 +14,8 @@ class ModelDeployment: _baseten_service: service.BasetenService def __init__(self, service: service.BasetenService): - self.model_id = service._model_id - self.model_deployment_id = service._model_version_id + self.model_id = service.model_id + self.model_deployment_id = service.model_version_id self._baseten_service = service def wait_for_active(self, timeout_seconds: int = 600) -> bool: diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 81be1d91d..d8633dce9 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -138,27 +138,32 @@ def create_model_from_truss( origin: Optional[b10_types.ModelOrigin] = None, ): query_string = f""" - mutation {{ - create_model_from_truss( - name: "{model_name}", - s3_key: "{s3_key}", - config: "{config}", - semver_bump: "{semver_bump}", - client_version: "{client_version}", - is_trusted: {"true" if is_trusted else "false"}, - allow_truss_download: {"true" if allow_truss_download else "false"}, - {f'version_name: "{deployment_name}"' if deployment_name else ""} - {f"model_origin: {origin.value}" if origin else ""} - ) {{ - id, - name, - version_id + mutation {{ + create_model_from_truss( + name: "{model_name}" + s3_key: "{s3_key}" + config: "{config}" + semver_bump: "{semver_bump}" + client_version: "{client_version}" + is_trusted: {"true" if is_trusted else "false"} + allow_truss_download: {"true" if allow_truss_download else "false"} + {f'version_name: "{deployment_name}"' if deployment_name else ""} + {f"model_origin: {origin.value}" if origin else ""} + ) {{ + model_version {{ + id + oracle {{ + id + name + hostname + }} + }} + }} }} - }} """ resp = self._post_graphql_query(query_string) - return resp["data"]["create_model_from_truss"] + return resp["data"]["create_model_from_truss"]["model_version"] def create_model_version_from_truss( self, @@ -173,25 +178,30 @@ def create_model_version_from_truss( environment: Optional[str] = None, ): query_string = f""" - mutation {{ - create_model_version_from_truss( - model_id: "{model_id}" - s3_key: "{s3_key}", - config: "{config}", - semver_bump: "{semver_bump}", - client_version: "{client_version}", - is_trusted: {"true" if is_trusted else "false"}, - scale_down_old_production: {"false" if preserve_previous_prod_deployment else "true"}, - {f'name: "{deployment_name}"' if deployment_name else ""} - {f'environment_name: "{environment}"' if environment else ""} - ) {{ - id + mutation {{ + create_model_version_from_truss( + model_id: "{model_id}" + s3_key: "{s3_key}" + config: "{config}" + semver_bump: "{semver_bump}" + client_version: "{client_version}" + is_trusted: {"true" if is_trusted else "false"} + scale_down_old_production: {"false" if preserve_previous_prod_deployment else "true"} + {f'name: "{deployment_name}"' if deployment_name else ""} + {f'environment_name: "{environment}"' if environment else ""} + ) {{ + model_version {{ + id + oracle {{ + hostname + }} + }} + }} }} - }} """ resp = self._post_graphql_query(query_string) - return resp["data"]["create_model_version_from_truss"] + return resp["data"]["create_model_version_from_truss"]["model_version"] def create_development_model_from_truss( self, @@ -204,23 +214,29 @@ def create_development_model_from_truss( origin: Optional[b10_types.ModelOrigin] = None, ): query_string = f""" - mutation {{ - deploy_draft_truss(name: "{model_name}", - s3_key: "{s3_key}", - config: "{config}", - client_version: "{client_version}", - is_trusted: {"true" if is_trusted else "false"}, - allow_truss_download: {"true" if allow_truss_download else "false"}, + mutation {{ + deploy_draft_truss(name: "{model_name}" + s3_key: "{s3_key}" + config: "{config}" + client_version: "{client_version}" + is_trusted: {"true" if is_trusted else "false"} + allow_truss_download: {"true" if allow_truss_download else "false"} {f"model_origin: {origin.value}" if origin else ""} - ) {{ - id, - name, - version_id - }} - }} + ) {{ + model_version {{ + id + oracle {{ + id + name + hostname + }} + }} + }} + }} """ + resp = self._post_graphql_query(query_string) - return resp["data"]["deploy_draft_truss"] + return resp["data"]["deploy_draft_truss"]["model_version"] def deploy_chain_atomic( self, @@ -251,10 +267,13 @@ def deploy_chain_atomic( dependencies: [{dependencies_str}] client_version: "{truss.version()}" ) {{ - chain_id - chain_deployment_id - entrypoint_model_id - entrypoint_model_version_id + chain_deployment {{ + id + chain {{ + id + hostname + }} + }} }} }} """ @@ -385,9 +404,10 @@ def get_model(self, model_name): query_string = f""" {{ model(name: "{model_name}") {{ - name id - versions{{ + name + hostname + versions {{ id semver truss_hash @@ -408,8 +428,9 @@ def get_model_by_id(self, model_id: str): query_string = f""" {{ model(id: "{model_id}") {{ - name id + name + hostname primary_version{{ id semver @@ -437,6 +458,7 @@ def get_model_version_by_id(self, model_version_id: str): oracle{{ id name + hostname }} }} }} diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index 555385c04..e8508f7b7 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -64,11 +64,16 @@ class TrussWatchState(NamedTuple): class ChainDeploymentHandleAtomic(NamedTuple): + hostname: str chain_id: str chain_deployment_id: str is_draft: bool - entrypoint_model_id: str - entrypoint_model_version_id: str + + +class ModelVersionHandle(NamedTuple): + version_id: str + model_id: str + hostname: str def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]: @@ -160,10 +165,9 @@ def create_chain_atomic( ) return ChainDeploymentHandleAtomic( - chain_id=res["chain_id"], - chain_deployment_id=res["chain_deployment_id"], - entrypoint_model_id=res["entrypoint_model_id"], - entrypoint_model_version_id=res["entrypoint_model_version_id"], + chain_deployment_id=res["chain_deployment"]["id"], + chain_id=res["chain_deployment"]["chain"]["id"], + hostname=res["chain_deployment"]["chain"]["hostname"], is_draft=is_draft, ) @@ -193,9 +197,9 @@ def exists_model(api: BasetenApi, model_name: str) -> Optional[str]: return model["model"]["id"] -def get_model_versions(api: BasetenApi, model_name: ModelName) -> Tuple[str, List]: +def get_model_and_versions(api: BasetenApi, model_name: ModelName) -> Tuple[dict, List]: query_result = api.get_model(model_name.value)["model"] - return query_result["id"], query_result["versions"] + return query_result, query_result["versions"] def get_dev_version_from_versions(versions: List[dict]) -> Optional[dict]: @@ -334,7 +338,7 @@ def create_truss_service( deployment_name: Optional[str] = None, origin: Optional[b10_types.ModelOrigin] = None, environment: Optional[str] = None, -) -> Tuple[str, str]: +) -> ModelVersionHandle: """ Create a model in the Baseten remote. @@ -352,8 +356,9 @@ def create_truss_service( development model. Returns: - A tuple of the model ID and version ID + A Model Version handle. """ + if is_draft: model_version_json = api.create_development_model_from_truss( model_name, @@ -365,11 +370,16 @@ def create_truss_service( origin=origin, ) - return model_version_json["id"], model_version_json["version_id"] + return ModelVersionHandle( + version_id=model_version_json["id"], + model_id=model_version_json["oracle"]["id"], + hostname=model_version_json["oracle"]["hostname"], + ) if model_id is None: if environment and environment != PRODUCTION_ENVIRONMENT_NAME: raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING) + model_version_json = api.create_model_from_truss( model_name=model_name, s3_key=s3_key, @@ -381,7 +391,12 @@ def create_truss_service( deployment_name=deployment_name, origin=origin, ) - return model_version_json["id"], model_version_json["version_id"] + + return ModelVersionHandle( + version_id=model_version_json["id"], + model_id=model_version_json["oracle"]["id"], + hostname=model_version_json["oracle"]["hostname"], + ) try: model_version_json = api.create_model_version_from_truss( @@ -404,8 +419,12 @@ def create_truss_service( f'Environment "{environment}" does not exist. You can create environments in the Baseten UI.' ) from e raise e - model_version_id = model_version_json["id"] - return model_id, model_version_id + + return ModelVersionHandle( + version_id=model_version_json["id"], + model_id=model_id, + hostname=model_version_json["oracle"]["hostname"], + ) def validate_truss_config(api: BasetenApi, config: str): diff --git a/truss/remote/baseten/custom_types.py b/truss/remote/baseten/custom_types.py index 55124a419..a8e04851b 100644 --- a/truss/remote/baseten/custom_types.py +++ b/truss/remote/baseten/custom_types.py @@ -11,7 +11,6 @@ class DeployedChainlet(pydantic.BaseModel): is_draft: bool status: str logs_url: str - oracle_predict_url: str oracle_name: str diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 00d4588d1..2ed4f64c3 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -22,6 +22,7 @@ ModelId, ModelIdentifier, ModelName, + ModelVersionHandle, ModelVersionId, archive_truss, create_chain_atomic, @@ -29,7 +30,7 @@ exists_model, get_dev_version, get_dev_version_from_versions, - get_model_versions, + get_model_and_versions, get_prod_version_from_versions, get_truss_watch_state, upload_truss, @@ -95,13 +96,6 @@ def get_chainlets( chain_deployment_id, chainlet["id"], ), - oracle_predict_url=URLConfig.invocation_url( - self._api.rest_api_url, - URLConfig.MODEL, - chainlet["oracle"]["id"], - chainlet["oracle_version"]["id"], - chainlet["oracle_version"]["is_draft"], - ), oracle_name=chainlet["oracle"]["name"], ) for chainlet in self._api.get_chainlets_by_deployment_id( @@ -228,7 +222,7 @@ def push( # type: ignore # many functions. We should consolidate them into a # data class with standardized default values so # we're not drilling these arguments everywhere. - model_id, model_version_id = create_truss_service( + model_version_handle = create_truss_service( api=self._api, model_name=push_data.model_name, s3_key=push_data.s3_key, @@ -244,11 +238,10 @@ def push( # type: ignore ) return BasetenService( - model_id=model_id, - model_version_id=model_version_id, + model_version_handle=model_version_handle, is_draft=push_data.is_draft, api_key=self._auth_service.authenticate().value, - service_url=f"{self._remote_url}/model_versions/{model_version_id}", + service_url=f"{self._remote_url}/model_versions/{model_version_handle.version_id}", truss_handle=truss_handle, api=self._api, ) @@ -332,23 +325,29 @@ def _get_matching_version(model_versions: List[dict], published: bool) -> dict: @staticmethod def _get_service_url_path_and_model_ids( api: BasetenApi, model_identifier: ModelIdentifier, published: bool - ) -> Tuple[str, str, str]: + ) -> Tuple[str, ModelVersionHandle]: if isinstance(model_identifier, ModelVersionId): try: model_version = api.get_model_version_by_id(model_identifier.value) except ApiError: raise RemoteError(f"Model version {model_identifier.value} not found.") model_version_id = model_version["model_version"]["id"] + hostname = model_version["model_version"]["oracle"]["hostname"] model_id = model_version["model_version"]["oracle"]["id"] service_url_path = f"/model_versions/{model_version_id}" - return service_url_path, model_id, model_version_id + + return service_url_path, ModelVersionHandle( + version_id=model_version_id, model_id=model_id, hostname=hostname + ) if isinstance(model_identifier, ModelName): - model_id, model_versions = get_model_versions(api, model_identifier) + model, model_versions = get_model_and_versions(api, model_identifier) model_version = BasetenRemote._get_matching_version( model_versions, published ) + model_id = model["id"] model_version_id = model_version["id"] + hostname = model["hostname"] service_url_path = f"/model_versions/{model_version_id}" elif isinstance(model_identifier, ModelId): # TODO(helen): consider making this consistent with getting the @@ -359,6 +358,7 @@ def _get_service_url_path_and_model_ids( raise RemoteError(f"Model {model_identifier.value} not found.") model_id = model["model"]["id"] model_version_id = model["model"]["primary_version"]["id"] + hostname = model["model"]["hostname"] service_url_path = f"/models/{model_id}" else: # Model identifier is of invalid type. @@ -367,7 +367,9 @@ def _get_service_url_path_and_model_ids( "--model-deployment or --model options." ) - return service_url_path, model_id, model_version_id + return service_url_path, ModelVersionHandle( + version_id=model_version_id, model_id=model_id, hostname=hostname + ) def get_service(self, **kwargs) -> BasetenService: try: @@ -376,15 +378,14 @@ def get_service(self, **kwargs) -> BasetenService: raise ValueError("Baseten Service requires a model_identifier") published = kwargs.get("published", False) - (service_url_path, model_id, model_version_id) = ( + (service_url_path, model_version_handle) = ( self._get_service_url_path_and_model_ids( self._api, model_identifier, published ) ) return BasetenService( - model_id=model_id, - model_version_id=model_version_id, + model_version_handle=model_version_handle, is_draft=not published, api_key=self._auth_service.authenticate().value, service_url=f"{self._remote_url}{service_url_path}", diff --git a/truss/remote/baseten/service.py b/truss/remote/baseten/service.py index 488d379e9..d7c46d3e5 100644 --- a/truss/remote/baseten/service.py +++ b/truss/remote/baseten/service.py @@ -9,6 +9,7 @@ from truss.base.errors import RemoteNetworkError from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.auth import AuthService +from truss.remote.baseten.core import ModelVersionHandle from truss.remote.truss_remote import TrussService from truss.truss_handle.truss_handle import TrussHandle @@ -36,23 +37,18 @@ class Data(NamedTuple): CHAIN = Data("chain", "run_remote", "chains") @staticmethod - def invocation_url( - api_url: str, # E.g. https://api.baseten.co + def invoke_url( + hostname: str, # E.g. https://model-{model_id}.api.baseten.co config: "URLConfig", - entity_id: str, entity_version_id: str, is_draft, ) -> str: """Get the URL for the predict/run_remote endpoint.""" - # E.g. `https://api.baseten.co` -> `https://model-{model_id}.api.baseten.co` - url = _add_model_subdomain(api_url, f"{config.value.prefix}-{entity_id}") + if is_draft: - # "https://model-{model_id}.api.baseten.co/development". - url = f"{url}/development/{config.value.invoke_endpoint}" + return f"{hostname}/development/{config.value.invoke_endpoint}" else: - # "https://model-{model_id}.api.baseten.co/deployment/{deployment_id}". - url = f"{url}/deployment/{entity_version_id}/{config.value.invoke_endpoint}" - return url + return f"{hostname}/deployment/{entity_version_id}/{config.value.invoke_endpoint}" @staticmethod def status_page_url( @@ -89,8 +85,7 @@ def chainlet_logs_url( class BasetenService(TrussService): def __init__( self, - model_id: str, - model_version_id: str, + model_version_handle: ModelVersionHandle, is_draft: bool, api_key: str, service_url: str, @@ -98,8 +93,7 @@ def __init__( truss_handle: Optional[TrussHandle] = None, ): super().__init__(is_draft=is_draft, service_url=service_url) - self._model_id = model_id - self._model_version_id = model_version_id + self._model_version_handle = model_version_handle self._auth_service = AuthService(api_key=api_key) self._api = api self._truss_handle = truss_handle @@ -112,15 +106,11 @@ def is_ready(self) -> bool: @property def model_id(self) -> str: - return self._model_id + return self._model_version_handle.model_id @property def model_version_id(self) -> str: - return self._model_version_id - - @property - def invocation_url(self) -> str: - return f"{self._service_url}/predict" + return self._model_version_handle.version_id def predict(self, model_request_body: Dict) -> Any: response = self._send_request( @@ -164,17 +154,18 @@ def logs_url(self) -> str: @property def predict_url(self) -> str: - return URLConfig.invocation_url( - self._api.rest_api_url, - URLConfig.MODEL, - self.model_id, - self._model_version_id, - self.is_draft, + handle = self._model_version_handle + + return URLConfig.invoke_url( + hostname=handle.hostname, + config=URLConfig.MODEL, + entity_version_id=handle.version_id, + is_draft=self.is_draft, ) @retry(stop=stop_after_delay(60), wait=wait_fixed(1), reraise=True) def _fetch_deployment(self) -> Any: - return self._api.get_deployment(self._model_id, self._model_version_id) + return self._api.get_deployment(self.model_id, self.model_version_id) def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]: """ diff --git a/truss/tests/remote/baseten/test_api.py b/truss/tests/remote/baseten/test_api.py index b1ebbbf4b..8d6bb682e 100644 --- a/truss/tests/remote/baseten/test_api.py +++ b/truss/tests/remote/baseten/test_api.py @@ -40,7 +40,11 @@ def mock_create_model_version_response(): response = Response() response.status_code = 200 response.json = mock.Mock( - return_value={"data": {"create_model_version_from_truss": {"id": "12345"}}} + return_value={ + "data": { + "create_model_version_from_truss": {"model_version": {"id": "12345"}} + } + } ) return response @@ -49,7 +53,9 @@ def mock_create_model_response(): response = Response() response.status_code = 200 response.json = mock.Mock( - return_value={"data": {"create_model_from_truss": {"id": "12345"}}} + return_value={ + "data": {"create_model_from_truss": {"model_version": {"id": "12345"}}} + } ) return response @@ -58,7 +64,9 @@ def mock_create_development_model_response(): response = Response() response.status_code = 200 response.json = mock.Mock( - return_value={"data": {"deploy_draft_truss": {"id": "12345"}}} + return_value={ + "data": {"deploy_draft_truss": {"model_version": {"id": "12345"}}} + } ) return response @@ -70,10 +78,7 @@ def mock_deploy_chain_deployment_response(): return_value={ "data": { "deploy_chain_atomic": { - "chain_id": "12345", - "chain_deployment_id": "54321", - "entrypoint_model_id": "67890", - "entrypoint_model_version_id": "09876", + "chain_deployment": {"id": "54321", "chain": {"id": "12345"}} } } } diff --git a/truss/tests/remote/baseten/test_core.py b/truss/tests/remote/baseten/test_core.py index bc60d4d4d..960494d98 100644 --- a/truss/tests/remote/baseten/test_core.py +++ b/truss/tests/remote/baseten/test_core.py @@ -86,9 +86,12 @@ def test_get_prod_version_from_versions_error(): @pytest.mark.parametrize("environment", [None, PRODUCTION_ENVIRONMENT_NAME]) def test_create_truss_service_handles_eligible_environment_values(environment): api = MagicMock() - return_value = {"id": "id", "version_id": "model_version_id"} + return_value = { + "id": "model_version_id", + "oracle": {"id": "model_id", "hostname": "hostname"}, + } api.create_model_from_truss.return_value = return_value - model_id, model_version_id = create_truss_service( + version_handle = create_truss_service( api, "model_name", "s3_key", @@ -100,17 +103,20 @@ def test_create_truss_service_handles_eligible_environment_values(environment): deployment_name="deployment_name", environment=environment, ) - assert model_id == return_value["id"] - assert model_version_id == return_value["version_id"] + assert version_handle.version_id == "model_version_id" + assert version_handle.model_id == "model_id" api.create_model_from_truss.assert_called_once() @pytest.mark.parametrize("model_id", ["some_model_id", None]) def test_create_truss_services_handles_is_draft(model_id): api = MagicMock() - return_value = {"id": "id", "version_id": "model_version_id"} + return_value = { + "id": "model_version_id", + "oracle": {"id": "model_id", "hostname": "hostname"}, + } api.create_development_model_from_truss.return_value = return_value - model_id, model_version_id = create_truss_service( + version_handle = create_truss_service( api, "model_name", "s3_key", @@ -121,8 +127,8 @@ def test_create_truss_services_handles_is_draft(model_id): model_id=model_id, deployment_name="deployment_name", ) - assert model_id == return_value["id"] - assert model_version_id == return_value["version_id"] + assert version_handle.version_id == "model_version_id" + assert version_handle.model_id == "model_id" api.create_development_model_from_truss.assert_called_once() @@ -151,9 +157,12 @@ def test_create_truss_services_handles_is_draft(model_id): ) def test_create_truss_service_handles_existing_model(inputs): api = MagicMock() - return_value = {"id": "model_version_id"} + return_value = { + "id": "model_version_id", + "oracle": {"id": "model_id", "hostname": "hostname"}, + } api.create_model_version_from_truss.return_value = return_value - model_id, model_version_id = create_truss_service( + version_handle = create_truss_service( api, "model_name", "s3_key", @@ -163,8 +172,8 @@ def test_create_truss_service_handles_existing_model(inputs): **inputs, ) - assert model_id == "model_id" - assert model_version_id == return_value["id"] + assert version_handle.version_id == "model_version_id" + assert version_handle.model_id == "model_id" api.create_model_version_from_truss.assert_called_once() _, kwargs = api.create_model_version_from_truss.call_args for k, v in inputs.items(): @@ -177,12 +186,14 @@ def test_create_truss_service_handles_allow_truss_download_for_new_models( is_draft, allow_truss_download ): api = MagicMock() - return_value = {"id": "id", "version_id": "model_version_id"} + return_value = { + "id": "model_version_id", + "oracle": {"id": "model_id", "hostname": "hostname"}, + } api.create_model_from_truss.return_value = return_value api.create_development_model_from_truss.return_value = return_value - model_id = None - model_id, model_version_id = create_truss_service( + version_handle = create_truss_service( api, "model_name", "s3_key", @@ -190,12 +201,12 @@ def test_create_truss_service_handles_allow_truss_download_for_new_models( is_trusted=False, preserve_previous_prod_deployment=False, is_draft=is_draft, - model_id=model_id, + model_id=None, deployment_name="deployment_name", allow_truss_download=allow_truss_download, ) - assert model_id == return_value["id"] - assert model_version_id == return_value["version_id"] + assert version_handle.version_id == "model_version_id" + assert version_handle.model_id == "model_id" create_model_mock = ( api.create_development_model_from_truss diff --git a/truss/tests/remote/baseten/test_remote.py b/truss/tests/remote/baseten/test_remote.py index da992939c..7c33f6d89 100644 --- a/truss/tests/remote/baseten/test_remote.py +++ b/truss/tests/remote/baseten/test_remote.py @@ -34,7 +34,7 @@ def assert_request_matches_expected_query(request, expected_query) -> None: def test_get_service_by_version_id(): remote = BasetenRemote(_TEST_REMOTE_URL, "api_key") - version = {"id": "version_id", "oracle": {"id": "model_id"}} + version = {"id": "version_id", "oracle": {"id": "model_id", "hostname": "hostname"}} model_version_response = {"data": {"model_version": version}} with requests_mock.Mocker() as m: @@ -64,7 +64,12 @@ def test_get_service_by_model_name(): ] model_response = { "data": { - "model": {"name": "model_name", "id": "model_id", "versions": versions} + "model": { + "name": "model_name", + "id": "model_id", + "hostname": "hostname", + "versions": versions, + } } } @@ -92,7 +97,12 @@ def test_get_service_by_model_name_no_dev_version(): versions = [{"id": "1", "is_draft": False, "is_primary": True}] model_response = { "data": { - "model": {"name": "model_name", "id": "model_id", "versions": versions} + "model": { + "name": "model_name", + "id": "model_id", + "hostname": "hostname", + "versions": versions, + } } } @@ -120,7 +130,12 @@ def test_get_service_by_model_name_no_prod_version(): versions = [{"id": "1", "is_draft": True, "is_primary": False}] model_response = { "data": { - "model": {"name": "model_name", "id": "model_id", "versions": versions} + "model": { + "name": "model_name", + "id": "model_id", + "hostname": "hostname", + "versions": versions, + } } } @@ -149,6 +164,7 @@ def test_get_service_by_model_id(): "name": "model_name", "id": "model_id", "primary_version": {"id": "version_id"}, + "hostname": "hostname", } } } @@ -277,10 +293,13 @@ def test_create_chain_with_no_publish(): "json": { "data": { "deploy_chain_atomic": { - "chain_id": "new-chain-id", - "chain_deployment_id": "new-chain-deployment-id", - "entrypoint_model_id": "new-entrypoint-model-id", - "entrypoint_model_version_id": "new-entrypoint-model-version-id", + "chain_deployment": { + "id": "new-chain-deployment-id", + "chain": { + "id": "new-chain-id", + "hostname": "hostname", + }, + } } } } @@ -345,10 +364,13 @@ def test_create_chain_with_no_publish(): dependencies: [] client_version: "{truss.version()}" ) {{ - chain_id - chain_deployment_id - entrypoint_model_id - entrypoint_model_version_id + chain_deployment {{ + id + chain {{ + id + hostname + }} + }} }} }} """.strip() @@ -373,10 +395,13 @@ def test_create_chain_no_existing_chain(): "json": { "data": { "deploy_chain_atomic": { - "chain_id": "new-chain-id", - "chain_deployment_id": "new-chain-deployment-id", - "entrypoint_model_id": "new-entrypoint-model-id", - "entrypoint_model_version_id": "new-entrypoint-model-version-id", + "chain_deployment": { + "id": "new-chain-deployment-id", + "chain": { + "id": "new-chain-id", + "hostname": "hostname", + }, + } } } } @@ -439,10 +464,13 @@ def test_create_chain_no_existing_chain(): dependencies: [] client_version: "{truss.version()}" ) {{ - chain_id - chain_deployment_id - entrypoint_model_id - entrypoint_model_version_id + chain_deployment {{ + id + chain {{ + id + hostname + }} + }} }} }} """.strip() @@ -473,10 +501,13 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false() "json": { "data": { "deploy_chain_atomic": { - "chain_id": "new-chain-id", - "chain_deployment_id": "new-chain-deployment-id", - "entrypoint_model_id": "new-entrypoint-model-id", - "entrypoint_model_version_id": "new-entrypoint-model-version-id", + "chain_deployment": { + "id": "new-chain-deployment-id", + "chain": { + "id": "new-chain-id", + "hostname": "hostname", + }, + } } } } @@ -542,10 +573,13 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false() dependencies: [] client_version: "{truss.version()}" ) {{ - chain_id - chain_deployment_id - entrypoint_model_id - entrypoint_model_version_id + chain_deployment {{ + id + chain {{ + id + hostname + }} + }} }} }} """.strip() @@ -576,10 +610,13 @@ def test_create_chain_existing_chain_publish_true_no_promotion(): "json": { "data": { "deploy_chain_atomic": { - "chain_id": "new-chain-id", - "chain_deployment_id": "new-chain-deployment-id", - "entrypoint_model_id": "new-entrypoint-model-id", - "entrypoint_model_version_id": "new-entrypoint-model-version-id", + "chain_deployment": { + "id": "new-chain-deployment-id", + "chain": { + "id": "new-chain-id", + "hostname": "hostname", + }, + } } } } @@ -642,10 +679,13 @@ def test_create_chain_existing_chain_publish_true_no_promotion(): dependencies: [] client_version: "{truss.version()}" ) {{ - chain_id - chain_deployment_id - entrypoint_model_id - entrypoint_model_version_id + chain_deployment {{ + id + chain {{ + id + hostname + }} + }} }} }} """.strip() diff --git a/truss/tests/remote/baseten/test_service.py b/truss/tests/remote/baseten/test_service.py index 923e1c9ac..2abcdbadb 100644 --- a/truss/tests/remote/baseten/test_service.py +++ b/truss/tests/remote/baseten/test_service.py @@ -1,30 +1,42 @@ from truss.remote.baseten import service -def test_model_invocation_url_prod(): - url = service.URLConfig.invocation_url( - "https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=False +def test_model_invoke_url_prod(): + url = service.URLConfig.invoke_url( + "https://model-123.api.baseten.co", + service.URLConfig.MODEL, + "789", + is_draft=False, ) assert url == "https://model-123.api.baseten.co/deployment/789/predict" -def test_model_invocation_url_draft(): - url = service.URLConfig.invocation_url( - "https://api.baseten.co", service.URLConfig.MODEL, "123", "789", is_draft=True +def test_model_invoke_url_draft(): + url = service.URLConfig.invoke_url( + "https://model-123.api.baseten.co", + service.URLConfig.MODEL, + "789", + is_draft=True, ) assert url == "https://model-123.api.baseten.co/development/predict" -def test_chain_invocation_url_prod(): - url = service.URLConfig.invocation_url( - "https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=False +def test_chain_invoke_url_prod(): + url = service.URLConfig.invoke_url( + "https://chain-abc.api.baseten.co", + service.URLConfig.CHAIN, + "666", + is_draft=False, ) assert url == "https://chain-abc.api.baseten.co/deployment/666/run_remote" -def test_chain_invocation_url_draft(): - url = service.URLConfig.invocation_url( - "https://api.baseten.co", service.URLConfig.CHAIN, "abc", "666", is_draft=True +def test_chain_invoke_url_draft(): + url = service.URLConfig.invoke_url( + "https://chain-abc.api.baseten.co", + service.URLConfig.CHAIN, + "666", + is_draft=True, ) assert url == "https://chain-abc.api.baseten.co/development/run_remote" From 3ea049ff7f7f5469cd621dad3bcb7f48d237969e Mon Sep 17 00:00:00 2001 From: basetenbot <96544894+basetenbot@users.noreply.github.com> Date: Fri, 7 Feb 2025 21:15:23 +0000 Subject: [PATCH 3/3] Bump version to 0.9.60 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b8a5a17e3..8bf143259 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.60.0" +version = "0.9.60" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md"