Skip to content

Commit

Permalink
Merge pull request basetenlabs#1373 from basetenlabs/bump-version-0.60.0
Browse files Browse the repository at this point in the history
Release 0.60.0
  • Loading branch information
nnarayen authored Feb 7, 2025
2 parents 8037619 + ef2e43b commit 26671e1
Show file tree
Hide file tree
Showing 51 changed files with 2,011 additions and 1,135 deletions.
808 changes: 418 additions & 390 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.59"
version = "0.60.0"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -24,6 +24,7 @@ requires-poetry = ">=2.0"

[tool.poetry.scripts]
truss = "truss.cli.cli:truss_cli"
truss-docker-build-setup = "truss.contexts.docker_build_setup:docker_build_setup"

[tool.poetry.urls]
"Homepage" = "https://truss.baseten.co"
Expand Down Expand Up @@ -165,7 +166,7 @@ numpy = ">=1.23.5"
opentelemetry-api = ">=1.25.0"
opentelemetry-exporter-otlp = ">=1.25.0"
opentelemetry-sdk = ">=1.25.0"
truss_transfer="0.0.1rc4"
truss_transfer="0.0.1"
uvicorn = ">=0.24.0"
uvloop = ">=0.17.0"

Expand Down
10 changes: 5 additions & 5 deletions smoketests/test_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from truss.remote.baseten.utils import status as status_utils

from truss_chains import definitions
from truss_chains.remote_chainlet import stub
from truss_chains.remote_chainlet import stub, utils

backend_env_domain = "staging.baseten.co"
BASETEN_API_KEY = os.environ["BASETEN_API_KEY_STAGING"]
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_itest_chain_publish(prepare) -> None:
# Test regular (JSON) invocation.
chain_stub = make_stub(url, definitions.RPCOptions(timeout_sec=10))
trace_parent = generate_traceparent()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
result = chain_stub.predict_sync({"length": 30, "num_partitions": 3})

expected = [
Expand All @@ -169,7 +169,7 @@ def test_itest_chain_publish(prepare) -> None:
invocation_times_sec = []
for i in range(10):
t0 = time.perf_counter()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
chain_stub.predict_sync({"length": 30, "num_partitions": 3})
invocation_times_sec.append(time.perf_counter() - t0)

Expand All @@ -182,7 +182,7 @@ def test_itest_chain_publish(prepare) -> None:
url, definitions.RPCOptions(timeout_sec=10, use_binary=True)
)
trace_parent = generate_traceparent()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
result = chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3})

expected = [
Expand All @@ -198,7 +198,7 @@ def test_itest_chain_publish(prepare) -> None:
invocation_times_sec = []
for i in range(10):
t0 = time.perf_counter()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3})
invocation_times_sec.append(time.perf_counter() - t0)

Expand Down
7 changes: 7 additions & 0 deletions truss-chains/tests/import/model_without_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class ClassWithoutModelInheritance:
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
19 changes: 19 additions & 0 deletions truss-chains/tests/import/standalone_with_multiple_entrypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import truss_chains as chains


class FirstModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count


class SecondModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
6 changes: 6 additions & 0 deletions truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import math

from user_package import shared_chainlet
from user_package.nested_package import io_types

import truss_chains as chains

logger = logging.getLogger(__name__)

