From 967a28aa6ad3f41a91fc83c8cbfe0e6942f6e9f2 Mon Sep 17 00:00:00 2001 From: Vladimir Dementyev Date: Thu, 2 Nov 2023 12:08:39 -0700 Subject: [PATCH] fix: watch nats epoch changes When restarting a cluster, there can be a race condition in epoch calculation and writing, so we must keep it in sync --- broker/broker.go | 1 + broker/memory.go | 22 +++++++++---- broker/nats.go | 79 +++++++++++++++++++++++++++++++++++++-------- broker/nats_test.go | 23 +++++++++++++ 4 files changed, 105 insertions(+), 20 deletions(-) diff --git a/broker/broker.go b/broker/broker.go index b727dd3d..0562b708 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -63,6 +63,7 @@ type LocalBroker interface { Start() error Shutdown(ctx context.Context) error SetEpoch(epoch string) + GetEpoch() string HistoryFrom(stream string, epoch string, offset uint64) ([]common.StreamMessage, error) HistorySince(stream string, ts int64) ([]common.StreamMessage, error) Store(stream string, msg []byte, seq uint64, ts time.Time) (uint64, error) diff --git a/broker/memory.go b/broker/memory.go index d188432c..3d405add 100644 --- a/broker/memory.go +++ b/broker/memory.go @@ -181,6 +181,7 @@ type Memory struct { streamsMu sync.RWMutex sessionsMu sync.RWMutex + epochMu sync.RWMutex } var _ Broker = (*Memory)(nil) @@ -201,7 +202,7 @@ func NewMemoryBroker(node Broadcaster, config *Config) *Memory { func (b *Memory) Announce() string { return fmt.Sprintf( "Using in-memory broker (epoch: %s, history limit: %d, history ttl: %ds, sessions ttl: %ds)", - b.epoch, + b.GetEpoch(), b.config.HistoryLimit, b.config.HistoryTTL, b.config.SessionsTTL, @@ -209,10 +210,16 @@ func (b *Memory) Announce() string { } func (b *Memory) GetEpoch() string { + b.epochMu.RLock() + defer b.epochMu.RUnlock() + return b.epoch } func (b *Memory) SetEpoch(v string) { + b.epochMu.Lock() + defer b.epochMu.Unlock() + b.epoch = v } @@ -229,7 +236,7 @@ func (b *Memory) Shutdown(ctx context.Context) error { func (b *Memory) HandleBroadcast(msg *common.StreamMessage) { offset := b.add(msg.Stream, msg.Data) - msg.Epoch = b.epoch + msg.Epoch = b.GetEpoch() msg.Offset = offset if b.tracker.Has(msg.Stream) { @@ -264,8 +271,10 @@ func (b *Memory) Unsubscribe(stream string) string { } func (b *Memory) HistoryFrom(name string, epoch string, offset uint64) ([]common.StreamMessage, error) { - if b.epoch != epoch { - return nil, fmt.Errorf("Unknown epoch: %s, current: %s", epoch, b.epoch) + bepoch := b.GetEpoch() + + if bepoch != epoch { + return nil, fmt.Errorf("Unknown epoch: %s, current: %s", epoch, bepoch) } stream := b.get(name) @@ -281,7 +290,7 @@ func (b *Memory) HistoryFrom(name string, epoch string, offset uint64) ([]common Stream: name, Data: entry.data, Offset: entry.offset, - Epoch: b.epoch, + Epoch: bepoch, }) }) @@ -299,6 +308,7 @@ func (b *Memory) HistorySince(name string, ts int64) ([]common.StreamMessage, er return nil, nil } + bepoch := b.GetEpoch() history := []common.StreamMessage{} err := stream.filterByTime(ts, func(entry *entry) { @@ -306,7 +316,7 @@ func (b *Memory) HistorySince(name string, ts int64) ([]common.StreamMessage, er Stream: name, Data: entry.data, Offset: entry.offset, - Epoch: b.epoch, + Epoch: bepoch, }) }) diff --git a/broker/nats.go b/broker/nats.go index 8aa80400..cf3fe78c 100644 --- a/broker/nats.go +++ b/broker/nats.go @@ -25,8 +25,9 @@ type NATS struct { nconf *natsconfig.NATSConfig conn *nats.Conn - js jetstream.JetStream - kv jetstream.KeyValue + js jetstream.JetStream + kv jetstream.KeyValue + epochKV jetstream.KeyValue jstreams *lru[string] jconsumers *lru[jetstream.Consumer] @@ -36,9 +37,13 @@ type NATS struct { local LocalBroker clientMu sync.RWMutex + epochMu sync.RWMutex epoch string + shutdownCtx context.Context + shutdownFn func() + log *log.Entry } @@ -61,10 +66,14 @@ func WithNATSLocalBroker(b LocalBroker) NATSOption { } func NewNATSBroker(broadcaster Broadcaster, c *Config, nc *natsconfig.NATSConfig, opts ...NATSOption) *NATS { + shutdownCtx, shutdownFn := context.WithCancel(context.Background()) + n := NATS{ broadcaster: broadcaster, conf: c, nconf: nc, + shutdownCtx: shutdownCtx, + shutdownFn: shutdownFn, tracker: NewStreamsTracker(), streamSync: newStreamsSynchronizer(), jstreams: newLRU[string](time.Duration(c.HistoryTTL * int64(time.Second))), @@ -134,17 +143,20 @@ func (n *NATS) Start() error { return errorx.Decorate(err, "Failed to calculate epoch") } - n.epoch = epoch - - n.local.SetEpoch(epoch) - + n.writeEpoch(epoch) err = n.local.Start() if err != nil { return errorx.Decorate(err, "Failed to start internal memory broker") } - n.log.Debugf("Current epoch: %s", n.epoch) + err = n.watchEpoch(n.shutdownCtx) + + if err != nil { + n.log.Warnf("failed to set up epoch watcher: %s", err) + } + + n.log.Debugf("Current epoch: %s", epoch) return nil } @@ -153,6 +165,8 @@ func (n *NATS) Shutdown(ctx context.Context) error { n.clientMu.Lock() defer n.clientMu.Unlock() + n.shutdownFn() + if n.conn != nil { n.conn.Close() n.conn = nil @@ -172,8 +186,8 @@ func (n *NATS) Announce() string { } func (n *NATS) Epoch() string { - n.clientMu.RLock() - defer n.clientMu.RUnlock() + n.epochMu.RLock() + defer n.epochMu.RUnlock() return n.epoch } @@ -193,13 +207,19 @@ func (n *NATS) SetEpoch(epoch string) error { return err } - n.epoch = epoch + n.writeEpoch(epoch) + + return nil +} + +func (n *NATS) writeEpoch(val string) { + n.epochMu.Lock() + defer n.epochMu.Unlock() + n.epoch = val if n.local != nil { - n.local.SetEpoch(epoch) + n.local.SetEpoch(val) } - - return nil } func (n *NATS) HandleBroadcast(msg *common.StreamMessage) { @@ -210,7 +230,7 @@ func (n *NATS) HandleBroadcast(msg *common.StreamMessage) { return } - msg.Epoch = n.epoch + msg.Epoch = n.Epoch() msg.Offset = offset if n.tracker.Has(msg.Stream) { @@ -487,6 +507,8 @@ fetchEpoch: return "", errorx.Decorate(err, "failed to connect to JetStream KV") } + n.epochKV = kv + _, err = kv.Create(context.Background(), epochKey, []byte(maybeNewEpoch)) if err != nil && strings.Contains(err.Error(), "key exists") { @@ -508,6 +530,35 @@ fetchEpoch: return maybeNewEpoch, nil } +func (n *NATS) watchEpoch(ctx context.Context) error { + watcher, err := n.epochKV.Watch(context.Background(), epochKey, jetstream.IgnoreDeletes()) + + if err != nil { + return err + } + + go func() { + for { + select { + case <-ctx.Done(): + watcher.Stop() // nolint:errcheck + return + case entry := <-watcher.Updates(): + if entry != nil { + newEpoch := string(entry.Value()) + + if n.Epoch() != newEpoch { + n.log.Warnf("epoch updated: %s", newEpoch) + n.writeEpoch(newEpoch) + } + } + } + } + }() + + return nil +} + func (n *NATS) fetchBucketWithTTL(key string, ttl time.Duration) (jetstream.KeyValue, error) { var bucket jetstream.KeyValue newBucket := true diff --git a/broker/nats_test.go b/broker/nats_test.go index f6f9eb6c..4090f77b 100644 --- a/broker/nats_test.go +++ b/broker/nats_test.go @@ -367,6 +367,29 @@ func TestNATSBroker_Epoch(t *testing.T) { defer anotherBroker.Shutdown(context.Background()) // nolint: errcheck assert.Equal(t, epoch, anotherBroker.Epoch()) + + // Now let's test that epoch changes are picked up + require.NoError(t, anotherBroker.SetEpoch("new-epoch")) + + assert.Equal(t, "new-epoch", anotherBroker.Epoch()) + assert.Equal(t, "new-epoch", anotherBroker.local.GetEpoch()) + + timer := time.After(2 * time.Second) + +wait: + for { + select { + case <-timer: + assert.Fail(t, "Epoch change wasn't picked up") + return + default: + if broker.Epoch() == "new-epoch" { + break wait + } + + time.Sleep(100 * time.Millisecond) + } + } } func buildNATSServer(t *testing.T, addr string) *enats.Service {