From 9c69f83f75a707e7bb11ee877ac3fe052636fcf1 Mon Sep 17 00:00:00 2001 From: Josh van Leeuwen Date: Thu, 1 Feb 2024 16:14:24 +0000 Subject: [PATCH] Block Graceful Shutdown: stop input bindings and subscriptions (#7474) Whilst Dapr should keep available outgoing API requests whilst blocking graceful shutdown, we want to prevent incoming input bindings and subscriptions from receiving new messages. PR updates the block graceful shutdown to forever stop input bindings and subscriptions from being read. Signed-off-by: joshvanl --- pkg/runtime/processor/binding/binding.go | 1 + pkg/runtime/processor/binding/send.go | 10 +- pkg/runtime/processor/manager.go | 4 +- pkg/runtime/processor/pubsub/pubsub.go | 1 + pkg/runtime/processor/pubsub/pubsub_test.go | 10 +- pkg/runtime/processor/pubsub/subscribe.go | 14 ++- pkg/runtime/runtime.go | 19 ++- .../suite/daprd/shutdown/block/app/healthy.go | 27 ++++- .../suite/daprd/shutdown/block/timeout.go | 113 ++++++++++++++++-- 9 files changed, 176 insertions(+), 23 deletions(-) diff --git a/pkg/runtime/processor/binding/binding.go b/pkg/runtime/processor/binding/binding.go index bc1975bfa1c..83604149e90 100644 --- a/pkg/runtime/processor/binding/binding.go +++ b/pkg/runtime/processor/binding/binding.go @@ -73,6 +73,7 @@ type binding struct { lock sync.Mutex readingBindings bool + stopForever bool subscribeBindingList []string inputCancels map[string]context.CancelFunc diff --git a/pkg/runtime/processor/binding/send.go b/pkg/runtime/processor/binding/send.go index d336030d6a8..fe9df663a20 100644 --- a/pkg/runtime/processor/binding/send.go +++ b/pkg/runtime/processor/binding/send.go @@ -43,6 +43,10 @@ func (b *binding) StartReadingFromBindings(ctx context.Context) error { b.lock.Lock() defer b.lock.Unlock() + if b.stopForever { + return nil + } + b.readingBindings = true if b.channels.AppChannel() == nil { @@ -110,11 +114,15 @@ func (b *binding) startInputBinding(comp componentsV1alpha1.Component, binding b return nil } -func (b *binding) StopReadingFromBindings() { +func (b *binding) StopReadingFromBindings(forever bool) { b.lock.Lock() defer b.lock.Unlock() defer b.wg.Wait() + if forever { + b.stopForever = true + } + b.readingBindings = false for _, cancel := range b.inputCancels { diff --git a/pkg/runtime/processor/manager.go b/pkg/runtime/processor/manager.go index 71b37fb65a7..1bb0b9811c6 100644 --- a/pkg/runtime/processor/manager.go +++ b/pkg/runtime/processor/manager.go @@ -47,7 +47,7 @@ type PubsubManager interface { BulkPublish(context.Context, *contribpubsub.BulkPublishRequest) (contribpubsub.BulkPublishResponse, error) StartSubscriptions(context.Context) error - StopSubscriptions() + StopSubscriptions(forever bool) Outbox() outbox.Outbox manager } @@ -56,7 +56,7 @@ type BindingManager interface { SendToOutputBinding(context.Context, string, *bindings.InvokeRequest) (*bindings.InvokeResponse, error) StartReadingFromBindings(context.Context) error - StopReadingFromBindings() + StopReadingFromBindings(forever bool) manager } diff --git a/pkg/runtime/processor/pubsub/pubsub.go b/pkg/runtime/processor/pubsub/pubsub.go index bd69433963a..7ecefab95f0 100644 --- a/pkg/runtime/processor/pubsub/pubsub.go +++ b/pkg/runtime/processor/pubsub/pubsub.go @@ -99,6 +99,7 @@ type pubsub struct { lock sync.RWMutex subscribing bool + stopForever bool topicCancels map[string]context.CancelFunc outbox outbox.Outbox diff --git a/pkg/runtime/processor/pubsub/pubsub_test.go b/pkg/runtime/processor/pubsub/pubsub_test.go index 05883866696..0d2512d4bff 100644 --- a/pkg/runtime/processor/pubsub/pubsub_test.go +++ b/pkg/runtime/processor/pubsub/pubsub_test.go @@ -133,7 +133,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel := new(channelt.MockAppChannel) ps.channels = new(channels.Channels).WithAppChannel(mockAppChannel) - ps.StopSubscriptions() + ps.StopSubscriptions(false) ps.compStore.SetTopicRoutes(nil) ps.compStore.SetSubscriptions(nil) for name := range ps.compStore.ListPubSubs() { @@ -227,7 +227,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), matchDaprRequestMethod("dapr/subscribe")).Return(fakeResp, nil) require.NoError(t, ps.StartSubscriptions(context.Background())) - ps.StopSubscriptions() + ps.StopSubscriptions(false) // act for _, comp := range pubsubComponents { @@ -327,7 +327,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), matchDaprRequestMethod("dapr/subscribe")).Return(fakeResp, nil) require.NoError(t, ps.StartSubscriptions(context.Background())) - ps.StopSubscriptions() + ps.StopSubscriptions(false) // act for _, comp := range pubsubComponents { @@ -1927,7 +1927,7 @@ func TestPubsubLifecycle(t *testing.T) { comp3 := getPubSub("mockPubSub3") comp3.On("unsubscribed", "topic4").Return(nil).Once() - ps.StopSubscriptions() + ps.StopSubscriptions(false) sendMessages(t, 0) @@ -1947,7 +1947,7 @@ func TestPubsubLifecycle(t *testing.T) { comp2.On("unsubscribed", "topic2").Return(nil).Once() comp2.On("unsubscribed", "topic3").Return(nil).Once() - ps.StopSubscriptions() + ps.StopSubscriptions(false) time.Sleep(time.Second / 2) comp1.AssertCalled(t, "unsubscribed", "topic1") diff --git a/pkg/runtime/processor/pubsub/subscribe.go b/pkg/runtime/processor/pubsub/subscribe.go index 7343aff76cc..c087aaae85f 100644 --- a/pkg/runtime/processor/pubsub/subscribe.go +++ b/pkg/runtime/processor/pubsub/subscribe.go @@ -29,11 +29,16 @@ import ( // StartSubscriptions starts the pubsub subscriptions func (p *pubsub) StartSubscriptions(ctx context.Context) error { // Clean any previous state - p.StopSubscriptions() + p.StopSubscriptions(false) p.lock.Lock() defer p.lock.Unlock() + // If Dapr has stopped subscribing forever, return early. + if p.stopForever { + return nil + } + p.subscribing = true var errs []error @@ -47,10 +52,15 @@ func (p *pubsub) StartSubscriptions(ctx context.Context) error { } // StopSubscriptions to all topics and cleans the cached topics -func (p *pubsub) StopSubscriptions() { +func (p *pubsub) StopSubscriptions(forever bool) { p.lock.Lock() defer p.lock.Unlock() + if forever { + // Mark if Dapr has stopped subscribing forever. + p.stopForever = true + } + p.subscribing = false for subKey := range p.topicCancels { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index b414ddd4677..f7de63db64c 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -339,6 +339,13 @@ func (a *DaprRuntime) Run(parentCtx context.Context) error { } log.Infof("Blocking graceful shutdown for %s or until app reports unhealthy...", *a.runtimeConfig.blockShutdownDuration) + + // Stop reading from subscriptions and input bindings forever while + // blocking graceful shutdown. This will prevent incoming messages from + // being processed, but allow outgoing APIs to be processed. + a.processor.PubSub().StopSubscriptions(true) + a.processor.Binding().StopReadingFromBindings(true) + select { case <-a.clock.After(*a.runtimeConfig.blockShutdownDuration): log.Info("Block shutdown period expired, entering shutdown...") @@ -715,8 +722,8 @@ func (a *DaprRuntime) appHealthChanged(ctx context.Context, status uint8) { } // Stop topic subscriptions and input bindings - a.processor.PubSub().StopSubscriptions() - a.processor.Binding().StopReadingFromBindings() + a.processor.PubSub().StopSubscriptions(false) + a.processor.Binding().StopReadingFromBindings(false) } } @@ -803,10 +810,14 @@ func (a *DaprRuntime) startHTTPServer(port int, publicPort *int, profilePort int return err } - if err := a.runnerCloser.AddCloser(a.processor.PubSub().StopSubscriptions); err != nil { + if err := a.runnerCloser.AddCloser(func() { + a.processor.PubSub().StopSubscriptions(true) + }); err != nil { return err } - if err := a.runnerCloser.AddCloser(a.processor.Binding().StopReadingFromBindings); err != nil { + if err := a.runnerCloser.AddCloser(func() { + a.processor.Binding().StopReadingFromBindings(true) + }); err != nil { return err } diff --git a/tests/integration/suite/daprd/shutdown/block/app/healthy.go b/tests/integration/suite/daprd/shutdown/block/app/healthy.go index 61c9b25dfa3..df89b17082b 100644 --- a/tests/integration/suite/daprd/shutdown/block/app/healthy.go +++ b/tests/integration/suite/daprd/shutdown/block/app/healthy.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + commonv1 "github.com/dapr/dapr/pkg/proto/common/v1" rtv1 "github.com/dapr/dapr/pkg/proto/runtime/v1" "github.com/dapr/dapr/tests/integration/framework" "github.com/dapr/dapr/tests/integration/framework/process/daprd" @@ -102,6 +103,14 @@ metadata: spec: type: pubsub.in-memory version: v1 +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: mystore +spec: + type: state.in-memory + version: v1 `)) return []framework.Option{ @@ -149,9 +158,16 @@ func (h *healthy) Run(t *testing.T, ctx context.Context) { require.NoError(t, err) select { case <-h.routeCh: - case <-ctx.Done(): - assert.Fail(t, "pubsub did not send message to subscriber") + assert.Fail(t, "pubsub should not have sent message to subscriber") + case <-time.After(time.Second): } + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + {Key: "key", Value: []byte("value")}, + }, + }) + require.NoError(t, err) healthzCalled = h.healthzCalled.Load() h.appHealth.Store(false) @@ -170,6 +186,13 @@ func (h *healthy) Run(t *testing.T, ctx context.Context) { //nolint:testifylint assert.Error(c, err) }, time.Second*5, time.Millisecond*100) + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + {Key: "key", Value: []byte("value2")}, + }, + }) + require.Error(t, err) select { case <-daprdStopped: diff --git a/tests/integration/suite/daprd/shutdown/block/timeout.go b/tests/integration/suite/daprd/shutdown/block/timeout.go index bb9d1ac3d8b..b6bc53d72d6 100644 --- a/tests/integration/suite/daprd/shutdown/block/timeout.go +++ b/tests/integration/suite/daprd/shutdown/block/timeout.go @@ -18,6 +18,7 @@ import ( "io" "net/http" "runtime" + "sync/atomic" "testing" "time" @@ -26,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + commonv1 "github.com/dapr/dapr/pkg/proto/common/v1" rtv1 "github.com/dapr/dapr/pkg/proto/runtime/v1" "github.com/dapr/dapr/tests/integration/framework" "github.com/dapr/dapr/tests/integration/framework/process/daprd" @@ -42,9 +44,11 @@ func init() { // timeout tests Daprd's --dapr-block-shutdown-seconds, ensuring shutdown // procedure will begin when seconds is reached when app still reports healthy. type timeout struct { - daprd *daprd.Daprd - logline *logline.LogLine - routeCh chan struct{} + daprd *daprd.Daprd + logline *logline.LogLine + routeCh chan struct{} + listening atomic.Bool + bindingChan chan struct{} } func (i *timeout) Setup(t *testing.T) []framework.Option { @@ -53,6 +57,7 @@ func (i *timeout) Setup(t *testing.T) []framework.Option { } i.routeCh = make(chan struct{}, 1) + i.bindingChan = make(chan struct{}) handler := http.NewServeMux() handler.HandleFunc("/dapr/subscribe", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -64,6 +69,13 @@ func (i *timeout) Setup(t *testing.T) []framework.Option { handler.HandleFunc("/route", func(w http.ResponseWriter, r *http.Request) { i.routeCh <- struct{}{} }) + handler.HandleFunc("/binding", func(w http.ResponseWriter, r *http.Request) { + if i.listening.Load() { + i.listening.Store(false) + i.bindingChan <- struct{}{} + } + }) + app := prochttp.New(t, prochttp.WithHandler(handler), ) @@ -92,6 +104,27 @@ metadata: spec: type: pubsub.in-memory version: v1 +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: 'binding' +spec: + type: bindings.cron + version: v1 + metadata: + - name: schedule + value: "@every 100ms" + - name: direction + value: "input" +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: 'mystore' +spec: + type: state.in-memory + version: v1 `)) return []framework.Option{ @@ -120,25 +153,75 @@ func (i *timeout) Run(t *testing.T, ctx context.Context) { assert.Fail(t, "pubsub message should have been sent to subscriber") } + i.listening.Store(true) + select { + case <-i.bindingChan: + case <-time.After(time.Second * 5): + assert.Fail(t, "timed out waiting for binding event") + } + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value"), + }, + }, + }) + require.NoError(t, err) + resp, err := client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.NoError(t, err) + assert.Equal(t, "value", string(resp.GetData())) + daprdStopped := make(chan struct{}) go func() { i.daprd.Cleanup(t) close(daprdStopped) }() - t.Run("daprd APIs should still be available during blocked shutdown", func(t *testing.T) { - time.Sleep(time.Second) + t.Run("daprd APIs should still be available during blocked shutdown, except input bindings and subscriptions", func(t *testing.T) { + time.Sleep(time.Second / 2) + + i.listening.Store(true) + select { + case <-i.bindingChan: + assert.Fail(t, "binding event should not have been sent to subscriber") + case <-time.After(time.Second / 2): + } + _, err = client.PublishEvent(ctx, &rtv1.PublishEventRequest{ PubsubName: "foo", Topic: "topic", Data: []byte(`{"status":"completed"}`), }) require.NoError(t, err) + select { case <-i.routeCh: - case <-ctx.Done(): - assert.Fail(t, "pubsub message should have been sent to subscriber") + assert.Fail(t, "pubsub message should not have been sent to subscriber") + case <-time.After(time.Second / 2): } + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value2"), + }, + }, + }) + require.NoError(t, err) + resp, err = client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.NoError(t, err) + assert.Equal(t, "value2", string(resp.GetData())) }) t.Run("daprd APIs are no longer available when past blocked shutdown", func(t *testing.T) { @@ -149,6 +232,22 @@ func (i *timeout) Run(t *testing.T, ctx context.Context) { Data: []byte(`{"status":"completed"}`), }) require.Error(t, err) + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value3"), + }, + }, + }) + require.Error(t, err) + _, err = client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.Error(t, err) }) select {