diff --git a/internal/command_types.go b/internal/command_types.go index d3116dc6..c34cb24b 100644 --- a/internal/command_types.go +++ b/internal/command_types.go @@ -84,6 +84,7 @@ const ( CommandDeletePublisher uint16 = 0x0006 // 6 CommandSubscribe uint16 = 0x0007 // 7 CommandDeliver uint16 = 0x0008 // 8 + CommandCredit uint16 = 0x0009 // 9 CommandCreate uint16 = 0x000d // 13 CommandDelete uint16 = 0x000e // 14 CommandPeerProperties uint16 = 0x0011 // 17 @@ -99,6 +100,7 @@ const ( CommandDeclarePublisherResponse uint16 = 0x8001 CommandDeletePublisherResponse uint16 = 0x8006 CommandSubscribeResponse uint16 = 0x8007 + CommandCreditResponse uint16 = 0x8009 CommandCreateResponse uint16 = 0x800d CommandDeleteResponse uint16 = 0x800e CommandPeerPropertiesResponse uint16 = 0x8011 diff --git a/internal/credit.go b/internal/credit.go new file mode 100644 index 00000000..27939bc4 --- /dev/null +++ b/internal/credit.go @@ -0,0 +1,80 @@ +package internal + +import ( + "bufio" + "bytes" +) + +type CreditRequest struct { + subscriptionId uint8 + // number of chunks that can be sent + credit uint16 +} + +func NewCreditRequest(subscriptionId uint8, credit uint16) *CreditRequest { + return &CreditRequest{subscriptionId: subscriptionId, credit: credit} +} + +func (c *CreditRequest) UnmarshalBinary(data []byte) error { + return readMany(bytes.NewReader(data), &c.subscriptionId, &c.credit) +} + +func (c *CreditRequest) SubscriptionId() uint8 { + return c.subscriptionId +} + +func (c *CreditRequest) Credit() uint16 { + return c.credit +} + +func (c *CreditRequest) Write(w *bufio.Writer) (int, error) { + return writeMany(w, c.subscriptionId, c.credit) +} + +func (c *CreditRequest) Key() uint16 { + return CommandCredit +} + +func (c *CreditRequest) SizeNeeded() int { + return streamProtocolHeaderSizeBytes + streamProtocolKeySizeUint8 + streamProtocolKeySizeUint16 +} + +func (c *CreditRequest) Version() int16 { + return Version1 +} + +type CreditResponse struct { + responseCode uint16 + subscriptionId uint8 +} + +func (c *CreditResponse) ResponseCode() uint16 { + return c.responseCode +} + +func (c *CreditResponse) SubscriptionId() uint8 { + return c.subscriptionId +} + +func (c *CreditResponse) MarshalBinary() (data []byte, err error) { + w := &bytes.Buffer{} + n, err := writeMany(w, c.responseCode, c.subscriptionId) + if err != nil { + return nil, err + } + + if n != 3 { + return nil, errWriteShort + } + + data = w.Bytes() + return +} + +func NewCreditResponse(responseCode uint16, subscriptionId uint8) *CreditResponse { + return &CreditResponse{responseCode: responseCode, subscriptionId: subscriptionId} +} + +func (c *CreditResponse) Read(r *bufio.Reader) error { + return readMany(r, &c.responseCode, &c.subscriptionId) +} diff --git a/internal/credit_test.go b/internal/credit_test.go new file mode 100644 index 00000000..92b9853a --- /dev/null +++ b/internal/credit_test.go @@ -0,0 +1,54 @@ +package internal + +import ( + "bufio" + "bytes" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Internal/Credit", func() { + Context("Request", func() { + It("knows the size needed to encode itself", func() { + c := &CreditRequest{ + subscriptionId: 123, + credit: 987, + } + Expect(c.SizeNeeded()).To(BeNumerically("==", streamProtocolHeaderSizeBytes+1+2)) + }) + + It("encodes itself into a binary sequence", func() { + c := &CreditRequest{ + subscriptionId: 5, + credit: 255, + } + buff := &bytes.Buffer{} + wr := bufio.NewWriter(buff) + Expect(c.Write(wr)).To(BeNumerically("==", 3)) + Expect(wr.Flush()).To(Succeed()) + + expectedByteSequence := []byte{ + 0x05, // subscription ID + 0x00, 0xff, // credit + } + Expect(buff.Bytes()).To(Equal(expectedByteSequence)) + }) + }) + + Context("Response", func() { + It("decodes itself into a response struct", func() { + byteSequence := []byte{ + 0x00, 0x0f, // response code + 0x10, // sub ID + } + + buff := bytes.NewBuffer(byteSequence) + r := bufio.NewReader(buff) + c := &CreditResponse{} + + Expect(c.Read(r)).Should(Succeed()) + Expect(c.responseCode).To(BeNumerically("==", 15)) + Expect(c.subscriptionId).To(BeNumerically("==", 16)) + }) + }) +}) diff --git a/main.go b/main.go index a6d342e0..f8e66d76 100644 --- a/main.go +++ b/main.go @@ -95,12 +95,19 @@ func main() { go func() { for c := range chunkChan { received += int(c.NumEntries) + err := streamClient.Credit(ctx, 1, 1) + if err != nil { + log.Error(err, "error sending credits") + } if (received % totalMessages) == 0 { log.Info("Received", "messages ", received) } } }() + // this should log a warning message: subscription does not exist + _ = streamClient.Credit(ctx, 123, 1) + err = streamClient.Subscribe(ctx, stream, constants.OffsetTypeFirst, 1, 10, map[string]string{"name": "my_consumer"}, 10) fmt.Println("Press any key to stop ") reader := bufio.NewReader(os.Stdin) diff --git a/pkg/raw/client.go b/pkg/raw/client.go index e5688839..85f77b8e 100644 --- a/pkg/raw/client.go +++ b/pkg/raw/client.go @@ -52,6 +52,7 @@ type Client struct { connectionProperties map[string]string confirmsCh chan *PublishConfirm chunkCh chan *Chunk + notifyCh chan *CreditError } // IsOpen returns true if the connection is open, false otherwise @@ -341,6 +342,32 @@ func (tc *Client) handleIncoming(ctx context.Context) error { log.Error(err, "error ") } tc.handleResponse(ctx, exchangeResponse) + case internal.CommandCreditResponse: + creditResp := new(CreditError) + err = creditResp.Read(buffer) + log.Error( + errUnknownSubscription, + "received credit response for unknown subscription", + "responseCode", + creditResp.ResponseCode(), + "subscriptionId", + creditResp.SubscriptionId(), + ) + if err != nil { + log.Error(err, "error in credit response") + return err + } + + tc.mu.Lock() + if tc.notifyCh != nil { + select { + case <-ctx.Done(): + tc.mu.Unlock() + return ctx.Err() + case tc.notifyCh <- creditResp: + } + } + tc.mu.Unlock() default: log.Info("frame not implemented", "command ID", fmt.Sprintf("%X", header.Command())) _, err := buffer.Discard(header.Length() - 4) @@ -945,6 +972,17 @@ func (tc *Client) ExchangeCommandVersions(ctx context.Context) error { return streamErrorOrNil(response.ResponseCode()) } +// Credit TODO: go docs +func (tc *Client) Credit(ctx context.Context, subscriptionID uint8, credits uint16) error { + if ctx == nil { + return errNilContext + } + logger := logr.FromContextOrDiscard(ctx).WithName("Credit") + logger.V(debugLevel).Info("starting credit") + + return tc.request(ctx, internal.NewCreditRequest(subscriptionID, credits)) +} + // NotifyPublish TODO: godocs func (tc *Client) NotifyPublish(c chan *PublishConfirm) <-chan *PublishConfirm { tc.mu.Lock() @@ -963,3 +1001,11 @@ func (tc *Client) NotifyChunk(c chan *Chunk) <-chan *Chunk { tc.chunkCh = c return c } + +// NotifyCreditError TODO: go docs +func (tc *Client) NotifyCreditError(notification chan *CreditError) <-chan *CreditError { + tc.mu.Lock() + defer tc.mu.Unlock() + tc.notifyCh = notification + return notification +} diff --git a/pkg/raw/client_test.go b/pkg/raw/client_test.go index c3d68321..9c9ab92f 100644 --- a/pkg/raw/client_test.go +++ b/pkg/raw/client_test.go @@ -150,6 +150,36 @@ var _ = Describe("Client", func() { Expect(streamClient.DeclareStream(itCtx, "test-stream", constants.StreamConfiguration{"some-key": "some-value"})).To(Succeed()) }) + Context("credits", func() { + It("sends credits to the server", func(ctx SpecContext) { + Expect(fakeClientConn.SetDeadline(time.Now().Add(time.Second))).To(Succeed()) + streamClient := raw.NewClient(fakeClientConn, conf) + + go fakeRabbitMQ.fakeRabbitMQCredit(2, 100) + + Expect(streamClient.Credit(ctx, 2, 100)).To(Succeed()) + }) + + When("sending credits for non-existing subscription", func() { + It("returns an error", func(ctx SpecContext) { + Expect(fakeClientConn.SetDeadline(time.Now().Add(time.Second))).To(Succeed()) + streamClient := raw.NewClient(fakeClientConn, conf) + go streamClient.(*raw.Client).StartFrameListener(ctx) + + go fakeRabbitMQ.fakeRabbitMQCreditResponse( + newContextWithResponseCode(ctx, streamResponseCodeSubscriptionIdDoesNotExist, "credit"), + 123, + ) + var notification *raw.CreditError + notificationCh := streamClient.NotifyCreditError(make(chan *raw.CreditError)) + Eventually(notificationCh).Should(Receive(¬ification)) + Expect(notification).To(BeAssignableToTypeOf(&raw.CreditError{})) + Expect(notification.ResponseCode()).To(BeNumerically("==", streamResponseCodeSubscriptionIdDoesNotExist)) + Expect(notification.SubscriptionId()).To(BeNumerically("==", 123)) + }) + }) + }) + It("Delete a stream", func(ctx SpecContext) { itCtx, cancel := context.WithTimeout(logr.NewContext(ctx, GinkgoLogr), time.Second*3) defer cancel() diff --git a/pkg/raw/client_types.go b/pkg/raw/client_types.go index e7db3d36..777a35e0 100644 --- a/pkg/raw/client_types.go +++ b/pkg/raw/client_types.go @@ -12,10 +12,11 @@ import ( ) var ( - errURIScheme = errors.New("RabbitMQ Stream scheme must be either 'rabbitmq-stream://' or 'rabbitmq-stream+tls://'") - errURIWhitespace = errors.New("URI must not contain whitespace") - errNilContext = errors.New("context cannot be nil") - errNilConfig = errors.New("RabbitmqConfiguration cannot be nil") + errURIScheme = errors.New("RabbitMQ Stream scheme must be either 'rabbitmq-stream://' or 'rabbitmq-stream+tls://'") + errURIWhitespace = errors.New("URI must not contain whitespace") + errNilContext = errors.New("context cannot be nil") + errNilConfig = errors.New("RabbitmqConfiguration cannot be nil") + errUnknownSubscription = errors.New("unknown subscription ID") ) var schemePorts = map[string]int{"rabbitmq-stream": 5552, "rabbitmq-stream+tls": 5551} @@ -106,6 +107,7 @@ func NewClientConfiguration(rabbitmqUrls ...string) (*ClientConfiguration, error type PublishConfirm = internal.PublishConfirmResponse type Chunk = internal.ChunkResponse +type CreditError = internal.CreditResponse type Clienter interface { Connect(ctx context.Context) error @@ -120,4 +122,6 @@ type Clienter interface { NotifyPublish(chan *PublishConfirm) <-chan *PublishConfirm NotifyChunk(c chan *Chunk) <-chan *Chunk ExchangeCommandVersions(ctx context.Context) error + Credit(ctx context.Context, subscriptionId uint8, credit uint16) error + NotifyCreditError(notification chan *CreditError) <-chan *CreditError } diff --git a/pkg/raw/stream_suite_test.go b/pkg/raw/stream_suite_test.go index 75e38538..57ec2dcf 100644 --- a/pkg/raw/stream_suite_test.go +++ b/pkg/raw/stream_suite_test.go @@ -512,6 +512,49 @@ func (rmq *fakeRabbitMQServer) fakeRabbitMQPublisherConfirms(pubId uint8, numOfC expectOffset1(err).ToNot(HaveOccurred()) } +func (rmq *fakeRabbitMQServer) fakeRabbitMQCredit(subscriptionId uint8, credits uint16) { + defer GinkgoRecover() + expectOffset1(rmq.connection.SetDeadline(time.Now().Add(time.Second))). + To(Succeed()) + + serverReader := bufio.NewReader(rmq.connection) + + header := new(internal.Header) + expectOffset1(header.Read(serverReader)).To(Succeed()) + expectOffset1(header.Command()).To(BeNumerically("==", 0x0009)) + expectOffset1(header.Version()).To(BeNumerically("==", 1)) + + buff := make([]byte, header.Length()-4) + expectOffset1(io.ReadFull(serverReader, buff)). + To(BeNumerically("==", header.Length()-4)) + + body := new(internal.CreditRequest) + expectOffset1(body.UnmarshalBinary(buff)).To(Succeed()) + expectOffset1(body.SubscriptionId()).To(Equal(subscriptionId)) + expectOffset1(body.Credit()).To(Equal(credits)) +} + +func (rmq *fakeRabbitMQServer) fakeRabbitMQCreditResponse(ctx context.Context, subscriptionId uint8) { + defer GinkgoRecover() + expectOffset1(rmq.connection.SetDeadline(time.Now().Add(time.Second))). + To(Succeed()) + + serverWriter := bufio.NewWriter(rmq.connection) + + header := internal.NewHeader(4, 0x8009, 1) + expectOffset1(header.Write(serverWriter)).To(BeNumerically("==", 8)) + + body := internal.NewCreditResponse(responseCodeFromContext(ctx, "credit"), subscriptionId) + creditResponse, err := body.MarshalBinary() + expectOffset1(err).NotTo(HaveOccurred()) + + n, err := serverWriter.Write(creditResponse) + expectOffset1(err).NotTo(HaveOccurred()) + expectOffset1(n).To(BeNumerically("==", 3)) + + expectOffset1(serverWriter.Flush()).To(Succeed()) +} + func newContextWithResponseCode(ctx context.Context, respCode uint16, suffix ...string) context.Context { var key = "rabbitmq-stream.response-code" if suffix != nil {