diff --git a/build/bazelutil/check.sh b/build/bazelutil/check.sh
index e2703d20ac35..53b7eee0ed91 100755
--- a/build/bazelutil/check.sh
+++ b/build/bazelutil/check.sh
@@ -24,6 +24,7 @@ pkg/roachprod/prometheus/prometheus.go://go:generate mockgen -package=prometheus
pkg/cmd/roachtest/clusterstats/collector.go://go:generate mockgen -package=clusterstats -destination mocks_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client
pkg/cmd/roachtest/tests/drt.go://go:generate mockgen -package tests -destination drt_generated_test.go github.com/cockroachdb/cockroach/pkg/roachprod/prometheus Client
pkg/kv/kvclient/kvcoord/transport.go://go:generate mockgen -package=kvcoord -destination=mocks_generated_test.go . Transport
+pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go://go:generate ../../../util/interval/generic/gen.sh *bufferedWrite kvcoord
pkg/kv/kvclient/rangecache/range_cache.go://go:generate mockgen -package=rangecachemock -destination=rangecachemock/mocks_generated.go . RangeDescriptorDB
pkg/kv/kvclient/rangefeed/rangefeed.go://go:generate mockgen -destination=mocks_generated_test.go --package=rangefeed . DB
pkg/kv/kvserver/concurrency/lock_table.go://go:generate ../../../util/interval/generic/gen.sh *keyLocks concurrency
diff --git a/docs/generated/settings/settings-for-tenants.txt b/docs/generated/settings/settings-for-tenants.txt
index f2c75846ef9a..a49792b619fe 100644
--- a/docs/generated/settings/settings-for-tenants.txt
+++ b/docs/generated/settings/settings-for-tenants.txt
@@ -91,6 +91,7 @@ kv.protectedts.reconciliation.interval duration 5m0s the frequency for reconcili
kv.rangefeed.client.stream_startup_rate integer 100 controls the rate per second the client will initiate new rangefeed stream for a single range; 0 implies unlimited application
kv.rangefeed.closed_timestamp_refresh_interval duration 3s the interval at which closed-timestamp updatesare delivered to rangefeeds; set to 0 to use kv.closed_timestamp.side_transport_interval system-visible
kv.rangefeed.enabled boolean false if set, rangefeed registration is enabled system-visible
+kv.transaction.buffered_writes.enabled boolean true if enabled, transactional writes are buffered on the gateway application
kv.transaction.max_intents_and_locks integer 0 maximum count of inserts or durable locks for a single transactions, 0 to disable application
kv.transaction.max_intents_bytes integer 4194304 maximum number of bytes used to track locks in transactions application
kv.transaction.max_refresh_spans_bytes integer 4194304 maximum number of bytes used to track refresh spans in serializable transactions application
diff --git a/docs/generated/settings/settings.html b/docs/generated/settings/settings.html
index fe10f493845c..2032b0581b25 100644
--- a/docs/generated/settings/settings.html
+++ b/docs/generated/settings/settings.html
@@ -120,6 +120,7 @@
kv.replication_reports.interval
| duration | 1m0s | the frequency for generating the replication_constraint_stats, replication_stats_report and replication_critical_localities reports (set to 0 to disable) | Dedicated/Self-Hosted |
kv.snapshot_rebalance.max_rate
| byte size | 32 MiB | the rate limit (bytes/sec) to use for rebalance and upreplication snapshots | Dedicated/Self-Hosted |
kv.snapshot_receiver.excise.enabled
| boolean | true | set to false to disable excises in place of range deletions for KV snapshots | Dedicated/Self-Hosted |
+kv.transaction.buffered_writes.enabled
| boolean | true | if enabled, transactional writes are buffered on the gateway | Serverless/Dedicated/Self-Hosted |
kv.transaction.max_intents_and_locks
| integer | 0 | maximum count of inserts or durable locks for a single transactions, 0 to disable | Serverless/Dedicated/Self-Hosted |
kv.transaction.max_intents_bytes
| integer | 4194304 | maximum number of bytes used to track locks in transactions | Serverless/Dedicated/Self-Hosted |
kv.transaction.max_refresh_spans_bytes
| integer | 4194304 | maximum number of bytes used to track refresh spans in serializable transactions | Serverless/Dedicated/Self-Hosted |
diff --git a/pkg/gen/misc.bzl b/pkg/gen/misc.bzl
index 8163f27a4bb6..e3470dfd2916 100644
--- a/pkg/gen/misc.bzl
+++ b/pkg/gen/misc.bzl
@@ -4,6 +4,8 @@ MISC_SRCS = [
"//pkg/backup:data_driven_generated_test.go",
"//pkg/ccl/kvccl/kvtenantccl/upgradeinterlockccl:generated_test.go",
"//pkg/internal/team:TEAMS.yaml",
+ "//pkg/kv/kvclient/kvcoord:bufferedwrite_interval_btree.go",
+ "//pkg/kv/kvclient/kvcoord:bufferedwrite_interval_btree_test.go",
"//pkg/kv/kvpb:batch_generated.go",
"//pkg/kv/kvserver/concurrency:keylocks_interval_btree.go",
"//pkg/kv/kvserver/concurrency:keylocks_interval_btree_test.go",
diff --git a/pkg/kv/kvclient/kvcoord/BUILD.bazel b/pkg/kv/kvclient/kvcoord/BUILD.bazel
index 8b926c105187..fe80799e530b 100644
--- a/pkg/kv/kvclient/kvcoord/BUILD.bazel
+++ b/pkg/kv/kvclient/kvcoord/BUILD.bazel
@@ -2,6 +2,7 @@ load("@bazel_gomock//:gomock.bzl", "gomock")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
load("//build:STRINGER.bzl", "stringer")
load("//pkg/testutils:buildutil/buildutil.bzl", "disallowed_imports_test")
+load("//pkg/util/interval/generic:gen.bzl", "gen_interval_btree")
go_library(
name = "kvcoord",
@@ -31,8 +32,10 @@ go_library(
"txn_interceptor_pipeliner.go",
"txn_interceptor_seq_num_allocator.go",
"txn_interceptor_span_refresher.go",
+ "txn_interceptor_write_buffer.go",
"txn_lock_gatekeeper.go",
"txn_metrics.go",
+ ":bufferedwrite_interval_btree.go", # keep
":gen-txnstate-stringer", # keep
],
importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvclient/kvcoord",
@@ -152,6 +155,7 @@ go_test(
"txn_interceptor_seq_num_allocator_test.go",
"txn_interceptor_span_refresher_test.go",
"txn_test.go",
+ ":bufferedwrite_interval_btree.go", # keep
":mock_kvcoord", # keep
],
data = glob(["testdata/**"]),
@@ -255,6 +259,12 @@ stringer(
typ = "txnState",
)
+gen_interval_btree(
+ name = "buffered_write_interval_btree",
+ package = "kvcoord",
+ type = "*bufferedWrite",
+)
+
disallowed_imports_test(
"kvcoord",
disallowed_list = [
diff --git a/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go
new file mode 100644
index 000000000000..64d19ec7e15a
--- /dev/null
+++ b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree.go
@@ -0,0 +1,1170 @@
+// Code generated by go_generics. DO NOT EDIT.
+
+// Copyright 2020 The Cockroach Authors.
+//
+// Use of this software is governed by the CockroachDB Software License
+// included in the /LICENSE file.
+
+package kvcoord
+
+import (
+ "bytes"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+)
+
+// nilT is a nil instance of the Template type.
+var nilT *bufferedWrite
+
+const (
+ degree = 16
+ maxItems = 2*degree - 1
+ minItems = degree - 1
+)
+
+// compare returns a value indicating the sort order relationship between
+// a and b. The comparison is performed lexicographically on
+//
+// (a.Key(), a.EndKey(), a.ID())
+//
+// and
+//
+// (b.Key(), b.EndKey(), b.ID())
+//
+// tuples.
+//
+// Given c = compare(a, b):
+//
+// c == -1 if (a.Key(), a.EndKey(), a.ID()) < (b.Key(), b.EndKey(), b.ID())
+// c == 0 if (a.Key(), a.EndKey(), a.ID()) == (b.Key(), b.EndKey(), b.ID())
+// c == 1 if (a.Key(), a.EndKey(), a.ID()) > (b.Key(), b.EndKey(), b.ID())
+func compare(a, b *bufferedWrite) int {
+ c := bytes.Compare(a.Key(), b.Key())
+ if c != 0 {
+ return c
+ }
+ c = bytes.Compare(a.EndKey(), b.EndKey())
+ if c != 0 {
+ return c
+ }
+ if a.ID() < b.ID() {
+ return -1
+ } else if a.ID() > b.ID() {
+ return 1
+ } else {
+ return 0
+ }
+}
+
+// keyBound represents the upper-bound of a key range.
+type keyBound struct {
+ key []byte
+ inc bool
+}
+
+func (b keyBound) compare(o keyBound) int {
+ c := bytes.Compare(b.key, o.key)
+ if c != 0 {
+ return c
+ }
+ if b.inc == o.inc {
+ return 0
+ }
+ if b.inc {
+ return 1
+ }
+ return -1
+}
+
+func (b keyBound) contains(a *bufferedWrite) bool {
+ c := bytes.Compare(a.Key(), b.key)
+ if c == 0 {
+ return b.inc
+ }
+ return c < 0
+}
+
+func upperBound(c *bufferedWrite) keyBound {
+ if len(c.EndKey()) != 0 {
+ return keyBound{key: c.EndKey()}
+ }
+ return keyBound{key: c.Key(), inc: true}
+}
+
+type node struct {
+ ref int32
+ count int16
+
+ // These fields form a keyBound, but by inlining them into node we can avoid
+ // the extra word that would be needed to pad out maxInc if it were part of
+ // its own struct.
+ maxInc bool
+ maxKey []byte
+
+ items [maxItems]*bufferedWrite
+
+ // The children array pointer is only populated for interior nodes; it is nil
+ // for leaf nodes.
+ children *childrenArray
+}
+
+type childrenArray = [maxItems + 1]*node
+
+var leafPool = sync.Pool{
+ New: func() interface{} {
+ return new(node)
+ },
+}
+
+var nodePool = sync.Pool{
+ New: func() interface{} {
+ type interiorNode struct {
+ node
+ children childrenArray
+ }
+ n := new(interiorNode)
+ n.node.children = &n.children
+ return &n.node
+ },
+}
+
+func newLeafNode() *node {
+ n := leafPool.Get().(*node)
+ n.ref = 1
+ return n
+}
+
+func newNode() *node {
+ n := nodePool.Get().(*node)
+ n.ref = 1
+ return n
+}
+
+// mut creates and returns a mutable node reference. If the node is not shared
+// with any other trees then it can be modified in place. Otherwise, it must be
+// cloned to ensure unique ownership. In this way, we enforce a copy-on-write
+// policy which transparently incorporates the idea of local mutations, like
+// Clojure's transients or Haskell's ST monad, where nodes are only copied
+// during the first time that they are modified between Clone operations.
+//
+// When a node is cloned, the provided pointer will be redirected to the new
+// mutable node.
+func mut(n **node) *node {
+ if atomic.LoadInt32(&(*n).ref) == 1 {
+ // Exclusive ownership. Can mutate in place.
+ return *n
+ }
+ // If we do not have unique ownership over the node then we
+ // clone it to gain unique ownership. After doing so, we can
+ // release our reference to the old node. We pass recursive
+ // as true because even though we just observed the node's
+ // reference count to be greater than 1, we might be racing
+ // with another call to decRef on this node.
+ c := (*n).clone()
+ (*n).decRef(true /* recursive */)
+ *n = c
+ return *n
+}
+
+// leaf returns true if this is a leaf node.
+func (n *node) leaf() bool {
+ return n.children == nil
+}
+
+// max returns the maximum keyBound in the subtree rooted at this node.
+func (n *node) max() keyBound {
+ return keyBound{
+ key: n.maxKey,
+ inc: n.maxInc,
+ }
+}
+
+// setMax sets the maximum keyBound for the subtree rooted at this node.
+func (n *node) setMax(k keyBound) {
+ n.maxKey = k.key
+ n.maxInc = k.inc
+}
+
+// incRef acquires a reference to the node.
+func (n *node) incRef() {
+ atomic.AddInt32(&n.ref, 1)
+}
+
+// decRef releases a reference to the node. If requested, the method
+// will recurse into child nodes and decrease their refcounts as well.
+func (n *node) decRef(recursive bool) {
+ if atomic.AddInt32(&n.ref, -1) > 0 {
+ // Other references remain. Can't free.
+ return
+ }
+ // Clear and release node into memory pool.
+ if n.leaf() {
+ *n = node{}
+ leafPool.Put(n)
+ } else {
+ // Release child references first, if requested.
+ if recursive {
+ for i := int16(0); i <= n.count; i++ {
+ n.children[i].decRef(true /* recursive */)
+ }
+ }
+ *n = node{children: n.children}
+ *n.children = childrenArray{}
+ nodePool.Put(n)
+ }
+}
+
+// clone creates a clone of the receiver with a single reference count.
+func (n *node) clone() *node {
+ var c *node
+ if n.leaf() {
+ c = newLeafNode()
+ } else {
+ c = newNode()
+ }
+ // NB: copy field-by-field without touching n.ref to avoid
+ // triggering the race detector and looking like a data race.
+ c.count = n.count
+ c.maxKey = n.maxKey
+ c.maxInc = n.maxInc
+ c.items = n.items
+ if !c.leaf() {
+ // Copy children and increase each refcount.
+ *c.children = *n.children
+ for i := int16(0); i <= c.count; i++ {
+ c.children[i].incRef()
+ }
+ }
+ return c
+}
+
+func (n *node) insertAt(index int, item *bufferedWrite, nd *node) {
+ if index < int(n.count) {
+ copy(n.items[index+1:n.count+1], n.items[index:n.count])
+ if !n.leaf() {
+ copy(n.children[index+2:n.count+2], n.children[index+1:n.count+1])
+ }
+ }
+ n.items[index] = item
+ if !n.leaf() {
+ n.children[index+1] = nd
+ }
+ n.count++
+}
+
+func (n *node) pushBack(item *bufferedWrite, nd *node) {
+ n.items[n.count] = item
+ if !n.leaf() {
+ n.children[n.count+1] = nd
+ }
+ n.count++
+}
+
+func (n *node) pushFront(item *bufferedWrite, nd *node) {
+ if !n.leaf() {
+ copy(n.children[1:n.count+2], n.children[:n.count+1])
+ n.children[0] = nd
+ }
+ copy(n.items[1:n.count+1], n.items[:n.count])
+ n.items[0] = item
+ n.count++
+}
+
+// removeAt removes a value at a given index, pulling all subsequent values
+// back.
+func (n *node) removeAt(index int) (*bufferedWrite, *node) {
+ var child *node
+ if !n.leaf() {
+ child = n.children[index+1]
+ copy(n.children[index+1:n.count], n.children[index+2:n.count+1])
+ n.children[n.count] = nil
+ }
+ n.count--
+ out := n.items[index]
+ copy(n.items[index:n.count], n.items[index+1:n.count+1])
+ n.items[n.count] = nilT
+ return out, child
+}
+
+// popBack removes and returns the last element in the list.
+func (n *node) popBack() (*bufferedWrite, *node) {
+ n.count--
+ out := n.items[n.count]
+ n.items[n.count] = nilT
+ if n.leaf() {
+ return out, nil
+ }
+ child := n.children[n.count+1]
+ n.children[n.count+1] = nil
+ return out, child
+}
+
+// popFront removes and returns the first element in the list.
+func (n *node) popFront() (*bufferedWrite, *node) {
+ n.count--
+ var child *node
+ if !n.leaf() {
+ child = n.children[0]
+ copy(n.children[:n.count+1], n.children[1:n.count+2])
+ n.children[n.count+1] = nil
+ }
+ out := n.items[0]
+ copy(n.items[:n.count], n.items[1:n.count+1])
+ n.items[n.count] = nilT
+ return out, child
+}
+
+// find returns the index where the given item should be inserted into this
+// list. 'found' is true if the item already exists in the list at the given
+// index.
+func (n *node) find(item *bufferedWrite) (index int, found bool) {
+ // Logic copied from sort.Search. Inlining this gave
+ // an 11% speedup on BenchmarkBTreeDeleteInsert.
+ i, j := 0, int(n.count)
+ for i < j {
+ h := int(uint(i+j) >> 1) // avoid overflow when computing h
+ // i ≤ h < j
+ v := compare(item, n.items[h])
+ if v == 0 {
+ return h, true
+ } else if v > 0 {
+ i = h + 1
+ } else {
+ j = h
+ }
+ }
+ return i, false
+}
+
+// split splits the given node at the given index. The current node shrinks,
+// and this function returns the item that existed at that index and a new
+// node containing all items/children after it.
+//
+// Before:
+//
+// +-----------+
+// | x y z |
+// +--/-/-\-\--+
+//
+// After:
+//
+// +-----------+
+// | y |
+// +----/-\----+
+// / \
+// v v
+//
+// +-----------+ +-----------+
+// | x | | z |
+// +-----------+ +-----------+
+func (n *node) split(i int) (*bufferedWrite, *node) {
+ out := n.items[i]
+ var next *node
+ if n.leaf() {
+ next = newLeafNode()
+ } else {
+ next = newNode()
+ }
+ next.count = n.count - int16(i+1)
+ copy(next.items[:], n.items[i+1:n.count])
+ for j := int16(i); j < n.count; j++ {
+ n.items[j] = nilT
+ }
+ if !n.leaf() {
+ copy(next.children[:], n.children[i+1:n.count+1])
+ for j := int16(i + 1); j <= n.count; j++ {
+ n.children[j] = nil
+ }
+ }
+ n.count = int16(i)
+
+ nextMax := next.findUpperBound()
+ next.setMax(nextMax)
+ nMax := n.max()
+ if nMax.compare(nextMax) != 0 && nMax.compare(upperBound(out)) != 0 {
+ // If upper bound wasn't from new node or item
+ // at index i, it must still be from old node.
+ } else {
+ n.setMax(n.findUpperBound())
+ }
+ return out, next
+}
+
+// insert inserts an item into the subtree rooted at this node, making sure no
+// nodes in the subtree exceed maxItems items. Returns true if an existing item
+// was replaced and false if an item was inserted. Also returns whether the
+// node's upper bound changes.
+func (n *node) insert(item *bufferedWrite) (replaced, newBound bool) {
+ i, found := n.find(item)
+ if found {
+ n.items[i] = item
+ return true, false
+ }
+ if n.leaf() {
+ n.insertAt(i, item, nil)
+ return false, n.adjustUpperBoundOnInsertion(item, nil)
+ }
+ if n.children[i].count >= maxItems {
+ splitLa, splitNode := mut(&n.children[i]).split(maxItems / 2)
+ n.insertAt(i, splitLa, splitNode)
+
+ switch v := compare(item, n.items[i]); {
+ case v < 0:
+ // no change, we want first split node
+ case v > 0:
+ i++ // we want second split node
+ default:
+ n.items[i] = item
+ return true, false
+ }
+ }
+ replaced, newBound = mut(&n.children[i]).insert(item)
+ if newBound {
+ newBound = n.adjustUpperBoundOnInsertion(item, nil)
+ }
+ return replaced, newBound
+}
+
+// removeMax removes and returns the maximum item from the subtree rooted at
+// this node.
+func (n *node) removeMax() *bufferedWrite {
+ if n.leaf() {
+ n.count--
+ out := n.items[n.count]
+ n.items[n.count] = nilT
+ n.adjustUpperBoundOnRemoval(out, nil)
+ return out
+ }
+ // Recurse into max child.
+ i := int(n.count)
+ if n.children[i].count <= minItems {
+ // Child not large enough to remove from.
+ n.rebalanceOrMerge(i)
+ return n.removeMax() // redo
+ }
+ child := mut(&n.children[i])
+ out := child.removeMax()
+ n.adjustUpperBoundOnRemoval(out, nil)
+ return out
+}
+
+// remove removes an item from the subtree rooted at this node. Returns the item
+// that was removed or nil if no matching item was found. Also returns whether
+// the node's upper bound changes.
+func (n *node) remove(item *bufferedWrite) (out *bufferedWrite, newBound bool) {
+ i, found := n.find(item)
+ if n.leaf() {
+ if found {
+ out, _ = n.removeAt(i)
+ return out, n.adjustUpperBoundOnRemoval(out, nil)
+ }
+ return nilT, false
+ }
+ if n.children[i].count <= minItems {
+ // Child not large enough to remove from.
+ n.rebalanceOrMerge(i)
+ return n.remove(item) // redo
+ }
+ child := mut(&n.children[i])
+ if found {
+ // Replace the item being removed with the max item in our left child.
+ out = n.items[i]
+ n.items[i] = child.removeMax()
+ return out, n.adjustUpperBoundOnRemoval(out, nil)
+ }
+ // Latch is not in this node and child is large enough to remove from.
+ out, newBound = child.remove(item)
+ if newBound {
+ newBound = n.adjustUpperBoundOnRemoval(out, nil)
+ }
+ return out, newBound
+}
+
+// rebalanceOrMerge grows child 'i' to ensure it has sufficient room to remove
+// an item from it while keeping it at or above minItems.
+func (n *node) rebalanceOrMerge(i int) {
+ switch {
+ case i > 0 && n.children[i-1].count > minItems:
+ // Rebalance from left sibling.
+ //
+ // +-----------+
+ // | y |
+ // +----/-\----+
+ // / \
+ // v v
+ // +-----------+ +-----------+
+ // | x | | |
+ // +----------\+ +-----------+
+ // \
+ // v
+ // a
+ //
+ // After:
+ //
+ // +-----------+
+ // | x |
+ // +----/-\----+
+ // / \
+ // v v
+ // +-----------+ +-----------+
+ // | | | y |
+ // +-----------+ +/----------+
+ // /
+ // v
+ // a
+ //
+ left := mut(&n.children[i-1])
+ child := mut(&n.children[i])
+ xLa, grandChild := left.popBack()
+ yLa := n.items[i-1]
+ child.pushFront(yLa, grandChild)
+ n.items[i-1] = xLa
+
+ left.adjustUpperBoundOnRemoval(xLa, grandChild)
+ child.adjustUpperBoundOnInsertion(yLa, grandChild)
+
+ case i < int(n.count) && n.children[i+1].count > minItems:
+ // Rebalance from right sibling.
+ //
+ // +-----------+
+ // | y |
+ // +----/-\----+
+ // / \
+ // v v
+ // +-----------+ +-----------+
+ // | | | x |
+ // +-----------+ +/----------+
+ // /
+ // v
+ // a
+ //
+ // After:
+ //
+ // +-----------+
+ // | x |
+ // +----/-\----+
+ // / \
+ // v v
+ // +-----------+ +-----------+
+ // | y | | |
+ // +----------\+ +-----------+
+ // \
+ // v
+ // a
+ //
+ right := mut(&n.children[i+1])
+ child := mut(&n.children[i])
+ xLa, grandChild := right.popFront()
+ yLa := n.items[i]
+ child.pushBack(yLa, grandChild)
+ n.items[i] = xLa
+
+ right.adjustUpperBoundOnRemoval(xLa, grandChild)
+ child.adjustUpperBoundOnInsertion(yLa, grandChild)
+
+ default:
+ // Merge with either the left or right sibling.
+ //
+ // +-----------+
+ // | u y v |
+ // +----/-\----+
+ // / \
+ // v v
+ // +-----------+ +-----------+
+ // | x | | z |
+ // +-----------+ +-----------+
+ //
+ // After:
+ //
+ // +-----------+
+ // | u v |
+ // +-----|-----+
+ // |
+ // v
+ // +-----------+
+ // | x y z |
+ // +-----------+
+ //
+ if i >= int(n.count) {
+ i = int(n.count - 1)
+ }
+ child := mut(&n.children[i])
+ // Make mergeChild mutable, bumping the refcounts on its children if necessary.
+ _ = mut(&n.children[i+1])
+ mergeLa, mergeChild := n.removeAt(i)
+ child.items[child.count] = mergeLa
+ copy(child.items[child.count+1:], mergeChild.items[:mergeChild.count])
+ if !child.leaf() {
+ copy(child.children[child.count+1:], mergeChild.children[:mergeChild.count+1])
+ }
+ child.count += mergeChild.count + 1
+
+ child.adjustUpperBoundOnInsertion(mergeLa, mergeChild)
+ mergeChild.decRef(false /* recursive */)
+ }
+}
+
+// findUpperBound returns the largest end key node range, assuming that its
+// children have correct upper bounds already set.
+func (n *node) findUpperBound() keyBound {
+ var max keyBound
+ for i := int16(0); i < n.count; i++ {
+ up := upperBound(n.items[i])
+ if max.compare(up) < 0 {
+ max = up
+ }
+ }
+ if !n.leaf() {
+ for i := int16(0); i <= n.count; i++ {
+ up := n.children[i].max()
+ if max.compare(up) < 0 {
+ max = up
+ }
+ }
+ }
+ return max
+}
+
+// adjustUpperBoundOnInsertion adjusts the upper key bound for this node given
+// an item and an optional child node that was inserted. Returns true is the
+// upper bound was changed and false if not.
+func (n *node) adjustUpperBoundOnInsertion(item *bufferedWrite, child *node) bool {
+ up := upperBound(item)
+ if child != nil {
+ if childMax := child.max(); up.compare(childMax) < 0 {
+ up = childMax
+ }
+ }
+ if n.max().compare(up) < 0 {
+ n.setMax(up)
+ return true
+ }
+ return false
+}
+
+// adjustUpperBoundOnRemoval adjusts the upper key bound for this node given an
+// item and an optional child node that was removed. Returns true is the upper
+// bound was changed and false if not.
+func (n *node) adjustUpperBoundOnRemoval(item *bufferedWrite, child *node) bool {
+ up := upperBound(item)
+ if child != nil {
+ if childMax := child.max(); up.compare(childMax) < 0 {
+ up = childMax
+ }
+ }
+ if n.max().compare(up) == 0 {
+ // up was previous upper bound of n.
+ max := n.findUpperBound()
+ n.setMax(max)
+ return max.compare(up) != 0
+ }
+ return false
+}
+
+// btree is an implementation of an augmented interval B-Tree.
+//
+// btree stores items in an ordered structure, allowing easy insertion,
+// removal, and iteration. It represents intervals and permits an interval
+// search operation following the approach laid out in CLRS, Chapter 14.
+// The B-Tree stores items in order based on their start key and each
+// B-Tree node maintains the upper-bound end key of all items in its
+// subtree.
+//
+// Write operations are not safe for concurrent mutation by multiple
+// goroutines, but Read operations are.
+type btree struct {
+ root *node
+ length int
+}
+
+// Reset removes all items from the btree. In doing so, it allows memory
+// held by the btree to be recycled. Failure to call this method before
+// letting a btree be GCed is safe in that it won't cause a memory leak,
+// but it will prevent btree nodes from being efficiently re-used.
+func (t *btree) Reset() {
+ if t.root != nil {
+ t.root.decRef(true /* recursive */)
+ t.root = nil
+ }
+ t.length = 0
+}
+
+// Clone clones the btree, lazily. It does so in constant time.
+func (t *btree) Clone() btree {
+ c := *t
+ if c.root != nil {
+ // Incrementing the reference count on the root node is sufficient to
+ // ensure that no node in the cloned tree can be mutated by an actor
+ // holding a reference to the original tree and vice versa. This
+ // property is upheld because the root node in the receiver btree and
+ // the returned btree will both necessarily have a reference count of at
+ // least 2 when this method returns. All tree mutations recursively
+ // acquire mutable node references (see mut) as they traverse down the
+ // tree. The act of acquiring a mutable node reference performs a clone
+ // if a node's reference count is greater than one. Cloning a node (see
+ // clone) increases the reference count on each of its children,
+ // ensuring that they have a reference count of at least 2. This, in
+ // turn, ensures that any of the child nodes that are modified will also
+ // be copied-on-write, recursively ensuring the immutability property
+ // over the entire tree.
+ c.root.incRef()
+ }
+ return c
+}
+
+// Delete removes an item equal to the passed in item from the tree.
+func (t *btree) Delete(item *bufferedWrite) {
+ if t.root == nil || t.root.count == 0 {
+ return
+ }
+ if out, _ := mut(&t.root).remove(item); out != nilT {
+ t.length--
+ }
+ if t.root.count == 0 {
+ old := t.root
+ if t.root.leaf() {
+ t.root = nil
+ } else {
+ t.root = t.root.children[0]
+ }
+ old.decRef(false /* recursive */)
+ }
+}
+
+// Set adds the given item to the tree. If an item in the tree already equals
+// the given one, it is replaced with the new item.
+func (t *btree) Set(item *bufferedWrite) {
+ if t.root == nil {
+ t.root = newLeafNode()
+ } else if t.root.count >= maxItems {
+ splitLa, splitNode := mut(&t.root).split(maxItems / 2)
+ newRoot := newNode()
+ newRoot.count = 1
+ newRoot.items[0] = splitLa
+ newRoot.children[0] = t.root
+ newRoot.children[1] = splitNode
+ newRoot.setMax(newRoot.findUpperBound())
+ t.root = newRoot
+ }
+ if replaced, _ := mut(&t.root).insert(item); !replaced {
+ t.length++
+ }
+}
+
+// MakeIter returns a new iterator object. It is not safe to continue using an
+// iterator after modifications are made to the tree. If modifications are made,
+// create a new iterator.
+func (t *btree) MakeIter() iterator {
+ return iterator{r: t.root, pos: -1}
+}
+
+// Height returns the height of the tree.
+func (t *btree) Height() int {
+ if t.root == nil {
+ return 0
+ }
+ h := 1
+ n := t.root
+ for !n.leaf() {
+ n = n.children[0]
+ h++
+ }
+ return h
+}
+
+// Len returns the number of items currently in the tree.
+func (t *btree) Len() int {
+ return t.length
+}
+
+// String returns a string description of the tree. The format is
+// similar to the https://en.wikipedia.org/wiki/Newick_format.
+func (t *btree) String() string {
+ if t.length == 0 {
+ return ";"
+ }
+ var b strings.Builder
+ t.root.writeString(&b)
+ return b.String()
+}
+
+func (n *node) writeString(b *strings.Builder) {
+ if n.leaf() {
+ for i := int16(0); i < n.count; i++ {
+ if i != 0 {
+ b.WriteString(",")
+ }
+ b.WriteString(n.items[i].String())
+ }
+ return
+ }
+ for i := int16(0); i <= n.count; i++ {
+ b.WriteString("(")
+ n.children[i].writeString(b)
+ b.WriteString(")")
+ if i < n.count {
+ b.WriteString(n.items[i].String())
+ }
+ }
+}
+
+// iterStack represents a stack of (node, pos) tuples, which captures
+// iteration state as an iterator descends a btree.
+type iterStack struct {
+ a iterStackArr
+ aLen int16 // -1 when using s
+ s []iterFrame
+}
+
+// Used to avoid allocations for stacks below a certain size.
+type iterStackArr [3]iterFrame
+
+type iterFrame struct {
+ n *node
+ pos int16
+}
+
+func (is *iterStack) push(f iterFrame) {
+ if is.aLen == -1 {
+ is.s = append(is.s, f)
+ } else if int(is.aLen) == len(is.a) {
+ is.s = make([]iterFrame, int(is.aLen)+1, 2*int(is.aLen))
+ copy(is.s, is.a[:])
+ is.s[int(is.aLen)] = f
+ is.aLen = -1
+ } else {
+ is.a[is.aLen] = f
+ is.aLen++
+ }
+}
+
+func (is *iterStack) pop() iterFrame {
+ if is.aLen == -1 {
+ f := is.s[len(is.s)-1]
+ is.s = is.s[:len(is.s)-1]
+ return f
+ }
+ is.aLen--
+ return is.a[is.aLen]
+}
+
+func (is *iterStack) len() int {
+ if is.aLen == -1 {
+ return len(is.s)
+ }
+ return int(is.aLen)
+}
+
+func (is *iterStack) reset() {
+ if is.aLen == -1 {
+ is.s = is.s[:0]
+ } else {
+ is.aLen = 0
+ }
+}
+
+// iterator is responsible for search and traversal within a btree.
+type iterator struct {
+ r *node
+ n *node
+ pos int16
+ s iterStack
+ o overlapScan
+}
+
+func (i *iterator) reset() {
+ i.n = i.r
+ i.pos = -1
+ i.s.reset()
+ i.o = overlapScan{}
+}
+
+func (i *iterator) descend(n *node, pos int16) {
+ i.s.push(iterFrame{n: n, pos: pos})
+ i.n = n.children[pos]
+ i.pos = 0
+}
+
+// ascend ascends up to the current node's parent and resets the position
+// to the one previously set for this parent node.
+func (i *iterator) ascend() {
+ f := i.s.pop()
+ i.n = f.n
+ i.pos = f.pos
+}
+
+// SeekGE seeks to the first item greater-than or equal to the provided
+// item.
+func (i *iterator) SeekGE(item *bufferedWrite) {
+ i.reset()
+ if i.n == nil {
+ return
+ }
+ for {
+ pos, found := i.n.find(item)
+ i.pos = int16(pos)
+ if found {
+ return
+ }
+ if i.n.leaf() {
+ if i.pos == i.n.count {
+ i.Next()
+ }
+ return
+ }
+ i.descend(i.n, i.pos)
+ }
+}
+
+// SeekLT seeks to the first item less-than the provided item.
+func (i *iterator) SeekLT(item *bufferedWrite) {
+ i.reset()
+ if i.n == nil {
+ return
+ }
+ for {
+ pos, found := i.n.find(item)
+ i.pos = int16(pos)
+ if found || i.n.leaf() {
+ i.Prev()
+ return
+ }
+ i.descend(i.n, i.pos)
+ }
+}
+
+// First seeks to the first item in the btree.
+func (i *iterator) First() {
+ i.reset()
+ if i.n == nil {
+ return
+ }
+ for !i.n.leaf() {
+ i.descend(i.n, 0)
+ }
+ i.pos = 0
+}
+
+// Last seeks to the last item in the btree.
+func (i *iterator) Last() {
+ i.reset()
+ if i.n == nil {
+ return
+ }
+ for !i.n.leaf() {
+ i.descend(i.n, i.n.count)
+ }
+ i.pos = i.n.count - 1
+}
+
+// Next positions the iterator to the item immediately following
+// its current position.
+func (i *iterator) Next() {
+ if i.n == nil {
+ return
+ }
+
+ if i.n.leaf() {
+ i.pos++
+ if i.pos < i.n.count {
+ return
+ }
+ for i.s.len() > 0 && i.pos >= i.n.count {
+ i.ascend()
+ }
+ return
+ }
+
+ i.descend(i.n, i.pos+1)
+ for !i.n.leaf() {
+ i.descend(i.n, 0)
+ }
+ i.pos = 0
+}
+
+// Prev positions the iterator to the item immediately preceding
+// its current position.
+func (i *iterator) Prev() {
+ if i.n == nil {
+ return
+ }
+
+ if i.n.leaf() {
+ i.pos--
+ if i.pos >= 0 {
+ return
+ }
+ for i.s.len() > 0 && i.pos < 0 {
+ i.ascend()
+ i.pos--
+ }
+ return
+ }
+
+ i.descend(i.n, i.pos)
+ for !i.n.leaf() {
+ i.descend(i.n, i.n.count)
+ }
+ i.pos = i.n.count - 1
+}
+
+// Valid returns whether the iterator is positioned at a valid position.
+func (i *iterator) Valid() bool {
+ return i.pos >= 0 && i.pos < i.n.count
+}
+
+// Cur returns the item at the iterator's current position. It is illegal
+// to call Cur if the iterator is not valid.
+func (i *iterator) Cur() *bufferedWrite {
+ return i.n.items[i.pos]
+}
+
+// An overlap scan is a scan over all items that overlap with the provided
+// item in order of the overlapping items' start keys. The goal of the scan
+// is to minimize the number of key comparisons performed in total. The
+// algorithm operates based on the following two invariants maintained by
+// augmented interval btree:
+// 1. all items are sorted in the btree based on their start key.
+// 2. all btree nodes maintain the upper bound end key of all items
+// in their subtree.
+//
+// The scan algorithm starts in "unconstrained minimum" and "unconstrained
+// maximum" states. To enter a "constrained minimum" state, the scan must reach
+// items in the tree with start keys above the search range's start key.
+// Because items in the tree are sorted by start key, once the scan enters the
+// "constrained minimum" state it will remain there. To enter a "constrained
+// maximum" state, the scan must determine the first child btree node in a given
+// subtree that can have items with start keys above the search range's end
+// key. The scan then remains in the "constrained maximum" state until it
+// traverse into this child node, at which point it moves to the "unconstrained
+// maximum" state again.
+//
+// The scan algorithm works like a standard btree forward scan with the
+// following augmentations:
+// 1. before tranversing the tree, the scan performs a binary search on the
+// root node's items to determine a "soft" lower-bound constraint position
+// and a "hard" upper-bound constraint position in the root's children.
+// 2. when tranversing into a child node in the lower or upper bound constraint
+// position, the constraint is refined by searching the child's items.
+// 3. the initial traversal down the tree follows the left-most children
+// whose upper bound end keys are equal to or greater than the start key
+// of the search range. The children followed will be equal to or less
+// than the soft lower bound constraint.
+// 4. once the initial tranversal completes and the scan is in the left-most
+// btree node whose upper bound overlaps the search range, key comparisons
+// must be performed with each item in the tree. This is necessary because
+// any of these items may have end keys that cause them to overlap with the
+// search range.
+// 5. once the scan reaches the lower bound constraint position (the first item
+// with a start key equal to or greater than the search range's start key),
+// it can begin scaning without performing key comparisons. This is allowed
+// because all items from this point forward will have end keys that are
+// greater than the search range's start key.
+// 6. once the scan reaches the upper bound constraint position, it terminates.
+// It does so because the item at this position is the first item with a
+// start key larger than the search range's end key.
+type overlapScan struct {
+ // The "soft" lower-bound constraint.
+ constrMinN *node
+ constrMinPos int16
+ constrMinReached bool
+
+ // The "hard" upper-bound constraint.
+ constrMaxN *node
+ constrMaxPos int16
+}
+
+// FirstOverlap seeks to the first item in the btree that overlaps with the
+// provided search item.
+func (i *iterator) FirstOverlap(item *bufferedWrite) {
+ i.reset()
+ if i.n == nil {
+ return
+ }
+ i.pos = 0
+ i.o = overlapScan{}
+ i.constrainMinSearchBounds(item)
+ i.constrainMaxSearchBounds(item)
+ i.findNextOverlap(item)
+}
+
+// NextOverlap positions the iterator to the item immediately following
+// its current position that overlaps with the search item.
+func (i *iterator) NextOverlap(item *bufferedWrite) {
+ if i.n == nil {
+ return
+ }
+ i.pos++
+ i.findNextOverlap(item)
+}
+
+func (i *iterator) constrainMinSearchBounds(item *bufferedWrite) {
+ k := item.Key()
+ j := sort.Search(int(i.n.count), func(j int) bool {
+ return bytes.Compare(k, i.n.items[j].Key()) <= 0
+ })
+ i.o.constrMinN = i.n
+ i.o.constrMinPos = int16(j)
+}
+
+func (i *iterator) constrainMaxSearchBounds(item *bufferedWrite) {
+ up := upperBound(item)
+ j := sort.Search(int(i.n.count), func(j int) bool {
+ return !up.contains(i.n.items[j])
+ })
+ i.o.constrMaxN = i.n
+ i.o.constrMaxPos = int16(j)
+}
+
+func (i *iterator) findNextOverlap(item *bufferedWrite) {
+ for {
+ if i.pos > i.n.count {
+ // Iterate up tree.
+ i.ascend()
+ } else if !i.n.leaf() {
+ // Iterate down tree.
+ if i.o.constrMinReached || i.n.children[i.pos].max().contains(item) {
+ par := i.n
+ pos := i.pos
+ i.descend(par, pos)
+
+ // Refine the constraint bounds, if necessary.
+ if par == i.o.constrMinN && pos == i.o.constrMinPos {
+ i.constrainMinSearchBounds(item)
+ }
+ if par == i.o.constrMaxN && pos == i.o.constrMaxPos {
+ i.constrainMaxSearchBounds(item)
+ }
+ continue
+ }
+ }
+
+ // Check search bounds.
+ if i.n == i.o.constrMaxN && i.pos == i.o.constrMaxPos {
+ // Invalid. Past possible overlaps.
+ i.pos = i.n.count
+ return
+ }
+ if i.n == i.o.constrMinN && i.pos == i.o.constrMinPos {
+ // The scan reached the soft lower-bound constraint.
+ i.o.constrMinReached = true
+ }
+
+ // Iterate across node.
+ if i.pos < i.n.count {
+ // Check for overlapping item.
+ if i.o.constrMinReached {
+ // Fast-path to avoid span comparison. i.o.constrMinReached
+ // tells us that all items have end keys above our search
+ // span's start key.
+ return
+ }
+ if upperBound(i.n.items[i.pos]).contains(item) {
+ return
+ }
+ }
+ i.pos++
+ }
+}
diff --git a/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go
new file mode 100644
index 000000000000..01ba1c155290
--- /dev/null
+++ b/pkg/kv/kvclient/kvcoord/bufferedwrite_interval_btree_test.go
@@ -0,0 +1,1111 @@
+// Code generated by go_generics. DO NOT EDIT.
+
+// Copyright 2020 The Cockroach Authors.
+//
+// Use of this software is governed by the CockroachDB Software License
+// included in the /LICENSE file.
+
+package kvcoord
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "sync"
+ "testing"
+
+ // Load pkg/keys so that roachpb.Span.String() could be executed correctly.
+ _ "github.com/cockroachdb/cockroach/pkg/keys"
+ "github.com/cockroachdb/cockroach/pkg/roachpb"
+ "github.com/cockroachdb/cockroach/pkg/util/timeutil"
+ "github.com/stretchr/testify/require"
+)
+
+func newItem(s roachpb.Span) *bufferedWrite {
+ i := nilT.New()
+ i.SetKey(s.Key)
+ i.SetEndKey(s.EndKey)
+ return i
+}
+
+func spanFromItem(i *bufferedWrite) roachpb.Span {
+ return roachpb.Span{Key: i.Key(), EndKey: i.EndKey()}
+}
+
+//////////////////////////////////////////
+// Invariant verification //
+//////////////////////////////////////////
+
+// Verify asserts that the tree's structural invariants all hold.
+func (t *btree) Verify(tt *testing.T) {
+ if t.length == 0 {
+ require.Nil(tt, t.root)
+ return
+ }
+ t.verifyLeafSameDepth(tt)
+ t.verifyCountAllowed(tt)
+ t.isSorted(tt)
+ t.isUpperBoundCorrect(tt)
+}
+
+func (t *btree) verifyLeafSameDepth(tt *testing.T) {
+ h := t.Height()
+ t.root.verifyDepthEqualToHeight(tt, 1, h)
+}
+
+func (n *node) verifyDepthEqualToHeight(t *testing.T, depth, height int) {
+ if n.leaf() {
+ require.Equal(t, height, depth, "all leaves should have the same depth as the tree height")
+ }
+ n.recurse(func(child *node, _ int16) {
+ child.verifyDepthEqualToHeight(t, depth+1, height)
+ })
+}
+
+func (t *btree) verifyCountAllowed(tt *testing.T) {
+ t.root.verifyCountAllowed(tt, true)
+}
+
+func (n *node) verifyCountAllowed(t *testing.T, root bool) {
+ if !root {
+ require.GreaterOrEqual(t, n.count, int16(minItems), "latch count %d must be in range [%d,%d]", n.count, minItems, maxItems)
+ require.LessOrEqual(t, n.count, int16(maxItems), "latch count %d must be in range [%d,%d]", n.count, minItems, maxItems)
+ }
+ for i, item := range n.items {
+ if i < int(n.count) {
+ require.NotNil(t, item, "latch below count")
+ } else {
+ require.Nil(t, item, "latch above count")
+ }
+ }
+ if !n.leaf() {
+ for i, child := range n.children {
+ if i <= int(n.count) {
+ require.NotNil(t, child, "node below count")
+ } else {
+ require.Nil(t, child, "node above count")
+ }
+ }
+ }
+ n.recurse(func(child *node, _ int16) {
+ child.verifyCountAllowed(t, false)
+ })
+}
+
+func (t *btree) isSorted(tt *testing.T) {
+ t.root.isSorted(tt)
+}
+
+func (n *node) isSorted(t *testing.T) {
+ for i := int16(1); i < n.count; i++ {
+ require.LessOrEqual(t, compare(n.items[i-1], n.items[i]), 0)
+ }
+ if !n.leaf() {
+ for i := int16(0); i < n.count; i++ {
+ prev := n.children[i]
+ next := n.children[i+1]
+
+ require.LessOrEqual(t, compare(prev.items[prev.count-1], n.items[i]), 0)
+ require.LessOrEqual(t, compare(n.items[i], next.items[0]), 0)
+ }
+ }
+ n.recurse(func(child *node, _ int16) {
+ child.isSorted(t)
+ })
+}
+
+func (t *btree) isUpperBoundCorrect(tt *testing.T) {
+ t.root.isUpperBoundCorrect(tt)
+}
+
+func (n *node) isUpperBoundCorrect(t *testing.T) {
+ require.Equal(t, 0, n.findUpperBound().compare(n.max()))
+ for i := int16(1); i < n.count; i++ {
+ require.LessOrEqual(t, upperBound(n.items[i]).compare(n.max()), 0)
+ }
+ if !n.leaf() {
+ for i := int16(0); i <= n.count; i++ {
+ child := n.children[i]
+ require.LessOrEqual(t, child.max().compare(n.max()), 0)
+ }
+ }
+ n.recurse(func(child *node, _ int16) {
+ child.isUpperBoundCorrect(t)
+ })
+}
+
+func (n *node) recurse(f func(child *node, pos int16)) {
+ if !n.leaf() {
+ for i := int16(0); i <= n.count; i++ {
+ f(n.children[i], i)
+ }
+ }
+}
+
+//////////////////////////////////////////
+// Unit Tests //
+//////////////////////////////////////////
+
+func key(i int) roachpb.Key {
+ if i < 0 || i > 99999 {
+ panic("key out of bounds")
+ }
+ return []byte(fmt.Sprintf("%05d", i))
+}
+
+func span(i int) roachpb.Span {
+ switch i % 10 {
+ case 0:
+ return roachpb.Span{Key: key(i)}
+ case 1:
+ return roachpb.Span{Key: key(i), EndKey: key(i).Next()}
+ case 2:
+ return roachpb.Span{Key: key(i), EndKey: key(i + 64)}
+ default:
+ return roachpb.Span{Key: key(i), EndKey: key(i + 4)}
+ }
+}
+
+func spanWithEnd(start, end int) roachpb.Span {
+ if start < end {
+ return roachpb.Span{Key: key(start), EndKey: key(end)}
+ } else if start == end {
+ return roachpb.Span{Key: key(start)}
+ } else {
+ panic("illegal span")
+ }
+}
+
+func spanWithMemo(i int, memo map[int]roachpb.Span) roachpb.Span {
+ if s, ok := memo[i]; ok {
+ return s
+ }
+ s := span(i)
+ memo[i] = s
+ return s
+}
+
+func randomSpan(rng *rand.Rand, n int) roachpb.Span {
+ start := rng.Intn(n)
+ end := rng.Intn(n + 1)
+ if end < start {
+ start, end = end, start
+ }
+ return spanWithEnd(start, end)
+}
+
+func checkIter(t *testing.T, it iterator, start, end int, spanMemo map[int]roachpb.Span) {
+ i := start
+ for it.First(); it.Valid(); it.Next() {
+ item := it.Cur()
+ expected := spanWithMemo(i, spanMemo)
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("expected %s, but found %s", expected, spanFromItem(item))
+ }
+ i++
+ }
+ if i != end {
+ t.Fatalf("expected %d, but at %d", end, i)
+ }
+
+ for it.Last(); it.Valid(); it.Prev() {
+ i--
+ item := it.Cur()
+ expected := spanWithMemo(i, spanMemo)
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("expected %s, but found %s", expected, spanFromItem(item))
+ }
+ }
+ if i != start {
+ t.Fatalf("expected %d, but at %d: %+v", start, i, it)
+ }
+
+ all := newItem(spanWithEnd(start, end))
+ for it.FirstOverlap(all); it.Valid(); it.NextOverlap(all) {
+ item := it.Cur()
+ expected := spanWithMemo(i, spanMemo)
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("expected %s, but found %s", expected, spanFromItem(item))
+ }
+ i++
+ }
+ if i != end {
+ t.Fatalf("expected %d, but at %d", end, i)
+ }
+}
+
+// TestBTree tests basic btree operations.
+func TestBTree(t *testing.T) {
+ var tr btree
+ spanMemo := make(map[int]roachpb.Span)
+
+ // With degree == 16 (max-items/node == 31) we need 513 items in order for
+ // there to be 3 levels in the tree. The count here is comfortably above
+ // that.
+ const count = 768
+
+ // Add keys in sorted order.
+ for i := 0; i < count; i++ {
+ tr.Set(newItem(span(i)))
+ tr.Verify(t)
+ if e := i + 1; e != tr.Len() {
+ t.Fatalf("expected length %d, but found %d", e, tr.Len())
+ }
+ checkIter(t, tr.MakeIter(), 0, i+1, spanMemo)
+ }
+
+ // Delete keys in sorted order.
+ for i := 0; i < count; i++ {
+ tr.Delete(newItem(span(i)))
+ tr.Verify(t)
+ if e := count - (i + 1); e != tr.Len() {
+ t.Fatalf("expected length %d, but found %d", e, tr.Len())
+ }
+ checkIter(t, tr.MakeIter(), i+1, count, spanMemo)
+ }
+
+ // Add keys in reverse sorted order.
+ for i := 0; i < count; i++ {
+ tr.Set(newItem(span(count - i)))
+ tr.Verify(t)
+ if e := i + 1; e != tr.Len() {
+ t.Fatalf("expected length %d, but found %d", e, tr.Len())
+ }
+ checkIter(t, tr.MakeIter(), count-i, count+1, spanMemo)
+ }
+
+ // Delete keys in reverse sorted order.
+ for i := 0; i < count; i++ {
+ tr.Delete(newItem(span(count - i)))
+ tr.Verify(t)
+ if e := count - (i + 1); e != tr.Len() {
+ t.Fatalf("expected length %d, but found %d", e, tr.Len())
+ }
+ checkIter(t, tr.MakeIter(), 1, count-i, spanMemo)
+ }
+}
+
+// TestBTreeSeek tests basic btree iterator operations.
+func TestBTreeSeek(t *testing.T) {
+ const count = 513
+
+ var tr btree
+ for i := 0; i < count; i++ {
+ tr.Set(newItem(span(i * 2)))
+ }
+
+ it := tr.MakeIter()
+ for i := 0; i < 2*count-1; i++ {
+ it.SeekGE(newItem(span(i)))
+ if !it.Valid() {
+ t.Fatalf("%d: expected valid iterator", i)
+ }
+ item := it.Cur()
+ expected := span(2 * ((i + 1) / 2))
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item))
+ }
+ }
+ it.SeekGE(newItem(span(2*count - 1)))
+ if it.Valid() {
+ t.Fatalf("expected invalid iterator")
+ }
+
+ for i := 1; i < 2*count; i++ {
+ it.SeekLT(newItem(span(i)))
+ if !it.Valid() {
+ t.Fatalf("%d: expected valid iterator", i)
+ }
+ item := it.Cur()
+ expected := span(2 * ((i - 1) / 2))
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item))
+ }
+ }
+ it.SeekLT(newItem(span(0)))
+ if it.Valid() {
+ t.Fatalf("expected invalid iterator")
+ }
+}
+
+// TestBTreeSeekOverlap tests btree iterator overlap operations.
+func TestBTreeSeekOverlap(t *testing.T) {
+ const count = 513
+ const size = 2 * maxItems
+
+ var tr btree
+ for i := 0; i < count; i++ {
+ tr.Set(newItem(spanWithEnd(i, i+size+1)))
+ }
+
+ // Iterate over overlaps with a point scan.
+ it := tr.MakeIter()
+ for i := 0; i < count+size; i++ {
+ scanItem := newItem(spanWithEnd(i, i))
+ it.FirstOverlap(scanItem)
+ for j := 0; j < size+1; j++ {
+ expStart := i - size + j
+ if expStart < 0 {
+ continue
+ }
+ if expStart >= count {
+ continue
+ }
+
+ if !it.Valid() {
+ t.Fatalf("%d/%d: expected valid iterator", i, j)
+ }
+ item := it.Cur()
+ expected := spanWithEnd(expStart, expStart+size+1)
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item))
+ }
+
+ it.NextOverlap(scanItem)
+ }
+ if it.Valid() {
+ t.Fatalf("%d: expected invalid iterator %v", i, it.Cur())
+ }
+ }
+ it.FirstOverlap(newItem(span(count + size + 1)))
+ if it.Valid() {
+ t.Fatalf("expected invalid iterator")
+ }
+
+ // Iterate over overlaps with a range scan.
+ it = tr.MakeIter()
+ for i := 0; i < count+size; i++ {
+ scanItem := newItem(spanWithEnd(i, i+size+1))
+ it.FirstOverlap(scanItem)
+ for j := 0; j < 2*size+1; j++ {
+ expStart := i - size + j
+ if expStart < 0 {
+ continue
+ }
+ if expStart >= count {
+ continue
+ }
+
+ if !it.Valid() {
+ t.Fatalf("%d/%d: expected valid iterator", i, j)
+ }
+ item := it.Cur()
+ expected := spanWithEnd(expStart, expStart+size+1)
+ if !expected.Equal(spanFromItem(item)) {
+ t.Fatalf("%d: expected %s, but found %s", i, expected, spanFromItem(item))
+ }
+
+ it.NextOverlap(scanItem)
+ }
+ if it.Valid() {
+ t.Fatalf("%d: expected invalid iterator %v", i, it.Cur())
+ }
+ }
+ it.FirstOverlap(newItem(span(count + size + 1)))
+ if it.Valid() {
+ t.Fatalf("expected invalid iterator")
+ }
+}
+
+// TestBTreeCompare tests the btree item comparison.
+func TestBTreeCompare(t *testing.T) {
+ // NB: go_generics doesn't do well with anonymous types, so name this type.
+ // Avoid the slice literal syntax, which GofmtSimplify mandates the use of
+ // anonymous constructors with.
+ type testCase struct {
+ spanA, spanB roachpb.Span
+ idA, idB uint64
+ exp int
+ }
+ var testCases []testCase
+ testCases = append(testCases,
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a")},
+ spanB: roachpb.Span{Key: roachpb.Key("a")},
+ idA: 1,
+ idB: 1,
+ exp: 0,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a")},
+ spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("b")},
+ idA: 1,
+ idB: 1,
+ exp: -1,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")},
+ spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("b")},
+ idA: 1,
+ idB: 1,
+ exp: 1,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")},
+ spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")},
+ idA: 1,
+ idB: 1,
+ exp: 0,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a")},
+ spanB: roachpb.Span{Key: roachpb.Key("a")},
+ idA: 1,
+ idB: 2,
+ exp: -1,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("a")},
+ spanB: roachpb.Span{Key: roachpb.Key("a")},
+ idA: 2,
+ idB: 1,
+ exp: 1,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("b")},
+ spanB: roachpb.Span{Key: roachpb.Key("a"), EndKey: roachpb.Key("c")},
+ idA: 1,
+ idB: 1,
+ exp: 1,
+ },
+ testCase{
+ spanA: roachpb.Span{Key: roachpb.Key("b"), EndKey: roachpb.Key("e")},
+ spanB: roachpb.Span{Key: roachpb.Key("c"), EndKey: roachpb.Key("d")},
+ idA: 1,
+ idB: 1,
+ exp: -1,
+ },
+ )
+ for _, tc := range testCases {
+ name := fmt.Sprintf("compare(%s:%d,%s:%d)", tc.spanA, tc.idA, tc.spanB, tc.idB)
+ t.Run(name, func(t *testing.T) {
+ laA := newItem(tc.spanA)
+ laA.SetID(tc.idA)
+ laB := newItem(tc.spanB)
+ laB.SetID(tc.idB)
+ require.Equal(t, tc.exp, compare(laA, laB))
+ })
+ }
+}
+
+// TestIterStack tests the interface of the iterStack type.
+func TestIterStack(t *testing.T) {
+ f := func(i int) iterFrame { return iterFrame{pos: int16(i)} }
+ var is iterStack
+ for i := 1; i <= 2*len(iterStackArr{}); i++ {
+ var j int
+ for j = 0; j < i; j++ {
+ is.push(f(j))
+ }
+ require.Equal(t, j, is.len())
+ for j--; j >= 0; j-- {
+ require.Equal(t, f(j), is.pop())
+ }
+ is.reset()
+ }
+}
+
+//////////////////////////////////////////
+// Randomized Unit Tests //
+//////////////////////////////////////////
+
+// perm returns a random permutation of items with spans in the range [0, n).
+func perm(n int) (out []*bufferedWrite) {
+ for _, i := range rand.Perm(n) {
+ out = append(out, newItem(spanWithEnd(i, i+1)))
+ }
+ return out
+}
+
+// rang returns an ordered list of items with spans in the range [m, n].
+func rang(m, n int) (out []*bufferedWrite) {
+ for i := m; i <= n; i++ {
+ out = append(out, newItem(spanWithEnd(i, i+1)))
+ }
+ return out
+}
+
+// all extracts all items from a tree in order as a slice.
+func all(tr *btree) (out []*bufferedWrite) {
+ it := tr.MakeIter()
+ it.First()
+ for it.Valid() {
+ out = append(out, it.Cur())
+ it.Next()
+ }
+ return out
+}
+
+func run(tb testing.TB, name string, f func(testing.TB)) {
+ switch v := tb.(type) {
+ case *testing.T:
+ v.Run(name, func(t *testing.T) {
+ f(t)
+ })
+ case *testing.B:
+ v.Run(name, func(b *testing.B) {
+ f(b)
+ })
+ default:
+ tb.Fatalf("unknown %T", tb)
+ }
+}
+
+func iters(tb testing.TB, count int) int {
+ switch v := tb.(type) {
+ case *testing.T:
+ return count
+ case *testing.B:
+ return v.N
+ default:
+ tb.Fatalf("unknown %T", tb)
+ return 0
+ }
+}
+
+func verify(tb testing.TB, tr *btree) {
+ if tt, ok := tb.(*testing.T); ok {
+ tr.Verify(tt)
+ }
+}
+
+func resetTimer(tb testing.TB) {
+ if b, ok := tb.(*testing.B); ok {
+ b.ResetTimer()
+ }
+}
+
+func stopTimer(tb testing.TB) {
+ if b, ok := tb.(*testing.B); ok {
+ b.StopTimer()
+ }
+}
+
+func startTimer(tb testing.TB) {
+ if b, ok := tb.(*testing.B); ok {
+ b.StartTimer()
+ }
+}
+
+func runBTreeInsert(tb testing.TB, count int) {
+ iters := iters(tb, count)
+ insertP := perm(count)
+ resetTimer(tb)
+ for i := 0; i < iters; {
+ var tr btree
+ for _, item := range insertP {
+ tr.Set(item)
+ verify(tb, &tr)
+ i++
+ if i >= iters {
+ return
+ }
+ }
+ }
+}
+
+func runBTreeDelete(tb testing.TB, count int) {
+ iters := iters(tb, count)
+ insertP, removeP := perm(count), perm(count)
+ resetTimer(tb)
+ for i := 0; i < iters; {
+ stopTimer(tb)
+ var tr btree
+ for _, item := range insertP {
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+ startTimer(tb)
+ for _, item := range removeP {
+ tr.Delete(item)
+ verify(tb, &tr)
+ i++
+ if i >= iters {
+ return
+ }
+ }
+ if tr.Len() > 0 {
+ tb.Fatalf("tree not empty: %s", &tr)
+ }
+ }
+}
+
+func runBTreeDeleteInsert(tb testing.TB, count int) {
+ iters := iters(tb, count)
+ insertP := perm(count)
+ var tr btree
+ for _, item := range insertP {
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+ resetTimer(tb)
+ for i := 0; i < iters; i++ {
+ item := insertP[i%count]
+ tr.Delete(item)
+ verify(tb, &tr)
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+}
+
+func runBTreeDeleteInsertCloneOnce(tb testing.TB, count int) {
+ iters := iters(tb, count)
+ insertP := perm(count)
+ var tr btree
+ for _, item := range insertP {
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+ tr = tr.Clone()
+ resetTimer(tb)
+ for i := 0; i < iters; i++ {
+ item := insertP[i%count]
+ tr.Delete(item)
+ verify(tb, &tr)
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+}
+
+func runBTreeDeleteInsertCloneEachTime(tb testing.TB, count int) {
+ for _, reset := range []bool{false, true} {
+ run(tb, fmt.Sprintf("reset=%t", reset), func(tb testing.TB) {
+ iters := iters(tb, count)
+ insertP := perm(count)
+ var tr, trReset btree
+ for _, item := range insertP {
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+ resetTimer(tb)
+ for i := 0; i < iters; i++ {
+ item := insertP[i%count]
+ if reset {
+ trReset.Reset()
+ trReset = tr
+ }
+ tr = tr.Clone()
+ tr.Delete(item)
+ verify(tb, &tr)
+ tr.Set(item)
+ verify(tb, &tr)
+ }
+ })
+ }
+}
+
+// randN returns a random integer in the range [min, max).
+func randN(min, max int) int { return rand.Intn(max-min) + min }
+func randCount() int {
+ if testing.Short() {
+ return randN(1, 128)
+ }
+ return randN(1, 1024)
+}
+
+func TestBTreeInsert(t *testing.T) {
+ count := randCount()
+ runBTreeInsert(t, count)
+}
+
+func TestBTreeDelete(t *testing.T) {
+ count := randCount()
+ runBTreeDelete(t, count)
+}
+
+func TestBTreeDeleteInsert(t *testing.T) {
+ count := randCount()
+ runBTreeDeleteInsert(t, count)
+}
+
+func TestBTreeDeleteInsertCloneOnce(t *testing.T) {
+ count := randCount()
+ runBTreeDeleteInsertCloneOnce(t, count)
+}
+
+func TestBTreeDeleteInsertCloneEachTime(t *testing.T) {
+ count := randCount()
+ runBTreeDeleteInsertCloneEachTime(t, count)
+}
+
+// TestBTreeSeekOverlapRandom tests btree iterator overlap operations using
+// randomized input.
+func TestBTreeSeekOverlapRandom(t *testing.T) {
+ rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
+
+ const trials = 10
+ for i := 0; i < trials; i++ {
+ var tr btree
+
+ const count = 1000
+ items := make([]*bufferedWrite, count)
+ itemSpans := make([]int, count)
+ for j := 0; j < count; j++ {
+ var item *bufferedWrite
+ end := rng.Intn(count + 10)
+ if end <= j {
+ end = j
+ item = newItem(spanWithEnd(j, end))
+ } else {
+ item = newItem(spanWithEnd(j, end+1))
+ }
+ tr.Set(item)
+ items[j] = item
+ itemSpans[j] = end
+ }
+
+ const scanTrials = 100
+ for j := 0; j < scanTrials; j++ {
+ var scanItem *bufferedWrite
+ scanStart := rng.Intn(count)
+ scanEnd := rng.Intn(count + 10)
+ if scanEnd <= scanStart {
+ scanEnd = scanStart
+ scanItem = newItem(spanWithEnd(scanStart, scanEnd))
+ } else {
+ scanItem = newItem(spanWithEnd(scanStart, scanEnd+1))
+ }
+
+ var exp, found []*bufferedWrite
+ for startKey, endKey := range itemSpans {
+ if startKey <= scanEnd && endKey >= scanStart {
+ exp = append(exp, items[startKey])
+ }
+ }
+
+ it := tr.MakeIter()
+ it.FirstOverlap(scanItem)
+ for it.Valid() {
+ found = append(found, it.Cur())
+ it.NextOverlap(scanItem)
+ }
+
+ require.Equal(t, len(exp), len(found), "search for %v", spanFromItem(scanItem))
+ }
+ }
+}
+
+// TestBTreeCloneConcurrentOperations tests that cloning a btree returns a new
+// btree instance which is an exact logical copy of the original but that can be
+// modified independently going forward.
+func TestBTreeCloneConcurrentOperations(t *testing.T) {
+ const cloneTestSize = 1000
+ p := perm(cloneTestSize)
+
+ var trees []*btree
+ treeC, treeDone := make(chan *btree), make(chan struct{})
+ go func() {
+ for b := range treeC {
+ trees = append(trees, b)
+ }
+ close(treeDone)
+ }()
+
+ var wg sync.WaitGroup
+ var populate func(tr *btree, start int)
+ populate = func(tr *btree, start int) {
+ t.Logf("Starting new clone at %v", start)
+ treeC <- tr
+ for i := start; i < cloneTestSize; i++ {
+ tr.Set(p[i])
+ if i%(cloneTestSize/5) == 0 {
+ wg.Add(1)
+ c := tr.Clone()
+ go populate(&c, i+1)
+ }
+ }
+ wg.Done()
+ }
+
+ wg.Add(1)
+ var tr btree
+ go populate(&tr, 0)
+ wg.Wait()
+ close(treeC)
+ <-treeDone
+
+ t.Logf("Starting equality checks on %d trees", len(trees))
+ want := rang(0, cloneTestSize-1)
+ for i, tree := range trees {
+ if !reflect.DeepEqual(want, all(tree)) {
+ t.Errorf("tree %v mismatch", i)
+ }
+ }
+
+ t.Log("Removing half of items from first half")
+ toRemove := want[cloneTestSize/2:]
+ for i := 0; i < len(trees)/2; i++ {
+ tree := trees[i]
+ wg.Add(1)
+ go func() {
+ for _, item := range toRemove {
+ tree.Delete(item)
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ t.Log("Checking all values again")
+ for i, tree := range trees {
+ var wantpart []*bufferedWrite
+ if i < len(trees)/2 {
+ wantpart = want[:cloneTestSize/2]
+ } else {
+ wantpart = want
+ }
+ if got := all(tree); !reflect.DeepEqual(wantpart, got) {
+ t.Errorf("tree %v mismatch, want %v got %v", i, len(want), len(got))
+ }
+ }
+}
+
+//////////////////////////////////////////
+// Benchmarks //
+//////////////////////////////////////////
+
+func forBenchmarkSizes(b *testing.B, f func(b *testing.B, count int)) {
+ for _, count := range []int{16, 128, 1024, 8192, 65536} {
+ b.Run(fmt.Sprintf("count=%d", count), func(b *testing.B) {
+ f(b, count)
+ })
+ }
+}
+
+// BenchmarkBTreeInsert measures btree insertion performance.
+func BenchmarkBTreeInsert(b *testing.B) {
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ runBTreeInsert(b, count)
+ })
+}
+
+// BenchmarkBTreeDelete measures btree deletion performance.
+func BenchmarkBTreeDelete(b *testing.B) {
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ runBTreeDelete(b, count)
+ })
+}
+
+// BenchmarkBTreeDeleteInsert measures btree deletion and insertion performance.
+func BenchmarkBTreeDeleteInsert(b *testing.B) {
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ runBTreeDeleteInsert(b, count)
+ })
+}
+
+// BenchmarkBTreeDeleteInsertCloneOnce measures btree deletion and insertion
+// performance after the tree has been copy-on-write cloned once.
+func BenchmarkBTreeDeleteInsertCloneOnce(b *testing.B) {
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ runBTreeDeleteInsertCloneOnce(b, count)
+ })
+}
+
+// BenchmarkBTreeDeleteInsertCloneEachTime measures btree deletion and insertion
+// performance while the tree is repeatedly copy-on-write cloned.
+func BenchmarkBTreeDeleteInsertCloneEachTime(b *testing.B) {
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ runBTreeDeleteInsertCloneEachTime(b, count)
+ })
+}
+
+// BenchmarkBTreeMakeIter measures the cost of creating a btree iterator.
+func BenchmarkBTreeMakeIter(b *testing.B) {
+ var tr btree
+ for i := 0; i < b.N; i++ {
+ it := tr.MakeIter()
+ it.First()
+ }
+}
+
+// BenchmarkBTreeIterSeekGE measures the cost of seeking a btree iterator
+// forward.
+func BenchmarkBTreeIterSeekGE(b *testing.B) {
+ rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ var spans []roachpb.Span
+ var tr btree
+
+ for i := 0; i < count; i++ {
+ s := span(i)
+ spans = append(spans, s)
+ tr.Set(newItem(s))
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ s := spans[rng.Intn(len(spans))]
+ it := tr.MakeIter()
+ it.SeekGE(newItem(s))
+ if testing.Verbose() {
+ if !it.Valid() {
+ b.Fatal("expected to find key")
+ }
+ if !s.Equal(spanFromItem(it.Cur())) {
+ b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur()))
+ }
+ }
+ }
+ })
+}
+
+// BenchmarkBTreeIterSeekLT measures the cost of seeking a btree iterator
+// backward.
+func BenchmarkBTreeIterSeekLT(b *testing.B) {
+ rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ var spans []roachpb.Span
+ var tr btree
+
+ for i := 0; i < count; i++ {
+ s := span(i)
+ spans = append(spans, s)
+ tr.Set(newItem(s))
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ j := rng.Intn(len(spans))
+ s := spans[j]
+ it := tr.MakeIter()
+ it.SeekLT(newItem(s))
+ if testing.Verbose() {
+ if j == 0 {
+ if it.Valid() {
+ b.Fatal("unexpected key")
+ }
+ } else {
+ if !it.Valid() {
+ b.Fatal("expected to find key")
+ }
+ s := spans[j-1]
+ if !s.Equal(spanFromItem(it.Cur())) {
+ b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur()))
+ }
+ }
+ }
+ }
+ })
+}
+
+// BenchmarkBTreeIterFirstOverlap measures the cost of finding a single
+// overlapping item using a btree iterator.
+func BenchmarkBTreeIterFirstOverlap(b *testing.B) {
+ rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
+ forBenchmarkSizes(b, func(b *testing.B, count int) {
+ var spans []roachpb.Span
+ var tr btree
+
+ for i := 0; i < count; i++ {
+ s := spanWithEnd(i, i+1)
+ spans = append(spans, s)
+ tr.Set(newItem(s))
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ j := rng.Intn(len(spans))
+ s := spans[j]
+ it := tr.MakeIter()
+ it.FirstOverlap(newItem(s))
+ if testing.Verbose() {
+ if !it.Valid() {
+ b.Fatal("expected to find key")
+ }
+ if !s.Equal(spanFromItem(it.Cur())) {
+ b.Fatalf("expected %s, but found %s", s, spanFromItem(it.Cur()))
+ }
+ }
+ }
+ })
+}
+
+// BenchmarkBTreeIterNext measures the cost of seeking a btree iterator to the
+// next item in the tree.
+func BenchmarkBTreeIterNext(b *testing.B) {
+ var tr btree
+
+ const count = 8 << 10
+ const size = 2 * maxItems
+ for i := 0; i < count; i++ {
+ item := newItem(spanWithEnd(i, i+size+1))
+ tr.Set(item)
+ }
+
+ it := tr.MakeIter()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if !it.Valid() {
+ it.First()
+ }
+ it.Next()
+ }
+}
+
+// BenchmarkBTreeIterPrev measures the cost of seeking a btree iterator to the
+// previous item in the tree.
+func BenchmarkBTreeIterPrev(b *testing.B) {
+ var tr btree
+
+ const count = 8 << 10
+ const size = 2 * maxItems
+ for i := 0; i < count; i++ {
+ item := newItem(spanWithEnd(i, i+size+1))
+ tr.Set(item)
+ }
+
+ it := tr.MakeIter()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if !it.Valid() {
+ it.Last()
+ }
+ it.Prev()
+ }
+}
+
+// BenchmarkBTreeIterNextOverlap measures the cost of seeking a btree iterator
+// to the next overlapping item in the tree.
+func BenchmarkBTreeIterNextOverlap(b *testing.B) {
+ var tr btree
+
+ const count = 8 << 10
+ const size = 2 * maxItems
+ for i := 0; i < count; i++ {
+ item := newItem(spanWithEnd(i, i+size+1))
+ tr.Set(item)
+ }
+
+ allCmd := newItem(spanWithEnd(0, count+1))
+ it := tr.MakeIter()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ if !it.Valid() {
+ it.FirstOverlap(allCmd)
+ }
+ it.NextOverlap(allCmd)
+ }
+}
+
+// BenchmarkBTreeIterOverlapScan measures the cost of scanning over all
+// overlapping items using a btree iterator.
+func BenchmarkBTreeIterOverlapScan(b *testing.B) {
+ var tr btree
+ rng := rand.New(rand.NewSource(timeutil.Now().UnixNano()))
+
+ const count = 8 << 10
+ const size = 2 * maxItems
+ for i := 0; i < count; i++ {
+ tr.Set(newItem(spanWithEnd(i, i+size+1)))
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ item := newItem(randomSpan(rng, count))
+ it := tr.MakeIter()
+ it.FirstOverlap(item)
+ for it.Valid() {
+ it.NextOverlap(item)
+ }
+ }
+}
diff --git a/pkg/kv/kvclient/kvcoord/txn_coord_sender.go b/pkg/kv/kvclient/kvcoord/txn_coord_sender.go
index b1318201806f..bc26c6f5ced0 100644
--- a/pkg/kv/kvclient/kvcoord/txn_coord_sender.go
+++ b/pkg/kv/kvclient/kvcoord/txn_coord_sender.go
@@ -161,9 +161,10 @@ type TxnCoordSender struct {
// additional heap allocations necessary.
interceptorStack []txnInterceptor
interceptorAlloc struct {
- arr [6]txnInterceptor
+ arr [7]txnInterceptor
txnHeartbeater
txnSeqNumAllocator
+ txnWriteBuffer
txnPipeliner
txnCommitter
txnSpanRefresher
@@ -275,6 +276,7 @@ func newRootTxnCoordSender(
// Various interceptors below rely on sequence number allocation,
// so the sequence number allocator is near the top of the stack.
&tcs.interceptorAlloc.txnSeqNumAllocator,
+ &tcs.interceptorAlloc.txnWriteBuffer,
// The pipeliner sits above the span refresher because it will
// never generate transaction retry errors that could be avoided
// with a refresh.
@@ -312,6 +314,9 @@ func (tc *TxnCoordSender) initCommonInterceptors(
if ds, ok := tcf.wrapped.(*DistSender); ok {
riGen.ds = ds
}
+ tc.interceptorAlloc.txnWriteBuffer = txnWriteBuffer{
+ enabled: bufferedWritesEnabled.Get(&tcf.st.SV),
+ }
tc.interceptorAlloc.txnPipeliner = txnPipeliner{
st: tcf.st,
riGen: riGen,
@@ -500,7 +505,7 @@ func (tc *TxnCoordSender) Send(
return nil, pErr
}
- if ba.IsSingleEndTxnRequest() && !tc.interceptorAlloc.txnPipeliner.hasAcquiredLocks() {
+ if ba.IsSingleEndTxnRequest() && (!tc.interceptorAlloc.txnPipeliner.hasAcquiredLocks() && tc.interceptorAlloc.txnWriteBuffer.buf.Len() == 0) {
return nil, tc.finalizeNonLockingTxnLocked(ctx, ba)
}
diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go
index 6e6407603562..88cc8549dd68 100644
--- a/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go
+++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_pipeliner.go
@@ -21,10 +21,10 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
- "github.com/google/btree"
+ gbtree "github.com/google/btree"
)
-// The degree of the inFlightWrites btree.
+// The degree of the inFlightWrites gbtree.
const txnPipelinerBtreeDegree = 32
// PipelinedWritesEnabled is the kv.transaction.write_pipelining.enabled cluster setting.
@@ -1048,12 +1048,12 @@ func makeInFlightWrite(key roachpb.Key, seq enginepb.TxnSeq, str lock.Strength)
}}
}
-// Less implements the btree.Item interface.
+// Less implements the gbtree.Item interface.
//
// inFlightWrites are ordered by Key, then by Sequence, then by Strength. Two
// inFlightWrites with the same Key but different Sequences and/or Strengths are
// not considered equal and are maintained separately in the inFlightWritesSet.
-func (a *inFlightWrite) Less(bItem btree.Item) bool {
+func (a *inFlightWrite) Less(bItem gbtree.Item) bool {
b := bItem.(*inFlightWrite)
kCmp := a.Key.Compare(b.Key)
if kCmp != 0 {
@@ -1077,7 +1077,7 @@ func (a *inFlightWrite) Less(bItem btree.Item) bool {
// writes, O(log n) removal of existing in-flight writes, and O(m + log n)
// retrieval over m in-flight writes that overlap with a given key.
type inFlightWriteSet struct {
- t *btree.BTree
+ t *gbtree.BTree
bytes int64
// Avoids allocs.
@@ -1090,7 +1090,7 @@ type inFlightWriteSet struct {
func (s *inFlightWriteSet) insert(key roachpb.Key, seq enginepb.TxnSeq, str lock.Strength) {
if s.t == nil {
// Lazily initialize btree.
- s.t = btree.New(txnPipelinerBtreeDegree)
+ s.t = gbtree.New(txnPipelinerBtreeDegree)
}
w := s.alloc.alloc(key, seq, str)
@@ -1136,7 +1136,7 @@ func (s *inFlightWriteSet) ascend(f func(w *inFlightWrite)) {
// Set is empty.
return
}
- s.t.Ascend(func(i btree.Item) bool {
+ s.t.Ascend(func(i gbtree.Item) bool {
f(i.(*inFlightWrite))
return true
})
@@ -1157,7 +1157,7 @@ func (s *inFlightWriteSet) ascendRange(start, end roachpb.Key, f func(w *inFligh
// Range lookup.
s.tmp2 = makeInFlightWrite(end, 0, 0)
}
- s.t.AscendRange(&s.tmp1, &s.tmp2, func(i btree.Item) bool {
+ s.t.AscendRange(&s.tmp1, &s.tmp2, func(i gbtree.Item) bool {
f(i.(*inFlightWrite))
return true
})
diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go
new file mode 100644
index 000000000000..7dcffb24aa35
--- /dev/null
+++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_write_buffer.go
@@ -0,0 +1,499 @@
+// Copyright 2024 The Cockroach Authors.
+//
+// Use of this software is governed by the Business Source License
+// included in the file licenses/BSL.txt.
+//
+// As of the Change Date specified in that file, in accordance with
+// the Business Source License, use of this software will be governed
+// by the Apache License, Version 2.0, included in the file
+// licenses/APL.txt.
+
+package kvcoord
+
+import (
+ "bytes"
+ "context"
+ "sort"
+
+ "github.com/cockroachdb/cockroach/pkg/kv/kvpb"
+ "github.com/cockroachdb/cockroach/pkg/roachpb"
+ "github.com/cockroachdb/cockroach/pkg/settings"
+ "github.com/cockroachdb/cockroach/pkg/storage/enginepb"
+ "github.com/cockroachdb/cockroach/pkg/util/log"
+)
+
+var bufferedWritesEnabled = settings.RegisterBoolSetting(
+ settings.ApplicationLevel,
+ "kv.transaction.buffered_writes.enabled",
+ "if enabled, transactional writes are buffered on the gateway",
+ true,
+ settings.WithPublic,
+)
+
+// txnWriteBuffer is a txnInterceptor that buffers writes for a transaction
+// before sending them during commit to the wrapped lockedSender. Buffering
+// writes client side has four main benefits:
+//
+// 1. It allows for more batching of writes, which can be more efficient.
+// Instead of sending writes one at a time, we can batch them up and send
+// them all at once. This is a win even if writes would otherwise be
+// pipelined through consensus.
+//
+// 2. It allows for the elimination of redundant writes. If a client writes to
+// the same key multiple times in a transaction, only the last write needs
+// to be written to the key-value layer.
+//
+// 3. It allows the client to serve read-your-writes locally, which can be much
+// faster and cheaper than sending them to the leaseholder. This is
+// especially true when the leaseholder is not collocated with the client.
+//
+// By serving read-your-writes locally from the gateway, write buffering
+// also avoids the problem of pipeline stalls that can occur when a client
+// reads a pipelined intent write before the write has completed consensus.
+// For details on pipeline stalls, see txnPipeliner.
+//
+// 4. It allows clients to passively hit the 1-phase commit fast-path, instead
+// of requiring clients to carefully construct "auto-commit" BatchRequests
+// to make us of the optimization. By buffering writes on the client before
+// commit, we avoid immediately disabling the fast-path when the client
+// issues their first write. Then, at commit time, we flush the buffer and
+// will happen to hit the 1-phase commit fast path if all writes go to the
+// same range.
+//
+// However, buffering writes comes with some challenges.
+//
+// The first is that read-only and read-write requests need to be aware of the
+// buffered writes, as they may need to serve reads from the buffer.
+// TODO: discuss distributed execution.
+//
+// The second is that savepoints need to be handled correctly. TODO...
+//
+// The third is that the buffer needs to adhere to memory limits. TODO ...
+type txnWriteBuffer struct {
+ wrapped lockedSender
+ enabled bool
+
+ buf btree
+ bufSeek bufferedWrite
+ bufIDAlloc uint64
+}
+
+// SendLocked implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) SendLocked(
+ ctx context.Context, ba *kvpb.BatchRequest,
+) (*kvpb.BatchResponse, *kvpb.Error) {
+ if !twb.enabled {
+ return twb.wrapped.SendLocked(ctx, ba)
+ }
+
+ if _, ok := ba.GetArg(kvpb.EndTxn); ok {
+ return twb.flushWithEndTxn(ctx, ba)
+ }
+
+ baRemote, brLocal, localPositions, pErr := twb.servePointReadsAndAddWritesToBuffer(ctx, ba)
+ if pErr != nil {
+ return nil, pErr
+ }
+ if len(baRemote.Requests) == 0 {
+ return brLocal, nil
+ }
+
+ brRemote, pErr := twb.wrapped.SendLocked(ctx, baRemote)
+ if pErr != nil {
+ return nil, pErr
+ }
+
+ twb.augmentRangedReadsFromBuffer(ctx, baRemote, brRemote)
+ br := twb.mergeLocalAndRemoteResponses(ctx, brLocal, brRemote, localPositions)
+ return br, nil
+}
+
+// flushWithEndTxn flushes the write buffer with the EndTxn request.
+func (twb *txnWriteBuffer) flushWithEndTxn(
+ ctx context.Context, ba *kvpb.BatchRequest,
+) (*kvpb.BatchResponse, *kvpb.Error) {
+ flushed := twb.buf.Len()
+ if flushed > 0 {
+ reqs := make([]kvpb.RequestUnion, 0, flushed+len(ba.Requests))
+ it := twb.buf.MakeIter()
+ for it.First(); it.Valid(); it.Next() {
+ reqs = append(reqs, it.Cur().toRequest())
+ }
+ sort.SliceStable(reqs, func(i, j int) bool {
+ return reqs[i].GetInner().Header().Sequence < reqs[j].GetInner().Header().Sequence
+ })
+ reqs = append(reqs, ba.Requests...)
+ ba.Requests = reqs
+ }
+
+ br, pErr := twb.wrapped.SendLocked(ctx, ba)
+ if pErr != nil {
+ return nil, pErr
+ }
+ br.Responses = br.Responses[flushed:]
+ return br, nil
+}
+
+// servePointReadsAndAddWritesToBuffer serves point reads from the buffer and
+// adds blind writes to the buffer. It returns a BatchRequest with the locally
+// serviceable requests removed, a BatchResponse with the locally serviceable
+// requests' responses, and a slice of the positions of the locally serviceable
+// requests in the original BatchRequest.
+func (twb *txnWriteBuffer) servePointReadsAndAddWritesToBuffer(
+ ctx context.Context, ba *kvpb.BatchRequest,
+) (
+ baRemote *kvpb.BatchRequest,
+ brLocal *kvpb.BatchResponse,
+ localPositions []int,
+ pErr *kvpb.Error,
+) {
+ baRemote = ba.ShallowCopy()
+ baRemote.Requests = nil
+ brLocal = &kvpb.BatchResponse{}
+ for i, ru := range ba.Requests {
+ req := ru.GetInner()
+ seek := twb.makeBufferSeekFor(req.Header())
+ prevLocalLen := len(brLocal.Responses)
+ switch t := req.(type) {
+ case *kvpb.GetRequest:
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if it.Valid() {
+ var ru kvpb.ResponseUnion
+ getResp := &kvpb.GetResponse{}
+ if it.Cur().val.IsPresent() {
+ getResp.Value = it.Cur().valPtr()
+ }
+ ru.MustSetInner(getResp)
+ brLocal.Responses = append(brLocal.Responses, ru)
+ } else {
+ baRemote.Requests = append(baRemote.Requests, ru)
+ }
+
+ case *kvpb.ScanRequest:
+ // Hack: just flush overlapping writes for now.
+ var baFlush *kvpb.BatchRequest
+ for {
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if !it.Valid() {
+ break
+ }
+ if baFlush == nil {
+ baFlush = baRemote.ShallowCopy()
+ baFlush.Requests = nil
+ baFlush.MaxSpanRequestKeys = 0
+ baFlush.TargetBytes = 0
+ }
+ baFlush.Requests = append(baFlush.Requests, it.Cur().toRequest())
+ twb.buf.Delete(it.Cur())
+ }
+ if baFlush != nil {
+ sort.SliceStable(baFlush.Requests, func(i, j int) bool {
+ return baFlush.Requests[i].GetInner().Header().Sequence < baFlush.Requests[j].GetInner().Header().Sequence
+ })
+ brFlush, pErr := twb.wrapped.SendLocked(ctx, baFlush)
+ if pErr != nil {
+ return nil, nil, nil, pErr
+ }
+ baRemote.Txn.Update(brFlush.Txn)
+ }
+ // Send the request, then augment the response.
+ baRemote.Requests = append(baRemote.Requests, ru)
+
+ case *kvpb.ReverseScanRequest:
+ // Hack: just flush overlapping writes for now.
+ var baFlush *kvpb.BatchRequest
+ for {
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if !it.Valid() {
+ break
+ }
+ if baFlush == nil {
+ baFlush = baRemote.ShallowCopy()
+ baFlush.Requests = nil
+ baFlush.MaxSpanRequestKeys = 0
+ baFlush.TargetBytes = 0
+ }
+ baFlush.Requests = append(baFlush.Requests, it.Cur().toRequest())
+ twb.buf.Delete(it.Cur())
+ }
+ if baFlush != nil {
+ sort.SliceStable(baFlush.Requests, func(i, j int) bool {
+ return baFlush.Requests[i].GetInner().Header().Sequence < baFlush.Requests[j].GetInner().Header().Sequence
+ })
+ brFlush, pErr := twb.wrapped.SendLocked(ctx, baFlush)
+ if pErr != nil {
+ return nil, nil, nil, pErr
+ }
+ baRemote.Txn.Update(brFlush.Txn)
+ }
+ // Send the request, then augment the response.
+ baRemote.Requests = append(baRemote.Requests, ru)
+
+ case *kvpb.PutRequest:
+ var ru kvpb.ResponseUnion
+ ru.MustSetInner(&kvpb.PutResponse{})
+ brLocal.Responses = append(brLocal.Responses, ru)
+
+ twb.addToBuffer(t.Key, t.Value, t.Sequence)
+
+ case *kvpb.DeleteRequest:
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ var ru kvpb.ResponseUnion
+ ru.MustSetInner(&kvpb.DeleteResponse{
+ // NOTE: this is incorrect. We aren't considering values that are
+ // present in the KV store. This is fine for the prototype.
+ FoundKey: it.Valid() && it.Cur().val.IsPresent(),
+ })
+ brLocal.Responses = append(brLocal.Responses, ru)
+
+ twb.addToBuffer(t.Key, roachpb.Value{}, t.Sequence)
+
+ case *kvpb.ConditionalPutRequest:
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if it.Valid() {
+ expBytes := t.ExpBytes
+ existVal := it.Cur().val
+ if expValPresent, existValPresent := len(expBytes) != 0, existVal.IsPresent(); expValPresent && existValPresent {
+ if !bytes.Equal(expBytes, existVal.TagAndDataBytes()) {
+ return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{
+ ActualValue: it.Cur().valPtr(),
+ })
+ }
+ } else if expValPresent != existValPresent && (existValPresent || !t.AllowIfDoesNotExist) {
+ return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{
+ ActualValue: it.Cur().valPtr(),
+ })
+ }
+ var ru kvpb.ResponseUnion
+ ru.MustSetInner(&kvpb.ConditionalPutResponse{})
+ brLocal.Responses = append(brLocal.Responses, ru)
+
+ twb.addToBuffer(t.Key, t.Value, t.Sequence)
+ } else {
+ baRemote.Requests = append(baRemote.Requests, ru)
+ }
+
+ case *kvpb.InitPutRequest:
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if it.Valid() {
+ failOnTombstones := t.FailOnTombstones
+ existVal := it.Cur().val
+ if failOnTombstones && !existVal.IsPresent() {
+ // We found a tombstone and failOnTombstones is true: fail.
+ return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{
+ ActualValue: it.Cur().valPtr(),
+ })
+ }
+ if existVal.IsPresent() && !existVal.EqualTagAndData(t.Value) {
+ // The existing value does not match the supplied value.
+ return nil, nil, nil, kvpb.NewError(&kvpb.ConditionFailedError{
+ ActualValue: it.Cur().valPtr(),
+ })
+ }
+ var ru kvpb.ResponseUnion
+ ru.MustSetInner(&kvpb.InitPutResponse{})
+ brLocal.Responses = append(brLocal.Responses, ru)
+
+ twb.addToBuffer(t.Key, t.Value, t.Sequence)
+ } else {
+ baRemote.Requests = append(baRemote.Requests, ru)
+ }
+
+ case *kvpb.IncrementRequest:
+ it := twb.buf.MakeIter()
+ it.FirstOverlap(seek)
+ if it.Valid() {
+ log.Fatalf(ctx, "unhandled buffered write overlap with increment")
+ }
+ baRemote.Requests = append(baRemote.Requests, ru)
+
+ case *kvpb.DeleteRangeRequest:
+ it := twb.buf.MakeIter()
+ for it.FirstOverlap(seek); it.Valid(); it.NextOverlap(seek) {
+ log.Fatalf(ctx, "unhandled buffered write overlap with delete range")
+ }
+ baRemote.Requests = append(baRemote.Requests, ru)
+
+ default:
+ log.Fatalf(ctx, "unexpected request type: %T", req)
+ }
+ if len(brLocal.Responses) != prevLocalLen {
+ localPositions = append(localPositions, i)
+ }
+ }
+ return baRemote, brLocal, localPositions, nil
+}
+
+// augmentRangedReadsFromBuffer augments the responses to ranged reads in the
+// BatchResponse with reads from the buffer.
+func (twb *txnWriteBuffer) augmentRangedReadsFromBuffer(
+ ctx context.Context, baRemote *kvpb.BatchRequest, brRemote *kvpb.BatchResponse,
+) {
+ for _, ru := range baRemote.Requests {
+ req := ru.GetInner()
+ switch req.(type) {
+ case *kvpb.ScanRequest, *kvpb.ReverseScanRequest:
+ // Send the request, then augment the response.
+ seek := twb.makeBufferSeekFor(req.Header())
+ it := twb.buf.MakeIter()
+ for it.FirstOverlap(seek); it.Valid(); it.NextOverlap(seek) {
+ log.Fatalf(ctx, "unhandled buffered write overlap with scan / reverse scan")
+ }
+ }
+ }
+}
+
+// mergeLocalAndRemoteResponses merges the responses to locally serviceable
+// requests with the responses to remotely serviceable requests. It returns the
+// merged BatchResponse.
+func (twb *txnWriteBuffer) mergeLocalAndRemoteResponses(
+ ctx context.Context, brLocal, brRemote *kvpb.BatchResponse, localPositions []int,
+) *kvpb.BatchResponse {
+ if brLocal == nil {
+ return brRemote
+ }
+ mergedResps := make([]kvpb.ResponseUnion, len(brLocal.Responses)+len(brRemote.Responses))
+ for i := range mergedResps {
+ if len(localPositions) > 0 && i == localPositions[0] {
+ mergedResps[i] = brLocal.Responses[0]
+ brLocal.Responses = brLocal.Responses[1:]
+ localPositions = localPositions[1:]
+ } else {
+ mergedResps[i] = brRemote.Responses[0]
+ brRemote.Responses = brRemote.Responses[1:]
+ }
+ }
+ if len(brRemote.Responses) > 0 || len(brLocal.Responses) > 0 || len(localPositions) > 0 {
+ log.Fatalf(ctx, "unexpected leftover responses: %d remote, %d local, %d positions",
+ len(brRemote.Responses), len(brLocal.Responses), len(localPositions))
+ }
+ brRemote.Responses = mergedResps
+ return brRemote
+}
+
+// setWrapped implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) setWrapped(wrapped lockedSender) {
+ twb.wrapped = wrapped
+}
+
+// populateLeafInputState implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) populateLeafInputState(*roachpb.LeafTxnInputState) {
+ // TODO(nvanbenschoten): send buffered writes to LeafTxns.
+}
+
+// populateLeafFinalState implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) populateLeafFinalState(*roachpb.LeafTxnFinalState) {
+ // TODO(nvanbenschoten): ingest buffered writes in LeafTxns.
+}
+
+// importLeafFinalState implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) importLeafFinalState(context.Context, *roachpb.LeafTxnFinalState) error {
+ return nil
+}
+
+// epochBumpedLocked implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) epochBumpedLocked() {
+ twb.buf.Reset()
+}
+
+// createSavepointLocked implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) createSavepointLocked(ctx context.Context, s *savepoint) {}
+
+// rollbackToSavepointLocked implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) rollbackToSavepointLocked(ctx context.Context, s savepoint) {
+ // TODO(nvanbenschoten): clear out writes after the savepoint. This means that
+ // we need to retain multiple writes on the same key, so that we can roll one
+ // back if a savepoint is rolled back. That complicates logic above, but not
+ // overly so.
+}
+
+// closeLocked implements the txnInterceptor interface.
+func (twb *txnWriteBuffer) closeLocked() {
+ twb.buf.Reset()
+}
+
+func (twb *txnWriteBuffer) makeBufferSeekFor(rh kvpb.RequestHeader) *bufferedWrite {
+ seek := &twb.bufSeek
+ seek.key = rh.Key
+ seek.endKey = rh.EndKey
+ return seek
+}
+
+func (twb *txnWriteBuffer) addToBuffer(key roachpb.Key, val roachpb.Value, seq enginepb.TxnSeq) {
+ it := twb.buf.MakeIter()
+ seek := &twb.bufSeek
+ seek.key = key
+ it.FirstOverlap(seek)
+ if it.Valid() {
+ // If we have a write for the same key, update it. This is incorrect, as
+ // it does not handle savepoints, but it makes the prototype simpler.
+ bw := it.Cur()
+ bw.val = val
+ bw.seq = seq
+ } else {
+ twb.bufIDAlloc++
+ twb.buf.Set(&bufferedWrite{
+ id: twb.bufIDAlloc,
+ key: key,
+ val: val,
+ seq: seq,
+ })
+ }
+}
+
+// bufferedWrite is a key-value pair with an associated sequence number.
+type bufferedWrite struct {
+ id uint64
+ key roachpb.Key
+ endKey roachpb.Key // used in btree iteration
+ val roachpb.Value
+ seq enginepb.TxnSeq
+}
+
+//go:generate ../../../util/interval/generic/gen.sh *bufferedWrite kvcoord
+
+// Methods required by util/interval/generic type contract.
+func (bw *bufferedWrite) ID() uint64 { return bw.id }
+func (bw *bufferedWrite) Key() []byte { return bw.key }
+func (bw *bufferedWrite) EndKey() []byte { return bw.endKey }
+func (bw *bufferedWrite) String() string { return "todo" }
+func (bw *bufferedWrite) New() *bufferedWrite { return new(bufferedWrite) }
+func (bw *bufferedWrite) SetID(v uint64) { bw.id = v }
+func (bw *bufferedWrite) SetKey(v []byte) { bw.key = v }
+func (bw *bufferedWrite) SetEndKey(v []byte) { bw.endKey = v }
+
+func (bw *bufferedWrite) toRequest() kvpb.RequestUnion {
+ var ru kvpb.RequestUnion
+ if bw.val.IsPresent() {
+ putAlloc := new(struct {
+ put kvpb.PutRequest
+ union kvpb.RequestUnion_Put
+ })
+ putAlloc.put.Key = bw.key
+ putAlloc.put.Value = bw.val
+ putAlloc.put.Sequence = bw.seq
+ putAlloc.union.Put = &putAlloc.put
+ ru.Value = &putAlloc.union
+ } else {
+ delAlloc := new(struct {
+ del kvpb.DeleteRequest
+ union kvpb.RequestUnion_Delete
+ })
+ delAlloc.del.Key = bw.key
+ delAlloc.del.Sequence = bw.seq
+ delAlloc.union.Delete = &delAlloc.del
+ ru.Value = &delAlloc.union
+ }
+ return ru
+}
+
+func (bw *bufferedWrite) valPtr() *roachpb.Value {
+ valCpy := bw.val
+ return &valCpy
+}
diff --git a/pkg/sql/row/inserter.go b/pkg/sql/row/inserter.go
index e9a5a557892a..c45a7f2c0c6d 100644
--- a/pkg/sql/row/inserter.go
+++ b/pkg/sql/row/inserter.go
@@ -201,7 +201,7 @@ func (ri *Inserter) InsertRow(
for i := range entries {
e := &entries[i]
- if ri.Helper.Indexes[idx].ForcePut() {
+ if ri.Helper.Indexes[idx].ForcePut() || true {
// See the comment on (catalog.Index).ForcePut() for more details.
insertPutFn(ctx, b, &e.Key, &e.Value, traceKV)
} else {
diff --git a/pkg/sql/row/updater.go b/pkg/sql/row/updater.go
index 55cb1e782c0b..d06bbfbee850 100644
--- a/pkg/sql/row/updater.go
+++ b/pkg/sql/row/updater.go
@@ -421,18 +421,14 @@ func (ru *Updater) UpdateRow(
continue
}
- if index.ForcePut() {
+ if index.ForcePut() || expValue == nil {
// See the comment on (catalog.Index).ForcePut() for more details.
insertPutFn(ctx, putter, &newEntry.Key, &newEntry.Value, traceKV)
} else {
if traceKV {
k := keys.PrettyPrint(ru.Helper.secIndexValDirs[i], newEntry.Key)
v := newEntry.Value.PrettyPrint()
- if expValue != nil {
- log.VEventf(ctx, 2, "CPut %s -> %v (replacing %v, if exists)", k, v, oldEntry.Value.PrettyPrint())
- } else {
- log.VEventf(ctx, 2, "CPut %s -> %v (expecting does not exist)", k, v)
- }
+ log.VEventf(ctx, 2, "CPut %s -> %v (replacing %v, if exists)", k, v, oldEntry.Value.PrettyPrint())
}
batch.CPutAllowingIfNotExists(newEntry.Key, &newEntry.Value, expValue)
}