From aa990a7d5e76492ac310ab1b893071d432a3ff77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lovro=20Ma=C5=BEgon?= Date: Fri, 6 Oct 2023 19:03:22 +0200 Subject: [PATCH] Introduce csync and cchan (#88) * introduce csync and cchan utility packages * introduce csync options * stay compatible with Go 1.20 --- destination.go | 15 +- internal/atomicvaluewatcher.go | 117 --------- internal/atomicvaluewatcher_test.go | 160 ------------ internal/cchan/chan.go | 98 +++++++ internal/cchan/chan_test.go | 140 ++++++++++ internal/csync/opt.go | 70 +++++ internal/csync/opt_test.go | 92 +++++++ internal/csync/run.go | 44 ++++ internal/csync/valuewatcher.go | 164 ++++++++++++ internal/csync/valuewatcher_benchmark_test.go | 41 +++ internal/csync/valuewatcher_test.go | 240 ++++++++++++++++++ internal/csync/waitgroup.go | 51 ++++ internal/csync/waitgroup_test.go | 61 +++++ source.go | 45 ++-- 14 files changed, 1022 insertions(+), 316 deletions(-) delete mode 100644 internal/atomicvaluewatcher.go delete mode 100644 internal/atomicvaluewatcher_test.go create mode 100644 internal/cchan/chan.go create mode 100644 internal/cchan/chan_test.go create mode 100644 internal/csync/opt.go create mode 100644 internal/csync/opt_test.go create mode 100644 internal/csync/run.go create mode 100644 internal/csync/valuewatcher.go create mode 100644 internal/csync/valuewatcher_benchmark_test.go create mode 100644 internal/csync/valuewatcher_test.go create mode 100644 internal/csync/waitgroup.go create mode 100644 internal/csync/waitgroup_test.go diff --git a/destination.go b/destination.go index 255f3272..ddc51df7 100644 --- a/destination.go +++ b/destination.go @@ -26,6 +26,7 @@ import ( "github.com/conduitio/conduit-connector-protocol/cpluginv1" "github.com/conduitio/conduit-connector-sdk/internal" + "github.com/conduitio/conduit-connector-sdk/internal/csync" "go.uber.org/multierr" ) @@ -99,7 +100,7 @@ func NewDestinationPlugin(impl Destination) cpluginv1.DestinationPlugin { type destinationPluginAdapter struct { impl Destination - lastPosition *internal.AtomicValueWatcher[Position] + lastPosition *csync.ValueWatcher[Position] openCancel context.CancelFunc // write is the chosen write strategy, either single records or batches @@ -166,7 +167,7 @@ func (a *destinationPluginAdapter) configureWriteStrategy(ctx context.Context, c } func (a *destinationPluginAdapter) Start(ctx context.Context, _ cpluginv1.DestinationStartRequest) (cpluginv1.DestinationStartResponse, error) { - a.lastPosition = new(internal.AtomicValueWatcher[Position]) + a.lastPosition = new(csync.ValueWatcher[Position]) // detach context, so we can control when it's canceled ctxOpen := internal.DetachContext(ctx) @@ -205,7 +206,7 @@ func (a *destinationPluginAdapter) Run(ctx context.Context, stream cpluginv1.Des err = a.writeStrategy.Write(ctx, r, func(err error) error { return a.ack(r, err, stream) }) - a.lastPosition.Store(r.Position) + a.lastPosition.Set(r.Position) if err != nil { return err } @@ -232,14 +233,10 @@ func (a *destinationPluginAdapter) Stop(ctx context.Context, req cpluginv1.Desti // last thing we do is cancel context in Open defer a.openCancel() - // wait for at most 1 minute - waitCtx, cancel := context.WithTimeout(ctx, time.Minute) // TODO make the timeout configurable (https://github.com/ConduitIO/conduit/issues/183) - defer cancel() - // wait for last record to be received - err := a.lastPosition.Await(waitCtx, func(val Position) bool { + _, err := a.lastPosition.Watch(ctx, func(val Position) bool { return bytes.Equal(val, req.LastPosition) - }) + }, csync.WithTimeout(time.Minute)) // TODO make the timeout configurable (https://github.com/ConduitIO/conduit/issues/183) // flush cached records, allow it to take at most 1 minute flushCtx, cancel := context.WithTimeout(ctx, time.Minute) // TODO make the timeout configurable diff --git a/internal/atomicvaluewatcher.go b/internal/atomicvaluewatcher.go deleted file mode 100644 index b1ad9b7f..00000000 --- a/internal/atomicvaluewatcher.go +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "errors" - "fmt" - "sync" -) - -// AtomicValueWatcher holds a reference to a value. Multiple goroutines are able to put or -// retrieve the value into/from the AtomicValueWatcher, as well as wait for a certain value -// to be put into the AtomicValueWatcher. -// It is similar to atomic.Value except the caller can call Await to be notified -// each time the value in AtomicValueWatcher changes. -type AtomicValueWatcher[T any] struct { - val T - m sync.RWMutex - listener chan T // a single listener, if needed we can expand to multiple in the future -} - -// Store sets val in AtomicValueWatcher and notifies the goroutine that called Await about -// the new value, if such a goroutine exists. -func (h *AtomicValueWatcher[T]) Store(val T) { - h.m.Lock() - defer h.m.Unlock() - - h.val = val - if h.listener != nil { - h.listener <- val - } -} - -// Load returns the current value stored in AtomicValueWatcher. -func (h *AtomicValueWatcher[T]) Load() T { - h.m.RLock() - defer h.m.RUnlock() - - return h.val -} - -// Await blocks and calls foundVal for every value that is put into the AtomicValueWatcher. -// Once foundVal returns true it stops blocking and returns nil. First call to -// foundVal will be with the current value stored in AtomicValueWatcher. Await can only be -// called by one goroutine at a time (we don't need anything more fancy right -// now), if two goroutines call Await one will receive an error. If the context -// gets cancelled before foundVal returns true, the function will return the -// context error. -func (h *AtomicValueWatcher[T]) Await(ctx context.Context, foundVal func(val T) bool) error { - val, err := h.subscribe() - if err != nil { - // the only option subscribe produces an error is if it is called - // concurrently which is an invalid use case at the moment - return fmt.Errorf("invalid use of AtomicValueWatcher.Await: %w", err) - } - defer h.unsubscribe() - - if foundVal(val) { - // first call to foundVal is with the current value - return nil - } - // val was not found yet, we need to wait some more - for { - select { - case <-ctx.Done(): - return ctx.Err() - case val := <-h.listener: - if foundVal(val) { - return nil - } - } - } -} - -// subscribe creates listener and returns the current value stored in AtomicValueWatcher at -// the time the listener was created. -func (h *AtomicValueWatcher[T]) subscribe() (T, error) { - h.m.Lock() - defer h.m.Unlock() - - if h.listener != nil { - var empty T - return empty, errors.New("another goroutine is already subscribed to changes") - } - h.listener = make(chan T) - - return h.val, nil -} - -func (h *AtomicValueWatcher[T]) unsubscribe() { - // drain channel and remove it - go func(in chan T) { - for range in { //nolint:revive // empty block for reason below - // do nothing, just drain channel in case new values come in - // while we try to unsubscribe - } - }(h.listener) - - h.m.Lock() - defer h.m.Unlock() - - close(h.listener) - h.listener = nil -} diff --git a/internal/atomicvaluewatcher_test.go b/internal/atomicvaluewatcher_test.go deleted file mode 100644 index f3eba49d..00000000 --- a/internal/atomicvaluewatcher_test.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright © 2022 Meroxa, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package internal - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/matryer/is" - "go.uber.org/goleak" -) - -func TestAtomicValueWatcher_GetEmptyValue(t *testing.T) { - is := is.New(t) - - var h AtomicValueWatcher[int] - got := h.Load() - is.Equal(0, got) -} - -func TestAtomicValueWatcher_GetEmptyPtr(t *testing.T) { - is := is.New(t) - - var h AtomicValueWatcher[*int] - got := h.Load() - is.Equal(nil, got) -} - -func TestAtomicValueWatcher_PutGetValue(t *testing.T) { - is := is.New(t) - - var h AtomicValueWatcher[int] - want := 123 - h.Store(want) - got := h.Load() - is.Equal(want, got) -} - -func TestAtomicValueWatcher_PutGetPtr(t *testing.T) { - is := is.New(t) - - var h AtomicValueWatcher[*int] - want := 123 - h.Store(&want) - got := h.Load() - is.Equal(&want, got) -} - -func TestAtomicValueWatcher_AwaitSuccess(t *testing.T) { - goleak.VerifyNone(t) - is := is.New(t) - - var h AtomicValueWatcher[int] - - putValue := make(chan int) - defer close(putValue) - go func() { - for val := range putValue { - h.Store(val) - } - }() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - i := 0 - err := h.Await(ctx, func(val int) bool { - i++ - switch i { - case 1: - is.Equal(0, val) // expected first value to be 0 - putValue <- 123 // put next value - return false // not the value we are looking for - case 2: - is.Equal(123, val) - putValue <- 555 // put next value - return false // not the value we are looking for - case 3: - is.Equal(555, val) - return true // that's what we were looking for - default: - is.Fail() // unexpected value for i - return false - } - }) - is.NoErr(err) - is.Equal(3, i) - - got := h.Load() - is.Equal(555, got) - - // we can still put more values into the watcher - h.Store(666) - got = h.Load() - is.Equal(666, got) -} - -func TestAtomicValueWatcher_AwaitContextCancel(t *testing.T) { - goleak.VerifyNone(t) - is := is.New(t) - - var h AtomicValueWatcher[int] - - ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) - defer cancel() - - i := 0 - err := h.Await(ctx, func(val int) bool { - i++ - is.Equal(0, val) - return false - }) - - is.Equal(ctx.Err(), err) - is.Equal(1, i) -} - -func TestAtomicValueWatcher_AwaitMultiple(t *testing.T) { - goleak.VerifyNone(t) - is := is.New(t) - - var h AtomicValueWatcher[int] - - var wg1, wg2, wg3 sync.WaitGroup - wg1.Add(1) - wg2.Add(1) - wg3.Add(1) - go func() { - defer wg3.Done() - err := h.Await(context.Background(), func(val int) bool { - wg1.Done() - wg2.Wait() // wait until test says it's ok to return - return true - }) - is.NoErr(err) - }() - - wg1.Wait() // wait for Await to actually run - - // try to run await a second time - err := h.Await(context.Background(), func(val int) bool { return false }) - is.True(err != nil) // expected an error from second Await call - - wg2.Done() // signal to first Await call to return - wg3.Wait() // wait for goroutine to stop running -} diff --git a/internal/cchan/chan.go b/internal/cchan/chan.go new file mode 100644 index 00000000..55269851 --- /dev/null +++ b/internal/cchan/chan.go @@ -0,0 +1,98 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cchan + +import ( + "context" + "time" +) + +// ChanOut is an output channel with utility methods. +type ChanOut[T any] <-chan T + +// Recv will try to receive a value from the channel, same as <-c. If the +// context is canceled before a value is received, it will return the context +// error. +func (c ChanOut[T]) Recv(ctx context.Context) (T, bool, error) { + select { + case val, ok := <-c: + return val, ok, nil + case <-ctx.Done(): + var empty T + return empty, false, ctx.Err() + } +} + +// RecvTimeout will try to receive a value from the channel, same as <-c. If the +// context is canceled before a value is received or the timeout is reached, it +// will return the context error. +func (c ChanOut[T]) RecvTimeout(ctx context.Context, timeout time.Duration) (T, bool, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return c.Recv(ctx) +} + +// ChanIn is an input channel with utility methods. +type ChanIn[T any] chan<- T + +// Send will try to send a value to the channel, same as c<-v. If the context is +// canceled before a value is sent, it will return the context error. +func (c ChanIn[T]) Send(ctx context.Context, v T) error { + select { + case c <- v: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// SendTimeout will try to send a value to the channel, same as c<-v. If the +// context is canceled before a value is sent or the timeout is reached, it will +// return the context error. +func (c ChanIn[T]) SendTimeout(ctx context.Context, v T, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return c.Send(ctx, v) +} + +// Chan is a channel with utility methods. +type Chan[T any] chan T + +// Recv will try to receive a value from the channel, same as <-c. If the +// context is canceled before a value is received, it will return the context +// error. +func (c Chan[T]) Recv(ctx context.Context) (T, bool, error) { + return ChanOut[T]((chan T)(c)).Recv(ctx) +} + +// RecvTimeout will try to receive a value from the channel, same as <-c. If the +// context is canceled before a value is received or the timeout is reached, it +// will return the context error. +func (c Chan[T]) RecvTimeout(ctx context.Context, timeout time.Duration) (T, bool, error) { + return ChanOut[T]((chan T)(c)).RecvTimeout(ctx, timeout) +} + +// Send will try to send a value to the channel, same as c<-v. If the context is +// canceled before a value is sent, it will return the context error. +func (c Chan[T]) Send(ctx context.Context, v T) error { + return ChanIn[T]((chan T)(c)).Send(ctx, v) +} + +// SendTimeout will try to send a value to the channel, same as c<-v. If the +// context is canceled before a value is sent or the timeout is reached, it will +// return the context error. +func (c Chan[T]) SendTimeout(ctx context.Context, v T, timeout time.Duration) error { + return ChanIn[T]((chan T)(c)).SendTimeout(ctx, v, timeout) +} diff --git a/internal/cchan/chan_test.go b/internal/cchan/chan_test.go new file mode 100644 index 00000000..c627d286 --- /dev/null +++ b/internal/cchan/chan_test.go @@ -0,0 +1,140 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cchan + +import ( + "context" + "testing" + "time" + + "github.com/matryer/is" +) + +func TestChanOut_Recv_Success(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + want := []int{1, 123, 1337} + c := make(chan int, len(want)) + for _, w := range want { + c <- w + } + + for i := range want { + got, ok, err := ChanOut[int](c).Recv(ctx) + is.NoErr(err) + is.True(ok) + is.Equal(want[i], got) + } +} + +func TestChanOut_Recv_Closed(t *testing.T) { + is := is.New(t) + ctx := context.Background() + c := make(chan int) + + close(c) + + got, ok, err := ChanOut[int](c).Recv(ctx) + is.NoErr(err) + is.True(!ok) + is.Equal(got, 0) +} + +func TestChanOut_Recv_Canceled(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + c := make(chan int) + got, ok, err := ChanOut[int](c).Recv(ctx) + is.Equal(err, context.Canceled) + is.True(!ok) + is.Equal(got, 0) +} + +func TestChanOut_RecvTimeout_DeadlineExceeded(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + c := make(chan int) + start := time.Now() + got, ok, err := ChanOut[int](c).RecvTimeout(ctx, time.Millisecond*100) + since := time.Since(start) + + is.Equal(err, context.DeadlineExceeded) + is.True(!ok) + is.Equal(got, 0) + + is.True(since >= time.Millisecond*100) +} + +func TestChanIn_Send_Success(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + want := []int{1, 123, 1337} + c := make(chan int, len(want)) + + for _, w := range want { + err := ChanIn[int](c).Send(ctx, w) + is.NoErr(err) + } + + for i := range want { + got, ok := <-c + is.True(ok) + is.Equal(got, want[i]) + } +} + +func TestChanIn_Send_Closed(t *testing.T) { + is := is.New(t) + ctx := context.Background() + c := make(chan int) + + close(c) + + defer func() { + r := recover() + is.True(r != nil) + is.Equal(r.(error).Error(), "send on closed channel") + }() + _ = ChanIn[int](c).Send(ctx, 1) + is.Fail() // unreachable +} + +func TestChanIn_Send_Canceled(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + c := make(chan int) + err := ChanIn[int](c).Send(ctx, 1) + is.Equal(err, context.Canceled) +} + +func TestChanIn_SendTimeout_DeadlineExceeded(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + c := make(chan int) + start := time.Now() + err := ChanIn[int](c).SendTimeout(ctx, 1, time.Millisecond*100) + since := time.Since(start) + + is.Equal(err, context.DeadlineExceeded) + is.True(since >= time.Millisecond*100) +} diff --git a/internal/csync/opt.go b/internal/csync/opt.go new file mode 100644 index 00000000..f8bca623 --- /dev/null +++ b/internal/csync/opt.go @@ -0,0 +1,70 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "time" +) + +type Option interface { + apply() // dummy method, real methods are in specific options +} + +type ctxOption interface { + Option + applyCtx(context.Context) (context.Context, context.CancelFunc) +} + +type ctxOptionFn func(ctx context.Context) (context.Context, context.CancelFunc) + +func (c ctxOptionFn) apply() { /* noop */ } +func (c ctxOptionFn) applyCtx(ctx context.Context) (context.Context, context.CancelFunc) { + return c(ctx) +} + +func applyAndRemoveCtxOptions(ctx context.Context, opts []Option) (context.Context, context.CancelFunc, []Option) { + if len(opts) == 0 { + return ctx, func() {}, opts // shortcut + } + + remainingOpts := make([]Option, 0, len(opts)) + var cancelFns []context.CancelFunc + for _, opt := range opts { + ctxOpt, ok := opt.(ctxOption) + if !ok { + remainingOpts = append(remainingOpts, opt) + continue + } + + var cancel context.CancelFunc + ctx, cancel = ctxOpt.applyCtx(ctx) + cancelFns = append(cancelFns, cancel) + } + return ctx, func() { + // call cancel functions in reverse + for i := len(cancelFns) - 1; i >= 0; i-- { + cancelFns[i]() + } + }, remainingOpts +} + +// WithTimeout cancels the operation after the timeout if it didn't succeed yet. +// The function returns context.DeadlineExceeded if the timeout is reached. +func WithTimeout(timeout time.Duration) Option { + return ctxOptionFn(func(ctx context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(ctx, timeout) + }) +} diff --git a/internal/csync/opt_test.go b/internal/csync/opt_test.go new file mode 100644 index 00000000..3cb025f7 --- /dev/null +++ b/internal/csync/opt_test.go @@ -0,0 +1,92 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "testing" + "time" + + "github.com/conduitio/conduit-connector-sdk/internal/cchan" + "github.com/matryer/is" +) + +type testOption struct { + foo any +} + +func (t testOption) apply() {} + +func TestCtxOption_WithTimeout_DeadlineExceeded(t *testing.T) { + ctx := context.Background() + is := is.New(t) + + opts := []Option{ + testOption{foo: "test"}, // not a context option + WithTimeout(time.Millisecond), + } + + gotCtx, gotCancel, gotOpts := applyAndRemoveCtxOptions(ctx, opts) + is.True(ctx != gotCtx) + is.True(gotCancel != nil) + is.Equal(gotOpts, []Option{testOption{foo: "test"}}) + + _, _, err := cchan.ChanOut[struct{}](gotCtx.Done()).RecvTimeout(ctx, time.Millisecond*100) + is.NoErr(err) + is.Equal(gotCtx.Err(), context.DeadlineExceeded) + + // running the cancel func should be a noop at this point, just testing if it panics + gotCancel() +} + +func TestCtxOption_WithTimeout_Cancel(t *testing.T) { + ctx := context.Background() + is := is.New(t) + + opts := []Option{ + testOption{foo: "test"}, // not a context option + WithTimeout(time.Second), + } + + gotCtx, gotCancel, gotOpts := applyAndRemoveCtxOptions(ctx, opts) + is.True(ctx != gotCtx) + is.True(gotCancel != nil) + is.Equal(gotOpts, []Option{testOption{foo: "test"}}) + + _, _, err := cchan.ChanOut[struct{}](gotCtx.Done()).RecvTimeout(ctx, time.Millisecond*100) + is.Equal(err, context.DeadlineExceeded) + + // running the cancel func should cancel the context now + gotCancel() + + _, _, err = cchan.ChanOut[struct{}](gotCtx.Done()).RecvTimeout(ctx, time.Millisecond*100) + is.NoErr(err) + is.Equal(gotCtx.Err(), context.Canceled) +} + +func TestCtxOption_Empty(t *testing.T) { + ctx := context.Background() + is := is.New(t) + + gotCtx, gotCancel, gotOpts := applyAndRemoveCtxOptions(ctx, nil) + is.Equal(ctx, gotCtx) + is.True(gotCancel != nil) + is.Equal(gotOpts, nil) + + gotCancel() // canceling the context shouldn't do anything + + _, _, err := cchan.ChanOut[struct{}](gotCtx.Done()).RecvTimeout(ctx, time.Millisecond*100) + is.Equal(err, context.DeadlineExceeded) +} diff --git a/internal/csync/run.go b/internal/csync/run.go new file mode 100644 index 00000000..d97ab638 --- /dev/null +++ b/internal/csync/run.go @@ -0,0 +1,44 @@ +// Copyright © 2023 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "fmt" + + "github.com/conduitio/conduit-connector-sdk/internal/cchan" +) + +// Run executes fn in a goroutine and waits for it to return. If the context +// gets canceled before that happens the method returns the context error. +// +// This is useful for executing long-running functions like sync.WaitGroup.Wait +// that don't take a context and can potentially block the execution forever. +func Run(ctx context.Context, fn func(), opts ...Option) error { + ctx, cancel, opts := applyAndRemoveCtxOptions(ctx, opts) + if len(opts) > 0 { + panic(fmt.Sprintf("invalid option type: %T", opts[0])) + } + defer cancel() + + done := make(chan struct{}) + go func() { + defer close(done) + fn() + }() + + _, _, err := cchan.ChanOut[struct{}](done).Recv(ctx) + return err +} diff --git a/internal/csync/valuewatcher.go b/internal/csync/valuewatcher.go new file mode 100644 index 00000000..505afd2c --- /dev/null +++ b/internal/csync/valuewatcher.go @@ -0,0 +1,164 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "fmt" + "sync" + + "github.com/conduitio/conduit-connector-sdk/internal/cchan" + "github.com/google/uuid" +) + +// ValueWatcher holds a reference to a value. Multiple goroutines are able to +// put or retrieve the value into/from the ValueWatcher, as well as wait for a +// certain value to be put into the ValueWatcher. +// It is similar to atomic.Value except the caller can call Watch to be notified +// each time the value in ValueWatcher changes. +type ValueWatcher[T any] struct { + val T + m sync.Mutex + listeners map[string]chan T +} + +type ValueWatcherFunc[T any] func(val T) bool + +// WatchValues is a utility function for creating a simple ValueWatcherFunc that +// waits for one of the supplied values. +func WatchValues[T comparable](want ...T) ValueWatcherFunc[T] { + if len(want) == 0 { + // this would block forever, prevent misuse + panic("invalid use of WatchValues, need to supply at least one value") + } + if len(want) == 1 { + // optimize + wantVal := want[0] + return func(val T) bool { + return val == wantVal + } + } + return func(val T) bool { + for _, wantVal := range want { + if val == wantVal { + return true + } + } + return false + } +} + +// Set stores val in ValueWatcher and notifies all goroutines that called Watch +// about the new value, if such goroutines exists. +func (vw *ValueWatcher[T]) Set(val T) { + vw.m.Lock() + defer vw.m.Unlock() + + vw.val = val + for _, l := range vw.listeners { + l <- val + } +} + +// Get returns the current value stored in ValueWatcher. +func (vw *ValueWatcher[T]) Get() T { + vw.m.Lock() + defer vw.m.Unlock() + + return vw.val +} + +// Watch blocks and calls f for every value that is put into the ValueWatcher. +// Once f returns true it stops blocking and returns nil. First call to f will +// be with the current value stored in ValueWatcher. Note that if no value was +// stored in ValueWatcher yet, the zero value of type T will be passed to f. +// +// Watch can be safely called by multiple goroutines. If the context gets +// cancelled before f returns true, the function will return the context error. +func (vw *ValueWatcher[T]) Watch(ctx context.Context, f ValueWatcherFunc[T], opts ...Option) (T, error) { + ctx, cancel, opts := applyAndRemoveCtxOptions(ctx, opts) + if len(opts) > 0 { + panic(fmt.Sprintf("invalid option type: %T", opts[0])) + } + defer cancel() + + val, found, listener, unsubscribe := vw.findOrSubscribe(f) + if found { + return val, nil + } + defer unsubscribe() + + // val was not found yet, we need to keep watching + clistener := cchan.ChanOut[T](listener) + for { + val, ok, err := clistener.Recv(ctx) + if err != nil { + var empty T + return empty, ctx.Err() + } + if ok && f(val) { + return val, nil + } + } +} + +func (vw *ValueWatcher[T]) findOrSubscribe(f ValueWatcherFunc[T]) (T, bool, chan T, func()) { + vw.m.Lock() + defer vw.m.Unlock() + + // first call to f is with the current value + if f(vw.val) { + return vw.val, true, nil, nil + } + + listener, unsubscribe := vw.subscribe() + var empty T + return empty, false, listener, unsubscribe +} + +// subscribe creates a channel that will receive changes and returns it +// alongside a cleanup function that closes the channel and removes it from +// ValueWatcher. +func (vw *ValueWatcher[T]) subscribe() (chan T, func()) { + if vw.listeners == nil { + vw.listeners = make(map[string]chan T) + } + + id := uuid.NewString() + c := make(chan T) + + vw.listeners[id] = c + + return c, func() { vw.unsubscribe(id, c) } +} + +func (vw *ValueWatcher[T]) unsubscribe(id string, c chan T) { + // drain channel and remove it + go func() { + //nolint:revive // see comment below + for range c { + // Do nothing, just drain channel. In case another goroutine tries + // to store a new value by calling ValueWatcher.Set, this goroutine + // will unblock it until we successfully unsubscribe and remove the + // channel from listeners. + } + }() + + vw.m.Lock() + defer vw.m.Unlock() + + close(c) + delete(vw.listeners, id) +} diff --git a/internal/csync/valuewatcher_benchmark_test.go b/internal/csync/valuewatcher_benchmark_test.go new file mode 100644 index 00000000..a43e0513 --- /dev/null +++ b/internal/csync/valuewatcher_benchmark_test.go @@ -0,0 +1,41 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "testing" +) + +var valSink int +var errSink error + +func BenchmarkValueWatcher_Watch(b *testing.B) { + var w ValueWatcher[int] + ctx := context.Background() + + for i := 0; i < b.N; i++ { + valSink, errSink = w.Watch(ctx, func(val int) bool { return val == 0 }) + } +} + +func BenchmarkValueWatcher_WatchValues(b *testing.B) { + var w ValueWatcher[int] + ctx := context.Background() + + for i := 0; i < b.N; i++ { + valSink, errSink = w.Watch(ctx, WatchValues(0)) + } +} diff --git a/internal/csync/valuewatcher_test.go b/internal/csync/valuewatcher_test.go new file mode 100644 index 00000000..93f02d34 --- /dev/null +++ b/internal/csync/valuewatcher_test.go @@ -0,0 +1,240 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/matryer/is" + "go.uber.org/goleak" +) + +func TestValueWatcher_GetEmptyValue(t *testing.T) { + is := is.New(t) + + var h ValueWatcher[int] + got := h.Get() + is.Equal(0, got) +} + +func TestValueWatcher_GetEmptyPtr(t *testing.T) { + is := is.New(t) + + var h ValueWatcher[*int] + got := h.Get() + is.Equal(nil, got) +} + +func TestValueWatcher_PutGetValue(t *testing.T) { + is := is.New(t) + + var h ValueWatcher[int] + want := 123 + h.Set(want) + got := h.Get() + is.Equal(want, got) +} + +func TestValueWatcher_PutGetPtr(t *testing.T) { + is := is.New(t) + + var h ValueWatcher[*int] + want := 123 + h.Set(&want) + got := h.Get() + is.Equal(&want, got) +} + +func TestValueWatcher_WatchSuccess(t *testing.T) { + goleak.VerifyNone(t) + is := is.New(t) + + var h ValueWatcher[int] + + putValue := make(chan int) + defer close(putValue) + go func() { + for val := range putValue { + h.Set(val) + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + i := 0 + val, err := h.Watch(ctx, func(val int) bool { + i++ + switch i { + case 1: + is.Equal(0, val) // expected first value to be 0 + putValue <- 123 // put next value + return false // not the value we are looking for + case 2: + is.Equal(123, val) + putValue <- 555 // put next value + return false // not the value we are looking for + case 3: + is.Equal(555, val) + return true // that's what we were looking for + default: + is.Fail() // unexpected value for i + return false + } + }) + is.NoErr(err) + is.Equal(3, i) + is.Equal(555, val) + + got := h.Get() + is.Equal(555, got) + + // we can still put more values into the watcher + h.Set(666) + got = h.Get() + is.Equal(666, got) +} + +func TestValueWatcher_WatchContextCancel(t *testing.T) { + goleak.VerifyNone(t) + is := is.New(t) + + var h ValueWatcher[int] + h.Set(1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + defer cancel() + + i := 0 + val, err := h.Watch(ctx, func(val int) bool { + i++ + is.Equal(1, val) + return false + }) + + is.Equal(ctx.Err(), err) + is.Equal(1, i) + is.Equal(0, val) +} + +func TestValueWatcher_WatchMultiple(t *testing.T) { + const watcherCount = 100 + goleak.VerifyNone(t) + is := is.New(t) + + var h ValueWatcher[int] + + // wg1 waits until all watchers are subscribed to changes + var wg1 sync.WaitGroup + // wg2 waits until all watchers found the value they were looking for + var wg2 sync.WaitGroup + + wg1.Add(watcherCount) + wg2.Add(watcherCount) + for i := 0; i < watcherCount; i++ { + go func(i int) { + defer wg2.Done() + var once sync.Once + old := -1 // first call to Watch will be with 0, pretend old value was -1 + val, err := h.Watch(context.Background(), func(val int) bool { + // first time the function is called with the current value in + // ValueWatcher, after that we know the watcher is successfully + // subscribed + once.Do(wg1.Done) + // make sure that we see all changes + is.Equal(old+1, val) + old = val + // only stop when our specific number shows up + return val == i + }) + is.NoErr(err) + is.Equal(val, i) + }(i) + } + + // wait for all watchers to be subscribed + err := (*WaitGroup)(&wg1).Wait(context.Background(), WithTimeout(time.Second)) + is.NoErr(err) + + // set the value incrementally higher + for i := 1; i < watcherCount; i++ { + h.Set(i) + } + + // wait for all watchers to be done + err = (*WaitGroup)(&wg2).Wait(context.Background(), WithTimeout(time.Second)) + is.NoErr(err) +} + +func TestValueWatcher_Concurrency(t *testing.T) { + const watcherCount = 40 + const setterCount = 40 + const setCount = 20 + + goleak.VerifyNone(t) + is := is.New(t) + + var h ValueWatcher[int] + + // wg1 waits until all watchers are subscribed to changes + var wg1 sync.WaitGroup + // wg2 waits until all watchers found the value they were looking for + var wg2 sync.WaitGroup + + wg1.Add(watcherCount) + wg2.Add(watcherCount) + for i := 0; i < watcherCount; i++ { + go func(i int) { + defer wg2.Done() + var once sync.Once + var count int + _, err := h.Watch(context.Background(), func(val int) bool { + once.Do(wg1.Done) + count++ + // +1 because of first call + return count == (setterCount*setCount)+1 + }) + is.NoErr(err) + is.Equal(count, (setterCount*setCount)+1) + }(i) + } + + // wait for all watchers to be subscribed + err := (*WaitGroup)(&wg1).Wait(context.Background(), WithTimeout(time.Second)) + is.NoErr(err) + + // wg3 waits for all setters to stop setting values + var wg3 sync.WaitGroup + wg3.Add(setterCount) + for i := 0; i < setterCount; i++ { + go func(i int) { + defer wg3.Done() + for j := 0; j < setCount; j++ { + h.Set(i) + } + }(i) + } + + // wait for all setters to be done + err = (*WaitGroup)(&wg3).Wait(context.Background(), WithTimeout(time.Second)) + is.NoErr(err) + + // wait for all watchers to be done + err = (*WaitGroup)(&wg2).Wait(context.Background(), WithTimeout(time.Second)) + is.NoErr(err) +} diff --git a/internal/csync/waitgroup.go b/internal/csync/waitgroup.go new file mode 100644 index 00000000..2abb1115 --- /dev/null +++ b/internal/csync/waitgroup.go @@ -0,0 +1,51 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "sync" +) + +// WaitGroup is a sync.WaitGroup with utility methods. +type WaitGroup sync.WaitGroup + +// Add adds delta, which may be negative, to the WaitGroup counter. +// If the counter becomes zero, all goroutines blocked on Wait are released. +// If the counter goes negative, Add panics. +// +// Note that calls with a positive delta that occur when the counter is zero +// must happen before a Wait. Calls with a negative delta, or calls with a +// positive delta that start when the counter is greater than zero, may happen +// at any time. +// Typically this means the calls to Add should execute before the statement +// creating the goroutine or other event to be waited for. +// If a WaitGroup is reused to wait for several independent sets of events, +// new Add calls must happen after all previous Wait calls have returned. +// See the WaitGroup example. +func (wg *WaitGroup) Add(delta int) { + (*sync.WaitGroup)(wg).Add(delta) +} + +// Done decrements the WaitGroup counter by one. +func (wg *WaitGroup) Done() { + (*sync.WaitGroup)(wg).Done() +} + +// Wait blocks until the WaitGroup counter is zero. If the context gets canceled +// before that happens the method returns an error. +func (wg *WaitGroup) Wait(ctx context.Context, opts ...Option) error { + return Run(ctx, (*sync.WaitGroup)(wg).Wait, opts...) +} diff --git a/internal/csync/waitgroup_test.go b/internal/csync/waitgroup_test.go new file mode 100644 index 00000000..868c7886 --- /dev/null +++ b/internal/csync/waitgroup_test.go @@ -0,0 +1,61 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package csync + +import ( + "context" + "testing" + "time" + + "github.com/matryer/is" +) + +func TestWaitGroup_Wait_Empty(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + var wg WaitGroup + wg.Add(1) + wg.Done() + err := wg.Wait(ctx) + is.NoErr(err) +} + +func TestWaitGroup_Wait_Canceled(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var wg WaitGroup + wg.Add(1) + + err := wg.Wait(ctx) + is.Equal(err, context.Canceled) +} + +func TestWaitGroup_WaitTimeout_DeadlineReached(t *testing.T) { + is := is.New(t) + ctx := context.Background() + + var wg WaitGroup + wg.Add(1) + + start := time.Now() + err := wg.Wait(ctx, WithTimeout(time.Millisecond*100)) + since := time.Since(start) + + is.Equal(err, context.DeadlineExceeded) + is.True(since >= time.Millisecond*100) +} diff --git a/source.go b/source.go index 6a885421..5bd6e848 100644 --- a/source.go +++ b/source.go @@ -25,6 +25,8 @@ import ( "github.com/conduitio/conduit-connector-protocol/cpluginv1" "github.com/conduitio/conduit-connector-sdk/internal" + "github.com/conduitio/conduit-connector-sdk/internal/cchan" + "github.com/conduitio/conduit-connector-sdk/internal/csync" "github.com/jpillora/backoff" "go.uber.org/multierr" "gopkg.in/tomb.v2" @@ -206,14 +208,13 @@ func (a *sourcePluginAdapter) runRead(ctx context.Context, stream cpluginv1.Sour } if errors.Is(err, ErrBackoffRetry) { // the plugin wants us to retry reading later - select { - case <-ctx.Done(): + _, _, err := cchan.ChanOut[time.Time](time.After(b.Duration())).Recv(ctx) + if err != nil { // the plugin is using the SDK for long polling and relying // on the SDK to check for a cancelled context return nil - case <-time.After(b.Duration()): - continue } + continue } return fmt.Errorf("read plugin error: %w", err) } @@ -253,10 +254,7 @@ func (a *sourcePluginAdapter) Stop(ctx context.Context, _ cpluginv1.SourceStopRe // wait for read to actually stop running with a timeout, in case the // connector gets stuck - waitCtx, cancel := context.WithTimeout(ctx, stopTimeout) - defer cancel() - - err := a.waitForClose(waitCtx, a.readDone) + _, _, err := cchan.ChanOut[struct{}](a.readDone).RecvTimeout(ctx, stopTimeout) if err != nil { Logger(ctx).Warn().Err(err).Msg("failed to wait for Read to stop running") return cpluginv1.SourceStopResponse{}, fmt.Errorf("failed to stop connector: %w", err) @@ -282,11 +280,7 @@ func (a *sourcePluginAdapter) Teardown(ctx context.Context, _ cpluginv1.SourceTe var waitErr error if a.t != nil { - // wait for at most 1 minute - waitCtx, cancel := context.WithTimeout(ctx, teardownTimeout) - defer cancel() - - waitErr = a.waitForRun(waitCtx) // wait for Run to stop running + waitErr = a.waitForRun(ctx, teardownTimeout) // wait for Run to stop running if waitErr != nil { // just log error and continue to call Teardown to keep guarantee Logger(ctx).Warn().Err(waitErr).Msg("failed to wait for Run to stop running") @@ -316,23 +310,14 @@ func (a *sourcePluginAdapter) LifecycleOnDeleted(ctx context.Context, req cplugi // waitForRun returns once the Run function returns or the context gets // cancelled, whichever happens first. If the context gets cancelled the context // error will be returned. -func (a *sourcePluginAdapter) waitForRun(ctx context.Context) error { - // wait for all acks to be sent back to Conduit - ackFuncsDone := make(chan struct{}) - go func() { - _ = a.t.Wait() // ignore tomb error, it will be returned in Run anyway - close(ackFuncsDone) - }() - return a.waitForClose(ctx, ackFuncsDone) -} - -func (a *sourcePluginAdapter) waitForClose(ctx context.Context, stop chan struct{}) error { - select { - case <-stop: - return nil - case <-ctx.Done(): - return ctx.Err() - } +func (a *sourcePluginAdapter) waitForRun(ctx context.Context, timeout time.Duration) error { + // wait for all acks to be sent back to Conduit, stop waiting if context + // gets cancelled or timeout is reached + return csync.Run( + ctx, + func() { _ = a.t.Wait() }, // ignore tomb error, it will be returned in Run anyway + csync.WithTimeout(timeout), + ) } func (a *sourcePluginAdapter) convertRecord(r Record) cpluginv1.Record {