diff --git a/connector/amqp10/amqp10.go b/connector/amqp10/amqp10.go index 764114b..6b67b1c 100644 --- a/connector/amqp10/amqp10.go +++ b/connector/amqp10/amqp10.go @@ -37,7 +37,8 @@ type AMQP10Connector struct { receivers map[string]*amqp.Receiver senders map[string]*amqp.Sender logger *logging.Logger - interrupt chan bool + inInterrupt chan bool + outInterrupt chan bool } // AMQP10Message holds received (or to be sent) messages from (to) AMQP-1.0 entity @@ -95,7 +96,6 @@ func CreateAMQP10Connector( logger: logger, receivers: make(map[string]*amqp.Receiver), senders: make(map[string]*amqp.Sender), - interrupt: make(chan bool), } // connect @@ -230,9 +230,12 @@ func (conn *AMQP10Connector) connect(connType string) error { switch connType { case "in": conn.inConnection = sess + conn.inInterrupt = make(chan bool) case "out": conn.outConnection = sess + conn.outInterrupt = make(chan bool) } + return nil } @@ -293,7 +296,7 @@ func (conn *AMQP10Connector) CreateReceiver(address string, prefetch int) error func (conn *AMQP10Connector) CreateSender(address string) (*amqp.Sender, error) { channel := strings.TrimPrefix(address, "/") if s, ok := conn.senders[channel]; ok { - s.Close(nil) + s.Close(context.Background()) delete(conn.senders, channel) } @@ -390,6 +393,12 @@ func (conn *AMQP10Connector) startSenders(inchan chan interface{}, wg *sync.Wait sndLock.Unlock() message.SetIdFromCounter(&counter) go func(sender *amqp.Sender, msg AMQP10Message, sndLock *sync.RWMutex, lfLock *sync.RWMutex, timeout time.Duration) { + var ( + connErr *amqp.ConnError + linkErr *amqp.LinkError + sessErr *amqp.SessionError + ) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer func(cancel context.CancelFunc) { cancel() @@ -420,7 +429,15 @@ func (conn *AMQP10Connector) startSenders(inchan chan interface{}, wg *sync.Wait lfLock.Lock() linkFail += 1 lfLock.Unlock() - msgMeta["reason"] = err.(*amqp.Error).Description + if urr, ok := err.(*amqp.Error); ok { + msgMeta["reason"] = urr.Description + } else if errors.As(err, &connErr) && connErr.RemoteErr != nil { + msgMeta["reason"] = connErr.RemoteErr.Description + } else if errors.As(err, &linkErr) && linkErr.RemoteErr != nil { + msgMeta["reason"] = linkErr.RemoteErr.Description + } else if errors.As(err, &sessErr) && sessErr.RemoteErr != nil { + msgMeta["reason"] = sessErr.RemoteErr.Description + } conn.logger.Metadata(msgMeta) conn.logger.Warn("Failed to send message") @@ -437,7 +454,7 @@ func (conn *AMQP10Connector) startSenders(inchan chan interface{}, wg *sync.Wait }) conn.logger.Debug("Skipped processing of sent AMQP1.0 message with invalid type") } - case <-conn.interrupt: + case <-conn.outInterrupt: goto doneSend } } @@ -472,6 +489,12 @@ func (conn *AMQP10Connector) startReceivers(outchan chan interface{}, wg *sync.W for _, rcv := range conn.receivers { wg.Add(1) go func(receiver *amqp.Receiver) { + var ( + connErr *amqp.ConnError + linkErr *amqp.LinkError + sessErr *amqp.SessionError + ) + defer wg.Done() connLogMeta := logging.Metadata{ @@ -480,25 +503,29 @@ func (conn *AMQP10Connector) startReceivers(outchan chan interface{}, wg *sync.W } conn.logger.Metadata(connLogMeta) conn.logger.Warn("Created receiver") + // number of link failures + linkFail := int64(0) for { select { - case <-conn.interrupt: + case <-conn.inInterrupt: goto doneReceive default: - } - var connErr *amqp.ConnError - if msg, err := receiver.Receive(context.Background(), nil); err == nil { - receiver.AcceptMessage(context.Background(), msg) - conn.processIncomingMessage(msg.GetData(), outchan, receiver) - conn.logger.Debug("Message ACKed") - } else if errors.As(err, &connErr) { - conn.logger.Metadata(connLogMeta) - conn.logger.Warn("Channel closed, reconnecting") - goto reconnectReceive - } else { - connLogMeta["err"] = err - conn.logger.Metadata(connLogMeta) - conn.logger.Error("Received AMQP1.0 error") + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if msg, err := receiver.Receive(ctx, nil); err == nil { + receiver.AcceptMessage(context.Background(), msg) + conn.processIncomingMessage(msg.GetData(), outchan, receiver) + conn.logger.Debug("Message ACKed") + linkFail = int64(0) + } else if errors.As(err, &connErr) || errors.As(err, &linkErr) || errors.As(err, &sessErr) { + linkFail += int64(1) + if linkFail > conn.LinkFailureLimit { + conn.logger.Metadata(connLogMeta) + conn.logger.Warn("Too many link failures in row, reconnecting") + goto reconnectReceive + } + } else { + // receiver wait timeouted + } } } @@ -525,8 +552,8 @@ func (conn *AMQP10Connector) stopReceivers() { "receiver": conn.receivers[r].Address(), }) conn.logger.Debug("Closed receiver link") - delete(conn.receivers, r) } + conn.receivers = map[string]*amqp.Receiver{} } // Reconnect tries to reconnect connector to configured AMQP1.0 node. Returns nil if failed @@ -535,6 +562,7 @@ func (conn *AMQP10Connector) Reconnect(connectionType string, msgChannel chan in listen := []string{} switch connectionType { case "in": + close(conn.inInterrupt) for r := range conn.receivers { listen = append(listen, conn.receivers[r].Address()) } @@ -542,6 +570,7 @@ func (conn *AMQP10Connector) Reconnect(connectionType string, msgChannel chan in conn.inConnection.Close(ctx) conn.logger.Debug("Disconnected incoming connection") case "out": + close(conn.outInterrupt) conn.stopSenders() conn.outConnection.Close(ctx) conn.logger.Debug("Disconnected outgoing connection") @@ -577,7 +606,8 @@ func (conn *AMQP10Connector) Reconnect(connectionType string, msgChannel chan in // Disconnect closes connection in both directions func (conn *AMQP10Connector) Disconnect() { ctx := context.Background() - close(conn.interrupt) + close(conn.inInterrupt) + close(conn.outInterrupt) time.Sleep(time.Second) conn.inConnection.Close(ctx) conn.outConnection.Close(ctx) diff --git a/tests/connector_test.go b/tests/connector_test.go index 6974d9e..5967ecc 100644 --- a/tests/connector_test.go +++ b/tests/connector_test.go @@ -240,11 +240,35 @@ func TestAMQP10SendAndReceiveMessage(t *testing.T) { wg.Wait() }) - t.Run("Test reconnect", func(t *testing.T) { + t.Run("Test reconnect of sender", func(t *testing.T) { + var wg sync.WaitGroup + + require.NoError(t, conn.Reconnect("out", sender, cwg)) + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 3; i++ { + data := <-receiver + assert.Equal(t, QDRMsg2, (data.(amqp10.AMQP10Message)).Body) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 3; i++ { + sender <- amqp10.AMQP10Message{Address: "qdrtest", Body: QDRMsg2} + } + }() + + wg.Wait() + }) + + t.Run("Test reconnect of receiver", func(t *testing.T) { var wg sync.WaitGroup require.NoError(t, conn.Reconnect("in", receiver, cwg)) - require.NoError(t, conn.Reconnect("out", receiver, cwg)) wg.Add(1) go func() { @@ -258,16 +282,16 @@ func TestAMQP10SendAndReceiveMessage(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - sender <- amqp10.AMQP10Message{Address: "qdrtest", Body: QDRMsg2} - sender <- amqp10.AMQP10Message{Address: "qdrtest", Body: QDRMsg2} - sender <- amqp10.AMQP10Message{Address: "qdrtest", Body: QDRMsg2} + for i := 0; i < 3; i++ { + sender <- amqp10.AMQP10Message{Address: "qdrtest", Body: QDRMsg2} + } }() wg.Wait() - conn.Disconnect() - cwg.Wait() }) + conn.Disconnect() + cwg.Wait() } func TestLoki(t *testing.T) {