Skip to content

Commit

Permalink
fix initialization from pydantic models (basetenlabs#1281)
Browse files Browse the repository at this point in the history
* fix initialization from pydantic models

* 0.9.56rc3
  • Loading branch information
joostinyi authored Dec 11, 2024
1 parent 90169ed commit 80a965b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.56rc2"
version = "0.9.56rc3"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
43 changes: 22 additions & 21 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,27 +208,28 @@ class TRTLLMConfiguration(BaseModel):
def migrate_runtime_fields(cls, data: Any) -> Any:
extra_runtime_fields = {}
valid_build_fields = {}
for key, value in data.get("build").items():
if key in TrussTRTLLMBuildConfiguration.__annotations__:
valid_build_fields[key] = value
else:
if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
logger.warning(f"Found runtime.{key}: {value} in build config")
extra_runtime_fields[key] = value
if extra_runtime_fields:
logger.warning(
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
)
data.get("runtime").update(
{
k: v
for k, v in extra_runtime_fields.items()
if k not in data.get("runtime")
}
)

data.update({"build": valid_build_fields})
if isinstance(data.get("build"), dict):
for key, value in data.get("build").items():
if key in TrussTRTLLMBuildConfiguration.__annotations__:
valid_build_fields[key] = value
else:
if key in TrussTRTLLMRuntimeConfiguration.__annotations__:
logger.warning(f"Found runtime.{key}: {value} in build config")
extra_runtime_fields[key] = value
if extra_runtime_fields:
logger.warning(
f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values."
" This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config."
)
data.get("runtime").update(
{
k: v
for k, v in extra_runtime_fields.items()
if k not in data.get("runtime")
}
)
data.update({"build": valid_build_fields})
return data
return data

@property
Expand Down
7 changes: 7 additions & 0 deletions truss/tests/trt_llm/test_trt_llm_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from truss.base.trt_llm_config import (
TRTLLMConfiguration,
TrussTRTLLMBatchSchedulerPolicy,
TrussTRTLLMBuildConfiguration,
TrussTRTLLMRuntimeConfiguration,
)


def test_trt_llm_config_init_from_pydantic_models(trtllm_config):
build_config = TrussTRTLLMBuildConfiguration(**trtllm_config["trt_llm"]["build"])
TRTLLMConfiguration(build=build_config, runtime=TrussTRTLLMRuntimeConfiguration())


def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields(
deprecated_trtllm_config,
):
Expand Down

0 comments on commit 80a965b

Please sign in to comment.