diff --git a/pkg/runtime/processor/binding/binding.go b/pkg/runtime/processor/binding/binding.go index bc1975bfa1c..83604149e90 100644 --- a/pkg/runtime/processor/binding/binding.go +++ b/pkg/runtime/processor/binding/binding.go @@ -73,6 +73,7 @@ type binding struct { lock sync.Mutex readingBindings bool + stopForever bool subscribeBindingList []string inputCancels map[string]context.CancelFunc diff --git a/pkg/runtime/processor/binding/send.go b/pkg/runtime/processor/binding/send.go index d336030d6a8..fe9df663a20 100644 --- a/pkg/runtime/processor/binding/send.go +++ b/pkg/runtime/processor/binding/send.go @@ -43,6 +43,10 @@ func (b *binding) StartReadingFromBindings(ctx context.Context) error { b.lock.Lock() defer b.lock.Unlock() + if b.stopForever { + return nil + } + b.readingBindings = true if b.channels.AppChannel() == nil { @@ -110,11 +114,15 @@ func (b *binding) startInputBinding(comp componentsV1alpha1.Component, binding b return nil } -func (b *binding) StopReadingFromBindings() { +func (b *binding) StopReadingFromBindings(forever bool) { b.lock.Lock() defer b.lock.Unlock() defer b.wg.Wait() + if forever { + b.stopForever = true + } + b.readingBindings = false for _, cancel := range b.inputCancels { diff --git a/pkg/runtime/processor/manager.go b/pkg/runtime/processor/manager.go index 71b37fb65a7..1bb0b9811c6 100644 --- a/pkg/runtime/processor/manager.go +++ b/pkg/runtime/processor/manager.go @@ -47,7 +47,7 @@ type PubsubManager interface { BulkPublish(context.Context, *contribpubsub.BulkPublishRequest) (contribpubsub.BulkPublishResponse, error) StartSubscriptions(context.Context) error - StopSubscriptions() + StopSubscriptions(forever bool) Outbox() outbox.Outbox manager } @@ -56,7 +56,7 @@ type BindingManager interface { SendToOutputBinding(context.Context, string, *bindings.InvokeRequest) (*bindings.InvokeResponse, error) StartReadingFromBindings(context.Context) error - StopReadingFromBindings() + StopReadingFromBindings(forever bool) manager } diff --git a/pkg/runtime/processor/pubsub/pubsub.go b/pkg/runtime/processor/pubsub/pubsub.go index bd69433963a..7ecefab95f0 100644 --- a/pkg/runtime/processor/pubsub/pubsub.go +++ b/pkg/runtime/processor/pubsub/pubsub.go @@ -99,6 +99,7 @@ type pubsub struct { lock sync.RWMutex subscribing bool + stopForever bool topicCancels map[string]context.CancelFunc outbox outbox.Outbox diff --git a/pkg/runtime/processor/pubsub/pubsub_test.go b/pkg/runtime/processor/pubsub/pubsub_test.go index 05883866696..0d2512d4bff 100644 --- a/pkg/runtime/processor/pubsub/pubsub_test.go +++ b/pkg/runtime/processor/pubsub/pubsub_test.go @@ -133,7 +133,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel := new(channelt.MockAppChannel) ps.channels = new(channels.Channels).WithAppChannel(mockAppChannel) - ps.StopSubscriptions() + ps.StopSubscriptions(false) ps.compStore.SetTopicRoutes(nil) ps.compStore.SetSubscriptions(nil) for name := range ps.compStore.ListPubSubs() { @@ -227,7 +227,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), matchDaprRequestMethod("dapr/subscribe")).Return(fakeResp, nil) require.NoError(t, ps.StartSubscriptions(context.Background())) - ps.StopSubscriptions() + ps.StopSubscriptions(false) // act for _, comp := range pubsubComponents { @@ -327,7 +327,7 @@ func TestInitPubSub(t *testing.T) { mockAppChannel.On("InvokeMethod", mock.MatchedBy(matchContextInterface), matchDaprRequestMethod("dapr/subscribe")).Return(fakeResp, nil) require.NoError(t, ps.StartSubscriptions(context.Background())) - ps.StopSubscriptions() + ps.StopSubscriptions(false) // act for _, comp := range pubsubComponents { @@ -1927,7 +1927,7 @@ func TestPubsubLifecycle(t *testing.T) { comp3 := getPubSub("mockPubSub3") comp3.On("unsubscribed", "topic4").Return(nil).Once() - ps.StopSubscriptions() + ps.StopSubscriptions(false) sendMessages(t, 0) @@ -1947,7 +1947,7 @@ func TestPubsubLifecycle(t *testing.T) { comp2.On("unsubscribed", "topic2").Return(nil).Once() comp2.On("unsubscribed", "topic3").Return(nil).Once() - ps.StopSubscriptions() + ps.StopSubscriptions(false) time.Sleep(time.Second / 2) comp1.AssertCalled(t, "unsubscribed", "topic1") diff --git a/pkg/runtime/processor/pubsub/subscribe.go b/pkg/runtime/processor/pubsub/subscribe.go index 7343aff76cc..c087aaae85f 100644 --- a/pkg/runtime/processor/pubsub/subscribe.go +++ b/pkg/runtime/processor/pubsub/subscribe.go @@ -29,11 +29,16 @@ import ( // StartSubscriptions starts the pubsub subscriptions func (p *pubsub) StartSubscriptions(ctx context.Context) error { // Clean any previous state - p.StopSubscriptions() + p.StopSubscriptions(false) p.lock.Lock() defer p.lock.Unlock() + // If Dapr has stopped subscribing forever, return early. + if p.stopForever { + return nil + } + p.subscribing = true var errs []error @@ -47,10 +52,15 @@ func (p *pubsub) StartSubscriptions(ctx context.Context) error { } // StopSubscriptions to all topics and cleans the cached topics -func (p *pubsub) StopSubscriptions() { +func (p *pubsub) StopSubscriptions(forever bool) { p.lock.Lock() defer p.lock.Unlock() + if forever { + // Mark if Dapr has stopped subscribing forever. + p.stopForever = true + } + p.subscribing = false for subKey := range p.topicCancels { diff --git a/pkg/runtime/runtime.go b/pkg/runtime/runtime.go index b414ddd4677..f7de63db64c 100644 --- a/pkg/runtime/runtime.go +++ b/pkg/runtime/runtime.go @@ -339,6 +339,13 @@ func (a *DaprRuntime) Run(parentCtx context.Context) error { } log.Infof("Blocking graceful shutdown for %s or until app reports unhealthy...", *a.runtimeConfig.blockShutdownDuration) + + // Stop reading from subscriptions and input bindings forever while + // blocking graceful shutdown. This will prevent incoming messages from + // being processed, but allow outgoing APIs to be processed. + a.processor.PubSub().StopSubscriptions(true) + a.processor.Binding().StopReadingFromBindings(true) + select { case <-a.clock.After(*a.runtimeConfig.blockShutdownDuration): log.Info("Block shutdown period expired, entering shutdown...") @@ -715,8 +722,8 @@ func (a *DaprRuntime) appHealthChanged(ctx context.Context, status uint8) { } // Stop topic subscriptions and input bindings - a.processor.PubSub().StopSubscriptions() - a.processor.Binding().StopReadingFromBindings() + a.processor.PubSub().StopSubscriptions(false) + a.processor.Binding().StopReadingFromBindings(false) } } @@ -803,10 +810,14 @@ func (a *DaprRuntime) startHTTPServer(port int, publicPort *int, profilePort int return err } - if err := a.runnerCloser.AddCloser(a.processor.PubSub().StopSubscriptions); err != nil { + if err := a.runnerCloser.AddCloser(func() { + a.processor.PubSub().StopSubscriptions(true) + }); err != nil { return err } - if err := a.runnerCloser.AddCloser(a.processor.Binding().StopReadingFromBindings); err != nil { + if err := a.runnerCloser.AddCloser(func() { + a.processor.Binding().StopReadingFromBindings(true) + }); err != nil { return err } diff --git a/tests/integration/suite/daprd/shutdown/block/app/healthy.go b/tests/integration/suite/daprd/shutdown/block/app/healthy.go index 61c9b25dfa3..df89b17082b 100644 --- a/tests/integration/suite/daprd/shutdown/block/app/healthy.go +++ b/tests/integration/suite/daprd/shutdown/block/app/healthy.go @@ -27,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + commonv1 "github.com/dapr/dapr/pkg/proto/common/v1" rtv1 "github.com/dapr/dapr/pkg/proto/runtime/v1" "github.com/dapr/dapr/tests/integration/framework" "github.com/dapr/dapr/tests/integration/framework/process/daprd" @@ -102,6 +103,14 @@ metadata: spec: type: pubsub.in-memory version: v1 +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: mystore +spec: + type: state.in-memory + version: v1 `)) return []framework.Option{ @@ -149,9 +158,16 @@ func (h *healthy) Run(t *testing.T, ctx context.Context) { require.NoError(t, err) select { case <-h.routeCh: - case <-ctx.Done(): - assert.Fail(t, "pubsub did not send message to subscriber") + assert.Fail(t, "pubsub should not have sent message to subscriber") + case <-time.After(time.Second): } + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + {Key: "key", Value: []byte("value")}, + }, + }) + require.NoError(t, err) healthzCalled = h.healthzCalled.Load() h.appHealth.Store(false) @@ -170,6 +186,13 @@ func (h *healthy) Run(t *testing.T, ctx context.Context) { //nolint:testifylint assert.Error(c, err) }, time.Second*5, time.Millisecond*100) + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + {Key: "key", Value: []byte("value2")}, + }, + }) + require.Error(t, err) select { case <-daprdStopped: diff --git a/tests/integration/suite/daprd/shutdown/block/timeout.go b/tests/integration/suite/daprd/shutdown/block/timeout.go index bb9d1ac3d8b..b6bc53d72d6 100644 --- a/tests/integration/suite/daprd/shutdown/block/timeout.go +++ b/tests/integration/suite/daprd/shutdown/block/timeout.go @@ -18,6 +18,7 @@ import ( "io" "net/http" "runtime" + "sync/atomic" "testing" "time" @@ -26,6 +27,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + commonv1 "github.com/dapr/dapr/pkg/proto/common/v1" rtv1 "github.com/dapr/dapr/pkg/proto/runtime/v1" "github.com/dapr/dapr/tests/integration/framework" "github.com/dapr/dapr/tests/integration/framework/process/daprd" @@ -42,9 +44,11 @@ func init() { // timeout tests Daprd's --dapr-block-shutdown-seconds, ensuring shutdown // procedure will begin when seconds is reached when app still reports healthy. type timeout struct { - daprd *daprd.Daprd - logline *logline.LogLine - routeCh chan struct{} + daprd *daprd.Daprd + logline *logline.LogLine + routeCh chan struct{} + listening atomic.Bool + bindingChan chan struct{} } func (i *timeout) Setup(t *testing.T) []framework.Option { @@ -53,6 +57,7 @@ func (i *timeout) Setup(t *testing.T) []framework.Option { } i.routeCh = make(chan struct{}, 1) + i.bindingChan = make(chan struct{}) handler := http.NewServeMux() handler.HandleFunc("/dapr/subscribe", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -64,6 +69,13 @@ func (i *timeout) Setup(t *testing.T) []framework.Option { handler.HandleFunc("/route", func(w http.ResponseWriter, r *http.Request) { i.routeCh <- struct{}{} }) + handler.HandleFunc("/binding", func(w http.ResponseWriter, r *http.Request) { + if i.listening.Load() { + i.listening.Store(false) + i.bindingChan <- struct{}{} + } + }) + app := prochttp.New(t, prochttp.WithHandler(handler), ) @@ -92,6 +104,27 @@ metadata: spec: type: pubsub.in-memory version: v1 +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: 'binding' +spec: + type: bindings.cron + version: v1 + metadata: + - name: schedule + value: "@every 100ms" + - name: direction + value: "input" +--- +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: 'mystore' +spec: + type: state.in-memory + version: v1 `)) return []framework.Option{ @@ -120,25 +153,75 @@ func (i *timeout) Run(t *testing.T, ctx context.Context) { assert.Fail(t, "pubsub message should have been sent to subscriber") } + i.listening.Store(true) + select { + case <-i.bindingChan: + case <-time.After(time.Second * 5): + assert.Fail(t, "timed out waiting for binding event") + } + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value"), + }, + }, + }) + require.NoError(t, err) + resp, err := client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.NoError(t, err) + assert.Equal(t, "value", string(resp.GetData())) + daprdStopped := make(chan struct{}) go func() { i.daprd.Cleanup(t) close(daprdStopped) }() - t.Run("daprd APIs should still be available during blocked shutdown", func(t *testing.T) { - time.Sleep(time.Second) + t.Run("daprd APIs should still be available during blocked shutdown, except input bindings and subscriptions", func(t *testing.T) { + time.Sleep(time.Second / 2) + + i.listening.Store(true) + select { + case <-i.bindingChan: + assert.Fail(t, "binding event should not have been sent to subscriber") + case <-time.After(time.Second / 2): + } + _, err = client.PublishEvent(ctx, &rtv1.PublishEventRequest{ PubsubName: "foo", Topic: "topic", Data: []byte(`{"status":"completed"}`), }) require.NoError(t, err) + select { case <-i.routeCh: - case <-ctx.Done(): - assert.Fail(t, "pubsub message should have been sent to subscriber") + assert.Fail(t, "pubsub message should not have been sent to subscriber") + case <-time.After(time.Second / 2): } + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value2"), + }, + }, + }) + require.NoError(t, err) + resp, err = client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.NoError(t, err) + assert.Equal(t, "value2", string(resp.GetData())) }) t.Run("daprd APIs are no longer available when past blocked shutdown", func(t *testing.T) { @@ -149,6 +232,22 @@ func (i *timeout) Run(t *testing.T, ctx context.Context) { Data: []byte(`{"status":"completed"}`), }) require.Error(t, err) + + _, err = client.SaveState(ctx, &rtv1.SaveStateRequest{ + StoreName: "mystore", + States: []*commonv1.StateItem{ + { + Key: "key", + Value: []byte("value3"), + }, + }, + }) + require.Error(t, err) + _, err = client.GetState(ctx, &rtv1.GetStateRequest{ + StoreName: "mystore", + Key: "key", + }) + require.Error(t, err) }) select {