diff --git a/pyproject.toml b/pyproject.toml index f463c1f9c..2f5634499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.56rc2" +version = "0.9.56rc3" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index aff9350f7..194adcec7 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -208,27 +208,28 @@ class TRTLLMConfiguration(BaseModel): def migrate_runtime_fields(cls, data: Any) -> Any: extra_runtime_fields = {} valid_build_fields = {} - for key, value in data.get("build").items(): - if key in TrussTRTLLMBuildConfiguration.__annotations__: - valid_build_fields[key] = value - else: - if key in TrussTRTLLMRuntimeConfiguration.__annotations__: - logger.warning(f"Found runtime.{key}: {value} in build config") - extra_runtime_fields[key] = value - if extra_runtime_fields: - logger.warning( - f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values." - " This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config." - ) - data.get("runtime").update( - { - k: v - for k, v in extra_runtime_fields.items() - if k not in data.get("runtime") - } - ) - - data.update({"build": valid_build_fields}) + if isinstance(data.get("build"), dict): + for key, value in data.get("build").items(): + if key in TrussTRTLLMBuildConfiguration.__annotations__: + valid_build_fields[key] = value + else: + if key in TrussTRTLLMRuntimeConfiguration.__annotations__: + logger.warning(f"Found runtime.{key}: {value} in build config") + extra_runtime_fields[key] = value + if extra_runtime_fields: + logger.warning( + f"Found extra fields {list(extra_runtime_fields.keys())} in build configuration, unspecified runtime fields will be configured using these values." + " This configuration of deprecated fields is scheduled for removal, please upgrade to the latest truss version and update configs according to https://docs.baseten.co/performance/engine-builder-config." + ) + data.get("runtime").update( + { + k: v + for k, v in extra_runtime_fields.items() + if k not in data.get("runtime") + } + ) + data.update({"build": valid_build_fields}) + return data return data @property diff --git a/truss/tests/trt_llm/test_trt_llm_config.py b/truss/tests/trt_llm/test_trt_llm_config.py index 0c4daf930..e6a0578d1 100644 --- a/truss/tests/trt_llm/test_trt_llm_config.py +++ b/truss/tests/trt_llm/test_trt_llm_config.py @@ -1,9 +1,16 @@ from truss.base.trt_llm_config import ( TRTLLMConfiguration, TrussTRTLLMBatchSchedulerPolicy, + TrussTRTLLMBuildConfiguration, + TrussTRTLLMRuntimeConfiguration, ) +def test_trt_llm_config_init_from_pydantic_models(trtllm_config): + build_config = TrussTRTLLMBuildConfiguration(**trtllm_config["trt_llm"]["build"]) + TRTLLMConfiguration(build=build_config, runtime=TrussTRTLLMRuntimeConfiguration()) + + def test_trt_llm_configuration_init_and_migrate_deprecated_runtime_fields( deprecated_trtllm_config, ):