diff --git a/app/models/requests.go b/app/models/requests.go index 81ecadc9..fbe4c224 100644 --- a/app/models/requests.go +++ b/app/models/requests.go @@ -1,6 +1,7 @@ package models import ( + "encoding/base64" "encoding/json" "fmt" "net/url" @@ -13,8 +14,6 @@ import ( log "github.com/sirupsen/logrus" ) -var caser = cases.Title(language.AmericanEnglish) - type CreateQueueRequest struct { QueueName string `json:"QueueName" schema:"QueueName"` Attributes QueueAttributes `json:"Attributes" schema:"Attribute"` @@ -205,12 +204,18 @@ func (r *SendMessageRequest) SetAttributesFromForm(values url.Values) { } stringValue := values.Get(fmt.Sprintf("MessageAttribute.%d.Value.StringValue", i)) - binaryValue := values.Get(fmt.Sprintf("MessageAttribute.%d.Value.BinaryValue", i)) + encodedBinaryValue := values.Get(fmt.Sprintf("MessageAttribute.%d.Value.BinaryValue", i)) + + binaryValue, err := base64.StdEncoding.DecodeString(encodedBinaryValue) + if err != nil { + log.Warnf("Failed to base64 decode. %s may not have been base64 encoded.", encodedBinaryValue) + continue + } r.MessageAttributes[name] = MessageAttribute{ DataType: dataType, StringValue: stringValue, - BinaryValue: []byte(binaryValue), + BinaryValue: binaryValue, } } } @@ -257,16 +262,22 @@ func (r *SendMessageBatchRequest) SetAttributesFromForm(values url.Values) { } stringValue := values.Get(fmt.Sprintf("Entries.%d.MessageAttributes.%d.Value.StringValue", entryIndex, attributeIndex)) - binaryValue := values.Get(fmt.Sprintf("Entries.%d.MessageAttributes.%d.Value.BinaryValue", entryIndex, attributeIndex)) + encodedBinaryValue := values.Get(fmt.Sprintf("Entries.%d.MessageAttributes.%d.Value.BinaryValue", entryIndex, attributeIndex)) if r.Entries[entryIndex].MessageAttributes == nil { r.Entries[entryIndex].MessageAttributes = make(map[string]MessageAttribute) } + binaryValue, err := base64.StdEncoding.DecodeString(encodedBinaryValue) + if err != nil { + log.Warnf("Failed to base64 decode. %s may not have been base64 encoded.", encodedBinaryValue) + continue + } + r.Entries[entryIndex].MessageAttributes[name] = MessageAttribute{ DataType: dataType, StringValue: stringValue, - BinaryValue: []byte(binaryValue), + BinaryValue: binaryValue, } if _, ok := r.Entries[entryIndex].MessageAttributes[name]; !ok { @@ -741,15 +752,21 @@ func (r *PublishRequest) SetAttributesFromForm(values url.Values) { } stringValue := values.Get(fmt.Sprintf("MessageAttributes.entry.%d.Value.StringValue", i)) - binaryValue := values.Get(fmt.Sprintf("MessageAttributes.entry.%d.Value.BinaryValue", i)) + encodedBinaryValue := values.Get(fmt.Sprintf("MessageAttributes.entry.%d.Value.BinaryValue", i)) + binaryValue, err := base64.StdEncoding.DecodeString(encodedBinaryValue) + if err != nil { + log.Warnf("Failed to base64 decode. %s may not have been base64 encoded.", encodedBinaryValue) + continue + } if r.MessageAttributes == nil { r.MessageAttributes = make(map[string]MessageAttribute) } + attributes[name] = MessageAttribute{ - DataType: caser.String(dataType), // capitalize + DataType: cases.Title(language.AmericanEnglish).String(dataType), // capitalize StringValue: stringValue, - BinaryValue: []byte(binaryValue), + BinaryValue: binaryValue, } } r.MessageAttributes = attributes diff --git a/app/models/requests_test.go b/app/models/requests_test.go index 9c075802..cd755cdb 100644 --- a/app/models/requests_test.go +++ b/app/models/requests_test.go @@ -314,7 +314,7 @@ func TestSendMessageRequest_SetAttributesFromForm_success(t *testing.T) { attr2 := r.MessageAttributes["Attr2"] assert.Equal(t, "Binary", attr2.DataType) assert.Empty(t, attr2.StringValue) - assert.Equal(t, []uint8("VmFsdWUy"), attr2.BinaryValue) + assert.Equal(t, []uint8("Value2"), attr2.BinaryValue) } func TestSetQueueAttributesRequest_SetAttributesFromForm_success(t *testing.T) { diff --git a/app/models/responses.go b/app/models/responses.go index fce4d1a2..02e38f12 100644 --- a/app/models/responses.go +++ b/app/models/responses.go @@ -1,6 +1,7 @@ package models import ( + "encoding/base64" "encoding/xml" ) @@ -83,6 +84,9 @@ func (r *ResultMessage) MarshalXML(e *xml.Encoder, start xml.StartElement) error } var messageAttrs []MessageAttributes for key, value := range r.MessageAttributes { + if value.DataType == "Binary" { + value.BinaryValue = []byte(base64.StdEncoding.EncodeToString(value.BinaryValue)) + } attribute := MessageAttributes{ Name: key, Value: value, @@ -95,6 +99,9 @@ func (r *ResultMessage) MarshalXML(e *xml.Encoder, start xml.StartElement) error e.EncodeElement(r.MessageId, xml.StartElement{Name: xml.Name{Local: "MessageId"}}) e.EncodeElement(r.ReceiptHandle, xml.StartElement{Name: xml.Name{Local: "ReceiptHandle"}}) e.EncodeElement(r.MD5OfBody, xml.StartElement{Name: xml.Name{Local: "MD5OfBody"}}) + if r.MessageAttributes != nil { + e.EncodeElement(r.MD5OfMessageAttributes, xml.StartElement{Name: xml.Name{Local: "MD5OfMessageAttributes"}}) + } e.EncodeElement(r.Body, xml.StartElement{Name: xml.Name{Local: "Body"}}) e.EncodeElement(attrs, xml.StartElement{Name: xml.Name{Local: "Attribute"}}) e.EncodeElement(messageAttrs, xml.StartElement{Name: xml.Name{Local: "MessageAttribute"}}) diff --git a/app/models/responses_test.go b/app/models/responses_test.go index 55a4d409..7b4737c2 100644 --- a/app/models/responses_test.go +++ b/app/models/responses_test.go @@ -63,7 +63,7 @@ func Test_ResultMessage_MarshalXML_success_with_attributes(t *testing.T) { resultString := string(result) // We have to assert piecemeal like this, the maps go into their lists unordered, which will randomly break this. - entry := "message-idreceipt-handlebody-md5message-body" + entry := "message-idreceipt-handlebody-md5message-attrs-md5message-body" assert.Contains(t, resultString, entry) entry = "ApproximateFirstReceiveTimestamp1" @@ -81,7 +81,7 @@ func Test_ResultMessage_MarshalXML_success_with_attributes(t *testing.T) { entry = "attr1Stringstring-value" assert.Contains(t, resultString, entry) - entry = "attr2binary-valueBinary" + entry = "attr2YmluYXJ5LXZhbHVlBinary" assert.Contains(t, resultString, entry) entry = "attr3Numbernumber-value" diff --git a/smoke_tests/sns_publish_test.go b/smoke_tests/sns_publish_test.go index aaafceb4..0a9a19bd 100644 --- a/smoke_tests/sns_publish_test.go +++ b/smoke_tests/sns_publish_test.go @@ -1,6 +1,7 @@ package smoke_tests import ( + "bytes" "context" "encoding/json" "io" @@ -64,7 +65,92 @@ func Test_Publish_sqs_json_raw(t *testing.T) { assert.Equal(t, message, *receivedMessage.Messages[0].Body) } -func Test_Publish_Sqs_With_Message_Attributes(t *testing.T) { +func Test_Publish_sqs_json_with_message_attributes_raw(t *testing.T) { + server := generateServer() + defer func() { + server.Close() + models.ResetResources() + }() + + sdkConfig, _ := config.LoadDefaultConfig(context.TODO()) + sdkConfig.BaseEndpoint = aws.String(server.URL) + sqsClient := sqs.NewFromConfig(sdkConfig) + snsClient := sns.NewFromConfig(sdkConfig) + + createQueueResult, _ := sqsClient.CreateQueue(context.TODO(), &sqs.CreateQueueInput{ + QueueName: &af.QueueName, + }) + + topicName := aws.String("unit-topic2") + + createTopicResult, _ := snsClient.CreateTopic(context.TODO(), &sns.CreateTopicInput{ + Name: topicName, + }) + + subscribeResult, _ := snsClient.Subscribe(context.TODO(), &sns.SubscribeInput{ + Protocol: aws.String("sqs"), + TopicArn: createTopicResult.TopicArn, + Attributes: map[string]string{}, + Endpoint: createQueueResult.QueueUrl, + ReturnSubscriptionArn: true, + }) + + snsClient.SetSubscriptionAttributes(context.TODO(), &sns.SetSubscriptionAttributesInput{ + SubscriptionArn: subscribeResult.SubscriptionArn, + AttributeName: aws.String("RawMessageDelivery"), + AttributeValue: aws.String("true"), + }) + message := "{\"IAm\": \"aMessage\"}" + subject := "I am a subject" + stringKey := "string-key" + binaryKey := "binary-key" + numberKey := "number-key" + stringValue := "string-value" + binaryValue := []byte("binary-value") + numberValue := "100" + attributes := map[string]types.MessageAttributeValue{ + stringKey: { + StringValue: aws.String(stringValue), + DataType: aws.String("String"), + }, + binaryKey: { + BinaryValue: binaryValue, + DataType: aws.String("Binary"), + }, + numberKey: { + StringValue: aws.String(numberValue), + DataType: aws.String("Number"), + }, + } + + publishResponse, publishErr := snsClient.Publish(context.TODO(), &sns.PublishInput{ + TopicArn: createTopicResult.TopicArn, + Message: &message, + Subject: &subject, + MessageAttributes: attributes, + }) + + receiveMessageResponse, receiveErr := sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{ + QueueUrl: createQueueResult.QueueUrl, + }) + + assert.Nil(t, publishErr) + assert.NotNil(t, publishResponse) + + assert.Nil(t, receiveErr) + assert.NotNil(t, receiveMessageResponse) + assert.Equal(t, message, *receiveMessageResponse.Messages[0].Body) + + assert.Equal(t, "649b2c548f103e499304eda4d6d4c5a2", *receiveMessageResponse.Messages[0].MD5OfBody) + assert.Equal(t, "ddfbe54b92058bf5b5f00055fa2032a5", *receiveMessageResponse.Messages[0].MD5OfMessageAttributes) + + assert.Equal(t, stringValue, *receiveMessageResponse.Messages[0].MessageAttributes[stringKey].StringValue) + assert.True(t, bytes.Equal(binaryValue, receiveMessageResponse.Messages[0].MessageAttributes[binaryKey].BinaryValue)) + assert.Equal(t, numberValue, *receiveMessageResponse.Messages[0].MessageAttributes[numberKey].StringValue) + +} + +func Test_Publish_sqs_json_with_message_attributes_not_raw(t *testing.T) { server := generateServer() defer func() { server.Close() diff --git a/smoke_tests/sqs_receive_message_test.go b/smoke_tests/sqs_receive_message_test.go index 32c93104..14ce0642 100644 --- a/smoke_tests/sqs_receive_message_test.go +++ b/smoke_tests/sqs_receive_message_test.go @@ -268,7 +268,7 @@ func Test_ReceiveMessageV1_xml_with_attributes(t *testing.T) { assert.Equal(t, 1, len(receiveMessageResponse.Result.Messages)) assert.Equal(t, "MyTestMessage", receiveMessageResponse.Result.Messages[0].Body) assert.Equal(t, "ad4883a84ad41c79aa3a373698c0d4e9", receiveMessageResponse.Result.Messages[0].MD5OfBody) - assert.Equal(t, "", receiveMessageResponse.Result.Messages[0].MD5OfMessageAttributes) + assert.Equal(t, "ae8770938aee44bc548cf65ac377e3bf", receiveMessageResponse.Result.Messages[0].MD5OfMessageAttributes) entry := "ApproximateFirstReceiveTimestamp" assert.Contains(t, response, entry) @@ -288,6 +288,6 @@ func Test_ReceiveMessageV1_xml_with_attributes(t *testing.T) { entry = "attr2Numbernumber-value" assert.Contains(t, response, entry) - entry = "attr3binary-valueBinary" + entry = "attr3YmluYXJ5LXZhbHVlBinary" assert.Contains(t, response, entry) } diff --git a/smoke_tests/sqs_send_message_batch_test.go b/smoke_tests/sqs_send_message_batch_test.go index 78b9dfb9..aec4b149 100644 --- a/smoke_tests/sqs_send_message_batch_test.go +++ b/smoke_tests/sqs_send_message_batch_test.go @@ -2,6 +2,7 @@ package smoke_tests import ( "context" + "encoding/base64" "encoding/xml" "fmt" "net/http" @@ -348,7 +349,8 @@ func TestSendMessageBatchV1_Xml_Success_including_attributes(t *testing.T) { stringType := "String" numberType := "Number" - binaryValue := "binary-value" + binaryValue := []byte("binary-value") + binaryValueEncodeString := base64.StdEncoding.EncodeToString([]byte("binary-value")) stringValue := "string-value" numberValue := "100" @@ -370,7 +372,7 @@ func TestSendMessageBatchV1_Xml_Success_including_attributes(t *testing.T) { WithFormField("Entries.1.MessageBody", messageBody2). WithFormField("Entries.1.MessageAttributes.1.Name", binaryAttributeKey). WithFormField("Entries.1.MessageAttributes.1.Value.DataType", binaryType). - WithFormField("Entries.1.MessageAttributes.1.Value.BinaryValue", binaryValue). + WithFormField("Entries.1.MessageAttributes.1.Value.BinaryValue", binaryValueEncodeString). WithFormField("Entries.1.MessageAttributes.2.Name", stringAttributeKey). WithFormField("Entries.1.MessageAttributes.2.Value.DataType", stringType). WithFormField("Entries.1.MessageAttributes.2.Value.StringValue", stringValue). diff --git a/smoke_tests/sqs_send_message_test.go b/smoke_tests/sqs_send_message_test.go index 424959ac..16f16fc8 100644 --- a/smoke_tests/sqs_send_message_test.go +++ b/smoke_tests/sqs_send_message_test.go @@ -282,7 +282,7 @@ func Test_SendMessageV1_xml_with_attributes(t *testing.T) { WithFormField("MessageAttribute.2.Value.StringValue", "2"). WithFormField("MessageAttribute.3.Name", "attr3"). WithFormField("MessageAttribute.3.Value.DataType", "Binary"). - WithFormField("MessageAttribute.3.Value.BinaryValue", "attr3_value"). + WithFormField("MessageAttribute.3.Value.BinaryValue", "YXR0cjNfdmFsdWU="). // base64 encode string attr3_value Expect(). Status(http.StatusOK). Body().Raw()