Skip to content

Commit

Permalink
Support promote to --environment in chains (basetenlabs#1180)
Browse files Browse the repository at this point in the history
* first version of --environment

* refactor promote support

* couple more fixes

* add tests

* couple fixes

* clean up env resource not found error handling

* simpler url replacement
  • Loading branch information
spal1 authored Oct 14, 2024
1 parent 2dd8237 commit 7c94c48
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 59 deletions.
12 changes: 9 additions & 3 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pydantic
from truss import truss_config
from truss.constants import PRODUCTION_ENVIRONMENT_NAME
from truss.remote import baseten as baseten_remote
from truss.remote import remote_cli, remote_factory

Expand Down Expand Up @@ -609,22 +610,27 @@ class PushOptions(SafeModelNonSerializable):
class PushOptionsBaseten(PushOptions):
remote_provider: baseten_remote.BasetenRemote
publish: bool
promote: bool
environment: Optional[str]

@classmethod
def create(
cls,
chain_name: str,
publish: bool,
promote: bool,
promote: Optional[bool],
only_generate_trusses: bool,
user_env: Mapping[str, str],
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> "PushOptionsBaseten":
if not remote:
remote = remote_cli.inquire_remote_name(
remote_factory.RemoteFactory.get_available_config_names()
)
if promote and not environment:
environment = PRODUCTION_ENVIRONMENT_NAME
if environment:
publish = True
remote_provider = cast(
baseten_remote.BasetenRemote,
remote_factory.RemoteFactory.create(remote=remote),
Expand All @@ -633,9 +639,9 @@ def create(
remote_provider=remote_provider,
chain_name=chain_name,
publish=publish,
promote=promote,
only_generate_trusses=only_generate_trusses,
user_env=user_env,
environment=environment,
)


Expand Down
3 changes: 3 additions & 0 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def push(
user_env: Optional[Mapping[str, str]] = None,
only_generate_trusses: bool = False,
remote: Optional[str] = None,
environment: Optional[str] = None,
) -> chains_remote.BasetenChainService:
"""
Deploys a chain remotely (with all dependent chainlets).
Expand All @@ -144,6 +145,7 @@ def push(
``/tmp/.chains_generated``.
remote: name of a remote config in `.trussrc`. If not provided, it will be
inquired.
environment: The name of an environment to promote deployment into.
Returns:
A chain service handle to the deployed chain.
Expand All @@ -156,6 +158,7 @@ def push(
user_env=user_env or {},
only_generate_trusses=only_generate_trusses,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint, options)
assert isinstance(service, chains_remote.BasetenChainService) # Per options above.
Expand Down
8 changes: 2 additions & 6 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,15 @@ def _push_to_baseten(
model_name = truss_handle.spec.config.model_name
assert model_name is not None
assert bool(_MODEL_NAME_RE.match(model_name))
if options.promote and not options.publish:
logging.info("`promote=True` overrides `publish` to `True`.")
logging.info(
f"Pushing chainlet `{model_name}` as a truss model on Baseten "
f"(publish={options.publish}, promote={options.promote})."
f"Pushing chainlet `{model_name}` as a truss model on Baseten (publish={options.publish})"
)
# Models must be trusted to use the API KEY secret.
service = options.remote_provider.push(
truss_handle,
model_name=model_name,
trusted=True,
publish=options.publish,
promote=options.promote,
origin=b10_types.ModelOrigin.CHAINS,
)
return cast(b10_service.BasetenService, service)
Expand Down Expand Up @@ -327,7 +323,7 @@ def _create_baseten_chain(
chain_name=baseten_options.chain_name,
chainlets=chainlet_data,
publish=baseten_options.publish,
promote=baseten_options.promote,
environment=baseten_options.environment,
)
return BasetenChainService(
baseten_options.chain_name,
Expand Down
32 changes: 29 additions & 3 deletions truss/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,13 @@ def chains():
"""Subcommands for truss chains"""


def _make_chains_curl_snippet(run_remote_url: str) -> str:
def _make_chains_curl_snippet(run_remote_url: str, environment: Optional[str]) -> str:
if environment:
idx = run_remote_url.find("deployment")
if idx != -1:
run_remote_url = (
run_remote_url[:idx] + f"environments/{environment}/run_remote"
)
return (
f"curl -X POST '{run_remote_url}' \\\n"
' -H "Authorization: Api-Key $BASETEN_API_KEY" \\\n'
Expand Down Expand Up @@ -505,6 +511,15 @@ def _create_chains_table(service) -> Tuple[rich.table.Table, List[str]]:
default=False,
help="Replace production chainlets with newly deployed chainlets.",
)
@click.option(
"--environment",
type=str,
required=False,
help=(
"Deploy the chain as a published deployment to the specified environment."
"If specified, --publish is implied and the supplied value of --promote will be ignored."
),
)
@click.option(
"--wait/--no-wait",
type=bool,
Expand Down Expand Up @@ -557,6 +572,7 @@ def push_chain(
dryrun: bool,
user_env: Optional[str],
remote: Optional[str],
environment: Optional[str],
) -> None:
"""
Deploys a chain remotely.
Expand Down Expand Up @@ -597,6 +613,10 @@ def push_chain(
else:
user_env_parsed = {}

if promote and environment:
promote_warning = "`promote` flag and `environment` flag were both specified. Ignoring the value of `promote`"
console.print(promote_warning, style="yellow")

with framework.import_target(source, entrypoint) as entrypoint_cls:
chain_name = name or entrypoint_cls.__name__
options = chains_def.PushOptionsBaseten.create(
Expand All @@ -606,6 +626,7 @@ def push_chain(
only_generate_trusses=dryrun,
user_env=user_env_parsed,
remote=remote,
environment=environment,
)
service = chains_remote.push(entrypoint_cls, options)

Expand All @@ -614,7 +635,9 @@ def push_chain(
return

assert isinstance(service, chains_remote.BasetenChainService)
curl_snippet = _make_chains_curl_snippet(service.run_remote_url)
curl_snippet = _make_chains_curl_snippet(
service.run_remote_url, options.environment
)

table, statuses = _create_chains_table(service)
status_check_wait_sec = 2
Expand Down Expand Up @@ -647,7 +670,10 @@ def push_chain(
for log in intercepted_logs:
console.print(f"\t{log}")
if success:
console.print("Deployment succeeded.", style="bold green")
deploy_success_text = "Deployment succeeded."
if environment:
deploy_success_text = f"Your chain has been deployed into the {options.environment} environment."
console.print(deploy_success_text, style="bold green")
console.print(f"You can run the chain with:\n{curl_snippet}")
if watch: # Note that this command will print a startup message.
chains_remote.watch(
Expand Down
21 changes: 12 additions & 9 deletions truss/remote/baseten/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,25 @@ def deploy_draft_chain(
return resp["data"]["deploy_draft_chain"]

def deploy_chain_deployment(
self, chain_id: str, chainlet_data: List[b10_types.ChainletData], promote: bool
self,
chain_id: str,
chainlet_data: List[b10_types.ChainletData],
environment: Optional[str] = None,
):
chainlet_data_strings = [
_chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data
]
chainlets_string = ", ".join(chainlet_data_strings)
query_string = f"""
mutation {{
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
promote_after_deploy: {'true' if promote else 'false'},
) {{
chain_id
chain_deployment_id
}}
deploy_chain_deployment(
chain_id: "{chain_id}",
chainlets: [{chainlets_string}],
{f'environment_name: "{environment}"' if environment else ""}
) {{
chain_id
chain_deployment_id
}}
}}
"""
resp = self._post_graphql_query(query_string)
Expand Down
21 changes: 18 additions & 3 deletions truss/remote/baseten/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def create_chain(
chain_name: str,
chainlets: List[b10_types.ChainletData],
is_draft: bool,
promote: bool,
environment: Optional[str],
) -> ChainDeploymentHandle:
if is_draft:
response = api.deploy_draft_chain(chain_name, chainlets)
Expand All @@ -93,8 +93,20 @@ def create_chain(
# if there is no chain already, the first deployment will
# already be production, and only published deployments can
# be promoted.
response = api.deploy_chain_deployment(chain_id, chainlets, promote)
try:
response = api.deploy_chain_deployment(chain_id, chainlets, environment)
except ApiError as e:
if (
e.graphql_error_code
== BasetenApi.GraphQLErrorCodes.RESOURCE_NOT_FOUND.value
):
raise ValueError(
f'Environment "{environment}" does not exist. You can create environments in the Chains UI.'
) from e
raise e
else:
if environment and environment != PRODUCTION_ENVIRONMENT_NAME:
raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING)
response = api.deploy_chain(chain_name, chainlets)

return ChainDeploymentHandle(
Expand Down Expand Up @@ -299,7 +311,10 @@ def create_truss_service(
environment=environment,
)
except ApiError as e:
if "Environment matching query does not exist" in e.message:
if (
e.graphql_error_code
== BasetenApi.GraphQLErrorCodes.RESOURCE_NOT_FOUND.value
):
raise ValueError(
f'Environment "{environment}" does not exist. You can create environments in the Baseten UI.'
) from e
Expand Down
8 changes: 4 additions & 4 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def create_chain(
chain_name: str,
chainlets: List[custom_types.ChainletData],
publish: bool = False,
promote: bool = False,
environment: Optional[str] = None,
) -> ChainDeploymentHandle:
if promote:
# If we are promoting a model after deploy, it must be published.
if environment:
# If we are promoting a model to an environment after deploy, it must be published.
# Draft models cannot be promoted.
publish = True
# Returns tuple of (chain_id, chain_deployment_id)
Expand All @@ -81,7 +81,7 @@ def create_chain(
chain_name=chain_name,
chainlets=chainlets,
is_draft=not publish,
promote=promote,
environment=environment,
)

def get_chainlets(
Expand Down
56 changes: 56 additions & 0 deletions truss/tests/remote/baseten/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import requests
from requests import Response
from truss.remote.baseten.api import BasetenApi
from truss.remote.baseten.custom_types import ChainletData
from truss.remote.baseten.error import ApiError


Expand Down Expand Up @@ -53,6 +54,22 @@ def mock_create_model_response():
return response


def mock_deploy_chain_deployment_response():
response = Response()
response.status_code = 200
response.json = mock.Mock(
return_value={
"data": {
"deploy_chain_deployment": {
"chain_id": "12345",
"chain_deployment_id": "54321",
}
}
}
)
return response


@pytest.fixture
def baseten_api(mock_auth_service):
return BasetenApi("https://app.test.com", mock_auth_service)
Expand Down Expand Up @@ -204,3 +221,42 @@ def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified(
assert 'client_version: "client_version"' in gql_mutation
assert "is_trusted: true" in gql_mutation
assert "version_name: " not in gql_mutation


@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
def test_deploy_chain_deployment(mock_post, baseten_api):
baseten_api.deploy_chain_deployment(
"chain_id",
[
ChainletData(
name="chainlet-1",
oracle_version_id="some-ov-id",
is_entrypoint=True,
)
],
"production",
)

gql_mutation = mock_post.call_args[1]["data"]["query"]
assert 'chain_id: "chain_id"' in gql_mutation
assert "chainlets:" in gql_mutation
assert 'environment_name: "production"' in gql_mutation


@mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response())
def test_deploy_chain_deployment_no_environment(mock_post, baseten_api):
baseten_api.deploy_chain_deployment(
"chain_id",
[
ChainletData(
name="chainlet-1",
oracle_version_id="some-ov-id",
is_entrypoint=True,
)
],
)

gql_mutation = mock_post.call_args[1]["data"]["query"]
assert 'chain_id: "chain_id"' in gql_mutation
assert "chainlets:" in gql_mutation
assert "environment_name" not in gql_mutation
Loading

0 comments on commit 7c94c48

Please sign in to comment.