Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend response header autocompletion support #120

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions spot_wrapper/testing/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from abc import ABC, abstractmethod

import grpc
from bosdyn.api.header_pb2 import CommonError

from spot_wrapper.testing.helpers import GeneralizedDecorator
from spot_wrapper.testing.helpers import ForwardingWrapper


def implemented(function: typing.Callable) -> bool:
Expand Down Expand Up @@ -94,10 +95,12 @@ class AutoServicer(object):
autospec: if true, deferred handlers will be used in place for every
non-implemented method handler.
autotrack: if true, tracking handlers will decorate every method handler.
autocomplete: if true, autocompleting handlers will decorate every method handler.
"""

autospec = False
autotrack = False
autocomplete = False

def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -112,6 +115,8 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.needs_shutdown.append(underlying_callable)
if self.autotrack:
underlying_callable = TrackingStreamStreamRpcHandler(underlying_callable)
if self.autocomplete:
underlying_callable = AutoCompletingStreamStreamRpcHandler(underlying_callable)
if underlying_callable is not handler.stream_stream:
setattr(self, unqualified_name, underlying_callable)
else:
Expand All @@ -122,6 +127,8 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.needs_shutdown.append(underlying_callable)
if self.autotrack:
underlying_callable = TrackingUnaryStreamRpcHandler(underlying_callable)
if self.autocomplete:
underlying_callable = AutoCompletingUnaryStreamRpcHandler(underlying_callable)
if underlying_callable is not handler.unary_stream:
setattr(self, unqualified_name, underlying_callable)
else:
Expand All @@ -133,6 +140,8 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.needs_shutdown.append(underlying_callable)
if self.autotrack:
underlying_callable = TrackingStreamUnaryRpcHandler(underlying_callable)
if self.autocomplete:
underlying_callable = AutoCompletingStreamUnaryRpcHandler(underlying_callable)
if underlying_callable is not handler.stream_unary:
setattr(self, unqualified_name, underlying_callable)
else:
Expand All @@ -143,6 +152,8 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.needs_shutdown.append(underlying_callable)
if self.autotrack:
underlying_callable = TrackingUnaryUnaryRpcHandler(underlying_callable)
if self.autocomplete:
underlying_callable = AutoCompletingUnaryUnaryRpcHandler(underlying_callable)
if underlying_callable is not handler.unary_unary:
setattr(self, unqualified_name, underlying_callable)

Expand All @@ -167,7 +178,7 @@ def shutdown(self):
obj.shutdown()


class TrackingUnaryUnaryRpcHandler(GeneralizedDecorator):
class TrackingUnaryUnaryRpcHandler(ForwardingWrapper):
"""A decorator for unary-unary gRPC handlers that tracks calls."""

def __init__(self, handler: typing.Callable) -> None:
Expand All @@ -183,7 +194,7 @@ def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing
self.num_calls += 1


class TrackingStreamUnaryRpcHandler(GeneralizedDecorator):
class TrackingStreamUnaryRpcHandler(ForwardingWrapper):
"""A decorator for stream-unary gRPC handlers that tracks calls."""

def __init__(self, handler: typing.Callable) -> None:
Expand All @@ -201,7 +212,7 @@ def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerCont
self.num_calls += 1


class TrackingUnaryStreamRpcHandler(GeneralizedDecorator):
class TrackingUnaryStreamRpcHandler(ForwardingWrapper):
"""A decorator for unary-stream gRPC handlers that tracks calls."""

def __init__(self, handler: typing.Callable) -> None:
Expand All @@ -217,7 +228,7 @@ def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing
self.num_calls += 1


class TrackingStreamStreamRpcHandler(GeneralizedDecorator):
class TrackingStreamStreamRpcHandler(ForwardingWrapper):
"""A decorator for stream-stream gRPC handlers that tracks calls."""

def __init__(self, handler: typing.Callable) -> None:
Expand All @@ -235,7 +246,62 @@ def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerCont
self.num_calls += 1


class DeferredRpcHandler(GeneralizedDecorator):
def fill_response_header(request: typing.Any, response: typing.Any) -> bool:
"""Fill response header if any when missing."""
if not hasattr(response, "header"):
return False
if hasattr(request, "header"):
response.header.request_header.CopyFrom(request.header)
response.header.request_received_timestamp.CopyFrom(request.header.request_timestamp)
response.header.error.code = response.header.error.code or CommonError.CODE_OK
return True


class AutoCompletingUnaryUnaryRpcHandler(ForwardingWrapper):
"""A decorator for unary-unary gRPC handlers that autocompletes response headers."""

def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Any:
response = self.__wrapped__(request, context)
fill_response_header(request, response)
return response


class AutoCompletingStreamUnaryRpcHandler(ForwardingWrapper):
"""A decorator for stream-unary gRPC handlers that autocompletes response headers.

