diff --git a/build/bazelutil/check.sh b/build/bazelutil/check.sh index ca6ff332e576..e2703d20ac35 100755 --- a/build/bazelutil/check.sh +++ b/build/bazelutil/check.sh @@ -18,6 +18,7 @@ GIT_GREP="git $CONFIGS grep" EXISTING_GO_GENERATE_COMMENTS=" pkg/config/field.go://go:generate stringer --type=Field --linecomment pkg/rpc/context.go://go:generate mockgen -destination=mocks_generated_test.go --package=. Dialbacker +pkg/rpc/stream_pool.go://go:generate mockgen -destination=mocks_generated_test.go --package=. BatchStreamClient pkg/roachprod/vm/aws/config.go://go:generate terraformgen -o terraform/main.tf pkg/roachprod/prometheus/prometheus.go://go:generate mockgen -package=prometheus -destination=mocks_generated_test.go . Cluster pkg/cmd/roachtest/clusterstats/collector.go://go:generate mockgen -package=clusterstats -destination mocks_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt index 4e996105888c..d9e3251ff718 100644 --- a/docs/generated/settings/settings-for-tenants.txt +++ b/docs/generated/settings/settings-for-tenants.txt @@ -401,4 +401,4 @@ trace.span_registry.enabled boolean false if set, ongoing traces can be seen at trace.zipkin.collector string the address of a Zipkin instance to receive traces, as :. If no port is specified, 9411 will be used. application ui.database_locality_metadata.enabled boolean true if enabled shows extended locality data about databases and tables in DB Console which can be expensive to compute application ui.display_timezone enumeration etc/utc the timezone used to format timestamps in the ui [etc/utc = 0, america/new_york = 1] application -version version 1000024.3-upgrading-to-1000025.1-step-008 set the active cluster version in the format '.' application +version version 1000024.3-upgrading-to-1000025.1-step-010 set the active cluster version in the format '.' application diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html index f4f277f43c8b..3c7afe299303 100644 --- a/docs/generated/settings/settings.html +++ b/docs/generated/settings/settings.html @@ -360,6 +360,6 @@
trace.zipkin.collector
stringthe address of a Zipkin instance to receive traces, as <host>:<port>. If no port is specified, 9411 will be used.Serverless/Dedicated/Self-Hosted
ui.database_locality_metadata.enabled
booleantrueif enabled shows extended locality data about databases and tables in DB Console which can be expensive to computeServerless/Dedicated/Self-Hosted
ui.display_timezone
enumerationetc/utcthe timezone used to format timestamps in the ui [etc/utc = 0, america/new_york = 1]Serverless/Dedicated/Self-Hosted -
version
version1000024.3-upgrading-to-1000025.1-step-008set the active cluster version in the format '<major>.<minor>'Serverless/Dedicated/Self-Hosted +
version
version1000024.3-upgrading-to-1000025.1-step-010set the active cluster version in the format '<major>.<minor>'Serverless/Dedicated/Self-Hosted diff --git a/pkg/clusterversion/cockroach_versions.go b/pkg/clusterversion/cockroach_versions.go index 4c41adb0251a..cbb01bc889cc 100644 --- a/pkg/clusterversion/cockroach_versions.go +++ b/pkg/clusterversion/cockroach_versions.go @@ -198,6 +198,10 @@ const ( // range-ID local key, which is written below raft. V25_1_AddRangeForceFlushKey + // V25_1_BatchStreamRPC adds the BatchStream RPC, which allows for more + // efficient Batch unary RPCs. + V25_1_BatchStreamRPC + // ************************************************* // Step (1) Add new versions above this comment. // Do not add new versions to a patch release. @@ -240,6 +244,7 @@ var versionTable = [numKeys]roachpb.Version{ V25_1_AddJobsTables: {Major: 24, Minor: 3, Internal: 4}, V25_1_MoveRaftTruncatedState: {Major: 24, Minor: 3, Internal: 6}, V25_1_AddRangeForceFlushKey: {Major: 24, Minor: 3, Internal: 8}, + V25_1_BatchStreamRPC: {Major: 24, Minor: 3, Internal: 10}, // ************************************************* // Step (2): Add new versions above this comment. diff --git a/pkg/cmd/roachtest/roachtestutil/validation_check.go b/pkg/cmd/roachtest/roachtestutil/validation_check.go index d2433f6c0eff..cb24489fb21a 100644 --- a/pkg/cmd/roachtest/roachtestutil/validation_check.go +++ b/pkg/cmd/roachtest/roachtestutil/validation_check.go @@ -128,7 +128,11 @@ func CheckInvalidDescriptors(ctx context.Context, db *gosql.DB) error { // validateTokensReturned ensures that all RACv2 tokens are returned to the pool // at the end of the test. func ValidateTokensReturned( - ctx context.Context, t test.Test, c cluster.Cluster, nodes option.NodeListOption, + ctx context.Context, + t test.Test, + c cluster.Cluster, + nodes option.NodeListOption, + waitTime time.Duration, ) { t.L().Printf("validating all tokens returned") for _, node := range nodes { @@ -163,10 +167,10 @@ func ValidateTokensReturned( } } return nil - // We wait up to 10 minutes for the tokens to be returned. In tests which + // We wait up to waitTime for the tokens to be returned. In tests which // purposefully create a send queue towards a node, the queue may take a // while to drain. The tokens will not be returned until the queue is // empty and there are no inflight requests. - }, 10*time.Minute) + }, waitTime) } } diff --git a/pkg/cmd/roachtest/tests/admission_control_elastic_mixed_version.go b/pkg/cmd/roachtest/tests/admission_control_elastic_mixed_version.go index 1dec1fa84742..26e0089a2f9b 100644 --- a/pkg/cmd/roachtest/tests/admission_control_elastic_mixed_version.go +++ b/pkg/cmd/roachtest/tests/admission_control_elastic_mixed_version.go @@ -133,7 +133,7 @@ func registerElasticWorkloadMixedVersion(r registry.Registry) { mvt.Run() // TODO(pav-kv): also validate that the write throughput was kept under // control, and the foreground traffic was not starved. - roachtestutil.ValidateTokensReturned(ctx, t, c, c.CRDBNodes()) + roachtestutil.ValidateTokensReturned(ctx, t, c, c.CRDBNodes(), time.Minute) }, }) } diff --git a/pkg/cmd/roachtest/tests/perturbation/framework.go b/pkg/cmd/roachtest/tests/perturbation/framework.go index e1c5023eab20..410c7e8b67f7 100644 --- a/pkg/cmd/roachtest/tests/perturbation/framework.go +++ b/pkg/cmd/roachtest/tests/perturbation/framework.go @@ -669,7 +669,14 @@ func (v variations) runTest(ctx context.Context, t test.Test, c cluster.Cluster) t.L().Printf("validating stats after the perturbation") failures = append(failures, isAcceptableChange(t.L(), baselineStats, afterStats, v.acceptableChange)...) require.True(t, len(failures) == 0, strings.Join(failures, "\n")) - roachtestutil.ValidateTokensReturned(ctx, t, v, v.stableNodes()) + // TODO(baptist): Look at the time for token return in actual tests to + // determine if this can be lowered further. + tokenReturnTime := 10 * time.Minute + // TODO(#137017): Increase the return time if disk bandwidth limit is set. + if v.diskBandwidthLimit != "0" { + tokenReturnTime = 1 * time.Hour + } + roachtestutil.ValidateTokensReturned(ctx, t, v, v.stableNodes(), tokenReturnTime) } func (v variations) applyClusterSettings(ctx context.Context, t test.Test) { diff --git a/pkg/crosscluster/logical/udf_row_processor_test.go b/pkg/crosscluster/logical/udf_row_processor_test.go index 8a3403d28679..391552e69552 100644 --- a/pkg/crosscluster/logical/udf_row_processor_test.go +++ b/pkg/crosscluster/logical/udf_row_processor_test.go @@ -188,7 +188,8 @@ func TestUDFPreviousValue(t *testing.T) { runnerA.Exec(t, "UPDATE tallies SET v = 15 WHERE pk = 1") WaitUntilReplicatedTime(t, s.Clock().Now(), runnerB, jobBID) - runnerB.CheckQueryResults(t, "SELECT * FROM tallies", [][]string{ - {"1", "25"}, + // At-least-once delivery means it should be at least 25 (might be 30/35/etc). + runnerB.CheckQueryResults(t, "SELECT v >= 25 FROM tallies", [][]string{ + {"true"}, }) } diff --git a/pkg/kv/kvclient/kvcoord/dist_sender_circuit_breaker_test.go b/pkg/kv/kvclient/kvcoord/dist_sender_circuit_breaker_test.go index f81441c51739..eeb056c13c67 100644 --- a/pkg/kv/kvclient/kvcoord/dist_sender_circuit_breaker_test.go +++ b/pkg/kv/kvclient/kvcoord/dist_sender_circuit_breaker_test.go @@ -45,7 +45,7 @@ func TestDistSenderReplicaStall(t *testing.T) { // The lease won't move unless we use expiration-based leases. We also // speed up the test by reducing various intervals and timeouts. st := cluster.MakeTestingClusterSettings() - kvserver.ExpirationLeasesOnly.Override(ctx, &st.SV, true) + kvserver.OverrideDefaultLeaseType(ctx, &st.SV, roachpb.LeaseExpiration) kvcoord.CircuitBreakersMode.Override( ctx, &st.SV, kvcoord.DistSenderCircuitBreakersAllRanges, ) diff --git a/pkg/kv/kvclient/kvcoord/transport_test.go b/pkg/kv/kvclient/kvcoord/transport_test.go index ddb370236689..92f57021a680 100644 --- a/pkg/kv/kvclient/kvcoord/transport_test.go +++ b/pkg/kv/kvclient/kvcoord/transport_test.go @@ -255,6 +255,12 @@ func (m *mockInternalClient) Batch( return br, nil } +func (m *mockInternalClient) BatchStream( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_BatchStreamClient, error) { + return nil, fmt.Errorf("unsupported BatchStream call") +} + // RangeLookup implements the kvpb.InternalClient interface. func (m *mockInternalClient) RangeLookup( ctx context.Context, rl *kvpb.RangeLookupRequest, _ ...grpc.CallOption, diff --git a/pkg/kv/kvclient/kvtenant/connector_test.go b/pkg/kv/kvclient/kvtenant/connector_test.go index 22ed7e472ee4..a4264229846c 100644 --- a/pkg/kv/kvclient/kvtenant/connector_test.go +++ b/pkg/kv/kvclient/kvtenant/connector_test.go @@ -121,6 +121,10 @@ func (*mockServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRespon panic("unimplemented") } +func (m *mockServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("implement me") +} + func (m *mockServer) MuxRangeFeed(server kvpb.Internal_MuxRangeFeedServer) error { panic("implement me") } diff --git a/pkg/kv/kvpb/api.proto b/pkg/kv/kvpb/api.proto index eeaf95b84bba..befe0824b75a 100644 --- a/pkg/kv/kvpb/api.proto +++ b/pkg/kv/kvpb/api.proto @@ -3668,23 +3668,30 @@ message JoinNodeResponse { // Batch and RangeFeed service implemented by nodes for KV API requests. service Internal { - rpc Batch (BatchRequest) returns (BatchResponse) {} + rpc Batch (BatchRequest) returns (BatchResponse) {} + + // BatchStream is a streaming variant of Batch. There is a 1:1 correspondence + // between requests and responses. The method is used to facilitate pooling of + // gRPC streams to avoid the overhead of creating and discarding a new stream + // for each unary Batch RPC invocation. See rpc.BatchStreamPool. + rpc BatchStream (stream BatchRequest) returns (stream BatchResponse) {} + rpc RangeLookup (RangeLookupRequest) returns (RangeLookupResponse) {} - rpc MuxRangeFeed (stream RangeFeedRequest) returns (stream MuxRangeFeedEvent) {} + rpc MuxRangeFeed (stream RangeFeedRequest) returns (stream MuxRangeFeedEvent) {} rpc GossipSubscription (GossipSubscriptionRequest) returns (stream GossipSubscriptionEvent) {} rpc ResetQuorum (ResetQuorumRequest) returns (ResetQuorumResponse) {} // TokenBucket is used by tenants to obtain Request Units and report // consumption. - rpc TokenBucket (TokenBucketRequest) returns (TokenBucketResponse) {} + rpc TokenBucket (TokenBucketRequest) returns (TokenBucketResponse) {} // Join a bootstrapped cluster. If the target node is itself not part of a // bootstrapped cluster, an appropriate error is returned. - rpc Join(JoinNodeRequest) returns (JoinNodeResponse) { } + rpc Join (JoinNodeRequest) returns (JoinNodeResponse) {} // GetSpanConfigs is used to fetch the span configurations over a given // keyspan. - rpc GetSpanConfigs (GetSpanConfigsRequest) returns (GetSpanConfigsResponse) { } + rpc GetSpanConfigs (GetSpanConfigsRequest) returns (GetSpanConfigsResponse) {} // GetAllSystemSpanConfigsThatApply is used to fetch all system span // configurations that apply over a tenant's ranges. @@ -3692,20 +3699,19 @@ service Internal { // UpdateSpanConfigs is used to update the span configurations over given // keyspans. - rpc UpdateSpanConfigs (UpdateSpanConfigsRequest) returns (UpdateSpanConfigsResponse) { } + rpc UpdateSpanConfigs (UpdateSpanConfigsRequest) returns (UpdateSpanConfigsResponse) {} // SpanConfigConformance is used to determine whether ranges backing the given // keyspans conform to span configs that apply over them. - rpc SpanConfigConformance (SpanConfigConformanceRequest) returns (SpanConfigConformanceResponse) { } + rpc SpanConfigConformance (SpanConfigConformanceRequest) returns (SpanConfigConformanceResponse) {} // TenantSettings is used by tenants to obtain and stay up to date with tenant // setting overrides. - rpc TenantSettings (TenantSettingsRequest) returns (stream TenantSettingsEvent) { } - + rpc TenantSettings (TenantSettingsRequest) returns (stream TenantSettingsEvent) {} // GetRangeDescriptors is used by tenants to get range descriptors for their // own ranges. - rpc GetRangeDescriptors (GetRangeDescriptorsRequest) returns (stream GetRangeDescriptorsResponse) { } + rpc GetRangeDescriptors (GetRangeDescriptorsRequest) returns (stream GetRangeDescriptorsResponse) {} } // GetRangeDescriptorsRequest is used to fetch range descriptors. diff --git a/pkg/kv/kvpb/kvpbmock/mocks_generated.go b/pkg/kv/kvpb/kvpbmock/mocks_generated.go index fae67f8dd2c1..6f85105a1cc2 100644 --- a/pkg/kv/kvpb/kvpbmock/mocks_generated.go +++ b/pkg/kv/kvpb/kvpbmock/mocks_generated.go @@ -58,6 +58,26 @@ func (mr *MockInternalClientMockRecorder) Batch(arg0, arg1 interface{}, arg2 ... return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Batch", reflect.TypeOf((*MockInternalClient)(nil).Batch), varargs...) } +// BatchStream mocks base method. +func (m *MockInternalClient) BatchStream(arg0 context.Context, arg1 ...grpc.CallOption) (kvpb.Internal_BatchStreamClient, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BatchStream", varargs...) + ret0, _ := ret[0].(kvpb.Internal_BatchStreamClient) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BatchStream indicates an expected call of BatchStream. +func (mr *MockInternalClientMockRecorder) BatchStream(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BatchStream", reflect.TypeOf((*MockInternalClient)(nil).BatchStream), varargs...) +} + // GetAllSystemSpanConfigsThatApply mocks base method. func (m *MockInternalClient) GetAllSystemSpanConfigsThatApply(arg0 context.Context, arg1 *roachpb.GetAllSystemSpanConfigsThatApplyRequest, arg2 ...grpc.CallOption) (*roachpb.GetAllSystemSpanConfigsThatApplyResponse, error) { m.ctrl.T.Helper() diff --git a/pkg/kv/kvserver/consistency_queue_test.go b/pkg/kv/kvserver/consistency_queue_test.go index 797aa3c0a8cf..e8fd03971c6a 100644 --- a/pkg/kv/kvserver/consistency_queue_test.go +++ b/pkg/kv/kvserver/consistency_queue_test.go @@ -606,18 +606,25 @@ func testConsistencyQueueRecomputeStatsImpl(t *testing.T, hadEstimates bool) { t.Fatal(err) } - // Force a run of the consistency queue, otherwise it might take a while. - store := tc.GetFirstStoreFromServer(t, 0) - require.NoError(t, store.ForceConsistencyQueueProcess()) - - // The stats should magically repair themselves. We'll first do a quick check - // and then a full recomputation. - repl, _, err := tc.Servers[0].GetStores().(*kvserver.Stores).GetReplicaForRangeID(ctx, rangeID) - require.NoError(t, err) - ms := repl.GetMVCCStats() - if ms.SysCount >= sysCountGarbage { - t.Fatalf("still have a SysCount of %d", ms.SysCount) - } + // When running with leader leases, it might take an extra election interval + // for a lease to be established after adding the voters above because the + // leader needs to get store liveness support from the followers. The stats + // re-computation runs on the leaseholder and will fail if there isn't one. + testutils.SucceedsSoon(t, func() error { + // Force a run of the consistency queue, otherwise it might take a while. + store := tc.GetFirstStoreFromServer(t, 0) + require.NoError(t, store.ForceConsistencyQueueProcess()) + + // The stats should magically repair themselves. We'll first do a quick check + // and then a full recomputation. + repl, _, err := tc.Servers[0].GetStores().(*kvserver.Stores).GetReplicaForRangeID(ctx, rangeID) + require.NoError(t, err) + ms := repl.GetMVCCStats() + if ms.SysCount >= sysCountGarbage { + return errors.Newf("still have a SysCount of %d", ms.SysCount) + } + return nil + }) if delta := computeDelta(db0); delta != (enginepb.MVCCStats{}) { t.Fatalf("stats still in need of adjustment: %+v", delta) diff --git a/pkg/kv/kvserver/flow_control_integration_test.go b/pkg/kv/kvserver/flow_control_integration_test.go index 924b52b9e710..86d7c2193cf6 100644 --- a/pkg/kv/kvserver/flow_control_integration_test.go +++ b/pkg/kv/kvserver/flow_control_integration_test.go @@ -5676,6 +5676,216 @@ func TestFlowControlSendQueueRangeMigrate(t *testing.T) { h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) } +// TestFlowControlSendQueueRangeSplitMerge exercises the send queue formation, +// prevention and force flushing due to range split and merge operations. See +// the initial comment for an overview of the test structure. +func TestFlowControlSendQueueRangeSplitMerge(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + ctx := context.Background() + const numNodes = 3 + settings := cluster.MakeTestingClusterSettings() + kvflowcontrol.Mode.Override(ctx, &settings.SV, kvflowcontrol.ApplyToAll) + // We want to exhaust tokens but not overload the test, so we set the limits + // lower (8 and 16 MiB default). + kvflowcontrol.ElasticTokensPerStream.Override(ctx, &settings.SV, 2<<20) + kvflowcontrol.RegularTokensPerStream.Override(ctx, &settings.SV, 4<<20) + // TODO(kvoli): There are unexpected messages popping up, which cause a send + // queue to be created on the RHS range post-split. This appears related to + // leader leases, or at least disablng them deflakes the test. We should + // re-enable leader leases likely by adjusting the test to ignore the 500b + // send queue formatiion: + // + // r3=(is_state_replicate=true has_send_queue=true send_queue_size=500 B / 1 entries + // [idx_to_send=12 next_raft_idx=13 next_raft_idx_initial=13 force_flush_stop_idx=0]) + // + // See #136258 for more debug info. + kvserver.OverrideDefaultLeaseType(ctx, &settings.SV, roachpb.LeaseEpoch) + disableWorkQueueGrantingServers := make([]atomic.Bool, numNodes) + setTokenReturnEnabled := func(enabled bool, serverIdxs ...int) { + for _, serverIdx := range serverIdxs { + disableWorkQueueGrantingServers[serverIdx].Store(!enabled) + } + } + + argsPerServer := make(map[int]base.TestServerArgs) + for i := range disableWorkQueueGrantingServers { + disableWorkQueueGrantingServers[i].Store(true) + argsPerServer[i] = base.TestServerArgs{ + Settings: settings, + Knobs: base.TestingKnobs{ + Store: &kvserver.StoreTestingKnobs{ + FlowControlTestingKnobs: &kvflowcontrol.TestingKnobs{ + UseOnlyForScratchRanges: true, + OverrideTokenDeduction: func(tokens kvflowcontrol.Tokens) kvflowcontrol.Tokens { + // Deduct every write as 1 MiB, regardless of how large it + // actually is. + return kvflowcontrol.Tokens(1 << 20) + }, + // We want to test the behavior of the send queue, so we want to + // always have up-to-date stats. This ensures that the send queue + // stats are always refreshed on each call to + // RangeController.HandleRaftEventRaftMuLocked. + OverrideAlwaysRefreshSendStreamStats: true, + }, + }, + AdmissionControl: &admission.TestingKnobs{ + DisableWorkQueueFastPath: true, + DisableWorkQueueGranting: func() bool { + idx := i + return disableWorkQueueGrantingServers[idx].Load() + }, + }, + }, + } + } + + tc := testcluster.StartTestCluster(t, numNodes, base.TestClusterArgs{ + ReplicationMode: base.ReplicationManual, + ServerArgsPerNode: argsPerServer, + }) + defer tc.Stopper().Stop(ctx) + + k := tc.ScratchRange(t) + tc.AddVotersOrFatal(t, k, tc.Targets(1, 2)...) + + h := newFlowControlTestHelper( + t, tc, "flow_control_integration_v2", /* testdata */ + kvflowcontrol.V2EnabledWhenLeaderV2Encoding, true, /* isStatic */ + ) + h.init(kvflowcontrol.ApplyToAll) + defer h.close("send_queue_range_split_merge") + + desc, err := tc.LookupRange(k) + require.NoError(t, err) + h.enableVerboseRaftMsgLoggingForRange(desc.RangeID) + h.enableVerboseRaftMsgLoggingForRange(desc.RangeID + 1) + n1 := sqlutils.MakeSQLRunner(tc.ServerConn(0)) + h.waitForConnectedStreams(ctx, desc.RangeID, 3, 0 /* serverIdx */) + // Reset the token metrics, since a send queue may have instantly + // formed when adding one of the replicas, before being quickly + // drained. + h.resetV2TokenMetrics(ctx) + + // TODO(kvoli): Update this comment to also mention writing to the RHS range, + // once we resolve the subsume wait for application issue. See #136649. + h.comment(` +-- We will exhaust the tokens across all streams while admission is blocked on +-- n3, using a single 4 MiB (deduction, the write itself is small) write. Then, +-- we will write a 1 MiB put to the range, split it, write a 1 MiB put to the +-- LHS range, merge the ranges, and write a 1 MiB put to the merged range. We +-- expect that at each stage where a send queue develops n1->s3, the send queue +-- will be flushed by the range merge and range split range operations.`) + h.comment(` +-- Start by exhausting the tokens from n1->s3 and blocking admission on s3. +-- (Issuing 4x1MiB regular, 3x replicated write that's not admitted on s3.)`) + setTokenReturnEnabled(true /* enabled */, 0, 1) + setTokenReturnEnabled(false /* enabled */, 2) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.waitForTotalTrackedTokens(ctx, desc.RangeID, 4<<20 /* 4 MiB */, 0 /* serverIdx */) + + h.comment(`(Sending 1 MiB put request to pre-split range)`) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.comment(`(Sent 1 MiB put request to pre-split range)`) + + h.waitForTotalTrackedTokens(ctx, desc.RangeID, 4<<20 /* 4 MiB */, 0 /* serverIdx */) + h.waitForAllTokensReturnedForStreamsV2(ctx, 0, /* serverIdx */ + testingMkFlowStream(0), testingMkFlowStream(1)) + h.waitForSendQueueSize(ctx, desc.RangeID, 1<<20 /* expSize 1 MiB */, 0 /* serverIdx */) + + h.comment(` +-- Send queue metrics from n1, n3's send queue should have 1 MiB for s3.`) + h.query(n1, flowSendQueueQueryStr) + h.comment(` +-- Observe the total tracked tokens per-stream on n1, s3's entries will still +-- be tracked here.`) + h.query(n1, ` + SELECT range_id, store_id, crdb_internal.humanize_bytes(total_tracked_tokens::INT8) + FROM crdb_internal.kv_flow_control_handles_v2 +`, "range_id", "store_id", "total_tracked_tokens") + h.comment(` +-- Per-store tokens available from n1, these should reflect the lack of tokens +-- for s3.`) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) + + h.comment(`-- (Splitting range.)`) + left, right := tc.SplitRangeOrFatal(t, k.Next()) + h.waitForConnectedStreams(ctx, left.RangeID, 3, 0 /* serverIdx */) + h.waitForConnectedStreams(ctx, right.RangeID, 3, 0 /* serverIdx */) + h.waitForSendQueueSize(ctx, left.RangeID, 0 /* expSize 0 MiB */, 0 /* serverIdx */) + h.waitForSendQueueSize(ctx, right.RangeID, 0 /* expSize 0 MiB */, 0 /* serverIdx */) + + h.comment(`-- Observe the newly split off replica, with its own three streams.`) + h.query(n1, ` + SELECT range_id, count(*) AS streams + FROM crdb_internal.kv_flow_control_handles_v2 +GROUP BY (range_id) +ORDER BY streams DESC; +`, "range_id", "stream_count") + h.comment(` +-- Send queue and flow token metrics from n1, post-split. +-- We expect to see a force flush of the send queue for s3.`) + h.query(n1, flowSendQueueQueryStr) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) + + h.comment(`(Sending 1 MiB put request to post-split LHS range)`) + h.put(ctx, roachpb.Key(left.StartKey), 1, admissionpb.NormalPri) + h.comment(`(Sent 1 MiB put request to post-split LHS range)`) + h.waitForAllTokensReturnedForStreamsV2(ctx, 0, /* serverIdx */ + testingMkFlowStream(0), testingMkFlowStream(1)) + + // TODO(kvoli): Uncomment once we resolve the subsume wait for application + // issue. See #136649. + // h.comment(`(Sending 1 MiB put request to post-split RHS range)`) + // h.put(ctx, roachpb.Key(right.StartKey), 1, admissionpb.NormalPri) + // h.comment(`(Sent 1 MiB put request to post-split RHS range)`) + // h.waitForAllTokensReturnedForStreamsV2(ctx, 0, /* serverIdx */ + // testingMkFlowStream(0), testingMkFlowStream(1)) + + h.comment(` +-- Send queue and flow token metrics from n1, post-split and 1 MiB put on +-- each side.`) + h.query(n1, flowSendQueueQueryStr) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) + + h.comment(`-- (Merging ranges.)`) + merged := tc.MergeRangesOrFatal(t, left.StartKey.AsRawKey()) + h.waitForConnectedStreams(ctx, merged.RangeID, 3, 0 /* serverIdx */) + h.waitForSendQueueSize(ctx, merged.RangeID, 0 /* expSize 0 MiB */, 0 /* serverIdx */) + + h.comment(` +-- Send queue and flow token metrics from n1, post-split-merge. +-- We expect to see a force flush of the send queue for s3 again.`) + h.query(n1, flowSendQueueQueryStr) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) + + h.comment(`(Sending 1 MiB put request to post-split-merge range)`) + h.put(ctx, k, 1, admissionpb.NormalPri) + h.comment(`(Sent 1 MiB put request to post-split-merge range)`) + h.waitForAllTokensReturnedForStreamsV2(ctx, 0, /* serverIdx */ + testingMkFlowStream(0), testingMkFlowStream(1)) + h.waitForSendQueueSize(ctx, merged.RangeID, 1<<20 /* expSize 1 MiB */, 0 /* serverIdx */) + + h.comment(` +-- Send queue and flow token metrics from n1, post-split-merge and 1 MiB put. +-- We expect to see the send queue develop for s3 again.`) + h.query(n1, flowSendQueueQueryStr) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) + + h.comment(`-- (Allowing below-raft admission to proceed on [n1,n2,n3].)`) + setTokenReturnEnabled(true /* enabled */, 0, 1, 2) + + h.waitForAllTokensReturned(ctx, 3, 0 /* serverIdx */) + h.comment(` +-- Send queue and flow token metrics from n1, all tokens should be returned.`) + h.query(n1, flowSendQueueQueryStr) + h.query(n1, flowPerStoreTokenQueryStr, flowPerStoreTokenQueryHeaderStrs...) +} + type flowControlTestHelper struct { t testing.TB tc *testcluster.TestCluster @@ -5788,7 +5998,7 @@ func (h *flowControlTestHelper) checkSendQueueSize( h.tc.GetFirstStoreFromServer(h.t, serverIdx).GetReplicaIfExists(rangeID).SendStreamStats(&stats) _, sizeBytes := stats.SumSendQueues() if sizeBytes != expSize { - return errors.Errorf("expected send queue size %d, got %d [%v]", expSize, sizeBytes, stats) + return errors.Errorf("expected send queue size %d, got %d [%v]", expSize, sizeBytes, &stats) } return nil } diff --git a/pkg/kv/kvserver/testdata/flow_control_integration_v2/send_queue_range_split_merge b/pkg/kv/kvserver/testdata/flow_control_integration_v2/send_queue_range_split_merge new file mode 100644 index 000000000000..8c89bc52dd7b --- /dev/null +++ b/pkg/kv/kvserver/testdata/flow_control_integration_v2/send_queue_range_split_merge @@ -0,0 +1,250 @@ +echo +---- +---- +-- We will exhaust the tokens across all streams while admission is blocked on +-- n3, using a single 4 MiB (deduction, the write itself is small) write. Then, +-- we will write a 1 MiB put to the range, split it, write a 1 MiB put to the +-- LHS range, merge the ranges, and write a 1 MiB put to the merged range. We +-- expect that at each stage where a send queue develops n1->s3, the send queue +-- will be flushed by the range merge and range split range operations. + + +-- Start by exhausting the tokens from n1->s3 and blocking admission on s3. +-- (Issuing 4x1MiB regular, 3x replicated write that's not admitted on s3.) + + +(Sending 1 MiB put request to pre-split range) + + +(Sent 1 MiB put request to pre-split range) + + +-- Send queue metrics from n1, n3's send queue should have 1 MiB for s3. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 1.0 MiB + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 0 B + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B + + +-- Observe the total tracked tokens per-stream on n1, s3's entries will still +-- be tracked here. +SELECT range_id, store_id, crdb_internal.humanize_bytes(total_tracked_tokens::INT8) + FROM crdb_internal.kv_flow_control_handles_v2 + + range_id | store_id | total_tracked_tokens +-----------+----------+----------------------- + 74 | 1 | 0 B + 74 | 2 | 0 B + 74 | 3 | 4.0 MiB + + +-- Per-store tokens available from n1, these should reflect the lack of tokens +-- for s3. +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 0 B | -3.0 MiB | 0 B | -2.0 MiB + + +-- (Splitting range.) + + +-- Observe the newly split off replica, with its own three streams. +SELECT range_id, count(*) AS streams + FROM crdb_internal.kv_flow_control_handles_v2 +GROUP BY (range_id) +ORDER BY streams DESC; + + range_id | stream_count +-----------+--------------- + 74 | 3 + 75 | 3 + + +-- Send queue and flow token metrics from n1, post-split. +-- We expect to see a force flush of the send queue for s3. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 0 B + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 1.0 MiB + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 0 B | -3.0 MiB | 0 B | -3.0 MiB + + +(Sending 1 MiB put request to post-split LHS range) + + +(Sent 1 MiB put request to post-split LHS range) + + +-- Send queue and flow token metrics from n1, post-split and 1 MiB put on +-- each side. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 1.0 MiB + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 1.0 MiB + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 0 B | -4.0 MiB | 0 B | -3.0 MiB + + +-- (Merging ranges.) + + +-- Send queue and flow token metrics from n1, post-split-merge. +-- We expect to see a force flush of the send queue for s3 again. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 0 B + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 2.0 MiB + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 0 B | -4.0 MiB | 0 B | -4.0 MiB + + +(Sending 1 MiB put request to post-split-merge range) + + +(Sent 1 MiB put request to post-split-merge range) + + +-- Send queue and flow token metrics from n1, post-split-merge and 1 MiB put. +-- We expect to see the send queue develop for s3 again. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 1.0 MiB + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 2.0 MiB + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 0 B | -5.0 MiB | 0 B | -4.0 MiB + + +-- (Allowing below-raft admission to proceed on [n1,n2,n3].) + + +-- Send queue and flow token metrics from n1, all tokens should be returned. +SELECT name, crdb_internal.humanize_bytes(value::INT8) + FROM crdb_internal.node_metrics + WHERE name LIKE '%kvflowcontrol%send_queue%' + AND name != 'kvflowcontrol.send_queue.count' +ORDER BY name ASC; + + kvflowcontrol.send_queue.bytes | 0 B + kvflowcontrol.send_queue.prevent.count | 0 B + kvflowcontrol.send_queue.scheduled.deducted_bytes | 0 B + kvflowcontrol.send_queue.scheduled.force_flush | 0 B + kvflowcontrol.tokens.send.elastic.deducted.force_flush_send_queue | 2.0 MiB + kvflowcontrol.tokens.send.elastic.deducted.prevent_send_queue | 0 B + kvflowcontrol.tokens.send.regular.deducted.prevent_send_queue | 0 B +SELECT store_id, + crdb_internal.humanize_bytes(available_eval_regular_tokens), + crdb_internal.humanize_bytes(available_eval_elastic_tokens), + crdb_internal.humanize_bytes(available_send_regular_tokens), + crdb_internal.humanize_bytes(available_send_elastic_tokens) + FROM crdb_internal.kv_flow_controller_v2 + ORDER BY store_id ASC; + + store_id | eval_regular_available | eval_elastic_available | send_regular_available | send_elastic_available +-----------+------------------------+------------------------+------------------------+------------------------- + 1 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 2 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB + 3 | 4.0 MiB | 2.0 MiB | 4.0 MiB | 2.0 MiB +---- +---- + +# vim:ft=sql diff --git a/pkg/rpc/BUILD.bazel b/pkg/rpc/BUILD.bazel index ab519ec07e62..30fa97abb325 100644 --- a/pkg/rpc/BUILD.bazel +++ b/pkg/rpc/BUILD.bazel @@ -25,6 +25,7 @@ go_library( "restricted_internal_client.go", "settings.go", "snappy.go", + "stream_pool.go", "tls.go", ], embed = [":rpc_go_proto"], @@ -90,7 +91,10 @@ go_library( gomock( name = "mock_rpc", out = "mocks_generated_test.go", - interfaces = ["Dialbacker"], + interfaces = [ + "BatchStreamClient", + "Dialbacker", + ], library = ":rpc", package = "rpc", self_package = "github.com/cockroachdb/cockroach/pkg/rpc", @@ -116,6 +120,7 @@ go_test( "metrics_test.go", "peer_test.go", "snappy_test.go", + "stream_pool_test.go", "tls_test.go", ":mock_rpc", # keep ], @@ -175,6 +180,7 @@ go_test( "@org_golang_google_grpc//metadata", "@org_golang_google_grpc//peer", "@org_golang_google_grpc//status", + "@org_golang_x_sync//errgroup", ], ) diff --git a/pkg/rpc/auth_tenant.go b/pkg/rpc/auth_tenant.go index c5233fc21b41..d923cf643fa4 100644 --- a/pkg/rpc/auth_tenant.go +++ b/pkg/rpc/auth_tenant.go @@ -55,7 +55,7 @@ func (a tenantAuthorizer) authorize( req interface{}, ) error { switch fullMethod { - case "/cockroach.roachpb.Internal/Batch": + case "/cockroach.roachpb.Internal/Batch", "/cockroach.roachpb.Internal/BatchStream": return a.authBatch(ctx, sv, tenID, req.(*kvpb.BatchRequest)) case "/cockroach.roachpb.Internal/RangeLookup": @@ -63,6 +63,7 @@ func (a tenantAuthorizer) authorize( case "/cockroach.roachpb.Internal/RangeFeed", "/cockroach.roachpb.Internal/MuxRangeFeed": return a.authRangeFeed(tenID, req.(*kvpb.RangeFeedRequest)) + case "/cockroach.roachpb.Internal/GossipSubscription": return a.authGossipSubscription(tenID, req.(*kvpb.GossipSubscriptionRequest)) diff --git a/pkg/rpc/auth_test.go b/pkg/rpc/auth_test.go index b3d90548e946..14c735a7163f 100644 --- a/pkg/rpc/auth_test.go +++ b/pkg/rpc/auth_test.go @@ -572,6 +572,30 @@ func TestTenantAuthRequest(t *testing.T) { expErr: noError, }, }, + "/cockroach.roachpb.Internal/BatchStream": { + { + req: &kvpb.BatchRequest{}, + expErr: `requested key span /Max not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq("a", "b"), + )}, + expErr: `requested key span {a-b} not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq(prefix(5, "a"), prefix(5, "b")), + )}, + expErr: `requested key span /Tenant/5{a-b} not fully contained in tenant keyspace /Tenant/1{0-1}`, + }, + { + req: &kvpb.BatchRequest{Requests: makeReqs( + makeReq(prefix(10, "a"), prefix(10, "b")), + )}, + expErr: noError, + }, + }, "/cockroach.roachpb.Internal/RangeLookup": { { req: &kvpb.RangeLookupRequest{}, @@ -1009,7 +1033,7 @@ func TestTenantAuthRequest(t *testing.T) { // cross-read capability and the request is a read, expect no error. if canCrossRead && strings.Contains(tc.expErr, "fully contained") { switch method { - case "/cockroach.roachpb.Internal/Batch": + case "/cockroach.roachpb.Internal/Batch", "/cockroach.roachpb.Internal/BatchStream": if tc.req.(*kvpb.BatchRequest).IsReadOnly() { tc.expErr = noError } diff --git a/pkg/rpc/connection.go b/pkg/rpc/connection.go index decfed37f61c..2adfd4560370 100644 --- a/pkg/rpc/connection.go +++ b/pkg/rpc/connection.go @@ -34,17 +34,26 @@ type Connection struct { // It always has to be signaled eventually, regardless of the stopper // draining, etc, since callers might be blocking on it. connFuture connFuture + // batchStreamPool holds a pool of BatchStreamClient streams established on + // the connection. The pool can be used to avoid the overhead of unary Batch + // RPCs. + // + // The pool is only initialized once the ClientConn is resolved. + batchStreamPool BatchStreamPool } // newConnectionToNodeID makes a Connection for the given node, class, and nontrivial Signal // that should be queried in Connect(). -func newConnectionToNodeID(k peerKey, breakerSignal func() circuit.Signal) *Connection { +func newConnectionToNodeID( + opts *ContextOptions, k peerKey, breakerSignal func() circuit.Signal, +) *Connection { c := &Connection{ breakerSignalFn: breakerSignal, k: k, connFuture: connFuture{ ready: make(chan struct{}), }, + batchStreamPool: makeStreamPool(opts.Stopper, newBatchStream), } return c } @@ -156,6 +165,13 @@ func (c *Connection) Signal() circuit.Signal { return c.breakerSignalFn() } +func (c *Connection) BatchStreamPool() *BatchStreamPool { + if !c.connFuture.Resolved() { + panic("BatchStreamPool called on unresolved connection") + } + return &c.batchStreamPool +} + type connFuture struct { ready chan struct{} cc *grpc.ClientConn diff --git a/pkg/rpc/context_test.go b/pkg/rpc/context_test.go index 9e40ae24a724..d81a8ff5f6d8 100644 --- a/pkg/rpc/context_test.go +++ b/pkg/rpc/context_test.go @@ -310,6 +310,10 @@ func (*internalServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRe return nil, nil } +func (*internalServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("unimplemented") +} + func (*internalServer) RangeLookup( context.Context, *kvpb.RangeLookupRequest, ) (*kvpb.RangeLookupResponse, error) { diff --git a/pkg/rpc/mocks_generated_test.go b/pkg/rpc/mocks_generated_test.go index f0255bf145f4..efa2c112a466 100644 --- a/pkg/rpc/mocks_generated_test.go +++ b/pkg/rpc/mocks_generated_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/cockroachdb/cockroach/pkg/rpc (interfaces: Dialbacker) +// Source: github.com/cockroachdb/cockroach/pkg/rpc (interfaces: BatchStreamClient,Dialbacker) // Package rpc is a generated GoMock package. package rpc @@ -8,12 +8,65 @@ import ( context "context" reflect "reflect" + kvpb "github.com/cockroachdb/cockroach/pkg/kv/kvpb" roachpb "github.com/cockroachdb/cockroach/pkg/roachpb" rpcpb "github.com/cockroachdb/cockroach/pkg/rpc/rpcpb" gomock "github.com/golang/mock/gomock" grpc "google.golang.org/grpc" ) +// MockBatchStreamClient is a mock of BatchStreamClient interface. +type MockBatchStreamClient struct { + ctrl *gomock.Controller + recorder *MockBatchStreamClientMockRecorder +} + +// MockBatchStreamClientMockRecorder is the mock recorder for MockBatchStreamClient. +type MockBatchStreamClientMockRecorder struct { + mock *MockBatchStreamClient +} + +// NewMockBatchStreamClient creates a new mock instance. +func NewMockBatchStreamClient(ctrl *gomock.Controller) *MockBatchStreamClient { + mock := &MockBatchStreamClient{ctrl: ctrl} + mock.recorder = &MockBatchStreamClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatchStreamClient) EXPECT() *MockBatchStreamClientMockRecorder { + return m.recorder +} + +// Recv mocks base method. +func (m *MockBatchStreamClient) Recv() (*kvpb.BatchResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Recv") + ret0, _ := ret[0].(*kvpb.BatchResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Recv indicates an expected call of Recv. +func (mr *MockBatchStreamClientMockRecorder) Recv() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Recv", reflect.TypeOf((*MockBatchStreamClient)(nil).Recv)) +} + +// Send mocks base method. +func (m *MockBatchStreamClient) Send(arg0 *kvpb.BatchRequest) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Send", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Send indicates an expected call of Send. +func (mr *MockBatchStreamClientMockRecorder) Send(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockBatchStreamClient)(nil).Send), arg0) +} + // MockDialbacker is a mock of Dialbacker interface. type MockDialbacker struct { ctrl *gomock.Controller diff --git a/pkg/rpc/nodedialer/BUILD.bazel b/pkg/rpc/nodedialer/BUILD.bazel index 192eb71c9512..b405795e4d32 100644 --- a/pkg/rpc/nodedialer/BUILD.bazel +++ b/pkg/rpc/nodedialer/BUILD.bazel @@ -7,12 +7,16 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/base", + "//pkg/clusterversion", "//pkg/kv/kvbase", "//pkg/kv/kvpb", "//pkg/roachpb", "//pkg/rpc", + "//pkg/settings", + "//pkg/settings/cluster", "//pkg/util/circuit", "//pkg/util/log", + "//pkg/util/metamorphic", "//pkg/util/stop", "//pkg/util/tracing", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/rpc/nodedialer/nodedialer.go b/pkg/rpc/nodedialer/nodedialer.go index 1715f5c7c3a5..4bebfd4f5dd4 100644 --- a/pkg/rpc/nodedialer/nodedialer.go +++ b/pkg/rpc/nodedialer/nodedialer.go @@ -11,12 +11,16 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/base" + "github.com/cockroachdb/cockroach/pkg/clusterversion" "github.com/cockroachdb/cockroach/pkg/kv/kvbase" "github.com/cockroachdb/cockroach/pkg/kv/kvpb" "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/rpc" + "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/settings/cluster" circuit2 "github.com/cockroachdb/cockroach/pkg/util/circuit" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/metamorphic" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/tracing" "github.com/cockroachdb/errors" @@ -96,7 +100,8 @@ func (n *Dialer) Dial( err = errors.Wrapf(err, "failed to resolve n%d", nodeID) return nil, err } - return n.dial(ctx, nodeID, addr, locality, true, class) + conn, _, err := n.dial(ctx, nodeID, addr, locality, true, class) + return conn, err } // DialNoBreaker is like Dial, but will not check the circuit breaker before @@ -112,7 +117,8 @@ func (n *Dialer) DialNoBreaker( if err != nil { return nil, err } - return n.dial(ctx, nodeID, addr, locality, false, class) + conn, _, err := n.dial(ctx, nodeID, addr, locality, false, class) + return conn, err } // DialInternalClient is a specialization of DialClass for callers that @@ -141,11 +147,14 @@ func (n *Dialer) DialInternalClient( return nil, errors.Wrap(err, "resolver error") } log.VEventf(ctx, 2, "sending request to %s", addr) - conn, err := n.dial(ctx, nodeID, addr, locality, true, class) + conn, pool, err := n.dial(ctx, nodeID, addr, locality, true, class) if err != nil { return nil, err } - client := kvpb.NewInternalClient(conn) + client := newBaseInternalClient(conn) + if shouldUseBatchStreamPoolClient(ctx, n.rpcContext.Settings) { + client = newBatchStreamPoolClient(pool) + } client = maybeWrapInTracingClient(ctx, client) return client, nil } @@ -160,11 +169,11 @@ func (n *Dialer) dial( locality roachpb.Locality, checkBreaker bool, class rpc.ConnectionClass, -) (_ *grpc.ClientConn, err error) { +) (_ *grpc.ClientConn, _ *rpc.BatchStreamPool, err error) { const ctxWrapMsg = "dial" // Don't trip the breaker if we're already canceled. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } rpcConn := n.rpcContext.GRPCDialNode(addr.String(), nodeID, locality, class) connect := rpcConn.Connect @@ -175,13 +184,13 @@ func (n *Dialer) dial( if err != nil { // If we were canceled during the dial, don't trip the breaker. if ctxErr := ctx.Err(); ctxErr != nil { - return nil, errors.Wrap(ctxErr, ctxWrapMsg) + return nil, nil, errors.Wrap(ctxErr, ctxWrapMsg) } err = errors.Wrapf(err, "failed to connect to n%d at %v", nodeID, addr) - return nil, err + return nil, nil, err } - - return conn, nil + pool := rpcConn.BatchStreamPool() + return conn, pool, nil } // ConnHealth returns nil if we have an open connection of the request @@ -275,25 +284,108 @@ func (n *Dialer) Latency(nodeID roachpb.NodeID) (time.Duration, error) { return latency, nil } -// TracingInternalClient wraps an InternalClient and fills in trace information -// on Batch RPCs. +// baseInternalClient is a wrapper around a grpc.ClientConn that implements the +// RestrictedInternalClient interface. By calling kvpb.NewInternalClient on each +// RPC invocation, that function can be inlined and the returned internalClient +// object (which itself is just a wrapper) never needs to be allocated on the +// heap. +type baseInternalClient grpc.ClientConn + +func newBaseInternalClient(conn *grpc.ClientConn) rpc.RestrictedInternalClient { + return (*baseInternalClient)(conn) +} + +func (c *baseInternalClient) asConn() *grpc.ClientConn { + return (*grpc.ClientConn)(c) +} + +// Batch implements the RestrictedInternalClient interface. +func (c *baseInternalClient) Batch( + ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, +) (*kvpb.BatchResponse, error) { + return kvpb.NewInternalClient(c.asConn()).Batch(ctx, ba, opts...) +} + +// MuxRangeFeed implements the RestrictedInternalClient interface. +func (c *baseInternalClient) MuxRangeFeed( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_MuxRangeFeedClient, error) { + return kvpb.NewInternalClient(c.asConn()).MuxRangeFeed(ctx, opts...) +} + +var batchStreamPoolingEnabled = settings.RegisterBoolSetting( + settings.ApplicationLevel, + "rpc.batch_stream_pool.enabled", + "if true, use pooled gRPC streams to execute Batch RPCs", + metamorphic.ConstantWithTestBool("rpc.batch_stream_pool.enabled", true), +) + +func shouldUseBatchStreamPoolClient(ctx context.Context, st *cluster.Settings) bool { + // NOTE: we use ActiveVersionOrEmpty(ctx).IsActive(...) instead of the more + // common IsActive(ctx, ...) to avoid a fatal error if an RPC is made before + // the cluster version is initialized. + if !st.Version.ActiveVersionOrEmpty(ctx).IsActive(clusterversion.V25_1_BatchStreamRPC) { + return false + } + if !batchStreamPoolingEnabled.Get(&st.SV) { + return false + } + return true +} + +// batchStreamPoolClient is a client that sends Batch RPCs using a pooled +// BatchStream RPC stream. Pooling these streams allows for reuse of gRPC +// resources, as opposed to native unary RPCs, which create a new stream and +// throw it away for each unary request (see grpc.invoke). +type batchStreamPoolClient rpc.BatchStreamPool + +func newBatchStreamPoolClient(pool *rpc.BatchStreamPool) rpc.RestrictedInternalClient { + return (*batchStreamPoolClient)(pool) +} + +func (c *batchStreamPoolClient) asPool() *rpc.BatchStreamPool { + return (*rpc.BatchStreamPool)(c) +} + +// Batch implements the RestrictedInternalClient interface, using the pooled +// streams in the BatchStreamPool to issue the Batch RPC. +func (c *batchStreamPoolClient) Batch( + ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, +) (*kvpb.BatchResponse, error) { + if len(opts) > 0 { + return nil, errors.AssertionFailedf("batchStreamPoolClient.Batch does not support CallOptions") + } + return c.asPool().Send(ctx, ba) +} + +// MuxRangeFeed implements the RestrictedInternalClient interface. +func (c *batchStreamPoolClient) MuxRangeFeed( + ctx context.Context, opts ...grpc.CallOption, +) (kvpb.Internal_MuxRangeFeedClient, error) { + return kvpb.NewInternalClient(c.asPool().Conn()).MuxRangeFeed(ctx, opts...) +} + +// tracingInternalClient wraps a RestrictedInternalClient and fills in trace +// information on Batch RPCs. // -// Note that TracingInternalClient is not used to wrap the internalClientAdapter +// Note that tracingInternalClient is not used to wrap the internalClientAdapter // - local RPCs don't need this tracing functionality. -type TracingInternalClient struct { - kvpb.InternalClient +type tracingInternalClient struct { + rpc.RestrictedInternalClient } -func maybeWrapInTracingClient(ctx context.Context, client kvpb.InternalClient) kvpb.InternalClient { +func maybeWrapInTracingClient( + ctx context.Context, client rpc.RestrictedInternalClient, +) rpc.RestrictedInternalClient { sp := tracing.SpanFromContext(ctx) if sp != nil && !sp.IsNoop() { - client = &TracingInternalClient{InternalClient: client} + return &tracingInternalClient{RestrictedInternalClient: client} } return client } // Batch overrides the Batch RPC client method and fills in tracing information. -func (tic *TracingInternalClient) Batch( +func (c *tracingInternalClient) Batch( ctx context.Context, ba *kvpb.BatchRequest, opts ...grpc.CallOption, ) (*kvpb.BatchResponse, error) { sp := tracing.SpanFromContext(ctx) @@ -301,5 +393,5 @@ func (tic *TracingInternalClient) Batch( ba = ba.ShallowCopy() ba.TraceInfo = sp.Meta().ToProto() } - return tic.InternalClient.Batch(ctx, ba, opts...) + return c.RestrictedInternalClient.Batch(ctx, ba, opts...) } diff --git a/pkg/rpc/nodedialer/nodedialer_test.go b/pkg/rpc/nodedialer/nodedialer_test.go index aa8298bd8d01..1cc7cf453fa5 100644 --- a/pkg/rpc/nodedialer/nodedialer_test.go +++ b/pkg/rpc/nodedialer/nodedialer_test.go @@ -415,6 +415,10 @@ func (*internalServer) Batch(context.Context, *kvpb.BatchRequest) (*kvpb.BatchRe return nil, nil } +func (*internalServer) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + panic("unimplemented") +} + func (*internalServer) RangeLookup( context.Context, *kvpb.RangeLookupRequest, ) (*kvpb.RangeLookupResponse, error) { @@ -422,7 +426,7 @@ func (*internalServer) RangeLookup( } func (s *internalServer) MuxRangeFeed(server kvpb.Internal_MuxRangeFeedServer) error { - panic("implement me") + panic("unimplemented") } func (*internalServer) GossipSubscription( diff --git a/pkg/rpc/peer.go b/pkg/rpc/peer.go index 7b24db0e7e81..3cc0bb599168 100644 --- a/pkg/rpc/peer.go +++ b/pkg/rpc/peer.go @@ -260,7 +260,7 @@ func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer { }, }) p.b = b - c := newConnectionToNodeID(k, b.Signal) + c := newConnectionToNodeID(p.opts, k, b.Signal) p.mu.PeerSnap = PeerSnap{c: c} return p @@ -361,7 +361,7 @@ func (p *peer) run(ctx context.Context, report func(error), done func()) { func() { p.mu.Lock() defer p.mu.Unlock() - p.mu.c = newConnectionToNodeID(p.k, p.mu.c.breakerSignalFn) + p.mu.c = newConnectionToNodeID(p.opts, p.k, p.mu.c.breakerSignalFn) }() if p.snap().deleteAfter != 0 { @@ -582,6 +582,11 @@ func (p *peer) onInitialHeartbeatSucceeded( p.ConnectionHeartbeats.Inc(1) // ConnectionFailures is not updated here. + // Bind the connection's stream pool to the active gRPC connection. Do this + // ahead of signaling the connFuture, so that the stream pool is ready for use + // by the time the connFuture is resolved. + p.mu.c.batchStreamPool.Bind(ctx, cc) + // Close the channel last which is helpful for unit tests that // first waitOrDefault for a healthy conn to then check metrics. p.mu.c.connFuture.Resolve(cc, nil /* err */) @@ -703,6 +708,10 @@ func (p *peer) onHeartbeatFailed( err = &netutil.InitialHeartbeatFailedError{WrappedErr: err} ls.c.connFuture.Resolve(nil /* cc */, err) } + + // Close down the stream pool that was bound to this connection. + ls.c.batchStreamPool.Close() + // By convention, we stick to updating breaker before updating peer // to make it easier to write non-flaky tests. report(err) diff --git a/pkg/rpc/stream_pool.go b/pkg/rpc/stream_pool.go new file mode 100644 index 000000000000..20773d714f0a --- /dev/null +++ b/pkg/rpc/stream_pool.go @@ -0,0 +1,333 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "context" + "io" + "slices" + "time" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" + "google.golang.org/grpc" +) + +// streamClient is a type constraint that is satisfied by a bidirectional gRPC +// client stream. +type streamClient[Req, Resp any] interface { + Send(Req) error + Recv() (Resp, error) +} + +// streamConstructor creates a new gRPC stream client over the provided client +// connection, using the provided call options. +type streamConstructor[Req, Resp, Conn any] func( + context.Context, Conn, +) (streamClient[Req, Resp], error) + +type result[Resp any] struct { + resp Resp + err error +} + +// defaultPooledStreamIdleTimeout is the default duration after which a pooled +// stream is considered idle and is closed. The idle timeout is used to ensure +// that stream pools eventually shrink when the load decreases. +const defaultPooledStreamIdleTimeout = 10 * time.Second + +// pooledStream is a wrapper around a grpc.ClientStream that is managed by a +// streamPool. It is responsible for sending a single request and receiving a +// single response on the stream at a time, mimicking the behavior of a gRPC +// unary RPC. However, unlike a unary RPC, the client stream is not discarded +// after a single use. Instead, it is returned to the pool for reuse. +// +// Most of the complexity around this type (e.g. the worker goroutine) comes +// from the need to handle context cancellation while a request is in-flight. +// gRPC streams support context cancellation, but they use the context provided +// to the stream when it was created for its entire lifetime. Meanwhile, we want +// to be able to handle context cancellation on a per-request basis while we +// layer unary RPC semantics on top of a pooled, bidirectional stream. To +// accomplish this, we use a worker goroutine to perform the (blocking) RPC +// function calls (Send and Recv) and let callers in Send wait on the result of +// the RPC call while also listening to their own context for cancellation. If +// the caller's context is canceled, it cancels the stream's context, which in +// turn cancels the RPC call. +// +// A pooledStream is not safe for concurrent use. It is intended to be used by +// only a single caller at a time. Mutual exclusion is coordinated by removing a +// pooledStream from the pool while it is in use. +// +// A pooledStream must only be returned to the pool for reuse after a successful +// Send call. If the Send call fails, the pooledStream must not be reused. +type pooledStream[Req, Resp any, Conn comparable] struct { + pool *streamPool[Req, Resp, Conn] + stream streamClient[Req, Resp] + streamCtx context.Context + streamCancel context.CancelFunc + + reqC chan Req + respC chan result[Resp] +} + +func newPooledStream[Req, Resp any, Conn comparable]( + pool *streamPool[Req, Resp, Conn], + stream streamClient[Req, Resp], + streamCtx context.Context, + streamCancel context.CancelFunc, +) *pooledStream[Req, Resp, Conn] { + return &pooledStream[Req, Resp, Conn]{ + pool: pool, + stream: stream, + streamCtx: streamCtx, + streamCancel: streamCancel, + reqC: make(chan Req), + respC: make(chan result[Resp], 1), + } +} + +func (s *pooledStream[Req, Resp, Conn]) run(ctx context.Context) { + defer s.close() + for s.runOnce(ctx) { + } +} + +func (s *pooledStream[Req, Resp, Conn]) runOnce(ctx context.Context) (loop bool) { + select { + case req := <-s.reqC: + err := s.stream.Send(req) + if err != nil { + // From grpc.ClientStream.SendMsg: + // > On error, SendMsg aborts the stream. + s.respC <- result[Resp]{err: err} + return false + } + resp, err := s.stream.Recv() + if err != nil { + // From grpc.ClientStream.RecvMsg: + // > It returns io.EOF when the stream completes successfully. On any + // > other error, the stream is aborted and the error contains the RPC + // > status. + if errors.Is(err, io.EOF) { + log.Errorf(ctx, "stream unexpectedly closed by server: %+v", err) + } + s.respC <- result[Resp]{err: err} + return false + } + s.respC <- result[Resp]{resp: resp} + return true + + case <-time.After(s.pool.idleTimeout): + // Try to remove ourselves from the pool. If we don't find ourselves in the + // pool, someone just grabbed us from the pool and we should keep running. + // If we do find and remove ourselves, we can close the stream and stop + // running. This ensures that callers never encounter spurious stream + // closures due to idle timeouts. + return !s.pool.remove(s) + + case <-ctx.Done(): + return false + } +} + +func (s *pooledStream[Req, Resp, Conn]) close() { + // Make sure the stream's context is canceled to ensure that we clean up + // resources in idle timeout case. + // + // From grpc.ClientConn.NewStream: + // > To ensure resources are not leaked due to the stream returned, one of the + // > following actions must be performed: + // > ... + // > 2. Cancel the context provided. + // > ... + s.streamCancel() + // Try to remove ourselves from the pool, now that we're closed. If we don't + // find ourselves in the pool, someone has already grabbed us from the pool + // and will check whether we are closed before putting us back. + s.pool.remove(s) +} + +// Send sends a request on the pooled stream and returns the response in a unary +// RPC fashion. Context cancellation is respected. +func (s *pooledStream[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { + var resp result[Resp] + select { + case s.reqC <- req: + // The request was passed to the stream's worker goroutine, which will + // invoke the RPC function calls (Send and Recv). Wait for a response. + select { + case resp = <-s.respC: + // Return the response. + case <-ctx.Done(): + // Cancel the stream and return the request's context error. + s.streamCancel() + resp.err = ctx.Err() + } + case <-s.streamCtx.Done(): + // The stream was closed before its worker goroutine could accept the + // request. Return the stream's context error. + resp.err = s.streamCtx.Err() + } + + if resp.err != nil { + // On error, wait until we see the streamCtx.Done() signal, to ensure that + // the stream has been cleaned up and won't be placed back in the pool by + // putIfNotClosed. + <-s.streamCtx.Done() + } + return resp.resp, resp.err +} + +// streamPool is a pool of grpc.ClientStream objects (wrapped in pooledStream) +// that are used to send requests and receive corresponding responses in a +// manner that mimics unary RPC invocation. Pooling these streams allows for +// reuse of gRPC resources across calls, as opposed to native unary RPCs, which +// create a new stream and throw it away for each request (see grpc.invoke). +type streamPool[Req, Resp any, Conn comparable] struct { + stopper *stop.Stopper + idleTimeout time.Duration + newStream streamConstructor[Req, Resp, Conn] + + // cc and ccCtx are set on bind, when the gRPC connection is established. + cc Conn + // Derived from rpc.Context.MasterCtx, canceled on stopper quiesce. + ccCtx context.Context + + streams struct { + syncutil.Mutex + s []*pooledStream[Req, Resp, Conn] + } +} + +func makeStreamPool[Req, Resp any, Conn comparable]( + stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn], +) streamPool[Req, Resp, Conn] { + return streamPool[Req, Resp, Conn]{ + stopper: stopper, + idleTimeout: defaultPooledStreamIdleTimeout, + newStream: newStream, + } +} + +// Bind sets the gRPC connection and context for the streamPool. This must be +// called once before streamPool.Send. +func (p *streamPool[Req, Resp, Conn]) Bind(ctx context.Context, cc Conn) { + p.cc = cc + p.ccCtx = ctx +} + +// Conn returns the gRPC connection bound to the streamPool. +func (p *streamPool[Req, Resp, Conn]) Conn() Conn { + return p.cc +} + +// Close closes all streams in the pool. +func (p *streamPool[Req, Resp, Conn]) Close() { + p.streams.Lock() + defer p.streams.Unlock() + for _, s := range p.streams.s { + s.streamCancel() + } + p.streams.s = nil +} + +func (p *streamPool[Req, Resp, Conn]) get() *pooledStream[Req, Resp, Conn] { + p.streams.Lock() + defer p.streams.Unlock() + if len(p.streams.s) == 0 { + return nil + } + // Pop from the tail to bias towards reusing the same streams repeatedly so + // that streams at the head of the slice are more likely to be closed due to + // idle timeouts. + s := p.streams.s[len(p.streams.s)-1] + p.streams.s[len(p.streams.s)-1] = nil + p.streams.s = p.streams.s[:len(p.streams.s)-1] + return s +} + +func (p *streamPool[Req, Resp, Conn]) putIfNotClosed(s *pooledStream[Req, Resp, Conn]) { + p.streams.Lock() + defer p.streams.Unlock() + if s.streamCtx.Err() != nil { + // The stream is closed, don't put it in the pool. Note that this must be + // done under lock to avoid racing with pooledStream.close, which attempts + // to remove a closing stream from the pool. + return + } + p.streams.s = append(p.streams.s, s) +} + +func (p *streamPool[Req, Resp, Conn]) remove(s *pooledStream[Req, Resp, Conn]) bool { + p.streams.Lock() + defer p.streams.Unlock() + i := slices.Index(p.streams.s, s) + if i == -1 { + return false + } + copy(p.streams.s[i:], p.streams.s[i+1:]) + p.streams.s[len(p.streams.s)-1] = nil + p.streams.s = p.streams.s[:len(p.streams.s)-1] + return true +} + +func (p *streamPool[Req, Resp, Conn]) newPooledStream() (*pooledStream[Req, Resp, Conn], error) { + var zero Conn + if p.cc == zero { + return nil, errors.AssertionFailedf("streamPool not bound to a client conn") + } + + ctx, cancel := context.WithCancel(p.ccCtx) + defer func() { + if cancel != nil { + cancel() + } + }() + + stream, err := p.newStream(ctx, p.cc) + if err != nil { + return nil, err + } + + s := newPooledStream(p, stream, ctx, cancel) + if err := p.stopper.RunAsyncTask(ctx, "pooled gRPC stream", s.run); err != nil { + return nil, err + } + cancel = nil + return s, nil +} + +// Send sends a request on a pooled stream and returns the response in a unary +// RPC fashion. If no stream is available in the pool, a new stream is created. +func (p *streamPool[Req, Resp, Conn]) Send(ctx context.Context, req Req) (Resp, error) { + s := p.get() + if s == nil { + var err error + s, err = p.newPooledStream() + if err != nil { + var zero Resp + return zero, err + } + } + defer p.putIfNotClosed(s) + return s.Send(ctx, req) +} + +// BatchStreamPool is a streamPool specialized for BatchStreamClient streams. +type BatchStreamPool = streamPool[*kvpb.BatchRequest, *kvpb.BatchResponse, *grpc.ClientConn] + +// BatchStreamClient is a streamClient specialized for the BatchStream RPC. +// +//go:generate mockgen -destination=mocks_generated_test.go --package=. BatchStreamClient +type BatchStreamClient = streamClient[*kvpb.BatchRequest, *kvpb.BatchResponse] + +// newBatchStream constructs a BatchStreamClient from a grpc.ClientConn. +func newBatchStream(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) { + return kvpb.NewInternalClient(cc).BatchStream(ctx) +} diff --git a/pkg/rpc/stream_pool_test.go b/pkg/rpc/stream_pool_test.go new file mode 100644 index 000000000000..192bda18b355 --- /dev/null +++ b/pkg/rpc/stream_pool_test.go @@ -0,0 +1,416 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package rpc + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/stop" + "github.com/cockroachdb/errors" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" +) + +type mockBatchStreamConstructor struct { + stream BatchStreamClient + streamErr error + streamCount int + lastStreamCtx *context.Context +} + +func (m *mockBatchStreamConstructor) newStream( + ctx context.Context, conn *grpc.ClientConn, +) (BatchStreamClient, error) { + m.streamCount++ + if m.lastStreamCtx != nil { + if m.streamCount != 1 { + panic("unexpected stream creation with non-nil lastStreamCtx") + } + *m.lastStreamCtx = ctx + } + return m.stream, m.streamErr +} + +func TestStreamPool_Basic(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(2) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(2) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + defer p.Close() + + // Exercise the pool. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) + + // Exercise the pool again. Should re-use the same connection. + resp, err = p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) +} + +func TestStreamPool_Multi(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + const num = 3 + sendC := make(chan struct{}) + recvC := make(chan struct{}) + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(num) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-recvC + return &kvpb.BatchResponse{}, nil + }). + Times(num) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + defer p.Close() + + // Exercise the pool with concurrent requests, pausing between each one to + // wait for its request to have been sent. + var g errgroup.Group + for range num { + g.Go(func() error { + _, err := p.Send(ctx, &kvpb.BatchRequest{}) + return err + }) + <-sendC + } + + // Assert that all requests have been sent and are waiting for responses. The + // pool is empty at this point, as all streams are in use. + require.Equal(t, num, conn.streamCount) + require.Equal(t, num, stopper.NumTasks()) + require.Len(t, p.streams.s, 0) + + // Allow all requests to complete. + for range num { + recvC <- struct{}{} + } + require.NoError(t, g.Wait()) + + // All three streams should be returned to the pool. + require.Len(t, p.streams.s, num) +} + +func TestStreamPool_SendBeforeBind(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Times(0) + stream.EXPECT().Recv().Times(0) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + + // Exercise the pool before it is bound to a gRPC connection. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.Regexp(t, err, "streamPool not bound to a client conn") + require.Equal(t, 0, conn.streamCount) + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_SendError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendErr := errors.New("test error") + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(sendErr).Times(1) + stream.EXPECT().Recv().Times(0) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, sendErr) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_RecvError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + recvErr := errors.New("test error") + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(nil, recvErr).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, recvErr) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_NewStreamError(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + + streamErr := errors.New("test error") + conn := &mockBatchStreamConstructor{streamErr: streamErr} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the error. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.Nil(t, resp) + require.Error(t, err) + require.ErrorIs(t, err, streamErr) + + // The stream should not be placed in the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_Cancellation(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendC := make(chan struct{}) + var streamCtx context.Context + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(1) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-streamCtx.Done() + return nil, streamCtx.Err() + }). + Times(1) + conn := &mockBatchStreamConstructor{stream: stream, lastStreamCtx: &streamCtx} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + + // Exercise the pool and observe the request get stuck. + reqCtx, cancel := context.WithCancel(ctx) + type result struct { + resp *kvpb.BatchResponse + err error + } + resC := make(chan result) + go func() { + resp, err := p.Send(reqCtx, &kvpb.BatchRequest{}) + resC <- result{resp: resp, err: err} + }() + <-sendC + select { + case <-resC: + t.Fatal("unexpected result") + case <-time.After(10 * time.Millisecond): + } + + // Cancel the request and observe the result. + cancel() + res := <-resC + require.Nil(t, res.resp) + require.Error(t, res.err) + require.ErrorIs(t, res.err, context.Canceled) + + // The stream should not be returned to the pool. + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_Quiesce(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + stopperCtx, _ := stopper.WithCancelOnQuiesce(ctx) + p.Bind(stopperCtx, new(grpc.ClientConn)) + + // Exercise the pool to create a worker goroutine. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + require.Equal(t, 1, stopper.NumTasks()) + require.Len(t, p.streams.s, 1) + + // Stop the stopper, which closes the pool. + stopper.Stop(ctx) + require.Len(t, p.streams.s, 0) +} + +func TestStreamPool_QuiesceDuringSend(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + sendC := make(chan struct{}) + var streamCtx context.Context + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT(). + Send(gomock.Any()). + DoAndReturn(func(_ *kvpb.BatchRequest) error { + sendC <- struct{}{} + return nil + }). + Times(1) + stream.EXPECT(). + Recv(). + DoAndReturn(func() (*kvpb.BatchResponse, error) { + <-streamCtx.Done() + return nil, streamCtx.Err() + }). + Times(1) + conn := &mockBatchStreamConstructor{stream: stream, lastStreamCtx: &streamCtx} + p := makeStreamPool(stopper, conn.newStream) + stopperCtx, _ := stopper.WithCancelOnQuiesce(ctx) + p.Bind(stopperCtx, new(grpc.ClientConn)) + + // Exercise the pool and observe the request get stuck. + type result struct { + resp *kvpb.BatchResponse + err error + } + resC := make(chan result) + go func() { + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + resC <- result{resp: resp, err: err} + }() + <-sendC + select { + case <-resC: + t.Fatal("unexpected result") + case <-time.After(10 * time.Millisecond): + } + + // Stop the stopper, which cancels the request and closes the pool. + stopper.Stop(ctx) + require.Len(t, p.streams.s, 0) + res := <-resC + require.Nil(t, res.resp) + require.Error(t, res.err) + require.ErrorIs(t, res.err, context.Canceled) +} + +func TestStreamPool_IdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + stopper := stop.NewStopper() + defer stopper.Stop(ctx) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + stream := NewMockBatchStreamClient(ctrl) + stream.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + stream.EXPECT().Recv().Return(&kvpb.BatchResponse{}, nil).Times(1) + conn := &mockBatchStreamConstructor{stream: stream} + p := makeStreamPool(stopper, conn.newStream) + p.Bind(ctx, new(grpc.ClientConn)) + p.idleTimeout = 10 * time.Millisecond + + // Exercise the pool to create a worker goroutine. + resp, err := p.Send(ctx, &kvpb.BatchRequest{}) + require.NotNil(t, resp) + require.NoError(t, err) + require.Equal(t, 1, conn.streamCount) + + // Eventually the worker should be stopped due to the idle timeout. + testutils.SucceedsSoon(t, func() error { + if stopper.NumTasks() != 0 { + return errors.New("worker not stopped") + } + return nil + }) + + // Once the worker is stopped, the pool should be empty. + require.Len(t, p.streams.s, 0) +} diff --git a/pkg/server/node.go b/pkg/server/node.go index bbe86bcd3244..5f5a969990a7 100644 --- a/pkg/server/node.go +++ b/pkg/server/node.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "fmt" + "io" "math" "net" "sort" @@ -1872,6 +1873,31 @@ func (n *Node) Batch(ctx context.Context, args *kvpb.BatchRequest) (*kvpb.BatchR return br, nil } +// BatchStream implements the kvpb.InternalServer interface. +func (n *Node) BatchStream(stream kvpb.Internal_BatchStreamServer) error { + ctx := stream.Context() + for { + args, err := stream.Recv() + if err != nil { + // From grpc.ServerStream.Recv: + // > It returns io.EOF when the client has performed a CloseSend. + if errors.Is(err, io.EOF) { + return nil + } + return err + } + + br, err := n.Batch(ctx, args) + if err != nil { + return err + } + err = stream.Send(br) + if err != nil { + return err + } + } +} + // spanForRequest is the retval of setupSpanForIncomingRPC. It groups together a // few variables needed when finishing an RPC's span. // diff --git a/pkg/sql/catalog/lease/BUILD.bazel b/pkg/sql/catalog/lease/BUILD.bazel index 1f0b94d53586..2ef44dc5b0e7 100644 --- a/pkg/sql/catalog/lease/BUILD.bazel +++ b/pkg/sql/catalog/lease/BUILD.bazel @@ -134,6 +134,7 @@ go_test( "//pkg/sql/sqlliveness", "//pkg/sql/sqlliveness/slbase", "//pkg/sql/sqlliveness/slprovider", + "//pkg/sql/stats", "//pkg/sql/types", "//pkg/storage", "//pkg/testutils", diff --git a/pkg/sql/catalog/lease/lease_test.go b/pkg/sql/catalog/lease/lease_test.go index d68819c26bcc..aed3fb0057fc 100644 --- a/pkg/sql/catalog/lease/lease_test.go +++ b/pkg/sql/catalog/lease/lease_test.go @@ -55,6 +55,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sqlinstance/instancestorage" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slbase" "github.com/cockroachdb/cockroach/pkg/sql/sqlliveness/slprovider" + "github.com/cockroachdb/cockroach/pkg/sql/stats" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" @@ -949,14 +950,16 @@ func TestDescriptorRefreshOnRetry(t *testing.T) { }, }, } + params.Settings = cluster.MakeTestingClusterSettings() + // Disable the automatic stats collection, which could interfere with + // the lease acquisition counts in this test. + stats.AutomaticStatisticsClusterMode.Override(ctx, ¶ms.Settings.SV, false) + // Set a long lease duration so that the periodic task to refresh leases does + // not run. + lease.LeaseDuration.Override(ctx, ¶ms.Settings.SV, 24*time.Hour) srv, sqlDB, kvDB := serverutils.StartServer(t, params) defer srv.Stopper().Stop(context.Background()) s := srv.ApplicationLayer() - // Disable the automatic stats collection, which could interfere with - // the lease acquisition counts in this test. - if _, err := sqlDB.Exec("SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false"); err != nil { - t.Fatal(err) - } if _, err := sqlDB.Exec(` CREATE DATABASE t; CREATE TABLE t.foo (v INT); diff --git a/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go b/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go index 949b27c2e1b9..8fee33b68b6e 100644 --- a/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go +++ b/pkg/util/tracing/grpcinterceptor/grpc_interceptor.go @@ -46,6 +46,9 @@ func setGRPCErrorTag(sp *tracing.Span, err error) { // BatchMethodName is the method name of Internal.Batch RPC. const BatchMethodName = "/cockroach.roachpb.Internal/Batch" +// BatchStreamMethodName is the method name of the Internal.BatchStream RPC. +const BatchStreamMethodName = "/cockroach.roachpb.Internal/BatchStream" + // sendKVBatchMethodName is the method name for adminServer.SendKVBatch. const sendKVBatchMethodName = "/cockroach.server.serverpb.Admin/SendKVBatch" @@ -61,6 +64,7 @@ const flowStreamMethodName = "/cockroach.sql.distsqlrun.DistSQL/FlowStream" // tracing because it's not worth it. func methodExcludedFromTracing(method string) bool { return method == BatchMethodName || + method == BatchStreamMethodName || method == sendKVBatchMethodName || method == SetupFlowMethodName || method == flowStreamMethodName