diff --git a/.codecov.yml b/.codecov.yml index fb3d02de..22838bc6 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -11,10 +11,4 @@ coverage: ignore: - "examples/" - - "pynumaflow/mapper/proto/*" - - "pynumaflow/sinker/proto/*" - - "pynumaflow/mapstreamer/proto/*" - - "pynumaflow/reducer/proto/*" - - "pynumaflow/sourcetransformer/proto/*" - - "pynumaflow/sideinput/proto/*" - - "pynumaflow/sourcer/proto/*" + - "pynumaflow/proto/*" diff --git a/.coveragerc b/.coveragerc index 95134092..c3dc24a6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,14 +5,18 @@ source = pynumaflow omit = pynumaflow/tests/* examples/* + pynumaflow/proto/* + pynumaflow/shared/server.py [report] exclude_lines = - def start - def start_async - def __serve_async - def start_multiproc + def sync_server_start def _run_server + def start_multiproc_server + async def start_async_server def _reserve_port if os.getenv("PYTHONDEBUG"): _LOGGER.setLevel(logging.DEBUG) + def exec_multiproc + def exec + async def aexec diff --git a/Makefile b/Makefile index 7c42a33b..af48b9ae 100644 --- a/Makefile +++ b/Makefile @@ -27,13 +27,13 @@ setup: proto: - python3 -m grpc_tools.protoc -I=pynumaflow/sinker/proto --python_out=pynumaflow/sinker/proto --grpc_python_out=pynumaflow/sinker/proto pynumaflow/sinker/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/mapper/proto --python_out=pynumaflow/mapper/proto --grpc_python_out=pynumaflow/mapper/proto pynumaflow/mapper/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/mapstreamer/proto --python_out=pynumaflow/mapstreamer/proto --grpc_python_out=pynumaflow/mapstreamer/proto pynumaflow/mapstreamer/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/reducer/proto --python_out=pynumaflow/reducer/proto --grpc_python_out=pynumaflow/reducer/proto pynumaflow/reducer/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/sourcetransformer/proto --python_out=pynumaflow/sourcetransformer/proto --grpc_python_out=pynumaflow/sourcetransformer/proto pynumaflow/sourcetransformer/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/sideinput/proto --python_out=pynumaflow/sideinput/proto --grpc_python_out=pynumaflow/sideinput/proto pynumaflow/sideinput/proto/*.proto - python3 -m grpc_tools.protoc -I=pynumaflow/sourcer/proto --python_out=pynumaflow/sourcer/proto --grpc_python_out=pynumaflow/sourcer/proto pynumaflow/sourcer/proto/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/sinker --python_out=pynumaflow/proto/sinker --grpc_python_out=pynumaflow/proto/sinker pynumaflow/proto/sinker/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/mapper --python_out=pynumaflow/proto/mapper --grpc_python_out=pynumaflow/proto/mapper pynumaflow/proto/mapper/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/mapstreamer --python_out=pynumaflow/proto/mapstreamer --grpc_python_out=pynumaflow/proto/mapstreamer pynumaflow/proto/mapstreamer/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/reducer --python_out=pynumaflow/proto/reducer --grpc_python_out=pynumaflow/proto/reducer pynumaflow/proto/reducer/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/sourcetransformer --python_out=pynumaflow/proto/sourcetransformer --grpc_python_out=pynumaflow/proto/sourcetransformer pynumaflow/proto/sourcetransformer/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/sideinput --python_out=pynumaflow/proto/sideinput --grpc_python_out=pynumaflow/proto/sideinput pynumaflow/proto/sideinput/*.proto + python3 -m grpc_tools.protoc -I=pynumaflow/proto/sourcer --python_out=pynumaflow/proto/sourcer --grpc_python_out=pynumaflow/proto/sourcer pynumaflow/proto/sourcer/*.proto - sed -i '' 's/^\(import.*_pb2\)/from . \1/' pynumaflow/*/proto/*.py + sed -i '' 's/^\(import.*_pb2\)/from . \1/' pynumaflow/proto/*/*.py diff --git a/README.md b/README.md index f9d99129..af3809fa 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![Release Version](https://img.shields.io/github/v/release/numaproj/numaflow-python?label=pynumaflow)](https://github.com/numaproj/numaflow-python/releases/latest) +This is the Python SDK for [Numaflow](https://numaflow.numaproj.io/). -This SDK provides the interface for writing [UDFs](https://numaflow.numaproj.io/user-guide/user-defined-functions/user-defined-functions/) -and [UDSinks](https://numaflow.numaproj.io/user-guide/sinks/user-defined-sinks/) in Python. +This SDK provides the interface for writing different functionalities of Numaflow like [UDFs](https://numaflow.numaproj.io/user-guide/user-defined-functions/user-defined-functions/), [UDSinks](https://numaflow.numaproj.io/user-guide/sinks/user-defined-sinks/), [UDSources](https://numaflow.numaproj.io/user-guide/sources/user-defined-sources/) and [SideInput](https://numaflow.numaproj.io/specifications/side-inputs/) in Python. ## Installation @@ -40,100 +40,119 @@ Setup [pre-commit](https://pre-commit.com/) hooks: pre-commit install ``` -## Implement a User Defined Function (UDF) +## Implementing different functionalities +- [Implement User Defined Sources](https://github.com/numaproj/numaflow-python/tree/main/examples/source) +- [Implement User Defined Source Transformers](https://github.com/numaproj/numaflow-python/tree/main/examples/sourcetransform) +- Implement User Defined Functions + - [Map](https://github.com/numaproj/numaflow-python/tree/main/examples/map) + - [Reduce](https://github.com/numaproj/numaflow-python/tree/main/examples/reduce) + - [Map Stream](https://github.com/numaproj/numaflow-python/tree/main/examples/mapstream) +- [Implement User Defined Sinks](https://github.com/numaproj/numaflow-python/tree/main/examples/sink) +- [Implement User Defined SideInputs](https://github.com/numaproj/numaflow-python/tree/main/examples/sideinput) +## Server Types -### Map +There are different types of gRPC server mechanisms which can be used to serve the UDFs, UDSinks and UDSource. +These have different functionalities and are used for different use cases. -```python -from pynumaflow.mapper import Messages, Message, Datum, Mapper +Currently we support the following server types: +- Sync Server +- Asyncronous Server +- MultiProcessing Server +Not all of the above are supported for all UDFs, UDSource and UDSinks. -def my_handler(keys: list[str], datum: Datum) -> Messages: - val = datum.value - _ = datum.event_time - _ = datum.watermark - return Messages(Message(value=val, keys=keys)) +For each of the UDFs, UDSource and UDSinks, there are seperate classes for each of the server types. +This helps in keeping the interface simple and easy to use, and the user can start the specific server type based +on the use case. -if __name__ == "__main__": - grpc_server = Mapper(handler=my_handler) - grpc_server.start() -``` -### SourceTransformer - Map with event time assignment capability -In addition to the regular Map function, SourceTransformer supports assigning a new event time to the message. -SourceTransformer is only supported at source vertex to enable (a) early data filtering and (b) watermark assignment by extracting new event time from the message payload. - -```python -from datetime import datetime -from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformer - - -def transform_handler(keys: list[str], datum: Datum) -> Messages: - val = datum.value - new_event_time = datetime.now() - _ = datum.watermark - message_t_s = Messages(Message(val, event_time=new_event_time, keys=keys)) - return message_t_s +#### SyncServer +Syncronous Server is the simplest server type. It is a multithreaded threaded server which can be used for simple UDFs and UDSinks. +Here the server will invoke the handler function for each message. The messaging is synchronous and the server will wait +for the handler to return before processing the next message. -if __name__ == "__main__": - grpc_server = SourceTransformer(handler=transform_handler) - grpc_server.start() ``` - -### Reduce - -```python -import aiorun -from typing import Iterator, List -from pynumaflow.reducer import Messages, Message, Datum, Metadata, AsyncReducer - - -async def my_handler( - keys: List[str], datums: Iterator[Datum], md: Metadata -) -> Messages: - interval_window = md.interval_window - counter = 0 - async for _ in datums: - counter += 1 - msg = ( - f"counter:{counter} interval_window_start:{interval_window.start} " - f"interval_window_end:{interval_window.end}" - ) - return Messages(Message(str.encode(msg), keys)) - - -if __name__ == "__main__": - grpc_server = AsyncReducer(handler=my_handler) - aiorun.run(grpc_server.start()) +grpc_server = MapServer(handler) ``` -### Sample Image -A sample UDF [Dockerfile](examples/map/forward_message/Dockerfile) is provided -under [examples](examples/map/forward_message). +#### AsyncServer -## Implement a User Defined Sink (UDSink) +Asyncronous Server is a multi threaded server which can be used for UDFs which are asyncronous. Here we utilize the asyncronous capabilities of Python to process multiple messages in parallel. The server will invoke the handler function for each message. The messaging is asyncronous and the server will not wait for the handler to return before processing the next message. Thus this server type is useful for UDFs which are asyncronous. +The handler function for such a server should be an async function. -```python -from typing import Iterator -from pynumaflow.sinker import Datum, Responses, Response, Sinker +``` +grpc_server = MapAsyncServer(handler) +``` +#### MultiProcessServer -def my_handler(datums: Iterator[Datum]) -> Responses: - responses = Responses() - for msg in datums: - print("User Defined Sink", msg.value.decode("utf-8")) - responses.append(Response.as_success(msg.id)) - return responses +MultiProcess Server is a multi process server which can be used for UDFs which are CPU intensive. Here we utilize the multi process capabilities of Python to process multiple messages in parallel by forking multiple servers in different processes. +The server will invoke the handler function for each message. Individually at the server level the messaging is synchronous and the server will wait for the handler to return before processing the next message. But since we have multiple servers running in parallel, the overall messaging also executes in parallel. +This could be an alternative to creating multiple replicas of the same UDF container as here we are using the multi processing capabilities of the system to process multiple messages in parallel but within the same container. -if __name__ == "__main__": - grpc_server = Sinker(my_handler) - grpc_server.start() +Thus this server type is useful for UDFs which are CPU intensive. +``` +grpc_server = MapMultiProcServer(handler) ``` -### Sample Image - -A sample UDSink [Dockerfile](examples/sink/log/Dockerfile) is provided -under [examples](examples/sink/log). \ No newline at end of file +#### Currently Supported Server Types for each functionality + +These are the class names for the server types supported by each of the functionalities. + +- UDFs + - Map + - MapServer + - MapAsyncServer + - MapMultiProcServer + - Reduce + - ReduceAsyncServer + - MapStream + - MapStreamAsyncServer + - Source Transform + - SourceTransformServer + - SourceTransformMultiProcServer +- UDSource + - SourceServer + - SourceAsyncServer +- UDSink + - SinkServer + - SinkAsyncServer +- SideInput + - SideInputServer + + + + +### Handler Function and Classes + +All the server types take a instance of a handler class or a handler function as an argument. +The handler function or class is the function or class which implements the functionality of the UDF, UDSource or UDSink. +For ease of use the user can pass either of the two to the server and the server will handle the rest. + +The handler for each of the servers has a specific signature which is defined by the server type and the implentation of the handlers +should follow the same signature. + +For using the class based handlers the user can inherit from the base handler class for each of the functionalities and implement the handler function. +The base handler class for each of the functionalities has the same signature as the handler function for the respective server type. +The list of base handler classes for each of the functionalities is given below - +- UDFs + - Map + - Mapper + - Reduce + - Reducer + - MapStream + - MapStreamer + - Source Transform + - SourceTransformer +- UDSource + - Sourcer +- UDSink + - Sinker +- SideInput + - SideInput + +More details about the signature of the handler function for each of the server types is given in the +documentation of the respective server type. diff --git a/examples/developer_guide/example.py b/examples/developer_guide/example.py index 8dbee528..dddc2e5e 100644 --- a/examples/developer_guide/example.py +++ b/examples/developer_guide/example.py @@ -1,29 +1,27 @@ -import aiorun -from collections.abc import Iterator +from collections.abc import AsyncIterable + from pynumaflow.reducer import ( Messages, Message, Datum, Metadata, - AsyncReducer, + ReduceAsyncServer, ) -async def my_handler(keys: list[str], datums: Iterator[Datum], md: Metadata) -> Messages: - # count the number of events +async def reduce_handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: interval_window = md.interval_window counter = 0 async for _ in datums: counter += 1 - msg = ( f"counter:{counter} interval_window_start:{interval_window.start} " f"interval_window_end:{interval_window.end}" ) - return Messages(Message(keys=keys, value=str.encode(msg))) + return Messages(Message(str.encode(msg), keys=keys)) if __name__ == "__main__": - grpc_server = AsyncReducer(handler=my_handler) - aiorun.run(grpc_server.start()) + grpc_server = ReduceAsyncServer(reduce_handler) + grpc_server.start() diff --git a/examples/map/even_odd/example.py b/examples/map/even_odd/example.py index 68c63ea3..52405590 100644 --- a/examples/map/even_odd/example.py +++ b/examples/map/even_odd/example.py @@ -1,4 +1,4 @@ -from pynumaflow.mapper import Messages, Message, Datum, Mapper +from pynumaflow.mapper import Messages, Message, Datum, MapServer def my_handler(keys: list[str], datum: Datum) -> Messages: @@ -22,5 +22,12 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: if __name__ == "__main__": - grpc_server = Mapper(handler=my_handler) + """ + This example shows how to create a simple map function that takes in a + number and outputs it to the "even" or "odd" key depending on whether it + is even or odd. + We use a function as handler, but a class that implements + a Mapper can be used as well. + """ + grpc_server = MapServer(my_handler) grpc_server.start() diff --git a/examples/map/even_odd/pyproject.toml b/examples/map/even_odd/pyproject.toml index d5df62f1..63e388cf 100644 --- a/examples/map/even_odd/pyproject.toml +++ b/examples/map/even_odd/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/map/flatmap/example.py b/examples/map/flatmap/example.py index 48642741..eda861bf 100644 --- a/examples/map/flatmap/example.py +++ b/examples/map/flatmap/example.py @@ -1,20 +1,30 @@ -from pynumaflow.mapper import Messages, Message, Datum, Mapper +from pynumaflow.mapper import Messages, Message, Datum, MapServer, Mapper -def my_handler(keys: list[str], datum: Datum) -> Messages: - val = datum.value - _ = datum.event_time - _ = datum.watermark - strs = val.decode("utf-8").split(",") - messages = Messages() - if len(strs) == 0: - messages.append(Message.to_drop()) +class Flatmap(Mapper): + """ + This is a class that inherits from the Mapper class. + It implements the handler method that is called for each datum. + """ + + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + messages = Messages() + if len(strs) == 0: + messages.append(Message.to_drop()) + return messages + for s in strs: + messages.append(Message(str.encode(s))) return messages - for s in strs: - messages.append(Message(str.encode(s))) - return messages if __name__ == "__main__": - grpc_server = Mapper(handler=my_handler) + """ + This example shows how to use the Flatmap mapper. + We use a class as handler, but a function can be used as well. + """ + grpc_server = MapServer(Flatmap()) grpc_server.start() diff --git a/examples/map/flatmap/pipeline.yaml b/examples/map/flatmap/pipeline.yaml index 5e1127d1..41a7c2f7 100644 --- a/examples/map/flatmap/pipeline.yaml +++ b/examples/map/flatmap/pipeline.yaml @@ -15,7 +15,7 @@ spec: - name: flatmap udf: container: - image: "quay.io/numaio/numaflow-python/map-flatmap:v0.5.0" + image: "quay.io/numaio/numaflow-python/map-flatmap:v0.7.0" env: - name: PYTHONDEBUG value: "true" diff --git a/examples/map/flatmap/pyproject.toml b/examples/map/flatmap/pyproject.toml index 3ecf4c88..badafc14 100644 --- a/examples/map/flatmap/pyproject.toml +++ b/examples/map/flatmap/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/map/forward_message/Makefile b/examples/map/forward_message/Makefile index 982dfc22..6bbab66a 100644 --- a/examples/map/forward_message/Makefile +++ b/examples/map/forward_message/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/map-forward-message:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/map-forward-message:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/map/forward_message/example.py b/examples/map/forward_message/example.py index a64eb897..9a6c9d09 100644 --- a/examples/map/forward_message/example.py +++ b/examples/map/forward_message/example.py @@ -1,15 +1,38 @@ -from pynumaflow.mapper import Messages, Message, Datum, Mapper +import os + +from pynumaflow.mapper import Messages, Message, Datum, MapServer, Mapper + + +class MessageForwarder(Mapper): + """ + This is a class that inherits from the Mapper class. + It implements the handler method that is called for each datum. + """ + + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + return Messages(Message(value=val, keys=keys)) def my_handler(keys: list[str], datum: Datum) -> Messages: val = datum.value _ = datum.event_time _ = datum.watermark - messages = Messages() - messages.append(Message(value=val, keys=keys)) - return messages + return Messages(Message(value=val, keys=keys)) if __name__ == "__main__": - grpc_server = Mapper(handler=my_handler) + """ + Use the class based approach or function based handler + based on the env variable + Both can be used and passed directly to the server class + """ + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + handler = MessageForwarder() + else: + handler = my_handler + grpc_server = MapServer(handler) grpc_server.start() diff --git a/examples/map/forward_message/pyproject.toml b/examples/map/forward_message/pyproject.toml index 361ba9e5..441e8dd4 100644 --- a/examples/map/forward_message/pyproject.toml +++ b/examples/map/forward_message/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/map/multiproc_map/Makefile b/examples/map/multiproc_map/Makefile index 50c4444f..e7679224 100644 --- a/examples/map/multiproc_map/Makefile +++ b/examples/map/multiproc_map/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/multiproc:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/multiproc:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/map/multiproc_map/README.md b/examples/map/multiproc_map/README.md index 3b42bbc4..f2053199 100644 --- a/examples/map/multiproc_map/README.md +++ b/examples/map/multiproc_map/README.md @@ -8,18 +8,15 @@ writing UDFs using map function. These are particularly useful for CPU intensive as it allows for better resource utilisation. In this mode we would spawn N number (N = Cpu count) of grpc servers in different processes, where each of them -listening on the same TCP socket. - -This is possible by enabling the `SO_REUSEPORT` flag for the TCP socket, which allows these different -processes to bind to the same port. +listening on multiple TCP sockets. To enable multiprocessing mode -1) Start the multiproc server in the UDF using the following command +1) Start the multiproc server in the UDF using the following command +2) Provide the optional argument `server_count` to specify the number of +servers to be forked. Defaults to `os.cpu_count` if not provided ```python if __name__ == "__main__": - grpc_server = MultiProcServer(map_handler=my_handler) + grpc_server = MapMultiProcServer(handler, server_count = 3) grpc_server.start() -``` -2) Set the ENV var value `NUM_CPU_MULTIPROC="n"` for the UDF container, -to set the value of the number of server instances (one for each subprocess) to be created. \ No newline at end of file +``` \ No newline at end of file diff --git a/examples/map/multiproc_map/example.py b/examples/map/multiproc_map/example.py index 195aad89..e99795b5 100644 --- a/examples/map/multiproc_map/example.py +++ b/examples/map/multiproc_map/example.py @@ -1,6 +1,7 @@ import math +import os -from pynumaflow.mapper import Messages, Message, Datum, MultiProcMapper +from pynumaflow.mapper import Messages, Message, Datum, Mapper, MapMultiprocServer def is_prime(n): @@ -11,23 +12,32 @@ def is_prime(n): return True -def my_handler(keys: list[str], datum: Datum) -> Messages: - val = datum.value - _ = datum.event_time - _ = datum.watermark - messages = Messages() - for i in range(2, 100000): - is_prime(i) - messages.append(Message(val, keys=keys)) - return messages +class PrimeMap(Mapper): + """ + This class needs to be of type Mapper class to be used + as a handler for the MapServer class. + Example of a mapper that calculates if a number is prime. + """ + + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + messages = Messages() + for i in range(2, 100000): + is_prime(i) + messages.append(Message(val, keys=keys)) + return messages if __name__ == "__main__": """ Example of starting a multiprocessing map vertex. - To enable set the env variable - MAP_MULTIPROC="true" - in the pipeline config for the numa container. """ - grpc_server = MultiProcMapper(handler=my_handler) + # To set the env server_count value set the env variable + # NUM_CPU_MULTIPROC="N" + server_count = int(os.getenv("NUM_CPU_MULTIPROC", "2")) + prime_class = PrimeMap() + # Server count is the number of server processes to start + grpc_server = MapMultiprocServer(prime_class, server_count=server_count) grpc_server.start() diff --git a/examples/map/multiproc_map/pipeline.yaml b/examples/map/multiproc_map/pipeline.yaml index 1fb6471e..54fe2b6a 100644 --- a/examples/map/multiproc_map/pipeline.yaml +++ b/examples/map/multiproc_map/pipeline.yaml @@ -15,12 +15,12 @@ spec: - name: mult udf: container: - image: "quay.io/numaio/numaflow-python/multiproc:latest" + image: "quay.io/numaio/numaflow-python/multiproc:v0.7.0" env: - name: PYTHONDEBUG value: "true" - name: NUM_CPU_MULTIPROC - value: "2" # DO NOT forget the double quotes!!! + value: "3" # DO NOT forget the double quotes!!! containerTemplate: resources: limits: diff --git a/examples/map/multiproc_map/pyproject.toml b/examples/map/multiproc_map/pyproject.toml index 361ba9e5..441e8dd4 100644 --- a/examples/map/multiproc_map/pyproject.toml +++ b/examples/map/multiproc_map/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/mapstream/flatmap_stream/Makefile b/examples/mapstream/flatmap_stream/Makefile index 824bec93..b84a451b 100644 --- a/examples/mapstream/flatmap_stream/Makefile +++ b/examples/mapstream/flatmap_stream/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/map-flatmap-stream:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/map-flatmap-stream:v0.6.1" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/mapstream/flatmap_stream/example.py b/examples/mapstream/flatmap_stream/example.py index 58f78f97..e25a11ab 100644 --- a/examples/mapstream/flatmap_stream/example.py +++ b/examples/mapstream/flatmap_stream/example.py @@ -1,7 +1,24 @@ -import aiorun +import os from collections.abc import AsyncIterable +from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer, MapStreamer -from pynumaflow.mapstreamer import Message, Datum, AsyncMapStreamer + +class FlatMapStream(MapStreamer): + async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]: + """ + A handler that splits the input datum value into multiple strings by `,` separator and + emits them as a stream. + """ + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + + if len(strs) == 0: + yield Message.to_drop() + return + for s in strs: + yield Message(str.encode(s)) async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Message]: @@ -22,5 +39,10 @@ async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Messag if __name__ == "__main__": - grpc_server = AsyncMapStreamer(handler=map_stream_handler) - aiorun.run(grpc_server.start()) + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + handler = FlatMapStream() + else: + handler = map_stream_handler + grpc_server = MapStreamAsyncServer(handler) + grpc_server.start() diff --git a/examples/mapstream/flatmap_stream/pipeline.yaml b/examples/mapstream/flatmap_stream/pipeline.yaml new file mode 100644 index 00000000..eb568f72 --- /dev/null +++ b/examples/mapstream/flatmap_stream/pipeline.yaml @@ -0,0 +1,49 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: simple-pipeline +spec: + limits: + readBatchSize: 2 + vertices: + - name: in + source: + # A self data generating source + generator: + rpu: 10 + duration: 1s + - name: flatmap + metadata: + annotations: + numaflow.numaproj.io/map-stream: "true" + limits: + readBatchSize: 1 + udf: + container: + image: "quay.io/numaio/numaflow-python/map-flatmap-stream:v0.6.1" + imagePullPolicy: Always + env: + - name: PYTHONDEBUG + value: "true" + - name : INVOKE + value: "func_handler" + containerTemplate: + resources: + limits: + cpu: "1" + memory: 2Gi + requests: + cpu: "500m" + memory: 1Gi + env: + - name: NUMAFLOW_DEBUG + value: "true" # DO NOT forget the double quotes!!! + - name: out + sink: + # A simple log printing sink + log: {} + edges: + - from: in + to: flatmap + - from: flatmap + to: out diff --git a/examples/mapstream/flatmap_stream/pyproject.toml b/examples/mapstream/flatmap_stream/pyproject.toml index 7df9056e..e99df380 100644 --- a/examples/mapstream/flatmap_stream/pyproject.toml +++ b/examples/mapstream/flatmap_stream/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/reduce/README.md b/examples/reduce/README.md new file mode 100644 index 00000000..130dc6ee --- /dev/null +++ b/examples/reduce/README.md @@ -0,0 +1,65 @@ +# Reducer in Python + +For creating a reducer UDF we can use two different approaches: +- Class based reducer + - For the class based reducer we need to implement a class that inherits from the `Reducer` class and implements the required methods. + - Next we need to create a `ReduceAsyncServer` instance and pass the reducer class to it along with any input args or + kwargs that the custom reducer class requires. + - Finally we need to call the `start` method on the `ReduceAsyncServer` instance to start the reducer server. + ```python + from numaflow import Reducer, ReduceAsyncServer + class Example(Reducer): + def __init__(self, counter): + self.counter = counter + + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + interval_window = md.interval_window + self.counter = 0 + async for _ in datums: + self.counter += 1 + msg = ( + f"counter:{self.counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) + + if __name__ == "__main__": + # Here we are using the class instance as the reducer_instance + # which will be used to invoke the handler function. + # We are passing the init_args for the class instance. + grpc_server = ReduceAsyncServer(Example, init_args=(0,)) + grpc_server.start() + ``` + +- Function based reducer + For the function based reducer we need to create a function of the signature + ```python + async def handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: + ``` + that takes the required arguments and returns the `Messages` object. + - Next we need to create a `ReduceAsyncServer` instance and pass the function to it along with any input args or kwargs that the custom reducer function requires. + - Finally we need to call the `start` method on the `ReduceAsyncServer` instance to start the reducer server. + - We must ensure that no init_args or init_kwargs are passed to the `ReduceAsyncServer` instance as they are not used for function based reducers. + ```python + from numaflow import ReduceAsyncServer + async def handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: + counter = 0 + interval_window = md.interval_window + async for _ in datums: + counter += 1 + msg = ( + f"counter:{counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) + + if __name__ == "__main__": + # Here we are using the function as the reducer_instance + # which will be used to invoke the handler function. + grpc_server = ReduceAsyncServer(handler) + grpc_server.start() + ``` + + diff --git a/examples/reduce/counter/Makefile b/examples/reduce/counter/Makefile index 16146522..363debad 100644 --- a/examples/reduce/counter/Makefile +++ b/examples/reduce/counter/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/reduce-counter:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/reduce-counter:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/reduce/counter/example.py b/examples/reduce/counter/example.py index 043b4dd9..124f010c 100644 --- a/examples/reduce/counter/example.py +++ b/examples/reduce/counter/example.py @@ -1,7 +1,25 @@ -import aiorun +import os from collections.abc import AsyncIterable -from pynumaflow.reducer import Messages, Message, Datum, Metadata, AsyncReducer +from pynumaflow.reducer import Messages, Message, Datum, Metadata, ReduceAsyncServer, Reducer + + +class ReduceCounter(Reducer): + def __init__(self, counter): + self.counter = counter + + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + interval_window = md.interval_window + self.counter = 0 + async for _ in datums: + self.counter += 1 + msg = ( + f"counter:{self.counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) async def reduce_handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: @@ -17,5 +35,13 @@ async def reduce_handler(keys: list[str], datums: AsyncIterable[Datum], md: Meta if __name__ == "__main__": - grpc_server = AsyncReducer(handler=reduce_handler) - aiorun.run(grpc_server.start()) + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + # Here we are using the class instance as the reducer_instance + # which will be used to invoke the handler function. + # We are passing the init_args for the class instance. + grpc_server = ReduceAsyncServer(ReduceCounter, init_args=(0,)) + else: + # Here we are using the handler function directly as the reducer_instance. + grpc_server = ReduceAsyncServer(reduce_handler) + grpc_server.start() diff --git a/examples/reduce/counter/pipeline.yaml b/examples/reduce/counter/pipeline.yaml new file mode 100644 index 00000000..1b7d6b34 --- /dev/null +++ b/examples/reduce/counter/pipeline.yaml @@ -0,0 +1,49 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: even-odd-sum +spec: + vertices: + - name: in + source: + http: {} + - name: atoi + scale: + min: 3 + udf: + container: + # Tell the input number is even or odd, see https://github.com/numaproj/numaflow-go/tree/main/pkg/mapper/examples/even_odd + image: quay.io/numaio/numaflow-go/map-even-odd:v0.5.0 + - name: compute-sum + udf: + container: + # compute the sum + image: quay.io/numaio/numaflow-python/reduce-counter:latest + imagePullPolicy: Always + env: + - name: PYTHONDEBUG + value: "true" + - name: INVOKE + value: "class" + groupBy: + window: + fixed: + length: 60s + keyed: true + storage: + persistentVolumeClaim: + volumeSize: 10Gi + accessMode: ReadWriteOnce + partitions: 1 + - name: sink + scale: + min: 1 + sink: + log: {} + edges: + - from: in + to: atoi + - from: atoi + to: compute-sum + - from: compute-sum + to: sink diff --git a/examples/reduce/counter/pyproject.toml b/examples/reduce/counter/pyproject.toml index 7c956677..dc3cc41d 100644 --- a/examples/reduce/counter/pyproject.toml +++ b/examples/reduce/counter/pyproject.toml @@ -6,8 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" -aiorun = "^2022.11.1" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/sideinput/simple-sideinput/Makefile b/examples/sideinput/simple-sideinput/Makefile index 2d36ab46..cd2f0add 100644 --- a/examples/sideinput/simple-sideinput/Makefile +++ b/examples/sideinput/simple-sideinput/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/sideinput-example:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/sideinput-example:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/sideinput/simple-sideinput/example.py b/examples/sideinput/simple-sideinput/example.py index 8cb73a62..c4265711 100644 --- a/examples/sideinput/simple-sideinput/example.py +++ b/examples/sideinput/simple-sideinput/example.py @@ -1,26 +1,27 @@ import datetime -from pynumaflow.sideinput import Response, SideInput +from pynumaflow.sideinput import Response, SideInputServer, SideInput -counter = 0 +class ExampleSideInput(SideInput): + def __init__(self): + self.counter = 0 -def my_handler() -> Response: - """ - This function is called every time the side input is requested. - """ - time_now = datetime.datetime.now() - # val is the value to be broadcasted - val = "an example:" + str(time_now) - global counter - counter += 1 - # broadcast every other time - if counter % 2 == 0: - # no_broadcast_message() is used to indicate that there is no broadcast - return Response.no_broadcast_message() - # broadcast_message() is used to indicate that there is a broadcast - return Response.broadcast_message(val.encode("utf-8")) + def retrieve_handler(self) -> Response: + """ + This function is called every time the side input is requested. + """ + time_now = datetime.datetime.now() + # val is the value to be broadcasted + val = f"an example: {str(time_now)}" + self.counter += 1 + # broadcast every other time + if self.counter % 2 == 0: + # no_broadcast_message() is used to indicate that there is no broadcast + return Response.no_broadcast_message() + # broadcast_message() is used to indicate that there is a broadcast + return Response.broadcast_message(val.encode("utf-8")) if __name__ == "__main__": - grpc_server = SideInput(handler=my_handler) + grpc_server = SideInputServer(ExampleSideInput()) grpc_server.start() diff --git a/examples/sideinput/simple-sideinput/pipeline-numaflow.yaml b/examples/sideinput/simple-sideinput/pipeline-numaflow.yaml index 3d471fb4..4e07d5f1 100644 --- a/examples/sideinput/simple-sideinput/pipeline-numaflow.yaml +++ b/examples/sideinput/simple-sideinput/pipeline-numaflow.yaml @@ -6,7 +6,7 @@ spec: sideInputs: - name: myticker container: - image: "quay.io/numaio/numaflow-python/sideinput-example:v0.5.0" + image: "quay.io/numaio/numaflow-python/sideinput-example:v0.7.0" imagePullPolicy: Always trigger: schedule: "*/2 * * * *" diff --git a/examples/sideinput/simple-sideinput/pyproject.toml b/examples/sideinput/simple-sideinput/pyproject.toml index 361ba9e5..441e8dd4 100644 --- a/examples/sideinput/simple-sideinput/pyproject.toml +++ b/examples/sideinput/simple-sideinput/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/sideinput/simple-sideinput/udf/example.py b/examples/sideinput/simple-sideinput/udf/example.py index a155c2d6..5f3bc8f1 100644 --- a/examples/sideinput/simple-sideinput/udf/example.py +++ b/examples/sideinput/simple-sideinput/udf/example.py @@ -1,6 +1,6 @@ from threading import Thread import pynumaflow.sideinput as sideinputsdk -from pynumaflow.mapper import Messages, Mapper, Message, Datum +from pynumaflow.mapper import Messages, MapServer, Message, Datum from watchfiles import watch @@ -14,7 +14,7 @@ def watcher(): """ This function is used to watch the side input directory for changes. """ - path = sideinputsdk.SideInput.SIDE_INPUT_DIR_PATH + path = sideinputsdk.SIDE_INPUT_DIR_PATH for changes in watch(path): print(changes) @@ -24,7 +24,7 @@ def watcher(): This function is used to start the GRPC server and the watcher thread. """ daemon = Thread(target=watcher, daemon=True, name="Monitor") - grpc_server = Mapper(handler=my_handler) + grpc_server = MapServer(my_handler) thread_server = Thread(target=grpc_server.start, daemon=True, name="GRPC Server") daemon.start() thread_server.start() diff --git a/examples/sink/async_log/Makefile b/examples/sink/async_log/Makefile index 5f300878..49ed516f 100644 --- a/examples/sink/async_log/Makefile +++ b/examples/sink/async_log/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/async-sink-log:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/async-sink-log:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/sink/async_log/example.py b/examples/sink/async_log/example.py index d8b9ac1f..7e338c3e 100644 --- a/examples/sink/async_log/example.py +++ b/examples/sink/async_log/example.py @@ -1,18 +1,35 @@ +import os from collections.abc import AsyncIterable +from pynumaflow.sinker import Datum, Responses, Response, Sinker +from pynumaflow.sinker import SinkAsyncServer +import logging -import aiorun +logging.basicConfig(level=logging.DEBUG) +_LOGGER = logging.getLogger(__name__) -from pynumaflow.sinker import Datum, Responses, Response, AsyncSinker + +class UserDefinedSink(Sinker): + async def handler(self, datums: AsyncIterable[Datum]) -> Responses: + responses = Responses() + async for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses async def udsink_handler(datums: AsyncIterable[Datum]) -> Responses: responses = Responses() async for msg in datums: - print("User Defined Sink", msg.value.decode("utf-8")) + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) responses.append(Response.as_success(msg.id)) return responses if __name__ == "__main__": - grpc_server = AsyncSinker(handler=udsink_handler) - aiorun.run(grpc_server.start()) + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + sink_handler = UserDefinedSink() + else: + sink_handler = udsink_handler + grpc_server = SinkAsyncServer(sink_handler) + grpc_server.start() diff --git a/examples/sink/async_log/pipeline-numaflow.yaml b/examples/sink/async_log/pipeline-numaflow.yaml index 690d1d7d..b2e23ac8 100644 --- a/examples/sink/async_log/pipeline-numaflow.yaml +++ b/examples/sink/async_log/pipeline-numaflow.yaml @@ -21,7 +21,13 @@ spec: args: - python - example.py - image: quay.io/numaio/numaflow-python/async-sink-log:latest + image: quay.io/numaio/numaflow-python/async-sink-log:v0.7.0 + imagePullPolicy: Always + env: + - name: PYTHONDEBUG + value: "true" + - name: INVOKE + value: "func_handler" - name: log-output sink: log: {} diff --git a/examples/sink/async_log/pyproject.toml b/examples/sink/async_log/pyproject.toml index 629d9c26..583a6388 100644 --- a/examples/sink/async_log/pyproject.toml +++ b/examples/sink/async_log/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/sink/log/Makefile b/examples/sink/log/Makefile index f0997a86..fa77ab1a 100644 --- a/examples/sink/log/Makefile +++ b/examples/sink/log/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/sink-log:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/sink-log:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/sink/log/example.py b/examples/sink/log/example.py index 653d8489..2c960139 100644 --- a/examples/sink/log/example.py +++ b/examples/sink/log/example.py @@ -1,16 +1,35 @@ +import os from collections.abc import Iterator +from pynumaflow.sinker import Datum, Responses, Response, SinkServer +from pynumaflow.sinker import Sinker +import logging -from pynumaflow.sinker import Datum, Responses, Response, Sinker +logging.basicConfig(level=logging.DEBUG) +_LOGGER = logging.getLogger(__name__) + + +class UserDefinedSink(Sinker): + def handler(self, datums: Iterator[Datum]) -> Responses: + responses = Responses() + for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses def udsink_handler(datums: Iterator[Datum]) -> Responses: responses = Responses() for msg in datums: - print("User Defined Sink", msg.value.decode("utf-8")) + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) responses.append(Response.as_success(msg.id)) return responses if __name__ == "__main__": - grpc_server = Sinker(handler=udsink_handler) + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + sink_handler = UserDefinedSink() + else: + sink_handler = udsink_handler + grpc_server = SinkServer(sink_handler) grpc_server.start() diff --git a/examples/sink/log/pipeline-numaflow.yaml b/examples/sink/log/pipeline-numaflow.yaml index c2bc60ad..609ed58c 100644 --- a/examples/sink/log/pipeline-numaflow.yaml +++ b/examples/sink/log/pipeline-numaflow.yaml @@ -21,7 +21,13 @@ spec: args: - python - example.py - image: quay.io/numaio/numaflow-python/sink-log:latest + image: "quay.io/numaio/numaflow-python/sink-log:v0.7.0" + imagePullPolicy: Always + env: + - name: PYTHONDEBUG + value: "true" + - name: INVOKE + value: "func_handler" - name: log-output sink: log: {} diff --git a/examples/sink/log/pyproject.toml b/examples/sink/log/pyproject.toml index 629d9c26..583a6388 100644 --- a/examples/sink/log/pyproject.toml +++ b/examples/sink/log/pyproject.toml @@ -6,7 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/source/async-source/Makefile b/examples/source/async-source/Makefile index 782ee86d..ddcc242e 100644 --- a/examples/source/async-source/Makefile +++ b/examples/source/async-source/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/async-source:v0.5.5" . + docker build -t "quay.io/numaio/numaflow-python/async-source:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/source/async-source/example.py b/examples/source/async-source/example.py index cf413808..352175ff 100644 --- a/examples/source/async-source/example.py +++ b/examples/source/async-source/example.py @@ -1,7 +1,5 @@ -from datetime import datetime from collections.abc import AsyncIterable - -import aiorun +from datetime import datetime from pynumaflow.sourcer import ( ReadRequest, @@ -9,13 +7,14 @@ AckRequest, PendingResponse, Offset, - AsyncSourcer, PartitionsResponse, get_default_partitions, + Sourcer, + SourceAsyncServer, ) -class AsyncSource: +class AsyncSource(Sourcer): """ AsyncSource is a class for User Defined Source implementation. """ @@ -69,10 +68,5 @@ async def partitions_handler(self) -> PartitionsResponse: if __name__ == "__main__": ud_source = AsyncSource() - grpc_server = AsyncSourcer( - read_handler=ud_source.read_handler, - ack_handler=ud_source.ack_handler, - pending_handler=ud_source.pending_handler, - partitions_handler=ud_source.partitions_handler, - ) - aiorun.run(grpc_server.start()) + grpc_server = SourceAsyncServer(ud_source) + grpc_server.start() diff --git a/examples/source/async-source/pipeline-numaflow.yaml b/examples/source/async-source/pipeline-numaflow.yaml index 1626001e..da8d482d 100644 --- a/examples/source/async-source/pipeline-numaflow.yaml +++ b/examples/source/async-source/pipeline-numaflow.yaml @@ -9,7 +9,7 @@ spec: udsource: container: # A simple user-defined async source - image: quay.io/numaio/numaflow-python/async-source:v0.5.5 + image: "quay.io/numaio/numaflow-python/async-source:v0.7.0" imagePullPolicy: Always limits: readBatchSize: 2 diff --git a/examples/source/async-source/pyproject.toml b/examples/source/async-source/pyproject.toml index bf23dd70..60fba875 100644 --- a/examples/source/async-source/pyproject.toml +++ b/examples/source/async-source/pyproject.toml @@ -6,8 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" -aiorun = "^2023.7" +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/source/simple-source/Makefile b/examples/source/simple-source/Makefile index 1c350bc5..647da242 100644 --- a/examples/source/simple-source/Makefile +++ b/examples/source/simple-source/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/simple-source:v0.5.5" . + docker build -t "quay.io/numaio/numaflow-python/simple-source:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/source/simple-source/example.py b/examples/source/simple-source/example.py index 49047b27..a2ae814a 100644 --- a/examples/source/simple-source/example.py +++ b/examples/source/simple-source/example.py @@ -4,16 +4,17 @@ from pynumaflow.sourcer import ( ReadRequest, Message, - Sourcer, AckRequest, PendingResponse, Offset, PartitionsResponse, get_default_partitions, + Sourcer, + SourceServer, ) -class SimpleSource: +class SimpleSource(Sourcer): """ SimpleSource is a class for User Defined Source implementation. """ @@ -67,10 +68,5 @@ def partitions_handler(self) -> PartitionsResponse: if __name__ == "__main__": ud_source = SimpleSource() - grpc_server = Sourcer( - read_handler=ud_source.read_handler, - ack_handler=ud_source.ack_handler, - pending_handler=ud_source.pending_handler, - partitions_handler=ud_source.partitions_handler, - ) + grpc_server = SourceServer(ud_source) grpc_server.start() diff --git a/examples/source/simple-source/pipeline-numaflow.yaml b/examples/source/simple-source/pipeline-numaflow.yaml index 920ded41..50246f7d 100644 --- a/examples/source/simple-source/pipeline-numaflow.yaml +++ b/examples/source/simple-source/pipeline-numaflow.yaml @@ -9,7 +9,7 @@ spec: udsource: container: # A simple user-defined source for e2e testing - image: quay.io/numaio/numaflow-python/simple-source:v0.5.4 + image: quay.io/numaio/numaflow-python/simple-source:v0.7.0 imagePullPolicy: Always limits: readBatchSize: 2 diff --git a/examples/source/simple-source/pyproject.toml b/examples/source/simple-source/pyproject.toml index 82428bc2..a2fa357a 100644 --- a/examples/source/simple-source/pyproject.toml +++ b/examples/source/simple-source/pyproject.toml @@ -6,8 +6,7 @@ authors = ["Numaflow developers"] [tool.poetry.dependencies] python = "~3.10" -pynumaflow = "~0.6.0" - +pynumaflow = "~0.7.0" [tool.poetry.dev-dependencies] diff --git a/examples/sourcetransform/event_time_filter/Makefile b/examples/sourcetransform/event_time_filter/Makefile index c47eac03..e22f7be8 100644 --- a/examples/sourcetransform/event_time_filter/Makefile +++ b/examples/sourcetransform/event_time_filter/Makefile @@ -1,6 +1,6 @@ .PHONY: image image: - docker build -t "quay.io/numaio/numaflow-python/mapt-event-time-filter:v0.5.0" . + docker build -t "quay.io/numaio/numaflow-python/mapt-event-time-filter:v0.7.0" . # Github CI runner uses platform linux/amd64. If your local environment don't, the image built by command above might not work # under the CI E2E test environment. # To build an image that supports multiple platforms(linux/amd64,linux/arm64) and push to quay.io, use the following command diff --git a/examples/sourcetransform/event_time_filter/example.py b/examples/sourcetransform/event_time_filter/example.py index e33604dc..add91b96 100644 --- a/examples/sourcetransform/event_time_filter/example.py +++ b/examples/sourcetransform/event_time_filter/example.py @@ -1,7 +1,7 @@ import datetime import logging -from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformer +from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformServer """ This is a simple User Defined Function example which receives a message, applies the following @@ -43,5 +43,5 @@ def my_handler(keys: list[str], datum: Datum) -> Messages: if __name__ == "__main__": - grpc_server = SourceTransformer(handler=my_handler) + grpc_server = SourceTransformServer(my_handler) grpc_server.start() diff --git a/examples/sourcetransform/event_time_filter/pyproject.toml b/examples/sourcetransform/event_time_filter/pyproject.toml index a6d19b57..31dfd83c 100644 --- a/examples/sourcetransform/event_time_filter/pyproject.toml +++ b/examples/sourcetransform/event_time_filter/pyproject.toml @@ -8,7 +8,7 @@ packages = [{include = "mapt_event_time_filter"}] [tool.poetry.dependencies] python = ">=3.9, <3.12" -pynumaflow = "~0.6.0" +pynumaflow = "~0.7.0" [build-system] requires = ["poetry-core"] diff --git a/poetry.lock b/poetry.lock index 8b310569..49aa67ad 100644 --- a/poetry.lock +++ b/poetry.lock @@ -762,6 +762,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -769,8 +770,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -787,6 +795,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -794,6 +803,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -915,6 +925,50 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvloop" +version = "0.19.0" +description = "Fast implementation of asyncio event loop on top of libuv" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:de4313d7f575474c8f5a12e163f6d89c0a878bc49219641d49e6f1444369a90e"}, + {file = "uvloop-0.19.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5588bd21cf1fcf06bded085f37e43ce0e00424197e7c10e77afd4bbefffef428"}, + {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b1fd71c3843327f3bbc3237bedcdb6504fd50368ab3e04d0410e52ec293f5b8"}, + {file = "uvloop-0.19.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a05128d315e2912791de6088c34136bfcdd0c7cbc1cf85fd6fd1bb321b7c849"}, + {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cd81bdc2b8219cb4b2556eea39d2e36bfa375a2dd021404f90a62e44efaaf957"}, + {file = "uvloop-0.19.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:5f17766fb6da94135526273080f3455a112f82570b2ee5daa64d682387fe0dcd"}, + {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4ce6b0af8f2729a02a5d1575feacb2a94fc7b2e983868b009d51c9a9d2149bef"}, + {file = "uvloop-0.19.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:31e672bb38b45abc4f26e273be83b72a0d28d074d5b370fc4dcf4c4eb15417d2"}, + {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:570fc0ed613883d8d30ee40397b79207eedd2624891692471808a95069a007c1"}, + {file = "uvloop-0.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5138821e40b0c3e6c9478643b4660bd44372ae1e16a322b8fc07478f92684e24"}, + {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:91ab01c6cd00e39cde50173ba4ec68a1e578fee9279ba64f5221810a9e786533"}, + {file = "uvloop-0.19.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:47bf3e9312f63684efe283f7342afb414eea4d3011542155c7e625cd799c3b12"}, + {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:da8435a3bd498419ee8c13c34b89b5005130a476bda1d6ca8cfdde3de35cd650"}, + {file = "uvloop-0.19.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:02506dc23a5d90e04d4f65c7791e65cf44bd91b37f24cfc3ef6cf2aff05dc7ec"}, + {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2693049be9d36fef81741fddb3f441673ba12a34a704e7b4361efb75cf30befc"}, + {file = "uvloop-0.19.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7010271303961c6f0fe37731004335401eb9075a12680738731e9c92ddd96ad6"}, + {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5daa304d2161d2918fa9a17d5635099a2f78ae5b5960e742b2fcfbb7aefaa593"}, + {file = "uvloop-0.19.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:7207272c9520203fea9b93843bb775d03e1cf88a80a936ce760f60bb5add92f3"}, + {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:78ab247f0b5671cc887c31d33f9b3abfb88d2614b84e4303f1a63b46c046c8bd"}, + {file = "uvloop-0.19.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:472d61143059c84947aa8bb74eabbace30d577a03a1805b77933d6bd13ddebbd"}, + {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45bf4c24c19fb8a50902ae37c5de50da81de4922af65baf760f7c0c42e1088be"}, + {file = "uvloop-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271718e26b3e17906b28b67314c45d19106112067205119dddbd834c2b7ce797"}, + {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:34175c9fd2a4bc3adc1380e1261f60306344e3407c20a4d684fd5f3be010fa3d"}, + {file = "uvloop-0.19.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e27f100e1ff17f6feeb1f33968bc185bf8ce41ca557deee9d9bbbffeb72030b7"}, + {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:13dfdf492af0aa0a0edf66807d2b465607d11c4fa48f4a1fd41cbea5b18e8e8b"}, + {file = "uvloop-0.19.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6e3d4e85ac060e2342ff85e90d0c04157acb210b9ce508e784a944f852a40e67"}, + {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8ca4956c9ab567d87d59d49fa3704cf29e37109ad348f2d5223c9bf761a332e7"}, + {file = "uvloop-0.19.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f467a5fd23b4fc43ed86342641f3936a68ded707f4627622fa3f82a120e18256"}, + {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:492e2c32c2af3f971473bc22f086513cedfc66a130756145a931a90c3958cb17"}, + {file = "uvloop-0.19.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2df95fca285a9f5bfe730e51945ffe2fa71ccbfdde3b0da5772b4ee4f2e770d5"}, + {file = "uvloop-0.19.0.tar.gz", hash = "sha256:0246f4fd1bf2bf702e06b0d45ee91677ee5c31242f39aab4ea6fe0c51aedd0fd"}, +] + +[package.extras] +docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] + [[package]] name = "virtualenv" version = "20.24.5" @@ -938,4 +992,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.12" -content-hash = "6372817a9a99177a328bfc5ec53fb44d5cf5b66205c3079bc35751561363966e" +content-hash = "e6fd5e2ffdc1b0e57b4cba288c6cb20260a66250085fe1f0b4f5982488ad81b4" diff --git a/pynumaflow/_constants.py b/pynumaflow/_constants.py index 253d0401..ca9766c4 100644 --- a/pynumaflow/_constants.py +++ b/pynumaflow/_constants.py @@ -1,3 +1,9 @@ +import logging +import os +from enum import Enum + +from pynumaflow import setup_logging + MAP_SOCK_PATH = "/var/run/numaflow/map.sock" MAP_STREAM_SOCK_PATH = "/var/run/numaflow/mapstream.sock" REDUCE_SOCK_PATH = "/var/run/numaflow/reduce.sock" @@ -7,6 +13,7 @@ MULTIPROC_MAP_SOCK_ADDR = "0.0.0.0" SIDE_INPUT_SOCK_PATH = "/var/run/numaflow/sideinput.sock" SOURCE_SOCK_PATH = "/var/run/numaflow/source.sock" +SIDE_INPUT_DIR_PATH = "/var/numaflow/side-inputs" # TODO: need to make sure the DATUM_KEY value is the same as # https://github.com/numaproj/numaflow-go/blob/main/pkg/function/configs.go#L6 @@ -17,3 +24,23 @@ STREAM_EOF = "EOF" DELIMITER = ":" DROP = "U+005C__DROP__" + +_PROCESS_COUNT = os.cpu_count() +MAX_THREADS = int(os.getenv("MAX_THREADS", "4")) + +_LOGGER = setup_logging(__name__) +if os.getenv("PYTHONDEBUG"): + _LOGGER.setLevel(logging.DEBUG) + + +class UDFType(str, Enum): + """ + Enumerate the type of UDF. + """ + + Map = "map" + Reduce = "reduce" + Sink = "sink" + Source = "source" + SideInput = "sideinput" + SourceTransformer = "sourcetransformer" diff --git a/pynumaflow/mapper/__init__.py b/pynumaflow/mapper/__init__.py index 374b123f..a713d039 100644 --- a/pynumaflow/mapper/__init__.py +++ b/pynumaflow/mapper/__init__.py @@ -1,12 +1,8 @@ -from pynumaflow.mapper._dtypes import ( - Message, - Messages, - Datum, - DROP, -) -from pynumaflow.mapper.async_server import AsyncMapper -from pynumaflow.mapper.multiproc_server import MultiProcMapper -from pynumaflow.mapper.server import Mapper +from pynumaflow.mapper.async_server import MapAsyncServer +from pynumaflow.mapper.multiproc_server import MapMultiprocServer +from pynumaflow.mapper.sync_server import MapServer + +from pynumaflow.mapper._dtypes import Message, Messages, Datum, DROP, Mapper __all__ = [ "Message", @@ -14,6 +10,7 @@ "Datum", "DROP", "Mapper", - "AsyncMapper", - "MultiProcMapper", + "MapServer", + "MapAsyncServer", + "MapMultiprocServer", ] diff --git a/pynumaflow/mapper/_dtypes.py b/pynumaflow/mapper/_dtypes.py index 92556a9d..607ef4c8 100644 --- a/pynumaflow/mapper/_dtypes.py +++ b/pynumaflow/mapper/_dtypes.py @@ -1,7 +1,8 @@ +from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Sequence, Awaitable from dataclasses import dataclass from datetime import datetime -from typing import TypeVar, Callable +from typing import TypeVar, Callable, Union from warnings import warn from pynumaflow._constants import DROP @@ -162,5 +163,31 @@ def watermark(self) -> datetime: return self._watermark -MapCallable = Callable[[list[str], Datum], Messages] -MapAsyncCallable = Callable[[list[str], Datum], Awaitable[Messages]] +class Mapper(metaclass=ABCMeta): + """ + Provides an interface to write a SyncMapServicer + which will be exposed over a Synchronous gRPC server. + """ + + def __call__(self, *args, **kwargs): + """ + This allows to execute the handler function directly if + class instance is sent as a callable. + """ + return self.handler(*args, **kwargs) + + @abstractmethod + def handler(self, keys: list[str], datum: Datum) -> Messages: + """ + Implement this handler function which implements the MapSyncCallable interface. + """ + pass + + +# MapSyncCallable is a callable which can be used as a handler for the Synchronous Map UDF +MapSyncHandlerCallable = Callable[[list[str], Datum], Messages] +MapSyncCallable = Union[Mapper, MapSyncHandlerCallable] + +# MapAsyncCallable is a callable which can be used as a handler for the Asynchronous Map UDF +MapAsyncHandlerCallable = Callable[[list[str], Datum], Awaitable[Messages]] +MapAsyncCallable = Union[Mapper, MapAsyncHandlerCallable] diff --git a/pynumaflow/mapper/async_server.py b/pynumaflow/mapper/async_server.py index 2c479848..534c6606 100644 --- a/pynumaflow/mapper/async_server.py +++ b/pynumaflow/mapper/async_server.py @@ -1,161 +1,100 @@ -import logging -import multiprocessing import os - +import aiorun import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow import setup_logging from pynumaflow._constants import ( + MAX_THREADS, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, ) -from pynumaflow.mapper import Datum from pynumaflow.mapper._dtypes import MapAsyncCallable -from pynumaflow.mapper.proto import map_pb2 -from pynumaflow.mapper.proto import map_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) +from pynumaflow.mapper.servicer.async_servicer import AsyncMapServicer +from pynumaflow.proto.mapper import map_pb2_grpc +from pynumaflow.shared.server import ( + NumaflowServer, + start_async_server, +) -class AsyncMapper(map_pb2_grpc.MapServicer): +class MapAsyncServer(NumaflowServer): """ - Provides an interface to write an Async Mapper - which will be exposed over gRPC. - + Create a new grpc Map Server instance. Args: - handler: Function callable following the type signature of MapCallable - sock_path: Path to the UNIX Domain Socket + mapper_instance: The mapper instance to be used for Map UDF + sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 + defaults to number of processors x4 Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.mapper import Messages, Message\ - ... Datum, AsyncMapper - ... import aiorun - ... - >>> async def map_handler(key: [str], datum: Datum) -> Messages: - ... val = datum.value - ... _ = datum.event_time - ... _ = datum.watermark - ... messages = Messages(Message(val, keys=keys)) - ... return messages - ... - >>> grpc_server = AsyncMapper(handler=map_handler) - >>> aiorun.run(grpc_server.start()) + from pynumaflow.mapper import Messages, Message, Datum, MapAsyncServer + async def async_map_handler(keys: list[str], datum: Datum) -> Messages: + val = datum.value + msg = "payload:{} event_time:{} watermark:{}".format( + val.decode("utf-8"), + datum.event_time, + datum.watermark, + ) + val = bytes(msg, encoding="utf-8") + return Messages(Message(value=val, keys=keys)) + + if __name__ == "__main__": + grpc_server = MapAsyncServer(async_map_handler) + grpc_server.start() """ def __init__( self, - handler: MapAsyncCallable, + mapper_instance: MapAsyncCallable, sock_path=MAP_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__map_handler: MapAsyncCallable = handler + """ + Create a new grpc Asynchronous Map Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + mapper_instance: The mapper instance to be used for Map UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + """ self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - # Collection for storing strong references to all running tasks. - # Event loop only keeps a weak reference, which can cause it to - # get lost during execution. - self.background_tasks = set() + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.mapper_instance = mapper_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] + # Get the servicer instance for the async server + self.servicer = AsyncMapServicer(handler=mapper_instance) - async def MapFn( - self, request: map_pb2.MapRequest, context: NumaflowServicerContext - ) -> map_pb2.MapResponse: - """ - Applies a function to each datum element. - The pascal case function name comes from the proto map_pb2_grpc.py file. - """ - # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: - res = await self.__invoke_map( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ) - except Exception as e: - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(e)) - return map_pb2.MapResponse(results=[]) - - return map_pb2.MapResponse(results=res) - - async def __invoke_map(self, keys: list[str], req: Datum): + def start(self) -> None: """ - Invokes the user defined function. + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context """ - try: - msgs = await self.__map_handler(keys, req) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - raise err - datums = [] - for msg in msgs: - datums.append(map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)) + aiorun.run(self.aexec(), use_uvloop=True) - return datums - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> map_pb2.ReadyResponse: + async def aexec(self) -> None: """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto map_pb2_grpc.py file. + Starts the Async gRPC server on the given UNIX socket with + given max threads. """ - return map_pb2.ReadyResponse(ready=True) - - async def __serve_async(self, server) -> None: - map_pb2_grpc.add_MapServicer_to_server( - AsyncMapper(handler=self.__map_handler), - server, - ) - server.add_insecure_port(self.sock_path) - _LOGGER.info("gRPC Async Map Server listening on: %s", self.sock_path) - await server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - async def server_graceful_shutdown(): - """ - Shuts down the server with 5 seconds of grace period. During the - grace period, the server won't accept new connections and allow - existing RPCs to continue within the grace period. - """ - _LOGGER.info("Starting graceful shutdown...") - await server.stop(5) + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context - self.cleanup_coroutines.append(server_graceful_shutdown()) - await server.wait_for_termination() + server_new = grpc.aio.server() + server_new.add_insecure_port(self.sock_path) + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server_new) - async def start(self) -> None: - """Starts the Async gRPC mapper on the given UNIX socket.""" - server = grpc.aio.server(options=self._server_options) - await self.__serve_async(server) + # Start the async server + await start_async_server(server_new, self.sock_path, self.max_threads, self._server_options) diff --git a/pynumaflow/mapper/multiproc_server.py b/pynumaflow/mapper/multiproc_server.py index d14fde93..e91a1fa2 100644 --- a/pynumaflow/mapper/multiproc_server.py +++ b/pynumaflow/mapper/multiproc_server.py @@ -1,208 +1,111 @@ -import contextlib -import logging -import multiprocessing import os -import socket -from concurrent import futures -from collections.abc import Iterator -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 - -from pynumaflow import setup_logging from pynumaflow._constants import ( + MAX_THREADS, MAX_MESSAGE_SIZE, + MAP_SOCK_PATH, + UDFType, + _PROCESS_COUNT, ) -from pynumaflow._constants import MULTIPROC_MAP_SOCK_ADDR -from pynumaflow.exceptions import SocketError -from pynumaflow.mapper import Datum -from pynumaflow.mapper._dtypes import MapCallable -from pynumaflow.mapper.proto import map_pb2 -from pynumaflow.mapper.proto import map_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import ( - get_sdk_version, - write as info_server_write, - get_metadata_env, -) -from pynumaflow.info.types import ( - ServerInfo, - Protocol, - Language, - SERVER_INFO_FILE_PATH, - METADATA_ENVS, +from pynumaflow.mapper._dtypes import MapSyncCallable +from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer +from pynumaflow.shared.server import ( + NumaflowServer, + start_multiproc_server, ) -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - -class MultiProcMapper(map_pb2_grpc.MapServicer): +class MapMultiprocServer(NumaflowServer): """ - Provides an interface to write a Multi Proc Mapper - which will be exposed over gRPC. - - Args: - handler: Function callable following the type signature of MapCallable - max_message_size: The max message size in bytes the server can receive and send - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.mapper import Messages, Message \ - ... Datum, MultiProcMapper - ... - >>> def map_handler(keys: list[str], datum: Datum) -> Messages: - ... val = datum.value - ... _ = datum.event_time - ... _ = datum.watermark - ... messages = Messages(Message(val, keys=keys)) - ... return messages - ... - >>> grpc_server = MultiProcMapper(handler=map_handler) - >>> grpc_server.start() + Create a new grpc Multiproc Map Server instance. """ - __slots__ = ( - "__map_handler", - "_max_message_size", - "_server_options", - "_process_count", - "_threads_per_proc", - ) - def __init__( self, - handler: MapCallable, + mapper_instance: MapSyncCallable, + server_count: int = _PROCESS_COUNT, + sock_path=MAP_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, + max_threads=MAX_THREADS, ): - self.__map_handler: MapCallable = handler - self._max_message_size = max_message_size + """ + Create a new grpc Multiproc Map Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + mapper_instance: The mapper instance to be used for Map UDF + server_count: The number of grpc server instances to be forked for multiproc + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example invocation: + import math + import os + from pynumaflow.mapper import Messages, Message, Datum, Mapper, MapMultiprocServer + + def is_prime(n): + for i in range(2, int(math.ceil(math.sqrt(n)))): + if n % i == 0: + return False + else: + return True + + class PrimeMap(Mapper): + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + messages = Messages() + for i in range(2, 100000): + is_prime(i) + messages.append(Message(val, keys=keys)) + return messages + + if __name__ == "__main__": + # To set the env server_count value set the env variable + # NUM_CPU_MULTIPROC="N" + server_count = int(os.getenv("NUM_CPU_MULTIPROC", "2")) + prime_class = PrimeMap() + # Server count is the number of server processes to start + grpc_server = MapMultiprocServer(prime_class, server_count=server_count) + grpc_server.start() + + """ + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.mapper_instance = mapper_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1), ] # Set the number of processes to be spawned to the number of CPUs or # the value of the env var NUM_CPU_MULTIPROC defined by the user # Setting the max value to 2 * CPU count - self._process_count = min( - int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count() - ) - self._threads_per_proc = int(os.getenv("MAX_THREADS", "4")) - - def MapFn( - self, request: map_pb2.MapRequest, context: NumaflowServicerContext - ) -> map_pb2.MapResponse: - """ - Applies a function to each datum element. - The pascal case function name comes from the proto map_pb2_grpc.py file. - """ - # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: - msgs = self.__map_handler( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(err)) - return map_pb2.MapResponse(results=[]) - - datums = [] - - for msg in msgs: - datums.append(map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)) - - return map_pb2.MapResponse(results=datums) - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> map_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto map_pb2_grpc.py file. - """ - return map_pb2.ReadyResponse(ready=True) - - def _run_server(self, bind_address: str) -> None: - """Start a server in a subprocess.""" - _LOGGER.info( - "Starting new server with num_procs: %s, num_threads/proc: %s", - self._process_count, - self._threads_per_proc, - ) - server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=self._threads_per_proc, - ), - options=self._server_options, - ) - map_pb2_grpc.add_MapServicer_to_server(self, server) - server.add_insecure_port(bind_address) - server.start() - _LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid()) - server.wait_for_termination() - - @contextlib.contextmanager - def _reserve_port(self, port_num: int) -> Iterator[int]: - """Find and reserve a port for all subprocesses to use.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0: - raise SocketError("Failed to set SO_REUSEADDR.") - try: - sock.bind(("", port_num)) - yield sock.getsockname()[1] - finally: - sock.close() + # Used for multiproc server + self._process_count = min(server_count, 2 * _PROCESS_COUNT) + self.servicer = SyncMapServicer(handler=mapper_instance) def start(self) -> None: """ - Start N grpc servers in different processes where N = The number of CPUs or the + Starts the N grpc servers gRPC serves on the with + given max threads. + where N = The number of CPUs or the value of the env var NUM_CPU_MULTIPROC defined by the user. The max value is set to 2 * CPU count. - Each server will be bound to a different port, and we will create equal number of - workers to handle each server. - On the client side there will be same number of connections as the number of servers. """ - workers = [] - server_ports = [] - for _ in range(self._process_count): - # Find a port to bind to for each server, thus sending the port number = 0 - # to the _reserve_port function so that kernel can find and return a free port - with self._reserve_port(0) as port: - bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}" - _LOGGER.info("Starting server on port: %s", port) - # NOTE: It is imperative that the worker subprocesses be forked before - # any gRPC servers start up. See - # https://github.com/grpc/grpc/issues/16001 for more details. - worker = multiprocessing.Process(target=self._run_server, args=(bind_address,)) - worker.start() - workers.append(worker) - server_ports.append(port) - - # Convert the available ports to a comma separated string - ports = ",".join(map(str, server_ports)) - serv_info = ServerInfo( - protocol=Protocol.TCP, - language=Language.PYTHON, - version=get_sdk_version(), - metadata=get_metadata_env(envs=METADATA_ENVS), + # Start the multiproc server + start_multiproc_server( + max_threads=self.max_threads, + servicer=self.servicer, + process_count=self._process_count, + server_options=self._server_options, + udf_type=UDFType.Map, ) - # Add the PORTS metadata using the available ports - serv_info.metadata["SERV_PORTS"] = ports - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - for worker in workers: - worker.join() diff --git a/pynumaflow/mapper/server.py b/pynumaflow/mapper/server.py deleted file mode 100644 index 0ae779ee..00000000 --- a/pynumaflow/mapper/server.py +++ /dev/null @@ -1,136 +0,0 @@ -import logging -import multiprocessing -import os -from concurrent.futures import ThreadPoolExecutor - -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH - -from pynumaflow import setup_logging -from pynumaflow._constants import ( - MAX_MESSAGE_SIZE, - MAP_SOCK_PATH, -) -from pynumaflow.mapper import Datum -from pynumaflow.mapper._dtypes import MapCallable -from pynumaflow.mapper.proto import map_pb2 -from pynumaflow.mapper.proto import map_pb2_grpc -from pynumaflow.types import NumaflowServicerContext - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - - -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - - -class Mapper(map_pb2_grpc.MapServicer): - """ - Provides an interface to write a Mapper - which will be exposed over a Synchronous gRPC server. - - Args: - handler: Function callable following the type signature of MapCallable - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.mapper import Messages, Message\ - ... Datum, Mapper - ... - >>> def map_handler(key: [str], datum: Datum) -> Messages: - ... val = datum.value - ... _ = datum.event_time - ... _ = datum.watermark - ... messages = Messages(Message(val, keys=keys)) - ... return messages - ... - >>> grpc_server = Mapper(handler=map_handler) - >>> grpc_server.start() - """ - - def __init__( - self, - handler: MapCallable, - sock_path=MAP_SOCK_PATH, - max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, - ): - self.__map_handler: MapCallable = handler - self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - - self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), - ] - - def MapFn( - self, request: map_pb2.MapRequest, context: NumaflowServicerContext - ) -> map_pb2.MapResponse: - """ - Applies a function to each datum element. - The pascal case function name comes from the proto map_pb2_grpc.py file. - """ - # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: - msgs = self.__map_handler( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(err)) - return map_pb2.MapResponse(results=[]) - - datums = [] - - for msg in msgs: - datums.append(map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)) - - return map_pb2.MapResponse(results=datums) - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> map_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto map_pb2_grpc.py file. - """ - return map_pb2.ReadyResponse(ready=True) - - def start(self) -> None: - """ - Starts the gRPC server on the given UNIX socket with given max threads. - """ - server = grpc.server( - ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options - ) - map_pb2_grpc.add_MapServicer_to_server(self, server) - server.add_insecure_port(self.sock_path) - server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - _LOGGER.info( - "GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads - ) - server.wait_for_termination() diff --git a/pynumaflow/mapper/proto/__init__.py b/pynumaflow/mapper/servicer/__init__.py similarity index 100% rename from pynumaflow/mapper/proto/__init__.py rename to pynumaflow/mapper/servicer/__init__.py diff --git a/pynumaflow/mapper/servicer/async_servicer.py b/pynumaflow/mapper/servicer/async_servicer.py new file mode 100644 index 00000000..9b076cce --- /dev/null +++ b/pynumaflow/mapper/servicer/async_servicer.py @@ -0,0 +1,72 @@ +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.mapper._dtypes import Datum +from pynumaflow.mapper._dtypes import MapAsyncHandlerCallable, MapSyncCallable +from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +class AsyncMapServicer(map_pb2_grpc.MapServicer): + """ + This class is used to create a new grpc Async Map Servicer instance. + It implements the SyncMapServicer interface from the proto map.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: MapAsyncHandlerCallable, + ): + self.__map_handler: MapSyncCallable = handler + + async def MapFn( + self, request: map_pb2.MapRequest, context: NumaflowServicerContext + ) -> map_pb2.MapResponse: + """ + Applies a function to each datum element. + The pascal case function name comes from the proto map_pb2_grpc.py file. + """ + # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer + # we need to explicitly convert it to list + try: + res = await self.__invoke_map( + list(request.keys), + Datum( + keys=list(request.keys), + value=request.value, + event_time=request.event_time.ToDatetime(), + watermark=request.watermark.ToDatetime(), + ), + ) + except Exception as e: + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(e)) + return map_pb2.MapResponse(results=[]) + + return map_pb2.MapResponse(results=res) + + async def __invoke_map(self, keys: list[str], req: Datum): + """ + Invokes the user defined function. + """ + try: + msgs = await self.__map_handler(keys, req) + except Exception as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + raise err + datums = [] + for msg in msgs: + datums.append(map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)) + + return datums + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> map_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto map_pb2_grpc.py file. + """ + return map_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/mapper/servicer/sync_servicer.py b/pynumaflow/mapper/servicer/sync_servicer.py new file mode 100644 index 00000000..b01690d0 --- /dev/null +++ b/pynumaflow/mapper/servicer/sync_servicer.py @@ -0,0 +1,38 @@ +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.mapper._dtypes import MapSyncCallable +from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc +from pynumaflow.mapper.servicer.utils import _map_fn_util +from pynumaflow.types import NumaflowServicerContext + + +class SyncMapServicer(map_pb2_grpc.MapServicer): + """ + This class is used to create a new grpc Map Servicer instance. + It implements the SyncMapServicer interface from the proto map.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: MapSyncCallable, + ): + self.__map_handler: MapSyncCallable = handler + + def MapFn( + self, request: map_pb2.MapRequest, context: NumaflowServicerContext + ) -> map_pb2.MapResponse: + """ + Applies a function to each datum element. + The pascal case function name comes from the proto map_pb2_grpc.py file. + """ + return _map_fn_util(self.__map_handler, request, context) + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> map_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto map_pb2_grpc.py file. + """ + return map_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/mapper/servicer/utils.py b/pynumaflow/mapper/servicer/utils.py new file mode 100644 index 00000000..c0c26185 --- /dev/null +++ b/pynumaflow/mapper/servicer/utils.py @@ -0,0 +1,36 @@ +import grpc +from pynumaflow.mapper._dtypes import MapSyncCallable + +from pynumaflow.mapper._dtypes import Datum +from pynumaflow.proto.mapper import map_pb2 +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +def _map_fn_util( + __map_handler: MapSyncCallable, request: map_pb2.MapRequest, context: NumaflowServicerContext +) -> map_pb2.MapResponse: + # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer + # we need to explicitly convert it to list + try: + msgs = __map_handler( + list(request.keys), + Datum( + keys=list(request.keys), + value=request.value, + event_time=request.event_time.ToDatetime(), + watermark=request.watermark.ToDatetime(), + ), + ) + except Exception as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(err)) + return map_pb2.MapResponse(results=[]) + + datums = [] + + for msg in msgs: + datums.append(map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)) + + return map_pb2.MapResponse(results=datums) diff --git a/pynumaflow/mapper/sync_server.py b/pynumaflow/mapper/sync_server.py new file mode 100644 index 00000000..a45f582d --- /dev/null +++ b/pynumaflow/mapper/sync_server.py @@ -0,0 +1,109 @@ +import os + + +from pynumaflow.mapper.servicer.sync_servicer import SyncMapServicer + +from pynumaflow._constants import ( + MAX_THREADS, + MAX_MESSAGE_SIZE, + _LOGGER, + MAP_SOCK_PATH, + UDFType, +) + +from pynumaflow.mapper._dtypes import MapSyncCallable +from pynumaflow.shared.server import ( + NumaflowServer, + sync_server_start, +) + + +class MapServer(NumaflowServer): + """ + Create a new grpc Map Server instance. + Args: + mapper_instance: The mapper instance to be used for Map UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example Invocation: + from pynumaflow.mapper import Messages, Message, Datum, MapServer, Mapper + + class MessageForwarder(Mapper): + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + return Messages(Message(value=val, keys=keys)) + + def my_handler(keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + return Messages(Message(value=val, keys=keys)) + + + if __name__ == "__main__": + Use the class based approach or function based handler + based on the env variable + Both can be used and passed directly to the server class + + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + handler = MessageForwarder() + else: + handler = my_handler + grpc_server = MapServer(handler) + grpc_server.start() + """ + + def __init__( + self, + mapper_instance: MapSyncCallable, + sock_path=MAP_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=MAX_THREADS, + ): + """ + Create a new grpc Synchronous Map Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + mapper_instance: The mapper instance to be used for Map UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + """ + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.mapper_instance = mapper_instance + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + # Get the servicer instance for the sync server + self.servicer = SyncMapServicer(handler=mapper_instance) + + def start(self) -> None: + """ + Starts the Synchronous gRPC server on the given UNIX socket with given max threads. + """ + _LOGGER.info( + "Sync GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, + ) + # Start the server + sync_server_start( + servicer=self.servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_options=self._server_options, + udf_type=UDFType.Map, + ) diff --git a/pynumaflow/mapstreamer/__init__.py b/pynumaflow/mapstreamer/__init__.py index 2896a903..f26f4bd4 100644 --- a/pynumaflow/mapstreamer/__init__.py +++ b/pynumaflow/mapstreamer/__init__.py @@ -1,15 +1,13 @@ -from pynumaflow.mapstreamer._dtypes import ( - Message, - Messages, - Datum, - DROP, -) -from pynumaflow.mapstreamer.async_server import AsyncMapStreamer +from pynumaflow._constants import DROP + +from pynumaflow.mapstreamer._dtypes import Message, Messages, Datum, MapStreamer +from pynumaflow.mapstreamer.async_server import MapStreamAsyncServer __all__ = [ "Message", "Messages", "Datum", "DROP", - "AsyncMapStreamer", + "MapStreamAsyncServer", + "MapStreamer", ] diff --git a/pynumaflow/mapstreamer/_dtypes.py b/pynumaflow/mapstreamer/_dtypes.py index 27a1fb14..2a467b9c 100644 --- a/pynumaflow/mapstreamer/_dtypes.py +++ b/pynumaflow/mapstreamer/_dtypes.py @@ -1,7 +1,8 @@ +from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Sequence from dataclasses import dataclass from datetime import datetime -from typing import TypeVar, Callable +from typing import TypeVar, Callable, Union from collections.abc import AsyncIterable from warnings import warn @@ -163,4 +164,28 @@ def watermark(self) -> datetime: return self._watermark -MapStreamCallable = Callable[[list[str], Datum], AsyncIterable[Message]] +class MapStreamer(metaclass=ABCMeta): + """ + Provides an interface to write a Map Streamer + which will be exposed over a gRPC server. + + Args: + + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + """ + return self.handler(*args, **kwargs) + + @abstractmethod + async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]: + """ + Implement this handler function which implements the MapSyncCallable interface. + """ + pass + + +MapStreamAsyncCallable = Callable[[list[str], Datum], AsyncIterable[Message]] +MapStreamCallable = Union[MapStreamer, MapStreamAsyncCallable] diff --git a/pynumaflow/mapstreamer/async_server.py b/pynumaflow/mapstreamer/async_server.py index 284517ee..1092c456 100644 --- a/pynumaflow/mapstreamer/async_server.py +++ b/pynumaflow/mapstreamer/async_server.py @@ -1,151 +1,121 @@ -import logging -import multiprocessing import os -from collections.abc import AsyncIterable - +import aiorun import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow import setup_logging +from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer +from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc + from pynumaflow._constants import ( - MAX_MESSAGE_SIZE, MAP_STREAM_SOCK_PATH, + MAX_MESSAGE_SIZE, + MAX_THREADS, + _LOGGER, ) -from pynumaflow.mapstreamer import Datum -from pynumaflow.mapstreamer._dtypes import MapStreamCallable -from pynumaflow.mapstreamer.proto import mapstream_pb2 -from pynumaflow.mapstreamer.proto import mapstream_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) +from pynumaflow.mapstreamer._dtypes import MapStreamCallable -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) +from pynumaflow.shared.server import NumaflowServer, start_async_server -class AsyncMapStreamer(mapstream_pb2_grpc.MapStreamServicer): +class MapStreamAsyncServer(NumaflowServer): """ - Provides an interface to write a Map Streamer - which will be exposed over gRPC. - - Args: - handler: Function callable following the type signature of MapStreamCallable - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.mapstreamer import Messages, Message \ - ... Datum, AsyncMapStreamer - ... import aiorun - >>> async def map_stream_handler(key: [str], datums: Datum) -> AsyncIterable[Message]: - ... val = datum.value - ... _ = datum.event_time - ... _ = datum.watermark - ... for i in range(10): - ... yield Message(val, keys=keys) - ... - >>> grpc_server = AsyncMapStreamer(handler=map_stream_handler) - >>> aiorun.run(grpc_server.start()) + Class for a new Map Stream Server instance. """ def __init__( self, - handler: MapStreamCallable, + map_stream_instance: MapStreamCallable, sock_path=MAP_STREAM_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__map_stream_handler: MapStreamCallable = handler + """ + Create a new grpc Async Map Stream Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + map_stream_instance: The map stream instance to be used for Map Stream UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + server_type: The type of server to be used + + Example invocation: + import os + from collections.abc import AsyncIterable + from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer, MapStreamer + + class FlatMapStream(MapStreamer): + async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]: + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + + if len(strs) == 0: + yield Message.to_drop() + return + for s in strs: + yield Message(str.encode(s)) + + async def map_stream_handler(_: list[str], datum: Datum) -> AsyncIterable[Message]: + + val = datum.value + _ = datum.event_time + _ = datum.watermark + strs = val.decode("utf-8").split(",") + + if len(strs) == 0: + yield Message.to_drop() + return + for s in strs: + yield Message(str.encode(s)) + + if __name__ == "__main__": + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + handler = FlatMapStream() + else: + handler = map_stream_handler + grpc_server = MapStreamAsyncServer(handler) + grpc_server.start() + + """ + self.map_stream_instance: MapStreamCallable = map_stream_instance self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - # Collection for storing strong references to all running tasks. - # Event loop only keeps a weak reference, which can cause it to - # get lost during execution. - self.background_tasks = set() + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] - async def MapStreamFn( - self, - request: mapstream_pb2.MapStreamRequest, - context: NumaflowServicerContext, - ) -> AsyncIterable[mapstream_pb2.MapStreamResponse]: + self.servicer = AsyncMapStreamServicer(handler=self.map_stream_instance) + + def start(self): """ - Applies a map function to a datum stream in streaming mode. - The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + Starter function for the Async Map Stream server, we need a separate caller + to the aexec so that all the async coroutines can be started from a single context """ + aiorun.run(self.aexec(), use_uvloop=True) - async for res in self.__invoke_map_stream( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ): - yield mapstream_pb2.MapStreamResponse(result=res) - - async def __invoke_map_stream(self, keys: list[str], req: Datum): - try: - async for msg in self.__map_stream_handler(keys, req): - yield mapstream_pb2.MapStreamResponse.Result( - keys=msg.keys, value=msg.value, tags=msg.tags - ) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - raise err - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> mapstream_pb2.ReadyResponse: + async def aexec(self): """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + Starts the Async gRPC server on the given UNIX socket with + given max threads. """ - return mapstream_pb2.ReadyResponse(ready=True) - - async def __serve_async(self, server) -> None: + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server() + server.add_insecure_port(self.sock_path) mapstream_pb2_grpc.add_MapStreamServicer_to_server( - AsyncMapStreamer(handler=self.__map_stream_handler), + self.servicer, server, ) - server.add_insecure_port(self.sock_path) - _LOGGER.info("GRPC Async Server listening on: %s", self.sock_path) - await server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - async def server_graceful_shutdown(): - """ - Shuts down the server with 5 seconds of grace period. During the - grace period, the server won't accept new connections and allow - existing RPCs to continue within the grace period. - """ - _LOGGER.info("Starting graceful shutdown...") - await server.stop(5) - - self.cleanup_coroutines.append(server_graceful_shutdown()) - await server.wait_for_termination() - - async def start(self) -> None: - """Starts the Async gRPC server on the given UNIX socket.""" - server = grpc.aio.server(options=self._server_options) - await self.__serve_async(server) + _LOGGER.info("Starting Map Stream Server") + await start_async_server(server, self.sock_path, self.max_threads, self._server_options) diff --git a/pynumaflow/mapstreamer/proto/__init__.py b/pynumaflow/mapstreamer/servicer/__init__.py similarity index 100% rename from pynumaflow/mapstreamer/proto/__init__.py rename to pynumaflow/mapstreamer/servicer/__init__.py diff --git a/pynumaflow/mapstreamer/servicer/async_servicer.py b/pynumaflow/mapstreamer/servicer/async_servicer.py new file mode 100644 index 00000000..33c8bd7c --- /dev/null +++ b/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -0,0 +1,64 @@ +from collections.abc import AsyncIterable + +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.mapstreamer import Datum +from pynumaflow.mapstreamer._dtypes import MapStreamCallable +from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc, mapstream_pb2 +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +class AsyncMapStreamServicer(mapstream_pb2_grpc.MapStreamServicer): + """ + This class is used to create a new grpc Map Stream Servicer instance. + It implements the SyncMapServicer interface from the proto + mapstream_pb2_grpc.py file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: MapStreamCallable, + ): + self.__map_stream_handler: MapStreamCallable = handler + + async def MapStreamFn( + self, + request: mapstream_pb2.MapStreamRequest, + context: NumaflowServicerContext, + ) -> AsyncIterable[mapstream_pb2.MapStreamResponse]: + """ + Applies a map function to a datum stream in streaming mode. + The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + """ + + async for res in self.__invoke_map_stream( + list(request.keys), + Datum( + keys=list(request.keys), + value=request.value, + event_time=request.event_time.ToDatetime(), + watermark=request.watermark.ToDatetime(), + ), + ): + yield mapstream_pb2.MapStreamResponse(result=res) + + async def __invoke_map_stream(self, keys: list[str], req: Datum): + try: + async for msg in self.__map_stream_handler(keys, req): + yield mapstream_pb2.MapStreamResponse.Result( + keys=msg.keys, value=msg.value, tags=msg.tags + ) + except Exception as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + raise err + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> mapstream_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto mapstream_pb2_grpc.py file. + """ + return mapstream_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/reducer/proto/__init__.py b/pynumaflow/proto/__init__.py similarity index 100% rename from pynumaflow/reducer/proto/__init__.py rename to pynumaflow/proto/__init__.py diff --git a/pynumaflow/sideinput/proto/__init__.py b/pynumaflow/proto/mapper/__init__.py similarity index 100% rename from pynumaflow/sideinput/proto/__init__.py rename to pynumaflow/proto/mapper/__init__.py diff --git a/pynumaflow/mapper/proto/map.proto b/pynumaflow/proto/mapper/map.proto similarity index 100% rename from pynumaflow/mapper/proto/map.proto rename to pynumaflow/proto/mapper/map.proto diff --git a/pynumaflow/mapper/proto/map_pb2.py b/pynumaflow/proto/mapper/map_pb2.py similarity index 98% rename from pynumaflow/mapper/proto/map_pb2.py rename to pynumaflow/proto/mapper/map_pb2.py index ddb812df..881e4fb3 100644 --- a/pynumaflow/mapper/proto/map_pb2.py +++ b/pynumaflow/proto/mapper/map_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: map.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/mapper/proto/map_pb2_grpc.py b/pynumaflow/proto/mapper/map_pb2_grpc.py similarity index 100% rename from pynumaflow/mapper/proto/map_pb2_grpc.py rename to pynumaflow/proto/mapper/map_pb2_grpc.py diff --git a/pynumaflow/sinker/proto/__init__.py b/pynumaflow/proto/mapstreamer/__init__.py similarity index 100% rename from pynumaflow/sinker/proto/__init__.py rename to pynumaflow/proto/mapstreamer/__init__.py diff --git a/pynumaflow/mapstreamer/proto/mapstream.proto b/pynumaflow/proto/mapstreamer/mapstream.proto similarity index 100% rename from pynumaflow/mapstreamer/proto/mapstream.proto rename to pynumaflow/proto/mapstreamer/mapstream.proto diff --git a/pynumaflow/mapstreamer/proto/mapstream_pb2.py b/pynumaflow/proto/mapstreamer/mapstream_pb2.py similarity index 98% rename from pynumaflow/mapstreamer/proto/mapstream_pb2.py rename to pynumaflow/proto/mapstreamer/mapstream_pb2.py index f1c2c169..abbdf0a0 100644 --- a/pynumaflow/mapstreamer/proto/mapstream_pb2.py +++ b/pynumaflow/proto/mapstreamer/mapstream_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: mapstream.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/mapstreamer/proto/mapstream_pb2_grpc.py b/pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py similarity index 100% rename from pynumaflow/mapstreamer/proto/mapstream_pb2_grpc.py rename to pynumaflow/proto/mapstreamer/mapstream_pb2_grpc.py diff --git a/pynumaflow/sourcer/proto/__init__.py b/pynumaflow/proto/reducer/__init__.py similarity index 100% rename from pynumaflow/sourcer/proto/__init__.py rename to pynumaflow/proto/reducer/__init__.py diff --git a/pynumaflow/reducer/proto/reduce.proto b/pynumaflow/proto/reducer/reduce.proto similarity index 100% rename from pynumaflow/reducer/proto/reduce.proto rename to pynumaflow/proto/reducer/reduce.proto diff --git a/pynumaflow/reducer/proto/reduce_pb2.py b/pynumaflow/proto/reducer/reduce_pb2.py similarity index 98% rename from pynumaflow/reducer/proto/reduce_pb2.py rename to pynumaflow/proto/reducer/reduce_pb2.py index f61b8887..e5b2aceb 100644 --- a/pynumaflow/reducer/proto/reduce_pb2.py +++ b/pynumaflow/proto/reducer/reduce_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: reduce.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/reducer/proto/reduce_pb2_grpc.py b/pynumaflow/proto/reducer/reduce_pb2_grpc.py similarity index 100% rename from pynumaflow/reducer/proto/reduce_pb2_grpc.py rename to pynumaflow/proto/reducer/reduce_pb2_grpc.py diff --git a/pynumaflow/sourcetransformer/proto/__init__.py b/pynumaflow/proto/sideinput/__init__.py similarity index 100% rename from pynumaflow/sourcetransformer/proto/__init__.py rename to pynumaflow/proto/sideinput/__init__.py diff --git a/pynumaflow/sideinput/proto/sideinput.proto b/pynumaflow/proto/sideinput/sideinput.proto similarity index 100% rename from pynumaflow/sideinput/proto/sideinput.proto rename to pynumaflow/proto/sideinput/sideinput.proto diff --git a/pynumaflow/sideinput/proto/sideinput_pb2.py b/pynumaflow/proto/sideinput/sideinput_pb2.py similarity index 97% rename from pynumaflow/sideinput/proto/sideinput_pb2.py rename to pynumaflow/proto/sideinput/sideinput_pb2.py index 8278c1df..82983082 100644 --- a/pynumaflow/sideinput/proto/sideinput_pb2.py +++ b/pynumaflow/proto/sideinput/sideinput_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: sideinput.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/sideinput/proto/sideinput_pb2_grpc.py b/pynumaflow/proto/sideinput/sideinput_pb2_grpc.py similarity index 100% rename from pynumaflow/sideinput/proto/sideinput_pb2_grpc.py rename to pynumaflow/proto/sideinput/sideinput_pb2_grpc.py diff --git a/pynumaflow/proto/sinker/__init__.py b/pynumaflow/proto/sinker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sinker/proto/sink.proto b/pynumaflow/proto/sinker/sink.proto similarity index 100% rename from pynumaflow/sinker/proto/sink.proto rename to pynumaflow/proto/sinker/sink.proto diff --git a/pynumaflow/sinker/proto/sink_pb2.py b/pynumaflow/proto/sinker/sink_pb2.py similarity index 92% rename from pynumaflow/sinker/proto/sink_pb2.py rename to pynumaflow/proto/sinker/sink_pb2.py index b6182a45..eada281c 100644 --- a/pynumaflow/sinker/proto/sink_pb2.py +++ b/pynumaflow/proto/sinker/sink_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: sink.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool @@ -24,8 +25,10 @@ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "sink_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b"Z6github.com/numaproj/numaflow-go/pkg/apis/proto/sink/v1" + _globals["DESCRIPTOR"]._options = None + _globals[ + "DESCRIPTOR" + ]._serialized_options = b"Z6github.com/numaproj/numaflow-go/pkg/apis/proto/sink/v1" _globals["_SINKREQUEST"]._serialized_start = 86 _globals["_SINKREQUEST"]._serialized_end = 235 _globals["_READYRESPONSE"]._serialized_start = 237 diff --git a/pynumaflow/sinker/proto/sink_pb2_grpc.py b/pynumaflow/proto/sinker/sink_pb2_grpc.py similarity index 100% rename from pynumaflow/sinker/proto/sink_pb2_grpc.py rename to pynumaflow/proto/sinker/sink_pb2_grpc.py diff --git a/pynumaflow/proto/sourcer/__init__.py b/pynumaflow/proto/sourcer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sourcer/proto/source.proto b/pynumaflow/proto/sourcer/source.proto similarity index 100% rename from pynumaflow/sourcer/proto/source.proto rename to pynumaflow/proto/sourcer/source.proto diff --git a/pynumaflow/sourcer/proto/source_pb2.py b/pynumaflow/proto/sourcer/source_pb2.py similarity index 99% rename from pynumaflow/sourcer/proto/source_pb2.py rename to pynumaflow/proto/sourcer/source_pb2.py index 73c282e1..10fe18d7 100644 --- a/pynumaflow/sourcer/proto/source_pb2.py +++ b/pynumaflow/proto/sourcer/source_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: source.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/sourcer/proto/source_pb2_grpc.py b/pynumaflow/proto/sourcer/source_pb2_grpc.py similarity index 100% rename from pynumaflow/sourcer/proto/source_pb2_grpc.py rename to pynumaflow/proto/sourcer/source_pb2_grpc.py diff --git a/pynumaflow/proto/sourcetransformer/__init__.py b/pynumaflow/proto/sourcetransformer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sourcetransformer/proto/transform.proto b/pynumaflow/proto/sourcetransformer/transform.proto similarity index 100% rename from pynumaflow/sourcetransformer/proto/transform.proto rename to pynumaflow/proto/sourcetransformer/transform.proto diff --git a/pynumaflow/sourcetransformer/proto/transform_pb2.py b/pynumaflow/proto/sourcetransformer/transform_pb2.py similarity index 98% rename from pynumaflow/sourcetransformer/proto/transform_pb2.py rename to pynumaflow/proto/sourcetransformer/transform_pb2.py index 41946e02..2f96e5fb 100644 --- a/pynumaflow/sourcetransformer/proto/transform_pb2.py +++ b/pynumaflow/proto/sourcetransformer/transform_pb2.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: transform.proto +# Protobuf Python Version: 4.25.0 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool diff --git a/pynumaflow/sourcetransformer/proto/transform_pb2_grpc.py b/pynumaflow/proto/sourcetransformer/transform_pb2_grpc.py similarity index 100% rename from pynumaflow/sourcetransformer/proto/transform_pb2_grpc.py rename to pynumaflow/proto/sourcetransformer/transform_pb2_grpc.py diff --git a/pynumaflow/reducer/__init__.py b/pynumaflow/reducer/__init__.py index 36fe4a9f..7a1c878b 100644 --- a/pynumaflow/reducer/__init__.py +++ b/pynumaflow/reducer/__init__.py @@ -5,8 +5,9 @@ IntervalWindow, Metadata, DROP, + Reducer, ) -from pynumaflow.reducer.async_server import AsyncReducer +from pynumaflow.reducer.async_server import ReduceAsyncServer __all__ = [ "Message", @@ -15,5 +16,6 @@ "IntervalWindow", "Metadata", "DROP", - "AsyncReducer", + "ReduceAsyncServer", + "Reducer", ] diff --git a/pynumaflow/reducer/_dtypes.py b/pynumaflow/reducer/_dtypes.py index 534d4e28..bc881e33 100644 --- a/pynumaflow/reducer/_dtypes.py +++ b/pynumaflow/reducer/_dtypes.py @@ -1,12 +1,13 @@ +from abc import ABCMeta, abstractmethod from asyncio import Task from collections.abc import Iterator, Sequence, Awaitable from dataclasses import dataclass from datetime import datetime -from typing import TypeVar, Callable +from typing import TypeVar, Callable, Union from collections.abc import AsyncIterable from warnings import warn -from pynumaflow.reducer.asynciter import NonBlockingIterator +from pynumaflow.reducer.servicer.asynciter import NonBlockingIterator from pynumaflow._constants import DROP M = TypeVar("M", bound="Message") @@ -232,4 +233,54 @@ def keys(self) -> list[str]: return self._key -ReduceCallable = Callable[[list[str], AsyncIterable[Datum], Metadata], Awaitable[Messages]] +ReduceAsyncCallable = Callable[[list[str], AsyncIterable[Datum], Metadata], Awaitable[Messages]] + + +class Reducer(metaclass=ABCMeta): + """ + Provides an interface to write a Reducer + which will be exposed over a gRPC server. + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + as the reducer_instance. + """ + return self.handler(*args, **kwargs) + + @abstractmethod + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + """ + Implement this handler function which implements the ReduceCallable interface. + """ + pass + + +class _ReduceBuilderClass: + """ + Class to build a Reducer class instance. + Used Internally + + Args: + reducer_class: the reducer class to be used for Reduce UDF + args: the arguments to be passed to the reducer class + kwargs: the keyword arguments to be passed to the reducer class + """ + + def __init__(self, reducer_class: type[Reducer], args: tuple, kwargs: dict): + self._reducer_class: type[Reducer] = reducer_class + self._args = args + self._kwargs = kwargs + + def create(self) -> Reducer: + """ + Create a new Reducer instance. + """ + return self._reducer_class(*self._args, **self._kwargs) + + +# ReduceCallable is a callable which can be used as a handler for the Reduce UDF. +ReduceCallable = Union[ReduceAsyncCallable, type[Reducer]] diff --git a/pynumaflow/reducer/async_server.py b/pynumaflow/reducer/async_server.py index 90d83e1e..a42d7ee7 100644 --- a/pynumaflow/reducer/async_server.py +++ b/pynumaflow/reducer/async_server.py @@ -1,251 +1,165 @@ -import asyncio -import logging -import multiprocessing -import os - -from datetime import datetime, timezone -from collections.abc import AsyncIterable +import inspect +import aiorun import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow import setup_logging +from pynumaflow.proto.reducer import reduce_pb2_grpc + +from pynumaflow.reducer.servicer.async_servicer import AsyncReduceServicer + from pynumaflow._constants import ( - WIN_START_TIME, - WIN_END_TIME, - MAX_MESSAGE_SIZE, - STREAM_EOF, - DELIMITER, REDUCE_SOCK_PATH, + MAX_MESSAGE_SIZE, + MAX_THREADS, + _LOGGER, ) -from pynumaflow.reducer import Datum, IntervalWindow, Metadata -from pynumaflow.reducer._dtypes import ReduceResult, ReduceCallable -from pynumaflow.reducer.asynciter import NonBlockingIterator -from pynumaflow.reducer.proto import reduce_pb2 -from pynumaflow.reducer.proto import reduce_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - - -async def datum_generator( - request_iterator: AsyncIterable[reduce_pb2.ReduceRequest], -) -> AsyncIterable[Datum]: - async for d in request_iterator: - datum = Datum( - keys=list(d.keys), - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - ) - yield datum +from pynumaflow.reducer._dtypes import ( + ReduceCallable, + _ReduceBuilderClass, + Reducer, +) -class AsyncReducer(reduce_pb2_grpc.ReduceServicer): - """ - Provides an interface to write a Reduce Function - which will be exposed over gRPC. +from pynumaflow.shared.server import NumaflowServer, checkInstance, start_async_server + +def get_handler(reducer_handler: ReduceCallable, init_args: tuple = (), init_kwargs: dict = None): + """ + Get the correct handler type based on the arguments passed + """ + if inspect.isfunction(reducer_handler): + if len(init_args) > 0 or len(init_kwargs) > 0: + # if the init_args or init_kwargs are passed, then the reducer_handler + # can only be of class Reducer type + raise TypeError("Cannot pass function handler with init args or kwargs") + # return the function handler + return reducer_handler + elif not checkInstance(reducer_handler, Reducer) and issubclass(reducer_handler, Reducer): + # if handler is type of Class Reducer, create a new instance of + # a ReducerBuilderClass + return _ReduceBuilderClass(reducer_handler, init_args, init_kwargs) + else: + _LOGGER.error("Invalid Type: please provide the handler or the class name") + raise TypeError("Inavlid Type: please provide the handler or the class name") + + +class ReduceAsyncServer(NumaflowServer): + """ + Class for a new Reduce Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. Args: - handler: Function callable following the type signature of ReduceCallable - sock_path: Path to the UNIX Domain Socket + reducer_handler: The reducer instance to be used for Reduce UDF + sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - + defaults to number of processors x4 Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.reducer import Messages, Message\ - ... Datum, Metadata, AsyncReducer - ... import aiorun - ... - >>> async def reduce_handler(key: list[str], datums: AsyncIterable[Datum], - >>> md: Metadata) -> Messages: - ... interval_window = md.interval_window - ... counter = 0 - ... async for _ in datums: - ... counter += 1 - ... msg = ( - ... f"counter:{counter} interval_window_start:{interval_window.start} " - ... f"interval_window_end:{interval_window.end}" - ... ) - ... return Messages(Message(value=str.encode(msg), keys=keys)) - ... - >>> grpc_server = AsyncReducer(handler=reduce_handler) - >>> aiorun.run(grpc_server.start()) + import os + from collections.abc import AsyncIterable + from pynumaflow.reducer import Messages, Message, Datum, Metadata, + ReduceAsyncServer, Reducer + + class ReduceCounter(Reducer): + def __init__(self, counter): + self.counter = counter + + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + interval_window = md.interval_window + self.counter = 0 + async for _ in datums: + self.counter += 1 + msg = ( + f"counter:{self.counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) + + async def reduce_handler(keys: list[str], + datums: AsyncIterable[Datum], + md: Metadata) -> Messages: + interval_window = md.interval_window + counter = 0 + async for _ in datums: + counter += 1 + msg = ( + f"counter:{counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) + + if __name__ == "__main__": + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + # Here we are using the class instance as the reducer_instance + # which will be used to invoke the handler function. + # We are passing the init_args for the class instance. + grpc_server = ReduceAsyncServer(ReduceCounter, init_args=(0,)) + else: + # Here we are using the handler function directly as the reducer_instance. + grpc_server = ReduceAsyncServer(reduce_handler) + grpc_server.start() + """ def __init__( self, - handler: ReduceCallable, + reducer_handler: ReduceCallable, + init_args: tuple = (), + init_kwargs: dict = None, sock_path=REDUCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__reduce_handler: ReduceCallable = handler + """ + Create a new grpc Reduce Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + reducer_instance: The reducer instance to be used for Reduce UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + server_type: The type of server to be used + """ + if init_kwargs is None: + init_kwargs = {} + self.reducer_handler = get_handler(reducer_handler, init_args, init_kwargs) self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - # Collection for storing strong references to all running tasks. - # Event loop only keeps a weak reference, which can cause it to - # get lost during execution. - self.background_tasks = set() + self.max_message_size = max_message_size + self.max_threads = max_threads self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] + # Get the servicer instance for the async server + self.servicer = AsyncReduceServicer(self.reducer_handler) - async def ReduceFn( - self, - request_iterator: AsyncIterable[reduce_pb2.ReduceRequest], - context: NumaflowServicerContext, - ) -> reduce_pb2.ReduceResponse: + def start(self): """ - Applies a reduce function to a datum stream. - The pascal case function name comes from the proto reduce_pb2_grpc.py file. + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context """ - - start, end = None, None - for metadata_key, metadata_value in context.invocation_metadata(): - if metadata_key == WIN_START_TIME: - start = metadata_value - elif metadata_key == WIN_END_TIME: - end = metadata_value - if not (start or end): - context.set_code(grpc.StatusCode.INVALID_ARGUMENT) - context.set_details( - f"Expected to have all key/window_start_time/window_end_time; " - f"got start: {start}, end: {end}." - ) - yield reduce_pb2.ReduceResponse(results=[]) - return - - start_dt = datetime.fromtimestamp(int(start) / 1e3, timezone.utc) - end_dt = datetime.fromtimestamp(int(end) / 1e3, timezone.utc) - interval_window = IntervalWindow(start=start_dt, end=end_dt) - - datum_iterator = datum_generator(request_iterator=request_iterator) - - response_task = asyncio.create_task( - self.__async_reduce_handler(interval_window, datum_iterator) + _LOGGER.info( + "Starting Async Reduce Server", ) + aiorun.run(self.aexec(), use_uvloop=True) - # Save a reference to the result of this function, to avoid a - # task disappearing mid-execution. - self.background_tasks.add(response_task) - response_task.add_done_callback(lambda t: self.background_tasks.remove(t)) - - await response_task - results_futures = response_task.result() - - try: - for fut in results_futures: - await fut - yield reduce_pb2.ReduceResponse(results=fut.result()) - except Exception as e: - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(e.__str__()) - yield reduce_pb2.ReduceResponse(results=[]) - - async def __async_reduce_handler(self, interval_window, datum_iterator: AsyncIterable[Datum]): - callable_dict = {} - # iterate through all the values - async for d in datum_iterator: - keys = d.keys() - unified_key = DELIMITER.join(keys) - result = callable_dict.get(unified_key, None) - - if not result: - niter = NonBlockingIterator() - riter = niter.read_iterator() - # schedule an async task for consumer - # returns a future that will give the results later. - task = asyncio.create_task( - self.__invoke_reduce(keys, riter, Metadata(interval_window=interval_window)) - ) - # Save a reference to the result of this function, to avoid a - # task disappearing mid-execution. - self.background_tasks.add(task) - task.add_done_callback(lambda t: self.background_tasks.remove(t)) - result = ReduceResult(task, niter, keys) - - callable_dict[unified_key] = result - - await result.iterator.put(d) - - for unified_key in callable_dict: - await callable_dict[unified_key].iterator.put(STREAM_EOF) - - tasks = [] - for unified_key in callable_dict: - fut = callable_dict[unified_key].future - tasks.append(fut) - - return tasks - - async def __invoke_reduce( - self, keys: list[str], request_iterator: AsyncIterable[Datum], md: Metadata - ): - try: - msgs = await self.__reduce_handler(keys, request_iterator, md) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - raise err - - datum_responses = [] - for msg in msgs: - datum_responses.append( - reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) - ) - - return datum_responses - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> reduce_pb2.ReadyResponse: + async def aexec(self): """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto reduce_pb2_grpc.py file. + Starts the Async gRPC server on the given UNIX socket with + given max threads. """ - return reduce_pb2.ReadyResponse(ready=True) - - async def __serve_async(self, server) -> None: - reduce_pb2_grpc.add_ReduceServicer_to_server( - AsyncReducer(handler=self.__reduce_handler), - server, - ) + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server() server.add_insecure_port(self.sock_path) - _LOGGER.info("GRPC Async Server listening on: %s", self.sock_path) - await server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - async def server_graceful_shutdown(): - """ - Shuts down the server with 5 seconds of grace period. During the - grace period, the server won't accept new connections and allow - existing RPCs to continue within the grace period. - """ - _LOGGER.info("Starting graceful shutdown...") - await server.stop(5) - - self.cleanup_coroutines.append(server_graceful_shutdown()) - await server.wait_for_termination() - - async def start(self) -> None: - """Starts the Async gRPC server on the given UNIX socket.""" - server = grpc.aio.server(options=self._server_options) - await self.__serve_async(server) + reduce_servicer = self.servicer + reduce_pb2_grpc.add_ReduceServicer_to_server(reduce_servicer, server) + await start_async_server(server, self.sock_path, self.max_threads, self._server_options) diff --git a/pynumaflow/reducer/servicer/__init__.py b/pynumaflow/reducer/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/reducer/servicer/async_servicer.py b/pynumaflow/reducer/servicer/async_servicer.py new file mode 100644 index 00000000..b8f0aef9 --- /dev/null +++ b/pynumaflow/reducer/servicer/async_servicer.py @@ -0,0 +1,180 @@ +import asyncio + +from datetime import datetime, timezone +from collections.abc import AsyncIterable +from typing import Union + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow._constants import ( + WIN_START_TIME, + WIN_END_TIME, + STREAM_EOF, + DELIMITER, +) +from pynumaflow.reducer._dtypes import ( + Datum, + IntervalWindow, + Metadata, + ReduceAsyncCallable, + _ReduceBuilderClass, +) +from pynumaflow.reducer._dtypes import ReduceResult +from pynumaflow.reducer.servicer.asynciter import NonBlockingIterator +from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +async def datum_generator( + request_iterator: AsyncIterable[reduce_pb2.ReduceRequest], +) -> AsyncIterable[Datum]: + async for d in request_iterator: + datum = Datum( + keys=list(d.keys), + value=d.value, + event_time=d.event_time.ToDatetime(), + watermark=d.watermark.ToDatetime(), + ) + yield datum + + +class AsyncReduceServicer(reduce_pb2_grpc.ReduceServicer): + """ + This class is used to create a new grpc Reduce servicer instance. + It implements the SyncMapServicer interface from the proto reduce.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: Union[ReduceAsyncCallable, _ReduceBuilderClass], + ): + # Collection for storing strong references to all running tasks. + # Event loop only keeps a weak reference, which can cause it to + # get lost during execution. + self.background_tasks = set() + # The reduce handler can be a function or a builder class instance. + self.__reduce_handler: Union[ReduceAsyncCallable, _ReduceBuilderClass] = handler + + async def ReduceFn( + self, + request_iterator: AsyncIterable[reduce_pb2.ReduceRequest], + context: NumaflowServicerContext, + ) -> reduce_pb2.ReduceResponse: + """ + Applies a reduce function to a datum stream. + The pascal case function name comes from the proto reduce_pb2_grpc.py file. + """ + + start, end = None, None + for metadata_key, metadata_value in context.invocation_metadata(): + if metadata_key == WIN_START_TIME: + start = metadata_value + elif metadata_key == WIN_END_TIME: + end = metadata_value + if not (start or end): + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details( + f"Expected to have all key/window_start_time/window_end_time; " + f"got start: {start}, end: {end}." + ) + yield reduce_pb2.ReduceResponse(results=[]) + return + + start_dt = datetime.fromtimestamp(int(start) / 1e3, timezone.utc) + end_dt = datetime.fromtimestamp(int(end) / 1e3, timezone.utc) + interval_window = IntervalWindow(start=start_dt, end=end_dt) + + datum_iterator = datum_generator(request_iterator=request_iterator) + + response_task = asyncio.create_task( + self.__async_reduce_handler(interval_window, datum_iterator) + ) + + # Save a reference to the result of this function, to avoid a + # task disappearing mid-execution. + self.background_tasks.add(response_task) + response_task.add_done_callback(lambda t: self.background_tasks.remove(t)) + + await response_task + results_futures = response_task.result() + + try: + for fut in results_futures: + await fut + yield reduce_pb2.ReduceResponse(results=fut.result()) + except Exception as e: + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(e.__str__()) + yield reduce_pb2.ReduceResponse(results=[]) + + async def __async_reduce_handler(self, interval_window, datum_iterator: AsyncIterable[Datum]): + callable_dict = {} + # iterate through all the values + async for d in datum_iterator: + keys = d.keys() + unified_key = DELIMITER.join(keys) + result = callable_dict.get(unified_key, None) + + if not result: + niter = NonBlockingIterator() + riter = niter.read_iterator() + # schedule an async task for consumer + # returns a future that will give the results later. + task = asyncio.create_task( + self.__invoke_reduce(keys, riter, Metadata(interval_window=interval_window)) + ) + # Save a reference to the result of this function, to avoid a + # task disappearing mid-execution. + self.background_tasks.add(task) + task.add_done_callback(lambda t: self.background_tasks.remove(t)) + result = ReduceResult(task, niter, keys) + + callable_dict[unified_key] = result + + await result.iterator.put(d) + + for unified_key in callable_dict: + await callable_dict[unified_key].iterator.put(STREAM_EOF) + + tasks = [] + for unified_key in callable_dict: + fut = callable_dict[unified_key].future + tasks.append(fut) + + return tasks + + async def __invoke_reduce( + self, keys: list[str], request_iterator: AsyncIterable[Datum], md: Metadata + ): + new_instance = self.__reduce_handler + # If the reduce handler is a class instance, create a new instance of it. + # It is required for a new key to be processed by a + # new instance of the reducer for a given window + # Otherwise the function handler can be called directly + if isinstance(self.__reduce_handler, _ReduceBuilderClass): + new_instance = self.__reduce_handler.create() + try: + msgs = await new_instance(keys, request_iterator, md) + except Exception as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + raise err + + datum_responses = [] + for msg in msgs: + datum_responses.append( + reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags) + ) + + return datum_responses + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> reduce_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto reduce_pb2_grpc.py file. + """ + return reduce_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/reducer/asynciter.py b/pynumaflow/reducer/servicer/asynciter.py similarity index 100% rename from pynumaflow/reducer/asynciter.py rename to pynumaflow/reducer/servicer/asynciter.py diff --git a/pynumaflow/shared/__init__.py b/pynumaflow/shared/__init__.py new file mode 100644 index 00000000..857f0a9f --- /dev/null +++ b/pynumaflow/shared/__init__.py @@ -0,0 +1,4 @@ +from pynumaflow.shared.server import NumaflowServer + + +__all__ = ["NumaflowServer"] diff --git a/pynumaflow/shared/server.py b/pynumaflow/shared/server.py new file mode 100644 index 00000000..d58af987 --- /dev/null +++ b/pynumaflow/shared/server.py @@ -0,0 +1,251 @@ +import contextlib +import multiprocessing +import os +import socket +from abc import ABCMeta, abstractmethod +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor + +import grpc +from pynumaflow._constants import ( + _LOGGER, + MULTIPROC_MAP_SOCK_ADDR, + UDFType, +) +from pynumaflow.exceptions import SocketError +from pynumaflow.info.server import get_sdk_version, write as info_server_write, get_metadata_env +from pynumaflow.info.types import ( + ServerInfo, + Protocol, + Language, + SERVER_INFO_FILE_PATH, + METADATA_ENVS, +) +from pynumaflow.proto.mapper import map_pb2_grpc +from pynumaflow.proto.sideinput import sideinput_pb2_grpc +from pynumaflow.proto.sinker import sink_pb2_grpc +from pynumaflow.proto.sourcer import source_pb2_grpc +from pynumaflow.proto.sourcetransformer import transform_pb2_grpc + + +class NumaflowServer(metaclass=ABCMeta): + """ + Provides an interface to write a Numaflow Server + which will be exposed over gRPC. + """ + + @abstractmethod + def start(self): + """ + Start the gRPC server + """ + pass + + +def write_info_file(protocol: Protocol, info_file=SERVER_INFO_FILE_PATH) -> None: + """ + Write the server info file to the given path. + """ + serv_info = ServerInfo( + protocol=protocol, + language=Language.PYTHON, + version=get_sdk_version(), + ) + info_server_write(server_info=serv_info, info_file=info_file) + + +def sync_server_start( + servicer, + bind_address: str, + max_threads: int, + server_options=None, + udf_type: str = UDFType.Map, + add_info_server=True, +): + """ + Utility function to start a sync grpc server instance. + """ + # Add the server information to the server info file, + # here we just write the protocol and language information + if add_info_server: + server_info = ServerInfo( + protocol=Protocol.UDS, + language=Language.PYTHON, + version=get_sdk_version(), + ) + else: + server_info = None + # Run a sync server instances + _run_server( + servicer=servicer, + bind_address=bind_address, + threads_per_proc=max_threads, + server_options=server_options, + udf_type=udf_type, + server_info=server_info, + ) + + +def _run_server( + servicer, + bind_address: str, + threads_per_proc, + server_options, + udf_type: str, + server_info=None, + server_info_file=SERVER_INFO_FILE_PATH, +) -> None: + """ + Starts the Synchronous server instance on the given UNIX socket + with given max threads. Wait for the server to terminate. + """ + server = grpc.server( + ThreadPoolExecutor( + max_workers=threads_per_proc, + ), + options=server_options, + ) + + # add the correct servicer to the server based on the UDF type + if udf_type == UDFType.Map: + map_pb2_grpc.add_MapServicer_to_server(servicer, server) + elif udf_type == UDFType.Sink: + sink_pb2_grpc.add_SinkServicer_to_server(servicer, server) + elif udf_type == UDFType.SourceTransformer: + transform_pb2_grpc.add_SourceTransformServicer_to_server(servicer, server) + elif udf_type == UDFType.Source: + source_pb2_grpc.add_SourceServicer_to_server(servicer, server) + elif udf_type == UDFType.SideInput: + sideinput_pb2_grpc.add_SideInputServicer_to_server(servicer, server) + + # bind the server to the UDS/TCP socket + server.add_insecure_port(bind_address) + # start the gRPC server + server.start() + + # Add the server information to the server info file if provided + if server_info and server_info_file: + info_server_write(server_info=server_info, info_file=server_info_file) + + _LOGGER.info("GRPC Server listening on: %s %d", bind_address, os.getpid()) + server.wait_for_termination() + + +def start_multiproc_server( + max_threads: int, servicer, process_count: int, server_options=None, udf_type: str = UDFType.Map +): + """ + Start N grpc servers in different processes where N = The number of CPUs or the + value of the env var NUM_CPU_MULTIPROC defined by the user. The max value + is set to 2 * CPU count. + Each server will be bound to a different port, and we will create equal number of + workers to handle each server. + On the client side there will be same number of connections as the number of servers. + """ + + _LOGGER.info( + "Starting new Multiproc server with num_procs: %s, num_threads per proc: %s", + process_count, + max_threads, + ) + workers = [] + server_ports = [] + for _ in range(process_count): + # Find a port to bind to for each server, thus sending the port number = 0 + # to the _reserve_port function so that kernel can find and return a free port + with _reserve_port(port_num=0) as port: + bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}" + _LOGGER.info("Starting server on port: %s", port) + # NOTE: It is imperative that the worker subprocesses be forked before + # any gRPC servers start up. See + # https://github.com/grpc/grpc/issues/16001 for more details. + worker = multiprocessing.Process( + target=_run_server, + args=(servicer, bind_address, max_threads, server_options, udf_type), + ) + worker.start() + workers.append(worker) + server_ports.append(port) + + # Convert the available ports to a comma separated string + ports = ",".join(map(str, server_ports)) + + serv_info = ServerInfo( + protocol=Protocol.TCP, + language=Language.PYTHON, + version=get_sdk_version(), + metadata=get_metadata_env(envs=METADATA_ENVS), + ) + # Add the PORTS metadata using the available ports + serv_info.metadata["SERV_PORTS"] = ports + info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) + + for worker in workers: + worker.join() + + +async def start_async_server( + server_async: grpc.aio.Server, sock_path: str, max_threads: int, cleanup_coroutines: list +): + """ + Starts the Async server instance on the given UNIX socket with given max threads. + Add the server graceful shutdown coroutine to the cleanup_coroutines list. + Wait for the server to terminate. + """ + await server_async.start() + + # Add the server information to the server info file + # Here we just write the protocol and language information + serv_info = ServerInfo( + protocol=Protocol.UDS, + language=Language.PYTHON, + version=get_sdk_version(), + ) + info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) + + # Log the server start + _LOGGER.info( + "New Async GRPC Server listening on: %s with max threads: %s", + sock_path, + max_threads, + ) + + async def server_graceful_shutdown(): + """ + Shuts down the server with 5 seconds of grace period. During the + grace period, the server won't accept new connections and allow + existing RPCs to continue within the grace period. + """ + _LOGGER.info("Starting graceful shutdown...") + await server_async.stop(5) + + cleanup_coroutines.append(server_graceful_shutdown()) + await server_async.wait_for_termination() + + +@contextlib.contextmanager +def _reserve_port(port_num: int) -> Iterator[int]: + """Find and reserve a port for all subprocesses to use.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0: + raise SocketError("Failed to set SO_REUSEADDR.") + try: + sock.bind(("", port_num)) + yield sock.getsockname()[1] + finally: + sock.close() + + +def checkInstance(instance, callable_type) -> bool: + """ + Check if the given instance is of the given callable_type. + """ + try: + if not isinstance(instance, callable_type): + return False + else: + return True + except Exception as e: + _LOGGER.error(e) + return False diff --git a/pynumaflow/sideinput/__init__.py b/pynumaflow/sideinput/__init__.py index 8a3c36f3..2058fe97 100644 --- a/pynumaflow/sideinput/__init__.py +++ b/pynumaflow/sideinput/__init__.py @@ -1,4 +1,5 @@ -from pynumaflow.sideinput._dtypes import Response -from pynumaflow.sideinput.server import SideInput +from pynumaflow._constants import SIDE_INPUT_DIR_PATH +from pynumaflow.sideinput._dtypes import Response, SideInput +from pynumaflow.sideinput.server import SideInputServer -__all__ = ["Response", "SideInput"] +__all__ = ["Response", "SideInput", "SideInputServer", "SIDE_INPUT_DIR_PATH"] diff --git a/pynumaflow/sideinput/_dtypes.py b/pynumaflow/sideinput/_dtypes.py index 86826578..6a68f420 100644 --- a/pynumaflow/sideinput/_dtypes.py +++ b/pynumaflow/sideinput/_dtypes.py @@ -1,5 +1,6 @@ +from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import TypeVar +from typing import TypeVar, Callable, Union R = TypeVar("R", bound="Response") @@ -36,3 +37,28 @@ def no_broadcast_message(cls: type[R]) -> R: This event will not be broadcasted. """ return Response(value=b"", no_broadcast=True) + + +class SideInput(metaclass=ABCMeta): + """ + Provides an interface to write a SideInput Class + which will be exposed over gRPC. + """ + + def __call__(self, *args, **kwargs): + """ + This allows to execute the handler function directly if + class instance is sent as a callable. + """ + return self.retrieve_handler(*args, **kwargs) + + @abstractmethod + def retrieve_handler(self) -> Response: + """ + This function is called when a Side Input request is received. + """ + pass + + +RetrieverHandlerCallable = Callable[[], Response] +RetrieverCallable = Union[SideInput, RetrieverHandlerCallable] diff --git a/pynumaflow/sideinput/server.py b/pynumaflow/sideinput/server.py index d786f0d7..ea6685e0 100644 --- a/pynumaflow/sideinput/server.py +++ b/pynumaflow/sideinput/server.py @@ -1,115 +1,91 @@ -import logging -import multiprocessing import os -from concurrent.futures import ThreadPoolExecutor -from typing import Callable - -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 - -from pynumaflow import setup_logging +from pynumaflow.shared import NumaflowServer +from pynumaflow.shared.server import sync_server_start +from pynumaflow.sideinput._dtypes import RetrieverCallable +from pynumaflow.sideinput.servicer.servicer import SideInputServicer from pynumaflow._constants import ( + MAX_THREADS, MAX_MESSAGE_SIZE, SIDE_INPUT_SOCK_PATH, + _LOGGER, + UDFType, + SIDE_INPUT_DIR_PATH, ) -from pynumaflow.sideinput import Response -from pynumaflow.sideinput.proto import sideinput_pb2, sideinput_pb2_grpc -from pynumaflow.types import NumaflowServicerContext - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) -RetrieverCallable = Callable[[], Response] -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - -class SideInput(sideinput_pb2_grpc.SideInputServicer): +class SideInputServer(NumaflowServer): """ - Provides an interface to write a User Defined Side Input (UDSideInput) - which will be exposed over gRPC. - + Class for a new Side Input Server instance. Args: - handler: Function callable following the type signature of RetrieverCallable - sock_path: Path to the UNIX Domain Socket + side_input_instance: The side input instance to be used for Side Input UDF + sock_path: The UNIX socket path to be used for the server max_message_size: The max message size in bytes the server can receive and send max_threads: The max number of threads to be spawned; - defaults to number of processors x 4 Example invocation: - >>> from typing import List - >>> from pynumaflow.sideinput import Response, SideInput - >>> def my_handler() -> Response: - ... response = Response.broadcast_message(b"hello") - ... return response - >>> grpc_server = SideInput(my_handler) - >>> grpc_server.start() - """ + import datetime + from pynumaflow.sideinput import Response, SideInputServer, SideInput + + class ExampleSideInput(SideInput): + def __init__(self): + self.counter = 0 + + def retrieve_handler(self) -> Response: + time_now = datetime.datetime.now() + # val is the value to be broadcasted + val = f"an example: {str(time_now)}" + self.counter += 1 + # broadcast every other time + if self.counter % 2 == 0: + # no_broadcast_message() is used to indicate that there is no broadcast + return Response.no_broadcast_message() + # broadcast_message() is used to indicate that there is a broadcast + return Response.broadcast_message(val.encode("utf-8")) - SIDE_INPUT_DIR_PATH = "/var/numaflow/side-inputs" + if __name__ == "__main__": + grpc_server = SideInputServer(ExampleSideInput()) + grpc_server.start() + + """ def __init__( self, - handler: RetrieverCallable, + side_input_instance: RetrieverCallable, sock_path=SIDE_INPUT_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, + side_input_dir_path=SIDE_INPUT_DIR_PATH, ): - self.__retrieve_handler: RetrieverCallable = handler self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] - def RetrieveSideInput( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sideinput_pb2.SideInputResponse: - """ - Applies a sideinput function for a retrieval request. - The pascal case function name comes from the proto sideinput_pb2_grpc.py file. - """ - # if there is an exception, we will mark all the responses as a failure - try: - rspn = self.__retrieve_handler() - except Exception as err: - err_msg = "RetrieveSideInputErr: %r" % err - _LOGGER.critical(err_msg, exc_info=True) - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(err)) - return sideinput_pb2.SideInputResponse(value=None, no_broadcast=True) - - return sideinput_pb2.SideInputResponse(value=rspn.value, no_broadcast=rspn.no_broadcast) + self.side_input_instance = side_input_instance + self.side_input_dir_path = side_input_dir_path + self.servicer = SideInputServicer(side_input_instance) - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sideinput_pb2.ReadyResponse: + def start(self): """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto sideinput_pb2_grpc.py file. - """ - return sideinput_pb2.ReadyResponse(ready=True) - - def start(self) -> None: + Starts the Synchronous gRPC server on the given UNIX socket with given max threads. """ - Starts the gRPC server on the given UNIX socket with given max threads. - """ - server = grpc.server( - ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options - ) - sideinput_pb2_grpc.add_SideInputServicer_to_server( - SideInput(self.__retrieve_handler), server - ) - server.add_insecure_port(self.sock_path) - server.start() + # Get the servicer instance based on the server type + side_input_servicer = self.servicer _LOGGER.info( - "Side Input gRPC Server listening on: %s with max threads: %s", + "Side Input GRPC Server listening on: %s with max threads: %s", self.sock_path, - self._max_threads, + self.max_threads, + ) + # Start the server + sync_server_start( + servicer=side_input_servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_options=self._server_options, + udf_type=UDFType.SideInput, + add_info_server=False, ) - server.wait_for_termination() diff --git a/pynumaflow/sideinput/servicer/__init__.py b/pynumaflow/sideinput/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sideinput/servicer/servicer.py b/pynumaflow/sideinput/servicer/servicer.py new file mode 100644 index 00000000..2f050149 --- /dev/null +++ b/pynumaflow/sideinput/servicer/servicer.py @@ -0,0 +1,45 @@ +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow._constants import ( + _LOGGER, +) +from pynumaflow.proto.sideinput import sideinput_pb2_grpc, sideinput_pb2 +from pynumaflow.sideinput._dtypes import RetrieverCallable +from pynumaflow.types import NumaflowServicerContext + + +class SideInputServicer(sideinput_pb2_grpc.SideInputServicer): + def __init__( + self, + handler: RetrieverCallable, + ): + self.__retrieve_handler: RetrieverCallable = handler + + def RetrieveSideInput( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> sideinput_pb2.SideInputResponse: + """ + Applies a sideinput function for a retrieval request. + The pascal case function name comes from the proto sideinput_pb2_grpc.py file. + """ + # if there is an exception, we will mark all the responses as a failure + try: + rspn = self.__retrieve_handler() + except Exception as err: + err_msg = "RetrieveSideInputErr: %r" % err + _LOGGER.critical(err_msg, exc_info=True) + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(err)) + return sideinput_pb2.SideInputResponse(value=None, no_broadcast=True) + + return sideinput_pb2.SideInputResponse(value=rspn.value, no_broadcast=rspn.no_broadcast) + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> sideinput_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto sideinput_pb2_grpc.py file. + """ + return sideinput_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/sinker/__init__.py b/pynumaflow/sinker/__init__.py index c6b5e679..4df6f270 100644 --- a/pynumaflow/sinker/__init__.py +++ b/pynumaflow/sinker/__init__.py @@ -1,5 +1,6 @@ -from pynumaflow.sinker._dtypes import Response, Responses, Datum -from pynumaflow.sinker.async_sink import AsyncSinker -from pynumaflow.sinker.server import Sinker +from pynumaflow.sinker.async_server import SinkAsyncServer +from pynumaflow.sinker.server import SinkServer -__all__ = ["Response", "Responses", "Datum", "Sinker", "AsyncSinker"] +from pynumaflow.sinker._dtypes import Response, Responses, Datum, Sinker + +__all__ = ["Response", "Responses", "Datum", "Sinker", "SinkServer", "SinkAsyncServer"] diff --git a/pynumaflow/sinker/_dtypes.py b/pynumaflow/sinker/_dtypes.py index 1a020ac7..5f6c5b12 100644 --- a/pynumaflow/sinker/_dtypes.py +++ b/pynumaflow/sinker/_dtypes.py @@ -1,6 +1,8 @@ +from abc import abstractmethod, ABCMeta from dataclasses import dataclass from datetime import datetime -from typing import TypeVar, Optional, Callable +from typing import TypeVar, Optional, Callable, Union +from collections.abc import AsyncIterable, Awaitable from collections.abc import Sequence, Iterator from warnings import warn @@ -161,4 +163,32 @@ def watermark(self) -> datetime: return self._watermark -SinkCallable = Callable[[Iterator[Datum]], Responses] +class Sinker(metaclass=ABCMeta): + """ + Provides an interface to write a Sinker + which will be exposed over a gRPC server. + + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + as the sinker_instance. + """ + return self.handler(*args, **kwargs) + + @abstractmethod + def handler(self, datums: Iterator[Datum]) -> Responses: + """ + Implement this handler function which implements the SinkCallable interface. + """ + pass + + +# SyncSinkCallable is a callable which can be used as a handler for the Synchronous UDSink. +SinkHandlerCallable = Callable[[Iterator[Datum]], Responses] +SyncSinkCallable = Union[Sinker, SinkHandlerCallable] + +# AsyncSinkCallable is a callable which can be used as a handler for the Asynchronous UDSink. +AsyncSinkHandlerCallable = Callable[[AsyncIterable[Datum]], Awaitable[Responses]] +AsyncSinkCallable = Union[Sinker, AsyncSinkHandlerCallable] diff --git a/pynumaflow/sinker/async_server.py b/pynumaflow/sinker/async_server.py new file mode 100644 index 00000000..6cb0eabc --- /dev/null +++ b/pynumaflow/sinker/async_server.py @@ -0,0 +1,105 @@ +import os + +import aiorun +import grpc + +from pynumaflow.sinker.servicer.async_servicer import AsyncSinkServicer +from pynumaflow.proto.sinker import sink_pb2_grpc + + +from pynumaflow._constants import ( + SINK_SOCK_PATH, + MAX_MESSAGE_SIZE, + MAX_THREADS, +) + +from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.sinker._dtypes import AsyncSinkCallable + + +class SinkAsyncServer(NumaflowServer): + """ + SinkAsyncServer is the main class to start a gRPC server for a sinker. + Create a new grpc Async Sink Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + sinker_instance: The sinker instance to be used for Sink UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example invocation: + import os + from collections.abc import AsyncIterable + from pynumaflow.sinker import Datum, Responses, Response, Sinker + from pynumaflow.sinker import SinkAsyncServer + from pynumaflow._constants import _LOGGER + + + class UserDefinedSink(Sinker): + async def handler(self, datums: AsyncIterable[Datum]) -> Responses: + responses = Responses() + async for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses + + + async def udsink_handler(datums: AsyncIterable[Datum]) -> Responses: + responses = Responses() + async for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses + + + if __name__ == "__main__": + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + sink_handler = UserDefinedSink() + else: + sink_handler = udsink_handler + grpc_server = SinkAsyncServer(sink_handler) + grpc_server.start() + """ + + def __init__( + self, + sinker_instance: AsyncSinkCallable, + sock_path=SINK_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=MAX_THREADS, + ): + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.sinker_instance = sinker_instance + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + self.servicer = AsyncSinkServicer(sinker_instance) + + def start(self): + """ + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context + """ + aiorun.run(self.aexec(), use_uvloop=True) + + async def aexec(self): + """ + Starts the Asynchronous gRPC server on the given UNIX socket with given max threads. + """ + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new server instance, add the servicer to it and start the server + server = grpc.aio.server() + server.add_insecure_port(self.sock_path) + sink_pb2_grpc.add_SinkServicer_to_server(self.servicer, server) + await start_async_server(server, self.sock_path, self.max_threads, self._server_options) diff --git a/pynumaflow/sinker/async_sink.py b/pynumaflow/sinker/async_sink.py deleted file mode 100644 index 8333710c..00000000 --- a/pynumaflow/sinker/async_sink.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -import multiprocessing -import os -from collections.abc import AsyncIterable - -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 - -from pynumaflow import setup_logging -from pynumaflow._constants import ( - SINK_SOCK_PATH, - MAX_MESSAGE_SIZE, -) -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH -from pynumaflow.sinker import Responses, Datum, Response -from pynumaflow.sinker._dtypes import SinkCallable -from pynumaflow.sinker.proto import sink_pb2_grpc, sink_pb2 -from pynumaflow.types import NumaflowServicerContext - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - - -async def datum_generator( - request_iterator: AsyncIterable[sink_pb2.SinkRequest], -) -> AsyncIterable[Datum]: - async for d in request_iterator: - datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - ) - yield datum - - -class AsyncSinker(sink_pb2_grpc.SinkServicer): - """ - Provides an interface to write an Async Sinker - which will be exposed over an Asyncronous gRPC server. - - Args: - handler: Function callable following the type signature of SinkCallable - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x 4 - - Example invocation: - >>> import aiorun - >>> from pynumaflow.sinker import Datum, Responses, Response, AsyncSinker - >>> async def my_handler(datums: AsyncIterable[Datum]) -> Responses: - ... responses = Responses() - ... async for msg in datums: - ... responses.append(Response.as_success(msg.id)) - ... return responses - >>> grpc_server = AsyncSinker(handler=my_handler) - >>> aiorun.run(grpc_server.start()) - """ - - def __init__( - self, - handler: SinkCallable, - sock_path=SINK_SOCK_PATH, - max_message_size=MAX_MESSAGE_SIZE, - max_threads=MAX_THREADS, - ): - self.background_tasks = set() - self.__sink_handler: SinkCallable = handler - self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - - self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), - ] - - async def SinkFn( - self, - request_iterator: AsyncIterable[sink_pb2.SinkRequest], - context: NumaflowServicerContext, - ) -> sink_pb2.SinkResponse: - """ - Applies a sink function to a list of datum elements. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator=request_iterator) - results = await self.__invoke_sink(datum_iterator) - - return sink_pb2.SinkResponse(results=results) - - async def __invoke_sink(self, datum_iterator: AsyncIterable[Datum]): - try: - rspns = await self.__sink_handler(datum_iterator) - except Exception as err: - err_msg = "UDSinkError: %r" % err - _LOGGER.critical(err_msg, exc_info=True) - rspns = Responses() - async for _datum in datum_iterator: - rspns.append(Response.as_failure(_datum.id, err_msg)) - responses = [] - for rspn in rspns: - responses.append( - sink_pb2.SinkResponse.Result(id=rspn.id, success=rspn.success, err_msg=rspn.err) - ) - return responses - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sink_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - return sink_pb2.ReadyResponse(ready=True) - - async def __serve_async(self, server) -> None: - sink_pb2_grpc.add_SinkServicer_to_server(AsyncSinker(self.__sink_handler), server) - server.add_insecure_port(self.sock_path) - _LOGGER.info("GRPC Async Server listening on: %s", self.sock_path) - await server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - async def server_graceful_shutdown(): - _LOGGER.info("Starting graceful shutdown...") - """ - Shuts down the server with 5 seconds of grace period. During the - grace period, the server won't accept new connections and allow - existing RPCs to continue within the grace period. - await server.stop(5) - """ - - self.cleanup_coroutines.append(server_graceful_shutdown()) - await server.wait_for_termination() - - async def start(self) -> None: - """Starts the Async gRPC server on the given UNIX socket.""" - server = grpc.aio.server(options=self._server_options) - await self.__serve_async(server) diff --git a/pynumaflow/sinker/server.py b/pynumaflow/sinker/server.py index 195cee10..8c95a861 100644 --- a/pynumaflow/sinker/server.py +++ b/pynumaflow/sinker/server.py @@ -1,139 +1,103 @@ -import logging -import multiprocessing import os -from concurrent.futures import ThreadPoolExecutor -from collections.abc import Iterator, Iterable -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow import setup_logging +from pynumaflow.sinker.servicer.sync_servicer import SyncSinkServicer + from pynumaflow._constants import ( SINK_SOCK_PATH, MAX_MESSAGE_SIZE, + MAX_THREADS, + _LOGGER, + UDFType, ) -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH -from pynumaflow.sinker import Responses, Datum, Response -from pynumaflow.sinker._dtypes import SinkCallable -from pynumaflow.sinker.proto import sink_pb2_grpc, sink_pb2 -from pynumaflow.types import NumaflowServicerContext - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) - - -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - - -def datum_generator(request_iterator: Iterable[sink_pb2.SinkRequest]) -> Iterable[Datum]: - for d in request_iterator: - datum = Datum( - keys=list(d.keys), - sink_msg_id=d.id, - value=d.value, - event_time=d.event_time.ToDatetime(), - watermark=d.watermark.ToDatetime(), - ) - yield datum +from pynumaflow.shared.server import NumaflowServer, sync_server_start +from pynumaflow.sinker._dtypes import SyncSinkCallable -class Sinker(sink_pb2_grpc.SinkServicer): + +class SinkServer(NumaflowServer): """ - Provides an interface to write a Sinker - which will be exposed over gRPC. - - Args: - handler: Function callable following the type signature of SinkCallable - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x 4 - - Example invocation: - >>> from typing import List - >>> from pynumaflow.sinker import Datum, Responses, Response, Sinker - >>> def my_handler(datums: Iterator[Datum]) -> Responses: - ... responses = Responses() - ... for msg in datums: - ... responses.append(Response.as_success(msg.id)) - ... return responses - >>> grpc_server = Sinker(handler=my_handler) - >>> grpc_server.start() + SinkServer is the main class to start a gRPC server for a sinker. """ def __init__( self, - handler: SinkCallable, + sinker_instance: SyncSinkCallable, sock_path=SINK_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__sink_handler: SinkCallable = handler + """ + Create a new grpc Sink Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + sinker_instance: The sinker instance to be used for Sink UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + Example invocation: + import os + from collections.abc import Iterator + + from pynumaflow.sinker import Datum, Responses, Response, SinkServer + from pynumaflow.sinker import Sinker + from pynumaflow._constants import _LOGGER + + class UserDefinedSink(Sinker): + def handler(self, datums: Iterator[Datum]) -> Responses: + responses = Responses() + for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses + + def udsink_handler(datums: Iterator[Datum]) -> Responses: + responses = Responses() + for msg in datums: + _LOGGER.info("User Defined Sink %s", msg.value.decode("utf-8")) + responses.append(Response.as_success(msg.id)) + return responses + + if __name__ == "__main__": + invoke = os.getenv("INVOKE", "func_handler") + if invoke == "class": + sink_handler = UserDefinedSink() + else: + sink_handler = udsink_handler + grpc_server = SinkServer(sink_handler) + grpc_server.start() + + """ self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.sinker_instance = sinker_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] - def SinkFn( - self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext - ) -> sink_pb2.SinkResponse: - """ - Applies a sink function to a list of datum elements. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - # if there is an exception, we will mark all the responses as a failure - datum_iterator = datum_generator(request_iterator) - try: - rspns = self.__sink_handler(datum_iterator) - except Exception as err: - err_msg = "UDSinkError: %r" % err - _LOGGER.critical(err_msg, exc_info=True) - rspns = Responses() - for _datum in datum_iterator: - rspns.append(Response.as_failure(_datum.id, err_msg)) - - responses = [] - for rspn in rspns: - responses.append( - sink_pb2.SinkResponse.Result(id=rspn.id, success=rspn.success, err_msg=rspn.err) - ) - - return sink_pb2.SinkResponse(results=responses) - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> sink_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto sink_pb2_grpc.py file. - """ - return sink_pb2.ReadyResponse(ready=True) + self.servicer = SyncSinkServicer(sinker_instance) - def start(self) -> None: + def start(self): """ - Starts the gRPC server on the given UNIX socket with given max threads. + Starts the Synchronous gRPC server on the + given UNIX socket with given max threads. """ - server = grpc.server( - ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options - ) - sink_pb2_grpc.add_SinkServicer_to_server(Sinker(self.__sink_handler), server) - server.add_insecure_port(self.sock_path) - server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - _LOGGER.info( - "GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads + "Sync GRPC Sink listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, + ) + # Start the server + sync_server_start( + servicer=self.servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_options=self._server_options, + udf_type=UDFType.Sink, ) - server.wait_for_termination() diff --git a/pynumaflow/sinker/servicer/__init__.py b/pynumaflow/sinker/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sinker/servicer/async_servicer.py b/pynumaflow/sinker/servicer/async_servicer.py new file mode 100644 index 00000000..59d2364e --- /dev/null +++ b/pynumaflow/sinker/servicer/async_servicer.py @@ -0,0 +1,78 @@ +from collections.abc import AsyncIterable + +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.sinker._dtypes import Responses, Datum, Response +from pynumaflow.sinker._dtypes import SyncSinkCallable +from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +async def datum_generator( + request_iterator: AsyncIterable[sink_pb2.SinkRequest], +) -> AsyncIterable[Datum]: + async for d in request_iterator: + datum = Datum( + keys=list(d.keys), + sink_msg_id=d.id, + value=d.value, + event_time=d.event_time.ToDatetime(), + watermark=d.watermark.ToDatetime(), + ) + yield datum + + +class AsyncSinkServicer(sink_pb2_grpc.SinkServicer): + """ + This class is used to create a new grpc Sink servicer instance. + It implements the SinkServicer interface from the proto sink.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: SyncSinkCallable, + ): + self.__sink_handler: SyncSinkCallable = handler + self.cleanup_coroutines = [] + + async def SinkFn( + self, + request_iterator: AsyncIterable[sink_pb2.SinkRequest], + context: NumaflowServicerContext, + ) -> sink_pb2.SinkResponse: + """ + Applies a sink function to a list of datum elements. + The pascal case function name comes from the proto sink_pb2_grpc.py file. + """ + # if there is an exception, we will mark all the responses as a failure + datum_iterator = datum_generator(request_iterator=request_iterator) + results = await self.__invoke_sink(datum_iterator) + + return sink_pb2.SinkResponse(results=results) + + async def __invoke_sink(self, datum_iterator: AsyncIterable[Datum]): + try: + rspns = await self.__sink_handler(datum_iterator) + except Exception as err: + err_msg = "UDSinkError: %r" % err + _LOGGER.critical(err_msg, exc_info=True) + rspns = Responses() + async for _datum in datum_iterator: + rspns.append(Response.as_failure(_datum.id, err_msg)) + responses = [] + for rspn in rspns: + responses.append( + sink_pb2.SinkResponse.Result(id=rspn.id, success=rspn.success, err_msg=rspn.err) + ) + return responses + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> sink_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto sink_pb2_grpc.py file. + """ + return sink_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/sinker/servicer/sync_servicer.py b/pynumaflow/sinker/servicer/sync_servicer.py new file mode 100644 index 00000000..652a56c9 --- /dev/null +++ b/pynumaflow/sinker/servicer/sync_servicer.py @@ -0,0 +1,69 @@ +from collections.abc import Iterator, Iterable + +from google.protobuf import empty_pb2 as _empty_pb2 +from pynumaflow._constants import _LOGGER +from pynumaflow.sinker._dtypes import Responses, Datum, Response +from pynumaflow.sinker._dtypes import SyncSinkCallable +from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 +from pynumaflow.types import NumaflowServicerContext + + +def datum_generator(request_iterator: Iterable[sink_pb2.SinkRequest]) -> Iterable[Datum]: + for d in request_iterator: + datum = Datum( + keys=list(d.keys), + sink_msg_id=d.id, + value=d.value, + event_time=d.event_time.ToDatetime(), + watermark=d.watermark.ToDatetime(), + ) + yield datum + + +class SyncSinkServicer(sink_pb2_grpc.SinkServicer): + """ + This class is used to create a new grpc Sink servicer instance. + It implements the SinkServicer interface from the proto sink.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: SyncSinkCallable, + ): + self.__sink_handler: SyncSinkCallable = handler + + def SinkFn( + self, request_iterator: Iterator[sink_pb2.SinkRequest], context: NumaflowServicerContext + ) -> sink_pb2.SinkResponse: + """ + Applies a sink function to a list of datum elements. + The pascal case function name comes from the proto sink_pb2_grpc.py file. + """ + # if there is an exception, we will mark all the responses as a failure + datum_iterator = datum_generator(request_iterator) + try: + rspns = self.__sink_handler(datum_iterator) + except Exception as err: + err_msg = "UDSinkError: %r" % err + _LOGGER.critical(err_msg, exc_info=True) + rspns = Responses() + for _datum in datum_iterator: + rspns.append(Response.as_failure(_datum.id, err_msg)) + + responses = [] + for rspn in rspns: + responses.append( + sink_pb2.SinkResponse.Result(id=rspn.id, success=rspn.success, err_msg=rspn.err) + ) + + return sink_pb2.SinkResponse(results=responses) + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> sink_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto sink_pb2_grpc.py file. + """ + return sink_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/sourcer/__init__.py b/pynumaflow/sourcer/__init__.py index f846b5f7..a90d2e4d 100644 --- a/pynumaflow/sourcer/__init__.py +++ b/pynumaflow/sourcer/__init__.py @@ -6,9 +6,10 @@ Offset, PartitionsResponse, get_default_partitions, + Sourcer, ) -from pynumaflow.sourcer.async_server import AsyncSourcer -from pynumaflow.sourcer.server import Sourcer +from pynumaflow.sourcer.async_server import SourceAsyncServer +from pynumaflow.sourcer.server import SourceServer __all__ = [ "Message", @@ -16,8 +17,9 @@ "PendingResponse", "AckRequest", "Offset", - "AsyncSourcer", - "Sourcer", "PartitionsResponse", "get_default_partitions", + "SourceServer", + "Sourcer", + "SourceAsyncServer", ] diff --git a/pynumaflow/sourcer/_dtypes.py b/pynumaflow/sourcer/_dtypes.py index 8e042b28..9fd0e910 100644 --- a/pynumaflow/sourcer/_dtypes.py +++ b/pynumaflow/sourcer/_dtypes.py @@ -1,4 +1,5 @@ import os +from abc import ABCMeta, abstractmethod from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime @@ -203,11 +204,60 @@ def partitions(self) -> list[int]: return self._partitions +class Sourcer(metaclass=ABCMeta): + """ + Provides an interface to write a Sourcer + which will be exposed over an gRPC server. + + Args: + + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + """ + return self.handler(*args, **kwargs) + + @abstractmethod + def read_handler(self, datum: ReadRequest) -> Iterable[Message]: + """ + Implement this handler function which implements the SourceReadCallable interface. + read_handler is used to read the data from the source and send the data forward + for each read request we process num_records and increment the read_idx to indicate that + the message has been read and the same is added to the ack set + """ + pass + + @abstractmethod + def ack_handler(self, ack_request: AckRequest): + """ + The ack handler is used acknowledge the offsets that have been read, and remove them + from the to_ack_set + """ + pass + + @abstractmethod + def pending_handler(self) -> PendingResponse: + """ + The simple source always returns zero to indicate there is no pending record. + """ + pass + + @abstractmethod + def partitions_handler(self) -> PartitionsResponse: + """ + The simple source always returns zero to indicate there is no pending record. + """ + pass + + # Create default partition id from the environment variable "NUMAFLOW_REPLICA" DefaultPartitionId = int(os.getenv("NUMAFLOW_REPLICA", "0")) SourceReadCallable = Callable[[ReadRequest], Iterable[Message]] AsyncSourceReadCallable = Callable[[ReadRequest], AsyncIterable[Message]] SourceAckCallable = Callable[[AckRequest], None] +SourceCallable = Sourcer def get_default_partitions() -> list[int]: diff --git a/pynumaflow/sourcer/async_server.py b/pynumaflow/sourcer/async_server.py index 411349a8..98c026f6 100644 --- a/pynumaflow/sourcer/async_server.py +++ b/pynumaflow/sourcer/async_server.py @@ -1,237 +1,135 @@ -import logging -import multiprocessing import os -from collections.abc import AsyncIterable -from google.protobuf import timestamp_pb2 as _timestamp_pb2 +import aiorun import grpc -from google.protobuf import empty_pb2 as _empty_pb2 +from pynumaflow.sourcer.servicer.async_servicer import AsyncSourceServicer -from pynumaflow import setup_logging from pynumaflow._constants import ( - MAX_MESSAGE_SIZE, SOURCE_SOCK_PATH, + MAX_MESSAGE_SIZE, + MAX_THREADS, ) -from pynumaflow.sourcer import ReadRequest -from pynumaflow.sourcer._dtypes import AsyncSourceReadCallable, Offset, AckRequest -from pynumaflow.sourcer.proto import source_pb2 -from pynumaflow.sourcer.proto import source_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) +from pynumaflow.proto.sourcer import source_pb2_grpc -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", "4")) +from pynumaflow.shared.server import NumaflowServer, start_async_server +from pynumaflow.sourcer._dtypes import SourceCallable -class AsyncSourcer(source_pb2_grpc.SourceServicer): +class SourceAsyncServer(NumaflowServer): """ - Provides an interface to write an Asynchronous Sourcer - which will be exposed over gRPC. - - Args: - read_handler: Function callable following the type signature of AsyncSourceReadCallable - ack_handler: Function handler for AckFn - pending_handler: Function handler for PendingFn - partitions_handler: Function handler for PartitionsFn - - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.sourcer import Message, get_default_partitions \ - ... ReadRequest, AsyncSourcer, - ... import aiorun - ... async def read_handler(datum: ReadRequest) -> AsyncIterable[Message]: - ... payload = b"payload:test_mock_message" - ... keys = ["test_key"] - ... offset = mock_offset() - ... event_time = mock_event_time() - ... for i in range(10): - ... yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) - ... async def ack_handler(ack_request: AckRequest): - ... return - ... async def pending_handler() -> PendingResponse: - ... PendingResponse(count=10) - ... async def partitions_handler() -> PartitionsResponse: - ... return PartitionsResponse(partitions=get_default_partitions()) - >>> grpc_server = AsyncSourcer(read_handler=read_handler, - ... ack_handler=ack_handler, - ... pending_handler=pending_handler, - ... partitions_handler=partitions_handler) - >>> aiorun.run(grpc_server.start()) + Class for a new Async Source Server instance. """ def __init__( self, - read_handler: AsyncSourceReadCallable, - ack_handler, - pending_handler, - partitions_handler, + sourcer_instance: SourceCallable, sock_path=SOURCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__source_read_handler: AsyncSourceReadCallable = read_handler - self.__source_ack_handler = ack_handler - self.__source_pending_handler = pending_handler - self.__source_partitions_handler = partitions_handler + """ + Create a new grpc Async Source Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + sourcer_instance: The sourcer instance to be used for Source UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example invocation: + from collections.abc import AsyncIterable + from datetime import datetime + from pynumaflow.sourcer import ( + ReadRequest, + Message, + AckRequest, + PendingResponse, + Offset, + PartitionsResponse, + get_default_partitions, + Sourcer, + SourceAsyncServer, + ) + + class AsyncSource(Sourcer): + # AsyncSource is a class for User Defined Source implementation. + + def __init__(self): + # to_ack_set: Set to maintain a track of the offsets yet to be acknowledged + # read_idx : the offset idx till where the messages have been read + self.to_ack_set = set() + self.read_idx = 0 + + async def read_handler(self, datum: ReadRequest) -> AsyncIterable[Message]: + # read_handler is used to read the data from the source and send + # the data forward + # for each read request we process num_records and increment + # the read_idx to indicate that + # the message has been read and the same is added to the ack set + if self.to_ack_set: + return + + for x in range(datum.num_records): + yield Message( + payload=str(self.read_idx).encode(), + offset=Offset.offset_with_default_partition_id(str(self.read_idx).encode()), + event_time=datetime.now(), + ) + self.to_ack_set.add(str(self.read_idx)) + self.read_idx += 1 + + async def ack_handler(self, ack_request: AckRequest): + # The ack handler is used acknowledge the offsets that have been read, + # and remove them from the to_ack_set + for offset in ack_request.offset: + self.to_ack_set.remove(str(offset.offset, "utf-8")) + + async def pending_handler(self) -> PendingResponse: + # The simple source always returns zero to indicate there is no pending record. + return PendingResponse(count=0) + + async def partitions_handler(self) -> PartitionsResponse: + # The simple source always returns default partitions. + return PartitionsResponse(partitions=get_default_partitions()) + + if __name__ == "__main__": + ud_source = AsyncSource() + grpc_server = SourceAsyncServer(ud_source) + grpc_server.start() + + """ self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - # Collection for storing strong references to all running tasks. - # Event loop only keeps a weak reference, which can cause it to - # get lost during execution. - self.background_tasks = set() + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.sourcer_instance = sourcer_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] - async def ReadFn( - self, - request: source_pb2.ReadRequest, - context: NumaflowServicerContext, - ) -> AsyncIterable[source_pb2.ReadResponse]: - """ - Applies a Read function and returns a stream of datum responses. - The pascal case function name comes from the proto source_pb2_grpc.py file. - """ + self.servicer = AsyncSourceServicer(source_handler=sourcer_instance) - async for res in self.__invoke_source_read_stream( - ReadRequest( - num_records=request.request.num_records, - timeout_in_ms=request.request.timeout_in_ms, - ) - ): - yield source_pb2.ReadResponse(result=res) - - async def __invoke_source_read_stream(self, req: ReadRequest): - try: - async for msg in self.__source_read_handler(req): - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=msg.event_time) - yield source_pb2.ReadResponse.Result( - payload=msg.payload, - keys=msg.keys, - offset=msg.offset.as_dict, - event_time=event_time_timestamp, - ) - except Exception as err: - _LOGGER.critical("User-Defined Source ReadError ", exc_info=True) - raise err - - async def AckFn( - self, request: source_pb2.AckRequest, context: NumaflowServicerContext - ) -> source_pb2.AckResponse: - """ - Applies an Ack function in User Defined Source - """ - # proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - offsets = [] - for offset in request.request.offsets: - offsets.append(Offset(offset.offset, offset.partition_id)) - try: - await self.__invoke_ack(ack_req=offsets) - except Exception as e: - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(e)) - raise e - - return source_pb2.AckResponse() - - async def __invoke_ack(self, ack_req: list[Offset]): + def start(self): """ - Invokes the Source Ack Function. + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context """ - try: - await self.__source_ack_handler(AckRequest(offsets=ack_req)) - except Exception as err: - _LOGGER.critical("AckFn Error", exc_info=True) - raise err - return source_pb2.AckResponse.Result() - - async def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto source_pb2_grpc.py file. - """ - return source_pb2.ReadyResponse(ready=True) + aiorun.run(self.aexec(), use_uvloop=True) - async def PendingFn( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.PendingResponse: - """ - PendingFn returns the number of pending records - at the user defined source. - """ - try: - count = await self.__source_pending_handler() - except Exception as err: - _LOGGER.critical("PendingFn Error", exc_info=True) - raise err - resp = source_pb2.PendingResponse.Result(count=count.count) - return source_pb2.PendingResponse(result=resp) - - async def PartitionsFn( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.PartitionsResponse: + async def aexec(self): """ - PartitionsFn returns the partitions of the user defined source. + Starts the Async gRPC server on the given UNIX socket with given max threads """ - try: - partitions = await self.__source_partitions_handler() - except Exception as err: - _LOGGER.critical("PartitionsFn Error", exc_info=True) - raise err - resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) - return source_pb2.PartitionsResponse(result=resp) - - async def __serve_async(self, server) -> None: - source_pb2_grpc.add_SourceServicer_to_server( - AsyncSourcer( - read_handler=self.__source_read_handler, - ack_handler=self.__source_ack_handler, - pending_handler=self.__source_pending_handler, - partitions_handler=self.__source_partitions_handler, - ), - server, - ) + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server() server.add_insecure_port(self.sock_path) - _LOGGER.info("GRPC Async Server listening on: %s", self.sock_path) - await server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - async def server_graceful_shutdown(): - """ - Shuts down the server with 5 seconds of grace period. During the - grace period, the server won't accept new connections and allow - existing RPCs to continue within the grace period. - """ - _LOGGER.info("Starting graceful shutdown...") - await server.stop(5) - - self.cleanup_coroutines.append(server_graceful_shutdown()) - await server.wait_for_termination() - - async def start(self) -> None: - """Starts the Async gRPC server on the given UNIX socket.""" - server = grpc.aio.server(options=self._server_options) - await self.__serve_async(server) + source_servicer = self.servicer + source_pb2_grpc.add_SourceServicer_to_server(source_servicer, server) + await start_async_server(server, self.sock_path, self.max_threads, self._server_options) diff --git a/pynumaflow/sourcer/server.py b/pynumaflow/sourcer/server.py index b36b22ba..6177045f 100644 --- a/pynumaflow/sourcer/server.py +++ b/pynumaflow/sourcer/server.py @@ -1,231 +1,133 @@ -import logging -import multiprocessing import os -from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor - -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 - -from pynumaflow import setup_logging from pynumaflow._constants import ( - MAX_MESSAGE_SIZE, SOURCE_SOCK_PATH, + MAX_MESSAGE_SIZE, + MAX_THREADS, + _LOGGER, + UDFType, ) -from pynumaflow.sourcer import ReadRequest -from pynumaflow.sourcer._dtypes import ( - SourceReadCallable, - Offset, - AckRequest, - SourceAckCallable, -) -from pynumaflow.sourcer.proto import source_pb2 -from pynumaflow.sourcer.proto import source_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) +from pynumaflow.shared.server import NumaflowServer, sync_server_start +from pynumaflow.sourcer._dtypes import SourceCallable +from pynumaflow.sourcer.servicer.sync_servicer import SyncSourceServicer -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", "4")) - -class Sourcer(source_pb2_grpc.SourceServicer): +class SourceServer(NumaflowServer): """ - Provides an interface to write a Sourcer - which will be exposed over gRPC. - - Args: - read_handler: Function callable following the type signature of SyncSourceReadCallable - ack_handler: Function handler for AckFn - pending_handler: Function handler for PendingFn - partitions_handler: Function handler for PartitionsFn - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.sourcer import Message, get_default_partitions, PartitionsResponse \ - ... ReadRequest, Sourcer, AckRequest, - ... def read_handler(datum: ReadRequest) -> Iterable[Message]: - ... payload = b"payload:test_mock_message" - ... keys = ["test_key"] - ... offset = mock_offset() - ... event_time = mock_event_time() - ... for i in range(10): - ... yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) - ... def ack_handler(ack_request: AckRequest): - ... return - ... def pending_handler() -> PendingResponse: - ... PendingResponse(count=10) - ... def partitions_handler() -> PartitionsResponse: - ... return PartitionsResponse(partitions=get_default_partitions()) - >>> grpc_server = Sourcer(read_handler=read_handler, - ... ack_handler=ack_handler, - ... pending_handler=pending_handler, - ... partitions_handler=partition_handler,) - >>> grpc_server.start() + Class for a new Source Server instance. """ def __init__( self, - read_handler: SourceReadCallable, - ack_handler: SourceAckCallable, - pending_handler, - partitions_handler, + sourcer_instance: SourceCallable, sock_path=SOURCE_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__source_read_handler: SourceReadCallable = read_handler - self.__source_ack_handler: SourceAckCallable = ack_handler - self.__source_pending_handler = pending_handler - self.__source_partitions_handler = partitions_handler + """ + Create a new grpc Source Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + sourcer_instance: The sourcer instance to be used for Source UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example invocation: + from collections.abc import Iterable + from datetime import datetime + + from pynumaflow.sourcer import ( + ReadRequest, + Message, + AckRequest, + PendingResponse, + Offset, + PartitionsResponse, + get_default_partitions, + Sourcer, + SourceServer, + ) + + class SimpleSource(Sourcer): + # SimpleSource is a class for User Defined Source implementation. + + def __init__(self): + # to_ack_set: Set to maintain a track of the offsets yet to be acknowledged + # read_idx : the offset idx till where the messages have been read + self.to_ack_set = set() + self.read_idx = 0 + + def read_handler(self, datum: ReadRequest) -> Iterable[Message]: + # read_handler is used to read the data from the source and + # send the data forward + # for each read request we process num_records and increment the + # read_idx to indicate that + # the message has been read and the same is added to the ack set + if self.to_ack_set: + return + + for x in range(datum.num_records): + yield Message( + payload=str(self.read_idx).encode(), + offset=Offset.offset_with_default_partition_id(str(self.read_idx).encode()), + event_time=datetime.now(), + ) + self.to_ack_set.add(str(self.read_idx)) + self.read_idx += 1 + + def ack_handler(self, ack_request: AckRequest): + # The ack handler is used acknowledge the offsets that have been + # read, and remove them + # from the to_ack_set + for offset in ack_request.offset: + self.to_ack_set.remove(str(offset.offset, "utf-8")) + + def pending_handler(self) -> PendingResponse: + # The simple source always returns zero to indicate there is no pending record. + return PendingResponse(count=0) + + def partitions_handler(self) -> PartitionsResponse: + # The simple source always returns zero to indicate there is no pending record. + return PartitionsResponse(partitions=get_default_partitions()) + + if __name__ == "__main__": + ud_source = SimpleSource() + grpc_server = SourceServer(ud_source) + grpc_server.start() + """ self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads - self.cleanup_coroutines = [] - # Collection for storing strong references to all running tasks. - # Event loop only keeps a weak reference, which can cause it to - # get lost during execution. - self.background_tasks = set() + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.sourcer_instance = sourcer_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ] - def ReadFn( - self, - request: source_pb2.ReadRequest, - context: NumaflowServicerContext, - ) -> Iterable[source_pb2.ReadResponse]: - """ - Applies a Read function to a datum stream in streaming mode. - The pascal case function name comes from the proto source_pb2_grpc.py file. - """ - - for res in self.__invoke_source_read_stream( - ReadRequest( - num_records=request.request.num_records, - timeout_in_ms=request.request.timeout_in_ms, - ) - ): - yield source_pb2.ReadResponse(result=res) - - def __invoke_source_read_stream(self, req: ReadRequest): - try: - for msg in self.__source_read_handler(req): - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=msg.event_time) - yield source_pb2.ReadResponse.Result( - payload=msg.payload, - keys=msg.keys, - offset=msg.offset.as_dict, - event_time=event_time_timestamp, - ) - except Exception as err: - _LOGGER.critical("User-Defined Source ReadError ", exc_info=True) - raise err - - def AckFn( - self, request: source_pb2.AckRequest, context: NumaflowServicerContext - ) -> source_pb2.AckResponse: - """ - Applies an Ack function in User Defined Source - """ - # proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - offsets = [] - for offset in request.request.offsets: - offsets.append(Offset(offset.offset, offset.partition_id)) - try: - self.__invoke_ack(ack_req=offsets) - except Exception as e: - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(e)) - raise e - - return source_pb2.AckResponse() - - def __invoke_ack(self, ack_req: list[Offset]): - """ - Invokes the Source Ack Function. - """ - try: - self.__source_ack_handler(AckRequest(offsets=ack_req)) - except Exception as err: - _LOGGER.critical("AckFn Error", exc_info=True) - raise err - return source_pb2.AckResponse.Result() - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto source_pb2_grpc.py file. - """ - return source_pb2.ReadyResponse(ready=True) + self.servicer = SyncSourceServicer(source_handler=sourcer_instance) - def PendingFn( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.PendingResponse: + def start(self): """ - PendingFn returns the number of pending records - at the user defined source. + Starts the Synchronous Source gRPC server on the given + UNIX socket with given max threads. """ - try: - count = self.__source_pending_handler() - except Exception as err: - _LOGGER.critical("PendingFn error", exc_info=True) - raise err - resp = source_pb2.PendingResponse.Result(count=count.count) - return source_pb2.PendingResponse(result=resp) - - def PartitionsFn( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> source_pb2.PartitionsResponse: - """ - Partitions returns the partitions associated with the source, will be used by - the platform to determine the partitions to which the watermark should be published. - If the source doesn't have partitions, get_default_partitions() can be used to - return the default partitions. In most cases, the get_default_partitions() - should be enough; the cases where we need to implement custom partitions_handler() - is in a case like Kafka, where a reader can read from multiple Kafka partitions. - """ - try: - partitions = self.__source_partitions_handler() - except Exception as err: - _LOGGER.critical("PartitionFn error", exc_info=True) - raise err - resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) - return source_pb2.PartitionsResponse(result=resp) - - def start(self) -> None: - """ - Starts the gRPC server on the given UNIX socket with given max threads. - """ - server = grpc.server( - ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options - ) - source_pb2_grpc.add_SourceServicer_to_server(self, server) - server.add_insecure_port(self.sock_path) - server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) + # Get the servicer instance + source_servicer = self.servicer _LOGGER.info( - "GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads + "Sync Source GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, + ) + # Start the sync server + sync_server_start( + servicer=source_servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_options=self._server_options, + udf_type=UDFType.Source, ) - server.wait_for_termination() diff --git a/pynumaflow/sourcer/servicer/__init__.py b/pynumaflow/sourcer/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py new file mode 100644 index 00000000..cdffae92 --- /dev/null +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -0,0 +1,129 @@ +from collections.abc import AsyncIterable +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.sourcer._dtypes import ReadRequest +from pynumaflow.sourcer._dtypes import Offset, AckRequest, SourceCallable +from pynumaflow.proto.sourcer import source_pb2 +from pynumaflow.proto.sourcer import source_pb2_grpc +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +class AsyncSourceServicer(source_pb2_grpc.SourceServicer): + """ + This class is used to create a new grpc Source servicer instance. + It implements the SourceServicer interface from the proto source.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__(self, source_handler: SourceCallable): + self.source_handler = source_handler + self.__source_read_handler = source_handler.read_handler + self.__source_ack_handler = source_handler.ack_handler + self.__source_pending_handler = source_handler.pending_handler + self.__source_partitions_handler = source_handler.partitions_handler + self.cleanup_coroutines = [] + + async def ReadFn( + self, + request: source_pb2.ReadRequest, + context: NumaflowServicerContext, + ) -> AsyncIterable[source_pb2.ReadResponse]: + """ + Applies a Read function and returns a stream of datum responses. + The pascal case function name comes from the proto source_pb2_grpc.py file. + """ + + async for res in self.__invoke_source_read_stream( + ReadRequest( + num_records=request.request.num_records, + timeout_in_ms=request.request.timeout_in_ms, + ) + ): + yield source_pb2.ReadResponse(result=res) + + async def __invoke_source_read_stream(self, req: ReadRequest): + try: + async for msg in self.__source_read_handler(req): + event_time_timestamp = _timestamp_pb2.Timestamp() + event_time_timestamp.FromDatetime(dt=msg.event_time) + yield source_pb2.ReadResponse.Result( + payload=msg.payload, + keys=msg.keys, + offset=msg.offset.as_dict, + event_time=event_time_timestamp, + ) + except Exception as err: + _LOGGER.critical("User-Defined Source ReadError ", exc_info=True) + raise err + + async def AckFn( + self, request: source_pb2.AckRequest, context: NumaflowServicerContext + ) -> source_pb2.AckResponse: + """ + Applies an Ack function in User Defined Source + """ + # proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer + # we need to explicitly convert it to list + offsets = [] + for offset in request.request.offsets: + offsets.append(Offset(offset.offset, offset.partition_id)) + try: + await self.__invoke_ack(ack_req=offsets) + except Exception as e: + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(e)) + raise e + + return source_pb2.AckResponse() + + async def __invoke_ack(self, ack_req: list[Offset]): + """ + Invokes the Source Ack Function. + """ + try: + await self.__source_ack_handler(AckRequest(offsets=ack_req)) + except Exception as err: + _LOGGER.critical("AckFn Error", exc_info=True) + raise err + return source_pb2.AckResponse.Result() + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto source_pb2_grpc.py file. + """ + return source_pb2.ReadyResponse(ready=True) + + async def PendingFn( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.PendingResponse: + """ + PendingFn returns the number of pending records + at the user defined source. + """ + try: + count = await self.__source_pending_handler() + except Exception as err: + _LOGGER.critical("PendingFn Error", exc_info=True) + raise err + resp = source_pb2.PendingResponse.Result(count=count.count) + return source_pb2.PendingResponse(result=resp) + + async def PartitionsFn( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.PartitionsResponse: + """ + PartitionsFn returns the partitions of the user defined source. + """ + try: + partitions = await self.__source_partitions_handler() + except Exception as err: + _LOGGER.critical("PartitionsFn Error", exc_info=True) + raise err + resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) + return source_pb2.PartitionsResponse(result=resp) diff --git a/pynumaflow/sourcer/servicer/sync_servicer.py b/pynumaflow/sourcer/servicer/sync_servicer.py new file mode 100644 index 00000000..824508c5 --- /dev/null +++ b/pynumaflow/sourcer/servicer/sync_servicer.py @@ -0,0 +1,143 @@ +from collections.abc import Iterable + +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow.sourcer._dtypes import ReadRequest +from pynumaflow.sourcer._dtypes import ( + SourceReadCallable, + Offset, + AckRequest, + SourceAckCallable, + SourceCallable, +) +from pynumaflow.proto.sourcer import source_pb2 +from pynumaflow.proto.sourcer import source_pb2_grpc +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +class SyncSourceServicer(source_pb2_grpc.SourceServicer): + """ + This class is used to create a new grpc Source servicer instance. + It implements the SourceServicer interface from the proto source.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + source_handler: SourceCallable, + ): + self.source_handler = source_handler + self.__source_read_handler: SourceReadCallable = source_handler.read_handler + self.__source_ack_handler: SourceAckCallable = source_handler.ack_handler + self.__source_pending_handler = source_handler.pending_handler + self.__source_partitions_handler = source_handler.partitions_handler + + def ReadFn( + self, + request: source_pb2.ReadRequest, + context: NumaflowServicerContext, + ) -> Iterable[source_pb2.ReadResponse]: + """ + Applies a Read function to a datum stream in streaming mode. + The pascal case function name comes from the proto source_pb2_grpc.py file. + """ + + for res in self.__invoke_source_read_stream( + ReadRequest( + num_records=request.request.num_records, + timeout_in_ms=request.request.timeout_in_ms, + ) + ): + yield source_pb2.ReadResponse(result=res) + + def __invoke_source_read_stream(self, req: ReadRequest): + try: + for msg in self.__source_read_handler(req): + event_time_timestamp = _timestamp_pb2.Timestamp() + event_time_timestamp.FromDatetime(dt=msg.event_time) + yield source_pb2.ReadResponse.Result( + payload=msg.payload, + keys=msg.keys, + offset=msg.offset.as_dict, + event_time=event_time_timestamp, + ) + except Exception as err: + _LOGGER.critical("User-Defined Source ReadError ", exc_info=True) + raise err + + def AckFn( + self, request: source_pb2.AckRequest, context: NumaflowServicerContext + ) -> source_pb2.AckResponse: + """ + Applies an Ack function in User Defined Source + """ + # proto repeated field(offsets) is of type google._upb._message.RepeatedScalarContainer + # we need to explicitly convert it to list + offsets = [] + for offset in request.request.offsets: + offsets.append(Offset(offset.offset, offset.partition_id)) + try: + self.__invoke_ack(ack_req=offsets) + except Exception as e: + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(e)) + raise e + + return source_pb2.AckResponse() + + def __invoke_ack(self, ack_req: list[Offset]): + """ + Invokes the Source Ack Function. + """ + try: + self.__source_ack_handler(AckRequest(offsets=ack_req)) + except Exception as err: + _LOGGER.critical("AckFn Error", exc_info=True) + raise err + return source_pb2.AckResponse.Result() + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto source_pb2_grpc.py file. + """ + return source_pb2.ReadyResponse(ready=True) + + def PendingFn( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.PendingResponse: + """ + PendingFn returns the number of pending records + at the user defined source. + """ + try: + count = self.__source_pending_handler() + except Exception as err: + _LOGGER.critical("PendingFn error", exc_info=True) + raise err + resp = source_pb2.PendingResponse.Result(count=count.count) + return source_pb2.PendingResponse(result=resp) + + def PartitionsFn( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> source_pb2.PartitionsResponse: + """ + Partitions returns the partitions associated with the source, will be used by + the platform to determine the partitions to which the watermark should be published. + If the source doesn't have partitions, get_default_partitions() can be used to + return the default partitions. In most cases, the get_default_partitions() + should be enough; the cases where we need to implement custom partitions_handler() + is in a case like Kafka, where a reader can read from multiple Kafka partitions. + """ + try: + partitions = self.__source_partitions_handler() + except Exception as err: + _LOGGER.critical("PartitionFn error", exc_info=True) + raise err + resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) + return source_pb2.PartitionsResponse(result=resp) diff --git a/pynumaflow/sourcetransformer/__init__.py b/pynumaflow/sourcetransformer/__init__.py index 4708603d..69f8018c 100644 --- a/pynumaflow/sourcetransformer/__init__.py +++ b/pynumaflow/sourcetransformer/__init__.py @@ -1,12 +1,19 @@ -from pynumaflow.sourcetransformer._dtypes import Message, Messages, Datum, DROP -from pynumaflow.sourcetransformer.multiproc_server import MultiProcSourceTransformer -from pynumaflow.sourcetransformer.server import SourceTransformer +from pynumaflow.sourcetransformer._dtypes import ( + Message, + Messages, + Datum, + DROP, + SourceTransformer, +) +from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer +from pynumaflow.sourcetransformer.server import SourceTransformServer __all__ = [ "Message", "Messages", "Datum", "DROP", + "SourceTransformServer", "SourceTransformer", - "MultiProcSourceTransformer", + "SourceTransformMultiProcServer", ] diff --git a/pynumaflow/sourcetransformer/_dtypes.py b/pynumaflow/sourcetransformer/_dtypes.py index b3242cd2..66e6978c 100644 --- a/pynumaflow/sourcetransformer/_dtypes.py +++ b/pynumaflow/sourcetransformer/_dtypes.py @@ -1,7 +1,8 @@ +from abc import ABCMeta, abstractmethod from collections.abc import Iterator, Sequence from dataclasses import dataclass from datetime import datetime -from typing import TypeVar, Callable +from typing import TypeVar, Callable, Union from warnings import warn from pynumaflow._constants import DROP @@ -172,4 +173,29 @@ def watermark(self) -> datetime: return self._watermark -SourceTransformCallable = Callable[[list[str], Datum], Messages] +class SourceTransformer(metaclass=ABCMeta): + """ + Provides an interface to write a Source Transformer + which will be exposed over a GRPC server. + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + as the source_transformer_instance. + """ + return self.handler(*args, **kwargs) + + @abstractmethod + def handler(self, keys: list[str], datum: Datum) -> Messages: + """ + Implement this handler function which implements the + SourceTransformCallable interface. + """ + pass + + +SourceTransformHandler = Callable[[list[str], Datum], Messages] +# SourceTransformCallable is the type of the handler function for the +# Source Transformer UDFunction. +SourceTransformCallable = Union[SourceTransformHandler, SourceTransformer] diff --git a/pynumaflow/sourcetransformer/multiproc_server.py b/pynumaflow/sourcetransformer/multiproc_server.py index 7aa58e9d..ace9aa1d 100644 --- a/pynumaflow/sourcetransformer/multiproc_server.py +++ b/pynumaflow/sourcetransformer/multiproc_server.py @@ -1,206 +1,127 @@ -import contextlib -import logging -import multiprocessing import os -import socket -from concurrent import futures -from collections.abc import Iterator -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from pynumaflow.sourcetransformer.servicer.server import SourceTransformServicer + +from pynumaflow.shared.server import start_multiproc_server -from pynumaflow import setup_logging from pynumaflow._constants import ( MAX_MESSAGE_SIZE, + SOURCE_TRANSFORMER_SOCK_PATH, + MAX_THREADS, + UDFType, + _PROCESS_COUNT, ) -from pynumaflow._constants import MULTIPROC_MAP_SOCK_ADDR -from pynumaflow.exceptions import SocketError -from pynumaflow.info.server import ( - get_sdk_version, - write as info_server_write, - get_metadata_env, -) -from pynumaflow.info.types import ( - ServerInfo, - Protocol, - Language, - SERVER_INFO_FILE_PATH, - METADATA_ENVS, -) -from pynumaflow.sourcetransformer import Datum + from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable -from pynumaflow.sourcetransformer.proto import transform_pb2 -from pynumaflow.sourcetransformer.proto import transform_pb2_grpc -from pynumaflow.types import NumaflowServicerContext -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) +from pynumaflow.shared import NumaflowServer -class MultiProcSourceTransformer(transform_pb2_grpc.SourceTransformServicer): +class SourceTransformMultiProcServer(NumaflowServer): """ - Provides an interface to write a Multi-Processor Source Transformer - which will be exposed over gRPC. - - Args: - - handler: Function callable following the type signature of SourceTransformCallable - max_message_size: The max message size in bytes the server can receive and send - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.sourcetransformer import Messages, Message \ - ... Datum, MultiProcSourceTransformer - >>> def transform_handler(key: [str], datum: Datum) -> Messages: - ... val = datum.value - ... new_event_time = datetime.time() - ... _ = datum.watermark - ... message_t_s = Messages(Message(val, event_time=new_event_time, keys=key)) - ... return message_t_s - ... - ... - >>> grpc_server = MultiProcSourceTransformer(handler=transform_handler) - >>> grpc_server.start() + Class for a new Source Transformer Server instance. """ def __init__( self, - handler: SourceTransformCallable, + source_transform_instance: SourceTransformCallable, + server_count: int = _PROCESS_COUNT, + sock_path=SOURCE_TRANSFORMER_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, + max_threads=MAX_THREADS, ): - self.__transform_handler: SourceTransformCallable = handler - self._max_message_size = max_message_size + """ + Create a new grpc Source Transformer Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + source_transform_instance: The source transformer instance to be used for + Source Transformer UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 + + Example invocation: + import datetime + import logging + + from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformServer + + # This is a simple User Defined Function example which receives a message, + # applies the following + # data transformation, and returns the message. + # If the message event time is before year 2022, drop the message + # with event time unchanged. + # If it's within year 2022, update the tag to "within_year_2022" and + # update the message event time to Jan 1st 2022. + # Otherwise, (exclusively after year 2022), update the tag to + # "after_year_2022" and update the + + + january_first_2022 = datetime.datetime.fromtimestamp(1640995200) + january_first_2023 = datetime.datetime.fromtimestamp(1672531200) + + + def my_handler(keys: list[str], datum: Datum) -> Messages: + val = datum.value + event_time = datum.event_time + messages = Messages() + + if event_time < january_first_2022: + logging.info("Got event time:%s, it is before 2022, so dropping", event_time) + messages.append(Message.to_drop(event_time)) + elif event_time < january_first_2023: + logging.info( + "Got event time:%s, it is within year 2022, so + forwarding to within_year_2022", + event_time, + ) + messages.append( + Message(value=val, event_time=january_first_2022, tags=["within_year_2022"]) + ) + else: + logging.info( + "Got event time:%s, it is after year 2022, so forwarding to + after_year_2022", event_time + ) + messages.append(Message(value=val, event_time=january_first_2023, + tags=["after_year_2022"])) + + return messages + + if __name__ == "__main__": + grpc_server = SourceTransformServer(my_handler) + grpc_server.start() + """ + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size + + self.source_transform_instance = source_transform_instance self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), ("grpc.so_reuseport", 1), ("grpc.so_reuseaddr", 1), ] - # Set the number of processes to be spawned to the number of CPUs or the value - # of the env var NUM_CPU_MULTIPROC defined by the user + # Set the number of processes to be spawned to the number of CPUs or + # the value of the env var NUM_CPU_MULTIPROC defined by the user # Setting the max value to 2 * CPU count - self._process_count = min( - int(os.getenv("NUM_CPU_MULTIPROC", str(os.cpu_count()))), 2 * os.cpu_count() - ) - self._threads_per_proc = int(os.getenv("MAX_THREADS", "4")) + # Used for multiproc server + self._process_count = min(server_count, 2 * _PROCESS_COUNT) + self.servicer = SourceTransformServicer(handler=source_transform_instance) - def SourceTransformFn( - self, request: transform_pb2.SourceTransformRequest, context: NumaflowServicerContext - ) -> transform_pb2.SourceTransformResponse: + def start(self): """ - Applies a function to each datum element. - The pascal case function name comes from the generated transform_pb2_grpc.py file. + Starts the Multiproc gRPC server on the given TCP sockets + with given max threads. """ - - # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: - msgts = self.__transform_handler( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(err)) - return transform_pb2.SourceTransformResponse(results=[]) - - datums = [] - for msgt in msgts: - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=msgt.event_time) - datums.append( - transform_pb2.SourceTransformResponse.Result( - keys=list(msgt.keys), - value=msgt.value, - tags=msgt.tags, - event_time=event_time_timestamp, - ) - ) - return transform_pb2.SourceTransformResponse(results=datums) - - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> transform_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto transform_pb2_grpc.py file. - """ - return transform_pb2.ReadyResponse(ready=True) - - def _run_server(self, bind_address): - """Start a server in a subprocess.""" - _LOGGER.info("Starting new server.") - server = grpc.server( - futures.ThreadPoolExecutor( - max_workers=self._threads_per_proc, - ), - options=self._server_options, + start_multiproc_server( + max_threads=self.max_threads, + servicer=self.servicer, + process_count=self._process_count, + server_options=self._server_options, + udf_type=UDFType.Map, ) - transform_pb2_grpc.add_SourceTransformServicer_to_server(self, server) - server.add_insecure_port(bind_address) - server.start() - _LOGGER.info("GRPC Multi-Processor Server listening on: %s %d", bind_address, os.getpid()) - server.wait_for_termination() - - @contextlib.contextmanager - def _reserve_port(self, port_num: int) -> Iterator[int]: - """Find and reserve a port for all subprocesses to use.""" - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 0: - raise SocketError("Failed to set SO_REUSEADDR.") - try: - sock.bind(("", port_num)) - yield sock.getsockname()[1] - finally: - sock.close() - - def start(self) -> None: - """ - Start N grpc servers in different processes where N = The number of CPUs or the - value of the env var NUM_CPU_MULTIPROC defined by the user. The max value - is set to 2 * CPU count. - Each server will be bound to a different port, and we will create equal number of - workers to handle each server. - On the client side there will be same number of connections as the number of servers. - """ - workers = [] - server_ports = [] - for _ in range(self._process_count): - # Find a port to bind to for each server, thus sending the port number = 0 - # to the _reserve_port function so that kernel can find and return a free port - with self._reserve_port(0) as port: - bind_address = f"{MULTIPROC_MAP_SOCK_ADDR}:{port}" - _LOGGER.info("Starting server on port: %s", port) - # NOTE: It is imperative that the worker subprocesses be forked before - # any gRPC servers start up. See - # https://github.com/grpc/grpc/issues/16001 for more details. - worker = multiprocessing.Process(target=self._run_server, args=(bind_address,)) - worker.start() - workers.append(worker) - server_ports.append(port) - - # Convert the available ports to a comma separated string - ports = ",".join(map(str, server_ports)) - - serv_info = ServerInfo( - protocol=Protocol.TCP, - language=Language.PYTHON, - version=get_sdk_version(), - metadata=get_metadata_env(envs=METADATA_ENVS), - ) - # Add the PORTS metadata using the available ports - serv_info.metadata["SERV_PORTS"] = ports - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) - - for worker in workers: - worker.join() diff --git a/pynumaflow/sourcetransformer/server.py b/pynumaflow/sourcetransformer/server.py index 5320668e..a2dd6c57 100644 --- a/pynumaflow/sourcetransformer/server.py +++ b/pynumaflow/sourcetransformer/server.py @@ -1,143 +1,120 @@ -import logging -import multiprocessing import os -from concurrent.futures import ThreadPoolExecutor -import grpc -from google.protobuf import empty_pb2 as _empty_pb2 -from google.protobuf import timestamp_pb2 as _timestamp_pb2 - -from pynumaflow import setup_logging from pynumaflow._constants import ( - SOURCE_TRANSFORMER_SOCK_PATH, MAX_MESSAGE_SIZE, + SOURCE_TRANSFORMER_SOCK_PATH, + MAX_THREADS, + _LOGGER, + UDFType, ) -from pynumaflow.info.server import get_sdk_version, write as info_server_write -from pynumaflow.info.types import ServerInfo, Protocol, Language, SERVER_INFO_FILE_PATH -from pynumaflow.sourcetransformer import Datum +from pynumaflow.shared import NumaflowServer +from pynumaflow.shared.server import sync_server_start from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable -from pynumaflow.sourcetransformer.proto import transform_pb2 -from pynumaflow.sourcetransformer.proto import transform_pb2_grpc -from pynumaflow.types import NumaflowServicerContext - -_LOGGER = setup_logging(__name__) -if os.getenv("PYTHONDEBUG"): - _LOGGER.setLevel(logging.DEBUG) +from pynumaflow.sourcetransformer.servicer.server import SourceTransformServicer -_PROCESS_COUNT = multiprocessing.cpu_count() -MAX_THREADS = int(os.getenv("MAX_THREADS", 0)) or (_PROCESS_COUNT * 4) - -class SourceTransformer(transform_pb2_grpc.SourceTransformServicer): +class SourceTransformServer(NumaflowServer): """ - Provides an interface to write a Source Transformer - which will be exposed over a Synchronous gRPC server. - - Args: - handler: Function callable following the type signature of SourceTransformCallable - sock_path: Path to the UNIX Domain Socket - max_message_size: The max message size in bytes the server can receive and send - max_threads: The max number of threads to be spawned; - defaults to number of processors x4 - - Example invocation: - >>> from typing import Iterator - >>> from pynumaflow.sourcetransformer import Messages, Message \ - ... Datum, SourceTransformer - >>> def transform_handler(key: [str], datum: Datum) -> Messages: - ... val = datum.value - ... new_event_time = datetime.time() - ... _ = datum.watermark - ... message_t_s = Messages(Message(val, event_time=new_event_time, keys=key)) - ... return message_t_s - ... - >>> grpc_server = SourceTransformer(handler=transform_handler) - >>> grpc_server.start() + Class for a new Source Transformer Server instance. """ def __init__( self, - handler: SourceTransformCallable, + source_transform_instance: SourceTransformCallable, sock_path=SOURCE_TRANSFORMER_SOCK_PATH, max_message_size=MAX_MESSAGE_SIZE, max_threads=MAX_THREADS, ): - self.__transform_handler: SourceTransformCallable = handler - self.sock_path = f"unix://{sock_path}" - self._max_message_size = max_message_size - self._max_threads = max_threads + """ + Create a new grpc Source Transformer Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + source_transform_instance: The source transformer instance to be used for + Source Transformer UDF + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to number of processors x4 - self._server_options = [ - ("grpc.max_send_message_length", self._max_message_size), - ("grpc.max_receive_message_length", self._max_message_size), - ] + Example Invocation: + + import datetime + import logging + + from pynumaflow.sourcetransformer import Messages, Message, Datum, SourceTransformServer + # This is a simple User Defined Function example which receives a message, + # applies the following + # data transformation, and returns the message. + # If the message event time is before year 2022, drop the message with event time unchanged. + # If it's within year 2022, update the tag to "within_year_2022" and + # update the message event time to Jan 1st 2022. + # Otherwise, (exclusively after year 2022), update the tag to + # "after_year_2022" and update the + # message event time to Jan 1st 2023. + + january_first_2022 = datetime.datetime.fromtimestamp(1640995200) + january_first_2023 = datetime.datetime.fromtimestamp(1672531200) - def SourceTransformFn( - self, request: transform_pb2.SourceTransformRequest, context: NumaflowServicerContext - ) -> transform_pb2.SourceTransformResponse: - """ - Applies a function to each datum element. - The pascal case function name comes from the generated transform_pb2_grpc.py file. - """ - # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer - # we need to explicitly convert it to list - try: - msgts = self.__transform_handler( - list(request.keys), - Datum( - keys=list(request.keys), - value=request.value, - event_time=request.event_time.ToDatetime(), - watermark=request.watermark.ToDatetime(), - ), - ) - except Exception as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - context.set_code(grpc.StatusCode.UNKNOWN) - context.set_details(str(err)) - return transform_pb2.SourceTransformResponse(results=[]) - - datums = [] - for msgt in msgts: - event_time_timestamp = _timestamp_pb2.Timestamp() - event_time_timestamp.FromDatetime(dt=msgt.event_time) - datums.append( - transform_pb2.SourceTransformResponse.Result( - keys=list(msgt.keys), - value=msgt.value, - tags=msgt.tags, - event_time=event_time_timestamp, + def my_handler(keys: list[str], datum: Datum) -> Messages: + val = datum.value + event_time = datum.event_time + messages = Messages() + + if event_time < january_first_2022: + logging.info("Got event time:%s, it is before 2022, so dropping", event_time) + messages.append(Message.to_drop(event_time)) + elif event_time < january_first_2023: + logging.info( + "Got event time:%s, it is within year 2022, so forwarding to within_year_2022", + event_time, + ) + messages.append( + Message(value=val, event_time=january_first_2022, + tags=["within_year_2022"]) ) - ) - return transform_pb2.SourceTransformResponse(results=datums) + else: + logging.info( + "Got event time:%s, it is after year 2022, so forwarding to + after_year_2022", event_time + ) + messages.append(Message(value=val, event_time=january_first_2023, + tags=["after_year_2022"])) - def IsReady( - self, request: _empty_pb2.Empty, context: NumaflowServicerContext - ) -> transform_pb2.ReadyResponse: - """ - IsReady is the heartbeat endpoint for gRPC. - The pascal case function name comes from the proto transform_pb2_grpc.py file. + return messages + + + if __name__ == "__main__": + grpc_server = SourceTransformServer(my_handler) + grpc_server.start() """ - return transform_pb2.ReadyResponse(ready=True) + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, int(os.getenv("MAX_THREADS", "4"))) + self.max_message_size = max_message_size - def start(self) -> None: + self.source_transform_instance = source_transform_instance + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + self.servicer = SourceTransformServicer(handler=source_transform_instance) + + def start(self): """ - Starts the gRPC server on the given UNIX socket with given max threads. + Starts the Synchronous gRPC server on the given UNIX socket with given max threads. """ - server = grpc.server( - ThreadPoolExecutor(max_workers=self._max_threads), options=self._server_options - ) - transform_pb2_grpc.add_SourceTransformServicer_to_server(self, server) - server.add_insecure_port(self.sock_path) - server.start() - serv_info = ServerInfo( - protocol=Protocol.UDS, - language=Language.PYTHON, - version=get_sdk_version(), - ) - info_server_write(server_info=serv_info, info_file=SERVER_INFO_FILE_PATH) _LOGGER.info( - "GRPC Server listening on: %s with max threads: %s", self.sock_path, self._max_threads + "Sync GRPC Server listening on: %s with max threads: %s", + self.sock_path, + self.max_threads, + ) + # Start the sync server + sync_server_start( + servicer=self.servicer, + bind_address=self.sock_path, + max_threads=self.max_threads, + server_options=self._server_options, + udf_type=UDFType.SourceTransformer, ) - server.wait_for_termination() diff --git a/pynumaflow/sourcetransformer/servicer/__init__.py b/pynumaflow/sourcetransformer/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/sourcetransformer/servicer/server.py b/pynumaflow/sourcetransformer/servicer/server.py new file mode 100644 index 00000000..6f803fe2 --- /dev/null +++ b/pynumaflow/sourcetransformer/servicer/server.py @@ -0,0 +1,73 @@ +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 + +from pynumaflow.sourcetransformer import Datum +from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable +from pynumaflow.proto.sourcetransformer import transform_pb2 +from pynumaflow.proto.sourcetransformer import transform_pb2_grpc +from pynumaflow.types import NumaflowServicerContext +from pynumaflow._constants import _LOGGER + + +class SourceTransformServicer(transform_pb2_grpc.SourceTransformServicer): + """ + This class is used to create a new grpc SourceTransform servicer instance. + It implements the SourceTransformServicer interface from the proto transform.proto file. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: SourceTransformCallable, + ): + self.__transform_handler: SourceTransformCallable = handler + + def SourceTransformFn( + self, request: transform_pb2.SourceTransformRequest, context: NumaflowServicerContext + ) -> transform_pb2.SourceTransformResponse: + """ + Applies a function to each datum element. + The pascal case function name comes from the generated transform_pb2_grpc.py file. + """ + + # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer + # we need to explicitly convert it to list + try: + msgts = self.__transform_handler( + list(request.keys), + Datum( + keys=list(request.keys), + value=request.value, + event_time=request.event_time.ToDatetime(), + watermark=request.watermark.ToDatetime(), + ), + ) + except Exception as err: + _LOGGER.critical("UDFError, re-raising the error", exc_info=True) + context.set_code(grpc.StatusCode.UNKNOWN) + context.set_details(str(err)) + return transform_pb2.SourceTransformResponse(results=[]) + + datums = [] + for msgt in msgts: + event_time_timestamp = _timestamp_pb2.Timestamp() + event_time_timestamp.FromDatetime(dt=msgt.event_time) + datums.append( + transform_pb2.SourceTransformResponse.Result( + keys=list(msgt.keys), + value=msgt.value, + tags=msgt.tags, + event_time=event_time_timestamp, + ) + ) + return transform_pb2.SourceTransformResponse(results=datums) + + def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> transform_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto transform_pb2_grpc.py file. + """ + return transform_pb2.ReadyResponse(ready=True) diff --git a/pyproject.toml b/pyproject.toml index 7a315da3..9f198883 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ google-cloud = "^0.34.0" google-api-core = "^2.11.0" protobuf = ">=3.20,<5.0" aiorun = "^2023.7" +uvloop = "^0.19.0" [tool.poetry.group.dev] optional = true diff --git a/tests/map/test_async_mapper.py b/tests/map/test_async_mapper.py index eacbe4ab..1d3fa7f2 100644 --- a/tests/map/test_async_mapper.py +++ b/tests/map/test_async_mapper.py @@ -10,12 +10,12 @@ from pynumaflow import setup_logging from pynumaflow.mapper import ( - AsyncMapper, Datum, Messages, Message, ) -from pynumaflow.mapper.proto import map_pb2_grpc, map_pb2 +from pynumaflow.mapper.async_server import MapAsyncServer +from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from tests.testing_utils import ( mock_event_time, mock_watermark, @@ -52,7 +52,7 @@ def request_generator(count, request, resetkey: bool = False): _s: Server = None -_channel = grpc.insecure_channel("localhost:50056") +_channel = grpc.insecure_channel("unix:///tmp/async_map.sock") _loop = None @@ -62,14 +62,15 @@ def startup_callable(loop): def new_async_mapper(): - udfs = AsyncMapper(handler=async_map_handler) + server = MapAsyncServer(mapper_instance=async_map_handler) + udfs = server.servicer return udfs -async def start_server(udfs: AsyncMapper): +async def start_server(udfs): server = grpc.aio.server() map_pb2_grpc.add_MapServicer_to_server(udfs, server) - listen_addr = "[::]:50056" + listen_addr = "unix:///tmp/async_map.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -90,7 +91,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50056") as channel: + with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -108,7 +109,7 @@ def tearDownClass(cls) -> None: LOGGER.error(e) def test_run_server(self) -> None: - with grpc.insecure_channel("localhost:50056") as channel: + with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: stub = map_pb2_grpc.MapStub(channel) event_time_timestamp = _timestamp_pb2.Timestamp() event_time_timestamp.FromDatetime(dt=mock_event_time()) @@ -198,7 +199,7 @@ def test_map_grpc_error(self) -> None: self.assertIsNotNone(grpcException) def test_is_ready(self) -> None: - with grpc.insecure_channel("localhost:50056") as channel: + with grpc.insecure_channel("unix:///tmp/async_map.sock") as channel: stub = map_pb2_grpc.MapStub(channel) request = _empty_pb2.Empty() @@ -210,6 +211,10 @@ def test_is_ready(self) -> None: self.assertTrue(response.ready) + def test_invalid_input(self): + with self.assertRaises(TypeError): + MapAsyncServer() + def __stub(self): return map_pb2_grpc.MapStub(_channel) diff --git a/tests/map/test_messages.py b/tests/map/test_messages.py index e3a29027..b2edbad7 100644 --- a/tests/map/test_messages.py +++ b/tests/map/test_messages.py @@ -1,6 +1,6 @@ import unittest -from pynumaflow.mapper import Messages, Message, DROP +from pynumaflow.mapper import Messages, Message, DROP, Mapper, Datum from tests.testing_utils import mock_message @@ -90,5 +90,30 @@ def test_err(self): msgts[:1] +class ExampleMapper(Mapper): + def handler(self, keys: list[str], datum: Datum) -> Messages: + messages = Messages() + messages.append(Message(mock_message(), keys=keys)) + return messages + + +class TestMapClass(unittest.TestCase): + def setUp(self) -> None: + # Create a map class instance + self.mapper_instance = ExampleMapper() + + def test_map_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.mapper_instance([], None) + self.assertEqual(mock_message(), ret[0].value) + # make a call to the handler + ret_handler = self.mapper_instance.handler(keys=[], datum=None) + # + self.assertEqual(ret[0], ret_handler[0]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/map/test_multiproc_mapper.py b/tests/map/test_multiproc_mapper.py index 7613be53..ccb10b5a 100644 --- a/tests/map/test_multiproc_mapper.py +++ b/tests/map/test_multiproc_mapper.py @@ -1,7 +1,6 @@ import os import unittest from unittest import mock -from unittest.mock import patch, Mock import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -9,8 +8,8 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.mapper.multiproc_server import MultiProcMapper -from pynumaflow.mapper.proto import map_pb2_grpc, map_pb2 +from pynumaflow.mapper import MapMultiprocServer +from pynumaflow.proto.mapper import map_pb2 from tests.map.utils import map_handler, err_map_handler from tests.testing_utils import ( mock_event_time, @@ -25,53 +24,29 @@ def mockenv(**envvars): class TestMultiProcMethods(unittest.TestCase): def setUp(self) -> None: - my_servicer = MultiProcMapper(handler=map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_servicer} + my_server = MapMultiprocServer(mapper_instance=map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} self.test_server = server_from_dictionary(services, strict_real_time()) - @mockenv(NUM_CPU_MULTIPROC="3") def test_multiproc_init(self) -> None: - server = MultiProcMapper(handler=map_handler) - self.assertEqual(server._process_count, 3) + my_server = MapMultiprocServer(mapper_instance=map_handler, server_count=3) + self.assertEqual(my_server._process_count, 3) - @patch("os.cpu_count", Mock(return_value=4)) def test_multiproc_process_count(self) -> None: - server = MultiProcMapper(handler=map_handler) - self.assertEqual(server._process_count, 4) + default_val = os.cpu_count() + my_server = MapMultiprocServer(mapper_instance=map_handler) + self.assertEqual(my_server._process_count, default_val) - @patch("os.cpu_count", Mock(return_value=4)) - @mockenv(NUM_CPU_MULTIPROC="10") def test_max_process_count(self) -> None: - server = MultiProcMapper(handler=map_handler) - self.assertEqual(server._process_count, 8) - - # To test the reuse property for the grpc servers which allow multiple - # bindings to the same server - def test_reuse_port(self): - serv_options = [("grpc.so_reuseaddr", 1)] - - server = MultiProcMapper(handler=map_handler) - - with server._reserve_port(0) as port: - print(port) - bind_address = f"localhost:{port}" - server1 = grpc.server(thread_pool=None, options=serv_options) - map_pb2_grpc.add_MapServicer_to_server(server, server1) - server1.add_insecure_port(bind_address) - - # so_reuseport=0 -> the bind should raise an error - server2 = grpc.server(thread_pool=None, options=(("grpc.so_reuseport", 0),)) - map_pb2_grpc.add_MapServicer_to_server(server, server2) - self.assertRaises(RuntimeError, server2.add_insecure_port, bind_address) - - # so_reuseport=1 -> should allow server to bind to port again - server3 = grpc.server(thread_pool=None, options=(("grpc.so_reuseport", 1),)) - map_pb2_grpc.add_MapServicer_to_server(server, server3) - server3.add_insecure_port(bind_address) + """Max process count is capped at 2 * os.cpu_count, irrespective of what the user + provides as input""" + default_val = os.cpu_count() + server = MapMultiprocServer(mapper_instance=map_handler, server_count=20) + self.assertEqual(server._process_count, default_val * 2) def test_udf_map_err(self): - my_servicer = MultiProcMapper(handler=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_servicer} + my_server = MapMultiprocServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} self.test_server = server_from_dictionary(services, strict_real_time()) event_time_timestamp = _timestamp_pb2.Timestamp() @@ -148,7 +123,7 @@ def test_map_forward_message(self): def test_invalid_input(self): with self.assertRaises(TypeError): - MultiProcMapper() + MapMultiprocServer() if __name__ == "__main__": diff --git a/tests/map/test_sync_mapper.py b/tests/map/test_sync_mapper.py index 8a4d7c4c..839ec2da 100644 --- a/tests/map/test_sync_mapper.py +++ b/tests/map/test_sync_mapper.py @@ -6,9 +6,9 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.mapper import Mapper -from pynumaflow.mapper.proto import map_pb2 -from tests.map.utils import map_handler, err_map_handler +from pynumaflow.mapper import MapServer +from pynumaflow.proto.mapper import map_pb2 +from tests.map.utils import map_handler, err_map_handler, ExampleMap from tests.testing_utils import ( mock_event_time, mock_watermark, @@ -18,20 +18,23 @@ class TestSyncMapper(unittest.TestCase): def setUp(self) -> None: - my_servicer = Mapper(handler=map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_servicer} + class_instance = ExampleMap() + my_server = MapServer(mapper_instance=class_instance) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} self.test_server = server_from_dictionary(services, strict_real_time()) def test_init_with_args(self) -> None: - my_servicer = Mapper( - handler=map_handler, sock_path="/tmp/test.sock", max_message_size=1024 * 1024 * 5 + my_servicer = MapServer( + mapper_instance=map_handler, + sock_path="/tmp/test.sock", + max_message_size=1024 * 1024 * 5, ) self.assertEqual(my_servicer.sock_path, "unix:///tmp/test.sock") - self.assertEqual(my_servicer._max_message_size, 1024 * 1024 * 5) + self.assertEqual(my_servicer.max_message_size, 1024 * 1024 * 5) def test_udf_map_err(self): - my_servicer = Mapper(handler=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_servicer} + my_server = MapServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} self.test_server = server_from_dictionary(services, strict_real_time()) event_time_timestamp = _timestamp_pb2.Timestamp() @@ -57,8 +60,8 @@ def test_udf_map_err(self): self.assertEqual(grpc.StatusCode.UNKNOWN, code) def test_udf_map_error_response(self): - my_servicer = Mapper(handler=err_map_handler) - services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_servicer} + my_server = MapServer(mapper_instance=err_map_handler) + services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer} self.test_server = server_from_dictionary(services, strict_real_time()) event_time_timestamp = _timestamp_pb2.Timestamp() @@ -135,7 +138,7 @@ def test_map_forward_message(self): def test_invalid_input(self): with self.assertRaises(TypeError): - Mapper() + MapServer() if __name__ == "__main__": diff --git a/tests/map/utils.py b/tests/map/utils.py index ef5d7c21..6cecd503 100644 --- a/tests/map/utils.py +++ b/tests/map/utils.py @@ -1,10 +1,24 @@ -from pynumaflow.mapper import Datum, Messages, Message +from pynumaflow.mapper import Datum, Messages, Message, Mapper async def async_map_error_fn(keys: list[str], datum: Datum) -> Messages: raise ValueError("error invoking map") +class ExampleMap(Mapper): + def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + msg = "payload:{} event_time:{} watermark:{}".format( + val.decode("utf-8"), + datum.event_time, + datum.watermark, + ) + val = bytes(msg, encoding="utf-8") + messages = Messages() + messages.append(Message(val, keys=keys)) + return messages + + def map_handler(keys: list[str], datum: Datum) -> Messages: val = datum.value msg = "payload:{} event_time:{} watermark:{}".format( diff --git a/tests/mapstream/test_async_map_stream.py b/tests/mapstream/test_async_map_stream.py index e2022713..107289a6 100644 --- a/tests/mapstream/test_async_map_stream.py +++ b/tests/mapstream/test_async_map_stream.py @@ -12,9 +12,9 @@ from pynumaflow.mapstreamer import ( Message, Datum, - AsyncMapStreamer, + MapStreamAsyncServer, ) -from pynumaflow.mapstreamer.proto import mapstream_pb2_grpc +from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc from tests.mapstream.utils import start_request_map_stream LOGGER = setup_logging(__name__) @@ -35,7 +35,7 @@ async def async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIterab _s: Server = None -_channel = grpc.insecure_channel("localhost:50060") +_channel = grpc.insecure_channel("unix:///tmp/async_map_stream.sock") _loop = None @@ -47,15 +47,15 @@ def startup_callable(loop): def NewAsyncMapStreamer( map_stream_handler=async_map_stream_handler, ): - udfs = AsyncMapStreamer(handler=async_map_stream_handler) - + server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler) + udfs = server.servicer return udfs -async def start_server(udfs: AsyncMapStreamer): +async def start_server(udfs): server = grpc.aio.server() mapstream_pb2_grpc.add_MapStreamServicer_to_server(udfs, server) - listen_addr = "[::]:50060" + listen_addr = "unix:///tmp/async_map_stream.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -76,7 +76,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50060") as channel: + with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -118,7 +118,7 @@ def test_map_stream(self) -> None: self.assertEqual(10, counter) def test_is_ready(self) -> None: - with grpc.insecure_channel("localhost:50060") as channel: + with grpc.insecure_channel("unix:///tmp/async_map_stream.sock") as channel: stub = mapstream_pb2_grpc.MapStreamStub(channel) request = _empty_pb2.Empty() diff --git a/tests/mapstream/test_async_map_stream_err.py b/tests/mapstream/test_async_map_stream_err.py index 1211165c..a1bf137e 100644 --- a/tests/mapstream/test_async_map_stream_err.py +++ b/tests/mapstream/test_async_map_stream_err.py @@ -9,8 +9,8 @@ from grpc.aio._server import Server from pynumaflow import setup_logging -from pynumaflow.mapstreamer import Message, Datum, AsyncMapStreamer -from pynumaflow.mapstreamer.proto import mapstream_pb2_grpc +from pynumaflow.mapstreamer import Message, Datum, MapStreamAsyncServer +from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc from tests.mapstream.utils import start_request_map_stream LOGGER = setup_logging(__name__) @@ -32,7 +32,7 @@ async def err_async_map_stream_handler(keys: list[str], datum: Datum) -> AsyncIt _s: Server = None -_channel = grpc.insecure_channel("localhost:50052") +_channel = grpc.insecure_channel("unix:///tmp/async_map_stream_err.sock") _loop = None @@ -43,9 +43,10 @@ def startup_callable(loop): async def start_server(): server = grpc.aio.server() - udfs = AsyncMapStreamer(handler=err_async_map_stream_handler) + server_instance = MapStreamAsyncServer(err_async_map_stream_handler) + udfs = server_instance.servicer mapstream_pb2_grpc.add_MapStreamServicer_to_server(udfs, server) - listen_addr = "[::]:50052" + listen_addr = "unix:///tmp/async_map_stream_err.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -65,7 +66,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50052") as channel: + with grpc.insecure_channel("unix:///tmp/async_map_stream_err.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -100,7 +101,7 @@ def __stub(self): def test_invalid_input(self): with self.assertRaises(TypeError): - AsyncMapStreamer() + MapStreamAsyncServer() if __name__ == "__main__": diff --git a/tests/mapstream/utils.py b/tests/mapstream/utils.py index 9f0960f0..4e9e4824 100644 --- a/tests/mapstream/utils.py +++ b/tests/mapstream/utils.py @@ -1,5 +1,5 @@ from pynumaflow.mapstreamer import Datum -from pynumaflow.mapstreamer.proto import mapstream_pb2 +from pynumaflow.proto.mapstreamer import mapstream_pb2 from tests.testing_utils import get_time_args, mock_message diff --git a/tests/reduce/test_async_reduce.py b/tests/reduce/test_async_reduce.py index b75501ed..c65f69ac 100644 --- a/tests/reduce/test_async_reduce.py +++ b/tests/reduce/test_async_reduce.py @@ -3,7 +3,6 @@ import threading import unittest from collections.abc import AsyncIterable -from collections.abc import Iterator import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -15,10 +14,11 @@ Messages, Message, Datum, - AsyncReducer, Metadata, + ReduceAsyncServer, + Reducer, ) -from pynumaflow.reducer.proto import reduce_pb2, reduce_pb2_grpc +from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc from tests.testing_utils import ( mock_message, mock_interval_window_start, @@ -28,24 +28,6 @@ LOGGER = setup_logging(__name__) -# if set to true, map handler will raise a `ValueError` exception. -raise_error_from_map = False - - -async def async_reduce_handler( - keys: list[str], datums: AsyncIterable[Datum], md: Metadata -) -> Messages: - interval_window = md.interval_window - counter = 0 - async for _ in datums: - counter += 1 - msg = ( - f"counter:{counter} interval_window_start:{interval_window.start} " - f"interval_window_end:{interval_window.end}" - ) - - return Messages(Message(str.encode(msg), keys=keys)) - def request_generator(count, request, resetkey: bool = False): for i in range(count): @@ -70,7 +52,7 @@ def start_request() -> (Datum, tuple): _s: Server = None -_channel = grpc.insecure_channel("localhost:50057") +_channel = grpc.insecure_channel("unix:///tmp/reduce.sock") _loop = None @@ -79,7 +61,27 @@ def startup_callable(loop): loop.run_forever() -async def reduce_handler(keys: list[str], datums: Iterator[Datum], md: Metadata) -> Messages: +class ExampleClass(Reducer): + def __init__(self, counter): + self.counter = counter + + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + interval_window = md.interval_window + self.counter = 0 + async for _ in datums: + self.counter += 1 + msg = ( + f"counter:{self.counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + return Messages(Message(str.encode(msg), keys=keys)) + + +async def reduce_handler_func( + keys: list[str], datums: AsyncIterable[Datum], md: Metadata +) -> Messages: interval_window = md.interval_window counter = 0 async for _ in datums: @@ -91,18 +93,17 @@ async def reduce_handler(keys: list[str], datums: Iterator[Datum], md: Metadata) return Messages(Message(str.encode(msg), keys=keys)) -def NewAsyncReducer( - reduce_handler=async_reduce_handler, -): - udfs = AsyncReducer(handler=async_reduce_handler) +def NewAsyncReducer(): + server_instance = ReduceAsyncServer(ExampleClass, init_args=(0,)) + udfs = server_instance.servicer return udfs -async def start_server(udfs: AsyncReducer): +async def start_server(udfs): server = grpc.aio.server() reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) - listen_addr = "[::]:50057" + listen_addr = "unix:///tmp/reduce.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -123,7 +124,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50057") as channel: + with grpc.insecure_channel("unix:///tmp/reduce.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -217,7 +218,7 @@ def test_reduce_with_multiple_keys(self) -> None: self.assertEqual(100, count) def test_is_ready(self) -> None: - with grpc.insecure_channel("localhost:50057") as channel: + with grpc.insecure_channel("unix:///tmp/reduce.sock") as channel: stub = reduce_pb2_grpc.ReduceStub(channel) request = _empty_pb2.Empty() @@ -232,6 +233,19 @@ def test_is_ready(self) -> None: def __stub(self): return reduce_pb2_grpc.ReduceStub(_channel) + def test_error_init(self): + # Check that reducer_handler in required + with self.assertRaises(TypeError): + ReduceAsyncServer() + # Check that the init_args and init_kwargs are passed + # only with a Reducer class + with self.assertRaises(TypeError): + ReduceAsyncServer(reduce_handler_func, init_args=(0, 1)) + # Check that an instance is not passed instead of the class + # signature + with self.assertRaises(TypeError): + ReduceAsyncServer(ExampleClass(0)) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/reduce/test_async_reduce_err.py b/tests/reduce/test_async_reduce_err.py new file mode 100644 index 00000000..8da36d0c --- /dev/null +++ b/tests/reduce/test_async_reduce_err.py @@ -0,0 +1,145 @@ +import asyncio +import logging +import threading +import unittest +from collections.abc import AsyncIterable + +import grpc +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow._constants import WIN_START_TIME, WIN_END_TIME +from pynumaflow.reducer import ( + Messages, + Message, + Datum, + Metadata, + ReduceAsyncServer, +) +from pynumaflow.proto.reducer import reduce_pb2, reduce_pb2_grpc +from tests.testing_utils import ( + mock_message, + mock_interval_window_start, + mock_interval_window_end, + get_time_args, +) + +LOGGER = setup_logging(__name__) + + +def request_generator(count, request, resetkey: bool = False): + for i in range(count): + if resetkey: + request.keys.extend([f"key-{i}"]) + yield request + + +def start_request() -> (Datum, tuple): + event_time_timestamp, watermark_timestamp = get_time_args() + + request = reduce_pb2.ReduceRequest( + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + ) + metadata = ( + (WIN_START_TIME, f"{mock_interval_window_start()}"), + (WIN_END_TIME, f"{mock_interval_window_end()}"), + ) + return request, metadata + + +_s: Server = None +_channel = grpc.insecure_channel("unix:///tmp/reduce_err.sock") +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +async def err_handler(keys: list[str], datums: AsyncIterable[Datum], md: Metadata) -> Messages: + interval_window = md.interval_window + counter = 0 + async for _ in datums: + counter += 1 + msg = ( + f"counter:{counter} interval_window_start:{interval_window.start} " + f"interval_window_end:{interval_window.end}" + ) + raise RuntimeError("Got a runtime error from reduce handler.") + return Messages(Message(str.encode(msg), keys=keys)) + + +def NewAsyncReducer(): + server_instance = ReduceAsyncServer(err_handler) + udfs = server_instance.servicer + + return udfs + + +async def start_server(udfs): + server = grpc.aio.server() + reduce_pb2_grpc.add_ReduceServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/reduce_err.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +class TestAsyncReducerError(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = NewAsyncReducer() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel("unix:///tmp/reduce_err.sock") as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + def test_reduce(self) -> None: + stub = self.__stub() + request, metadata = start_request() + generator_response = None + try: + generator_response = stub.ReduceFn( + request_iterator=request_generator(count=10, request=request), metadata=metadata + ) + counter = 0 + for _ in generator_response: + counter += 1 + except Exception as err: + self.assertTrue("Got a runtime error from reduce handler." in err.__str__()) + return + self.fail("Expected an exception.") + + def __stub(self): + return reduce_pb2_grpc.ReduceStub(_channel) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/reduce/test_datatypes.py b/tests/reduce/test_datatypes.py index 59f54dc1..5433a044 100644 --- a/tests/reduce/test_datatypes.py +++ b/tests/reduce/test_datatypes.py @@ -1,6 +1,9 @@ +from copy import deepcopy import unittest +from collections.abc import AsyncIterable from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from pynumaflow.reducer import Reducer, Messages from pynumaflow.reducer._dtypes import ( IntervalWindow, @@ -102,5 +105,46 @@ def test_interval_window(self): self.assertEqual(i, m.interval_window) +class TestReducerClass(unittest.TestCase): + class ExampleClass(Reducer): + async def handler( + self, keys: list[str], datums: AsyncIterable[Datum], md: Metadata + ) -> Messages: + pass + + def __init__(self, test1, test2): + self.test1 = test1 + self.test2 = test2 + self.test3 = self.test1 + + def test_init(self): + r = self.ExampleClass(test1=1, test2=2) + self.assertEqual(1, r.test1) + self.assertEqual(2, r.test2) + self.assertEqual(1, r.test3) + + def test_deep_copy(self): + """Test that the deepcopy works as expected""" + r = self.ExampleClass(test1=1, test2=2) + # Create a copy of r + r_copy = deepcopy(r) + # Check that the attributes are the same + self.assertEqual(1, r_copy.test1) + self.assertEqual(2, r_copy.test2) + self.assertEqual(1, r_copy.test3) + # Check that the objects are not the same + self.assertNotEqual(id(r), id(r_copy)) + # Update the attributes of r + r.test1 = 5 + r.test3 = 6 + # Check that the other object is not updated + self.assertNotEqual(r.test1, r_copy.test1) + self.assertNotEqual(r.test3, r_copy.test3) + self.assertNotEqual(id(r.test3), id(r_copy.test3)) + # Verify that the instance type is correct + self.assertTrue(isinstance(r_copy, self.ExampleClass)) + self.assertTrue(isinstance(r_copy, Reducer)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/sideinput/test_responses.py b/tests/sideinput/test_responses.py index 589250e3..859f4bb1 100644 --- a/tests/sideinput/test_responses.py +++ b/tests/sideinput/test_responses.py @@ -1,6 +1,6 @@ import unittest -from pynumaflow.sideinput import Response +from pynumaflow.sideinput import Response, SideInput class TestResponse(unittest.TestCase): @@ -26,5 +26,28 @@ def test_no_broadcast_message(self): self.assertTrue(succ_response.no_broadcast) +class ExampleSideInput(SideInput): + def retrieve_handler(self) -> Response: + return Response.broadcast_message(b"testMessage") + + +class TestSideInputClass(unittest.TestCase): + def setUp(self) -> None: + # Create a side input class instance + self.side_input_instance = ExampleSideInput() + + def test_side_input_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.side_input_instance() + self.assertEqual(b"testMessage", ret.value) + # make a call to the handler + ret_handler = self.side_input_instance.retrieve_handler() + # Both responses should be equal + self.assertEqual(ret, ret_handler) + + if __name__ == "__main__": unittest.main() diff --git a/tests/sideinput/test_side_input_server.py b/tests/sideinput/test_side_input_server.py index 53e360bb..501c378f 100644 --- a/tests/sideinput/test_side_input_server.py +++ b/tests/sideinput/test_side_input_server.py @@ -4,10 +4,9 @@ from google.protobuf import empty_pb2 as _empty_pb2 from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sideinput import SideInput -from pynumaflow.sideinput.proto import sideinput_pb2 +from pynumaflow.proto.sideinput import sideinput_pb2 -from pynumaflow.sideinput import Response +from pynumaflow.sideinput import Response, SideInputServer def retrieve_side_input_handler() -> Response: @@ -34,7 +33,8 @@ class TestServer(unittest.TestCase): """ def setUp(self) -> None: - my_service = SideInput(handler=retrieve_side_input_handler) + server = SideInputServer(retrieve_side_input_handler) + my_service = server.servicer services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_service} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -42,20 +42,21 @@ def test_init_with_args(self) -> None: """ Test the initialization of the SideInput class, """ - my_servicer = SideInput( - handler=retrieve_side_input_handler, + my_servicer = SideInputServer( + side_input_instance=retrieve_side_input_handler, sock_path="/tmp/test_side_input.sock", max_message_size=1024 * 1024 * 5, ) self.assertEqual(my_servicer.sock_path, "unix:///tmp/test_side_input.sock") - self.assertEqual(my_servicer._max_message_size, 1024 * 1024 * 5) + self.assertEqual(my_servicer.max_message_size, 1024 * 1024 * 5) def test_side_input_err(self): """ Test the error case for the RetrieveSideInput method, """ - my_servicer = SideInput(handler=err_retrieve_handler) - services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_servicer} + server = SideInputServer(err_retrieve_handler) + my_service = server.servicer + services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_service} self.test_server = server_from_dictionary(services, strict_real_time()) method = self.test_server.invoke_unary_unary( @@ -115,7 +116,8 @@ def test_side_input_no_broadcast(self): Test the no_broadcast_message method, where we expect the no_broadcast flag to be True. """ - my_servicer = SideInput(handler=retrieve_no_broadcast_handler) + server = SideInputServer(side_input_instance=retrieve_no_broadcast_handler) + my_servicer = server.servicer services = {sideinput_pb2.DESCRIPTOR.services_by_name["SideInput"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -137,7 +139,7 @@ def test_side_input_no_broadcast(self): def test_invalid_input(self): with self.assertRaises(TypeError): - SideInput() + SideInputServer() if __name__ == "__main__": diff --git a/tests/sink/test_async_sink.py b/tests/sink/test_async_sink.py index 062c5781..1ba23edd 100644 --- a/tests/sink/test_async_sink.py +++ b/tests/sink/test_async_sink.py @@ -12,9 +12,9 @@ from pynumaflow.sinker import ( Datum, ) -from pynumaflow.sinker import Responses, Response, AsyncSinker -from pynumaflow.sinker.proto import sink_pb2 -from pynumaflow.sinker.proto import sink_pb2_grpc +from pynumaflow.sinker import Responses, Response +from pynumaflow.proto.sinker import sink_pb2_grpc, sink_pb2 +from pynumaflow.sinker.async_server import SinkAsyncServer from tests.sink.test_server import ( mock_message, mock_err_message, @@ -58,7 +58,7 @@ def start_sink_streaming_request(err=False) -> (Datum, tuple): _s: Server = None -_channel = grpc.insecure_channel("localhost:50055") +_channel = grpc.insecure_channel("unix:///tmp/async_sink.sock") _loop = None @@ -69,9 +69,10 @@ def startup_callable(loop): async def start_server(): server = grpc.aio.server() - uds = AsyncSinker(handler=udsink_handler) + server_instance = SinkAsyncServer(sinker_instance=udsink_handler) + uds = server_instance.servicer sink_pb2_grpc.add_SinkServicer_to_server(uds, server) - listen_addr = "[::]:50055" + listen_addr = "unix:///tmp/async_sink.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -91,7 +92,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50055") as channel: + with grpc.insecure_channel("unix:///tmp/async_sink.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -110,7 +111,7 @@ def tearDownClass(cls) -> None: # def test_run_server(self) -> None: - with grpc.insecure_channel("localhost:50055") as channel: + with grpc.insecure_channel("unix:///tmp/async_sink.sock") as channel: stub = sink_pb2_grpc.SinkStub(channel) request = _empty_pb2.Empty() @@ -160,6 +161,10 @@ def test_sink_err(self) -> None: def __stub(self): return sink_pb2_grpc.SinkStub(_channel) + def test_invalid_server_type(self) -> None: + with self.assertRaises(TypeError): + SinkAsyncServer() + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/sink/test_responses.py b/tests/sink/test_responses.py index 9fed6d97..032323b8 100644 --- a/tests/sink/test_responses.py +++ b/tests/sink/test_responses.py @@ -1,6 +1,7 @@ import unittest +from collections.abc import Iterator -from pynumaflow.sinker import Response, Responses +from pynumaflow.sinker import Response, Responses, Sinker, Datum class TestResponse(unittest.TestCase): @@ -39,5 +40,30 @@ def test_responses(self): ) +class ExampleSinkClass(Sinker): + def handler(self, datums: Iterator[Datum]) -> Responses: + results = Responses() + results.append(Response.as_success("test_message")) + return results + + +class TestSinkClass(unittest.TestCase): + def setUp(self) -> None: + # Create a map class instance + self.sinker_instance = ExampleSinkClass() + + def test_sink_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.sinker_instance(None) + self.assertEqual("test_message", ret[0].id) + # make a call to the handler + ret_handler = self.sinker_instance.handler(None) + # Both responses should be equal + self.assertEqual(ret[0], ret_handler[0]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/sink/test_server.py b/tests/sink/test_server.py index af9bccac..4a9bea86 100644 --- a/tests/sink/test_server.py +++ b/tests/sink/test_server.py @@ -7,8 +7,8 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sinker import Responses, Datum, Response, Sinker -from pynumaflow.sinker.proto import sink_pb2 +from pynumaflow.sinker import Responses, Datum, Response, SinkServer +from pynumaflow.proto.sinker import sink_pb2 def udsink_handler(datums: Iterator[Datum]) -> Responses: @@ -47,7 +47,8 @@ def mock_watermark(): class TestServer(unittest.TestCase): def setUp(self) -> None: - my_servicer = Sinker(udsink_handler) + server = SinkServer(sinker_instance=udsink_handler) + my_servicer = server.servicer services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -67,7 +68,8 @@ def test_is_ready(self): self.assertEqual(code, StatusCode.OK) def test_udsink_err(self): - my_servicer = Sinker(err_udsink_handler) + server = SinkServer(sinker_instance=err_udsink_handler) + my_servicer = server.servicer services = {sink_pb2.DESCRIPTOR.services_by_name["Sink"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -156,6 +158,10 @@ def test_forward_message(self): self.assertEqual("mock sink message error", response.results[1].err_msg) self.assertEqual(code, StatusCode.OK) + def test_invalid_init(self): + with self.assertRaises(TypeError): + SinkServer() + if __name__ == "__main__": unittest.main() diff --git a/tests/source/test_async_source.py b/tests/source/test_async_source.py index 53a6d97a..fdf5e756 100644 --- a/tests/source/test_async_source.py +++ b/tests/source/test_async_source.py @@ -8,19 +8,16 @@ from grpc.aio._server import Server from pynumaflow import setup_logging +from pynumaflow.proto.sourcer import source_pb2_grpc, source_pb2 from pynumaflow.sourcer import ( - AsyncSourcer, + SourceAsyncServer, ) -from pynumaflow.sourcer.proto import source_pb2_grpc, source_pb2 from tests.source.utils import ( - async_source_read_handler, - async_source_ack_handler, - async_source_pending_handler, - async_source_partition_handler, mock_offset, read_req_source_fn, ack_req_source_fn, mock_partitions, + AsyncSource, ) LOGGER = setup_logging(__name__) @@ -28,8 +25,7 @@ # if set to true, map handler will raise a `ValueError` exception. raise_error_from_map = False - -server_port = "localhost:50058" +server_port = "unix:///tmp/async_source.sock" _s: Server = None _channel = grpc.insecure_channel(server_port) @@ -41,25 +37,17 @@ def startup_callable(loop): loop.run_forever() -def NewAsyncSourcer( - handler=async_source_read_handler, - ack_handler=async_source_ack_handler, - pending_handler=async_source_pending_handler, - partitions_handler=async_source_partition_handler, -): - udfs = AsyncSourcer( - read_handler=async_source_read_handler, - ack_handler=async_source_ack_handler, - pending_handler=async_source_pending_handler, - partitions_handler=async_source_partition_handler, - ) +def NewAsyncSourcer(): + class_instance = AsyncSource() + server = SourceAsyncServer(sourcer_instance=class_instance) + udfs = server.servicer return udfs -async def start_server(udfs: AsyncSourcer): +async def start_server(udfs): server = grpc.aio.server() source_pb2_grpc.add_SourceServicer_to_server(udfs, server) - listen_addr = "[::]:50058" + listen_addr = "unix:///tmp/async_source.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s diff --git a/tests/source/test_async_source_err.py b/tests/source/test_async_source_err.py index 6b9bd3f0..85acef7e 100644 --- a/tests/source/test_async_source_err.py +++ b/tests/source/test_async_source_err.py @@ -8,22 +8,19 @@ from grpc.aio._server import Server from pynumaflow import setup_logging -from pynumaflow.sourcer import AsyncSourcer -from pynumaflow.sourcer.proto import source_pb2_grpc, source_pb2 +from pynumaflow.sourcer import SourceAsyncServer +from pynumaflow.proto.sourcer import source_pb2_grpc, source_pb2 from google.protobuf import empty_pb2 as _empty_pb2 from tests.source.utils import ( - err_async_source_read_handler, - err_async_source_ack_handler, - err_async_source_pending_handler, read_req_source_fn, ack_req_source_fn, - err_async_source_partition_handler, + AsyncSourceError, ) LOGGER = setup_logging(__name__) _s: Server = None -server_port = "localhost:50062" +server_port = "unix:///tmp/async_err_source.sock" _channel = grpc.insecure_channel(server_port) _loop = None @@ -35,14 +32,11 @@ def startup_callable(loop): async def start_server(): server = grpc.aio.server() - udfs = AsyncSourcer( - read_handler=err_async_source_read_handler, - ack_handler=err_async_source_ack_handler, - pending_handler=err_async_source_pending_handler, - partitions_handler=err_async_source_partition_handler, - ) + class_instance = AsyncSourceError() + server_instance = SourceAsyncServer(sourcer_instance=class_instance) + udfs = server_instance.servicer source_pb2_grpc.add_SourceServicer_to_server(udfs, server) - listen_addr = "[::]:50062" + listen_addr = "unix:///tmp/async_err_source.sock" server.add_insecure_port(listen_addr) logging.info("Starting server on %s", listen_addr) global _s @@ -62,7 +56,7 @@ def setUpClass(cls) -> None: asyncio.run_coroutine_threadsafe(start_server(), loop=loop) while True: try: - with grpc.insecure_channel("localhost:50062") as channel: + with grpc.insecure_channel("unix:///tmp/async_err_source.sock") as channel: f = grpc.channel_ready_future(channel) f.result(timeout=10) if f.done(): @@ -126,6 +120,10 @@ def test_partition_error(self) -> None: return self.fail("Expected an exception.") + def test_invalid_server_type(self) -> None: + with self.assertRaises(TypeError): + SourceAsyncServer() + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) diff --git a/tests/source/test_message.py b/tests/source/test_message.py index 43d54f87..00ca83e7 100644 --- a/tests/source/test_message.py +++ b/tests/source/test_message.py @@ -1,6 +1,11 @@ import unittest -from pynumaflow.sourcer import Message, Offset, ReadRequest, PartitionsResponse +from pynumaflow.sourcer import ( + Message, + Offset, + ReadRequest, + PartitionsResponse, +) from tests.source.utils import mock_offset from tests.testing_utils import mock_event_time diff --git a/tests/source/test_sync_source.py b/tests/source/test_sync_source.py index 8b230d5e..5e6c2ca8 100644 --- a/tests/source/test_sync_source.py +++ b/tests/source/test_sync_source.py @@ -4,42 +4,34 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sourcer import Sourcer -from pynumaflow.sourcer.proto import source_pb2 +from pynumaflow.sourcer import SourceServer +from pynumaflow.proto.sourcer import source_pb2 from tests.source.utils import ( - sync_source_read_handler, - sync_source_ack_handler, - sync_source_pending_handler, - sync_source_partition_handler, read_req_source_fn, mock_offset, ack_req_source_fn, mock_partitions, + SyncSource, ) class TestSyncSourcer(unittest.TestCase): def setUp(self) -> None: - my_servicer = Sourcer( - read_handler=sync_source_read_handler, - ack_handler=sync_source_ack_handler, - pending_handler=sync_source_pending_handler, - partitions_handler=sync_source_partition_handler, - ) + class_instance = SyncSource() + server = SourceServer(sourcer_instance=class_instance) + my_servicer = server.servicer services = {source_pb2.DESCRIPTOR.services_by_name["Source"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) def test_init_with_args(self) -> None: - my_servicer = Sourcer( - read_handler=sync_source_read_handler, - ack_handler=sync_source_ack_handler, - pending_handler=sync_source_pending_handler, - partitions_handler=sync_source_partition_handler, + class_instance = SyncSource() + server = SourceServer( + sourcer_instance=class_instance, sock_path="/tmp/test.sock", max_message_size=1024 * 1024 * 5, ) - self.assertEqual(my_servicer.sock_path, "unix:///tmp/test.sock") - self.assertEqual(my_servicer._max_message_size, 1024 * 1024 * 5) + self.assertEqual(server.sock_path, "unix:///tmp/test.sock") + self.assertEqual(server.max_message_size, 1024 * 1024 * 5) def test_is_ready(self): method = self.test_server.invoke_unary_unary( diff --git a/tests/source/test_sync_source_err.py b/tests/source/test_sync_source_err.py index 875c4494..e135913e 100644 --- a/tests/source/test_sync_source_err.py +++ b/tests/source/test_sync_source_err.py @@ -4,26 +4,20 @@ from google.protobuf import empty_pb2 as _empty_pb2 from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sourcer import Sourcer -from pynumaflow.sourcer.proto import source_pb2 +from pynumaflow.sourcer import SourceServer +from pynumaflow.proto.sourcer import source_pb2 from tests.source.utils import ( read_req_source_fn, ack_req_source_fn, - err_sync_source_read_handler, - err_sync_source_ack_handler, - err_sync_source_pending_handler, - err_sync_source_partition_handler, + SyncSourceError, ) class TestSyncSourcer(unittest.TestCase): def setUp(self) -> None: - my_servicer = Sourcer( - read_handler=err_sync_source_read_handler, - ack_handler=err_sync_source_ack_handler, - pending_handler=err_sync_source_pending_handler, - partitions_handler=err_sync_source_partition_handler, - ) + class_instance = SyncSourceError() + server = SourceServer(sourcer_instance=class_instance) + my_servicer = server.servicer services = {source_pb2.DESCRIPTOR.services_by_name["Source"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -105,7 +99,7 @@ def test_source_partition(self): def test_invalid_input(self): with self.assertRaises(TypeError): - Sourcer() + SourceServer() if __name__ == "__main__": diff --git a/tests/source/utils.py b/tests/source/utils.py index 7ddc8f85..41a3d637 100644 --- a/tests/source/utils.py +++ b/tests/source/utils.py @@ -1,8 +1,14 @@ from collections.abc import AsyncIterable, Iterable from pynumaflow.sourcer import ReadRequest, Message -from pynumaflow.sourcer._dtypes import AckRequest, PendingResponse, Offset, PartitionsResponse -from pynumaflow.sourcer.proto import source_pb2 +from pynumaflow.sourcer._dtypes import ( + AckRequest, + PendingResponse, + Offset, + PartitionsResponse, + Sourcer, +) +from pynumaflow.proto.sourcer import source_pb2 from tests.testing_utils import mock_event_time @@ -14,46 +20,42 @@ def mock_partitions() -> list[int]: return [1, 2, 3] -async def async_source_read_handler(datum: ReadRequest) -> AsyncIterable[Message]: - payload = b"payload:test_mock_message" - keys = ["test_key"] - offset = mock_offset() - event_time = mock_event_time() - for i in range(10): - yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) +class AsyncSource(Sourcer): + async def read_handler(self, datum: ReadRequest) -> AsyncIterable[Message]: + payload = b"payload:test_mock_message" + keys = ["test_key"] + offset = mock_offset() + event_time = mock_event_time() + for i in range(10): + yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) + async def ack_handler(self, ack_request: AckRequest): + return -async def async_source_ack_handler(ack_request: AckRequest): - return + async def pending_handler(self) -> PendingResponse: + return PendingResponse(count=10) + async def partitions_handler(self) -> PartitionsResponse: + return PartitionsResponse(partitions=mock_partitions()) -async def async_source_pending_handler() -> PendingResponse: - return PendingResponse(count=10) +class SyncSource(Sourcer): + def read_handler(self, datum: ReadRequest) -> Iterable[Message]: + payload = b"payload:test_mock_message" + keys = ["test_key"] + offset = mock_offset() + event_time = mock_event_time() + for i in range(10): + yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) -async def async_source_partition_handler() -> PartitionsResponse: - return PartitionsResponse(partitions=mock_partitions()) + def ack_handler(self, ack_request: AckRequest): + return + def pending_handler(self) -> PendingResponse: + return PendingResponse(count=10) -def sync_source_read_handler(datum: ReadRequest) -> Iterable[Message]: - payload = b"payload:test_mock_message" - keys = ["test_key"] - offset = mock_offset() - event_time = mock_event_time() - for i in range(10): - yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) - - -def sync_source_ack_handler(ack_request: AckRequest): - return - - -def sync_source_pending_handler() -> PendingResponse: - return PendingResponse(count=10) - - -def sync_source_partition_handler() -> PartitionsResponse: - return PartitionsResponse(partitions=mock_partitions()) + def partitions_handler(self) -> PartitionsResponse: + return PartitionsResponse(partitions=mock_partitions()) def read_req_source_fn() -> ReadRequest: @@ -70,40 +72,36 @@ def ack_req_source_fn() -> AckRequest: return request -# This handler mimics the scenario where map stream UDF throws a runtime error. -async def err_async_source_read_handler(datum: ReadRequest) -> AsyncIterable[Message]: - payload = b"payload:test_mock_message" - keys = ["test_key"] - offset = mock_offset() - event_time = mock_event_time() - for i in range(10): - yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) - raise RuntimeError("Got a runtime error from read handler.") - - -async def err_async_source_ack_handler(ack_request: AckRequest): - raise RuntimeError("Got a runtime error from ack handler.") - - -async def err_async_source_pending_handler() -> PendingResponse: - raise RuntimeError("Got a runtime error from pending handler.") - - -async def err_async_source_partition_handler() -> PartitionsResponse: - raise RuntimeError("Got a runtime error from partition handler.") +class AsyncSourceError(Sourcer): + # This handler mimics the scenario where map stream UDF throws a runtime error. + async def read_handler(self, datum: ReadRequest) -> AsyncIterable[Message]: + payload = b"payload:test_mock_message" + keys = ["test_key"] + offset = mock_offset() + event_time = mock_event_time() + for i in range(10): + yield Message(payload=payload, keys=keys, offset=offset, event_time=event_time) + raise RuntimeError("Got a runtime error from read handler.") + async def ack_handler(self, ack_request: AckRequest): + raise RuntimeError("Got a runtime error from ack handler.") -def err_sync_source_read_handler(datum: ReadRequest) -> Iterable[Message]: - raise RuntimeError("Got a runtime error from read handler.") + async def pending_handler(self) -> PendingResponse: + raise RuntimeError("Got a runtime error from pending handler.") + async def partitions_handler(self) -> PartitionsResponse: + raise RuntimeError("Got a runtime error from partition handler.") -def err_sync_source_ack_handler(ack_request: AckRequest): - raise RuntimeError("Got a runtime error from ack handler.") +class SyncSourceError(Sourcer): + def read_handler(self, datum: ReadRequest) -> Iterable[Message]: + raise RuntimeError("Got a runtime error from read handler.") -def err_sync_source_pending_handler() -> PendingResponse: - raise RuntimeError("Got a runtime error from pending handler.") + def ack_handler(self, ack_request: AckRequest): + raise RuntimeError("Got a runtime error from ack handler.") + def pending_handler(self) -> PendingResponse: + raise RuntimeError("Got a runtime error from pending handler.") -def err_sync_source_partition_handler() -> PartitionsResponse: - raise RuntimeError("Got a runtime error from partition handler.") + def partitions_handler(self) -> PartitionsResponse: + raise RuntimeError("Got a runtime error from partition handler.") diff --git a/tests/sourcetransform/test_messages.py b/tests/sourcetransform/test_messages.py index 8f2caf46..9f6baceb 100644 --- a/tests/sourcetransform/test_messages.py +++ b/tests/sourcetransform/test_messages.py @@ -1,7 +1,8 @@ import unittest from datetime import datetime, timezone -from pynumaflow.sourcetransformer import Messages, Message, DROP +from pynumaflow.sourcetransformer import Messages, Message, DROP, SourceTransformer, Datum +from tests.testing_utils import mock_new_event_time def mock_message_t(): @@ -93,5 +94,30 @@ def test_err(self): msgts[:1] +class ExampleSourceTransformClass(SourceTransformer): + def handler(self, keys: list[str], datum: Datum) -> Messages: + messages = Messages() + messages.append(Message(mock_message_t(), mock_new_event_time(), keys=keys)) + return messages + + +class TestSourceTransformClass(unittest.TestCase): + def setUp(self) -> None: + # Create a map class instance + self.transform_instance = ExampleSourceTransformClass() + + def test_source_transform_class_call(self): + """Test that the __call__ functionality for the class works, + ie the class instance can be called directly to invoke the handler function + """ + # make a call to the class directly + ret = self.transform_instance([], None) + self.assertEqual(mock_message_t(), ret[0].value) + # make a call to the handler + ret_handler = self.transform_instance.handler([], None) + # Both responses should be equal + self.assertEqual(ret[0], ret_handler[0]) + + if __name__ == "__main__": unittest.main() diff --git a/tests/sourcetransform/test_multiproc.py b/tests/sourcetransform/test_multiproc.py index 25cca845..da4dcffb 100644 --- a/tests/sourcetransform/test_multiproc.py +++ b/tests/sourcetransform/test_multiproc.py @@ -1,7 +1,6 @@ import os import unittest from unittest import mock -from unittest.mock import Mock, patch import grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -9,8 +8,8 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sourcetransformer.multiproc_server import MultiProcSourceTransformer -from pynumaflow.sourcetransformer.proto import transform_pb2_grpc, transform_pb2 +from pynumaflow.proto.sourcetransformer import transform_pb2 +from pynumaflow.sourcetransformer.multiproc_server import SourceTransformMultiProcServer from tests.sourcetransform.utils import transform_handler, err_transform_handler from tests.testing_utils import ( mock_event_time, @@ -27,52 +26,32 @@ def mockenv(**envvars): class TestMultiProcMethods(unittest.TestCase): def setUp(self) -> None: - my_servicer = MultiProcSourceTransformer(handler=transform_handler) + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + my_servicer = server.servicer services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) - @mockenv(NUM_CPU_MULTIPROC="3") def test_multiproc_init(self) -> None: - server = MultiProcSourceTransformer(handler=transform_handler) + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, server_count=3 + ) self.assertEqual(server._process_count, 3) - @patch("os.cpu_count", Mock(return_value=4)) def test_multiproc_process_count(self) -> None: - server = MultiProcSourceTransformer(handler=transform_handler) - self.assertEqual(server._process_count, 4) + default_value = os.cpu_count() + server = SourceTransformMultiProcServer(source_transform_instance=transform_handler) + self.assertEqual(server._process_count, default_value) - @patch("os.cpu_count", Mock(return_value=4)) - @mockenv(NUM_CPU_MULTIPROC="10") def test_max_process_count(self) -> None: - server = MultiProcSourceTransformer(handler=transform_handler) - self.assertEqual(server._process_count, 8) - - # To test the reuse property for the grpc servers which allow multiple - # bindings to the same server - def test_reuse_port(self): - serv_options = [("grpc.so_reuseaddr", 1)] - - server = MultiProcSourceTransformer(handler=transform_handler) - - with server._reserve_port(0) as port: - print(port) - bind_address = f"localhost:{port}" - server1 = grpc.server(thread_pool=None, options=serv_options) - transform_pb2_grpc.add_SourceTransformServicer_to_server(server, server1) - server1.add_insecure_port(bind_address) - - # so_reuseport=0 -> the bind should raise an error - server2 = grpc.server(thread_pool=None, options=(("grpc.so_reuseport", 0),)) - transform_pb2_grpc.add_SourceTransformServicer_to_server(server, server2) - self.assertRaises(RuntimeError, server2.add_insecure_port, bind_address) - - # so_reuseport=1 -> should allow server to bind to port again - server3 = grpc.server(thread_pool=None, options=(("grpc.so_reuseport", 1),)) - transform_pb2_grpc.add_SourceTransformServicer_to_server(server, server3) - server3.add_insecure_port(bind_address) + default_value = os.cpu_count() + server = SourceTransformMultiProcServer( + source_transform_instance=transform_handler, server_count=50 + ) + self.assertEqual(server._process_count, 2 * default_value) def test_udf_mapt_err(self): - my_servicer = MultiProcSourceTransformer(handler=err_transform_handler) + server = SourceTransformMultiProcServer(source_transform_instance=err_transform_handler) + my_servicer = server.servicer services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -166,7 +145,7 @@ def test_mapt_assign_new_event_time(self, test_server=None): def test_invalid_input(self): with self.assertRaises(TypeError): - MultiProcSourceTransformer() + SourceTransformMultiProcServer() if __name__ == "__main__": diff --git a/tests/sourcetransform/test_sync_server.py b/tests/sourcetransform/test_sync_server.py index 54830c40..382c36e5 100644 --- a/tests/sourcetransform/test_sync_server.py +++ b/tests/sourcetransform/test_sync_server.py @@ -6,8 +6,8 @@ from grpc import StatusCode from grpc_testing import server_from_dictionary, strict_real_time -from pynumaflow.sourcetransformer import SourceTransformer -from pynumaflow.sourcetransformer.proto import transform_pb2 +from pynumaflow.sourcetransformer import SourceTransformServer +from pynumaflow.proto.sourcetransformer import transform_pb2 from tests.sourcetransform.utils import transform_handler, err_transform_handler from tests.testing_utils import ( mock_event_time, @@ -19,19 +19,23 @@ class TestServer(unittest.TestCase): def setUp(self) -> None: - my_servicer = SourceTransformer(handler=transform_handler) + server = SourceTransformServer(source_transform_instance=transform_handler) + my_servicer = server.servicer services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) def test_init_with_args(self) -> None: - my_servicer = SourceTransformer( - handler=transform_handler, sock_path="/tmp/test.sock", max_message_size=1024 * 1024 * 5 + server = SourceTransformServer( + source_transform_instance=transform_handler, + sock_path="/tmp/test.sock", + max_message_size=1024 * 1024 * 5, ) - self.assertEqual(my_servicer.sock_path, "unix:///tmp/test.sock") - self.assertEqual(my_servicer._max_message_size, 1024 * 1024 * 5) + self.assertEqual(server.sock_path, "unix:///tmp/test.sock") + self.assertEqual(server.max_message_size, 1024 * 1024 * 5) def test_udf_mapt_err(self): - my_servicer = SourceTransformer(handler=err_transform_handler) + server = SourceTransformServer(source_transform_instance=err_transform_handler) + my_servicer = server.servicer services = {transform_pb2.DESCRIPTOR.services_by_name["SourceTransform"]: my_servicer} self.test_server = server_from_dictionary(services, strict_real_time()) @@ -128,7 +132,7 @@ def test_mapt_assign_new_event_time(self, test_server=None): def test_invalid_input(self): with self.assertRaises(TypeError): - SourceTransformer() + SourceTransformServer() if __name__ == "__main__": diff --git a/tests/test_shared.py b/tests/test_shared.py new file mode 100644 index 00000000..9818d9ac --- /dev/null +++ b/tests/test_shared.py @@ -0,0 +1,37 @@ +import unittest + + +from pynumaflow.info.server import get_sdk_version +from pynumaflow.info.types import Protocol +from pynumaflow.mapper import Datum, Messages, Message + +from pynumaflow.shared.server import write_info_file +from tests.testing_utils import read_info_server + + +def map_handler(keys: list[str], datum: Datum) -> Messages: + val = datum.value + msg = "payload:{} event_time:{} watermark:{}".format( + val.decode("utf-8"), + datum.event_time, + datum.watermark, + ) + val = bytes(msg, encoding="utf-8") + messages = Messages() + messages.append(Message(val, keys=keys)) + return messages + + +class TestSharedUtils(unittest.TestCase): + def test_write_info_file(self): + """ + Test write_info_file function + Write data to the info file and read it back to verify + """ + info_file = "/tmp/test_info_server" + ret = write_info_file(info_file=info_file, protocol=Protocol.UDS) + self.assertIsNone(ret) + file_data = read_info_server(info_file=info_file) + self.assertEqual(file_data["protocol"], "uds") + self.assertEqual(file_data["language"], "python") + self.assertEqual(file_data["version"], get_sdk_version())