diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 7ebe3b8..3a0040d 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -2,12 +2,14 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.extractMessageFromSnsJson; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageInSnsJson; import java.util.ArrayList; import java.util.HashMap; @@ -211,15 +213,26 @@ public CompletableFuture receiveMessage(ReceiveMessageRe for (Message message : messages) { Message.Builder messageBuilder = message.toBuilder(); + final String originalBody = message.body(); + String effectiveBody = originalBody; + if (clientConfiguration.isPayloadSupportFromSnsEnabled()) { + effectiveBody = extractMessageFromSnsJson(originalBody); + } + + final Message messageToProcess = messageBuilder.body(effectiveBody).build(); + // For each received message check if they are stored in S3. Optional largePayloadAttributeName = getReservedAttributeNameIfPresent( - message.messageAttributes()); + messageToProcess.messageAttributes()); if (!largePayloadAttributeName.isPresent()) { // Not S3 + // If it was SNS, the builder already has effectiveBody, but we want to return originalBody + // if it's not a large payload, to preserve the envelope. + messageBuilder.body(originalBody); modifiedMessageFutures.add(CompletableFuture.completedFuture(messageBuilder.build())); } else { // In S3 - final String largeMessagePointer = message.body() + final String largeMessagePointer = messageToProcess.body() .replace("com.amazon.sqs.javamessaging.MessageS3Pointer", "software.amazon.payloadoffloading.PayloadS3Pointer"); @@ -234,7 +247,7 @@ public CompletableFuture receiveMessage(ReceiveMessageRe DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest .builder() .queueUrl(queueUrl) - .receiptHandle(message.receiptHandle()) + .receiptHandle(messageToProcess.receiptHandle()) .build(); deleteMessage(deleteMessageRequest).join(); @@ -248,18 +261,22 @@ public CompletableFuture receiveMessage(ReceiveMessageRe } // Set original payload - messageBuilder.body(originalPayload); + if (clientConfiguration.isPayloadSupportFromSnsEnabled()) { + messageBuilder.body(updateMessageInSnsJson(originalBody, originalPayload)); + } else { + messageBuilder.body(originalPayload); + } // Remove the additional attribute before returning the message // to user. Map messageAttributes = new HashMap<>( - message.messageAttributes()); + messageToProcess.messageAttributes()); messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); messageBuilder.messageAttributes(messageAttributes); // Embed s3 object pointer in the receipt handle. String modifiedReceiptHandle = embedS3PointerInReceiptHandle( - message.receiptHandle(), + messageToProcess.receiptHandle(), largeMessagePointer); messageBuilder.receiptHandle(modifiedReceiptHandle); diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 5b372a9..6ab0834 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -17,12 +17,14 @@ import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.extractMessageFromSnsJson; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageInSnsJson; import java.util.ArrayList; import java.util.HashMap; @@ -335,14 +337,25 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag for (Message message : messages) { Message.Builder messageBuilder = message.toBuilder(); + String originalBody = message.body(); + String effectiveBody = originalBody; + if (clientConfiguration.isPayloadSupportFromSnsEnabled()) { + effectiveBody = extractMessageFromSnsJson(originalBody); + } + // for each received message check if they are stored in S3. Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(message.messageAttributes()); if (largePayloadAttributeName.isPresent()) { - String largeMessagePointer = message.body(); + String largeMessagePointer = effectiveBody; largeMessagePointer = largeMessagePointer.replace("com.amazon.sqs.javamessaging.MessageS3Pointer", "software.amazon.payloadoffloading.PayloadS3Pointer"); try { - messageBuilder.body(payloadStore.getOriginalPayload(largeMessagePointer)); + String resolvedPayload = payloadStore.getOriginalPayload(largeMessagePointer); + if (clientConfiguration.isPayloadSupportFromSnsEnabled()) { + messageBuilder.body(updateMessageInSnsJson(originalBody, resolvedPayload)); + } else { + messageBuilder.body(resolvedPayload); + } } catch (SdkException e) { if (e.getCause() instanceof NoSuchKeyException && clientConfiguration.ignoresPayloadNotFound()) { DeleteMessageRequest deleteMessageRequest = DeleteMessageRequest diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java index 8bf1609..436dce7 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java @@ -1,5 +1,8 @@ package com.amazon.sqs.javamessaging; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -21,6 +24,7 @@ public class AmazonSQSExtendedClientUtil { private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedClientUtil.class); + private static final ObjectMapper MAPPER = new ObjectMapper(); public static final String LEGACY_RESERVED_ATTRIBUTE_NAME = "SQSLargePayloadSize"; public static final List RESERVED_ATTRIBUTE_NAMES = Arrays.asList(LEGACY_RESERVED_ATTRIBUTE_NAME, @@ -138,6 +142,32 @@ public static T appendUserAgent( .build()); } + public static String extractMessageFromSnsJson(String snsJson) { + try { + JsonNode rootNode = MAPPER.readTree(snsJson); + if (rootNode.has("Message")) { + return rootNode.get("Message").asText(); + } + } catch (Exception e) { + LOG.warn("Failed to parse SNS JSON message body", e); + } + return snsJson; + } + + public static String updateMessageInSnsJson(String snsJson, String newMessage) { + try { + JsonNode rootNode = MAPPER.readTree(snsJson); + if (rootNode.isObject() && rootNode.has("Message")) { + ObjectNode objectNode = (ObjectNode) rootNode; + objectNode.put("Message", newMessage); + return MAPPER.writeValueAsString(objectNode); + } + } catch (Exception e) { + LOG.warn("Failed to update SNS JSON message body", e); + } + return newMessage; + } + private static String getFromReceiptHandleByMarker(String receiptHandle, String marker) { int firstOccurence = receiptHandle.indexOf(marker); int secondOccurence = receiptHandle.indexOf(marker, firstOccurence + 1); diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java index e96fff2..46dc05d 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -17,6 +17,7 @@ public class ExtendedAsyncClientConfiguration extends PayloadStorageAsyncConfigu private boolean cleanupS3Payload = true; private boolean useLegacyReservedAttributeName = true; private boolean ignorePayloadNotFound = false; + private boolean payloadSupportFromSnsEnabled = false; private String s3KeyPrefix = ""; public ExtendedAsyncClientConfiguration() { @@ -28,6 +29,7 @@ public ExtendedAsyncClientConfiguration(ExtendedAsyncClientConfiguration other) this.cleanupS3Payload = other.doesCleanupS3Payload(); this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName(); this.ignorePayloadNotFound = other.ignoresPayloadNotFound(); + this.payloadSupportFromSnsEnabled = other.isPayloadSupportFromSnsEnabled(); this.s3KeyPrefix = other.s3KeyPrefix; } @@ -183,6 +185,43 @@ public boolean ignoresPayloadNotFound() { return ignorePayloadNotFound; } + /** + * Sets whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @param payloadSupportFromSnsEnabled + * Whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + */ + public void setPayloadSupportFromSnsEnabled(boolean payloadSupportFromSnsEnabled) { + this.payloadSupportFromSnsEnabled = payloadSupportFromSnsEnabled; + } + + /** + * Sets whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @param payloadSupportFromSnsEnabled + * Whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + * @return the updated ExtendedAsyncClientConfiguration object. + */ + public ExtendedAsyncClientConfiguration withPayloadSupportFromSnsEnabled(boolean payloadSupportFromSnsEnabled) { + setPayloadSupportFromSnsEnabled(payloadSupportFromSnsEnabled); + return this; + } + + /** + * Checks whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @return True if the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + */ + public boolean isPayloadSupportFromSnsEnabled() { + return payloadSupportFromSnsEnabled; + } + @Override public ExtendedAsyncClientConfiguration withAlwaysThroughS3(boolean alwaysThroughS3) { setAlwaysThroughS3(alwaysThroughS3); diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java index 75a30f8..28dcf1b 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java @@ -35,6 +35,7 @@ public class ExtendedClientConfiguration extends PayloadStorageConfiguration { private boolean cleanupS3Payload = true; private boolean useLegacyReservedAttributeName = true; private boolean ignorePayloadNotFound = false; + private boolean payloadSupportFromSnsEnabled = false; private String s3KeyPrefix = ""; public ExtendedClientConfiguration() { @@ -47,6 +48,7 @@ public ExtendedClientConfiguration(ExtendedClientConfiguration other) { this.cleanupS3Payload = other.doesCleanupS3Payload(); this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName(); this.ignorePayloadNotFound = other.ignoresPayloadNotFound(); + this.payloadSupportFromSnsEnabled = other.isPayloadSupportFromSnsEnabled(); this.s3KeyPrefix = other.s3KeyPrefix; } @@ -196,6 +198,43 @@ public boolean ignoresPayloadNotFound() { return ignorePayloadNotFound; } + /** + * Sets whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @param payloadSupportFromSnsEnabled + * Whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + */ + public void setPayloadSupportFromSnsEnabled(boolean payloadSupportFromSnsEnabled) { + this.payloadSupportFromSnsEnabled = payloadSupportFromSnsEnabled; + } + + /** + * Sets whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @param payloadSupportFromSnsEnabled + * Whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + * @return the updated ExtendedClientConfiguration object. + */ + public ExtendedClientConfiguration withPayloadSupportFromSnsEnabled(boolean payloadSupportFromSnsEnabled) { + setPayloadSupportFromSnsEnabled(payloadSupportFromSnsEnabled); + return this; + } + + /** + * Checks whether or not the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. + * + * @return True if the client should attempt to find the message body + * in the "Message" field of a JSON-formatted SNS message. Default: false + */ + public boolean isPayloadSupportFromSnsEnabled() { + return payloadSupportFromSnsEnabled; + } + @Override public ExtendedClientConfiguration withAlwaysThroughS3(boolean alwaysThroughS3) { setAlwaysThroughS3(alwaysThroughS3); diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientSnsTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientSnsTest.java new file mode 100644 index 0000000..a4390df --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientSnsTest.java @@ -0,0 +1,138 @@ +package com.amazon.sqs.javamessaging; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.utils.StringInputStream; +import software.amazon.payloadoffloading.PayloadS3Pointer; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AmazonSQSExtendedClientSnsTest { + + private static final String S3_BUCKET_NAME = "test-bucket"; + private static final ObjectMapper MAPPER = new ObjectMapper(); + private SqsClient mockSqs; + private S3Client mockS3; + private SqsAsyncClient mockSqsAsync; + private S3AsyncClient mockS3Async; + + @BeforeEach + public void setup() { + mockSqs = mock(SqsClient.class); + mockS3 = mock(S3Client.class); + mockSqsAsync = mock(SqsAsyncClient.class); + mockS3Async = mock(S3AsyncClient.class); + } + + @Test + public void testSyncReceiveMessageWithSnsAndS3Pointer() throws Exception { + ExtendedClientConfiguration config = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSupportFromSnsEnabled(true); + AmazonSQSExtendedClient extendedClient = new AmazonSQSExtendedClient(mockSqs, config); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "s3-key").toJson(); + String escapedS3Pointer = s3Pointer.replace("\"", "\\\""); + String snsJson = "{\"Type\":\"Notification\",\"Subject\":\"LargePayload\",\"Message\":\"" + escapedS3Pointer + "\"}"; + + Message message = Message.builder() + .body(snsJson) + .messageAttributes(Collections.singletonMap(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, + MessageAttributeValue.builder().dataType("Number").stringValue("100").build())) + .build(); + + when(mockSqs.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder().messages(message).build()); + + String largePayload = "ResolvedLargePayload"; + GetObjectResponse getObjectResponse = GetObjectResponse.builder().build(); + ResponseInputStream s3Stream = new ResponseInputStream<>(getObjectResponse, + AbortableInputStream.create(new StringInputStream(largePayload))); + when(mockS3.getObject(any(GetObjectRequest.class))).thenReturn(s3Stream); + + ReceiveMessageResponse result = extendedClient.receiveMessage(ReceiveMessageRequest.builder().build()); + String resultBody = result.messages().get(0).body(); + + JsonNode root = MAPPER.readTree(resultBody); + assertEquals("Notification", root.get("Type").asText()); + assertEquals("LargePayload", root.get("Subject").asText()); + assertEquals(largePayload, root.get("Message").asText()); + } + + @Test + public void testAsyncReceiveMessageWithSnsAndS3Pointer() throws Exception { + ExtendedAsyncClientConfiguration config = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3Async, S3_BUCKET_NAME) + .withPayloadSupportFromSnsEnabled(true); + AmazonSQSExtendedAsyncClient extendedClient = new AmazonSQSExtendedAsyncClient(mockSqsAsync, config); + + String s3Pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "s3-key").toJson(); + String escapedS3Pointer = s3Pointer.replace("\"", "\\\""); + String snsJson = "{\"Type\":\"Notification\",\"Message\":\"" + escapedS3Pointer + "\"}"; + + Message message = Message.builder() + .body(snsJson) + .messageAttributes(Collections.singletonMap(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, + MessageAttributeValue.builder().dataType("Number").stringValue("100").build())) + .build(); + + when(mockSqsAsync.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(CompletableFuture.completedFuture(ReceiveMessageResponse.builder().messages(message).build())); + + String largePayload = "ResolvedLargePayloadAsync"; + ResponseBytes s3Object = ResponseBytes.fromByteArray( + GetObjectResponse.builder().build(), + largePayload.getBytes(StandardCharsets.UTF_8)); + when(mockS3Async.getObject(isA(GetObjectRequest.class), isA(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(s3Object)); + + ReceiveMessageResponse result = extendedClient.receiveMessage(ReceiveMessageRequest.builder().build()).get(); + String resultBody = result.messages().get(0).body(); + + JsonNode root = MAPPER.readTree(resultBody); + assertEquals(largePayload, root.get("Message").asText()); + } + + @Test + public void testReceiveMessageStandardSqsWhenSnsEnabled() { + ExtendedClientConfiguration config = new ExtendedClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withPayloadSupportFromSnsEnabled(true); + AmazonSQSExtendedClient extendedClient = new AmazonSQSExtendedClient(mockSqs, config); + + String standardBody = "Standard SQS Body"; + Message message = Message.builder().body(standardBody).build(); + + when(mockSqs.receiveMessage(any(ReceiveMessageRequest.class))) + .thenReturn(ReceiveMessageResponse.builder().messages(message).build()); + + ReceiveMessageResponse result = extendedClient.receiveMessage(ReceiveMessageRequest.builder().build()); + assertEquals(standardBody, result.messages().get(0).body()); + } +} diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtilTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtilTest.java new file mode 100644 index 0000000..2ba6bd9 --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtilTest.java @@ -0,0 +1,42 @@ +package com.amazon.sqs.javamessaging; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AmazonSQSExtendedClientUtilTest { + + @Test + public void testExtractMessageFromSnsJson() { + String snsJson = "{\"Type\":\"Notification\",\"MessageId\":\"123\",\"Message\":\"Hello World\"}"; + assertEquals("Hello World", AmazonSQSExtendedClientUtil.extractMessageFromSnsJson(snsJson)); + } + + @Test + public void testUpdateMessageInSnsJson() { + String snsJson = "{\"Type\":\"Notification\",\"Subject\":\"Test\",\"Message\":\"Old\"}"; + String result = AmazonSQSExtendedClientUtil.updateMessageInSnsJson(snsJson, "NewContent"); + assertEquals("{\"Type\":\"Notification\",\"Subject\":\"Test\",\"Message\":\"NewContent\"}", result); + } + + @Test + public void testUpdateMessageInSnsJsonEscaping() { + String snsJson = "{\"Message\":\"Old\"}"; + String result = AmazonSQSExtendedClientUtil.updateMessageInSnsJson(snsJson, "Line1\nLine2 \"Quoted\""); + assertEquals("{\"Message\":\"Line1\\nLine2 \\\"Quoted\\\"\"}", result); + } + + @Test + public void testUpdateMessageInSnsJsonNonSnsJson() { + // If it's valid JSON but doesn't have "Message", it should return the new message directly (fallback) + String json = "{\"OtherField\":\"Value\"}"; + String result = AmazonSQSExtendedClientUtil.updateMessageInSnsJson(json, "NewContent"); + assertEquals("NewContent", result); + } + + @Test + public void testUpdateMessageInSnsJsonMalformed() { + String malformed = "not json"; + String result = AmazonSQSExtendedClientUtil.updateMessageInSnsJson(malformed, "NewContent"); + assertEquals("NewContent", result); + } +} diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java index 879f098..6ec69ca 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -34,7 +34,8 @@ public void testCopyConstructor() { extendedClientConfig.withPayloadSupportEnabled(s3, s3BucketName, doesCleanupS3Payload) .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(messageSizeThreshold) - .withServerSideEncryption(serverSideEncryptionStrategy); + .withServerSideEncryption(serverSideEncryptionStrategy) + .withPayloadSupportFromSnsEnabled(true); ExtendedAsyncClientConfiguration newExtendedClientConfig = new ExtendedAsyncClientConfiguration(extendedClientConfig); @@ -45,10 +46,24 @@ public void testCopyConstructor() { assertEquals(doesCleanupS3Payload, newExtendedClientConfig.doesCleanupS3Payload()); assertEquals(alwaysThroughS3, newExtendedClientConfig.isAlwaysThroughS3()); assertEquals(messageSizeThreshold, newExtendedClientConfig.getPayloadSizeThreshold()); + assertTrue(newExtendedClientConfig.isPayloadSupportFromSnsEnabled()); assertNotSame(newExtendedClientConfig, extendedClientConfig); } + @Test + public void testPayloadSupportFromSnsEnabledDefault() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + assertFalse(extendedClientConfiguration.isPayloadSupportFromSnsEnabled()); + } + + @Test + public void testPayloadSupportFromSnsEnabledSet() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.setPayloadSupportFromSnsEnabled(true); + assertTrue(extendedClientConfiguration.isPayloadSupportFromSnsEnabled()); + } + @Test public void testLargePayloadSupportEnabledWithDefaultDeleteFromS3Config() { S3AsyncClient s3 = mock(S3AsyncClient.class); diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java index 2dc5b6b..16243ac 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedClientConfigurationTest.java @@ -56,7 +56,8 @@ public void testCopyConstructor() { extendedClientConfig.withPayloadSupportEnabled(s3, s3BucketName, doesCleanupS3Payload) .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(messageSizeThreshold) - .withServerSideEncryption(serverSideEncryptionStrategy); + .withServerSideEncryption(serverSideEncryptionStrategy) + .withPayloadSupportFromSnsEnabled(true); ExtendedClientConfiguration newExtendedClientConfig = new ExtendedClientConfiguration(extendedClientConfig); @@ -67,10 +68,24 @@ public void testCopyConstructor() { assertEquals(doesCleanupS3Payload, newExtendedClientConfig.doesCleanupS3Payload()); assertEquals(alwaysThroughS3, newExtendedClientConfig.isAlwaysThroughS3()); assertEquals(messageSizeThreshold, newExtendedClientConfig.getPayloadSizeThreshold()); + assertTrue(newExtendedClientConfig.isPayloadSupportFromSnsEnabled()); assertNotSame(newExtendedClientConfig, extendedClientConfig); } + @Test + public void testPayloadSupportFromSnsEnabledDefault() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + assertFalse(extendedClientConfiguration.isPayloadSupportFromSnsEnabled()); + } + + @Test + public void testPayloadSupportFromSnsEnabledSet() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration(); + extendedClientConfiguration.setPayloadSupportFromSnsEnabled(true); + assertTrue(extendedClientConfiguration.isPayloadSupportFromSnsEnabled()); + } + @Test public void testLargePayloadSupportEnabledWithDefaultDeleteFromS3Config() { S3Client s3 = mock(S3Client.class);