Skip to content

Commit

Permalink
Inference client for Pykos (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
WT-MM authored Jan 17, 2025
1 parent 44dc678 commit a4486de
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ members = [
]

[workspace.package]
version = "0.4.1"
version = "0.5.0"
authors = [
"Benjamin Bolte <[email protected]>",
"Denys Bezmenov <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion kos-py/pykos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""KOS Python client."""

__version__ = "0.4.1"
__version__ = "0.5.0"

from . import services
from .client import KOS
2 changes: 2 additions & 0 deletions kos-py/pykos/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pykos.services.actuator import ActuatorServiceClient
from pykos.services.imu import IMUServiceClient
from pykos.services.inference import InferenceServiceClient
from pykos.services.process_manager import ProcessManagerServiceClient
from pykos.services.sim import SimServiceClient

Expand All @@ -26,6 +27,7 @@ def __init__(self, ip: str = "localhost", port: int = 50051) -> None:
self.imu = IMUServiceClient(self.channel)
self.actuator = ActuatorServiceClient(self.channel)
self.process_manager = ProcessManagerServiceClient(self.channel)
self.inference = InferenceServiceClient(self.channel)
self.sim = SimServiceClient(self.channel)

def close(self) -> None:
Expand Down
244 changes: 244 additions & 0 deletions kos-py/pykos/services/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
"""Inference service client."""

from typing import NotRequired, TypedDict

import grpc

from kos_protos import common_pb2, inference_pb2, inference_pb2_grpc


class ModelMetadata(TypedDict):
"""Model metadata for uploading models.
All fields are optional and can be used to provide additional information about the model.
"""

model_name: NotRequired[str | None]
model_description: NotRequired[str | None]
model_version: NotRequired[str | None]
model_author: NotRequired[str | None]


class TensorDimension(TypedDict):
"""Information about a tensor dimension.
Args:
size: Size of this dimension
name: Name of this dimension (e.g., "batch", "channels", "height")
dynamic: Whether this dimension can vary (e.g., batch size)
"""

size: int
name: str
dynamic: bool


class Tensor(TypedDict):
"""A tensor containing data.
Args:
values: Tensor values in row-major order
shape: List of dimension information
"""

values: list[float]
shape: list[TensorDimension]


class ForwardResponse(TypedDict):
"""Response from running model inference.
Args:
outputs: Dictionary mapping tensor names to output tensors
error: Optional error information if inference failed
"""

outputs: dict[str, Tensor]
error: NotRequired[common_pb2.Error | None]


class ModelInfo(TypedDict):
"""Information about a model.
Args:
uid: Model UID (assigned by server)
metadata: Model metadata
input_specs: Expected input tensor specifications
output_specs: Expected output tensor specifications
description: str
"""

uid: str
metadata: ModelMetadata
input_specs: dict[str, Tensor]
output_specs: dict[str, Tensor]
description: str


class GetModelsInfoResponse(TypedDict):
"""Response containing information about available models."""

models: list[ModelInfo]
error: NotRequired[common_pb2.Error | None]


class InferenceServiceClient:
"""Client for the InferenceService.
This service allows uploading models and running inference on them.
"""

def __init__(self, channel: grpc.Channel) -> None:
"""Initialize the inference service client.
Args:
channel: gRPC channel to use for communication.
"""
self.stub = inference_pb2_grpc.InferenceServiceStub(channel)

def upload_model(
self, model_data: bytes, metadata: ModelMetadata | None = None
) -> inference_pb2.UploadModelResponse:
"""Upload a model to the robot.
Example:
>>> client.upload_model(model_data,
... metadata={"model_name": "MyModel",
... "model_description": "A model for inference",
... "model_version": "1.0.0",
... "model_author": "John Doe"})
Args:
model_data: The binary model data to upload.
metadata: Optional metadata about the model that can include:
model_name: Name of the model
model_description: Description of the model
model_version: Version of the model
model_author: Author of the model
Returns:
UploadModelResponse containing the model UID and any error information.
"""
proto_metadata = None
if metadata is not None:
proto_metadata = inference_pb2.ModelMetadata(**metadata)
request = inference_pb2.UploadModelRequest(model=model_data, metadata=proto_metadata)
return self.stub.UploadModel(request)

def load_models(self, uids: list[str]) -> inference_pb2.LoadModelsResponse:
"""Load models from the robot's filesystem.
Args:
uids: List of model UIDs to load.
Returns:
LoadModelsResponse containing information about the loaded models.
"""
request = inference_pb2.ModelUids(uids=uids)
return self.stub.LoadModels(request)

def unload_models(self, uids: list[str]) -> common_pb2.ActionResponse:
"""Unload models from the robot's filesystem.
Args:
uids: List of model UIDs to unload.
Returns:
ActionResponse indicating success/failure of the unload operation.
"""
request = inference_pb2.ModelUids(uids=uids)
return self.stub.UnloadModels(request)

def get_models_info(self, model_uids: list[str] | None = None) -> GetModelsInfoResponse:
"""Get information about available models.
Args:
model_uids: Optional list of specific model UIDs to get info for.
If None, returns info for all models.
Returns:
GetModelsInfoResponse containing:
models: List of ModelInfo objects
error: Optional error information if fetching failed
"""
if model_uids is not None:
request = inference_pb2.GetModelsInfoRequest(model_uids=inference_pb2.ModelUids(uids=model_uids))
else:
request = inference_pb2.GetModelsInfoRequest(all=True)

response = self.stub.GetModelsInfo(request)

return GetModelsInfoResponse(
models=[
ModelInfo(
uid=model.uid,
metadata=ModelMetadata(
model_name=model.metadata.model_name if model.metadata.HasField("model_name") else None,
model_description=(
model.metadata.model_description if model.metadata.HasField("model_description") else None
),
model_version=(
model.metadata.model_version if model.metadata.HasField("model_version") else None
),
model_author=model.metadata.model_author if model.metadata.HasField("model_author") else None,
),
input_specs={
name: Tensor(
values=list(tensor.values),
shape=[
TensorDimension(size=dim.size, name=dim.name, dynamic=dim.dynamic)
for dim in tensor.shape
],
)
for name, tensor in model.input_specs.items()
},
output_specs={
name: Tensor(
values=list(tensor.values),
shape=[
TensorDimension(size=dim.size, name=dim.name, dynamic=dim.dynamic)
for dim in tensor.shape
],
)
for name, tensor in model.output_specs.items()
},
description=model.description,
)
for model in response.models
],
error=response.error if response.HasField("error") else None,
)

def forward(self, model_uid: str, inputs: dict[str, Tensor]) -> ForwardResponse:
"""Run inference using a specified model.
Args:
model_uid: The UID of the model to use for inference.
inputs: Dictionary mapping tensor names to tensors.
Returns:
ForwardResponse containing:
outputs: Dictionary mapping tensor names to output tensors
error: Optional error information if inference failed
"""
tensor_inputs = {}
for name, tensor in inputs.items():
shape = [
inference_pb2.Tensor.Dimension(size=dim["size"], name=dim["name"], dynamic=dim["dynamic"])
for dim in tensor["shape"]
]
proto_tensor = inference_pb2.Tensor(values=tensor["values"], shape=shape)
tensor_inputs[name] = proto_tensor

response = self.stub.Forward(inference_pb2.ForwardRequest(model_uid=model_uid, inputs=tensor_inputs))

return ForwardResponse(
outputs={
name: Tensor(
values=list(tensor.values),
shape=[TensorDimension(size=dim.size, name=dim.name, dynamic=dim.dynamic) for dim in tensor.shape],
)
for name, tensor in response.outputs.items()
},
error=response.error if response.HasField("error") else None,
)
24 changes: 20 additions & 4 deletions kos/proto/kos/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,34 @@ message ModelMetadata {

// Information about a model
message ModelInfo {
string uid = 1; // Model UID (assigned by server)
ModelMetadata metadata = 2; // Model metadata
string uid = 1; // Model UID (assigned by server)
ModelMetadata metadata = 2; // Model metadata
map<string, Tensor> input_specs = 3; // Expected input tensor specifications
map<string, Tensor> output_specs = 4; // Expected output tensor specifications
string description = 5; // Optional description of tensor usage
}

// Request message for running inference.
message ForwardRequest {
string model_uid = 1; // Model UID to use for inference
repeated float inputs = 2; // Input data for the model
map<string, Tensor> inputs = 2; // Named input tensors
}

// A tensor containing data
message Tensor {
repeated float values = 1; // Tensor values in row-major order
repeated Dimension shape = 2; // Shape of the tensor

// Dimension information
message Dimension {
uint32 size = 1; // Size of this dimension
string name = 2; // Name (e.g., "batch", "channels", "height")
bool dynamic = 3; // Whether this dimension can vary (e.g., batch size)
}
}

// Response message containing inference results.
message ForwardResponse {
repeated float outputs = 1; // Output data from the model
map<string, Tensor> outputs = 1; // Named output tensors
kos.common.Error error = 2; // Error details if inference failed
}
8 changes: 7 additions & 1 deletion kos/src/hal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,17 @@ pub trait Inference: Send + Sync {
model: Vec<u8>,
metadata: Option<ModelMetadata>,
) -> Result<UploadModelResponse>;

async fn get_models_info(&self, request: GetModelsInfoRequest)
-> Result<GetModelsInfoResponse>;

async fn load_models(&self, uids: Vec<String>) -> Result<LoadModelsResponse>;
async fn unload_models(&self, uids: Vec<String>) -> Result<ActionResponse>;
async fn forward(&self, model_uid: String, inputs: Vec<f32>) -> Result<ForwardResponse>;
async fn forward(
&self,
model_uid: String,
inputs: std::collections::HashMap<String, Tensor>,
) -> Result<ForwardResponse>;
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down

0 comments on commit a4486de

Please sign in to comment.