From 9cf730da476987fa6a796e7fc095f1000243de01 Mon Sep 17 00:00:00 2001 From: Pieter Loubser Date: Tue, 26 Nov 2024 13:32:32 +0000 Subject: [PATCH] Add consumer and stream balancer Here we add the `balancer` package that allows us to balance stream and consumer leaders over a set of servers. The balancer will determine what an even distribution of stream/consumer leaders are over the set of available servers are, and then trigger leader elections to rebalance the distribution. A perfect distribution is not guaranteed after a run. --- balancer/balancer.go | 278 ++++++++++++++++++++++++++++++++++++++ balancer/balancer_test.go | 191 ++++++++++++++++++++++++++ consumers.go | 9 ++ go.mod | 1 + go.sum | 2 + streams.go | 9 ++ 6 files changed, 490 insertions(+) create mode 100644 balancer/balancer.go create mode 100644 balancer/balancer_test.go diff --git a/balancer/balancer.go b/balancer/balancer.go new file mode 100644 index 0000000..70f57d4 --- /dev/null +++ b/balancer/balancer.go @@ -0,0 +1,278 @@ +// Copyright 2024 The NATS Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balancer + +import ( + "context" + "fmt" + "math" + "slices" + "time" + + "github.com/nats-io/jsm.go" + "github.com/nats-io/jsm.go/api" + "github.com/nats-io/nats.go" + "golang.org/x/exp/rand" +) + +// Balancer is used to redistribute stream and consumer leaders in a cluster. +// The Balancer will first find all the leaders and peers for the given set of +// streams or consumers, and then determine an even distribution. If any of the +// servers is the leader for more than the even distribution, the balancer will +// step down a number of streams/consumers until the even distribution is met. +// Which streams/consumers are stepped down is determined randomly. +// If stepping down fails, or if the same server is elected the leader again, +// we will move on the next randomly selected server. If we get a second, similar +// failure the Balancer will return an error. +type Balancer struct { + nc *nats.Conn + log api.Logger +} + +type balanceEntity interface { + LeaderStepDown() error + Name() string + ClusterInfo() (api.ClusterInfo, error) +} + +type peer struct { + hostname string + entities []balanceEntity + leaderCount int + rebalance int +} + +// New returns a new instance of the Balancer +func New(nc *nats.Conn, log api.Logger) (*Balancer, error) { + return &Balancer{ + nc: nc, + log: log, + }, nil +} + +func (b *Balancer) updateServersWithExclude(servers map[string]*peer, exclude string) (map[string]*peer, error) { + updated := map[string]*peer{} + var err error + + for _, s := range servers { + if s.hostname == exclude { + continue + } + for _, e := range s.entities { + updated, err = b.mapEntityToServers(e, updated) + if err != nil { + return updated, err + } + } + } + + return updated, nil +} + +func (b *Balancer) getOvers(server map[string]*peer, evenDistribution int) { + for _, s := range server { + if s.leaderCount == 0 { + continue + } + + if over := s.leaderCount - evenDistribution; over > 0 { + s.rebalance = over + } + } +} + +func (b *Balancer) isBalanced(servers map[string]*peer) bool { + for _, s := range servers { + if s.rebalance > 0 { + return false + } + } + + return true +} + +func (b *Balancer) mapEntityToServers(entity balanceEntity, serverMap map[string]*peer) (map[string]*peer, error) { + info, err := entity.ClusterInfo() + if err != nil { + return nil, err + } + + leaderName := info.Leader + _, ok := serverMap[leaderName] + if !ok { + tmp := peer{ + hostname: leaderName, + entities: []balanceEntity{}, + leaderCount: 0, + } + serverMap[leaderName] = &tmp + } + serverMap[leaderName].entities = append(serverMap[leaderName].entities, entity) + serverMap[leaderName].leaderCount += 1 + + for _, replica := range info.Replicas { + _, ok = serverMap[replica.Name] + if !ok { + tmp := peer{ + hostname: replica.Name, + entities: []balanceEntity{}, + leaderCount: 0, + } + serverMap[replica.Name] = &tmp + } + } + + return serverMap, nil +} + +func (b *Balancer) calcDistribution(entities, servers int) int { + evenDistributionf := float64(entities) / float64(servers) + return int(math.Floor(evenDistributionf + 0.5)) +} + +func (b *Balancer) balance(servers map[string]*peer, evenDistribution int) (int, error) { + var err error + steppedDown := 0 + + for !b.isBalanced(servers) { + for _, s := range servers { + // skip servers that aren't leaders + if s.leaderCount == 0 { + continue + } + + if s.rebalance > 0 { + b.log.Infof("Found server '%s' with %d entities over the even distribution\n", s.hostname, s.rebalance) + // Now we have to kick a random selection of streams where number = rebalance + retries := 0 + for i := 0; i < s.rebalance; i++ { + randomIndex := rand.Intn(len(s.entities)) + entity := s.entities[randomIndex] + b.log.Infof("Requesting leader (%s) step down for %s", s.hostname, entity.Name()) + info, err := entity.ClusterInfo() + if err != nil { + return 0, err + } + + currentLeader := info.Leader + + err = entity.LeaderStepDown() + if err != nil { + b.log.Errorf("Unable to step down leader for %s - %s", entity.Name(), err) + // If we failed to step down the stream, decrement the iterator so that we don't kick one too few + // Limit this to one retry, if we can't step down multiple leaders something is wrong + if retries == 0 { + i-- + retries++ + continue + } + return 0, err + } else { + b.log.Infof("Successful step down %s ", entity.Name()) + } + + err = b.waitForLeaderUpdate(currentLeader, entity) + if err != nil { + // If leader election doesn't result in a new leader we will retry once before giving up + if retries == 0 { + i-- + retries++ + continue + } + return 0, err + } + retries = 0 + s.entities = slices.Delete(s.entities, randomIndex, randomIndex+1) + steppedDown += 1 + } + + // finally, if we rebalanced a server we update the servers list and start again, excluding the one we just rebalanced + servers, err = b.updateServersWithExclude(servers, s.hostname) + if err != nil { + return steppedDown, err + } + b.getOvers(servers, evenDistribution) + break + } + } + } + + return steppedDown, nil +} + +func (b *Balancer) waitForLeaderUpdate(currentLeader string, entity balanceEntity) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + info, err := entity.ClusterInfo() + if err != nil { + continue + } + if currentLeader == info.Leader { + return nil + } + + return nil + case <-ctx.Done(): + return fmt.Errorf("leader did not change - %s", entity.Name()) + } + } +} + +// BalanceStreams finds the expected distribution of stream leaders over servers +// and forces leader election on any with an uneven distribution +func (b *Balancer) BalanceStreams(streams []*jsm.Stream) (int, error) { + var err error + servers := map[string]*peer{} + + for _, s := range streams { + servers, err = b.mapEntityToServers(s, servers) + if err != nil { + return 0, err + } + } + + b.log.Debugf("found %d streams on %d servers\n", len(streams), len(servers)) + evenDistribution := b.calcDistribution(len(streams), len(servers)) + b.log.Debugf("even distribution is %d\n", evenDistribution) + b.getOvers(servers, evenDistribution) + + return b.balance(servers, evenDistribution) +} + +// BalanceConsumers finds the expected distribution of consumer leaders over servers +// and forces leader election on any with an uneven distribution +func (b *Balancer) BalanceConsumers(consumers []*jsm.Consumer) (int, error) { + var err error + servers := map[string]*peer{} + + for _, s := range consumers { + servers, err = b.mapEntityToServers(s, servers) + if err != nil { + return 0, err + } + } + + b.log.Debugf("found %d consumers on %d servers\n", len(consumers), len(servers)) + evenDistribution := b.calcDistribution(len(consumers), len(servers)) + b.log.Debugf("even distribution is %d\n", evenDistribution) + b.getOvers(servers, evenDistribution) + + return b.balance(servers, evenDistribution) +} diff --git a/balancer/balancer_test.go b/balancer/balancer_test.go new file mode 100644 index 0000000..410062f --- /dev/null +++ b/balancer/balancer_test.go @@ -0,0 +1,191 @@ +package balancer + +import ( + "context" + "fmt" + "net/url" + "os" + "path/filepath" + "testing" + "time" + + "github.com/nats-io/jsm.go" + "github.com/nats-io/jsm.go/api" + "github.com/nats-io/nats-server/v2/server" + "github.com/nats-io/nats.go" +) + +func TestBalanceStream(t *testing.T) { + withJSCluster(t, 3, func(t *testing.T, servers []*server.Server, nc *nats.Conn, mgr *jsm.Manager) error { + streams := []*jsm.Stream{} + for i := 1; i <= 10; i++ { + streamName := fmt.Sprintf("tests%d", i) + subjects := fmt.Sprintf("tests%d.*", i) + s, err := mgr.NewStream(streamName, jsm.Subjects(subjects), jsm.MemoryStorage(), jsm.Replicas(3)) + if err != nil { + t.Fatalf("could not create stream %s", err) + } + streams = append(streams, s) + defer s.Delete() + } + + servers[2].DisableJetStream() + err := servers[2].EnableJetStream(nil) + if err != nil { + return err + } + + b, err := New(nc, api.NewDefaultLogger(api.DebugLevel)) + if err != nil { + return err + } + + count, err := b.BalanceStreams(streams) + if err != nil { + return err + } + + if count == 0 { + return err + } + return nil + }) +} + +func TestBalanceConsumer(t *testing.T) { + withJSCluster(t, 3, func(t *testing.T, servers []*server.Server, nc *nats.Conn, mgr *jsm.Manager) error { + s, err := mgr.NewStream("TEST_CONSUMER_BALANCE", jsm.Subjects("test.*"), jsm.MemoryStorage(), jsm.Replicas(3)) + if err != nil { + return err + } + + defer s.Delete() + + consumers := []*jsm.Consumer{} + for i := 1; i <= 10; i++ { + consumerName := fmt.Sprintf("testc%d", i) + c, err := mgr.NewConsumer("TEST_CONSUMER_BALANCE", jsm.ConsumerName(consumerName)) + if err != nil { + return err + } + consumers = append(consumers, c) + defer c.Delete() + } + + servers[2].DisableJetStream() + err = servers[2].EnableJetStream(nil) + if err != nil { + return err + } + + b, err := New(nc, api.NewDefaultLogger(api.DebugLevel)) + if err != nil { + return err + } + + count, err := b.BalanceConsumers(consumers) + if err != nil { + return err + } + + if count == 0 { + return err + } + + return nil + }) +} + +func withJSCluster(t *testing.T, retries int, cb func(*testing.T, []*server.Server, *nats.Conn, *jsm.Manager) error) { + t.Helper() + + d, err := os.MkdirTemp("", "jstest") + if err != nil { + t.Fatalf("temp dir could not be made: %s", err) + } + defer os.RemoveAll(d) + + var ( + servers []*server.Server + ) + + for i := 1; i <= 3; i++ { + opts := &server.Options{ + JetStream: true, + StoreDir: filepath.Join(d, fmt.Sprintf("s%d", i)), + Port: -1, + Host: "localhost", + ServerName: fmt.Sprintf("s%d", i), + LogFile: "/dev/null", + Cluster: server.ClusterOpts{ + Name: "TEST", + Port: 12000 + i, + }, + Routes: []*url.URL{ + {Host: "localhost:12001"}, + {Host: "localhost:12002"}, + {Host: "localhost:12003"}, + }, + } + + s, err := server.NewServer(opts) + if err != nil { + t.Fatalf("server %d start failed: %v", i, err) + } + s.ConfigureLogger() + go s.Start() + if !s.ReadyForConnections(10 * time.Second) { + t.Errorf("nats server %d did not start", i) + } + defer func() { + s.Shutdown() + }() + + servers = append(servers, s) + } + + if len(servers) != 3 { + t.Fatalf("servers did not start") + } + + nc, err := nats.Connect(servers[0].ClientURL(), nats.UseOldRequestStyle()) + if err != nil { + t.Fatalf("client start failed: %s", err) + } + defer nc.Close() + + mgr, err := jsm.New(nc, jsm.WithTimeout(time.Second)) + if err != nil { + t.Fatalf("manager creation failed: %s", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _, err := mgr.JetStreamAccountInfo() + if err != nil { + continue + } + + for i := 0; i < retries; i++ { + err = cb(t, servers, nc, mgr) + if err == nil { + break + } + } + + if err != nil { + t.Fatal(err) + } + + return + case <-ctx.Done(): + t.Fatalf("jetstream did not become available") + } + } +} diff --git a/consumers.go b/consumers.go index a2a607c..15092a1 100644 --- a/consumers.go +++ b/consumers.go @@ -892,6 +892,15 @@ func (c *Consumer) LatestState() (api.ConsumerInfo, error) { return c.State() } +func (c *Consumer) ClusterInfo() (api.ClusterInfo, error) { + nfo, err := c.LatestState() + if err != nil { + return api.ClusterInfo{}, err + } + + return *nfo.Cluster, nil +} + // State loads a snapshot of consumer state including delivery counts, retries and more func (c *Consumer) State() (api.ConsumerInfo, error) { s, err := c.mgr.loadConsumerInfo(c.stream, c.name) diff --git a/go.mod b/go.mod index 7b3395f..3fc0ff8 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/nats-io/nuid v1.0.1 github.com/prometheus/client_golang v1.20.5 github.com/prometheus/common v0.60.1 + golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f golang.org/x/net v0.31.0 golang.org/x/text v0.20.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 5f0c95b..4ffd981 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/streams.go b/streams.go index 4b8958c..6e713c9 100644 --- a/streams.go +++ b/streams.go @@ -670,6 +670,15 @@ func (s *Stream) LatestState() (state api.StreamState, err error) { return nfo.State, nil } +func (s *Stream) ClusterInfo() (api.ClusterInfo, error) { + nfo, err := s.LatestInformation() + if err != nil { + return api.ClusterInfo{}, err + } + + return *nfo.Cluster, nil +} + // State retrieves the Stream State func (s *Stream) State(req ...api.JSApiStreamInfoRequest) (stats api.StreamState, err error) { info, err := s.Information(req...)