Skip to content

Commit

Permalink
feat: update batchmap and mapstream to use Map proto (numaproj#200)
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
Co-authored-by: Vigith Maurice <[email protected]>
  • Loading branch information
kohlisid and vigith authored Nov 6, 2024
1 parent 688132a commit f2f7bf6
Show file tree
Hide file tree
Showing 22 changed files with 252 additions and 803 deletions.
7 changes: 5 additions & 2 deletions pynumaflow/batchmapper/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from datetime import datetime
from typing import TypeVar, Callable, Union, Optional
from collections.abc import AsyncIterable
from collections.abc import Awaitable

from pynumaflow._constants import DROP

Expand Down Expand Up @@ -222,5 +221,9 @@ async def handler(self, datums: AsyncIterable[Datum]) -> BatchResponses:
pass


BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], Awaitable[BatchResponses]]
BatchMapAsyncCallable = Callable[[AsyncIterable[Datum]], BatchResponses]
BatchMapCallable = Union[BatchMapper, BatchMapAsyncCallable]


class BatchMapError(Exception):
"""To Raise an error while executing a BatchMap call"""
4 changes: 2 additions & 2 deletions pynumaflow/batchmapper/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MINIMUM_NUMAFLOW_VERSION,
ContainerType,
)
from pynumaflow.proto.batchmapper import batchmap_pb2_grpc
from pynumaflow.proto.mapper import map_pb2_grpc
from pynumaflow.shared.server import NumaflowServer, start_async_server


