diff --git a/provider.go b/provider.go index be53958..ec7bb26 100644 --- a/provider.go +++ b/provider.go @@ -98,9 +98,11 @@ type BlockWriter interface { Write(*BlockWithBaseTimecode) error // ReadResponse reads a response from Kinesis Video Stream. ReadResponse() (*FragmentEvent, error) - // Close immediately shuts down the client + // Close immediately shuts down the client. Close() error - // Shutdown gracefully shuts down the client without interrupting on-going PutMedia request + // Shutdown gracefully shuts down the client without interrupting on-going PutMedia request. + // If Shotdown returned an error, some of the internal resources might not released yet and + // caller should call Shutdown or Close again. Shutdown(ctx context.Context) error } @@ -248,14 +250,20 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { close(allDone) }() - shutdown := func(ctx context.Context) { + shutdown := func(ctx context.Context) error { timeout.Stop() cleanConnections() - close(chConnection) + if chConnection != nil { + close(chConnection) + chConnection = nil + } select { case <-allDone: + cancel() case <-ctx.Done(): + return ctx.Err() } + return nil } prepareNextConn := func() { nextConn = newConnection(options) @@ -322,13 +330,11 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { return resp, nil }, fnShutdown: func(ctx context.Context) error { - shutdown(ctx) - return nil + return shutdown(ctx) }, fnClose: func() error { cancel() - shutdown(context.Background()) - return nil + return shutdown(context.Background()) }, } diff --git a/provider_test.go b/provider_test.go index 09cc144..69f3969 100644 --- a/provider_test.go +++ b/provider_test.go @@ -539,6 +539,40 @@ func TestProvider_WithPutMediaLogger(t *testing.T) { } } +func TestProvider_shutdownTwice(t *testing.T) { + server := kvsm.NewKinesisVideoServer() + defer server.Close() + + pro := newProvider(t, server) + + w, err := pro.PutMedia() + if err != nil { + t.Fatal(err) + } + + go func() { + for { + if _, err := w.ReadResponse(); err != nil { + return + } + } + }() + time.Sleep(10 * time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if err := w.Shutdown(ctx); err != context.Canceled { + t.Fatalf("Expected error: %v, got: %v", context.Canceled, err) + } + if err := w.Shutdown(ctx); err != context.Canceled { + t.Fatalf("Expected error: %v, got: %v", context.Canceled, err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } +} + func newProvider(t *testing.T, server *kvsm.KinesisVideoServer) *kvm.Provider { cfg := &aws.Config{ Credentials: credentials.NewStaticCredentials("key", "secret", "token"),