IMAGE_BASETEN = chains.DockerImage(
base_image=chains.BasetenImage.PY310,
pip_requirements_file=chains.make_abs_path_here("requirements.txt"),
Expand Down Expand Up @@ -103,6 +106,7 @@ def __init__(
text_to_num: TextToNum = chains.depends(TextToNum),
context=chains.depends_context(),
) -> None:
logging.info("User log root during load.")
self._context = context
self._data_generator = data_generator
self._data_splitter = splitter
Expand All @@ -117,6 +121,8 @@ async def run_remote(
),
simple_default_arg: list[str] = ["a", "b"],
) -> tuple[int, str, int, shared_chainlet.SplitTextOutput, list[str]]:
logging.info("User log root.")
logger.info("User log module.")
data = self._data_generator.run_remote(length)
text_parts, number, items = await self._data_splitter.run_remote(
io_types.SplitTextInput(
Expand Down
36 changes: 25 additions & 11 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
def test_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand All @@ -37,7 +39,7 @@ def test_chain():
response = requests.post(
url,
json={"length": 30, "num_partitions": 3},
headers={"traceparent": "TEST TEST TEST"},
headers={"traceparent": "TRACE_ID"},
)
print(response.content)
assert response.status_code == 200
Expand Down Expand Up @@ -70,7 +72,10 @@ def test_chain():

# Test with errors.
response = requests.post(
url, json={"length": 300, "num_partitions": 3}, stream=True
url,
json={"length": 300, "num_partitions": 3},
stream=True,
headers={"traceparent": "TRACE_ID"},
)
print(response)
assert response.status_code == 500
Expand All @@ -86,12 +91,12 @@ def test_chain():
File \".*?/itest_chain\.py\", line \d+, in _accumulate_parts
value \+= self\._text_to_num\.run_remote\(part\)
ValueError: \(showing chained remote errors, root error at the bottom\)
├─ Error in dependency Chainlet `TextToNum` \(HTTP status 500\):
├─ Error calling dependency Chainlet `TextToNum`, HTTP status=500, trace ID=`TRACE_ID`.
│ Chainlet-Traceback \(most recent call last\):
│ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ generated_text = self\._replicator\.run_remote\(data\)
│ ValueError: \(showing chained remote errors, root error at the bottom\)
│ ├─ Error in dependency Chainlet `TextReplicator` \(HTTP status 500\):
│ ├─ Error calling dependency Chainlet `TextReplicator`, HTTP status=500, trace ID=`TRACE_ID`.
│ │ Chainlet-Traceback \(most recent call last\):
│ │ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ │ validate_data\(data\)
Expand All @@ -106,7 +111,9 @@ def test_chain():
@pytest.mark.asyncio
async def test_chain_local():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
with public_api.run_local():
with pytest.raises(ValueError):
# First time `SplitTextFailOnce` raises an error and
Expand Down Expand Up @@ -140,7 +147,9 @@ def test_streaming_chain():
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "Consumer"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down Expand Up @@ -176,7 +185,7 @@ def test_streaming_chain():
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote(cause_error=False)
print(result)
Expand All @@ -198,7 +207,7 @@ def test_numpy_chain(mode):
target = "HostBinary"
with ensure_kill_all():
chain_root = TEST_ROOT / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, target) as entrypoint:
with framework.ChainletImporter.import_target(chain_root, target) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand All @@ -213,11 +222,14 @@ def test_numpy_chain(mode):
print(response.json())


@pytest.mark.integration
@pytest.mark.asyncio
async def test_timeout():
with ensure_kill_all():
chain_root = TEST_ROOT / "timeout" / "timeout_chain.py"
with framework.import_target(chain_root, "TimeoutChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "TimeoutChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand Down Expand Up @@ -284,7 +296,9 @@ def test_traditional_truss():
def test_custom_health_checks_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "custom_health_checks" / "custom_health_checks.py"
with framework.import_target(chain_root, "CustomHealthChecks") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "CustomHealthChecks"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down
18 changes: 18 additions & 0 deletions truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import logging
import pathlib
import re
from typing import AsyncIterator, Iterator, List

Expand All @@ -12,6 +13,7 @@

utils.setup_dev_logging(logging.DEBUG)

TEST_ROOT = pathlib.Path(__file__).parent.resolve()

# Assert that naive chainlet initialization is detected and prevented. #################

Expand Down Expand Up @@ -668,3 +670,19 @@ def is_healthy(self) -> str: # type: ignore[misc]

async def run_remote(self) -> str:
return ""


def test_import_model_requires_entrypoint():
model_src = TEST_ROOT / "import" / "model_without_inheritance.py"
match = r"No Model class in `.+` inherits from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass


def test_import_model_requires_single_entrypoint():
model_src = TEST_ROOT / "import" / "standalone_with_multiple_entrypoints.py"
match = r"Multiple Model classes in `.+` inherit from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass
3 changes: 1 addition & 2 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
RemoteErrorDetail,
RPCOptions,
)
from truss_chains.framework import ChainletBase, ModelBase
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def migrate_fields(cls, values):
class ComputeSpec(pydantic.BaseModel):
"""Parsed and validated compute. See ``Compute`` for more information."""

# TODO[rcano] add node count
cpu_count: int = 1
predict_concurrency: int = 1
memory: str = "2Gi"
Expand Down
8 changes: 3 additions & 5 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
f"request: starlette.requests.Request) -> {output_type_name}:"
)
# Add error handling context manager:
parts.append(
_indent("with stub.trace_parent(request), utils.exception_to_http_error():")
)
parts.append(_indent("with utils.predict_context(request):"))
# Invoke Chainlet.
if (
chainlet_descriptor.endpoint.is_async
Expand Down Expand Up @@ -733,7 +731,7 @@ def gen_truss_model_from_source(
# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = model_src.absolute().parent
with framework.import_target(model_src) as entrypoint_cls:
with framework.ModelImporter.import_target(model_src) as entrypoint_cls:
descriptor = framework.get_descriptor(entrypoint_cls)
return gen_truss_model(
model_root=root_dir,
Expand Down Expand Up @@ -773,7 +771,7 @@ def gen_truss_chainlet(
gen_root = pathlib.Path(tempfile.gettempdir())
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
logging.info(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"Code generation for {chainlet_descriptor.chainlet_cls.entity_type} `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_write_truss_config_yaml(
Expand Down
Loading

0 comments on commit 26671e1

Please sign in to comment.