Expand Down Expand Up @@ -103,7 +103,7 @@ async def aexec(self):
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
batchmap_pb2_grpc.add_BatchMapServicer_to_server(
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
Expand Down
157 changes: 69 additions & 88 deletions pynumaflow/batchmapper/servicer/async_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,19 @@
from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow.batchmapper import Datum
from pynumaflow.batchmapper._dtypes import BatchMapCallable
from pynumaflow.proto.batchmapper import batchmap_pb2, batchmap_pb2_grpc
from pynumaflow.batchmapper._dtypes import BatchMapCallable, BatchMapError
from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc
from pynumaflow.shared.asynciter import NonBlockingIterator
from pynumaflow.shared.server import exit_on_error
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER, STREAM_EOF


async def datum_generator(
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
) -> AsyncIterable[Datum]:
"""
This function is used to create an async generator
from the gRPC request iterator.
It yields a Datum instance for each request received which is then
forwarded to the UDF.
"""
async for d in request_iterator:
request = Datum(
keys=d.keys,
value=d.value,
event_time=d.event_time.ToDatetime(),
watermark=d.watermark.ToDatetime(),
headers=dict(d.headers),
id=d.id,
)
yield request


class AsyncBatchMapServicer(batchmap_pb2_grpc.BatchMapServicer):
class AsyncBatchMapServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Batch Map Servicer instance.
It implements the BatchMapServicer interface from the proto
batchmap_pb2_grpc.py file.
It implements the MapServicer interface from the proto
map_pb2_grpc.py file.
Provides the functionality for the required rpc methods.
"""

Expand All @@ -49,41 +28,74 @@ def __init__(
self.background_tasks = set()
self.__batch_map_handler: BatchMapCallable = handler

async def BatchMapFn(
async def MapFn(
self,
request_iterator: AsyncIterable[batchmap_pb2.BatchMapRequest],
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> batchmap_pb2.BatchMapResponse:
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a batch map function to a BatchMapRequest stream in a batching mode.
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
Applies a batch map function to a MapRequest stream in a batching mode.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
# Create an async iterator from the request iterator
datum_iterator = datum_generator(request_iterator=request_iterator)

try:
# invoke the UDF call for batch map
responses, request_counter = await self.invoke_batch_map(datum_iterator)

# If the number of responses received does not align with the request batch size,
# we will not be able to process the data correctly.
# This should be marked as an error and raised to the user.
if len(responses) != request_counter:
err_msg = "batchMapFn: mismatch between length of batch requests and responses"
raise Exception(err_msg)

# iterate over the responses received and covert to the required proto format
for batch_response in responses:
single_req_resp = []
for msg in batch_response.messages:
single_req_resp.append(
batchmap_pb2.BatchMapResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise BatchMapError("BatchMapFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# cur_task is used to track the task (coroutine) processing
# the current batch of messages.
cur_task = None
# iterate of the incoming messages ot the sink
async for d in request_iterator:
# if we do not have any active task currently processing the batch
# we need to create one and call the User function for processing the same.
if cur_task is None:
req_queue = NonBlockingIterator()
cur_task = asyncio.create_task(
self.__batch_map_handler(req_queue.read_iterator())
)

# send the response for a given ID back to the stream
yield batchmap_pb2.BatchMapResponse(id=batch_response.id, results=single_req_resp)
self.background_tasks.add(cur_task)
cur_task.add_done_callback(self.background_tasks.discard)
# when we have end of transmission message, we need to stop the processing the
# current batch and wait for the next batch of messages.
# We will also wait for the current task to finish processing the current batch.
# We mark the current task as None to indicate that we are
# ready to process the next batch.
if d.status and d.status.eot:
await req_queue.put(STREAM_EOF)
await cur_task
ret = cur_task.result()

# iterate over the responses received and covert to the required proto format
for batch_response in ret:
single_req_resp = []
for msg in batch_response.messages:
single_req_resp.append(
map_pb2.MapResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
)
# send the response for a given ID back to the stream
yield map_pb2.MapResponse(id=batch_response.id, results=single_req_resp)

# send EOT after each finishing Batch responses
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True))
cur_task = None
continue

# if we have a valid message, we will add it to the request queue for processing.
datum = Datum(
keys=list(d.request.keys),
value=d.request.value,
event_time=d.request.event_time.ToDatetime(),
watermark=d.request.watermark.ToDatetime(),
headers=dict(d.request.headers),
id=d.id,
)
await req_queue.put(datum)

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
Expand All @@ -93,42 +105,11 @@ async def BatchMapFn(
exit_on_error(context, repr(err))
return

async def invoke_batch_map(self, datum_iterator: AsyncIterable[Datum]):
"""
# iterate over the incoming requests, and keep sending to the user code
# once all messages have been sent, we wait for the responses
"""
# create a message queue to send to the user code
niter = NonBlockingIterator()
riter = niter.read_iterator()
# create a task for invoking the UDF handler
task = asyncio.create_task(self.__batch_map_handler(riter))
# Save a reference to the result of this function, to avoid a
# task disappearing mid-execution.
self.background_tasks.add(task)
task.add_done_callback(lambda t: self.background_tasks.remove(t))

req_count = 0
# start streaming the messages to the UDF code, and increment the request counter
async for datum in datum_iterator:
await niter.put(datum)
req_count += 1

# once all messages have been exhausted, send an EOF to indicate end of messages
# to the UDF
await niter.put(STREAM_EOF)

# wait for all the responses
await task

# return the result from the UDF, along with the request_counter
return task.result(), req_count

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> batchmap_pb2.ReadyResponse:
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto batchmap_pb2_grpc.py file.
"""
return batchmap_pb2.ReadyResponse(ready=True)
return map_pb2.ReadyResponse(ready=True)
4 changes: 4 additions & 0 deletions pynumaflow/mapstreamer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,7 @@ async def handler(self, keys: list[str], datum: Datum) -> AsyncIterable[Message]

MapStreamAsyncCallable = Callable[[list[str], Datum], AsyncIterable[Message]]
MapStreamCallable = Union[MapStreamer, MapStreamAsyncCallable]


class MapStreamError(Exception):
"""To Raise an error while executing a MapStream call"""
4 changes: 2 additions & 2 deletions pynumaflow/mapstreamer/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ContainerType,
)
from pynumaflow.mapstreamer.servicer.async_servicer import AsyncMapStreamServicer
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc
from pynumaflow.proto.mapper import map_pb2_grpc

