diff --git a/build/bazelutil/check.sh b/build/bazelutil/check.sh index e2703d20ac35..53b7eee0ed91 100755 --- a/build/bazelutil/check.sh +++ b/build/bazelutil/check.sh @@ -24,6 +24,7 @@ pkg/roachprod/prometheus/prometheus.go://go:generate mockgen -package=prometheus pkg/cmd/roachtest/clusterstats/collector.go://go:generate mockgen -package=clusterstats -destination mocks_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client pkg/cmd/roachtest/tests/drt.go://go:generate mockgen -package tests -destination drt_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client pkg/kv/kvclient/kvcoord/transport.go://go:generate mockgen -package=kvcoord -destination=mocks_generated_test.go . Transport +pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go://go:generate ../../../util/interval/generic/gen.sh *bufferedWrite kvcoord pkg/kv/kvclient/rangecache/range_cache.go://go:generate mockgen -package=rangecachemock -destination=rangecachemock/mocks_generated.go . RangeDescriptorDB pkg/kv/kvclient/rangefeed/rangefeed.go://go:generate mockgen -destination=mocks_generated_test.go --package=rangefeed . DB pkg/kv/kvserver/concurrency/lock_table.go://go:generate ../../../util/interval/generic/gen.sh *keyLocks concurrency diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt index f2c75846ef9a..a49792b619fe 100644 --- a/docs/generated/settings/settings-for-tenants.txt +++ b/docs/generated/settings/settings-for-tenants.txt @@ -91,6 +91,7 @@ kv.protectedts.reconciliation.interval duration 5m0s the frequency for reconcili kv.rangefeed.client.stream_startup_rate integer 100 controls the rate per second the client will initiate new rangefeed stream for a single range; 0 implies unlimited application kv.rangefeed.closed_timestamp_refresh_interval duration 3s the interval at which closed-timestamp updatesare delivered to rangefeeds; set to 0 to use kv.closed_timestamp.side_transport_interval system-visible kv.rangefeed.enabled boolean false if set, rangefeed registration is enabled system-visible +kv.transaction.buffered_writes.enabled boolean true if enabled, transactional writes are buffered on the gateway application kv.transaction.max_intents_and_locks integer 0 maximum count of inserts or durable locks for a single transactions, 0 to disable application kv.transaction.max_intents_bytes integer 4194304 maximum number of bytes used to track locks in transactions application kv.transaction.max_refresh_spans_bytes integer 4194304 maximum number of bytes used to track refresh spans in serializable transactions application diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html index fe10f493845c..2032b0581b25 100644 --- a/docs/generated/settings/settings.html +++ b/docs/generated/settings/settings.html @@ -120,6 +120,7 @@
kv.replication_reports.interval
duration1m0sthe frequency for generating the replication_constraint_stats, replication_stats_report and replication_critical_localities reports (set to 0 to disable)Dedicated/Self-Hosted
kv.snapshot_rebalance.max_rate
byte size32 MiBthe rate limit (bytes/sec) to use for rebalance and upreplication snapshotsDedicated/Self-Hosted
kv.snapshot_receiver.excise.enabled
booleantrueset to false to disable excises in place of range deletions for KV snapshotsDedicated/Self-Hosted +
kv.transaction.buffered_writes.enabled
booleantrueif enabled, transactional writes are buffered on the gatewayServerless/Dedicated/Self-Hosted
kv.transaction.max_intents_and_locks
integer0maximum count of inserts or durable locks for a single transactions, 0 to disableServerless/Dedicated/Self-Hosted
kv.transaction.max_intents_bytes
integer4194304maximum number of bytes used to track locks in transactionsServerless/Dedicated/Self-Hosted
kv.transaction.max_refresh_spans_bytes
integer4194304maximum number of bytes used to track refresh spans in serializable transactionsServerless/Dedicated/Self-Hosted diff --git a/pkg/gen/misc.bzl b/pkg/gen/misc.bzl index 8163f27a4bb6..e3470dfd2916 100644 --- a/pkg/gen/misc.bzl +++ b/pkg/gen/misc.bzl @@ -4,6 +4,8 @@ MISC_SRCS = [ "//pkg/backup:data_driven_generated_test.go", "//pkg/ccl/kvccl/kvtenantccl/upgradeinterlockccl:generated_test.go", "//pkg/internal/team:TEAMS.yaml", + "//pkg/kv/kvclient/kvcoord:bufferedwrite_interval_btree.go", + "//pkg/kv/kvclient/kvcoord:bufferedwrite_interval_btree_test.go", "//pkg/kv/kvpb:batch_generated.go", "//pkg/kv/kvserver/concurrency:keylocks_interval_btree.go", "//pkg/kv/kvserver/concurrency:keylocks_interval_btree_test.go", diff --git a/pkg/kv/kvclient/kvcoord/BUILD.bazel b/pkg/kv/kvclient/kvcoord/BUILD.bazel index 8b926c105187..fe80799e530b 100644 --- a/pkg/kv/kvclient/kvcoord/BUILD.bazel +++ b/pkg/kv/kvclient/kvcoord/BUILD.bazel @@ -2,6 +2,7 @@ load("@bazel_gomock//:gomock.bzl", "gomock") load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") load("//build:STRINGER.bzl", "stringer") load("//pkg/testutils:buildutil/buildutil.bzl", "disallowed_imports_test") +load("//pkg/util/interval/generic:gen.bzl", "gen_interval_btree") go_library( name = "kvcoord", @@ -31,8 +32,10 @@ go_library( "txn_interceptor_pipeliner.go", "txn_interceptor_seq_num_allocator.go", "txn_interceptor_span_refresher.go", + "txn_interceptor_write_buffer.go", "txn_lock_gatekeeper.go", "txn_metrics.go", + ":bufferedwrite_interval_btree.go", # keep ":gen-txnstate-stringer", # keep ], importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvcoord", @@ -152,6 +155,7 @@ go_test( "txn_interceptor_seq_num_allocator_test.go", "txn_interceptor_span_refresher_test.go", "txn_test.go", + ":bufferedwrite_interval_btree.go", # keep ":mock_kvcoord", # keep ], data = glob(["testdata/**"]), @@ -255,6 +259,12 @@ stringer( typ = "txnState", ) +gen_interval_btree( + name = "buffered_write_interval_btree", + package = "kvcoord", + type = "*bufferedWrite", +) + disallowed_imports_test( "kvcoord", disallowed_list = [ diff --git a/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go new file mode 100644 index 000000000000..64d19ec7e15a --- /dev/null +++ b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go @@ -0,0 +1,1170 @@ +// Code generated by go_generics. DO NOT EDIT. + +// Copyright 2020 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package kvcoord + +import ( + "bytes" + "sort" + "strings" + "sync" + "sync/atomic" +) + +// nilT is a nil instance of the Template type. +var nilT *bufferedWrite + +const ( + degree = 16 + maxItems = 2*degree - 1 + minItems = degree - 1 +) + +// compare returns a value indicating the sort order relationship between +// a and b. The comparison is performed lexicographically on +// +// (a.Key(), a.EndKey(), a.ID()) +// +// and +// +// (b.Key(), b.EndKey(), b.ID()) +// +// tuples. +// +// Given c = compare(a, b): +// +// c == -1 if (a.Key(), a.EndKey(), a.ID()) < (b.Key(), b.EndKey(), b.ID()) +// c == 0 if (a.Key(), a.EndKey(), a.ID()) == (b.Key(), b.EndKey(), b.ID()) +// c == 1 if (a.Key(), a.EndKey(), a.ID()) > (b.Key(), b.EndKey(), b.ID()) +func compare(a, b *bufferedWrite) int { + c := bytes.Compare(a.Key(), b.Key()) + if c != 0 { + return c + } + c = bytes.Compare(a.EndKey(), b.EndKey()) + if c != 0 { + return c + } + if a.ID() < b.ID() { + return -1 + } else if a.ID() > b.ID() { + return 1 + } else { + return 0 + } +} + +// keyBound represents the upper-bound of a key range. +type keyBound struct { + key []byte + inc bool +} + +func (b keyBound) compare(o keyBound) int { + c := bytes.Compare(b.key, o.key) + if c != 0 { + return c + } + if b.inc == o.inc { + return 0 + } + if b.inc { + return 1 + } + return -1 +} + +func (b keyBound) contains(a *bufferedWrite) bool { + c := bytes.Compare(a.Key(), b.key) + if c == 0 { + return b.inc + } + return c < 0 +} + +func upperBound(c *bufferedWrite) keyBound { + if len(c.EndKey()) != 0 { + return keyBound{key: c.EndKey()} + } + return keyBound{key: c.Key(), inc: true} +} + +type node struct { + ref int32 + count int16 + + // These fields form a keyBound, but by inlining them into node we can avoid + // the extra word that would be needed to pad out maxInc if it were part of + // its own struct. + maxInc bool + maxKey []byte + + items [maxItems]*bufferedWrite + + // The children array pointer is only populated for interior nodes; it is nil + // for leaf nodes. + children *childrenArray +} + +type childrenArray = [maxItems + 1]*node + +var leafPool = sync.Pool{ + New: func() interface{} { + return new(node) + }, +} + +var nodePool = sync.Pool{ + New: func() interface{} { + type interiorNode struct { + node + children childrenArray + } + n := new(interiorNode) + n.node.children = &n.children + return &n.node + }, +} + +func newLeafNode() *node { + n := leafPool.Get().(*node) + n.ref = 1 + return n +} + +func newNode() *node { + n := nodePool.Get().(*node) + n.ref = 1 + return n +} + +// mut creates and returns a mutable node reference. If the node is not shared +// with any other trees then it can be modified in place. Otherwise, it must be +// cloned to ensure unique ownership. In this way, we enforce a copy-on-write +// policy which transparently incorporates the idea of local mutations, like +// Clojure's transients or Haskell's ST monad, where nodes are only copied +// during the first time that they are modified between Clone operations. +// +// When a node is cloned, the provided pointer will be redirected to the new +// mutable node. +func mut(n **node) *node { + if atomic.LoadInt32(&(*n).ref) == 1 { + // Exclusive ownership. Can mutate in place. + return *n + } + // If we do not have unique ownership over the node then we + // clone it to gain unique ownership. After doing so, we can + // release our reference to the old node. We pass recursive + // as true because even though we just observed the node's + // reference count to be greater than 1, we might be racing + // with another call to decRef on this node. + c := (*n).clone() + (*n).decRef(true /* recursive */) + *n = c + return *n +} + +// leaf returns true if this is a leaf node. +func (n *node) leaf() bool { + return n.children == nil +} + +// max returns the maximum keyBound in the subtree rooted at this node. +func (n *node) max() keyBound { + return keyBound{ + key: n.maxKey, + inc: n.maxInc, + } +} + +// setMax sets the maximum keyBound for the subtree rooted at this node. +func (n *node) setMax(k keyBound) { + n.maxKey = k.key + n.maxInc = k.inc +} + +// incRef acquires a reference to the node. +func (n *node) incRef() { + atomic.AddInt32(&n.ref, 1) +} + +// decRef releases a reference to the node. If requested, the method +// will recurse into child nodes and decrease their refcounts as well. +func (n *node) decRef(recursive bool) { + if atomic.AddInt32(&n.ref, -1) > 0 { + // Other references remain. Can't free. + return + } + // Clear and release node into memory pool. + if n.leaf() { + *n = node{} + leafPool.Put(n) + } else { + // Release child references first, if requested. + if recursive { + for i := int16(0); i <= n.count; i++ { + n.children[i].decRef(true /* recursive */) + } + } + *n = node{children: n.children} + *n.children = childrenArray{} + nodePool.Put(n) + } +} + +// clone creates a clone of the receiver with a single reference count. +func (n *node) clone() *node { + var c *node + if n.leaf() { + c = newLeafNode() + } else { + c = newNode() + } + // NB: copy field-by-field without touching n.ref to avoid + // triggering the race detector and looking like a data race. + c.count = n.count + c.maxKey = n.maxKey + c.maxInc = n.maxInc + c.items = n.items + if !c.leaf() { + // Copy children and increase each refcount. + *c.children = *n.children + for i := int16(0); i <= c.count; i++ { + c.children[i].incRef() + } + } + return c +} + +func (n *node) insertAt(index int, item *bufferedWrite, nd *node) { + if index < int(n.count) { + copy(n.items[index+1:n.count+1], n.items[index:n.count]) + if !n.leaf() { + copy(n.children[index+2:n.count+2], n.children[index+1:n.count+1]) + } + } + n.items[index] = item + if !n.leaf() { + n.children[index+1] = nd + } + n.count++ +} + +func (n *node) pushBack(item *bufferedWrite, nd *node) { + n.items[n.count] = item + if !n.leaf() { + n.children[n.count+1] = nd + } + n.count++ +} + +func (n *node) pushFront(item *bufferedWrite, nd *node) { + if !n.leaf() { + copy(n.children[1:n.count+2], n.children[:n.count+1]) + n.children[0] = nd + } + copy(n.items[1:n.count+1], n.items[:n.count]) + n.items[0] = item + n.count++ +} + +// removeAt removes a value at a given index, pulling all subsequent values +// back. +func (n *node) removeAt(index int) (*bufferedWrite, *node) { + var child *node + if !n.leaf() { + child = n.children[index+1] + copy(n.children[index+1:n.count], n.children[index+2:n.count+1]) + n.children[n.count] = nil + } + n.count-- + out := n.items[index] + copy(n.items[index:n.count], n.items[index+1:n.count+1]) + n.items[n.count] = nilT + return out, child +} + +// popBack removes and returns the last element in the list. +func (n *node) popBack() (*bufferedWrite, *node) { + n.count-- + out := n.items[n.count] + n.items[n.count] = nilT + if n.leaf() { + return out, nil + } + child := n.children[n.count+1] + n.children[n.count+1] = nil + return out, child +} + +// popFront removes and returns the first element in the list. +func (n *node) popFront() (*bufferedWrite, *node) { + n.count-- + var child *node + if !n.leaf() { + child = n.children[0] + copy(n.children[:n.count+1], n.children[1:n.count+2]) + n.children[n.count+1] = nil + } + out := n.items[0] + copy(n.items[:n.count], n.items[1:n.count+1]) + n.items[n.count] = nilT + return out, child +} + +// find returns the index where the given item should be inserted into this +// list. 'found' is true if the item already exists in the list at the given +// index. +func (n *node) find(item *bufferedWrite) (index int, found bool) { + // Logic copied from sort.Search. Inlining this gave + // an 11% speedup on BenchmarkBTreeDeleteInsert. + i, j := 0, int(n.count) + for i < j { + h := int(uint(i+j) >> 1) // avoid overflow when computing h + // i ≤ h < j + v := compare(item, n.items[h]) + if v == 0 { + return h, true + } else if v > 0 { + i = h + 1 + } else { + j = h + } + } + return i, false +} + +// split splits the given node at the given index. The current node shrinks, +// and this function returns the item that existed at that index and a new +// node containing all items/children after it. +// +// Before: +// +// +-----------+ +// | x y z | +// +--/-/-\-\--+ +// +// After: +// +// +-----------+ +// | y | +// +----/-\----+ +// / \ +// v v +// +// +-----------+ +-----------+ +// | x | | z | +// +-----------+ +-----------+ +func (n *node) split(i int) (*bufferedWrite, *node) { + out := n.items[i] + var next *node + if n.leaf() { + next = newLeafNode() + } else { + next = newNode() + } + next.count = n.count - int16(i+1) + copy(next.items[:], n.items[i+1:n.count]) + for j := int16(i); j < n.count; j++ { + n.items[j] = nilT + } + if !n.leaf() { + copy(next.children[:], n.children[i+1:n.count+1]) + for j := int16(i + 1); j <= n.count; j++ { + n.children[j] = nil + } + } + n.count = int16(i) + + nextMax := next.findUpperBound() + next.setMax(nextMax) + nMax := n.max() + if nMax.compare(nextMax) != 0 && nMax.compare(upperBound(out)) != 0 { + // If upper bound wasn't from new node or item + // at index i, it must still be from old node. + } else { + n.setMax(n.findUpperBound()) + } + return out, next +} + +// insert inserts an item into the subtree rooted at this node, making sure no +// nodes in the subtree exceed maxItems items. Returns true if an existing item +// was replaced and false if an item was inserted. Also returns whether the +// node's upper bound changes. +func (n *node) insert(item *bufferedWrite) (replaced, newBound bool) { + i, found := n.find(item) + if found { + n.items[i] = item + return true, false + } + if n.leaf() { + n.insertAt(i, item, nil) + return false, n.adjustUpperBoundOnInsertion(item, nil) + } + if n.children[i].count >= maxItems { + splitLa, splitNode := mut(&n.children[i]).split(maxItems / 2) + n.insertAt(i, splitLa, splitNode) + + switch v := compare(item, n.items[i]); { + case v < 0: + // no change, we want first split node + case v > 0: + i++ // we want second split node + default: + n.items[i] = item + return true, false + } + } + replaced, newBound = mut(&n.children[i]).insert(item) + if newBound { + newBound = n.adjustUpperBoundOnInsertion(item, nil) + } + return replaced, newBound +} + +// removeMax removes and returns the maximum item from the subtree rooted at +// this node. +func (n *node) removeMax() *bufferedWrite { + if n.leaf() { + n.count-- + out := n.items[n.count] + n.items[n.count] = nilT + n.adjustUpperBoundOnRemoval(out, nil) + return out + } + // Recurse into max child. + i := int(n.count) + if n.children[i].count <= minItems { + // Child not large enough to remove from. + n.rebalanceOrMerge(i) + return n.removeMax() // redo + } + child := mut(&n.children[i]) + out := child.removeMax() + n.adjustUpperBoundOnRemoval(out, nil) + return out +} + +// remove removes an item from the subtree rooted at this node. Returns the item +// that was removed or nil if no matching item was found. Also returns whether +// the node's upper bound changes. +func (n *node) remove(item *bufferedWrite) (out *bufferedWrite, newBound bool) { + i, found := n.find(item) + if n.leaf() { + if found { + out, _ = n.removeAt(i) + return out, n.adjustUpperBoundOnRemoval(out, nil) + } + return nilT, false + } + if n.children[i].count <= minItems { + // Child not large enough to remove from. + n.rebalanceOrMerge(i) + return n.remove(item) // redo + } + child := mut(&n.children[i]) + if found { + // Replace the item being removed with the max item in our left child. + out = n.items[i] + n.items[i] = child.removeMax() + return out, n.adjustUpperBoundOnRemoval(out, nil) + } + // Latch is not in this node and child is large enough to remove from. + out, newBound = child.remove(item) + if newBound { + newBound = n.adjustUpperBoundOnRemoval(out, nil) + } + return out, newBound +} + +// rebalanceOrMerge grows child 'i' to ensure it has sufficient room to remove +// an item from it while keeping it at or above minItems. +func (n *node) rebalanceOrMerge(i int) { + switch { + case i > 0 && n.children[i-1].count > minItems: + // Rebalance from left sibling. + // + // +-----------+ + // | y | + // +----/-\----+ + // / \ + // v v + // +-----------+ +-----------+ + // | x | | | + // +----------\+ +-----------+ + // \ + // v + // a + // + // After: + // + // +-----------+ + // | x | + // +----/-\----+ + // / \ + // v v + // +-----------+ +-----------+ + // | | | y | + // +-----------+ +/----------+ + // / + // v + // a + // + left := mut(&n.children[i-1]) + child := mut(&n.children[i]) + xLa, grandChild := left.popBack() + yLa := n.items[i-1] + child.pushFront(yLa, grandChild) + n.items[i-1] = xLa + + left.adjustUpperBoundOnRemoval(xLa, grandChild) + child.adjustUpperBoundOnInsertion(yLa, grandChild) + + case i < int(n.count) && n.children[i+1].count > minItems: + // Rebalance from right sibling. + // + // +-----------+ + // | y | + // +----/-\----+ + // / \ + // v v + // +-----------+ +-----------+ + // | | | x | + // +-----------+ +/----------+ + // / + // v + // a + // + // After: + // + // +-----------+ + // | x | + // +----/-\----+ + // / \ + // v v + // +-----------+ +-----------+ + // | y | | | + // +----------\+ +-----------+ + // \ + // v + // a + // + right := mut(&n.children[i+1]) + child := mut(&n.children[i]) + xLa, grandChild := right.popFront() + yLa := n.items[i] + child.pushBack(yLa, grandChild) + n.items[i] = xLa + + right.adjustUpperBoundOnRemoval(xLa, grandChild) + child.adjustUpperBoundOnInsertion(yLa, grandChild) + + default: + // Merge with either the left or right sibling. + // + // +-----------+ + // | u y v | + // +----/-\----+ + // / \ + // v v + // +-----------+ +-----------+ + // | x | | z | + // +-----------+ +-----------+ + // + // After: + // + // +-----------+ + // | u v | + // +-----|-----+ + // | + // v + // +-----------+ + // | x y z | + // +-----------+ + // + if i >= int(n.count) { + i = int(n.count - 1) + } + child := mut(&n.children[i]) + // Make mergeChild mutable, bumping the refcounts on its children if necessary. + _ = mut(&n.children[i+1]) + mergeLa, mergeChild := n.removeAt(i) + child.items[child.count] = mergeLa + copy(child.items[child.count+1:], mergeChild.items[:mergeChild.count]) + if !child.leaf() { + copy(child.children[child.count+1:], mergeChild.children[:mergeChild.count+1]) + } + child.count += mergeChild.count + 1 + + child.adjustUpperBoundOnInsertion(mergeLa, mergeChild) + mergeChild.decRef(false /* recursive */) + } +} + +// findUpperBound returns the largest end key node range, assuming that its +// children have correct upper bounds already set. +func (n *node) findUpperBound() keyBound { + var max keyBound + for i := int16(0); i < n.count; i++ { + up := upperBound(n.items[i]) + if max.compare(up) < 0 { + max = up + } + } + if !n.leaf() { + for i := int16(0); i <= n.count; i++ { + up := n.children[i].max() + if max.compare(up) < 0 { + max = up + } + } + } + return max +} + +// adjustUpperBoundOnInsertion adjusts the upper key bound for this node given +// an item and an optional child node that was inserted. Returns true is the +// upper bound was changed and false if not. +func (n *node) adjustUpperBoundOnInsertion(item *bufferedWrite, child *node) bool { + up := upperBound(item) + if child != nil { + if childMax := child.max(); up.compare(childMax) < 0 { + up = childMax + } + } + if n.max().compare(up) < 0 { + n.setMax(up) + return true + } + return false +} + +// adjustUpperBoundOnRemoval adjusts the upper key bound for this node given an +// item and an optional child node that was removed. Returns true is the upper +// bound was changed and false if not. +func (n *node) adjustUpperBoundOnRemoval(item *bufferedWrite, child *node) bool { + up := upperBound(item) + if child != nil { + if childMax := child.max(); up.compare(childMax) < 0 { + up = childMax + } + } + if n.max().compare(up) == 0 { + // up was previous upper bound of n. + max := n.findUpperBound() + n.setMax(max) + return max.compare(up) != 0 + } + return false +} + +// btree is an implementation of an augmented interval B-Tree. +// +// btree stores items in an ordered structure, allowing easy insertion, +// removal, and iteration. It represents intervals and permits an interval +// search operation following the approach laid out in CLRS, Chapter 14. +// The B-Tree stores items in order based on their start key and each +// B-Tree node maintains the upper-bound end key of all items in its +// subtree. +// +// Write operations are not safe for concurrent mutation by multiple +// goroutines, but Read operations are. +type btree struct { + root *node + length int +} + +// Reset removes all items from the btree. In doing so, it allows memory +// held by the btree to be recycled. Failure to call this method before +// letting a btree be GCed is safe in that it won't cause a memory leak, +// but it will prevent btree nodes from being efficiently re-used. +func (t *btree) Reset() { + if t.root != nil { + t.root.decRef(true /* recursive */) + t.root = nil + } + t.length = 0 +} + +// Clone clones the btree, lazily. It does so in constant time. +func (t *btree) Clone() btree { + c := *t + if c.root != nil { + // Incrementing the reference count on the root node is sufficient to + // ensure that no node in the cloned tree can be mutated by an actor + // holding a reference to the original tree and vice versa. This + // property is upheld because the root node in the receiver btree and + // the returned btree will both necessarily have a reference count of at + // least 2 when this method returns. All tree mutations recursively + // acquire mutable node references (see mut) as they traverse down the + // tree. The act of acquiring a mutable node reference performs a clone + // if a node's reference count is greater than one. Cloning a node (see + // clone) increases the reference count on each of its children, + // ensuring that they have a reference count of at least 2. This, in + // turn, ensures that any of the child nodes that are modified will also + // be copied-on-write, recursively ensuring the immutability property + // over the entire tree. + c.root.incRef() + } + return c +} + +// Delete removes an item equal to the passed in item from the tree. +func (t *btree) Delete(item *bufferedWrite) { + if t.root == nil || t.root.count == 0 { + return + } + if out, _ := mut(&t.root).remove(item); out != nilT { + t.length-- + } + if t.root.count == 0 { + old := t.root + if t.root.leaf() { + t.root = nil + } else { + t.root = t.root.children[0] + } + old.decRef(false /* recursive */) + } +} + +// Set adds the given item to the tree. If an item in the tree already equals +// the given one, it is replaced with the new item. +func (t *btree) Set(item *bufferedWrite) { + if t.root == nil { + t.root = newLeafNode() + } else if t.root.count >= maxItems { + splitLa, splitNode := mut(&t.root).split(maxItems / 2) + newRoot := newNode() + newRoot.count = 1 + newRoot.items[0] = splitLa + newRoot.children[0] = t.root + newRoot.children[1] = splitNode + newRoot.setMax(newRoot.findUpperBound()) + t.root = newRoot + } + if replaced, _ := mut(&t.root).insert(item); !replaced { + t.length++ + } +} + +// MakeIter returns a new iterator object. It is not safe to continue using an +// iterator after modifications are made to the tree. If modifications are made, +// create a new iterator. +func (t *btree) MakeIter() iterator { + return iterator{r: t.root, pos: -1} +} + +// Height returns the height of the tree. +func (t *btree) Height() int { + if t.root == nil { + return 0 + } + h := 1 + n := t.root + for !n.leaf() { + n = n.children[0] + h++ + } + return h +} + +// Len returns the number of items currently in the tree. +func (t *btree) Len() int { + return t.length +} + +// String returns a string description of the tree. The format is +// similar to the https://en.wikipedia.org/wiki/Newick_format. +func (t *btree) String() string { + if t.length == 0 { + return ";" + } + var b strings.Builder + t.root.writeString(&b) + return b.String() +} + +func (n *node) writeString(b *strings.Builder) { + if n.leaf() { + for i := int16(0); i < n.count; i++ { + if i != 0 { + b.WriteString(",") + } + b.WriteString(n.items[i].String()) + } + return + } + for i := int16(0); i <= n.count; i++ { + b.WriteString("(") + n.children[i].writeString(b) + b.WriteString(")") + if i < n.count { + b.WriteString(n.items[i].String()) + } + } +} + +// iterStack represents a stack of (node, pos) tuples, which captures +// iteration state as an iterator descends a btree. +type iterStack struct { + a iterStackArr + aLen int16 // -1 when using s + s []iterFrame +} + +// Used to avoid allocations for stacks below a certain size. +type iterStackArr [3]iterFrame + +type iterFrame struct { + n *node + pos int16 +} + +func (is *iterStack) push(f iterFrame) { + if is.aLen == -1 { + is.s = append(is.s, f) + } else if int(is.aLen) == len(is.a) { + is.s = make([]iterFrame, int(is.aLen)+1, 2*int(is.aLen)) + copy(is.s, is.a[:]) + is.s[int(is.aLen)] = f + is.aLen = -1 + } else { + is.a[is.aLen] = f + is.aLen++ + } +} + +func (is *iterStack) pop() iterFrame { + if is.aLen == -1 { + f := is.s[len(is.s)-1] + is.s = is.s[:len(is.s)-1] + return f + } + is.aLen-- + return is.a[is.aLen] +} + +func (is *iterStack) len() int { + if is.aLen == -1 { + return len(is.s) + } + return int(is.aLen) +} + +func (is *iterStack) reset() { + if is.aLen == -1 { + is.s = is.s[:0] + } else { + is.aLen = 0 + } +} + +// iterator is responsible for search and traversal within a btree. +type iterator struct { + r *node + n *node + pos int16 + s iterStack + o overlapScan +} + +func (i *iterator) reset() { + i.n = i.r + i.pos = -1 + i.s.reset() + i.o = overlapScan{} +} + +func (i *iterator) descend(n *node, pos int16) { + i.s.push(iterFrame{n: n, pos: pos}) + i.n = n.children[pos] + i.pos = 0 +} + +// ascend ascends up to the current node's parent and resets the position +// to the one previously set for this parent node. +func (i *iterator) ascend() { + f := i.s.pop() + i.n = f.n + i.pos = f.pos +} + +// SeekGE seeks to the first item greater-than or equal to the provided +// item. +func (i *iterator) SeekGE(item *bufferedWrite) { + i.reset() + if i.n == nil { + return + } + for { + pos, found := i.n.find(item) + i.pos = int16(pos) + if found { + return + } + if i.n.leaf() { + if i.pos == i.n.count { + i.Next() + } + return + } + i.descend(i.n, i.pos) + } +} + +// SeekLT seeks to the first item less-than the provided item. +func (i *iterator) SeekLT(item *bufferedWrite) { + i.reset() + if i.n == nil { + return + } + for { + pos, found := i.n.find(item) + i.pos = int16(pos) + if found || i.n.leaf() { + i.Prev() + return + } + i.descend(i.n, i.pos) + } +} + +// First seeks to the first item in the btree. +func (i *iterator) First() { + i.reset() + if i.n == nil { + return + } + for !i.n.leaf() { + i.descend(i.n, 0) + } + i.pos = 0 +} + +// Last seeks to the last item in the btree. +func (i *iterator) Last() { + i.reset() + if i.n == nil { + return + } + for !i.n.leaf() { + i.descend(i.n, i.n.count) + } + i.pos = i.n.count - 1 +} + +// Next positions the iterator to the item immediately following +// its current position. +func (i *iterator) Next() { + if i.n == nil { + return + } + + if i.n.leaf() { + i.pos++ + if i.pos < i.n.count { + return + } + for i.s.len() > 0 && i.pos >= i.n.count { + i.ascend() + } + return + } + + i.descend(i.n, i.pos+1) + for !i.n.leaf() { + i.descend(i.n, 0) + } + i.pos = 0 +} + +// Prev positions the iterator to the item immediately preceding +// its current position. +func (i *iterator) Prev() { + if i.n == nil { + return + } + + if i.n.leaf() { + i.pos-- + if i.pos >= 0 { + return + } + for i.s.len() > 0 && i.pos < 0 { + i.ascend() + i.pos-- + } + return + } + + i.descend(i.n, i.pos) + for !i.n.leaf() { + i.descend(i.n, i.n.count) + } + i.pos = i.n.count - 1 +} + +// Valid returns whether the iterator is positioned at a valid position. +func (i *iterator) Valid() bool { + return i.pos >= 0 && i.pos < i.n.count +} + +// Cur returns the item at the iterator's current position. It is illegal +// to call Cur if the iterator is not valid. +func (i *iterator) Cur() *bufferedWrite { + return i.n.items[i.pos] +} + +// An overlap scan is a scan over all items that overlap with the provided +// item in order of the overlapping items' start keys. The goal of the scan +// is to minimize the number of key comparisons performed in total. The +// algorithm operates based on the following two invariants maintained by +// augmented interval btree: +// 1. all items are sorted in the btree based on their start key. +// 2. all btree nodes maintain the upper bound end key of all items +// in their subtree. +// +// The scan algorithm starts in "unconstrained minimum" and "unconstrained +// maximum" states. To enter a "constrained minimum" state, the scan must reach +// items in the tree with start keys above the search range's start key. +// Because items in the tree are sorted by start key, once the scan enters the +// "constrained minimum" state it will remain there. To enter a "constrained +// maximum" state, the scan must determine the first child btree node in a given +// subtree that can have items with start keys above the search range's end +// key. The scan then remains in the "constrained maximum" state until it +// traverse into this child node, at which point it moves to the "unconstrained +// maximum" state again. +// +// The scan algorithm works like a standard btree forward scan with the +// following augmentations: +// 1. before tranversing the tree, the scan performs a binary search on the +// root node's items to determine a "soft" lower-bound constraint position +// and a "hard" upper-bound constraint position in the root's children. +// 2. when tranversing into a child node in the lower or upper bound constraint +// position, the constraint is refined by searching the child's items. +// 3. the initial traversal down the tree follows the left-most children +// whose upper bound end keys are equal to or greater than the start key +// of the search range. The children followed will be equal to or less +// than the soft lower bound constraint. +// 4. once the initial tranversal completes and the scan is in the left-most +// btree node whose upper bound overlaps the search range, key comparisons +// must be performed with each item in the tree. This is necessary because +// any of these items may have end keys that cause them to overlap with the +// search range. +// 5. once the scan reaches the lower bound constraint position (the first item +// with a start key equal to or greater than the search range's start key), +// it can begin scaning without performing key comparisons. This is allowed +// because all items from this point forward will have end keys that are +// greater than the search range's start key. +// 6. once the scan reaches the upper bound constraint position, it terminates. +// It does so because the item at this position is the first item with a +// start key larger than the search range's end key. +type overlapScan struct { + // The "soft" lower-bound constraint. + constrMinN *node + constrMinPos int16 + constrMinReached bool + + // The "hard" upper-bound constraint. + constrMaxN *node + constrMaxPos int16 +} + +// FirstOverlap seeks to the first item in the btree that overlaps with the +// provided search item. +func (i *iterator) FirstOverlap(item *bufferedWrite) { + i.reset() + if i.n == nil { + return + } + i.pos = 0 + i.o = overlapScan{} + i.constrainMinSearchBounds(item) + i.constrainMaxSearchBounds(item) + i.findNextOverlap(item) +} + +// NextOverlap positions the iterator to the item immediately following +// its current position that overlaps with the search item. +func (i *iterator) NextOverlap(item *bufferedWrite) { + if i.n == nil { + return + } + i.pos++ + i.findNextOverlap(item) +} + +func (i *iterator) constrainMinSearchBounds(item *bufferedWrite) { + k := item.Key() + j := sort.Search(int(i.n.count), func(j int) bool { + return bytes.Compare(k, i.n.items[j].Key()) <= 0 + }) + i.o.constrMinN = i.n + i.o.constrMinPos = int16(j) +} + +func (i *iterator) constrainMaxSearchBounds(item *bufferedWrite) { + up := upperBound(item) + j := sort.Search(int(i.n.count), func(j int) bool { + return !up.contains(i.n.items[j]) + }) + i.o.constrMaxN = i.n + i.o.constrMaxPos = int16(j) +} + +func (i *iterator) findNextOverlap(item *bufferedWrite) { + for { + if i.pos > i.n.count { + // Iterate up tree. + i.ascend() + } else if !i.n.leaf() { + // Iterate down tree. + if i.o.constrMinReached || i.n.children[i.pos].max().contains(item) { + par := i.n + pos := i.pos + i.descend(par, pos) + + // Refine the constraint bounds, if necessary. + if par == i.o.constrMinN && pos == i.o.constrMinPos { + i.constrainMinSearchBounds(item) + } + if par == i.o.constrMaxN && pos == i.o.constrMaxPos { + i.constrainMaxSearchBounds(item) + } + continue + } + } + + // Check search bounds. + if i.n == i.o.constrMaxN && i.pos == i.o.constrMaxPos { + // Invalid. Past possible overlaps. + i.pos = i.n.count + return + } + if i.n == i.o.constrMinN && i.pos == i.o.constrMinPos { + // The scan reached the soft lower-bound constraint. + i.o.constrMinReached = true + } + + // Iterate across node. + if i.pos < i.n.count { + // Check for overlapping item. + if i.o.constrMinReached { + // Fast-path to avoid span comparison. i.o.constrMinReached + // tells us that all items have end keys above our search + // span's start key. + return + } + if upperBound(i.n.items[i.pos]).contains(item) { + return + } + } + i.pos++ + } +} diff --git a/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go new file mode 100644 index 000000000000..01ba1c155290 --- /dev/null +++ b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go @@ -0,0 +1,1111 @@ +// Code generated by go_generics. DO NOT EDIT. + +// Copyright 2020 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package kvcoord + +import ( + "fmt" + "math/rand" + "reflect" + "sync" + "testing" + + // Load pkg/keys so that roachpb.Span.String() could be executed correctly. + _ "github.com/cockroachdb/cockroach/pkg/keys" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/stretchr/testify/require" +) + +func newItem(s roachpb.Span) *bufferedWrite { + i := nilT.New() + i.SetKey(s.Key) + i.SetEndKey(s.EndKey) + return i +} + +func spanFromItem(i *bufferedWrite) roachpb.Span { + return roachpb.Span{Key: i.Key(), EndKey: i.EndKey()} +} + +////////////////////////////////////////// +// Invariant verification // +////////////////////////////////////////// + +// Verify asserts that the tree's structural invariants all hold. +func (t *btree) Verify(tt *testing.T) { + if t.length == 0 { + require.Nil(tt, t.root) + return + } + t.verifyLeafSameDepth(tt) + t.verifyCountAllowed(tt) + t.isSorted(tt) + t.isUpperBoundCorrect(tt) +} + +func (t *btree) verifyLeafSameDepth(tt *testing.T) { + h := t.Height() + t.root.verifyDepthEqualToHeight(tt, 1, h) +} + +func (n *node) verifyDepthEqualToHeight(t *testing.T, depth, height int) { + if n.leaf() { + require.Equal(t, height, depth, "all leaves should have the same depth as the tree height") + } + n.recurse(func(child *node, _ int16) { + child.verifyDepthEqualToHeight(t, depth+1, height) + }) +} + +func (t *btree) verifyCountAllowed(tt *testing.T) { + t.root.verifyCountAllowed(tt, true) +} + +func (n *node) verifyCountAllowed(t *testing.T, root bool) { + if !root { + require.GreaterOrEqual(t, n.count, int16(minItems), "latch count %d must be in range [%d,%d]", n.count, minItems, maxItems) + require.LessOrEqual(t, n.count, int16(maxItems), "latch count %d must be in range [%d,%d]", n.count, minItems, maxItems) + } + for i, item := range n.items { + if i < int(n.count) { + require.NotNil(t, item, "latch below count") + } else { + require.Nil(t, item, "latch above count") + } + } + if !n.leaf() { + for i, child := range n.children { + if i <= int(n.count) { + require.NotNil(t, child, "node below count") + } else { + require.Nil(t, child, "node above count") + } + } + } + n.recurse(func(child *node, _ int16) { + child.verifyCountAllowed(t, false) + }) +} + +func (t *btree) isSorted(tt *testing.T) { + t.root.isSorted(tt) +} + +func (n *node) isSorted(t *testing.T) { + for i := int16(1); i < n.count; i++ { + require.LessOrEqual(t, compare(n.items[i-1], n.items[i]), 0) + } + if !n.leaf() { + for i := int16(0); i < n.count; i++ { + prev := n.children[i] + next := n.children[i+1] + + require.LessOrEqual(t, compare(prev.items[prev.count-1], n.items[i]), 0) + require.LessOrEqual(t, compare(n.items[i], next.items[0]), 0) + } + } + n.recurse(func(child *node, _ int16) { + child.isSorted(t) + }) +} + +func (t *btree) isUpperBoundCorrect(tt *testing.T) { + t.root.isUpperBoundCorrect(tt) +} + +func (n *node) isUpperBoundCorrect(t *testing.T) { + require.Equal(t, 0, n.findUpperBound().compare(n.max())) + for i := int16(1); i < n.count; i++ { + require.LessOrEqual(t, upperBound(n.items[i]).compare(n.max()), 0) + } + if !n.leaf() { + for i := int16(0); i <= n.count; i++ { + child := n.children[i] + require.LessOrEqual(t, child.max().compare(n.max()), 0) + } + } + n.recurse(func(child *node, _ int16) { + child.isUpperBoundCorrect(t) + }) +} + +func (n *node) recurse(f func(child *node, pos int16)) { + if !n.leaf() { + for i := int16(0); i <= n.count; i++ { + f(n.children[i], i) + } + } +} + +////////////////////////////////////////// +// Unit Tests // +////////////////////////////////////////// + +func key(i int) roachpb.Key { + if i < 0 || i > 99999 { + panic("key out of bounds") + } + return []byte(fmt.Sprintf("%05d", i)) +} + +func span(i int) roachpb.Span { + switch i % 10 { + case 0: + return roachpb.Span{Key: key(i)} + case 1: + return roachpb.Span{Key: key(i), EndKey: key(i).Next()} + case 2: + return roachpb.Span{Key: key(i), EndKey: key(i + 64)} + default: + return roachpb.Span{Key: key(i), EndKey: key(i + 4)} + } +} + +func spanWithEnd(start, end int) roachpb.Span { + if start < end { + return roachpb.Span{Key: key(start), EndKey: key(end)} + } else if start == end { + return roachpb.Span{Key: key(start)} + } else { + panic("illegal span") + } +} + +func spanWithMemo(i int, memo map[int]roachpb.Span) roachpb.Span { + if s, ok := memo[i]; ok { + return s + } + s := span(i) + memo[i] = s + return s +} + +func randomSpan(rng *rand.Rand, n int) roachpb.Span { + start := rng.Intn(n) + end := rng.Intn(n + 1) + if end < start { + start, end = end, start + } + return spanWithEnd(start, end) +} + +func checkIter(t *testing.T, it iterator, start, end int, spanMemo map[int]roachpb.Span) { + i := start + for it.First(); it.Valid(); it.Next() { + item := it.Cur() + expected := spanWithMemo(i, spanMemo) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("expected %s, but found %s", expected, spanFromItem(item)) + } + i++ + } + if i != end { + t.Fatalf("expected %d, but at %d", end, i) + } + + for it.Last(); it.Valid(); it.Prev() { + i-- + item := it.Cur() + expected := spanWithMemo(i, spanMemo) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("expected %s, but found %s", expected, spanFromItem(item)) + } + } + if i != start { + t.Fatalf("expected %d, but at %d: %+v", start, i, it) + } + + all := newItem(spanWithEnd(start, end)) + for it.FirstOverlap(all); it.Valid(); it.NextOverlap(all) { + item := it.Cur() + expected := spanWithMemo(i, spanMemo) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("expected %s, but found %s", expected, spanFromItem(item)) + } + i++ + } + if i != end { + t.Fatalf("expected %d, but at %d", end, i) + } +} + +// TestBTree tests basic btree operations. +func TestBTree(t *testing.T) { + var tr btree + spanMemo := make(map[int]roachpb.Span) + + // With degree == 16 (max-items/node == 31) we need 513 items in order for + // there to be 3 levels in the tree. The count here is comfortably above + // that. + const count = 768 + + // Add keys in sorted order. + for i := 0; i < count; i++ { + tr.Set(newItem(span(i))) + tr.Verify(t) + if e := i + 1; e != tr.Len() { + t.Fatalf("expected length %d, but found %d", e, tr.Len()) + } + checkIter(t, tr.MakeIter(), 0, i+1, spanMemo) + } + + // Delete keys in sorted order. + for i := 0; i < count; i++ { + tr.Delete(newItem(span(i))) + tr.Verify(t) + if e := count - (i + 1); e != tr.Len() { + t.Fatalf("expected length %d, but found %d", e, tr.Len()) + } + checkIter(t, tr.MakeIter(), i+1, count, spanMemo) + } + + // Add keys in reverse sorted order. + for i := 0; i < count; i++ { + tr.Set(newItem(span(count - i))) + tr.Verify(t) + if e := i + 1; e != tr.Len() { + t.Fatalf("expected length %d, but found %d", e, tr.Len()) + } + checkIter(t, tr.MakeIter(), count-i, count+1, spanMemo) + } + + // Delete keys in reverse sorted order. + for i := 0; i < count; i++ { + tr.Delete(newItem(span(count - i))) + tr.Verify(t) + if e := count - (i + 1); e != tr.Len() { + t.Fatalf("expected length %d, but found %d", e, tr.Len()) + } + checkIter(t, tr.MakeIter(), 1, count-i, spanMemo) + } +} + +// TestBTreeSeek tests basic btree iterator operations. +func TestBTreeSeek(t *testing.T) { + const count = 513 + + var tr btree + for i := 0; i < count; i++ { + tr.Set(newItem(span(i * 2))) + } + + it := tr.MakeIter() + for i := 0; i < 2*count-1; i++ { + it.SeekGE(newItem(span(i))) + if !it.Valid() { + t.Fatalf("%d: expected valid iterator", i) + } + item := it.Cur() + expected := span(2 * ((i + 1) / 2)) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item)) + } + } + it.SeekGE(newItem(span(2*count - 1))) + if it.Valid() { + t.Fatalf("expected invalid iterator") + } + + for i := 1; i < 2*count; i++ { + it.SeekLT(newItem(span(i))) + if !it.Valid() { + t.Fatalf("%d: expected valid iterator", i) + } + item := it.Cur() + expected := span(2 * ((i - 1) / 2)) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item)) + } + } + it.SeekLT(newItem(span(0))) + if it.Valid() { + t.Fatalf("expected invalid iterator") + } +} + +// TestBTreeSeekOverlap tests btree iterator overlap operations. +func TestBTreeSeekOverlap(t *testing.T) { + const count = 513 + const size = 2 * maxItems + + var tr btree + for i := 0; i < count; i++ { + tr.Set(newItem(spanWithEnd(i, i+size+1))) + } + + // Iterate over overlaps with a point scan. + it := tr.MakeIter() + for i := 0; i < count+size; i++ { + scanItem := newItem(spanWithEnd(i, i)) + it.FirstOverlap(scanItem) + for j := 0; j < size+1; j++ { + expStart := i - size + j + if expStart < 0 { + continue + } + if expStart >= count { + continue + } + + if !it.Valid() { + t.Fatalf("%d/%d: expected valid iterator", i, j) + } + item := it.Cur() + expected := spanWithEnd(expStart, expStart+size+1) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item)) + } + + it.NextOverlap(scanItem) + } + if it.Valid() { + t.Fatalf("%d: expected invalid iterator %v", i, it.Cur()) + } + } + it.FirstOverlap(newItem(span(count + size + 1))) + if it.Valid() { + t.Fatalf("expected invalid iterator") + } + + // Iterate over overlaps with a range scan. + it = tr.MakeIter() + for i := 0; i < count+size; i++ { + scanItem := newItem(spanWithEnd(i, i+size+1)) + it.FirstOverlap(scanItem) + for j := 0; j < 2*size+1; j++ { + expStart := i - size + j + if expStart < 0 { + continue + } + if expStart >= count { + continue + } + + if !it.Valid() { + t.Fatalf("%d/%d: expected valid iterator", i, j) + } + item := it.Cur() + expected := spanWithEnd(expStart, expStart+size+1) + if !expected.Equal(spanFromItem(item)) { + t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item)) + } + + it.NextOverlap(scanItem) + } + if it.Valid() { + t.Fatalf("%d: expected invalid iterator %v", i, it.Cur()) + } + } + it.FirstOverlap(newItem(span(count + size + 1))) + if it.Valid() { + t.Fatalf("expected invalid iterator") + } +} + +// TestBTreeCompare tests the btree item comparison. +func TestBTreeCompare(t *testing.T) { + // NB: go_generics doesn't do well with anonymous types, so name this type. + // Avoid the slice literal syntax, which GofmtSimplify mandates the use of + // anonymous constructors with. + type testCase struct { + spanA, spanB roachpb.Span + idA, idB uint64 + exp int + } + var testCases []testCase + testCases = append(testCases, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a")}, + spanB: roachpb.Span{Key: roachpb.Key("a")}, + idA: 1, + idB: 1, + exp: 0, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a")}, + spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("b")}, + idA: 1, + idB: 1, + exp: -1, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")}, + spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("b")}, + idA: 1, + idB: 1, + exp: 1, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")}, + spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")}, + idA: 1, + idB: 1, + exp: 0, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a")}, + spanB: roachpb.Span{Key: roachpb.Key("a")}, + idA: 1, + idB: 2, + exp: -1, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("a")}, + spanB: roachpb.Span{Key: roachpb.Key("a")}, + idA: 2, + idB: 1, + exp: 1, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("b")}, + spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")}, + idA: 1, + idB: 1, + exp: 1, + }, + testCase{ + spanA: roachpb.Span{Key: roachpb.Key("b"), EndKey: roachpb.Key("e")}, + spanB: roachpb.Span{Key: roachpb.Key("c"), EndKey: roachpb.Key("d")}, + idA: 1, + idB: 1, + exp: -1, + }, + ) + for _, tc := range testCases { + name := fmt.Sprintf("compare(%s:%d,%s:%d)", tc.spanA, tc.idA, tc.spanB, tc.idB) + t.Run(name, func(t *testing.T) { + laA := newItem(tc.spanA) + laA.SetID(tc.idA) + laB := newItem(tc.spanB) + laB.SetID(tc.idB) + require.Equal(t, tc.exp, compare(laA, laB)) + }) + } +} + +// TestIterStack tests the interface of the iterStack type. +func TestIterStack(t *testing.T) { + f := func(i int) iterFrame { return iterFrame{pos: int16(i)} } + var is iterStack + for i := 1; i <= 2*len(iterStackArr{}); i++ { + var j int + for j = 0; j < i; j++ { + is.push(f(j)) + } + require.Equal(t, j, is.len()) + for j--; j >= 0; j-- { + require.Equal(t, f(j), is.pop()) + } + is.reset() + } +} + +////////////////////////////////////////// +// Randomized Unit Tests // +////////////////////////////////////////// + +// perm returns a random permutation of items with spans in the range [0, n). +func perm(n int) (out []*bufferedWrite) { + for _, i := range rand.Perm(n) { + out = append(out, newItem(spanWithEnd(i, i+1))) + } + return out +} + +// rang returns an ordered list of items with spans in the range [m, n]. +func rang(m, n int) (out []*bufferedWrite) { + for i := m; i <= n; i++ { + out = append(out, newItem(spanWithEnd(i, i+1))) + } + return out +} + +// all extracts all items from a tree in order as a slice. +func all(tr *btree) (out []*bufferedWrite) { + it := tr.MakeIter() + it.First() + for it.Valid() { + out = append(out, it.Cur()) + it.Next() + } + return out +} + +func run(tb testing.TB, name string, f func(testing.TB)) { + switch v := tb.(type) { + case *testing.T: + v.Run(name, func(t *testing.T) { + f(t) + }) + case *testing.B: + v.Run(name, func(b *testing.B) { + f(b) + }) + default: + tb.Fatalf("unknown %T", tb) + } +} + +func iters(tb testing.TB, count int) int { + switch v := tb.(type) { + case *testing.T: + return count + case *testing.B: + return v.N + default: + tb.Fatalf("unknown %T", tb) + return 0 + } +} + +func verify(tb testing.TB, tr *btree) { + if tt, ok := tb.(*testing.T); ok { + tr.Verify(tt) + } +} + +func resetTimer(tb testing.TB) { + if b, ok := tb.(*testing.B); ok { + b.ResetTimer() + } +} + +func stopTimer(tb testing.TB) { + if b, ok := tb.(*testing.B); ok { + b.StopTimer() + } +} + +func startTimer(tb testing.TB) { + if b, ok := tb.(*testing.B); ok { + b.StartTimer() + } +} + +func runBTreeInsert(tb testing.TB, count int) { + iters := iters(tb, count) + insertP := perm(count) + resetTimer(tb) + for i := 0; i < iters; { + var tr btree + for _, item := range insertP { + tr.Set(item) + verify(tb, &tr) + i++ + if i >= iters { + return + } + } + } +} + +func runBTreeDelete(tb testing.TB, count int) { + iters := iters(tb, count) + insertP, removeP := perm(count), perm(count) + resetTimer(tb) + for i := 0; i < iters; { + stopTimer(tb) + var tr btree + for _, item := range insertP { + tr.Set(item) + verify(tb, &tr) + } + startTimer(tb) + for _, item := range removeP { + tr.Delete(item) + verify(tb, &tr) + i++ + if i >= iters { + return + } + } + if tr.Len() > 0 { + tb.Fatalf("tree not empty: %s", &tr) + } + } +} + +func runBTreeDeleteInsert(tb testing.TB, count int) { + iters := iters(tb, count) + insertP := perm(count) + var tr btree + for _, item := range insertP { + tr.Set(item) + verify(tb, &tr) + } + resetTimer(tb) + for i := 0; i < iters; i++ { + item := insertP[i%count] + tr.Delete(item) + verify(tb, &tr) + tr.Set(item) + verify(tb, &tr) + } +} + +func runBTreeDeleteInsertCloneOnce(tb testing.TB, count int) { + iters := iters(tb, count) + insertP := perm(count) + var tr btree + for _, item := range insertP { + tr.Set(item) + verify(tb, &tr) + } + tr = tr.Clone() + resetTimer(tb) + for i := 0; i < iters; i++ { + item := insertP[i%count] + tr.Delete(item) + verify(tb, &tr) + tr.Set(item) + verify(tb, &tr) + } +} + +func runBTreeDeleteInsertCloneEachTime(tb testing.TB, count int) { + for _, reset := range []bool{false, true} { + run(tb, fmt.Sprintf("reset=%t", reset), func(tb testing.TB) { + iters := iters(tb, count) + insertP := perm(count) + var tr, trReset btree + for _, item := range insertP { + tr.Set(item) + verify(tb, &tr) + } + resetTimer(tb) + for i := 0; i < iters; i++ { + item := insertP[i%count] + if reset { + trReset.Reset() + trReset = tr + } + tr = tr.Clone() + tr.Delete(item) + verify(tb, &tr) + tr.Set(item) + verify(tb, &tr) + } + }) + } +} + +// randN returns a random integer in the range [min, max). +func randN(min, max int) int { return rand.Intn(max-min) + min } +func randCount() int { + if testing.Short() { + return randN(1, 128) + } + return randN(1, 1024) +} + +func TestBTreeInsert(t *testing.T) { + count := randCount() + runBTreeInsert(t, count) +} + +func TestBTreeDelete(t *testing.T) { + count := randCount() + runBTreeDelete(t, count) +} + +func TestBTreeDeleteInsert(t *testing.T) { + count := randCount() + runBTreeDeleteInsert(t, count) +} + +func TestBTreeDeleteInsertCloneOnce(t *testing.T) { + count := randCount() + runBTreeDeleteInsertCloneOnce(t, count) +} + +func TestBTreeDeleteInsertCloneEachTime(t *testing.T) { + count := randCount() + runBTreeDeleteInsertCloneEachTime(t, count) +} + +// TestBTreeSeekOverlapRandom tests btree iterator overlap operations using +// randomized input. +func TestBTreeSeekOverlapRandom(t *testing.T) { + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + + const trials = 10 + for i := 0; i < trials; i++ { + var tr btree + + const count = 1000 + items := make([]*bufferedWrite, count) + itemSpans := make([]int, count) + for j := 0; j < count; j++ { + var item *bufferedWrite + end := rng.Intn(count + 10) + if end <= j { + end = j + item = newItem(spanWithEnd(j, end)) + } else { + item = newItem(spanWithEnd(j, end+1)) + } + tr.Set(item) + items[j] = item + itemSpans[j] = end + } + + const scanTrials = 100 + for j := 0; j < scanTrials; j++ { + var scanItem *bufferedWrite + scanStart := rng.Intn(count) + scanEnd := rng.Intn(count + 10) + if scanEnd <= scanStart { + scanEnd = scanStart + scanItem = newItem(spanWithEnd(scanStart, scanEnd)) + } else { + scanItem = newItem(spanWithEnd(scanStart, scanEnd+1)) + } + + var exp, found []*bufferedWrite + for startKey, endKey := range itemSpans { + if startKey <= scanEnd && endKey >= scanStart { + exp = append(exp, items[startKey]) + } + } + + it := tr.MakeIter() + it.FirstOverlap(scanItem) + for it.Valid() { + found = append(found, it.Cur()) + it.NextOverlap(scanItem) + } + + require.Equal(t, len(exp), len(found), "search for %v", spanFromItem(scanItem)) + } + } +} + +// TestBTreeCloneConcurrentOperations tests that cloning a btree returns a new +// btree instance which is an exact logical copy of the original but that can be +// modified independently going forward. +func TestBTreeCloneConcurrentOperations(t *testing.T) { + const cloneTestSize = 1000 + p := perm(cloneTestSize) + + var trees []*btree + treeC, treeDone := make(chan *btree), make(chan struct{}) + go func() { + for b := range treeC { + trees = append(trees, b) + } + close(treeDone) + }() + + var wg sync.WaitGroup + var populate func(tr *btree, start int) + populate = func(tr *btree, start int) { + t.Logf("Starting new clone at %v", start) + treeC <- tr + for i := start; i < cloneTestSize; i++ { + tr.Set(p[i]) + if i%(cloneTestSize/5) == 0 { + wg.Add(1) + c := tr.Clone() + go populate(&c, i+1) + } + } + wg.Done() + } + + wg.Add(1) + var tr btree + go populate(&tr, 0) + wg.Wait() + close(treeC) + <-treeDone + + t.Logf("Starting equality checks on %d trees", len(trees)) + want := rang(0, cloneTestSize-1) + for i, tree := range trees { + if !reflect.DeepEqual(want, all(tree)) { + t.Errorf("tree %v mismatch", i) + } + } + + t.Log("Removing half of items from first half") + toRemove := want[cloneTestSize/2:] + for i := 0; i < len(trees)/2; i++ { + tree := trees[i] + wg.Add(1) + go func() { + for _, item := range toRemove { + tree.Delete(item) + } + wg.Done() + }() + } + wg.Wait() + + t.Log("Checking all values again") + for i, tree := range trees { + var wantpart []*bufferedWrite + if i < len(trees)/2 { + wantpart = want[:cloneTestSize/2] + } else { + wantpart = want + } + if got := all(tree); !reflect.DeepEqual(wantpart, got) { + t.Errorf("tree %v mismatch, want %v got %v", i, len(want), len(got)) + } + } +} + +////////////////////////////////////////// +// Benchmarks // +////////////////////////////////////////// + +func forBenchmarkSizes(b *testing.B, f func(b *testing.B, count int)) { + for _, count := range []int{16, 128, 1024, 8192, 65536} { + b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) { + f(b, count) + }) + } +} + +// BenchmarkBTreeInsert measures btree insertion performance. +func BenchmarkBTreeInsert(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + runBTreeInsert(b, count) + }) +} + +// BenchmarkBTreeDelete measures btree deletion performance. +func BenchmarkBTreeDelete(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + runBTreeDelete(b, count) + }) +} + +// BenchmarkBTreeDeleteInsert measures btree deletion and insertion performance. +func BenchmarkBTreeDeleteInsert(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + runBTreeDeleteInsert(b, count) + }) +} + +// BenchmarkBTreeDeleteInsertCloneOnce measures btree deletion and insertion +// performance after the tree has been copy-on-write cloned once. +func BenchmarkBTreeDeleteInsertCloneOnce(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + runBTreeDeleteInsertCloneOnce(b, count) + }) +} + +// BenchmarkBTreeDeleteInsertCloneEachTime measures btree deletion and insertion +// performance while the tree is repeatedly copy-on-write cloned. +func BenchmarkBTreeDeleteInsertCloneEachTime(b *testing.B) { + forBenchmarkSizes(b, func(b *testing.B, count int) { + runBTreeDeleteInsertCloneEachTime(b, count) + }) +} + +// BenchmarkBTreeMakeIter measures the cost of creating a btree iterator. +func BenchmarkBTreeMakeIter(b *testing.B) { + var tr btree + for i := 0; i < b.N; i++ { + it := tr.MakeIter() + it.First() + } +} + +// BenchmarkBTreeIterSeekGE measures the cost of seeking a btree iterator +// forward. +func BenchmarkBTreeIterSeekGE(b *testing.B) { + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + forBenchmarkSizes(b, func(b *testing.B, count int) { + var spans []roachpb.Span + var tr btree + + for i := 0; i < count; i++ { + s := span(i) + spans = append(spans, s) + tr.Set(newItem(s)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + s := spans[rng.Intn(len(spans))] + it := tr.MakeIter() + it.SeekGE(newItem(s)) + if testing.Verbose() { + if !it.Valid() { + b.Fatal("expected to find key") + } + if !s.Equal(spanFromItem(it.Cur())) { + b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur())) + } + } + } + }) +} + +// BenchmarkBTreeIterSeekLT measures the cost of seeking a btree iterator +// backward. +func BenchmarkBTreeIterSeekLT(b *testing.B) { + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + forBenchmarkSizes(b, func(b *testing.B, count int) { + var spans []roachpb.Span + var tr btree + + for i := 0; i < count; i++ { + s := span(i) + spans = append(spans, s) + tr.Set(newItem(s)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + j := rng.Intn(len(spans)) + s := spans[j] + it := tr.MakeIter() + it.SeekLT(newItem(s)) + if testing.Verbose() { + if j == 0 { + if it.Valid() { + b.Fatal("unexpected key") + } + } else { + if !it.Valid() { + b.Fatal("expected to find key") + } + s := spans[j-1] + if !s.Equal(spanFromItem(it.Cur())) { + b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur())) + } + } + } + } + }) +} + +// BenchmarkBTreeIterFirstOverlap measures the cost of finding a single +// overlapping item using a btree iterator. +func BenchmarkBTreeIterFirstOverlap(b *testing.B) { + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + forBenchmarkSizes(b, func(b *testing.B, count int) { + var spans []roachpb.Span + var tr btree + + for i := 0; i < count; i++ { + s := spanWithEnd(i, i+1) + spans = append(spans, s) + tr.Set(newItem(s)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + j := rng.Intn(len(spans)) + s := spans[j] + it := tr.MakeIter() + it.FirstOverlap(newItem(s)) + if testing.Verbose() { + if !it.Valid() { + b.Fatal("expected to find key") + } + if !s.Equal(spanFromItem(it.Cur())) { + b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur())) + } + } + } + }) +} + +// BenchmarkBTreeIterNext measures the cost of seeking a btree iterator to the +// next item in the tree. +func BenchmarkBTreeIterNext(b *testing.B) { + var tr btree + + const count = 8 << 10 + const size = 2 * maxItems + for i := 0; i < count; i++ { + item := newItem(spanWithEnd(i, i+size+1)) + tr.Set(item) + } + + it := tr.MakeIter() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !it.Valid() { + it.First() + } + it.Next() + } +} + +// BenchmarkBTreeIterPrev measures the cost of seeking a btree iterator to the +// previous item in the tree. +func BenchmarkBTreeIterPrev(b *testing.B) { + var tr btree + + const count = 8 << 10 + const size = 2 * maxItems + for i := 0; i < count; i++ { + item := newItem(spanWithEnd(i, i+size+1)) + tr.Set(item) + } + + it := tr.MakeIter() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !it.Valid() { + it.Last() + } + it.Prev() + } +} + +// BenchmarkBTreeIterNextOverlap measures the cost of seeking a btree iterator +// to the next overlapping item in the tree. +func BenchmarkBTreeIterNextOverlap(b *testing.B) { + var tr btree + + const count = 8 << 10 + const size = 2 * maxItems + for i := 0; i < count; i++ { + item := newItem(spanWithEnd(i, i+size+1)) + tr.Set(item) + } + + allCmd := newItem(spanWithEnd(0, count+1)) + it := tr.MakeIter() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !it.Valid() { + it.FirstOverlap(allCmd) + } + it.NextOverlap(allCmd) + } +} + +// BenchmarkBTreeIterOverlapScan measures the cost of scanning over all +// overlapping items using a btree iterator. +func BenchmarkBTreeIterOverlapScan(b *testing.B) { + var tr btree + rng := rand.New(rand.NewSource(timeutil.Now().UnixNano())) + + const count = 8 << 10 + const size = 2 * maxItems + for i := 0; i < count; i++ { + tr.Set(newItem(spanWithEnd(i, i+size+1))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + item := newItem(randomSpan(rng, count)) + it := tr.MakeIter() + it.FirstOverlap(item) + for it.Valid() { + it.NextOverlap(item) + } + } +} diff --git a/pkg/kv/kvclient/kvcoord/txn_coord_sender.go b/pkg/kv/kvclient/kvcoord/txn_coord_sender.go index b1318201806f..bc26c6f5ced0 100644 --- a/pkg/kv/kvclient/kvcoord/txn_coord_sender.go +++ b/pkg/kv/kvclient/kvcoord/txn_coord_sender.go @@ -161,9 +161,10 @@ type TxnCoordSender struct { // additional heap allocations necessary. interceptorStack []txnInterceptor interceptorAlloc struct { - arr [6]txnInterceptor + arr [7]txnInterceptor txnHeartbeater txnSeqNumAllocator + txnWriteBuffer txnPipeliner txnCommitter txnSpanRefresher @@ -275,6 +276,7 @@ func newRootTxnCoordSender( // Various interceptors below rely on sequence number allocation, // so the sequence number allocator is near the top of the stack. &tcs.interceptorAlloc.txnSeqNumAllocator, + &tcs.interceptorAlloc.txnWriteBuffer, // The pipeliner sits above the span refresher because it will // never generate transaction retry errors that could be avoided // with a refresh. @@ -312,6 +314,9 @@ func (tc *TxnCoordSender) initCommonInterceptors( if ds, ok := tcf.wrapped.(*DistSender); ok { riGen.ds = ds } + tc.interceptorAlloc.txnWriteBuffer = txnWriteBuffer{ + enabled: bufferedWritesEnabled.Get(&tcf.st.SV), + } tc.interceptorAlloc.txnPipeliner = txnPipeliner{ st: tcf.st, riGen: riGen, @@ -500,7 +505,7 @@ func (tc *TxnCoordSender) Send( return nil, pErr } - if ba.IsSingleEndTxnRequest() && !tc.interceptorAlloc.txnPipeliner.hasAcquiredLocks() { + if ba.IsSingleEndTxnRequest() && (!tc.interceptorAlloc.txnPipeliner.hasAcquiredLocks() && tc.interceptorAlloc.txnWriteBuffer.buf.Len() == 0) { return nil, tc.finalizeNonLockingTxnLocked(ctx, ba) } diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go index 6e6407603562..88cc8549dd68 100644 --- a/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go @@ -21,10 +21,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" - "github.com/google/btree" + gbtree "github.com/google/btree" ) -// The degree of the inFlightWrites btree. +// The degree of the inFlightWrites gbtree. const txnPipelinerBtreeDegree = 32 // PipelinedWritesEnabled is the kv.transaction.write_pipelining.enabled cluster setting. @@ -1048,12 +1048,12 @@ func makeInFlightWrite(key roachpb.Key, seq enginepb.TxnSeq, str lock.Strength) }} } -// Less implements the btree.Item interface. +// Less implements the gbtree.Item interface. // // inFlightWrites are ordered by Key, then by Sequence, then by Strength. Two // inFlightWrites with the same Key but different Sequences and/or Strengths are // not considered equal and are maintained separately in the inFlightWritesSet. -func (a *inFlightWrite) Less(bItem btree.Item) bool { +func (a *inFlightWrite) Less(bItem gbtree.Item) bool { b := bItem.(*inFlightWrite) kCmp := a.Key.Compare(b.Key) if kCmp != 0 { @@ -1077,7 +1077,7 @@ func (a *inFlightWrite) Less(bItem btree.Item) bool { // writes, O(log n) removal of existing in-flight writes, and O(m + log n) // retrieval over m in-flight writes that overlap with a given key. type inFlightWriteSet struct { - t *btree.BTree + t *gbtree.BTree bytes int64 // Avoids allocs. @@ -1090,7 +1090,7 @@ type inFlightWriteSet struct { func (s *inFlightWriteSet) insert(key roachpb.Key, seq enginepb.TxnSeq, str lock.Strength) { if s.t == nil { // Lazily initialize btree. - s.t = btree.New(txnPipelinerBtreeDegree) + s.t = gbtree.New(txnPipelinerBtreeDegree) } w := s.alloc.alloc(key, seq, str) @@ -1136,7 +1136,7 @@ func (s *inFlightWriteSet) ascend(f func(w *inFlightWrite)) { // Set is empty. return } - s.t.Ascend(func(i btree.Item) bool { + s.t.Ascend(func(i gbtree.Item) bool { f(i.(*inFlightWrite)) return true }) @@ -1157,7 +1157,7 @@ func (s *inFlightWriteSet) ascendRange(start, end roachpb.Key, f func(w *inFligh // Range lookup. s.tmp2 = makeInFlightWrite(end, 0, 0) } - s.t.AscendRange(&s.tmp1, &s.tmp2, func(i btree.Item) bool { + s.t.AscendRange(&s.tmp1, &s.tmp2, func(i gbtree.Item) bool { f(i.(*inFlightWrite)) return true }) diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go new file mode 100644 index 000000000000..7dcffb24aa35 --- /dev/null +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go @@ -0,0 +1,499 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package kvcoord + +import ( + "bytes" + "context" + "sort" + + "github.com/cockroachdb/cockroach/pkg/kv/kvpb" + "github.com/cockroachdb/cockroach/pkg/roachpb" + "github.com/cockroachdb/cockroach/pkg/settings" + "github.com/cockroachdb/cockroach/pkg/storage/enginepb" + "github.com/cockroachdb/cockroach/pkg/util/log" +) + +var bufferedWritesEnabled = settings.RegisterBoolSetting( + settings.ApplicationLevel, + "kv.transaction.buffered_writes.enabled", + "if enabled, transactional writes are buffered on the gateway", + true, + settings.WithPublic, +) + +// txnWriteBuffer is a txnInterceptor that buffers writes for a transaction +// before sending them during commit to the wrapped lockedSender. Buffering +// writes client side has four main benefits: +// +// 1. It allows for more batching of writes, which can be more efficient. +// Instead of sending writes one at a time, we can batch them up and send +// them all at once. This is a win even if writes would otherwise be +// pipelined through consensus. +// +// 2. It allows for the elimination of redundant writes. If a client writes to +// the same key multiple times in a transaction, only the last write needs +// to be written to the key-value layer. +// +// 3. It allows the client to serve read-your-writes locally, which can be much +// faster and cheaper than sending them to the leaseholder. This is +// especially true when the leaseholder is not collocated with the client. +// +// By serving read-your-writes locally from the gateway, write buffering +// also avoids the problem of pipeline stalls that can occur when a client +// reads a pipelined intent write before the write has completed consensus. +// For details on pipeline stalls, see txnPipeliner. +// +// 4. It allows clients to passively hit the 1-phase commit fast-path, instead +// of requiring clients to carefully construct "auto-commit" BatchRequests +// to make us of the optimization. By buffering writes on the client before +// commit, we avoid immediately disabling the fast-path when the client +// issues their first write. Then, at commit time, we flush the buffer and +// will happen to hit the 1-phase commit fast path if all writes go to the +// same range. +// +// However, buffering writes comes with some challenges. +// +// The first is that read-only and read-write requests need to be aware of the +// buffered writes, as they may need to serve reads from the buffer. +// TODO: discuss distributed execution. +// +// The second is that savepoints need to be handled correctly. TODO... +// +// The third is that the buffer needs to adhere to memory limits. TODO ... +type txnWriteBuffer struct { + wrapped lockedSender + enabled bool + + buf btree + bufSeek bufferedWrite + bufIDAlloc uint64 +} + +// SendLocked implements the txnInterceptor interface. +func (twb *txnWriteBuffer) SendLocked( + ctx context.Context, ba *kvpb.BatchRequest, +) (*kvpb.BatchResponse, *kvpb.Error) { + if !twb.enabled { + return twb.wrapped.SendLocked(ctx, ba) + } + + if _, ok := ba.GetArg(kvpb.EndTxn); ok { + return twb.flushWithEndTxn(ctx, ba) + } + + baRemote, brLocal, localPositions, pErr := twb.servePointReadsAndAddWritesToBuffer(ctx, ba) + if pErr != nil { + return nil, pErr + } + if len(baRemote.Requests) == 0 { + return brLocal, nil + } + + brRemote, pErr := twb.wrapped.SendLocked(ctx, baRemote) + if pErr != nil { + return nil, pErr + } + + twb.augmentRangedReadsFromBuffer(ctx, baRemote, brRemote) + br := twb.mergeLocalAndRemoteResponses(ctx, brLocal, brRemote, localPositions) + return br, nil +} + +// flushWithEndTxn flushes the write buffer with the EndTxn request. +func (twb *txnWriteBuffer) flushWithEndTxn( + ctx context.Context, ba *kvpb.BatchRequest, +) (*kvpb.BatchResponse, *kvpb.Error) { + flushed := twb.buf.Len() + if flushed > 0 { + reqs := make([]kvpb.RequestUnion, 0, flushed+len(ba.Requests)) + it := twb.buf.MakeIter() + for it.First(); it.Valid(); it.Next() { + reqs = append(reqs, it.Cur().toRequest()) + } + sort.SliceStable(reqs, func(i, j int) bool { + return reqs[i].GetInner().Header().Sequence < reqs[j].GetInner().Header().Sequence + }) + reqs = append(reqs, ba.Requests...) + ba.Requests = reqs + } + + br, pErr := twb.wrapped.SendLocked(ctx, ba) + if pErr != nil { + return nil, pErr + } + br.Responses = br.Responses[flushed:] + return br, nil +} + +// servePointReadsAndAddWritesToBuffer serves point reads from the buffer and +// adds blind writes to the buffer. It returns a BatchRequest with the locally +// serviceable requests removed, a BatchResponse with the locally serviceable +// requests' responses, and a slice of the positions of the locally serviceable +// requests in the original BatchRequest. +func (twb *txnWriteBuffer) servePointReadsAndAddWritesToBuffer( + ctx context.Context, ba *kvpb.BatchRequest, +) ( + baRemote *kvpb.BatchRequest, + brLocal *kvpb.BatchResponse, + localPositions []int, + pErr *kvpb.Error, +) { + baRemote = ba.ShallowCopy() + baRemote.Requests = nil + brLocal = &kvpb.BatchResponse{} + for i, ru := range ba.Requests { + req := ru.GetInner() + seek := twb.makeBufferSeekFor(req.Header()) + prevLocalLen := len(brLocal.Responses) + switch t := req.(type) { + case *kvpb.GetRequest: + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if it.Valid() { + var ru kvpb.ResponseUnion + getResp := &kvpb.GetResponse{} + if it.Cur().val.IsPresent() { + getResp.Value = it.Cur().valPtr() + } + ru.MustSetInner(getResp) + brLocal.Responses = append(brLocal.Responses, ru) + } else { + baRemote.Requests = append(baRemote.Requests, ru) + } + + case *kvpb.ScanRequest: + // Hack: just flush overlapping writes for now. + var baFlush *kvpb.BatchRequest + for { + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if !it.Valid() { + break + } + if baFlush == nil { + baFlush = baRemote.ShallowCopy() + baFlush.Requests = nil + baFlush.MaxSpanRequestKeys = 0 + baFlush.TargetBytes = 0 + } + baFlush.Requests = append(baFlush.Requests, it.Cur().toRequest()) + twb.buf.Delete(it.Cur()) + } + if baFlush != nil { + sort.SliceStable(baFlush.Requests, func(i, j int) bool { + return baFlush.Requests[i].GetInner().Header().Sequence < baFlush.Requests[j].GetInner().Header().Sequence + }) + brFlush, pErr := twb.wrapped.SendLocked(ctx, baFlush) + if pErr != nil { + return nil, nil, nil, pErr + } + baRemote.Txn.Update(brFlush.Txn) + } + // Send the request, then augment the response. + baRemote.Requests = append(baRemote.Requests, ru) + + case *kvpb.ReverseScanRequest: + // Hack: just flush overlapping writes for now. + var baFlush *kvpb.BatchRequest + for { + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if !it.Valid() { + break + } + if baFlush == nil { + baFlush = baRemote.ShallowCopy() + baFlush.Requests = nil + baFlush.MaxSpanRequestKeys = 0 + baFlush.TargetBytes = 0 + } + baFlush.Requests = append(baFlush.Requests, it.Cur().toRequest()) + twb.buf.Delete(it.Cur()) + } + if baFlush != nil { + sort.SliceStable(baFlush.Requests, func(i, j int) bool { + return baFlush.Requests[i].GetInner().Header().Sequence < baFlush.Requests[j].GetInner().Header().Sequence + }) + brFlush, pErr := twb.wrapped.SendLocked(ctx, baFlush) + if pErr != nil { + return nil, nil, nil, pErr + } + baRemote.Txn.Update(brFlush.Txn) + } + // Send the request, then augment the response. + baRemote.Requests = append(baRemote.Requests, ru) + + case *kvpb.PutRequest: + var ru kvpb.ResponseUnion + ru.MustSetInner(&kvpb.PutResponse{}) + brLocal.Responses = append(brLocal.Responses, ru) + + twb.addToBuffer(t.Key, t.Value, t.Sequence) + + case *kvpb.DeleteRequest: + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + var ru kvpb.ResponseUnion + ru.MustSetInner(&kvpb.DeleteResponse{ + // NOTE: this is incorrect. We aren't considering values that are + // present in the KV store. This is fine for the prototype. + FoundKey: it.Valid() && it.Cur().val.IsPresent(), + }) + brLocal.Responses = append(brLocal.Responses, ru) + + twb.addToBuffer(t.Key, roachpb.Value{}, t.Sequence) + + case *kvpb.ConditionalPutRequest: + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if it.Valid() { + expBytes := t.ExpBytes + existVal := it.Cur().val + if expValPresent, existValPresent := len(expBytes) != 0, existVal.IsPresent(); expValPresent && existValPresent { + if !bytes.Equal(expBytes, existVal.TagAndDataBytes()) { + return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{ + ActualValue: it.Cur().valPtr(), + }) + } + } else if expValPresent != existValPresent && (existValPresent || !t.AllowIfDoesNotExist) { + return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{ + ActualValue: it.Cur().valPtr(), + }) + } + var ru kvpb.ResponseUnion + ru.MustSetInner(&kvpb.ConditionalPutResponse{}) + brLocal.Responses = append(brLocal.Responses, ru) + + twb.addToBuffer(t.Key, t.Value, t.Sequence) + } else { + baRemote.Requests = append(baRemote.Requests, ru) + } + + case *kvpb.InitPutRequest: + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if it.Valid() { + failOnTombstones := t.FailOnTombstones + existVal := it.Cur().val + if failOnTombstones && !existVal.IsPresent() { + // We found a tombstone and failOnTombstones is true: fail. + return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{ + ActualValue: it.Cur().valPtr(), + }) + } + if existVal.IsPresent() && !existVal.EqualTagAndData(t.Value) { + // The existing value does not match the supplied value. + return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{ + ActualValue: it.Cur().valPtr(), + }) + } + var ru kvpb.ResponseUnion + ru.MustSetInner(&kvpb.InitPutResponse{}) + brLocal.Responses = append(brLocal.Responses, ru) + + twb.addToBuffer(t.Key, t.Value, t.Sequence) + } else { + baRemote.Requests = append(baRemote.Requests, ru) + } + + case *kvpb.IncrementRequest: + it := twb.buf.MakeIter() + it.FirstOverlap(seek) + if it.Valid() { + log.Fatalf(ctx, "unhandled buffered write overlap with increment") + } + baRemote.Requests = append(baRemote.Requests, ru) + + case *kvpb.DeleteRangeRequest: + it := twb.buf.MakeIter() + for it.FirstOverlap(seek); it.Valid(); it.NextOverlap(seek) { + log.Fatalf(ctx, "unhandled buffered write overlap with delete range") + } + baRemote.Requests = append(baRemote.Requests, ru) + + default: + log.Fatalf(ctx, "unexpected request type: %T", req) + } + if len(brLocal.Responses) != prevLocalLen { + localPositions = append(localPositions, i) + } + } + return baRemote, brLocal, localPositions, nil +} + +// augmentRangedReadsFromBuffer augments the responses to ranged reads in the +// BatchResponse with reads from the buffer. +func (twb *txnWriteBuffer) augmentRangedReadsFromBuffer( + ctx context.Context, baRemote *kvpb.BatchRequest, brRemote *kvpb.BatchResponse, +) { + for _, ru := range baRemote.Requests { + req := ru.GetInner() + switch req.(type) { + case *kvpb.ScanRequest, *kvpb.ReverseScanRequest: + // Send the request, then augment the response. + seek := twb.makeBufferSeekFor(req.Header()) + it := twb.buf.MakeIter() + for it.FirstOverlap(seek); it.Valid(); it.NextOverlap(seek) { + log.Fatalf(ctx, "unhandled buffered write overlap with scan / reverse scan") + } + } + } +} + +// mergeLocalAndRemoteResponses merges the responses to locally serviceable +// requests with the responses to remotely serviceable requests. It returns the +// merged BatchResponse. +func (twb *txnWriteBuffer) mergeLocalAndRemoteResponses( + ctx context.Context, brLocal, brRemote *kvpb.BatchResponse, localPositions []int, +) *kvpb.BatchResponse { + if brLocal == nil { + return brRemote + } + mergedResps := make([]kvpb.ResponseUnion, len(brLocal.Responses)+len(brRemote.Responses)) + for i := range mergedResps { + if len(localPositions) > 0 && i == localPositions[0] { + mergedResps[i] = brLocal.Responses[0] + brLocal.Responses = brLocal.Responses[1:] + localPositions = localPositions[1:] + } else { + mergedResps[i] = brRemote.Responses[0] + brRemote.Responses = brRemote.Responses[1:] + } + } + if len(brRemote.Responses) > 0 || len(brLocal.Responses) > 0 || len(localPositions) > 0 { + log.Fatalf(ctx, "unexpected leftover responses: %d remote, %d local, %d positions", + len(brRemote.Responses), len(brLocal.Responses), len(localPositions)) + } + brRemote.Responses = mergedResps + return brRemote +} + +// setWrapped implements the txnInterceptor interface. +func (twb *txnWriteBuffer) setWrapped(wrapped lockedSender) { + twb.wrapped = wrapped +} + +// populateLeafInputState implements the txnInterceptor interface. +func (twb *txnWriteBuffer) populateLeafInputState(*roachpb.LeafTxnInputState) { + // TODO(nvanbenschoten): send buffered writes to LeafTxns. +} + +// populateLeafFinalState implements the txnInterceptor interface. +func (twb *txnWriteBuffer) populateLeafFinalState(*roachpb.LeafTxnFinalState) { + // TODO(nvanbenschoten): ingest buffered writes in LeafTxns. +} + +// importLeafFinalState implements the txnInterceptor interface. +func (twb *txnWriteBuffer) importLeafFinalState(context.Context, *roachpb.LeafTxnFinalState) error { + return nil +} + +// epochBumpedLocked implements the txnInterceptor interface. +func (twb *txnWriteBuffer) epochBumpedLocked() { + twb.buf.Reset() +} + +// createSavepointLocked implements the txnInterceptor interface. +func (twb *txnWriteBuffer) createSavepointLocked(ctx context.Context, s *savepoint) {} + +// rollbackToSavepointLocked implements the txnInterceptor interface. +func (twb *txnWriteBuffer) rollbackToSavepointLocked(ctx context.Context, s savepoint) { + // TODO(nvanbenschoten): clear out writes after the savepoint. This means that + // we need to retain multiple writes on the same key, so that we can roll one + // back if a savepoint is rolled back. That complicates logic above, but not + // overly so. +} + +// closeLocked implements the txnInterceptor interface. +func (twb *txnWriteBuffer) closeLocked() { + twb.buf.Reset() +} + +func (twb *txnWriteBuffer) makeBufferSeekFor(rh kvpb.RequestHeader) *bufferedWrite { + seek := &twb.bufSeek + seek.key = rh.Key + seek.endKey = rh.EndKey + return seek +} + +func (twb *txnWriteBuffer) addToBuffer(key roachpb.Key, val roachpb.Value, seq enginepb.TxnSeq) { + it := twb.buf.MakeIter() + seek := &twb.bufSeek + seek.key = key + it.FirstOverlap(seek) + if it.Valid() { + // If we have a write for the same key, update it. This is incorrect, as + // it does not handle savepoints, but it makes the prototype simpler. + bw := it.Cur() + bw.val = val + bw.seq = seq + } else { + twb.bufIDAlloc++ + twb.buf.Set(&bufferedWrite{ + id: twb.bufIDAlloc, + key: key, + val: val, + seq: seq, + }) + } +} + +// bufferedWrite is a key-value pair with an associated sequence number. +type bufferedWrite struct { + id uint64 + key roachpb.Key + endKey roachpb.Key // used in btree iteration + val roachpb.Value + seq enginepb.TxnSeq +} + +//go:generate ../../../util/interval/generic/gen.sh *bufferedWrite kvcoord + +// Methods required by util/interval/generic type contract. +func (bw *bufferedWrite) ID() uint64 { return bw.id } +func (bw *bufferedWrite) Key() []byte { return bw.key } +func (bw *bufferedWrite) EndKey() []byte { return bw.endKey } +func (bw *bufferedWrite) String() string { return "todo" } +func (bw *bufferedWrite) New() *bufferedWrite { return new(bufferedWrite) } +func (bw *bufferedWrite) SetID(v uint64) { bw.id = v } +func (bw *bufferedWrite) SetKey(v []byte) { bw.key = v } +func (bw *bufferedWrite) SetEndKey(v []byte) { bw.endKey = v } + +func (bw *bufferedWrite) toRequest() kvpb.RequestUnion { + var ru kvpb.RequestUnion + if bw.val.IsPresent() { + putAlloc := new(struct { + put kvpb.PutRequest + union kvpb.RequestUnion_Put + }) + putAlloc.put.Key = bw.key + putAlloc.put.Value = bw.val + putAlloc.put.Sequence = bw.seq + putAlloc.union.Put = &putAlloc.put + ru.Value = &putAlloc.union + } else { + delAlloc := new(struct { + del kvpb.DeleteRequest + union kvpb.RequestUnion_Delete + }) + delAlloc.del.Key = bw.key + delAlloc.del.Sequence = bw.seq + delAlloc.union.Delete = &delAlloc.del + ru.Value = &delAlloc.union + } + return ru +} + +func (bw *bufferedWrite) valPtr() *roachpb.Value { + valCpy := bw.val + return &valCpy +} diff --git a/pkg/sql/row/inserter.go b/pkg/sql/row/inserter.go index e9a5a557892a..c45a7f2c0c6d 100644 --- a/pkg/sql/row/inserter.go +++ b/pkg/sql/row/inserter.go @@ -201,7 +201,7 @@ func (ri *Inserter) InsertRow( for i := range entries { e := &entries[i] - if ri.Helper.Indexes[idx].ForcePut() { + if ri.Helper.Indexes[idx].ForcePut() || true { // See the comment on (catalog.Index).ForcePut() for more details. insertPutFn(ctx, b, &e.Key, &e.Value, traceKV) } else { diff --git a/pkg/sql/row/updater.go b/pkg/sql/row/updater.go index 55cb1e782c0b..d06bbfbee850 100644 --- a/pkg/sql/row/updater.go +++ b/pkg/sql/row/updater.go @@ -421,18 +421,14 @@ func (ru *Updater) UpdateRow( continue } - if index.ForcePut() { + if index.ForcePut() || expValue == nil { // See the comment on (catalog.Index).ForcePut() for more details. insertPutFn(ctx, putter, &newEntry.Key, &newEntry.Value, traceKV) } else { if traceKV { k := keys.PrettyPrint(ru.Helper.secIndexValDirs[i], newEntry.Key) v := newEntry.Value.PrettyPrint() - if expValue != nil { - log.VEventf(ctx, 2, "CPut %s -> %v (replacing %v, if exists)", k, v, oldEntry.Value.PrettyPrint()) - } else { - log.VEventf(ctx, 2, "CPut %s -> %v (expecting does not exist)", k, v) - } + log.VEventf(ctx, 2, "CPut %s -> %v (replacing %v, if exists)", k, v, oldEntry.Value.PrettyPrint()) } batch.CPutAllowingIfNotExists(newEntry.Key, &newEntry.Value, expValue) }