diff --git a/provider.go b/provider.go index 6dd2a56..8940171 100644 --- a/provider.go +++ b/provider.go @@ -218,6 +218,7 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { o(options) } + var muConn sync.Mutex var conn, nextConn *connection var lastAbsTime uint64 chConnection := make(chan *connection) @@ -235,6 +236,9 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { var timeout *time.Timer resetTimeout := func() { timeout = time.AfterFunc(options.connectionTimeout, func() { + muConn.Lock() + defer muConn.Unlock() + options.logger.Warnf(`Receiving block timed out, clean connections: { StreamID: "%s" }`, p.streamID) cleanConnections() }) @@ -250,13 +254,23 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { close(allDone) }() + closed := make(chan struct{}) + var closedOnce sync.Once + shutdown := func(ctx context.Context) error { + closedOnce.Do(func() { + close(closed) + }) + + muConn.Lock() timeout.Stop() cleanConnections() if chConnection != nil { close(chConnection) chConnection = nil } + muConn.Unlock() + select { case <-allDone: cancel() @@ -267,7 +281,10 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { } prepareNextConn := func() { nextConn = newConnection(options) - chConnection <- nextConn + select { + case chConnection <- nextConn: + case <-closed: + } } switchToNextConn := func(startTime uint64) { if conn != nil { @@ -282,6 +299,7 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { writer := &blockWriter{ fnWrite: func(bt *BlockWithBaseTimecode) error { + var forceSwitchConn bool absTime := uint64(bt.AbsTimecode()) if lastAbsTime != 0 { diff := int64(absTime - lastAbsTime) @@ -298,10 +316,17 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { if nextConn == nil { prepareNextConn() } - switchToNextConn(absTime) + forceSwitchConn = true } } + muConn.Lock() + defer muConn.Unlock() + + if forceSwitchConn { + switchToNextConn(absTime) + } + if conn == nil || (nextConn == nil && int16(absTime-conn.baseTimecode) > 8000) { options.logger.Debugf(`Prepare next connection: { StreamID: "%s" }`, p.streamID) prepareNextConn() @@ -321,6 +346,7 @@ func (p *Provider) PutMedia(opts ...PutMediaOption) (BlockWriter, error) { p.streamID, bt.AbsTimecode(), ErrWriteTimeout, ) + case <-closed: } return nil }, diff --git a/provider_test.go b/provider_test.go index a7c965a..ed0173d 100644 --- a/provider_test.go +++ b/provider_test.go @@ -571,6 +571,51 @@ func TestProvider_shutdownTwice(t *testing.T) { } } +func TestProvider_writeAfterClose(t *testing.T) { + server := kvsm.NewKinesisVideoServer() + defer server.Close() + + for i := 0; i < 100 && !t.Failed(); i++ { + pro := newProvider(t, server) + + w, err := pro.PutMedia( + kvm.OnError(func(e error) {}), + ) + if err != nil { + t.Fatal(err) + } + + go func() { + for { + if _, err := w.ReadResponse(); err != nil { + return + } + } + }() + time.Sleep(10 * time.Millisecond) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + if err := w.Close(); err != nil { + t.Error(err) + } + }() + go func() { + defer wg.Done() + if err := w.Write(&kvm.BlockWithBaseTimecode{ + Timecode: 1, + Block: newBlock(0), + }); err != nil { + t.Error(err) + } + }() + wg.Wait() + } +} + func newProvider(t *testing.T, server *kvsm.KinesisVideoServer) *kvm.Provider { cfg := &aws.Config{ Credentials: credentials.NewStaticCredentials("key", "secret", "token"),