Skip to content

Commit

Permalink
Merge pull request basetenlabs#1376 from basetenlabs/bump-version-0.9.60
Browse files Browse the repository at this point in the history
Release 0.9.60
  • Loading branch information
nnarayen authored Feb 10, 2025
2 parents 26671e1 + 3ea049f commit a0e2e11
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 196 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.60.0"
version = "0.9.60"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
14 changes: 8 additions & 6 deletions truss-chains/truss_chains/deployment/deployment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,14 @@ def __init__(
@property
def run_remote_url(self) -> str:
"""URL to invoke the entrypoint."""
return b10_service.URLConfig.invocation_url(
self._remote.api.rest_api_url,
b10_service.URLConfig.CHAIN,
self._chain_deployment_handle.chain_id,
self._chain_deployment_handle.chain_deployment_id,
self._chain_deployment_handle.is_draft,

handle = self._chain_deployment_handle

return b10_service.URLConfig.invoke_url(
hostname=handle.hostname,
config=b10_service.URLConfig.CHAIN,
entity_version_id=handle.chain_deployment_id,
is_draft=handle.is_draft,
)

def run_remote(self, json_data: Dict) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions truss/api/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class ModelDeployment:
_baseten_service: service.BasetenService

def __init__(self, service: service.BasetenService):
self.model_id = service._model_id
self.model_deployment_id = service._model_version_id
self.model_id = service.model_id
self.model_deployment_id = service.model_version_id
self._baseten_service = service

def wait_for_active(self, timeout_seconds: int = 600) -> bool:
Expand Down
128 changes: 75 additions & 53 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,27 +138,32 @@ def create_model_from_truss(
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
create_model_from_truss(
name: "{model_name}",
s3_key: "{s3_key}",
config: "{config}",
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {"true" if is_trusted else "false"},
allow_truss_download: {"true" if allow_truss_download else "false"},
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f"model_origin: {origin.value}" if origin else ""}
) {{
id,
name,
version_id
mutation {{
create_model_from_truss(
name: "{model_name}"
s3_key: "{s3_key}"
config: "{config}"
semver_bump: "{semver_bump}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
allow_truss_download: {"true" if allow_truss_download else "false"}
{f'version_name: "{deployment_name}"' if deployment_name else ""}
{f"model_origin: {origin.value}" if origin else ""}
) {{
model_version {{
id
oracle {{
id
name
hostname
}}
}}
}}
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_from_truss"]
return resp["data"]["create_model_from_truss"]["model_version"]

def create_model_version_from_truss(
self,
Expand All @@ -173,25 +178,30 @@ def create_model_version_from_truss(
environment: Optional[str] = None,
):
query_string = f"""
mutation {{
create_model_version_from_truss(
model_id: "{model_id}"
s3_key: "{s3_key}",
config: "{config}",
semver_bump: "{semver_bump}",
client_version: "{client_version}",
is_trusted: {"true" if is_trusted else "false"},
scale_down_old_production: {"false" if preserve_previous_prod_deployment else "true"},
{f'name: "{deployment_name}"' if deployment_name else ""}
{f'environment_name: "{environment}"' if environment else ""}
) {{
id
mutation {{
create_model_version_from_truss(
model_id: "{model_id}"
s3_key: "{s3_key}"
config: "{config}"
semver_bump: "{semver_bump}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
scale_down_old_production: {"false" if preserve_previous_prod_deployment else "true"}
{f'name: "{deployment_name}"' if deployment_name else ""}
{f'environment_name: "{environment}"' if environment else ""}
) {{
model_version {{
id
oracle {{
hostname
}}
}}
}}
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["create_model_version_from_truss"]
return resp["data"]["create_model_version_from_truss"]["model_version"]

def create_development_model_from_truss(
self,
Expand All @@ -204,23 +214,29 @@ def create_development_model_from_truss(
origin: Optional[b10_types.ModelOrigin] = None,
):
query_string = f"""
mutation {{
deploy_draft_truss(name: "{model_name}",
s3_key: "{s3_key}",
config: "{config}",
client_version: "{client_version}",
is_trusted: {"true" if is_trusted else "false"},
allow_truss_download: {"true" if allow_truss_download else "false"},
mutation {{
deploy_draft_truss(name: "{model_name}"
s3_key: "{s3_key}"
config: "{config}"
client_version: "{client_version}"
is_trusted: {"true" if is_trusted else "false"}
allow_truss_download: {"true" if allow_truss_download else "false"}
{f"model_origin: {origin.value}" if origin else ""}
) {{
id,
name,
version_id
}}
}}
) {{
model_version {{
id
oracle {{
id
name
hostname
}}
}}
}}
}}
"""

resp = self._post_graphql_query(query_string)
return resp["data"]["deploy_draft_truss"]
return resp["data"]["deploy_draft_truss"]["model_version"]

def deploy_chain_atomic(
self,
Expand Down Expand Up @@ -251,10 +267,13 @@ def deploy_chain_atomic(
dependencies: [{dependencies_str}]
client_version: "{truss.version()}"
) {{
chain_id
chain_deployment_id
entrypoint_model_id
entrypoint_model_version_id
chain_deployment {{
id
chain {{
id
hostname
}}
}}
}}
}}
"""
Expand Down Expand Up @@ -385,9 +404,10 @@ def get_model(self, model_name):
query_string = f"""
{{
model(name: "{model_name}") {{
name
id
versions{{
name
hostname
versions {{
id
semver
truss_hash
Expand All @@ -408,8 +428,9 @@ def get_model_by_id(self, model_id: str):
query_string = f"""
{{
model(id: "{model_id}") {{
name
id
name
hostname
primary_version{{
id
semver
Expand Down Expand Up @@ -437,6 +458,7 @@ def get_model_version_by_id(self, model_version_id: str):
oracle{{
id
name
hostname
}}
}}
}}
Expand Down
47 changes: 33 additions & 14 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,16 @@ class TrussWatchState(NamedTuple):


class ChainDeploymentHandleAtomic(NamedTuple):
hostname: str
chain_id: str
chain_deployment_id: str
is_draft: bool
entrypoint_model_id: str
entrypoint_model_version_id: str


class ModelVersionHandle(NamedTuple):
version_id: str
model_id: str
hostname: str


def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]:
Expand Down Expand Up @@ -160,10 +165,9 @@ def create_chain_atomic(
)

return ChainDeploymentHandleAtomic(
chain_id=res["chain_id"],
chain_deployment_id=res["chain_deployment_id"],
entrypoint_model_id=res["entrypoint_model_id"],
entrypoint_model_version_id=res["entrypoint_model_version_id"],
chain_deployment_id=res["chain_deployment"]["id"],
chain_id=res["chain_deployment"]["chain"]["id"],
hostname=res["chain_deployment"]["chain"]["hostname"],
is_draft=is_draft,
)

Expand Down Expand Up @@ -193,9 +197,9 @@ def exists_model(api: BasetenApi, model_name: str) -> Optional[str]:
return model["model"]["id"]


def get_model_versions(api: BasetenApi, model_name: ModelName) -> Tuple[str, List]:
def get_model_and_versions(api: BasetenApi, model_name: ModelName) -> Tuple[dict, List]:
query_result = api.get_model(model_name.value)["model"]
return query_result["id"], query_result["versions"]
return query_result, query_result["versions"]


def get_dev_version_from_versions(versions: List[dict]) -> Optional[dict]:
Expand Down Expand Up @@ -334,7 +338,7 @@ def create_truss_service(
deployment_name: Optional[str] = None,
origin: Optional[b10_types.ModelOrigin] = None,
environment: Optional[str] = None,
) -> Tuple[str, str]:
) -> ModelVersionHandle:
"""
Create a model in the Baseten remote.
Expand All @@ -352,8 +356,9 @@ def create_truss_service(
development model.
Returns:
A tuple of the model ID and version ID
A Model Version handle.
"""

if is_draft:
model_version_json = api.create_development_model_from_truss(
model_name,
Expand All @@ -365,11 +370,16 @@ def create_truss_service(
origin=origin,
)

return model_version_json["id"], model_version_json["version_id"]
return ModelVersionHandle(
version_id=model_version_json["id"],
model_id=model_version_json["oracle"]["id"],
hostname=model_version_json["oracle"]["hostname"],
)

if model_id is None:
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)

model_version_json = api.create_model_from_truss(
model_name=model_name,
s3_key=s3_key,
Expand All @@ -381,7 +391,12 @@ def create_truss_service(
deployment_name=deployment_name,
origin=origin,
)
return model_version_json["id"], model_version_json["version_id"]

return ModelVersionHandle(
version_id=model_version_json["id"],
model_id=model_version_json["oracle"]["id"],
hostname=model_version_json["oracle"]["hostname"],
)

try:
model_version_json = api.create_model_version_from_truss(
Expand All @@ -404,8 +419,12 @@ def create_truss_service(
f'Environment "{environment}" does not exist. You can create environments in the Baseten UI.'
) from e
raise e
model_version_id = model_version_json["id"]
return model_id, model_version_id

return ModelVersionHandle(
version_id=model_version_json["id"],
model_id=model_id,
hostname=model_version_json["oracle"]["hostname"],
)


def validate_truss_config(api: BasetenApi, config: str):
Expand Down
1 change: 0 additions & 1 deletion truss/remote/baseten/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class DeployedChainlet(pydantic.BaseModel):
is_draft: bool
status: str
logs_url: str
oracle_predict_url: str
oracle_name: str


Expand Down
Loading

0 comments on commit a0e2e11

Please sign in to comment.