diff --git a/adaptor/rethinkdb/writer.go b/adaptor/rethinkdb/writer.go index 8be968d44..d29e61e91 100644 --- a/adaptor/rethinkdb/writer.go +++ b/adaptor/rethinkdb/writer.go @@ -30,8 +30,9 @@ type Writer struct { } type bulkOperation struct { - s *r.Session - docs []map[string]interface{} + s *r.Session + confirms chan struct{} + docs []map[string]interface{} } func newWriter(done chan struct{}, wg *sync.WaitGroup) *Writer { @@ -51,7 +52,11 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error switch msg.OP() { case ops.Delete: w.flushAll() - return msg, do(r.DB(rSession.Database()).Table(table).Get(prepareDocument(msg)["id"]).Delete(), rSession) + return msg, do( + r.DB(rSession.Database()).Table(table).Get(prepareDocument(msg)["id"]).Delete(), + rSession, + msg.Confirms(), + ) case ops.Insert: w.Lock() bOp, ok := w.bulkMap[table] @@ -62,6 +67,9 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error } w.bulkMap[table] = bOp } + if msg.Confirms() != nil { + bOp.confirms = msg.Confirms() + } bOp.docs = append(bOp.docs, prepareDocument(msg)) w.Unlock() w.opCounter++ @@ -70,7 +78,11 @@ func (w *Writer) Write(msg message.Msg) func(client.Session) (message.Msg, error } case ops.Update: w.flushAll() - return msg, do(r.DB(rSession.Database()).Table(table).Insert(prepareDocument(msg), r.InsertOpts{Conflict: "replace"}), rSession) + return msg, do( + r.DB(rSession.Database()).Table(table).Insert(prepareDocument(msg), r.InsertOpts{Conflict: "replace"}), + rSession, + msg.Confirms(), + ) } return msg, nil } @@ -118,7 +130,7 @@ func (w *Writer) flushAll() error { if err != nil { return err } - if err := handleResponse(&resp); err != nil { + if err := handleResponse(&resp, bOp.confirms); err != nil { return err } } @@ -126,20 +138,23 @@ func (w *Writer) flushAll() error { return nil } -func do(t r.Term, s *r.Session) error { +func do(t r.Term, s *r.Session, confirms chan struct{}) error { resp, err := t.RunWrite(s) if err != nil { return err } - return handleResponse(&resp) + return handleResponse(&resp, confirms) } // handleresponse takes the rethink response and turn it into something we can consume elsewhere -func handleResponse(resp *r.WriteResponse) error { +func handleResponse(resp *r.WriteResponse, confirms chan struct{}) error { if resp.Errors != 0 { if !strings.Contains(resp.FirstError, "Duplicate primary key") { // we don't care about this error return fmt.Errorf("%s\n%s", "problem inserting docs", resp.FirstError) } } + if confirms != nil { + close(confirms) + } return nil } diff --git a/adaptor/rethinkdb/writer_test.go b/adaptor/rethinkdb/writer_test.go index 26ab11633..b12343866 100644 --- a/adaptor/rethinkdb/writer_test.go +++ b/adaptor/rethinkdb/writer_test.go @@ -77,9 +77,15 @@ func TestBulkInsert(t *testing.T) { if err != nil { t.Fatalf("unable to obtain session to rethinkdb, %s", err) } + confirms := make(chan struct{}) + var confirmed bool + go func() { + <-confirms + confirmed = true + }() for i := 0; i < 999; i++ { - msg := message.From(ops.Insert, "bulk", map[string]interface{}{"i": i}) - if _, err := w.Write(msg)(s); err != nil { + msg := message.From(ops.Insert, "bulk", map[string]interface{}{"id": i, "i": i}) + if _, err := w.Write(message.WithConfirms(confirms, msg))(s); err != nil { t.Errorf("unexpected Insert error, %s", err) } } @@ -96,6 +102,27 @@ func TestBulkInsert(t *testing.T) { if count != 999 { t.Errorf("[bulk] mismatched doc count, expected 999, got %d", count) } + + if !confirmed { + t.Errorf("[bulk] confirm chan never closed but should have") + } + + for i := 0; i < 2000; i++ { + msg := message.From(ops.Insert, "bulk", map[string]interface{}{"id": i, "i": i}) + if _, err := w.Write(msg)(s); err != nil { + t.Errorf("unexpected Insert error, %s", err) + } + } + + countResp, err = r.DB(writerTestData.DB).Table("bulk").Count().Run(defaultSession.session) + if err != nil { + t.Errorf("unable to determine table count, %s", err) + } + countResp.One(&count) + countResp.Close() + if count != 2000 { + t.Errorf("[bulk] mismatched doc count, expected 2000, got %d", count) + } } func TestInsert(t *testing.T) { diff --git a/pipe/pipe.go b/pipe/pipe.go index c065bc2a6..41f17979c 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -8,6 +8,7 @@ package pipe import ( "errors" + "sync" "time" "github.com/compose/transporter/events" @@ -25,7 +26,7 @@ var ( type messageChan chan TrackedMessage func newMessageChan() messageChan { - return make(chan TrackedMessage) + return make(chan TrackedMessage, 10) } type TrackedMessage struct { @@ -48,8 +49,9 @@ type Pipe struct { MessageCount int path string // the path of this pipe (for events and errors) - chStop chan chan bool + chStop chan struct{} listening bool + wg sync.WaitGroup } // NewPipe creates a new Pipe. If the pipe that is passed in is nil, then this pipe will be treated as a source pipe that just serves to emit messages. @@ -60,7 +62,7 @@ func NewPipe(pipe *Pipe, path string) *Pipe { p := &Pipe{ Out: make([]messageChan, 0), path: path, - chStop: make(chan chan bool), + chStop: make(chan struct{}), } if pipe != nil { @@ -84,22 +86,25 @@ func (p *Pipe) Listen(fn func(message.Msg, offset.Offset) (message.Msg, error)) return ErrUnableToListen } p.listening = true + p.wg.Add(1) for { // check for stop select { - case c := <-p.chStop: - p.Stopped = true - c <- true + case <-p.chStop: + if len(p.In) > 0 { + log.With("buffer_length", len(p.In)).Infoln("received stop, message buffer not empty, continuing...") + continue + } + log.Infoln("received stop, message buffer is empty, closing...") + p.wg.Done() return nil case m := <-p.In: - if p.Stopped { - break - } outmsg, err := fn(m.Msg, m.Off) if err != nil { p.Stopped = true p.Err <- err - break + p.wg.Done() + return err } if outmsg == nil { break @@ -120,11 +125,34 @@ func (p *Pipe) Stop() { // we only worry about the stop channel if we're in a listening loop if p.listening { - c := make(chan bool) - p.chStop <- c - <-c + close(p.chStop) + p.wg.Wait() + return + } + + timeout := time.After(10 * time.Second) + for { + select { + case <-timeout: + log.Errorln("timeout reached waiting for Out channels to clear") + return + default: + } + if p.empty() { + return + } + time.Sleep(1 * time.Second) + } + } +} + +func (p *Pipe) empty() bool { + for _, ch := range p.Out { + if len(ch) > 0 { + return false } } + return true } // Send emits the given message on the 'Out' channel. the send Timesout after 100 ms in order to chaeck of the Pipe has stopped and we've been asked to exit. @@ -138,12 +166,6 @@ func (p *Pipe) Send(msg message.Msg, off offset.Offset) { select { case ch <- TrackedMessage{msg, off}: break A - case <-time.After(100 * time.Millisecond): - if p.Stopped { - // return, with no guarantee - log.Infoln("returning with no guarantee") - return - } } } } diff --git a/pipe/pipe_test.go b/pipe/pipe_test.go index 93601ef43..c0d08516b 100644 --- a/pipe/pipe_test.go +++ b/pipe/pipe_test.go @@ -17,7 +17,6 @@ func TestSend(t *testing.T) { source := NewPipe(nil, "source") sink1 := NewPipe(source, "sink1") go sink1.Listen(func(msg message.Msg, _ offset.Offset) (message.Msg, error) { - time.Sleep(200 * time.Millisecond) msgsProcessed++ return msg, nil }) @@ -26,42 +25,34 @@ func TestSend(t *testing.T) { msgsProcessed++ return msg, nil }) - go func() { - source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - }() - time.Sleep(300 * time.Millisecond) - if msgsProcessed != 3 { - t.Errorf("unexpected messages processed count, expected 3, got %d", msgsProcessed) - } + source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) + source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) source.Stop() sink1.Stop() sink2.Stop() + if msgsProcessed != 4 { + t.Errorf("unexpected messages processed count, expected 4, got %d", msgsProcessed) + } } -func TestSendTimeout(t *testing.T) { +func TestStopMessageInFlight(t *testing.T) { var msgsProcessed int - source := NewPipe(nil, "source") - sink1 := NewPipe(source, "sink1") + source := NewPipe(nil, "in-flight-source") + sink1 := NewPipe(source, "in-flight-sink1") go sink1.Listen(func(msg message.Msg, _ offset.Offset) (message.Msg, error) { - time.Sleep(200 * time.Millisecond) - msgsProcessed++ - return msg, nil - }) - sink2 := NewPipe(source, "sink2") - go sink2.Listen(func(msg message.Msg, _ offset.Offset) (message.Msg, error) { + time.Sleep(100 * time.Millisecond) msgsProcessed++ return msg, nil }) - source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - go source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - time.Sleep(100 * time.Millisecond) - source.Stop() + for i := 0; i < 20; i++ { + source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) + } sink1.Stop() - sink2.Stop() - if msgsProcessed != 2 { - t.Errorf("unexpected messages processed count, expected 2, got %d", msgsProcessed) + source.Stop() + if msgsProcessed != 20 { + t.Errorf("unexpected messages processed count, expected 20, got %d", msgsProcessed) } + } func TestChainMessage(t *testing.T) { @@ -77,7 +68,6 @@ func TestChainMessage(t *testing.T) { return msg, nil }) source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - time.Sleep(100 * time.Millisecond) source.Stop() sink1.Stop() sink2.Stop() @@ -99,7 +89,6 @@ func TestSkipMessage(t *testing.T) { return msg, nil }) source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) - time.Sleep(100 * time.Millisecond) source.Stop() sink1.Stop() sink2.Stop() @@ -119,12 +108,11 @@ func TestListenErr(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func(wg *sync.WaitGroup, t *testing.T) { - for err := range source.Err { - if !reflect.DeepEqual(err, errListen) { - t.Errorf("wrong error received, expected %s, got %s", errListen, err) - } - wg.Done() + err := <-source.Err + if !reflect.DeepEqual(err, errListen) { + t.Errorf("wrong error received, expected %s, got %s", errListen, err) } + wg.Done() }(&wg, t) source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) source.Send(message.From(ops.Insert, "test", map[string]interface{}{}), offset.Offset{}) diff --git a/pipeline/node_test.go b/pipeline/node_test.go index fb4fea906..d1ebfc147 100644 --- a/pipeline/node_test.go +++ b/pipeline/node_test.go @@ -649,29 +649,41 @@ func TestStop(t *testing.T) { for _, st := range stopTests { source, s, deferFunc := st.node() defer deferFunc() - var errored bool + var errorChecked bool stopC := make(chan struct{}) - go func() { + var mu sync.Mutex + go func(mu *sync.Mutex) { select { case <-source.pipe.Err: - errored = true + mu.Lock() + defer mu.Unlock() + if errorChecked { + return + } + errorChecked = true + time.Sleep(1 * time.Second) source.Stop() close(stopC) } - }() + }(&mu) if err := source.Start(); err != st.startErr { t.Errorf("[%s] unexpected Start() error, expected %s, got %s", st.name, st.startErr, err) } - if !errored { + mu.Lock() + if !errorChecked { + errorChecked = true + time.Sleep(1 * time.Second) source.Stop() close(stopC) } + mu.Unlock() <-stopC for _, child := range source.children { if !s.Closed { t.Errorf("[%s] child node was not closed but should have been", child.Name) } } + if st.msgCount != s.MsgCount { t.Errorf("[%s] wrong number of messages received, expected %d, got %d", st.name, st.msgCount, s.MsgCount) } diff --git a/pipeline/pipeline_events_integration_test.go b/pipeline/pipeline_events_integration_test.go index f24932b76..1687a5479 100644 --- a/pipeline/pipeline_events_integration_test.go +++ b/pipeline/pipeline_events_integration_test.go @@ -53,6 +53,7 @@ func TestEventsBroadcast(t *testing.T) { eh := &EventHolder{rawEvents: make([][]byte, 0)} ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { event, _ := ioutil.ReadAll(r.Body) + fmt.Println(string(event)) r.Body.Close() eh.rawEvents = append(eh.rawEvents, event) })) @@ -103,7 +104,7 @@ func TestEventsBroadcast(t *testing.T) { t.Fatalf("can't create NewNode, got %s", err) } - p, err := NewDefaultPipeline(dummyOutNode, ts.URL, "asdf", "jklm", "test", 1*time.Second) + p, err := NewDefaultPipeline(dummyOutNode, ts.URL, "asdf", "jklm", "test", 10*time.Second) if err != nil { t.Errorf("can't create pipeline, got %s", err.Error()) t.FailNow()