diff --git a/google/cloud/bigquery_storage_v1/writer.py b/google/cloud/bigquery_storage_v1/writer.py index a8c447bb..2f000bfb 100644 --- a/google/cloud/bigquery_storage_v1/writer.py +++ b/google/cloud/bigquery_storage_v1/writer.py @@ -68,6 +68,7 @@ def __init__( client: big_query_write.BigQueryWriteClient, initial_request_template: gapic_types.AppendRowsRequest, metadata: Sequence[Tuple[str, str]] = (), + resend: bool = True, # resend queued messages and failed message ): """Construct a stream manager. @@ -87,12 +88,20 @@ def __init__( self._closing = threading.Lock() self._closed = False self._close_callbacks = [] - self._futures_queue = queue.Queue() + self._queue = queue.Queue() self._inital_request_template = initial_request_template self._metadata = metadata + self._resend = resend - # Only one call to `send()` should attempt to open the RPC. - self._opening = threading.Lock() + + # if self._hibernating == True, this means the connection was closed by + # the server. The stream will try to reconnect if the message queue is + # non-empty, or when the customer sends another request. + self._hibernating = False + + # Only one call to `send()` should attempt to open the RPC. Use + # reentrant lock, because it may be called in a nested way. + self._opening = threading.RLock() self._rpc = None self._stream_name = None @@ -170,7 +179,7 @@ def _open( request.trace_id = f"python-writer:{package_version.__version__}" inital_response_future = AppendRowsFuture(self) - self._futures_queue.put(inital_response_future) + self._queue.put((request, inital_response_future)) self._rpc = bidi.BidiRpc( self._client.append_rows, @@ -235,9 +244,14 @@ def _open( self.close(reason=request_exception) raise request_exception + self._hibernating = False return inital_response_future - def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture": + def send( + self, + request: gapic_types.AppendRowsRequest, + put_into_queue: bool = True, + ) -> "AppendRowsFuture": """Send an append rows request to the open stream. Args: @@ -259,14 +273,18 @@ def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture": # to open, in which case this send will fail anyway due to a closed # RPC. with self._opening: - if not self.is_active: + if not self.is_active or self._hibernating: return self._open(request) # For each request, we expect exactly one response (in order). Add a # future to the queue so that when the response comes, the callback can # pull it off and notify completion. future = AppendRowsFuture(self) - self._futures_queue.put(future) + + # Only put into queue when we are not resending. + if put_into_queue: + self._queue.put((request, future)) + self._rpc.send(request) return future @@ -282,7 +300,7 @@ def _on_response(self, response: gapic_types.AppendRowsResponse): # Since we have 1 response per request, if we get here from a response # callback, the queue should never be empty. - future: AppendRowsFuture = self._futures_queue.get_nowait() + future: AppendRowsFuture = self._queue.get_nowait()[1] if response.error.code: exc = exceptions.from_grpc_status( response.error.code, response.error.message, response=response @@ -328,11 +346,11 @@ def _shutdown(self, reason: Optional[Exception] = None): # We know that no new items will be added to the queue because # we've marked the stream as closed. - while not self._futures_queue.empty(): + while not self._queue.empty(): # Mark each future as failed. Since the consumer thread has # stopped (or at least is attempting to stop), we won't get # response callbacks to populate the remaining futures. - future = self._futures_queue.get_nowait() + future = self._queue.get_nowait()[1] if reason is None: exc = bqstorage_exceptions.StreamClosedError( "Stream closed before receiving a response." @@ -344,6 +362,72 @@ def _shutdown(self, reason: Optional[Exception] = None): for callback in self._close_callbacks: callback(self, reason) + def _hibernate(self, reason: Optional[Exception] = None): + # If the connection is shut down by the server for retriable reasons, + # such as idle connection, we shut down the grpc connection and delete + # the consumer. However, we preserve futures queue, and if the queue is + # not empty, or if there is a new message to be sent, it tries to create + # a new gRPC connection and corresponding consumer to resume the process. + # breakpoint() + + # Stop consumer + if self.is_active: + _LOGGER.debug("Stopping consumer.") + self._consumer.stop() + self._consumer = None + + # Close RPC connection + if self._rpc is not None: + self._rpc.close() + # self._closed = True + _LOGGER.debug("Finished stopping manager.") + + # Register error on the future corresponding to this error message. + # If self._resend == True, we do not set error in the future, + # instead, we retry in self._retry(). + if not self._resend: + future = self._queue.get_nowait()[1] + future.set_exception(reason) + + # Mark self._hibernating as True for future reopening + self._hibernating = True + + return + + def _retry(self): + new_queue = queue.Queue() + print("retrying") + print(f"length of queue: {self._queue.qsize()}") + + with self._opening: + + # Resend each request remaining in the queue, and create a new queue + # with the new futures + while not self._queue.empty(): + print(f"popping from queue: {self._queue.qsize()}") + request, _ = self._queue.get_nowait() + new_future = self.send(request, put_into_queue=False) + new_queue.put((request, new_future)) + + self._queue = new_queue + + return + + def _rpc_done_callback(self, reason: Optional[Exception] = None): + # When the RPC connection is closed, take action accordingly: + # 1. For a retriable error, hibernate the connection, resend queued + # messages if self._resend == True. + # 2. Otherwise, shut down connection completely, and report error for + # all queued futures. + if isinstance(reason, exceptions.Aborted) and self._resend: + self._hibernate(reason) + + # Only retries if the queue is not empty. + if not self._queue.empty(): + self._retry() + else: + self._shutdown(reason) + def _on_rpc_done(self, future): """Triggered whenever the underlying RPC terminates without recovery. @@ -358,7 +442,9 @@ def _on_rpc_done(self, future): _LOGGER.info("RPC termination has signaled streaming pull manager shutdown.") error = _wrap_as_exception(future) thread = threading.Thread( - name=_RPC_ERROR_THREAD_NAME, target=self._shutdown, kwargs={"reason": error} + name=_RPC_ERROR_THREAD_NAME, + target=self._rpc_done_callback, + kwargs={"reason": error}, ) thread.daemon = True thread.start()