The last request chunk header will be used to complete the response header.
"""

def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerContext) -> typing.Any:
*head_requests, tail_request = request_iterator
response = self.__wrapped__(iter([*head_requests, tail_request]), context)
fill_response_header(tail_request, response)
return response


class AutoCompletingUnaryStreamRpcHandler(ForwardingWrapper):
"""A decorator for unary-stream gRPC handlers that autocompletes response headers."""

def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Iterator:
for response in self.__wrapped__(request, context):
fill_response_header(request, response)
yield response


class AutoCompletingStreamStreamRpcHandler(ForwardingWrapper):
"""A decorator for stream-stream gRPC handlers that autocompletes response headers.

The last request chunk header will be used to complete the response header.
"""

def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerContext) -> typing.Iterator:
*head_requests, tail_request = request_iterator
for response in self.__wrapped__(iter([*head_requests, tail_request]), context):
fill_response_header(tail_request, response)
yield response


class DeferredRpcHandler(ForwardingWrapper):
"""
A gRPC handler that decouples invocation and computation execution paths.

Expand Down
28 changes: 4 additions & 24 deletions spot_wrapper/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import functools
import typing

import grpc
from bosdyn.api.header_pb2 import CommonError
from bosdyn.api.lease_pb2 import ResourceTree


class GeneralizedDecorator:
class ForwardingWrapper:
"""A `functools.wraps` equivalent that is transparent to attribute access."""

__name__ = __qualname__ = __doc__ = ""

__annotations__ = {}

@staticmethod
def wraps(wrapped: typing.Callable):
def decorator(func: typing.Callable):
class wrapper(GeneralizedDecorator):
class wrapper(ForwardingWrapper):
def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return func(*args, **kwargs)

Expand All @@ -34,26 +34,6 @@ def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
raise NotImplementedError()


UnaryUnaryHandlerCallable = typing.Callable[[typing.Any, grpc.ServicerContext], typing.Any]


def enforce_matching_headers(
handler: UnaryUnaryHandlerCallable,
) -> UnaryUnaryHandlerCallable:
"""Enforce headers for handler request and response match (by copy)."""

@GeneralizedDecorator.wraps(handler)
def wrapper(request: typing.Any, context: grpc.ServicerContext) -> typing.Any:
response = handler(request, context)
if hasattr(request, "header") and hasattr(response, "header"):
response.header.request_header.CopyFrom(request.header)
response.header.request_received_timestamp.CopyFrom(request.header.request_timestamp)
response.header.error.code = response.header.error.code or CommonError.CODE_OK
return response

return wrapper


def walk_resource_tree(resource_tree: ResourceTree) -> typing.Iterable[ResourceTree]:
"""Walks `resource_tree` top-down, depth-first."""
yield resource_tree
Expand Down
15 changes: 1 addition & 14 deletions spot_wrapper/testing/mocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
from bosdyn.api.world_object_service_pb2_grpc import WorldObjectServiceServicer

from spot_wrapper.testing.grpc import AutoServicer, collect_method_handlers
from spot_wrapper.testing.helpers import enforce_matching_headers
from spot_wrapper.testing.mocks.auth import MockAuthService
from spot_wrapper.testing.mocks.cam import MockCAMService
from spot_wrapper.testing.mocks.directory import MockDirectoryService
Expand Down Expand Up @@ -161,19 +160,7 @@ class BaseMockSpot(
"""

name = "mockie"

def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
for _, handler in collect_method_handlers(self):
if handler.request_streaming:
continue
if handler.response_streaming:
continue
setattr(
self,
handler.unary_unary.__name__,
enforce_matching_headers(handler.unary_unary),
)
autocomplete = True


class MockSpot(
Expand Down