From 48fd22cee8be7bba05314b405b58209ff142737f Mon Sep 17 00:00:00 2001 From: Forrest Date: Mon, 13 Feb 2023 12:20:53 -0500 Subject: [PATCH] feat(server): async generator stream support This adds stream support (i.e. progress bars) via async generators. --- server/custom/user_api.py | 10 +++++- server/volview_server/__init__.py | 4 +-- server/volview_server/api.py | 39 ++++++++++++++++++++ src/components/ServerModule.vue | 46 +++++++++++++++++++++--- src/core/remote/remote-connection.ts | 53 ++++++++++++++++++++++++++-- 5 files changed, 142 insertions(+), 10 deletions(-) diff --git a/server/custom/user_api.py b/server/custom/user_api.py index 73ba5f871..c5e4562fc 100644 --- a/server/custom/user_api.py +++ b/server/custom/user_api.py @@ -1,4 +1,6 @@ -from volview_server import VolViewAPI, rpc +import asyncio + +from volview_server import VolViewAPI, rpc, stream from volview_server.converters import ( convert_itk_to_vtkjs_image, convert_vtkjs_to_itk_image, @@ -46,3 +48,9 @@ async def number_trivia(self): url = "http://numbersapi.com/random/" async with session.get(url) as resp: return await resp.text() + + @stream("number_stream") + async def number_stream(self): + for i in range(1, 101): + yield {"progress": i} + await asyncio.sleep(0.1) diff --git a/server/volview_server/__init__.py b/server/volview_server/__init__.py index 3cf022a71..a2b0100bc 100644 --- a/server/volview_server/__init__.py +++ b/server/volview_server/__init__.py @@ -1,5 +1,5 @@ __version__ = "0.1.0" __author__ = "Kitware, Inc." -__all__ = ["VolViewAPI", "rpc"] +__all__ = ["VolViewAPI", "rpc", "stream"] -from volview_server.api import VolViewAPI, rpc +from volview_server.api import VolViewAPI, rpc, stream diff --git a/server/volview_server/api.py b/server/volview_server/api.py index 0a9b98268..6ca0583f9 100644 --- a/server/volview_server/api.py +++ b/server/volview_server/api.py @@ -15,6 +15,23 @@ NUM_THREADS = 1 OBJECT_ID_TAG = "__volview_object_id" +_STREAM_URIS_ATTR = "_volview_stream_uris" + + +def stream(name: str): + """Exposes a method as a stream endpoint with the given name.""" + wslink.checkURI(name) + + def wrapper(fn): + if not inspect.isasyncgenfunction(fn): + raise Exception(f'Stream "{name}" is not an async generator') + uris = getattr(fn, _STREAM_URIS_ATTR, []) + uris.append(name) + setattr(fn, _STREAM_URIS_ATTR, uris) + return fn + + return wrapper + def rpc(name: str): """Exposes a method as an RPC endpoint with the given name.""" @@ -63,6 +80,19 @@ def __init__(self, *args, **kwargs): encode_np_array, ] + self._stream_fns = self._get_stream_functions() + + def _get_stream_functions(self): + """Generates a map of name -> stream-enabled function.""" + stream_fn_map = {} + for _, fn in inspect.getmembers(self, inspect.isasyncgenfunction): + uris = getattr(fn, _STREAM_URIS_ATTR, []) + for uri in uris: + if uri in stream_fn_map: + raise Exception(f"{uri} has multiple mappings") + stream_fn_map[uri] = fn + return stream_fn_map + def _unpack_data(self, data: bytes): return msgpack.unpackb( data, ext_hook=self._decode_ext, object_hook=self._decode_object @@ -156,3 +186,12 @@ def _decode_tagged_object(self, obj: Any): if obj_id in self._objects: return self._objects[obj_id] return obj + + @rpc("volview.stream") + async def _stream_generator(self, endpoint: str, channel: str, args: List[Any]): + """Streams the results of an async generator.""" + fn = self._stream_fns.get(endpoint) + if not fn: + raise Exception(f"No stream with the name {endpoint}") + async for data in fn(*args): + self.publish(channel, data) diff --git a/src/components/ServerModule.vue b/src/components/ServerModule.vue index 510ae51f4..3d9b868bf 100644 --- a/src/components/ServerModule.vue +++ b/src/components/ServerModule.vue @@ -34,7 +34,7 @@ export default defineComponent({ const doSum = async () => { doSumLoading.value = true; try { - sum.value = await rconn.call('add', sumOp1.value, sumOp2.value); + sum.value = await rconn.call('add', [sumOp1.value, sumOp2.value]); } finally { doSumLoading.value = false; } @@ -54,6 +54,26 @@ export default defineComponent({ } }; + // --- stream test --- // + + const streamProgress = ref(0); + const streamLoading = ref(false); + + type StreamData = { progress: number }; + + const onStreamData = (data: StreamData) => { + const { progress } = data; + streamProgress.value = progress; + if (progress === 100) { + streamLoading.value = false; + } + }; + + const startStream = async () => { + streamLoading.value = true; + await rconn.stream('number_stream', onStreamData); + }; + // --- median filter --- // const medianFilterLoading = ref(false); @@ -77,11 +97,10 @@ export default defineComponent({ medianFilterLoading.value = true; try { - const blurredImageJSON = await rconn.call( - 'median_filter', + const blurredImageJSON = await rconn.call('median_filter', [ image.toJSON(), - medianFilterRadius.value - ); + medianFilterRadius.value, + ]); const blurredImage = vtk(blurredImageJSON) as vtkImageData; const imageStore = useImageStore(); @@ -114,6 +133,13 @@ export default defineComponent({ trivia, triviaLoading, + startStream, + streamLoading, + streamProgress: computed(() => { + const p = streamProgress.value; + return p < 100 ? p : 'Done!'; + }), + doMedianFilter, medianFilterLoading, hasCurrentImage, @@ -180,6 +206,16 @@ export default defineComponent({ + Progress + + + + Start progress + + Progress: {{ streamProgress }} + + + Median Filter
diff --git a/src/core/remote/remote-connection.ts b/src/core/remote/remote-connection.ts index fcb3d2f9c..fc5d4ea68 100644 --- a/src/core/remote/remote-connection.ts +++ b/src/core/remote/remote-connection.ts @@ -77,13 +77,13 @@ export default class RemoteConnection { * @param methodName * @param args */ - async call(methodName: string, ...args: unknown[]) { + async call(methodName: string, args?: unknown[]) { if (!this.connected || !this.ws) { throw new Error('Not connected'); } const argObjIDs = await Promise.all( - args.map((arg) => this.uploadObject(arg)) + (args ?? []).map((arg) => this.uploadObject(arg)) ); const argObjRefs = argObjIDs.map((id) => ({ [OBJECT_ID_TAG]: id })); @@ -103,6 +103,55 @@ export default class RemoteConnection { } } + /** + * Calls a remote streaming method. + * @param methodName + * @param callback + */ + async stream( + methodName: string, + callback: (data: D) => void + ): Promise; + + /** + * Calls a remote streaming method. + * @param methodName + * @param args + * @param callback + */ + async stream( + methodName: string, + args: unknown[], + callback: (data: D) => void + ): Promise; + + async stream( + methodName: string, + argsOrCallback: unknown[] | ((data: D) => void), + maybeCallback?: (data: D) => void + ) { + if (!this.connected || !this.ws) { + throw new Error('not connected'); + } + + const args = Array.isArray(argsOrCallback) ? argsOrCallback : []; + const callback = Array.isArray(argsOrCallback) + ? maybeCallback! + : argsOrCallback; + + const channel = nanoid(32).replace('-', ''); + const session = this.ws.getSession(); + const { unsubscribe } = session.subscribe(channel, async (data) => { + callback(data[0] as D); + }); + + try { + await this.call('volview.stream', [methodName, channel, args]); + } finally { + unsubscribe(); + } + } + /** * Encodes a set of arguments and invokes the RPC endpoint. * @param methodName