Skip to content

Commit

Permalink
fix resend
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchin committed Jan 17, 2025
1 parent e7c0249 commit ffefe90
Showing 1 changed file with 39 additions and 22 deletions.
61 changes: 39 additions & 22 deletions google/cloud/bigquery_storage_v1/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ def __init__(
client: big_query_write.BigQueryWriteClient,
initial_request_template: gapic_types.AppendRowsRequest,
metadata: Sequence[Tuple[str, str]] = (),
resend: bool = True, # resend queued message
resend_failed: bool = True,
resend: bool = True, # resend queued messages and failed message
):
"""Construct a stream manager.
Expand All @@ -93,16 +92,16 @@ def __init__(
self._inital_request_template = initial_request_template
self._metadata = metadata
self._resend = resend
self._resend_failed = resend_failed


# if self._hibernating == True, this means the connection was closed by
# the server. The stream will try to reconnect if the message queue is
# non-empty, or when the customer sends another request.
self._hibernating = False

# Only one call to `send()` should attempt to open the RPC.
self._opening = threading.Lock()
# Only one call to `send()` should attempt to open the RPC. Use
# reentrant lock, because it may be called in a nested way.
self._opening = threading.RLock()

self._rpc = None
self._stream_name = None
Expand Down Expand Up @@ -250,7 +249,11 @@ def _open(
self._hibernating = False
return inital_response_future

def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture":
def send(
self,
request: gapic_types.AppendRowsRequest,
put_into_queue: bool = True,
) -> "AppendRowsFuture":
"""Send an append rows request to the open stream.
Args:
Expand Down Expand Up @@ -279,7 +282,11 @@ def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture":
# future to the queue so that when the response comes, the callback can
# pull it off and notify completion.
future = AppendRowsFuture(self)
self._queue.put((request, future))

# Only put into queue when we are not resending.
if put_into_queue:
self._queue.put((request, future))

self._rpc.send(request)
return future

Expand Down Expand Up @@ -377,10 +384,12 @@ def _hibernate(self, reason: Optional[Exception] = None):
# self._closed = True
_LOGGER.debug("Finished stopping manager.")

# Register error on the future corresponding to this error message
# if not self._resend_failed:
future = self._queue.get_nowait()[1]
future.set_exception(reason)
# Register error on the future corresponding to this error message.
# If self._resend == True, we do not set error in the future,
# instead, we retry in self._retry().
if not self._resend:
future = self._queue.get_nowait()[1]
future.set_exception(reason)

# Mark self._hibernating as True for future reopening
self._hibernating = True
Expand All @@ -389,20 +398,28 @@ def _hibernate(self, reason: Optional[Exception] = None):

def _retry(self):
new_queue = queue.Queue()
print("retrying")
print(f"length of queue: {self._queue.qsize()}")

with self._opening:

# Resend each request remaining in the queue, and create a new queue
# with the new futures
while not self._queue.empty():
print(f"popping from queue: {self._queue.qsize()}")
request, _ = self._queue.get_nowait()
new_future = self.send(request, put_into_queue=False)
new_queue.put((request, new_future))

# Resend each request remaining in the queue, and create a new queue
# with the new futures
while not self._queue.empty():
request, _ = self._queue.get_nowait()
new_future = self.send(request)
new_queue.put((request, new_future))
self._queue = new_queue

self._queue = new_queue
return

def _shutdown_or_hibernate_or_retry(self, reason: Optional[Exception] = None):
# Hibernate if a retriable error is received, otherwise, shut down
# completely.
def _rpc_done_callback(self, reason: Optional[Exception] = None):
# When the RPC connection is closed, take action accordingly:
# 1. For a retriable error, hibernate the connection, resend queued
# messages if self._resend == True.
# 2. Otherwise, shut down connection completely.
if isinstance(reason, exceptions.Aborted):
self._hibernate(reason)
if self._resend:
Expand All @@ -425,7 +442,7 @@ def _on_rpc_done(self, future):
error = _wrap_as_exception(future)
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME,
target=self._shutdown_or_hibernate_or_retry,
target=self._rpc_done_callback,
kwargs={"reason": error},
)
thread.daemon = True
Expand Down

0 comments on commit ffefe90

Please sign in to comment.