from pynumaflow._constants import (
MAP_STREAM_SOCK_PATH,
Expand Down Expand Up @@ -122,7 +122,7 @@ async def aexec(self):
# Create a new async server instance and add the servicer to it
server = grpc.aio.server(options=self._server_options)
server.add_insecure_port(self.sock_path)
mapstream_pb2_grpc.add_MapStreamServicer_to_server(
map_pb2_grpc.add_MapServicer_to_server(
self.servicer,
server,
)
Expand Down
68 changes: 37 additions & 31 deletions pynumaflow/mapstreamer/servicer/async_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
from google.protobuf import empty_pb2 as _empty_pb2

from pynumaflow.mapstreamer import Datum
from pynumaflow.mapstreamer._dtypes import MapStreamCallable
from pynumaflow.proto.mapstreamer import mapstream_pb2_grpc, mapstream_pb2
from pynumaflow.mapstreamer._dtypes import MapStreamCallable, MapStreamError
from pynumaflow.proto.mapper import map_pb2_grpc, map_pb2
from pynumaflow.shared.server import exit_on_error
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER


class AsyncMapStreamServicer(mapstream_pb2_grpc.MapStreamServicer):
class AsyncMapStreamServicer(map_pb2_grpc.MapServicer):
"""
This class is used to create a new grpc Map Stream Servicer instance.
It implements the SyncMapServicer interface from the proto
mapstream_pb2_grpc.py file.
map_pb2_grpc.py file.
Provides the functionality for the required rpc methods.
"""

Expand All @@ -24,52 +24,58 @@ def __init__(
):
self.__map_stream_handler: MapStreamCallable = handler

async def MapStreamFn(
async def MapFn(
self,
request: mapstream_pb2.MapStreamRequest,
request_iterator: AsyncIterable[map_pb2.MapRequest],
context: NumaflowServicerContext,
) -> AsyncIterable[mapstream_pb2.MapStreamResponse]:
) -> AsyncIterable[map_pb2.MapResponse]:
"""
Applies a map function to a datum stream in streaming mode.
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""

try:
async for res in self.__invoke_map_stream(
list(request.keys),
Datum(
keys=list(request.keys),
value=request.value,
event_time=request.event_time.ToDatetime(),
watermark=request.watermark.ToDatetime(),
headers=dict(request.headers),
),
context,
):
yield mapstream_pb2.MapStreamResponse(result=res)
# The first message to be received should be a valid handshake
req = await request_iterator.__anext__()
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise MapStreamError("MapStreamFn: expected handshake as the first message")
yield map_pb2.MapResponse(handshake=map_pb2.Handshake(sot=True))

# read for each input request
async for req in request_iterator:
# yield messages as received from the UDF
async for res in self.__invoke_map_stream(
list(req.request.keys),
Datum(
keys=list(req.request.keys),
value=req.request.value,
event_time=req.request.event_time.ToDatetime(),
watermark=req.request.watermark.ToDatetime(),
headers=dict(req.request.headers),
),
):
yield map_pb2.MapResponse(results=[res], id=req.id)
# send EOT to indicate end of transmission for a given message
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
exit_on_error(context, repr(err))
return

async def __invoke_map_stream(
self, keys: list[str], req: Datum, context: NumaflowServicerContext
):
async def __invoke_map_stream(self, keys: list[str], req: Datum):
try:
# Invoke the user handler for map stream
async for msg in self.__map_stream_handler(keys, req):
yield mapstream_pb2.MapStreamResponse.Result(
keys=msg.keys, value=msg.value, tags=msg.tags
)
yield map_pb2.MapResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
exit_on_error(context, repr(err))
raise err

async def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> mapstream_pb2.ReadyResponse:
) -> map_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto mapstream_pb2_grpc.py file.
The pascal case function name comes from the proto map_pb2_grpc.py file.
"""
return mapstream_pb2.ReadyResponse(ready=True)
return map_pb2.ReadyResponse(ready=True)
Empty file.
Loading

0 comments on commit f2f7bf6

Please sign in to comment.