Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reconnect write stream if disconnected by server #868

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions google/cloud/bigquery_storage_v1/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
client: big_query_write.BigQueryWriteClient,
initial_request_template: gapic_types.AppendRowsRequest,
metadata: Sequence[Tuple[str, str]] = (),
resend: bool = True, # resend queued messages and failed message
):
"""Construct a stream manager.

Expand All @@ -87,12 +88,20 @@ def __init__(
self._closing = threading.Lock()
self._closed = False
self._close_callbacks = []
self._futures_queue = queue.Queue()
self._queue = queue.Queue()
self._inital_request_template = initial_request_template
self._metadata = metadata
self._resend = resend

# Only one call to `send()` should attempt to open the RPC.
self._opening = threading.Lock()

# if self._hibernating == True, this means the connection was closed by
# the server. The stream will try to reconnect if the message queue is
# non-empty, or when the customer sends another request.
self._hibernating = False

# Only one call to `send()` should attempt to open the RPC. Use
Copy link
Contributor Author

@Linchin Linchin Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: using reentrant lock may be against the original purpose that "Only one call to send() should attempt to open the RPC". Need to check:

  1. Do we really need thread lock in _retry()? (I think so)
  2. If yes, we may need another send func without using any thread lock

# reentrant lock, because it may be called in a nested way.
self._opening = threading.RLock()

self._rpc = None
self._stream_name = None
Expand Down Expand Up @@ -170,7 +179,7 @@ def _open(
request.trace_id = f"python-writer:{package_version.__version__}"

inital_response_future = AppendRowsFuture(self)
self._futures_queue.put(inital_response_future)
self._queue.put((request, inital_response_future))

self._rpc = bidi.BidiRpc(
self._client.append_rows,
Expand Down Expand Up @@ -216,6 +225,8 @@ def _open(
# may be None.
pass

# breakpoint()

try:
is_consumer_active = self._consumer.is_active
except AttributeError:
Expand All @@ -235,9 +246,14 @@ def _open(
self.close(reason=request_exception)
raise request_exception

self._hibernating = False
return inital_response_future

def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture":
def send(
self,
request: gapic_types.AppendRowsRequest,
put_into_queue: bool = True,
) -> "AppendRowsFuture":
"""Send an append rows request to the open stream.

Args:
Expand All @@ -259,14 +275,18 @@ def send(self, request: gapic_types.AppendRowsRequest) -> "AppendRowsFuture":
# to open, in which case this send will fail anyway due to a closed
# RPC.
with self._opening:
if not self.is_active:
if not self.is_active or self._hibernating:
return self._open(request)

# For each request, we expect exactly one response (in order). Add a
# future to the queue so that when the response comes, the callback can
# pull it off and notify completion.
future = AppendRowsFuture(self)
self._futures_queue.put(future)

# Only put into queue when we are not resending.
if put_into_queue:
self._queue.put((request, future))

self._rpc.send(request)
return future

Expand All @@ -282,7 +302,7 @@ def _on_response(self, response: gapic_types.AppendRowsResponse):

# Since we have 1 response per request, if we get here from a response
# callback, the queue should never be empty.
future: AppendRowsFuture = self._futures_queue.get_nowait()
future: AppendRowsFuture = self._queue.get_nowait()[1]
if response.error.code:
exc = exceptions.from_grpc_status(
response.error.code, response.error.message, response=response
Expand Down Expand Up @@ -328,11 +348,11 @@ def _shutdown(self, reason: Optional[Exception] = None):

# We know that no new items will be added to the queue because
# we've marked the stream as closed.
while not self._futures_queue.empty():
while not self._queue.empty():
# Mark each future as failed. Since the consumer thread has
# stopped (or at least is attempting to stop), we won't get
# response callbacks to populate the remaining futures.
future = self._futures_queue.get_nowait()
future = self._queue.get_nowait()[1]
if reason is None:
exc = bqstorage_exceptions.StreamClosedError(
"Stream closed before receiving a response."
Expand All @@ -344,6 +364,69 @@ def _shutdown(self, reason: Optional[Exception] = None):
for callback in self._close_callbacks:
callback(self, reason)

def _hibernate(self, reason: Optional[Exception] = None):
# If the connection is shut down by the server for retriable reasons,
# such as idle connection, we shut down the grpc connection and delete
# the consumer. However, we preserve futures queue, and if the queue is
# not empty, or if there is a new message to be sent, it tries to create
# a new gRPC connection and corresponding consumer to resume the process.
# breakpoint()

# Stop consumer
if self.is_active:
_LOGGER.debug("Stopping consumer.")
self._consumer.stop()
self._consumer = None

# Close RPC connection
if self._rpc is not None:
self._rpc.close()
# self._closed = True
_LOGGER.debug("Finished stopping manager.")

# Register error on the future corresponding to this error message.
# If self._resend == True, we do not set error in the future,
# instead, we retry in self._retry().
if not self._resend:
future = self._queue.get_nowait()[1]
future.set_exception(reason)

# Mark self._hibernating as True for future reopening
self._hibernating = True

return

def _retry(self):
new_queue = queue.Queue()
print("retrying")
print(f"length of queue: {self._queue.qsize()}")

with self._opening:

# Resend each request remaining in the queue, and create a new queue
# with the new futures
while not self._queue.empty():
print(f"popping from queue: {self._queue.qsize()}")
request, _ = self._queue.get_nowait()
new_future = self.send(request, put_into_queue=False)
new_queue.put((request, new_future))

self._queue = new_queue

return

def _rpc_done_callback(self, reason: Optional[Exception] = None):
# When the RPC connection is closed, take action accordingly:
# 1. For a retriable error, hibernate the connection, resend queued
# messages if self._resend == True.
# 2. Otherwise, shut down connection completely.
if isinstance(reason, exceptions.Aborted):
self._hibernate(reason)
if self._resend:
self._retry()
else:
self._shutdown(reason)

def _on_rpc_done(self, future):
"""Triggered whenever the underlying RPC terminates without recovery.

Expand All @@ -358,7 +441,9 @@ def _on_rpc_done(self, future):
_LOGGER.info("RPC termination has signaled streaming pull manager shutdown.")
error = _wrap_as_exception(future)
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME, target=self._shutdown, kwargs={"reason": error}
name=_RPC_ERROR_THREAD_NAME,
target=self._rpc_done_callback,
kwargs={"reason": error},
)
thread.daemon = True
thread.start()
Expand Down
Loading