diff --git a/third-party/squangle/src/squangle/mysql_client/FetchOperation.cpp b/third-party/squangle/src/squangle/mysql_client/FetchOperation.cpp index 6ad418b5f1f68..6349265dff1f5 100644 --- a/third-party/squangle/src/squangle/mysql_client/FetchOperation.cpp +++ b/third-party/squangle/src/squangle/mysql_client/FetchOperation.cpp @@ -87,6 +87,21 @@ const InternalConnection& FetchOperationImpl::getInternalConnection() const { return conn().getInternalConnection(); } +void FetchOperationImpl::cancel() { + // Free any allocated results before the connection is closed + // We need to do this in the mysql_thread for async versions as the + // mysql_thread _might_ be using that memory + auto cancelFn = [&]() { + current_row_stream_ = folly::none; + OperationBase::cancel(); + }; + if (client_.isInCorrectThread(true)) { + cancelFn(); + } else { + client_.runInThread(std::move(cancelFn), true /*wait*/); + } +} + uint64_t FetchOperationImpl::currentLastInsertId() const { CHECK_THROW(isStreamAccessAllowed(), db::OperationStateException); return current_last_insert_id_; diff --git a/third-party/squangle/src/squangle/mysql_client/FetchOperation.h b/third-party/squangle/src/squangle/mysql_client/FetchOperation.h index bf0715556055e..94f76302c49d7 100644 --- a/third-party/squangle/src/squangle/mysql_client/FetchOperation.h +++ b/third-party/squangle/src/squangle/mysql_client/FetchOperation.h @@ -132,11 +132,7 @@ class FetchOperationImpl : virtual public OperationBase { use_checksum_ = useChecksum; } - void cancel() override { - // Free any allocated results before the connection is closed - current_row_stream_ = folly::none; - OperationBase::cancel(); - } + void cancel() override; uint64_t currentLastInsertId() const; uint64_t currentAffectedRows() const; diff --git a/third-party/squangle/src/squangle/mysql_client/Operation.h b/third-party/squangle/src/squangle/mysql_client/Operation.h index 879db595fb299..01166c1044566 100644 --- a/third-party/squangle/src/squangle/mysql_client/Operation.h +++ b/third-party/squangle/src/squangle/mysql_client/Operation.h @@ -174,7 +174,11 @@ class OperationBase { } OperationState state() const { - return state_; + return state_.load(std::memory_order_relaxed); + } + + bool isCancelling() const { + return state() == OperationState::Cancelling; } void setObserverCallback(ObserverCallback obs_cb); @@ -217,7 +221,7 @@ class OperationBase { void setConnectionContext( std::shared_ptr context) { CHECK_THROW( - state_ == OperationState::Unstarted, db::OperationStateException); + state() == OperationState::Unstarted, db::OperationStateException); connection_context_ = std::move(context); } @@ -454,7 +458,7 @@ class OperationBase { friend class SyncConnection; // Data members; subclasses freely interact with these. - OperationState state_{OperationState::Unstarted}; + std::atomic state_{OperationState::Unstarted}; OperationResult result_{OperationResult::Unknown}; // Connection or query attributes (depending on the Operation type) diff --git a/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.cpp b/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.cpp index 360687435a3bd..ef1cc16c51507 100644 --- a/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.cpp +++ b/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.cpp @@ -24,10 +24,20 @@ bool MysqlFetchOperationImpl::isStreamAccessAllowed() const { return isPaused() || isInEventBaseThread(); } -bool MysqlFetchOperationImpl::isPaused() const { +bool MysqlFetchOperationImpl::isPausedImpl() const { return active_fetch_action_ == FetchAction::WaitForConsumer; } +bool MysqlFetchOperationImpl::isPaused() const { + if (client_.isInCorrectThread(true)) { + return isPausedImpl(); + } + + bool isPaused = false; + client_.runInThread([&]() { isPaused = isPausedImpl(); }, true); + return isPaused; +} + void MysqlFetchOperationImpl::specializedRun() { if (!conn().runInThread([&]() { specializedRunImpl(); })) { completeOperationInner(OperationResult::Failed); @@ -277,7 +287,7 @@ void MysqlFetchOperationImpl::pauseForConsumer() { } void MysqlFetchOperationImpl::resumeImpl() { - CHECK_THROW(isPaused(), db::OperationStateException); + CHECK_THROW(isPausedImpl(), db::OperationStateException); // We should only allow pauses during fetch or between queries. // If we come back as RowsFetched and the stream has completed the query, diff --git a/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.h b/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.h index 699b02221fa69..3e29a03cc8b51 100644 --- a/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.h +++ b/third-party/squangle/src/squangle/mysql_client/mysql_protocol/MysqlFetchOperationImpl.h @@ -49,6 +49,7 @@ class MysqlFetchOperationImpl : public MysqlOperationImpl, private: void resumeImpl(); + bool isPausedImpl() const; // Checks if the current thread has access to stream, or result data. bool isStreamAccessAllowed() const override;