Skip to content

Commit

Permalink
feat(server): async generator stream support
Browse files Browse the repository at this point in the history
This adds stream support (i.e. progress bars) via async generators.
  • Loading branch information
floryst committed Feb 13, 2023
1 parent 27d8ec4 commit 48fd22c
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 10 deletions.
10 changes: 9 additions & 1 deletion server/custom/user_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from volview_server import VolViewAPI, rpc
import asyncio

from volview_server import VolViewAPI, rpc, stream
from volview_server.converters import (
convert_itk_to_vtkjs_image,
convert_vtkjs_to_itk_image,
Expand Down Expand Up @@ -46,3 +48,9 @@ async def number_trivia(self):
url = "http://numbersapi.com/random/"
async with session.get(url) as resp:
return await resp.text()

@stream("number_stream")
async def number_stream(self):
for i in range(1, 101):
yield {"progress": i}
await asyncio.sleep(0.1)
4 changes: 2 additions & 2 deletions server/volview_server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__version__ = "0.1.0"
__author__ = "Kitware, Inc."
__all__ = ["VolViewAPI", "rpc"]
__all__ = ["VolViewAPI", "rpc", "stream"]

from volview_server.api import VolViewAPI, rpc
from volview_server.api import VolViewAPI, rpc, stream
39 changes: 39 additions & 0 deletions server/volview_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@
NUM_THREADS = 1
OBJECT_ID_TAG = "__volview_object_id"

_STREAM_URIS_ATTR = "_volview_stream_uris"


def stream(name: str):
"""Exposes a method as a stream endpoint with the given name."""
wslink.checkURI(name)

def wrapper(fn):
if not inspect.isasyncgenfunction(fn):
raise Exception(f'Stream "{name}" is not an async generator')
uris = getattr(fn, _STREAM_URIS_ATTR, [])
uris.append(name)
setattr(fn, _STREAM_URIS_ATTR, uris)
return fn

return wrapper


def rpc(name: str):
"""Exposes a method as an RPC endpoint with the given name."""
Expand Down Expand Up @@ -63,6 +80,19 @@ def __init__(self, *args, **kwargs):
encode_np_array,
]

self._stream_fns = self._get_stream_functions()

def _get_stream_functions(self):
"""Generates a map of name -> stream-enabled function."""
stream_fn_map = {}
for _, fn in inspect.getmembers(self, inspect.isasyncgenfunction):
uris = getattr(fn, _STREAM_URIS_ATTR, [])
for uri in uris:
if uri in stream_fn_map:
raise Exception(f"{uri} has multiple mappings")
stream_fn_map[uri] = fn
return stream_fn_map

def _unpack_data(self, data: bytes):
return msgpack.unpackb(
data, ext_hook=self._decode_ext, object_hook=self._decode_object
Expand Down Expand Up @@ -156,3 +186,12 @@ def _decode_tagged_object(self, obj: Any):
if obj_id in self._objects:
return self._objects[obj_id]
return obj

@rpc("volview.stream")
async def _stream_generator(self, endpoint: str, channel: str, args: List[Any]):
"""Streams the results of an async generator."""
fn = self._stream_fns.get(endpoint)
if not fn:
raise Exception(f"No stream with the name {endpoint}")
async for data in fn(*args):
self.publish(channel, data)
46 changes: 41 additions & 5 deletions src/components/ServerModule.vue
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export default defineComponent({
const doSum = async () => {
doSumLoading.value = true;
try {
sum.value = await rconn.call('add', sumOp1.value, sumOp2.value);
sum.value = await rconn.call('add', [sumOp1.value, sumOp2.value]);
} finally {
doSumLoading.value = false;
}
Expand All @@ -54,6 +54,26 @@ export default defineComponent({
}
};
// --- stream test --- //
const streamProgress = ref(0);
const streamLoading = ref(false);
type StreamData = { progress: number };
const onStreamData = (data: StreamData) => {
const { progress } = data;
streamProgress.value = progress;
if (progress === 100) {
streamLoading.value = false;
}
};
const startStream = async () => {
streamLoading.value = true;
await rconn.stream('number_stream', onStreamData);
};
// --- median filter --- //
const medianFilterLoading = ref(false);
Expand All @@ -77,11 +97,10 @@ export default defineComponent({
medianFilterLoading.value = true;
try {
const blurredImageJSON = await rconn.call(
'median_filter',
const blurredImageJSON = await rconn.call('median_filter', [
image.toJSON(),
medianFilterRadius.value
);
medianFilterRadius.value,
]);
const blurredImage = vtk(blurredImageJSON) as vtkImageData;
const imageStore = useImageStore();
Expand Down Expand Up @@ -114,6 +133,13 @@ export default defineComponent({
trivia,
triviaLoading,
startStream,
streamLoading,
streamProgress: computed(() => {
const p = streamProgress.value;
return p < 100 ? p : 'Done!';
}),
doMedianFilter,
medianFilterLoading,
hasCurrentImage,
Expand Down Expand Up @@ -180,6 +206,16 @@ export default defineComponent({
</v-col>
</v-row>
<v-divider />
<v-subheader>Progress</v-subheader>
<v-row class="mb-4">
<v-col>
<v-btn @click="startStream" :loading="streamLoading" :disabled="!ready">
Start progress
</v-btn>
<span class="ml-3"> Progress: {{ streamProgress }} </span>
</v-col>
</v-row>
<v-divider />
<v-subheader>Median Filter</v-subheader>
<div>
<v-row>
Expand Down
53 changes: 51 additions & 2 deletions src/core/remote/remote-connection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ export default class RemoteConnection {
* @param methodName
* @param args
*/
async call<R = unknown>(methodName: string, ...args: unknown[]) {
async call<R>(methodName: string, args?: unknown[]) {
if (!this.connected || !this.ws) {
throw new Error('Not connected');
}

const argObjIDs = await Promise.all(
args.map((arg) => this.uploadObject(arg))
(args ?? []).map((arg) => this.uploadObject(arg))
);
const argObjRefs = argObjIDs.map((id) => ({ [OBJECT_ID_TAG]: id }));

Expand All @@ -103,6 +103,55 @@ export default class RemoteConnection {
}
}

/**
* Calls a remote streaming method.
* @param methodName
* @param callback
*/
async stream<D>(
methodName: string,
callback: (data: D) => void
): Promise<void>;

/**
* Calls a remote streaming method.
* @param methodName
* @param args
* @param callback
*/
async stream<D>(
methodName: string,
args: unknown[],
callback: (data: D) => void
): Promise<void>;

async stream<D>(
methodName: string,
argsOrCallback: unknown[] | ((data: D) => void),
maybeCallback?: (data: D) => void
) {
if (!this.connected || !this.ws) {
throw new Error('not connected');
}

const args = Array.isArray(argsOrCallback) ? argsOrCallback : [];
const callback = Array.isArray(argsOrCallback)
? maybeCallback!
: argsOrCallback;

const channel = nanoid(32).replace('-', '');
const session = this.ws.getSession();
const { unsubscribe } = session.subscribe(channel, async (data) => {
callback(data[0] as D);
});

try {
await this.call('volview.stream', [methodName, channel, args]);
} finally {
unsubscribe();
}
}

/**
* Encodes a set of arguments and invokes the RPC endpoint.
* @param methodName
Expand Down

0 comments on commit 48fd22c

Please sign in to comment.