From 7ef78d4de5eb314e086c3ef608a54f9fa16566ae Mon Sep 17 00:00:00 2001 From: Danlock Date: Thu, 21 Sep 2023 01:24:27 -0400 Subject: [PATCH] added benchmarks, PublishBatchUntilAcked for publishing many messages, slight refactoring of Consumer to take the connection on constructor --- Makefile | 6 +- README.md | 15 +-- benchmark_int_test.go | 215 ++++++++++++++++++++++++++++++++++++++++ consumer.go | 27 ++--- consumer_int_test.go | 30 +++--- healthcheck_int_test.go | 4 +- publisher.go | 151 ++++++++++++++++++++++++++-- 7 files changed, 401 insertions(+), 47 deletions(-) create mode 100644 benchmark_int_test.go diff --git a/Makefile b/Makefile index a3c88a6..dada866 100644 --- a/Makefile +++ b/Makefile @@ -40,4 +40,8 @@ coverage-browser: update-readme-badge: @go tool cover -func=$(COVERAGE_PATH) -o=$(COVERAGE_PATH).badge - @go run github.com/AlexBeauchemin/gobadge@v0.3.0 -filename=$(COVERAGE_PATH).badge \ No newline at end of file + @go run github.com/AlexBeauchemin/gobadge@v0.3.0 -filename=$(COVERAGE_PATH).badge + +# pkg.go.dev documentation is updated via go get +update-proxy-cache: + @GOPROXY=https://proxy.golang.org go get github.com/danlock/rmq diff --git a/README.md b/README.md index c620b76..c81123c 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,8 @@ This package attempts to provide a wrapper of useful features on top of amqp091, Using an AMQP publisher to publish a message with at least once delivery. ``` -ctx := context.TODO() +ctx, cancel := context.WithTimeout(context.TODO(), time.Minute) +defer cancel() cfg := rmq.CommonConfig{Log: slog.Log} rmqConn := rmq.ConnectWithURLs(ctx, rmq.ConnectConfig{CommonConfig: cfg}, os.Getenv("AMQP_URL_1"), os.Getenv("AMQP_URL_2")) @@ -39,21 +40,21 @@ if err := rmqPub.PublishUntilAcked(ctx, time.Minute, msg); err != nil { } ``` -Using a reliable AMQP consumer that delivers messages through transient network failures while processing work concurrently with bounded goroutines. +Using a reliable AMQP consumer that receives deliveries through transient network failures while processing work concurrently with bounded goroutines. ``` -ctx := context.TODO() +ctx, := context.TODO() cfg := rmq.CommonConfig{Log: slog.Log} -rmqConn := rmq.ConnectWithURL(ctx, rmq.ConnectConfig{CommonConfig: cfg}, os.Getenv("AMQP_URL")) +rmqConn := rmq.ConnectWithAMQPConfig(ctx, rmq.ConnectConfig{CommonConfig: cfg}, os.Getenv("AMQP_URL"), amqp.Config{}) consCfg := rmq.ConsumerConfig{ CommonConfig: cfg, Queue: rmq.Queue{Name: "q2d2", AutoDelete: true}, - Qos: rmq.Qos{PrefetchCount: 100}, + Qos: rmq.Qos{PrefetchCount: 1000}, } -rmq.NewConsumer(consCfg).ConsumeConcurrently(ctx, rmqConn, 50, func(ctx context.Context, msg amqp.Delivery) { +rmq.NewConsumer(rmqConn, consCfg).ConsumeConcurrently(ctx, 100, func(ctx context.Context, msg amqp.Delivery) { process(msg) if err := msg.Ack(false); err != nil { handleErr(err) @@ -74,7 +75,7 @@ Here is an example logrus wrapper. danlock/rmq only uses the predefined slog.Lev PublisherConfig{ Log: func(ctx context.Context, level slog.Level, msg string, args ...any) { logruslevel, _ := logrus.ParseLevel(level.String()) - logrus.StandardLogger().WithContext(ctx).Logf(logruslevel, msg, args...) + logrus.StandardLogger().WithContext(ctx).Logf(logruslevel, msg) } } ``` \ No newline at end of file diff --git a/benchmark_int_test.go b/benchmark_int_test.go new file mode 100644 index 0000000..03a5fa7 --- /dev/null +++ b/benchmark_int_test.go @@ -0,0 +1,215 @@ +//go:build rabbit + +package rmq_test + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "os" + "strconv" + "testing" + "time" + + "github.com/danlock/rmq" + amqp "github.com/rabbitmq/amqp091-go" +) + +func generatePublishings(num int, routingKey string) []rmq.Publishing { + publishings := make([]rmq.Publishing, 100) + for i := range publishings { + publishings[i] = rmq.Publishing{ + RoutingKey: routingKey, + Mandatory: true, + Publishing: amqp.Publishing{ + Body: []byte(fmt.Sprintf("%d.%d", i, time.Now().UnixNano())), + }, + } + } + return publishings +} + +func BenchmarkPublishAndConsumeMany(b *testing.B) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + randSuffix := fmt.Sprintf("%d.%p", time.Now().UnixNano(), b) + + queueName := "BenchmarkPublishAndConsumeMany" + randSuffix + baseCfg := rmq.CommonConfig{Log: slog.Log} + topology := rmq.Topology{ + CommonConfig: baseCfg, + Queues: []rmq.Queue{{ + Name: queueName, + Args: amqp.Table{ + amqp.QueueTTLArg: time.Minute.Milliseconds(), + }, + }}, + } + + subRMQConn := rmq.ConnectWithURL(ctx, rmq.ConnectConfig{CommonConfig: baseCfg, Topology: topology}, os.Getenv("TEST_AMQP_URI")) + pubRMQConn := rmq.ConnectWithURL(ctx, rmq.ConnectConfig{CommonConfig: baseCfg, Topology: topology}, os.Getenv("TEST_AMQP_URI")) + + consumer := rmq.NewConsumer(subRMQConn, rmq.ConsumerConfig{ + CommonConfig: baseCfg, + Queue: topology.Queues[0], + }) + + publisher := rmq.NewPublisher(ctx, pubRMQConn, rmq.PublisherConfig{ + CommonConfig: baseCfg, + LogReturns: true, + }) + + publisher2, publisher3 := rmq.NewPublisher(ctx, pubRMQConn, rmq.PublisherConfig{ + CommonConfig: baseCfg, + LogReturns: true, + }), rmq.NewPublisher(ctx, pubRMQConn, rmq.PublisherConfig{ + CommonConfig: baseCfg, + LogReturns: true, + }) + + dot := []byte(".") + errChan := make(chan error) + consumeChan := consumer.Consume(ctx) + + publishings := generatePublishings(10000, queueName) + + cases := []struct { + name string + publishFunc func(b *testing.B) + }{ + { + "PublishBatchUntilAcked", + func(b *testing.B) { + if err := publisher.PublishBatchUntilAcked(ctx, 0, publishings...); err != nil { + b.Fatalf("PublishBatchUntilAcked err %v", err) + } + }, + }, + { + "PublishBatchUntilAcked into thirds", + func(b *testing.B) { + errChan := make(chan error) + publishers := []*rmq.Publisher{publisher, publisher, publisher} + for i := range publishers { + go func(i int) { + errChan <- publishers[i].PublishBatchUntilAcked(ctx, 0, publishings[i:i+1]...) + }(i) + } + successes := 0 + for { + select { + case err := <-errChan: + if err != nil { + b.Fatalf("PublishBatchUntilAcked err %v", err) + } + successes++ + if successes == len(publishers) { + return + } + case <-ctx.Done(): + b.Fatalf("PublishBatchUntilAcked timed out") + } + } + }, + }, + { + "PublishBatchUntilAcked on three Publishers", + func(b *testing.B) { + errChan := make(chan error) + publishers := []*rmq.Publisher{publisher, publisher2, publisher3} + for i := range publishers { + go func(i int) { + errChan <- publishers[i].PublishBatchUntilAcked(ctx, 0, publishings[i:i+1]...) + }(i) + } + successes := 0 + for { + select { + case err := <-errChan: + if err != nil { + b.Fatalf("PublishBatchUntilAcked err %v", err) + } + successes++ + if successes == len(publishers) { + return + } + case <-ctx.Done(): + b.Fatalf("PublishBatchUntilAcked timed out") + } + } + }, + }, + { + "Concurrent PublishUntilAcked", + func(b *testing.B) { + errChan := make(chan error) + for i := range publishings { + go func(i int) { + errChan <- publisher.PublishUntilAcked(ctx, 0, publishings[i]) + }(i) + } + successes := 0 + for { + select { + case err := <-errChan: + if err != nil { + b.Fatalf("PublishUntilAcked err %v", err) + } + successes++ + if successes == len(publishings) { + return + } + case <-ctx.Done(): + b.Fatalf("PublishUntilAcked timed out") + } + } + }, + }, + } + + for _, bb := range cases { + b.Run(bb.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + go func(i int) (err error) { + received := make(map[uint64]struct{}, len(publishings)) + defer func() { errChan <- err }() + for { + select { + case msg := <-consumeChan: + rawIndex := bytes.Split(msg.Body, dot)[0] + index, err := strconv.ParseUint(string(rawIndex), 10, 64) + if err != nil { + return fmt.Errorf("strconv.ParseUint err %w", err) + } + received[index] = struct{}{} + if err := msg.Ack(false); err != nil { + return fmt.Errorf("msg.Ack err %w", err) + } + if len(received) == len(publishings) { + return nil + } + case <-ctx.Done(): + return fmt.Errorf("timed out after consuming %d publishings on bench run %d", len(received), i) + } + } + }(i) + + if err := publisher.PublishBatchUntilAcked(ctx, 0, publishings...); err != nil { + b.Fatalf("PublishBatchUntilAcked err %v", err) + } + select { + case <-ctx.Done(): + b.Fatalf("timed out on bench run %d", i) + case err := <-errChan: + if err != nil { + b.Fatalf("on bench run %d consumer err %v", i, err) + + } + } + } + }) + } + +} diff --git a/consumer.go b/consumer.go index 5ae11c1..4a7217b 100644 --- a/consumer.go +++ b/consumer.go @@ -16,7 +16,7 @@ type ConsumerConfig struct { CommonConfig Queue Queue - QueueBindings []QueueBinding // Should only be used for anonymous queues, otherwise QueueBinding's be declared with DeclareTopology + QueueBindings []QueueBinding // Only needed for anonymous queues since Consumer's do not return the generated RabbitMQ queue name Consume Consume Qos Qos } @@ -61,24 +61,25 @@ type Qos struct { // Consumer enables reliable AMQP Queue consumption. type Consumer struct { config ConsumerConfig + conn *Connection } // NewConsumer takes in a ConsumerConfig that describes the AMQP topology of a single queue, // and returns a rmq.Consumer that can redeclare this topology on any errors during queue consumption. // This enables robust reconnections even on unreliable networks. -func NewConsumer(config ConsumerConfig) *Consumer { +func NewConsumer(rmqConn *Connection, config ConsumerConfig) *Consumer { config.setDefaults() - return &Consumer{config: config} + return &Consumer{config: config, conn: rmqConn} } // safeDeclareAndConsume safely declares and consumes from an amqp.Queue // Closes the amqp.Channel on errors. -func (c *Consumer) safeDeclareAndConsume(ctx context.Context, rmqConn *Connection) (_ *amqp.Channel, _ <-chan amqp.Delivery, err error) { +func (c *Consumer) safeDeclareAndConsume(ctx context.Context) (_ *amqp.Channel, _ <-chan amqp.Delivery, err error) { logPrefix := fmt.Sprintf("rmq.Consumer.safeDeclareAndConsume for queue %s", c.config.Queue.Name) ctx, cancel := context.WithTimeout(ctx, c.config.AMQPTimeout) defer cancel() - mqChan, err := rmqConn.Channel(ctx) + mqChan, err := c.conn.Channel(ctx) if err != nil { return nil, nil, fmt.Errorf(logPrefix+" failed to get a channel due to err %w", err) } @@ -172,7 +173,7 @@ func (c *Consumer) declareAndConsume(ctx context.Context, mqChan *amqp.Channel) // On errors Consume reconnects to AMQP, redeclares and resumes consumption and forwarding of deliveries. // Consume returns an unbuffered channel, and will block on sending to it if no ones listening. // The returned channel is closed only after the context finishes. -func (c *Consumer) Consume(ctx context.Context, rmqConn *Connection) <-chan amqp.Delivery { +func (c *Consumer) Consume(ctx context.Context) <-chan amqp.Delivery { outChan := make(chan amqp.Delivery) go func() { logPrefix := fmt.Sprintf("rmq.Consumer.Consume for queue (%s)", c.config.Queue.Name) @@ -185,7 +186,7 @@ func (c *Consumer) Consume(ctx context.Context, rmqConn *Connection) <-chan amqp return case <-time.After(delay): } - mqChan, inChan, err := c.safeDeclareAndConsume(ctx, rmqConn) + mqChan, inChan, err := c.safeDeclareAndConsume(ctx) if err != nil { delay = c.config.Delay(attempt) attempt++ @@ -202,7 +203,7 @@ func (c *Consumer) Consume(ctx context.Context, rmqConn *Connection) <-chan amqp return outChan } -// forwardDeliveries forwards from inChan until it closes. If the context finishes it closes the amqp Channel so that the delivery channel will close eventually. +// forwardDeliveries forwards from inChan until it closes. If the context finishes it closes the amqp Channel so that the delivery channel will close after sending it's deliveries. func (c *Consumer) forwardDeliveries(ctx context.Context, mqChan *amqp.Channel, inChan <-chan amqp.Delivery, outChan chan<- amqp.Delivery) { logPrefix := fmt.Sprintf("rmq.Consumer.forwardDeliveries for queue (%s)", c.config.Queue.Name) closeNotifier := mqChan.NotifyClose(make(chan *amqp.Error, 6)) @@ -233,18 +234,18 @@ func (c *Consumer) forwardDeliveries(ctx context.Context, mqChan *amqp.Channel, } // ConsumeConcurrently simply runs the provided deliveryProcessor on each delivery from Consume in a new goroutine. -// maxGoroutines limits the amounts of goroutines spawned and defaults to 2000. +// maxGoroutines limits the amounts of goroutines spawned and defaults to 500. // Qos.PrefetchCount can also limit goroutines spawned if deliveryProcessor properly Acks messages. // Blocks until the context is finished and the Consume channel closes. -func (c *Consumer) ConsumeConcurrently(ctx context.Context, rmqConn *Connection, maxGoroutines uint64, deliveryProcessor func(ctx context.Context, msg amqp.Delivery)) { +func (c *Consumer) ConsumeConcurrently(ctx context.Context, maxGoroutines uint64, deliveryProcessor func(ctx context.Context, msg amqp.Delivery)) { if maxGoroutines == 0 { - maxGoroutines = 2000 + maxGoroutines = 500 } // We use a simple semaphore here and a new goroutine each time. - // It may be more efficient to use a goroutine pool, but a concerned caller can probably do it better themselves. + // It may be more efficient to use a goroutine pool for small amounts of work, but a concerned caller can probably do it better themselves. semaphore := make(chan struct{}, maxGoroutines) deliverAndReleaseSemaphore := func(msg amqp.Delivery) { deliveryProcessor(ctx, msg); <-semaphore } - for msg := range c.Consume(ctx, rmqConn) { + for msg := range c.Consume(ctx) { semaphore <- struct{}{} go deliverAndReleaseSemaphore(msg) } diff --git a/consumer_int_test.go b/consumer_int_test.go index 37d1c77..da7bc1d 100644 --- a/consumer_int_test.go +++ b/consumer_int_test.go @@ -49,15 +49,15 @@ func TestConsumer(t *testing.T) { }, } - baseConsumer := rmq.NewConsumer(baseConsConfig) + baseConsumer := rmq.NewConsumer(rmqConn, baseConsConfig) canceledCtx, canceledCancel := context.WithCancel(ctx) canceledCancel() // ConsumeConcurrently should exit immediately on canceled contexts. - baseConsumer.ConsumeConcurrently(canceledCtx, rmqConn, 0, nil) + baseConsumer.ConsumeConcurrently(canceledCtx, 0, nil) rmqBaseConsMessages := make(chan amqp.Delivery, 10) - go baseConsumer.ConsumeConcurrently(ctx, rmqConn, 0, func(ctx context.Context, msg amqp.Delivery) { + go baseConsumer.ConsumeConcurrently(ctx, 0, func(ctx context.Context, msg amqp.Delivery) { rmqBaseConsMessages <- msg _ = msg.Ack(false) }) @@ -155,7 +155,7 @@ func TestConsumer_Load(t *testing.T) { publisher := rmq.NewPublisher(ctx, rmqConn, rmq.PublisherConfig{CommonConfig: baseCfg}) msgCount := 5_000 - errChan := make(chan error, (msgCount+1)*len(consumers)) + errChan := make(chan error, (msgCount/2+1)*len(consumers)) for _, c := range consumers { c := c go func() { @@ -163,7 +163,7 @@ func TestConsumer_Load(t *testing.T) { receives := make(map[int]struct{}) var msgRecv uint64 var consMu sync.Mutex - rmq.NewConsumer(c).ConsumeConcurrently(ctx, rmqConn, 0, func(ctx context.Context, msg amqp.Delivery) { + rmq.NewConsumer(rmqConn, c).ConsumeConcurrently(ctx, 0, func(ctx context.Context, msg amqp.Delivery) { if !c.Consume.AutoAck { defer msg.Ack(false) } @@ -187,8 +187,8 @@ func TestConsumer_Load(t *testing.T) { }) }() go func() { - // Send half of the messages in parallel, then the rest serially - // The listen() goroutine will serially execute all of these publishes anyway. Even the underlying *amqp.Channel will lock it's mutex on publishes. + // Send half of the messages with an incredibly inefficient use of goroutines, and the rest in a PublishBatchUntilAcked. + // Publishing all of this stuff in different goroutines should not cause any races. for i := 0; i < msgCount/2; i++ { go func(i int) { errChan <- publisher.PublishUntilAcked(ctx, 0, rmq.Publishing{ @@ -200,15 +200,17 @@ func TestConsumer_Load(t *testing.T) { }) }(i) } - for i := msgCount / 2; i < msgCount; i++ { - errChan <- publisher.PublishUntilAcked(ctx, 0, rmq.Publishing{ + pubs := make([]rmq.Publishing, msgCount/2) + for i := range pubs { + pubs[i] = rmq.Publishing{ RoutingKey: c.Queue.Name, Mandatory: true, Publishing: amqp.Publishing{ - Body: []byte(fmt.Sprint(c.Queue.Name, ":", i)), + Body: []byte(fmt.Sprint(c.Queue.Name, ":", i+len(pubs))), }, - }) + } } + errChan <- publisher.PublishBatchUntilAcked(ctx, 0, pubs...) }() } @@ -233,14 +235,14 @@ func TestRMQConsumer_AutogeneratedQueueNames(t *testing.T) { // NewConsumer with an empty Queue.Name will declare a queue with a RabbitMQ generated name // This is useless unless the config also includes QueueBindings, since reconnections cause RabbitMQ to generate a different name anyway - cons := rmq.NewConsumer(rmq.ConsumerConfig{ + cons := rmq.NewConsumer(rmqConn, rmq.ConsumerConfig{ CommonConfig: baseCfg, QueueBindings: []rmq.QueueBinding{ {ExchangeName: "amq.fanout", RoutingKey: "TestRMQConsumer_AutogeneratedQueueNames"}, }, Qos: rmq.Qos{PrefetchCount: 1}, }) - deliveries := cons.Consume(ctx, rmqConn) + deliveries := cons.Consume(ctx) // Wait a sec for Consume to actually bring up the queue, since otherwise a published message could happen before a queue is declared. // danlock/rmq best practice to only use queues named in your Topology so you won't have to remember this. time.Sleep(time.Second / 3) @@ -253,7 +255,7 @@ func TestRMQConsumer_AutogeneratedQueueNames(t *testing.T) { // Declaring again should work without errors, but it will create a different queue rather than consuming from the first one. // rmq.Consumer could remember the last queue name to consume from it again, but that wouldn't be reliable with auto-deleted or expiring queues. // It's simpler to disallow that use case by not making RabbitMQ generated queue names available from rmq.Consumer. - secondDeliveries := cons.Consume(ctx, rmqConn) + secondDeliveries := cons.Consume(ctx) publisher := rmq.NewPublisher(ctx, rmqConn, rmq.PublisherConfig{CommonConfig: baseCfg, LogReturns: true}) pubCount := 10 time.Sleep(time.Second / 3) diff --git a/healthcheck_int_test.go b/healthcheck_int_test.go index 4999e0e..000c1a8 100644 --- a/healthcheck_int_test.go +++ b/healthcheck_int_test.go @@ -55,13 +55,13 @@ func Example() { panic("couldn't get a channel") } - rmqCons := rmq.NewConsumer(rmq.ConsumerConfig{ + rmqCons := rmq.NewConsumer(subRMQConn, rmq.ConsumerConfig{ CommonConfig: commonCfg, Queue: topology.Queues[0], Qos: rmq.Qos{PrefetchCount: 10}, }) // Now we have a RabbitMQ queue with messages incoming on the deliveries channel, even if the network flakes. - deliveries := rmqCons.Consume(ctx, subRMQConn) + deliveries := rmqCons.Consume(ctx) rmqPub := rmq.NewPublisher(ctx, pubRMQConn, rmq.PublisherConfig{CommonConfig: commonCfg}) // Now we have an AMQP publisher that can sends messages with at least once delivery. diff --git a/publisher.go b/publisher.go index ea5759a..72c3547 100644 --- a/publisher.go +++ b/publisher.go @@ -16,10 +16,10 @@ type PublisherConfig struct { // NotifyReturn will receive amqp.Return's from any amqp.Channel this rmq.Publisher sends on. // Recommended to use a buffered channel. Closed after the publisher's context is done. NotifyReturn chan<- amqp.Return - // LogReturns without their amqp.Return.Body using PublisherConfig.Log. + // LogReturns without their amqp.Return.Body using CommonConfig.Log when true LogReturns bool - // DontConfirm will not set the amqp.Channel in Confirm mode, and disallow PublishUntilConfirmed. + // DontConfirm means the Publisher's amqp.Channel won't be in Confirm mode. Methods except for Publish will throw an error. DontConfirm bool } @@ -183,6 +183,10 @@ func (p *Publishing) publish(mqChan *amqp.Channel) { p.req.RespChan <- resp } +func (p *Publishing) empty() bool { + return p.Exchange == "" && p.RoutingKey == "" && len(p.Body) == 0 +} + // Publish send a Publishing on rmq.Publisher's current amqp.Channel. // Returns amqp.DefferedConfirmation's only if the rmq.Publisher has Confirm set. // If an error is returned, rmq.Publisher will grab another amqp.Channel from rmq.Connection, which itself will redial AMQP if necessary. @@ -209,9 +213,9 @@ func (p *Publisher) Publish(ctx context.Context, pub Publishing) (*amqp.Deferred } // PublishUntilConfirmed calls Publish and waits for Publishing to be confirmed. -// It republishes if a message isn't confirmed after ConfirmTimeout, or if Publish returns an error. +// It republishes if a message isn't confirmed after confirmTimeout, or if Publish returns an error. // Returns *amqp.DeferredConfirmation so the caller can check if it's Acked(). -// Recommended to call with context.WithTimeout. +// confirmTimeout defaults to 1 minute. Recommended to call with context.WithTimeout. func (p *Publisher) PublishUntilConfirmed(ctx context.Context, confirmTimeout time.Duration, pub Publishing) (*amqp.DeferredConfirmation, error) { logPrefix := "rmq.Publisher.PublishUntilConfirmed" @@ -225,15 +229,18 @@ func (p *Publisher) PublishUntilConfirmed(ctx context.Context, confirmTimeout ti var pubDelay time.Duration attempt := 0 + errs := make([]error, 0) + for { defConf, err := p.Publish(ctx, pub) if err != nil { pubDelay = p.config.Delay(attempt) attempt++ - p.config.Log(ctx, slog.LevelError, logPrefix+" got a Publish error. Republishing due to %v", err) + errs = append(errs, err) select { case <-ctx.Done(): - return defConf, fmt.Errorf(logPrefix+" context done before the publish was sent %w", context.Cause(ctx)) + err = fmt.Errorf(logPrefix+" context done before the publish was sent %w", context.Cause(ctx)) + return defConf, errors.Join(append(errs, err)...) case <-time.After(pubDelay): continue } @@ -245,10 +252,11 @@ func (p *Publisher) PublishUntilConfirmed(ctx context.Context, confirmTimeout ti select { case <-confirmTimeout.C: - p.config.Log(ctx, slog.LevelWarn, logPrefix+" timed out waiting for confirm, republishing") + errs = append(errs, errors.New(logPrefix+" timed out waiting for confirm, republishing")) continue case <-ctx.Done(): - return defConf, fmt.Errorf("rmq.Publisher.PublishUntilConfirmed context done before the publish was confirmed %w", context.Cause(ctx)) + err = fmt.Errorf("rmq.Publisher.PublishUntilConfirmed context done before the publish was confirmed %w", context.Cause(ctx)) + return defConf, errors.Join(append(errs, err)...) case <-defConf.Done(): return defConf, nil } @@ -263,12 +271,16 @@ func (p *Publisher) PublishUntilConfirmed(ctx context.Context, confirmTimeout ti // RabbitMQ acks Publishing's so monitor the NotifyReturn chan to ensure your Publishing's are being delivered. // // PublishUntilAcked is intended for ensuring a Publishing with a known destination queue will get acked despite flaky connections or temporary RabbitMQ node failures. +// Recommended to call with context.WithTimeout. func (p *Publisher) PublishUntilAcked(ctx context.Context, confirmTimeout time.Duration, pub Publishing) error { logPrefix := "rmq.Publisher.PublishUntilAcked" nacks := 0 for { defConf, err := p.PublishUntilConfirmed(ctx, confirmTimeout, pub) if err != nil { + if nacks > 0 { + return fmt.Errorf(logPrefix+" resent nacked Publishings %d time(s) and %w", nacks, err) + } return err } @@ -277,7 +289,126 @@ func (p *Publisher) PublishUntilAcked(ctx context.Context, confirmTimeout time.D } nacks++ - p.config.Log(ctx, slog.LevelWarn, logPrefix+" resending Publishing that has been nacked %d time(s)...", nacks) - // There isn't a delay here since PublishUntilConfirmed waiting for the confirm should effectively slow us down to what can be handled by the AMQP server. } } + +// PublishBatchUntilAcked Publishes all of your Publishings at once, and then wait's for the DeferredConfirmation to be Acked, +// resending if it's been longer than confirmTimeout or if they've been nacked. +// confirmTimeout defaults to 1 minute. Recommended to call with context.WithTimeout. +func (p *Publisher) PublishBatchUntilAcked(ctx context.Context, confirmTimeout time.Duration, pubs ...Publishing) error { + logPrefix := "rmq.Publisher.PublishBatchUntilConfirmed" + + if len(pubs) == 0 { + return nil + } + if p.config.DontConfirm { + return fmt.Errorf(logPrefix + " called on a rmq.Publisher that's not in Confirm mode") + } + + if confirmTimeout == 0 { + confirmTimeout = time.Minute + } + + errs := make([]error, 0) + pendingPubs := make([]*amqp.DeferredConfirmation, len(pubs)) + ackedPubs := make([]bool, len(pubs)) + + remainingPubs := func() int { + unacks := 0 + for _, acked := range ackedPubs { + if !acked { + unacks++ + } + } + return unacks + } + + for { + select { + case <-ctx.Done(): + err := fmt.Errorf(logPrefix+" timed out because %w", context.Cause(ctx)) + return errors.Join(append(errs, err)...) + default: + } + + err := p.publishBatch(ctx, confirmTimeout, remainingPubs(), pubs, pendingPubs, ackedPubs, errs) + if err == nil { + return nil + } + clear(pendingPubs) + } +} + +// publishBatch publishes a slice of pubs once, waiting for them all to get acked. +// republishes on failure, returns after they've confirmed. +// blocks until context ends or confirmTimeout +func (p *Publisher) publishBatch( + ctx context.Context, + confirmTimeout time.Duration, + remaining int, + pubs []Publishing, + pendingPubs []*amqp.DeferredConfirmation, + ackedPubs []bool, + errs []error, +) (err error) { + logPrefix := "rmq.Publisher.publishBatch" + published := 0 + attempt := 0 + var delay time.Duration + for published != remaining { + for i, pub := range pubs { + // Skip if it's been successfully published or acked + if pendingPubs[i] != nil || ackedPubs[i] { + continue + } + + select { + case <-ctx.Done(): + return fmt.Errorf(logPrefix+" timed out because %w", context.Cause(ctx)) + default: + } + + pendingPubs[i], err = p.Publish(ctx, pub) + if err != nil { + errs = append(errs, err) + delay = p.config.Delay(attempt) + attempt++ + select { + case <-ctx.Done(): + return fmt.Errorf(logPrefix+" timed out because %w", context.Cause(ctx)) + case <-time.After(delay): + } + } else { + published++ + attempt = 0 + } + } + } + + confirmTimer := time.After(confirmTimeout) + confirmed := 0 + for confirmed != remaining { + for i, pub := range pendingPubs { + // Skip if it's already been confirmed + if pendingPubs[i] == nil { + continue + } + + select { + case <-ctx.Done(): + return fmt.Errorf(logPrefix+" timed out because %w", context.Cause(ctx)) + case <-confirmTimer: + return fmt.Errorf(logPrefix + " timed out waiting on confirms") + case <-pub.Done(): + if pub.Acked() { + ackedPubs[i] = true + } + confirmed++ + pendingPubs[i] = nil + default: + } + } + } + + return nil +}