Skip to content

Commit

Permalink
[sqs] Get messages with attributes round-tripping through Send/Receive
Browse files Browse the repository at this point in the history
  • Loading branch information
dzbarsky committed Feb 19, 2024
1 parent 26af496 commit 0b6e2b0
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 37 deletions.
90 changes: 80 additions & 10 deletions services/sqs/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"reflect"
"strconv"

"github.com/gofrs/uuid/v5"

Expand Down Expand Up @@ -68,33 +69,102 @@ func unmarshal(r *http.Request, target any) error {

for i := 0; i < ty.NumField(); i++ {
field := ty.Field(i)
fieldSingular := field.Name[:len(field.Name)-1]

f := v.Field(i)

switch field.Type.Kind() {
switch k := field.Type.Kind(); k {
case reflect.Int:
v := r.FormValue(field.Name)
if v == "" {
continue
}
i, err := strconv.Atoi(r.FormValue(field.Name))
if err != nil {
return err
}
f.Set(reflect.ValueOf(i))
case reflect.String:
f.Set(reflect.ValueOf(r.FormValue(field.Name)))
case reflect.Slice:
for i := 1; ; i++ {
v := r.FormValue(fmt.Sprintf("%s.%d", fieldSingular, i))
if v == "" {
break
}
f.Set(reflect.Append(f, reflect.ValueOf(v)))
}
case reflect.Map:
// Initialize the map and then read as many elements as we can.
f.Set(reflect.MakeMap(f.Type()))

EntriesLoop:
for i := 1; ; i++ {
mapKey := r.FormValue(fmt.Sprintf("%s.%d.Key", field.Name, i))
mapValue := r.FormValue(fmt.Sprintf("%s.%d.Value", field.Name, i))
if mapKey == "" && mapValue == "" {
break
// TODO(zbarsky): this is pretty HAX way to control the deserialization
switch field.Name {
case "Attribute", "Tag":
mapKey := r.FormValue(fmt.Sprintf("%s.%d.Key", fieldSingular, i))
mapValue := r.FormValue(fmt.Sprintf("%s.%d.Value", fieldSingular, i))
if mapKey == "" && mapValue == "" {
break EntriesLoop
}
if mapKey != "" && mapValue != "" {
f.SetMapIndex(reflect.ValueOf(mapKey), reflect.ValueOf(mapValue))
continue EntriesLoop
}
return errors.New("mismatched key/value?")
case "MessageAttributes", "MessageSystemAttributes":
mapKey := r.FormValue(fmt.Sprintf("%s.%d.Name", fieldSingular, i))
mapValue := extractAPIAttribute(
fmt.Sprintf("%s.%d.Value", fieldSingular, i),
r.FormValue)
if mapKey == "" && mapValue.DataType == "" {
break EntriesLoop
}
if mapKey != "" && mapValue.DataType != "" {
f.SetMapIndex(reflect.ValueOf(mapKey), reflect.ValueOf(mapValue))
continue EntriesLoop
}
return errors.New("mismatched key/value?")
default:
panic("Unknown field: " + field.Name)
}
if mapKey != "" && mapValue != "" {
f.SetMapIndex(reflect.ValueOf(mapKey), reflect.ValueOf(mapValue))
continue
}
return errors.New("mismatched key/value?")
}
default:
panic(field)
}

}

return nil
}

func extractAPIAttribute(prefix string, get func(string) string) APIAttribute {
attr := APIAttribute{
BinaryValue: []byte(get(prefix + ".BinaryValue")),
StringValue: get(prefix + ".StringValue"),
DataType: get(prefix + ".DataType"),
}

for i := 1; ; i++ {
v := get(fmt.Sprintf("%s.BinaryListValue.%d", prefix, i))
if v == "" {
break
}
attr.BinaryListValues = append(attr.BinaryListValues, []byte(v))
}

for i := 1; ; i++ {
v := get(fmt.Sprintf("%s.StringListValue.%d", prefix, i))
if v == "" {
break
}
attr.StringListValues = append(attr.StringListValues, v)
}

return attr
}

type xmlResp[T any] struct {
T *T
ResponseMetadata ResponseMetadata
Expand Down
67 changes: 61 additions & 6 deletions services/sqs/itest/sqs_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package itest

import (
"bytes"
"context"
"fmt"
"log/slog"
"net"
"net/http"
"slices"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"

"aws-in-a-box/server"
sqsImpl "aws-in-a-box/services/sqs"
Expand Down Expand Up @@ -38,7 +40,7 @@ func makeClientServerPair() (*sqs.Client, *http.Server) {
return client, srv
}

func TestQueue(t *testing.T) {
func TestSendReceiveMessage_RoundtripAttributes(t *testing.T) {
ctx := context.Background()
client, srv := makeClientServerPair()
defer srv.Shutdown(ctx)
Expand All @@ -54,12 +56,65 @@ func TestQueue(t *testing.T) {
t.Fatal(err)
}

msg, err := client.SendMessage(ctx, &sqs.SendMessageInput{
QueueUrl: resp.QueueUrl,
MessageBody: aws.String("READ THIS AND WEEP"),
messageAttributes := map[string]types.MessageAttributeValue{
"string": {
DataType: aws.String("String"),
StringValue: aws.String("s"),
},
"stringList": {
DataType: aws.String("String"),
StringListValues: []string{"s1", "s2"},
},
"binary": {
DataType: aws.String("Binary"),
BinaryValue: []byte("b"),
},
"binaryList": {
DataType: aws.String("Binary"),
BinaryListValues: [][]byte{[]byte("b1"), []byte("b2")},
},
}

body := "just a body, nothing to see here"
_, err = client.SendMessage(ctx, &sqs.SendMessageInput{
QueueUrl: resp.QueueUrl,
MessageBody: aws.String(body),
MessageAttributes: messageAttributes,
})
if err != nil {
t.Fatal(err)
}
fmt.Println("msg ", msg)

receiveResp, err := client.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
QueueUrl: resp.QueueUrl,
MessageAttributeNames: []string{".*"},
})
if err != nil {
t.Fatal(err)
}
if len(receiveResp.Messages) != 1 {
t.Fatalf("Did not receive right number of messages: %d", len(receiveResp.Messages))
}
msg := receiveResp.Messages[0]
if *msg.Body != body {
t.Fatal("Didn't get back the right message")
}
if *messageAttributes["string"].StringValue != *msg.MessageAttributes["string"].StringValue {
t.Fatal("string attribute did not roundtrip")
}
if !slices.Equal(messageAttributes["binary"].BinaryValue, msg.MessageAttributes["binary"].BinaryValue) {
t.Fatal("binary attribute did not roundtrip")
}
if !slices.Equal(messageAttributes["stringList"].StringListValues, msg.MessageAttributes["stringList"].StringListValues) {
t.Fatalf("stringList attribute did not roundtrip, got %v, want %v",
msg.MessageAttributes["stringList"].StringListValues,
messageAttributes["stringList"].StringListValues,
)
}
if !slices.EqualFunc(messageAttributes["binaryList"].BinaryListValues, msg.MessageAttributes["binaryList"].BinaryListValues, bytes.Equal) {
t.Fatalf("binaryList attribute did not roundtrip, got %v, want %v",
msg.MessageAttributes["binaryList"].BinaryListValues,
messageAttributes["binaryList"].BinaryListValues,
)
}
}
61 changes: 40 additions & 21 deletions services/sqs/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ const AWSTraceHeaderAttributeName = "AWSTraceHeader"

type SendMessageInput struct {
DelaySeconds int
MessageAttributes map[string]APIAttribute
MessageAttributes APIMessageAttributes
MessageBody string
MessageDeduplicationId string
MessageGroupId string
MessageSystemAttributes map[string]APIAttribute
MessageSystemAttributes APIMessageAttributes
QueueUrl string
}

Expand All @@ -42,10 +42,10 @@ type SendMessageOutput struct {
}

type APIAttribute struct {
BinaryListValues [][]byte
BinaryListValues [][]byte `xml:"BinaryListValue"`
BinaryValue []byte
DataType string
StringListValues []string
StringListValues []string `xml:"StringListValue"`
StringValue string
}

Expand Down Expand Up @@ -99,30 +99,30 @@ type ListQueueTagsOutput struct {
}

// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-AttributeNames
type AttributeName string
type SystemAttributeName string

const (
All = AttributeName("All")
ApproximateFirstReceiveTimestamp = AttributeName("ApproximateFirstReceiveTimestamp")
ApproximateReceiveCount = AttributeName("ApproximateReceiveCount")
AWSTraceHeader = AttributeName("AWSTraceHeader")
SenderId = AttributeName("SenderId")
SentTimestamp = AttributeName("SentTimestamp")
// TODO: this one not listedin https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-MessageSystemAttributeNames
SqsManagedSseEnabled = AttributeName("SqsManagedSseEnabled")
MessageDeduplicationId = AttributeName("MessageDeduplicationId")
MessageGroupId = AttributeName("MessageGroupId")
SequenceNumber = AttributeName("SequenceNumber")
DeadLetterQueueSourceArn = AttributeName("DeadLetterQueueSourceArn")
All = SystemAttributeName("All")
ApproximateFirstReceiveTimestamp = SystemAttributeName("ApproximateFirstReceiveTimestamp")
ApproximateReceiveCount = SystemAttributeName("ApproximateReceiveCount")
AWSTraceHeader = SystemAttributeName("AWSTraceHeader")
SenderId = SystemAttributeName("SenderId")
SentTimestamp = SystemAttributeName("SentTimestamp")
// TODO: this one not listed in https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html#SQS-ReceiveMessage-request-MessageSystemAttributeNames
SqsManagedSseEnabled = SystemAttributeName("SqsManagedSseEnabled")
MessageDeduplicationId = SystemAttributeName("MessageDeduplicationId")
MessageGroupId = SystemAttributeName("MessageGroupId")
SequenceNumber = SystemAttributeName("SequenceNumber")
DeadLetterQueueSourceArn = SystemAttributeName("DeadLetterQueueSourceArn")
// TODO: there are more
)

type ReceiveMessageInput struct {
// Deprecated
AttributeNames []AttributeName
AttributeNames []SystemAttributeName
MaxNumberOfMessages int
MessageAttributeNames []string
MessageSystemAttributeNames []AttributeName
MessageSystemAttributeNames []SystemAttributeName
QueueUrl string
// ReceiveRequestAttemptId
VisibilityTimeout int
Expand All @@ -134,12 +134,31 @@ type ReceiveMessageOutput struct {
Message []APIMessage
}

type APIAttributes map[string]string
type APIMessageAttributes map[string]APIAttribute

func (a APIMessageAttributes) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
type XMLAttribute struct {
Name string
Value APIAttribute
}
attrs := make([]XMLAttribute, 0, len(a))
for k, v := range a {
attrs = append(attrs, XMLAttribute{
Name: k,
Value: v,
})
}

return e.EncodeElement(attrs, start)
}

type APIMessage struct {
Attributes map[string]string
//Attributes APIAttributes
Body string
MD5OfBody string
MD5OfMessageAttributes string
MessageAttributes map[string]APIAttribute
MessageAttributes APIMessageAttributes `xml:"MessageAttribute"`
MessageId string
ReceiptHandle string
}
Expand Down

0 comments on commit 0b6e2b0

Please sign in to comment.