diff --git a/tensorstore/kvstore/test_matchers.h b/tensorstore/kvstore/test_matchers.h index 117487343..a81d0cdac 100644 --- a/tensorstore/kvstore/test_matchers.h +++ b/tensorstore/kvstore/test_matchers.h @@ -15,6 +15,8 @@ #ifndef TENSORSTORE_KVSTORE_TEST_MATCHERS_H_ #define TENSORSTORE_KVSTORE_TEST_MATCHERS_H_ +#include + #include #include diff --git a/tensorstore/kvstore/tsgrpc/BUILD b/tensorstore/kvstore/tsgrpc/BUILD index bb3f88690..4b231452f 100644 --- a/tensorstore/kvstore/tsgrpc/BUILD +++ b/tensorstore/kvstore/tsgrpc/BUILD @@ -71,7 +71,6 @@ tensorstore_cc_library( "//tensorstore/util:result", "@com_github_grpc_grpc//:grpc++", "@com_github_grpc_grpc//:grpc++_public_hdrs", - "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/time", ], @@ -215,6 +214,7 @@ tensorstore_cc_test( "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorstore/kvstore/tsgrpc/handler_template.h b/tensorstore/kvstore/tsgrpc/handler_template.h index 3b9c25448..a0cb89f50 100644 --- a/tensorstore/kvstore/tsgrpc/handler_template.h +++ b/tensorstore/kvstore/tsgrpc/handler_template.h @@ -15,11 +15,8 @@ #ifndef TENSORSTORE_KVSTORE_TSGRPC_HANDLER_TEMPLATE_H_ #define TENSORSTORE_KVSTORE_TSGRPC_HANDLER_TEMPLATE_H_ -#include +#include -#include - -#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "grpcpp/grpcpp.h" // third_party #include "grpcpp/server_context.h" // third_party @@ -37,7 +34,7 @@ class HandlerBase HandlerBase(::grpc::CallbackServerContext* grpc_context) : grpc_context_(grpc_context) { - // This refcount should be adopoted by calling Adopt. + // This refcount should be adopted by calling Adopt. intrusive_ptr_increment(this); } @@ -78,17 +75,18 @@ class Handler : public HandlerBase, public grpc::ServerUnaryReactor { Response* response_; }; -// Handler base class for a stream request. +// Handler base class for an RPC with a streaming response. template -class StreamHandler : public HandlerBase, - public grpc::ServerWriteReactor { +class StreamServerResponseHandler + : public HandlerBase, + public grpc::ServerWriteReactor { public: using Request = RequestProto; using Response = ResponseProto; using Reactor = typename grpc::ServerWriteReactor; - StreamHandler(::grpc::CallbackServerContext* grpc_context, - const Request* request) + StreamServerResponseHandler(::grpc::CallbackServerContext* grpc_context, + const Request* request) : HandlerBase(grpc_context), request_(request) {} using Reactor::Finish; @@ -104,6 +102,33 @@ class StreamHandler : public HandlerBase, const Request* request_; }; +// Handler base class for an RPC with a streaming request. +template +class StreamClientRequestHandler + : public HandlerBase, + public grpc::ServerReadReactor { + public: + using Request = RequestProto; + using Response = ResponseProto; + using Reactor = typename grpc::ServerReadReactor; + + StreamClientRequestHandler(::grpc::CallbackServerContext* grpc_context, + Response* response) + : HandlerBase(grpc_context), response_(response) {} + + using Reactor::Finish; + void Finish(absl::Status status) { + Finish(tensorstore::internal::AbslStatusToGrpcStatus(status)); + } + + Response* response() { return response_; } + + protected: + void OnDone() final { auto adopted = Adopt(); } + + Response* response_; +}; + } // namespace tensorstore_grpc #endif // TENSORSTORE_KVSTORE_TSGRPC_HANDLER_TEMPLATE_H_ diff --git a/tensorstore/kvstore/tsgrpc/kvstore.proto b/tensorstore/kvstore/tsgrpc/kvstore.proto index 19f5746df..82ecec57d 100644 --- a/tensorstore/kvstore/tsgrpc/kvstore.proto +++ b/tensorstore/kvstore/tsgrpc/kvstore.proto @@ -22,10 +22,10 @@ import "tensorstore/kvstore/tsgrpc/common.proto"; // Proto-api for a remote key-value store service KvStoreService { /// Attempts to read the specified key. - rpc Read(ReadRequest) returns (ReadResponse); + rpc Read(ReadRequest) returns (stream ReadResponse); /// Performs an optionally-conditional write. - rpc Write(WriteRequest) returns (WriteResponse); + rpc Write(stream WriteRequest) returns (WriteResponse); /// Performs an optionally-conditional delete. rpc Delete(DeleteRequest) returns (DeleteResponse); @@ -73,6 +73,11 @@ message ReadRequest { } message ReadResponse { + // When multiple ReadResponse messages are received, all messages after the + // initial message are partial responses, and only the value_part field is + // meaningful. In such a case, the value is the catenation of all value_parts + // in order. + /// Optionally, a non-ok status message may be returned. StatusMessage status = 1; @@ -92,14 +97,22 @@ message ReadResponse { /// Indicates a value is present. VALUE = 2; } - State state = 3; - bytes value = 4 [ctype = CORD]; + + // Partial value. Only meaningful when state is VALUE. + // The actaual value is the catenation of all value_part fields in order. + bytes value_part = 4 [ctype = CORD]; } /// See tensorstore/kvstore/operations.h /// kvstore::WriteOptions message WriteRequest { + // When multiple WriteRequest messages are sent, all messages after the + // initial message are partial requests, and only the value_part field + // is meaningful. In such a case, the value is the catenation of all + // value_parts in order. + + /// The key to write. bytes key = 1; /// The write is aborted if the existing generation associated with the @@ -111,7 +124,10 @@ message WriteRequest { /// - The special value of `StorageGeneration::NoValue()` specifies a /// condition that the `key` does not have an existing value. bytes generation_if_equal = 2; - bytes value = 3 [ctype = CORD]; + + // Partial value. + // The actaual value is the catenation of all value_part fields in order. + bytes value_part = 3 [ctype = CORD]; } message WriteResponse { @@ -125,8 +141,12 @@ message WriteResponse { /// See tensorstore/kvstore/operations.h /// kvstore::WriteOptions message DeleteRequest { + // The key to delete. bytes key = 1; KeyRange range = 2; + + /// The delete is aborted if the existing generation associated with the + /// stored `key` does not match `if_equal`. bytes generation_if_equal = 3; } diff --git a/tensorstore/kvstore/tsgrpc/kvstore_server.cc b/tensorstore/kvstore/tsgrpc/kvstore_server.cc index ddf016901..f54ea5c58 100644 --- a/tensorstore/kvstore/tsgrpc/kvstore_server.cc +++ b/tensorstore/kvstore/tsgrpc/kvstore_server.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -74,7 +75,8 @@ using ::tensorstore::internal_metrics::MetricMetadata; using ::tensorstore::kvstore::ListEntry; using ::tensorstore_grpc::EncodeGenerationAndTimestamp; using ::tensorstore_grpc::Handler; -using ::tensorstore_grpc::StreamHandler; +using ::tensorstore_grpc::StreamClientRequestHandler; +using ::tensorstore_grpc::StreamServerResponseHandler; using ::tensorstore_grpc::kvstore::DeleteRequest; using ::tensorstore_grpc::kvstore::DeleteResponse; using ::tensorstore_grpc::kvstore::ListRequest; @@ -108,13 +110,16 @@ auto& list_metric = internal_metrics::Counter::New( ABSL_CONST_INIT internal_log::VerboseFlag verbose_logging("tsgrpc_kvstore"); -class ReadHandler final : public Handler { - using Base = Handler; +constexpr size_t kMaxReadChunkSize = 1 << 20; + +class ReadHandler final + : public StreamServerResponseHandler { + using Base = StreamServerResponseHandler; public: ReadHandler(CallbackServerContext* grpc_context, const Request* request, - Response* response, KvStore kvstore) - : Base(grpc_context, request, response), kvstore_(std::move(kvstore)) {} + KvStore kvstore) + : Base(grpc_context, request), kvstore_(std::move(kvstore)) {} void Run() { ABSL_LOG_IF(INFO, verbose_logging) @@ -142,15 +147,36 @@ class ReadHandler final : public Handler { } internal::IntrusivePtr self{this}; - future_ = - PromiseFuturePair::Link( - [self = std::move(self)](tensorstore::Promise promise, - auto read_result) { - if (!promise.result_needed()) return; - promise.SetResult(self->HandleResult(read_result.result())); - }, - tensorstore::kvstore::Read(kvstore_, request()->key(), options)) - .future; + future_ = tensorstore::kvstore::Read(kvstore_, request()->key(), options); + future_.ExecuteWhenReady( + [self = std::move(self)](ReadyFuture ready) { + self->HandleInitialResult(std::move(ready).result()); + }); + } + + void HandleInitialResult(Result result) { + auto status = result.status(); + if (!status.ok()) { + // Consider setting the status in the response instead. + Finish(status); + return; + } + + auto& r = result.value(); + response_.set_state(static_cast(r.state)); + EncodeGenerationAndTimestamp(r.stamp, &response_); + + value_ = std::move(r.value); + value_offset_ = 0; + + SetNextPart(); + StartWrite(&response_); + } + + void SetNextPart() { + auto next_part = value_.Subcord(value_offset_, kMaxReadChunkSize); + value_offset_ = std::min(value_.size(), value_offset_ + next_part.size()); + response_.set_value_part(std::move(next_part)); } void OnCancel() final { @@ -159,50 +185,75 @@ class ReadHandler final : public Handler { Finish(::grpc::Status(::grpc::StatusCode::CANCELLED, "")); } - absl::Status HandleResult(const Result& result) { - auto status = result.status(); - if (status.ok()) { - auto& r = result.value(); - response()->set_state(static_cast(r.state)); - EncodeGenerationAndTimestamp(r.stamp, response()); - if (r.has_value()) { - response()->set_value(r.value); - } + void OnWriteDone(bool ok) final { + if (!ok) { + // OnDone is going to be called after we return from this method. + Finish(::grpc::Status(::grpc::StatusCode::UNKNOWN, "Write failed")); + return; } - Finish(status); - return status; + if (value_offset_ == value_.size()) { + Finish(::grpc::Status::OK); + return; + } + response_.Clear(); + SetNextPart(); + StartWrite(&response_); } private: KvStore kvstore_; - Future future_; + Future future_; + + ReadResponse response_; + absl::Cord value_; + size_t value_offset_ = 0; }; -class WriteHandler final : public Handler { - using Base = Handler; +class WriteHandler final + : public StreamClientRequestHandler { + using Base = StreamClientRequestHandler; public: - WriteHandler(CallbackServerContext* grpc_context, const Request* request, - Response* response, KvStore kvstore) - : Base(grpc_context, request, response), kvstore_(std::move(kvstore)) {} + WriteHandler(CallbackServerContext* grpc_context, Response* response, + KvStore kvstore) + : Base(grpc_context, response), kvstore_(std::move(kvstore)) {} - void Run() { - ABSL_LOG_IF(INFO, verbose_logging) - << "WriteHandler " << ConciseDebugString(*request()); - tensorstore::kvstore::WriteOptions options{}; - options.generation_conditions.if_equal.value = - request()->generation_if_equal(); + void Run() { StartRead(&request_); } + void OnReadDone(bool ok) override { + if (grpc_context()->IsCancelled()) { + // OnCancelled is going to be called and will invoke Finish, thus just + // stop reading sequence here. + return; + } + if (ok) { + ABSL_LOG_IF(INFO, verbose_logging) + << "WriteHandler " << ConciseDebugString(request_); + // One read chunk has completed, but not everything. + // According to protocol, state, generation and timestamp must be taken + // from the first message in the stream. + if (!first_request_received_) { + first_request_received_ = true; + options_.generation_conditions.if_equal.value = + request_.generation_if_equal(); + key_ = request_.key(); + } + + value_.Append(request_.value_part()); + StartRead(&request_); + return; + } + + ABSL_LOG_IF(INFO, verbose_logging) << "WriteHandler starting write"; + + // Setup the response. internal::IntrusivePtr self{this}; - future_ = - PromiseFuturePair::Link( - [self = std::move(self)](Promise promise, auto write_result) { - if (!promise.result_needed()) return; - promise.SetResult(self->HandleResult(write_result.result())); - }, - kvstore::Write(kvstore_, request()->key(), - absl::Cord(request()->value()), options)) - .future; + future_ = tensorstore::kvstore::Write(kvstore_, key_, value_, options_); + future_.ExecuteWhenReady( + [self = + std::move(self)](ReadyFuture ready) { + self->HandleResult(std::move(ready).result()); + }); } void OnCancel() final { @@ -211,19 +262,24 @@ class WriteHandler final : public Handler { Finish(::grpc::Status(::grpc::StatusCode::CANCELLED, "")); } - absl::Status HandleResult( - const tensorstore::Result& result) { + void HandleResult(tensorstore::Result& result) { auto status = result.status(); if (status.ok()) { EncodeGenerationAndTimestamp(result.value(), response()); } Finish(status); - return status; } private: KvStore kvstore_; - Future future_; + tensorstore::kvstore::WriteOptions options_; + + Future future_; + WriteRequest request_; + + std::string key_; + absl::Cord value_; + bool first_request_received_ = false; }; class DeleteHandler final : public Handler { @@ -294,8 +350,9 @@ class DeleteHandler final : public Handler { tensorstore::Future future_; }; -class ListHandler final : public StreamHandler { - using Base = StreamHandler; +class ListHandler final + : public StreamServerResponseHandler { + using Base = StreamServerResponseHandler; public: ListHandler(CallbackServerContext* grpc_context, const Request* request, @@ -457,12 +514,12 @@ class KvStoreServer::Impl final : public KvStoreService::CallbackService { public: Impl(KvStore kvstore) : kvstore_(std::move(kvstore)) {} - ::grpc::ServerUnaryReactor* Read(::grpc::CallbackServerContext* context, - const ReadRequest* request, - ReadResponse* response) override { + ::grpc::ServerWriteReactor<::tensorstore_grpc::kvstore::ReadResponse>* Read( + ::grpc::CallbackServerContext* context, + const ReadRequest* request) override { read_metric.Increment(); internal::IntrusivePtr handler( - new ReadHandler(context, request, response, kvstore_)); + new ReadHandler(context, request, kvstore_)); assert(handler->use_count() == 2); handler->Run(); assert(handler->use_count() > 0); @@ -470,12 +527,12 @@ class KvStoreServer::Impl final : public KvStoreService::CallbackService { return handler.get(); } - ::grpc::ServerUnaryReactor* Write(::grpc::CallbackServerContext* context, - const WriteRequest* request, - WriteResponse* response) override { + ::grpc::ServerReadReactor<::tensorstore_grpc::kvstore::WriteRequest>* Write( + ::grpc::CallbackServerContext* context, + WriteResponse* response) override { write_metric.Increment(); internal::IntrusivePtr handler( - new WriteHandler(context, request, response, kvstore_)); + new WriteHandler(context, response, kvstore_)); assert(handler->use_count() == 2); handler->Run(); assert(handler->use_count() > 0); @@ -496,7 +553,7 @@ class KvStoreServer::Impl final : public KvStoreService::CallbackService { return handler.get(); } - ::grpc::ServerWriteReactor< ::tensorstore_grpc::kvstore::ListResponse>* List( + ::grpc::ServerWriteReactor<::tensorstore_grpc::kvstore::ListResponse>* List( ::grpc::CallbackServerContext* context, const ListRequest* request) override { list_metric.Increment(); diff --git a/tensorstore/kvstore/tsgrpc/kvstore_server_test.cc b/tensorstore/kvstore/tsgrpc/kvstore_server_test.cc index ef75df93b..2acccf2cb 100644 --- a/tensorstore/kvstore/tsgrpc/kvstore_server_test.cc +++ b/tensorstore/kvstore/tsgrpc/kvstore_server_test.cc @@ -23,6 +23,7 @@ #include "absl/strings/cord.h" #include "absl/strings/str_format.h" #include "absl/synchronization/notification.h" +#include "absl/time/clock.h" #include #include "tensorstore/context.h" #include "tensorstore/kvstore/key_range.h" @@ -43,7 +44,9 @@ namespace { namespace kvstore = ::tensorstore::kvstore; using ::tensorstore::KeyRange; using ::tensorstore::grpc_kvstore::KvStoreServer; +using ::tensorstore::internal::IsRegularStorageGeneration; using ::tensorstore::internal::MatchesKvsReadResultNotFound; +using ::tensorstore::internal::MatchesTimestampedStorageGeneration; class KvStoreSingleton { public: @@ -217,4 +220,35 @@ TEST_F(KvStoreTest, List) { } } +TEST_F(KvStoreTest, MultiPartReadWrite) { + absl::Cord value; + char x = ' '; + while (value.size() < (2 << 20)) { + value.Append(std::string(1 << 12, x)); + x++; + if (static_cast(x) > 126) x = ' '; + } + + auto context = tensorstore::Context::Default(); + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto store, tensorstore::kvstore::Open({{"driver", "tsgrpc_kvstore"}, + {"address", address()}, + {"path", "large/"}}, + context) + .result()); + + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto generation, kvstore::Write(store, "large_value", value).result()); + EXPECT_THAT(generation.generation, IsRegularStorageGeneration()); + + // memory kvstore driver returns the current time as the timestamp. + auto now = absl::Now(); + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + auto result, kvstore::Read(store, "large_value").result()); + + EXPECT_EQ(value, result.value); + EXPECT_THAT(result.stamp, MatchesTimestampedStorageGeneration( + generation.generation, testing::Ge(now))); +} + } // namespace diff --git a/tensorstore/kvstore/tsgrpc/mock_kvstore_service.h b/tensorstore/kvstore/tsgrpc/mock_kvstore_service.h index 59fd72821..6726fdd01 100644 --- a/tensorstore/kvstore/tsgrpc/mock_kvstore_service.h +++ b/tensorstore/kvstore/tsgrpc/mock_kvstore_service.h @@ -34,10 +34,12 @@ class MockKvStoreService : public kvstore::grpc_gen::KvStoreService::Service { public: using ServiceType = ::tensorstore_grpc::kvstore::grpc_gen::KvStoreService; - TENSORSTORE_GRPC_MOCK(Read, ::tensorstore_grpc::kvstore::ReadRequest, - ::tensorstore_grpc::kvstore::ReadResponse); - TENSORSTORE_GRPC_MOCK(Write, ::tensorstore_grpc::kvstore::WriteRequest, - ::tensorstore_grpc::kvstore::WriteResponse); + TENSORSTORE_GRPC_SERVER_STREAMING_MOCK( + Read, ::tensorstore_grpc::kvstore::ReadRequest, + ::tensorstore_grpc::kvstore::ReadResponse); + TENSORSTORE_GRPC_CLIENT_STREAMING_MOCK( + Write, ::tensorstore_grpc::kvstore::WriteRequest, + ::tensorstore_grpc::kvstore::WriteResponse); TENSORSTORE_GRPC_MOCK(Delete, ::tensorstore_grpc::kvstore::DeleteRequest, ::tensorstore_grpc::kvstore::DeleteResponse); TENSORSTORE_GRPC_SERVER_STREAMING_MOCK( diff --git a/tensorstore/kvstore/tsgrpc/tsgrpc.cc b/tensorstore/kvstore/tsgrpc/tsgrpc.cc index 1da1c40cf..766ce8e5c 100644 --- a/tensorstore/kvstore/tsgrpc/tsgrpc.cc +++ b/tensorstore/kvstore/tsgrpc/tsgrpc.cc @@ -15,8 +15,10 @@ /// \file /// Key-value store proxied over grpc. +#include #include +#include #include #include #include @@ -32,6 +34,7 @@ #include "grpcpp/channel.h" // third_party #include "grpcpp/client_context.h" // third_party #include "grpcpp/create_channel.h" // third_party +#include "grpcpp/support/client_callback.h" // third_party #include "grpcpp/support/status.h" // third_party #include "grpcpp/support/sync_stream.h" // third_party #include "tensorstore/context.h" @@ -113,6 +116,7 @@ auto tsgrpc_metrics = []() -> TsGrpcMetrics { ABSL_CONST_INIT internal_log::VerboseFlag verbose_logging("tsgrpc_kvstore"); +constexpr size_t kMaxWriteChunkSize = 1 << 20; struct TsGrpcKeyValueStoreSpecData { std::string address; @@ -153,6 +157,8 @@ class TsGrpcKeyValueStore : public internal_kvstore::RegisteredDriver { public: + TsGrpcKeyValueStore(const TsGrpcKeyValueStoreSpecData& spec) : spec_(spec) {} + void MaybeSetDeadline(grpc::ClientContext& context) { if (spec_.timeout > absl::ZeroDuration() && spec_.timeout != absl::InfiniteDuration()) { @@ -182,157 +188,218 @@ class TsGrpcKeyValueStore void ListImpl(ListOptions options, ListReceiver receiver) override; - SpecData spec_; + TsGrpcKeyValueStoreSpecData spec_; std::shared_ptr channel_; std::unique_ptr stub_; }; //////////////////////////////////////////////////// -/// Implements `TsGrpcKeyValueStore::Read`. -struct ReadTask : public internal::AtomicReferenceCount { - internal::IntrusivePtr driver; - grpc::ClientContext context; - ReadRequest request; - ReadResponse response; +// Implements TsGrpcKeyValueStore::Read +// TODO: Add retries. +struct ReadTask : public internal::AtomicReferenceCount, + public grpc::ClientReadReactor { + Executor executor_; + Promise promise_; - Future Start(kvstore::Key key, - const kvstore::ReadOptions& options) { - request.set_key(std::move(key)); - request.set_generation_if_equal( - options.generation_conditions.if_equal.value); - request.set_generation_if_not_equal( - options.generation_conditions.if_not_equal.value); - if (!options.byte_range.IsFull()) { - request.mutable_byte_range()->set_inclusive_min( - options.byte_range.inclusive_min); - request.mutable_byte_range()->set_exclusive_max( - options.byte_range.exclusive_max); - } - if (options.staleness_bound != absl::InfiniteFuture()) { - AbslTimeToProto(options.staleness_bound, - request.mutable_staleness_bound()); + // working state. + grpc::ClientContext context_; + kvstore::ReadOptions options_; + ReadRequest request_; + ReadResponse response_; + kvstore::ReadResult result_; + + ReadTask(Executor executor, Promise promise) + : executor_(std::move(executor)), promise_(std::move(promise)) {} + + void TryCancel() { context_.TryCancel(); } + + void Start(KvStoreService::StubInterface* stub) { + intrusive_ptr_increment(this); // adopted in OnDone. + stub->async()->Read(&context_, &request_, this); + + StartRead(&response_); + StartCall(); + } + + void OnReadDone(bool ok) override { + if (!ok) return; + if (!promise_.result_needed()) { + TryCancel(); + return; } - driver->MaybeSetDeadline(context); + auto status = [&]() -> absl::Status { + if (auto status = GetMessageStatus(response_); !status.ok()) { + return status; + } - internal::IntrusivePtr self(this); - auto pair = tensorstore::PromiseFuturePair::Make(); - pair.promise.ExecuteWhenNotNeeded([self] { self->context.TryCancel(); }); + if (result_.value.empty()) { + auto stamp = DecodeGenerationAndTimestamp(response_); + if (!stamp.ok()) { + return std::move(stamp).status(); + } + result_.stamp = std::move(stamp).value(); + result_.state = + static_cast(response_.state()); + } - driver->stub()->async()->Read( - &context, &request, &response, - WithExecutor(driver->executor(), [self = std::move(self), - promise = std::move(pair.promise)]( - ::grpc::Status s) { - if (!promise.result_needed()) return; - promise.SetResult(self->Ready(GrpcStatusToAbslStatus(s))); - })); + result_.value.Append(response_.value_part()); + StartRead(&response_); + return absl::OkStatus(); + }(); - return std::move(pair.future); + if (!status.ok()) { + promise_.SetResult(std::move(status)); + TryCancel(); + } } - Result Ready(absl::Status status) { + void OnDone(const grpc::Status& s) override { + internal::IntrusivePtr self(this, internal::adopt_object_ref); + executor_([self = std::move(self), status = s]() { + self->ReadFinished(GrpcStatusToAbslStatus(status)); + }); + } + + void ReadFinished(absl::Status status) { + // Streaming read complete. + if (!promise_.result_needed()) { + return; + } ABSL_LOG_IF(INFO, verbose_logging) - << "ReadTask::Ready " << ConciseDebugString(response) << " " << status; + << "ReadTask::ReadFinished " << ConciseDebugString(response_) << " " + << status; - TENSORSTORE_RETURN_IF_ERROR(status); - TENSORSTORE_RETURN_IF_ERROR(GetMessageStatus(response)); - TENSORSTORE_ASSIGN_OR_RETURN(auto stamp, - DecodeGenerationAndTimestamp(response)); - return kvstore::ReadResult{ - static_cast(response.state()), - absl::Cord(response.value()), - std::move(stamp), - }; + if (!status.ok()) { + promise_.SetResult(status); + } else { + promise_.SetResult(std::move(result_)); + } } }; -/// Implements `TsGrpcKeyValueStore::Write`. -struct WriteTask : public internal::AtomicReferenceCount { - internal::IntrusivePtr driver; - grpc::ClientContext context; - WriteRequest request; - WriteResponse response; +/// Key value store operations. +Future TsGrpcKeyValueStore::Read(Key key, + ReadOptions options) { + tsgrpc_metrics.read.Increment(); - Future Start( - kvstore::Key key, const absl::Cord value, - const kvstore::WriteOptions& options) { - request.set_key(std::move(key)); - request.set_value(value); - request.set_generation_if_equal( - options.generation_conditions.if_equal.value); + auto pair = PromiseFuturePair::Make(); + + auto task = + internal::MakeIntrusivePtr(executor(), std::move(pair.promise)); + MaybeSetDeadline(task->context_); + + auto& request = task->request_; + request.set_key(std::move(key)); + request.set_generation_if_equal(options.generation_conditions.if_equal.value); + request.set_generation_if_not_equal( + options.generation_conditions.if_not_equal.value); + if (!options.byte_range.IsFull()) { + request.mutable_byte_range()->set_inclusive_min( + options.byte_range.inclusive_min); + request.mutable_byte_range()->set_exclusive_max( + options.byte_range.exclusive_max); + } + if (options.staleness_bound != absl::InfiniteFuture()) { + AbslTimeToProto(options.staleness_bound, request.mutable_staleness_bound()); + } - driver->MaybeSetDeadline(context); + task->Start(stub_.get()); + task->promise_.ExecuteWhenNotNeeded([t = task] { t->TryCancel(); }); + return std::move(pair.future); +} - internal::IntrusivePtr self(this); - auto pair = - tensorstore::PromiseFuturePair::Make(); - pair.promise.ExecuteWhenNotNeeded([self] { self->context.TryCancel(); }); - - driver->stub()->async()->Write( - &context, &request, &response, - WithExecutor(driver->executor(), [self = std::move(self), - promise = std::move(pair.promise)]( - ::grpc::Status s) { - if (!promise.result_needed()) return; - promise.SetResult(self->Ready(GrpcStatusToAbslStatus(s))); - })); - return std::move(pair.future); +////////////////////////////////////////////////////////////////////////// + +// Implements TsGrpcKeyValueStore::Write +// TODO: Add retries. +struct WriteTask : public internal::AtomicReferenceCount, + public grpc::ClientWriteReactor { + Executor executor_; + Promise promise_; + absl::Cord value_; + + // working state. + grpc::ClientContext context_; + WriteRequest request_; + WriteResponse response_; + size_t value_offset_ = 0; + + WriteTask(Executor executor, Promise promise, + absl::Cord value) + : executor_(std::move(executor)), + promise_(std::move(promise)), + value_(std::move(value)) {} + + void UpdateForNextWrite() { + auto next_part = value_.Subcord(value_offset_, kMaxWriteChunkSize); + value_offset_ = std::min(value_.size(), value_offset_ + next_part.size()); + request_.set_value_part(std::move(next_part)); } - Result Ready(absl::Status status) { - ABSL_LOG_IF(INFO, verbose_logging) - << "WriteTask::Ready " << ConciseDebugString(response) << " " << status; - TENSORSTORE_RETURN_IF_ERROR(status); - TENSORSTORE_RETURN_IF_ERROR(GetMessageStatus(response)); - return DecodeGenerationAndTimestamp(response); + void TryCancel() { context_.TryCancel(); } + + void Start(KvStoreService::StubInterface* stub) { + intrusive_ptr_increment(this); // adopted in OnDone. + stub->async()->Write(&context_, &response_, this); + + UpdateForNextWrite(); + + auto options = grpc::WriteOptions(); + if (value_offset_ == value_.size()) { + options.set_last_message(); + } + StartWrite(&request_, options); + StartCall(); } -}; -/// Implements `TsGrpcKeyValueStore::Delete`. -struct DeleteTask : public internal::AtomicReferenceCount { - internal::IntrusivePtr driver; - grpc::ClientContext context; - DeleteRequest request; - DeleteResponse response; + void OnWriteDone(bool ok) override { + // Not streaming any additional data bits. + if (!ok) return; + if (value_offset_ < value_.size()) { + UpdateForNextWrite(); - Future Start( - kvstore::Key key, const kvstore::WriteOptions options) { - request.set_key(std::move(key)); - request.set_generation_if_equal( - options.generation_conditions.if_equal.value); - return StartImpl(); + auto options = grpc::WriteOptions(); + if (value_offset_ == value_.size()) { + options.set_last_message(); + } + StartWrite(&request_, options); + } } - Future StartRange(KeyRange range) { - request.mutable_range()->set_inclusive_min(range.inclusive_min); - request.mutable_range()->set_exclusive_max(range.exclusive_max); - return StartImpl(); + void OnDone(const grpc::Status& s) override { + internal::IntrusivePtr self(this, internal::adopt_object_ref); + executor_([self = std::move(self), status = s]() { + self->WriteFinished(GrpcStatusToAbslStatus(status)); + }); } - Future StartImpl() { - driver->MaybeSetDeadline(context); + void WriteFinished(absl::Status status) { + if (!promise_.result_needed()) { + return; + } + ABSL_LOG_IF(INFO, verbose_logging) + << "WriteTask::WriteFinished " << ConciseDebugString(response_) << " " + << status; - internal::IntrusivePtr self(this); - auto pair = - tensorstore::PromiseFuturePair::Make(); - pair.promise.ExecuteWhenNotNeeded([self] { self->context.TryCancel(); }); - - driver->stub()->async()->Delete( - &context, &request, &response, - WithExecutor(driver->executor(), [self = std::move(self), - promise = std::move(pair.promise)]( - ::grpc::Status s) { - if (!promise.result_needed()) return; - promise.SetResult(self->Ready(GrpcStatusToAbslStatus(s))); - })); - return std::move(pair.future); + promise_.SetResult([&]() -> Result { + TENSORSTORE_RETURN_IF_ERROR(status); + TENSORSTORE_RETURN_IF_ERROR(GetMessageStatus(response_)); + return DecodeGenerationAndTimestamp(response_); + }()); } +}; + +////////////////////////////////////////////////////////////////////////// + +struct DeleteCallbackState { + grpc::ClientContext context; + DeleteResponse response; Result Ready(absl::Status status) { ABSL_LOG_IF(INFO, verbose_logging) - << "DeleteTask::Ready " << ConciseDebugString(response) << " " + << "DeleteCallbackState " << ConciseDebugString(response) << " " << status; TENSORSTORE_RETURN_IF_ERROR(status); TENSORSTORE_RETURN_IF_ERROR(GetMessageStatus(response)); @@ -340,6 +407,82 @@ struct DeleteTask : public internal::AtomicReferenceCount { } }; +Future TsGrpcKeyValueStore::Write( + Key key, std::optional value, WriteOptions options) { + auto pair = PromiseFuturePair::Make(); + + if (!value) { + // empty value is delete. + tsgrpc_metrics.delete_calls.Increment(); + + DeleteRequest request; + request.set_key(std::move(key)); + request.set_generation_if_equal( + options.generation_conditions.if_equal.value); + + auto callback = std::make_shared(); + MaybeSetDeadline(callback->context); + pair.promise.ExecuteWhenNotNeeded( + [callback] { callback->context.TryCancel(); }); + + auto* callback_ptr = callback.get(); + stub()->async()->Delete( + &callback_ptr->context, &request, &callback_ptr->response, + WithExecutor( + executor(), [callback = std::move(callback), + promise = std::move(pair.promise)](::grpc::Status s) { + if (!promise.result_needed()) return; + promise.SetResult(callback->Ready(GrpcStatusToAbslStatus(s))); + })); + return std::move(pair.future); + } + + tsgrpc_metrics.write.Increment(); + + auto state = internal::MakeIntrusivePtr( + executor(), std::move(pair.promise), *std::move(value)); + MaybeSetDeadline(state->context_); + + auto& request = state->request_; + request.set_key(std::move(key)); + request.set_generation_if_equal(options.generation_conditions.if_equal.value); + + state->Start(stub_.get()); + state->promise_.ExecuteWhenNotNeeded([state] { state->TryCancel(); }); + return std::move(pair.future); +} + +Future TsGrpcKeyValueStore::DeleteRange(KeyRange range) { + if (range.empty()) return absl::OkStatus(); + tsgrpc_metrics.delete_range.Increment(); + + DeleteRequest request; + request.mutable_range()->set_inclusive_min(range.inclusive_min); + request.mutable_range()->set_exclusive_max(range.exclusive_max); + + auto callback = std::make_shared(); + MaybeSetDeadline(callback->context); + + auto pair = PromiseFuturePair::Make(absl::OkStatus()); + pair.promise.ExecuteWhenNotNeeded( + [callback] { callback->context.TryCancel(); }); + + auto* callback_ptr = callback.get(); + stub()->async()->Delete( + &callback_ptr->context, &request, &callback_ptr->response, + WithExecutor( + executor(), [callback = std::move(callback), + promise = std::move(pair.promise)](::grpc::Status s) { + if (!promise.result_needed()) return; + if (auto result = callback->Ready(GrpcStatusToAbslStatus(s)); + !result.ok()) { + promise.SetResult(std::move(result).status()); + } + })); + + return std::move(pair.future); +} + // Implements TsGrpcKeyValueStore::List // NOTE: Convert to async(). struct ListTask { @@ -399,46 +542,6 @@ struct ListTask { } }; -/// Key value store operations. -Future TsGrpcKeyValueStore::Read(Key key, - ReadOptions options) { - tsgrpc_metrics.read.Increment(); - auto task = internal::MakeIntrusivePtr(); - task->driver = internal::IntrusivePtr(this); - return task->Start(std::move(key), options); -} - -Future TsGrpcKeyValueStore::Write( - Key key, std::optional value, WriteOptions options) { - if (value) { - tsgrpc_metrics.write.Increment(); - auto task = internal::MakeIntrusivePtr(); - task->driver = internal::IntrusivePtr(this); - return task->Start(std::move(key), value.value(), options); - } else { - // empty value is delete. - tsgrpc_metrics.delete_calls.Increment(); - auto task = internal::MakeIntrusivePtr(); - task->driver = internal::IntrusivePtr(this); - return task->Start(std::move(key), options); - } -} - -Future TsGrpcKeyValueStore::DeleteRange(KeyRange range) { - if (range.empty()) return absl::OkStatus(); - tsgrpc_metrics.delete_range.Increment(); - auto task = internal::MakeIntrusivePtr(); - task->driver = internal::IntrusivePtr(this); - - // Convert Future to Future - return MapFuture( - InlineExecutor{}, - [](const Result& result) { - return MakeResult(result.status()); - }, - task->StartRange(std::move(range))); -} - void TsGrpcKeyValueStore::ListImpl(ListOptions options, ListReceiver receiver) { if (options.range.empty()) { execution::set_starting(receiver, [] {}); @@ -461,15 +564,15 @@ void TsGrpcKeyValueStore::ListImpl(ListOptions options, ListReceiver receiver) { } Future TsGrpcKeyValueStoreSpec::DoOpen() const { - auto driver = internal::MakeIntrusivePtr(); - driver->spec_ = data_; + auto driver = internal::MakeIntrusivePtr(data_); // Create a communication channel with credentials, then use that // to construct a gprc stub. // // TODO: Determine a better mapping to a grpc credentials for this. - // grpc::Credentials ties the authentication to the communication channel - // See: , https://grpc.io/docs/guides/auth/ + // grpc::Credentials ties the authentication to the communication + // channel See: , + // https://grpc.io/docs/guides/auth/ ABSL_LOG_IF(INFO, verbose_logging) << "tsgrpc_kvstore address=" << data_.address; diff --git a/tensorstore/kvstore/tsgrpc/tsgrpc_test.cc b/tensorstore/kvstore/tsgrpc/tsgrpc_test.cc index 8454eafcd..63cd3098a 100644 --- a/tensorstore/kvstore/tsgrpc/tsgrpc_test.cc +++ b/tensorstore/kvstore/tsgrpc/tsgrpc_test.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include #include #include @@ -97,7 +99,7 @@ TEST_F(TsGrpcMockTest, Read) { ReadResponse response = ParseTextProtoOrDie(R"pb( state: 2 - value: '1234' + value_part: '1234' generation_and_timestamp { generation: '1\001' timestamp { seconds: 1634327736 nanos: 123456 } @@ -105,7 +107,12 @@ TEST_F(TsGrpcMockTest, Read) { )pb"); EXPECT_CALL(mock(), Read(_, EqualsProto(expected_request), _)) - .WillOnce(DoAll(SetArgPointee<2>(response), Return(grpc::Status::OK))); + .WillOnce(testing::Invoke( + [=](auto*, auto*, + grpc::ServerWriter* resp) -> ::grpc::Status { + resp->Write(response); + return grpc::Status::OK; + })); kvstore::ReadResult result; { @@ -150,10 +157,57 @@ TEST_F(TsGrpcMockTest, ReadWithOptions) { EXPECT_EQ(result.stamp.generation, StorageGeneration::Unknown()); } +TEST_F(TsGrpcMockTest, ReadMultipart) { + ReadRequest expected_request = ParseTextProtoOrDie(R"pb( + key: 'abc' + )pb"); + + std::vector responses{ + ParseTextProtoOrDie(R"pb( + state: 2 + value_part: '1234' + generation_and_timestamp { + generation: '1\001' + timestamp { seconds: 1634327736 nanos: 123456 } + } + )pb"), + ParseTextProtoOrDie(R"pb( + value_part: '5678' + )pb"), + ParseTextProtoOrDie(R"pb( + value_part: '9012' + )pb"), + }; + + EXPECT_CALL(mock(), Read(_, EqualsProto(expected_request), _)) + .WillOnce(testing::Invoke( + [=](auto*, auto*, + grpc::ServerWriter* resp) -> ::grpc::Status { + for (const auto& response : responses) { + resp->Write(response); + } + return grpc::Status::OK; + })); + + kvstore::ReadResult result; + { + auto store = OpenStore(); + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + result, kvstore::Read(store, expected_request.key()).result()); + } + + // Individual result field verification. + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value, "123456789012"); + EXPECT_EQ(result.stamp.time, + absl::FromUnixSeconds(1634327736) + absl::Nanoseconds(123456)); + EXPECT_EQ(result.stamp.generation, StorageGeneration::FromString("1")); +} + TEST_F(TsGrpcMockTest, Write) { WriteRequest expected_request = ParseTextProtoOrDie(R"pb( key: 'abc' - value: '1234' + value_part: '1234' )pb"); WriteResponse response = ParseTextProtoOrDie(R"pb( @@ -163,15 +217,24 @@ TEST_F(TsGrpcMockTest, Write) { } )pb"); - EXPECT_CALL(mock(), Write(_, EqualsProto(expected_request), _)) - .WillOnce(DoAll(SetArgPointee<2>(response), Return(grpc::Status::OK))); + EXPECT_CALL(mock(), Write(_, _, _)) + .WillOnce( + testing::Invoke([=](auto*, grpc::ServerReader* req, + WriteResponse* resp) -> ::grpc::Status { + WriteRequest actual_request; + EXPECT_TRUE(req->Read(&actual_request)); + EXPECT_THAT(actual_request, EqualsProto(expected_request)); + EXPECT_FALSE(req->Read(&actual_request)); + *resp = response; + return grpc::Status::OK; + })); tensorstore::TimestampedStorageGeneration result; { auto store = OpenStore(); TENSORSTORE_ASSERT_OK_AND_ASSIGN( result, kvstore::Write(store, expected_request.key(), - absl::Cord(expected_request.value())) + absl::Cord(expected_request.value_part())) .result()); } EXPECT_EQ(result.generation, StorageGeneration::FromString("1")); @@ -190,8 +253,17 @@ TEST_F(TsGrpcMockTest, WriteEmpty) { } )pb"); - EXPECT_CALL(mock(), Write(_, EqualsProto(expected_request), _)) - .WillOnce(DoAll(SetArgPointee<2>(response), Return(grpc::Status::OK))); + EXPECT_CALL(mock(), Write(_, _, _)) + .WillOnce( + testing::Invoke([=](auto*, grpc::ServerReader* req, + WriteResponse* resp) -> ::grpc::Status { + WriteRequest actual_request; + EXPECT_TRUE(req->Read(&actual_request)); + EXPECT_THAT(actual_request, EqualsProto(expected_request)); + EXPECT_FALSE(req->Read(&actual_request)); + *resp = response; + return grpc::Status::OK; + })); tensorstore::TimestampedStorageGeneration result; { @@ -207,7 +279,7 @@ TEST_F(TsGrpcMockTest, WriteEmpty) { TEST_F(TsGrpcMockTest, WriteWithOptions) { WriteRequest expected_request = ParseTextProtoOrDie(R"pb( key: 'abc' - value: '1234' + value_part: '1234' generation_if_equal: "abc\001" )pb"); @@ -218,15 +290,24 @@ TEST_F(TsGrpcMockTest, WriteWithOptions) { } )pb"); - EXPECT_CALL(mock(), Write(_, EqualsProto(expected_request), _)) - .WillOnce(DoAll(SetArgPointee<2>(response), Return(grpc::Status::OK))); + EXPECT_CALL(mock(), Write(_, _, _)) + .WillOnce( + testing::Invoke([=](auto*, grpc::ServerReader* req, + WriteResponse* resp) -> ::grpc::Status { + WriteRequest actual_request; + EXPECT_TRUE(req->Read(&actual_request)); + EXPECT_THAT(actual_request, EqualsProto(expected_request)); + EXPECT_FALSE(req->Read(&actual_request)); + *resp = response; + return grpc::Status::OK; + })); tensorstore::TimestampedStorageGeneration result; { auto store = OpenStore(); TENSORSTORE_ASSERT_OK_AND_ASSIGN( result, kvstore::Write(store, expected_request.key(), - absl::Cord(expected_request.value()), + absl::Cord(expected_request.value_part()), {StorageGeneration::FromString("abc")}) .result()); } @@ -260,6 +341,45 @@ TEST_F(TsGrpcMockTest, WriteNullopt) { EXPECT_EQ(result.generation, StorageGeneration::FromString("1")); } +TEST_F(TsGrpcMockTest, WriteMultipart) { + absl::Cord value; + char x = ' '; + while (value.size() < (2 << 20)) { + value.Append(std::string(1 << 12, x)); + x++; + if (static_cast(x) > 126) x = ' '; + }; + + WriteResponse response = ParseTextProtoOrDie(R"pb( + generation_and_timestamp { + generation: '1\001' + timestamp { seconds: 1634327736 nanos: 123456 } + } + )pb"); + + EXPECT_CALL(mock(), Write(_, _, _)) + .WillOnce( + testing::Invoke([=](auto*, grpc::ServerReader* req, + WriteResponse* resp) -> ::grpc::Status { + WriteRequest actual_request; + size_t i = 0; + while (req->Read(&actual_request)) { + i++; + } + EXPECT_EQ(i, 2); + *resp = response; + return grpc::Status::OK; + })); + + tensorstore::TimestampedStorageGeneration result; + { + auto store = OpenStore(); + TENSORSTORE_ASSERT_OK_AND_ASSIGN( + result, kvstore::Write(store, "large_value", value).result()); + } + EXPECT_EQ(result.generation, StorageGeneration::FromString("1")); +} + TEST_F(TsGrpcMockTest, Delete) { DeleteRequest expected_request = ParseTextProtoOrDie(R"pb( key: 'abc'