Skip to content

Commit

Permalink
Update Chains Docs. (basetenlabs#1282)
Browse files Browse the repository at this point in the history
* Update Chains Docs.

* Fix imblanced braces

* Fix docstring
  • Loading branch information
marius-baseten authored Dec 12, 2024
1 parent 80a965b commit b34acc8
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 241 deletions.
151 changes: 103 additions & 48 deletions docs/chains/doc_gen/API-reference.mdx

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/chains/doc_gen/generate_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
NON_PUBLIC_SYMBOLS = [
# "truss_chains.definitions.AssetSpec",
# "truss_chains.definitions.ComputeSpec",
"truss_chains.remote.ChainService",
"truss_chains.deployment.deployment_client.ChainService",
"truss_chains.definitions.Environment",
]

Expand Down Expand Up @@ -69,7 +69,7 @@
"General framework and helper functions.",
[
"truss_chains.push",
"truss_chains.remote.ChainService",
"truss_chains.deployment.deployment_client.ChainService",
"truss_chains.make_abs_path_here",
"truss_chains.run_local",
"truss_chains.DeployedServiceDescriptor",
Expand Down
108 changes: 76 additions & 32 deletions docs/chains/doc_gen/generated-reference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ https://github.com/basetenlabs/truss/tree/main/docs/chains/doc_gen

APIs for creating user-defined Chainlets.


### *class* `truss_chains.ChainletBase`

Base class for all chainlets.
Expand Down Expand Up @@ -41,8 +42,9 @@ chainlet instance is provided.
| Name | Type | Description |
|------|------|-------------|
| `chainlet_cls` | *Type[ChainletT]* | The chainlet class of the dependency. |
| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). |
| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. |
| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. |
| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. |

* **Returns:**
A “symbolic marker” to be used as a default argument in a chainlet’s
Expand All @@ -69,6 +71,7 @@ context instance is provided.
* **Return type:**
[*DeploymentContext*](#truss_chains.DeploymentContext)


### *class* `truss_chains.DeploymentContext`

Bases: `pydantic.BaseModel`
Expand All @@ -85,12 +88,11 @@ an access token for downloading model weights).
| Name | Type | Description |
|------|------|-------------|
| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. |
| `user_config` | ** | User-defined configuration for the chainlet. |
| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. |
| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. |
| `secrets` | *MappingNoIter[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. |
| `environment` | *[Environment](#truss_chains.definitions.Environment* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. |

#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#truss_chains.ServiceDescriptor)]*
#### chainlet_to_service *: Mapping[str, [DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor)]*

#### data_dir *: Path | None*

Expand All @@ -106,10 +108,11 @@ an access token for downloading model weights).
* **Parameters:**
**chainlet_name** (*str*)
* **Return type:**
[*ServiceDescriptor*](#truss_chains.ServiceDescriptor)
[*DeployedServiceDescriptor*](#truss_chains.DeployedServiceDescriptor)

#### secrets *: MappingNoIter[str, str]*


### *class* `truss_chains.definitions.Environment`

Bases: `pydantic.BaseModel`
Expand All @@ -120,6 +123,7 @@ The environment the chainlet is deployed in.
**name** (*str*) – The name of the environment.
#### name *: str*


### *class* `truss_chains.ChainletOptions`

Bases: `pydantic.BaseModel`
Expand All @@ -136,24 +140,28 @@ Bases: `pydantic.BaseModel`

#### env_variables *: Mapping[str, str]*


### *class* `truss_chains.RPCOptions`

Bases: `pydantic.BaseModel`

Options to customize RPCs to dependency chainlets.


**Parameters:**

| Name | Type | Description |
|------|------|-------------|
| `timeout_sec` | *int* | |
| `retries` | *int* | |

| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. |
| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. |
| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. |

#### retries *: int*

#### timeout_sec *: int*

#### use_binary *: bool*

### `truss_chains.mark_entrypoint`

Decorator to mark a chainlet as the entrypoint of a chain.
Expand Down Expand Up @@ -181,6 +189,7 @@ class MyChainlet(ChainletBase):

These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources.


### *class* `truss_chains.RemoteConfig`

Bases: `pydantic.BaseModel`
Expand Down Expand Up @@ -234,6 +243,7 @@ class MyChainlet(chains.ChainletBase):

#### options *: [ChainletOptions](#truss_chains.ChainletOptions)*


### *class* `truss_chains.DockerImage`

Bases: `pydantic.BaseModel`
Expand Down Expand Up @@ -266,10 +276,13 @@ modules and keep their requirement files right next their python source files.

#### external_package_dirs *: list[AbsPath] | None*

#### *classmethod* migrate_fields(values)

#### pip_requirements *: list[str]*

#### pip_requirements_file *: AbsPath | None*


### *class* `truss_chains.BasetenImage`

Bases: `Enum`
Expand All @@ -283,6 +296,7 @@ uses GPUs, drivers will be included in the image.

#### PY39 *= 'py39'*


### *class* `truss_chains.CustomImage`

Bases: `pydantic.BaseModel`
Expand All @@ -304,6 +318,7 @@ Configures the usage of a custom image hosted on dockerhub.

#### python_executable_path *: str | None*


### *class* `truss_chains.Compute`

Specifies which compute resources a chainlet has in the *remote* deployment.
Expand Down Expand Up @@ -342,6 +357,7 @@ two ways:
* **Return type:**
*ComputeSpec*


### *class* `truss_chains.Assets`

Specifies which assets a chainlet can access in the remote deployment.
Expand All @@ -355,7 +371,7 @@ from truss.base import truss_config
mistral_cache = truss_config.ModelRepo(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
allow_patterns=["*.json", "*.safetensors", ".model"]
)
)
chains.Assets(cached=[mistral_cache], ...)
```

Expand Down Expand Up @@ -402,15 +418,17 @@ Deploys a chain remotely (with all dependent chainlets).
| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) |
| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). |
| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. |
| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. |
| `remote` | *str* | name of a remote config in .trussrc. If not provided, it will be inquired. |
| `environment` | *str\|None* | The name of an environment to promote deployment into. |
| `progress_bar` | *Type[progress.Progress]\|None* | Optional rich.progress.Progress if output is desired. |

* **Returns:**
A chain service handle to the deployed chain.
* **Return type:**
*BasetenChainService*

### *class* `truss_chains.remote.ChainService`

### *class* `truss_chains.deployment.deployment_client.ChainService`

Bases: `ABC`

Expand Down Expand Up @@ -535,7 +553,7 @@ corresponding fields of `DeploymentContext`.
|------|------|-------------|
| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. |
| `data_dir` | *Path\|str\|None* | Path to a directory with data files. |
| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A dict of chainlet names to service descriptors. |
| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A dict of chainlet names to service descriptors. |

* **Return type:**
*ContextManager*[None]
Expand All @@ -555,8 +573,9 @@ if __name__ == "__main__":
with chains.run_local(
secrets={"some_token": os.environ["SOME_TOKEN"]},
chainlet_to_service={
"SomeChainlet": chains.ServiceDescriptor(
"SomeChainlet": chains.DeployedServiceDescriptor(
name="SomeChainlet",
display_name="SomeChainlet",
predict_url="https://...",
options=chains.RPCOptions(),
)
Expand All @@ -571,52 +590,61 @@ if __name__ == "__main__":
Refer to the [local debugging guide](https://docs.baseten.co/chains/guide#test-a-chain-locally)
for more details.

### *class* `truss_chains.ServiceDescriptor`

Bases: `pydantic.BaseModel`
### *class* `truss_chains.DeployedServiceDescriptor`

Bundles values to establish an RPC session to a dependency chainlet,
specifically with `StubBase`.
Bases: `ServiceDescriptor`

**Parameters:**

| Name | Type | Description |
|------|------|-------------|
| `name` | *str* | |
| `predict_url` | *str* | |
| `display_name` | *str* | |
| `options` | *[RPCOptions](#truss_chains.RPCOptions* | |
| `predict_url` | *str* | |


#### name *: str*

#### options *: [RPCOptions](#truss_chains.RPCOptions)*

#### predict_url *: str*


### *class* `truss_chains.StubBase`

Bases: `ABC`
Bases: `BasetenSession`, `ABC`

Base class for stubs that invoke remote chainlets.

Extends `BasetenSession` with methods for data serialization, de-serialization
and invoking other endpoints.

It is used internally for RPCs to dependency chainlets, but it can also be used
in user-code for wrapping a deployed truss model into the chains framework, e.g.
like that:
in user-code for wrapping a deployed truss model into the Chains framework. It
flexibly supports JSON and pydantic inputs and output. Example usage:

```default
import pydantic
import truss_chains as chains
class WhisperOutput(pydantic.BaseModel):
...
class DeployedWhisper(chains.StubBase):
# Input JSON, output JSON.
async def run_remote(self, audio_b64: str) -> Any:
return await self.predict_async(
inputs={"audio": audio_b64})
# resp == {"text": ..., "language": ...}
# OR Input JSON, output pydantic model.
async def run_remote(self, audio_b64: str) -> WhisperOutput:
resp = await self._remote.predict_async(
json_payload={"audio": audio_b64})
return WhisperOutput(text=resp["text"], language=resp["language"])
return await self.predict_async(
inputs={"audio": audio_b64}, output_model=WhisperOutput)
# OR Input and output are pydantic models.
async def run_remote(self, data: WhisperInput) -> WhisperOutput:
return await self.predict_async(data, output_model=WhisperOutput)
class MyChainlet(chains.ChainletBase):
Expand All @@ -628,14 +656,17 @@ class MyChainlet(chains.ChainletBase):
context,
options=chains.RPCOptions(retries=3),
)
async def run_remote(self, ...):
await self._whisper.run_remote(...)
```


**Parameters:**

| Name | Type | Description |
|------|------|-------------|
| `service_descriptor` | *[ServiceDescriptor](#truss_chains.ServiceDescriptor* | Contains the URL and other configuration. |
| `service_descriptor` | *[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | Contains the URL and other configuration. |
| `api_key` | *str* | A baseten API key to authorize requests. |


Expand All @@ -653,6 +684,22 @@ Factory method, convenient to be used in chainlet’s `__init__`-method.
| `options` | *[RPCOptions](#truss_chains.RPCOptions* | RPC options, e.g. retries. |


#### *async* predict_async(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT

#### *async* predict_async(inputs: InputT, output_model: None = None) → Any

#### *async* predict_async_stream(inputs)

* **Parameters:**
**inputs** (*InputT*)
* **Return type:**
*AsyncIterator*[bytes]

#### predict_sync(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT

#### predict_sync(inputs: InputT, output_model: None = None) → Any


### *class* `truss_chains.RemoteErrorDetail`

Bases: `pydantic.BaseModel`
Expand All @@ -665,7 +712,6 @@ error response.

| Name | Type | Description |
|------|------|-------------|
| `remote_name` | *str* | |
| `exception_cls_name` | *str* | |
| `exception_module_name` | *str\|None* | |
| `exception_message` | *str* | |
Expand All @@ -686,6 +732,4 @@ with stack traces.
* **Return type:**
str

#### remote_name *: str*

#### user_stack_trace *: list[StackFrame]*
2 changes: 1 addition & 1 deletion docs/chains/doc_gen/mdx_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _line_replacements(line: str) -> str:
first_brace = line.find("(")
if first_brace > 0:
line = line[:first_brace]
return f"### *class* `{line}`"
return f"\n### *class* `{line}`"
elif line.startswith("### "):
line = line.replace("### ", "").strip()
if not any(sym in line for sym in NON_PUBLIC_SYMBOLS):
Expand Down
Loading

0 comments on commit b34acc8

Please sign in to comment.