Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,20 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static software.amazon.encryption.s3.TestUtils.*;

import java.net.Socket;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.amazonaws.services.s3.model.KMSEncryptionMaterials;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.smithy.java.aws.client.restjson.RestJsonClientProtocol;
import software.amazon.smithy.java.client.core.ClientConfig;
import software.amazon.smithy.java.client.core.ClientProtocol;
import software.amazon.smithy.java.client.core.endpoint.EndpointResolver;
import software.amazon.encryption.s3.client.S3ECTestServerClient;
import software.amazon.encryption.s3.model.CreateClientInput;
import software.amazon.encryption.s3.model.CreateClientOutput;
Expand All @@ -41,200 +30,26 @@
import software.amazon.encryption.s3.model.KeyMaterial;
import software.amazon.encryption.s3.model.PutObjectInput;
import software.amazon.encryption.s3.model.S3ECConfig;
import software.amazon.encryption.s3.model.S3ECTestServerApiService;
import software.amazon.encryption.s3.model.S3EncryptionClientError;
import software.amazon.smithy.java.http.api.HttpRequest;
import software.amazon.smithy.java.http.api.HttpResponse;

import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.s3.AmazonS3Encryption;
import com.amazonaws.services.s3.AmazonS3EncryptionClient;
import com.amazonaws.services.s3.model.CryptoConfiguration;
import com.amazonaws.services.s3.model.CryptoMode;
import com.amazonaws.services.s3.model.CryptoStorageMode;
import software.amazon.encryption.s3.TestUtils.*;
import com.amazonaws.services.s3.model.EncryptionMaterialsProvider;
import com.amazonaws.services.s3.model.KMSEncryptionMaterialsProvider;

public class RoundTripTests {
private static final String JAVA_V3 = "Java-V3";
private static final String PYTHON_V3 = "Python-V3";
private static final String GO_V3 = "Go-V3";
private static final String CPP_V2 = "CPP-V2";
private static final String NET_V2 = "NET-V2";
private static final String NET_V3 = "NET-V3";
private static final String PHP_V2 = "PHP-V2";
private static final String PHP_V3 = "PHP-V3";
private static final String RUBY_V2 = "Ruby-V2";
private static final String RUBY_V3 = "Ruby-V3";

private static final Map<String, LanguageServerTarget> serverMap;

private static final String KMS_KEY_ARN = System.getenv("TEST_SERVER_KMS_KEY_ARN") != null ?
System.getenv("TEST_SERVER_KMS_KEY_ARN") : "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Test-Server-Github-KMS-Key";
private static final Region KMS_REGION = Region.getRegion(Regions.fromName("us-west-2"));
private static final String BUCKET = System.getenv("TEST_SERVER_S3_BUCKET") != null ?
System.getenv("TEST_SERVER_S3_BUCKET") : "s3ec-test-server-github-bucket";

static {
final Map<String, LanguageServerTarget> servers = new LinkedHashMap<>();
servers.put(JAVA_V3, new LanguageServerTarget(JAVA_V3, "8080"));
servers.put(PYTHON_V3, new LanguageServerTarget(PYTHON_V3, "8081"));
servers.put(GO_V3, new LanguageServerTarget(GO_V3, "8082"));
servers.put(NET_V2, new LanguageServerTarget(NET_V2, "8083"));
servers.put(NET_V3, new LanguageServerTarget(NET_V3, "8084"));
servers.put(CPP_V2, new LanguageServerTarget(CPP_V2, "8085"));
servers.put(PHP_V2, new LanguageServerTarget(PHP_V2, "8087"));
servers.put(PHP_V3, new LanguageServerTarget(PHP_V3, "8093"));
servers.put(RUBY_V2, new LanguageServerTarget(RUBY_V2, "8086"));
servers.put(RUBY_V3, new LanguageServerTarget(RUBY_V3, "8092"));

serverMap = filterServers(servers);
}

private static Map<String, LanguageServerTarget> filterServers(Map<String, LanguageServerTarget> allServers) {

final String maybeFilter = System.getProperty("test.filter.servers");
if (maybeFilter == null || maybeFilter.trim().isEmpty()) {
return allServers; // No filtering - use all servers
}

final String[] filters = Arrays.stream(maybeFilter.split(","))
.map(String::trim)
.map(String::toLowerCase)
.toArray(String[]::new);

return allServers.entrySet().stream()
.filter(entry -> {
String key = entry.getKey().toLowerCase();
return Arrays.stream(filters).anyMatch(key::contains);
})
.collect(Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue,
(e1, e2) -> e1, // merge function (not really needed)
LinkedHashMap::new // preserve order
));
}

// Encryption context validation behavior varies by implementation:
// - Go: Does not validate encryption context on decrypt operations
// - .NET: Only validates against encryption context stored in the object metadata
// If the encryption context provided to getObject does not match the encryption context on the stored object,
// these implementations will not raise an error as expected.
// For now, skip tests that expect encryption context validation on decrypt.
private static final Set<String> ENCRYPTION_CONTEXT_ON_DECRYPT_UNSUPPORTED =
Set.of(GO_V3, PHP_V2, PHP_V3, NET_V2, NET_V3);

// S3EC .NET implementations does not accept encryption context (EC) during putObject operations.
// These tests are not configured to pass encryption context at client level but at encrypt,
// So, for .NET EC is not passed.
// For now, skip tests that expect encryption context validation on decrypt.
private static final Set<String> ENCRYPTION_CONTEXT_ON_ENCRYPT_UNSUPPORTED =
Set.of(NET_V2, NET_V3);

static public class LanguageServerTarget {
public String getLanguageName() {
return languageName;
}

public URI getServerURI() {
return serverURI;
}

private final String baseURI = "http://localhost";
private String languageName;
private URI serverURI;

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
LanguageServerTarget that = (LanguageServerTarget) o;
return Objects.equals(languageName, that.languageName) && Objects.equals(serverURI, that.serverURI);
}

@Override
public int hashCode() {
return Objects.hash(languageName, serverURI);
}

LanguageServerTarget(String language, String port) {
languageName = language;
serverURI = URI.create(baseURI+ ":" + port);
}

@Override
public String toString() {
return languageName;
}
}

@BeforeAll
public static void setup() {
// Wait for servers to start
for (LanguageServerTarget server : serverMap.values()) {
if (!serverListening(server.getServerURI())) {
throw new RuntimeException(String.format("Test Server for %s is not running at endpoint: %s", server.getLanguageName(), server.getServerURI()));
}
}
}

public static boolean serverListening(URI uri) {
try (Socket ignored = new Socket(uri.getHost(), uri.getPort())) {
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}

static S3ECTestServerClient testServerClientFor(LanguageServerTarget server) {
S3ECTestServerApiService apiService = S3ECTestServerApiService.instance();
ClientProtocol<HttpRequest, HttpResponse> rest = new RestJsonClientProtocol(apiService.schema().id());
return S3ECTestServerClient.builder()
.endpointResolver(EndpointResolver.staticEndpoint(server.serverURI))
.withConfiguration(ClientConfig.builder()
.service(apiService)
.protocol(rest)
.endpointResolver(EndpointResolver.staticEndpoint(server.serverURI))
.build())
.build();
}

static Stream<Arguments> clientsForTest() {
return serverMap.values().stream()
.map(LanguageServerTarget::getLanguageName)
.map(Arguments::of);
}

static Stream<Arguments> crossLanguageClients() {
return serverMap.values().stream()
.flatMap(t1 -> serverMap.values().stream()
.flatMap(t2 -> Stream.of(
Arguments.of(t1, t2)
)));
}

/**
* Annoyingly, Smithy doesn't provide an interface for map types
* in HTTP headers, so we have to do the serde ourselves
* Servers need an equivalent utility.
* TODO: Move to a utilities class or something.
*/
private List<String> metadataMapToList(Map<String, String> md) {
List<String> mdAsList = new ArrayList<>(md.size());
for (Map.Entry<String, String> keyValue : md.entrySet()) {
// Using ":" because Smithy will parse "," into a flattened list
mdAsList.add("[" + keyValue.getKey() + "]:[" + keyValue.getValue() + "]");
}
return mdAsList;
validateServersRunning();
}

@ParameterizedTest(name = "{displayName} for Encrypt: {0}, Decrypt: {1}")
@MethodSource("crossLanguageClients")
@MethodSource("software.amazon.encryption.s3.TestUtils#crossLanguageClients")
public void crossLanguageTestKms(LanguageServerTarget encLang, LanguageServerTarget decLang) {
S3ECTestServerClient encClient = testServerClientFor(encLang);
final String objectKey = "cross-lang-test-key-" + encLang;
Expand Down Expand Up @@ -271,7 +86,7 @@ public void crossLanguageTestKms(LanguageServerTarget encLang, LanguageServerTar
}

@ParameterizedTest(name = "{displayName} for Encrypt: {0}, Decrypt: {1}")
@MethodSource("crossLanguageClients")
@MethodSource("software.amazon.encryption.s3.TestUtils#crossLanguageClients")
public void crossLanguageTestKmsWithEncCtx(LanguageServerTarget encLang, LanguageServerTarget decLang) {
if (ENCRYPTION_CONTEXT_ON_ENCRYPT_UNSUPPORTED.contains(encLang.getLanguageName())) {
return;
Expand Down Expand Up @@ -318,7 +133,7 @@ public void crossLanguageTestKmsWithEncCtx(LanguageServerTarget encLang, Languag
}

@ParameterizedTest(name = "{displayName} for Encrypt: {0}, Decrypt: {1}")
@MethodSource("crossLanguageClients")
@MethodSource("software.amazon.encryption.s3.TestUtils#crossLanguageClients")
public void crossLanguageTestKmsWithSubsetEncCtxFails(LanguageServerTarget encLang, LanguageServerTarget decLang) {
if (ENCRYPTION_CONTEXT_ON_DECRYPT_UNSUPPORTED.contains(decLang.getLanguageName())) {
return;
Expand Down Expand Up @@ -363,7 +178,7 @@ public void crossLanguageTestKmsWithSubsetEncCtxFails(LanguageServerTarget encLa
.build());
fail("Expected exception!");
} catch (S3EncryptionClientError e) {
if (decLang.languageName.equals(RUBY_V3) || decLang.languageName.equals(RUBY_V2)) {
if (decLang.getLanguageName().equals(RUBY_V3) || decLang.getLanguageName().equals(RUBY_V2)) {
assertTrue(e.getMessage().contains("Value of encryption context from envelope does not match the provided encryption context"));
} else {
assertTrue(e.getMessage().contains("Provided encryption context does not match information retrieved from S3"));
Expand All @@ -372,7 +187,7 @@ public void crossLanguageTestKmsWithSubsetEncCtxFails(LanguageServerTarget encLa
}

@ParameterizedTest(name = "{displayName} for Encrypt: {0}, Decrypt: {1}")
@MethodSource("crossLanguageClients")
@MethodSource("software.amazon.encryption.s3.TestUtils#crossLanguageClients")
public void crossLanguageTestKmsWithIncorrectEncCtxFails(LanguageServerTarget encLang, LanguageServerTarget decLang) {
if (ENCRYPTION_CONTEXT_ON_DECRYPT_UNSUPPORTED.contains(decLang.getLanguageName())) {
return;
Expand Down Expand Up @@ -419,7 +234,7 @@ public void crossLanguageTestKmsWithIncorrectEncCtxFails(LanguageServerTarget en
.build());
fail("Expected exception!");
} catch (S3EncryptionClientError e) {
if (decLang.languageName.equals(RUBY_V3) || decLang.languageName.equals(RUBY_V2)) {
if (decLang.getLanguageName().equals(RUBY_V3) || decLang.getLanguageName().equals(RUBY_V2)) {
assertTrue(e.getMessage().contains("Value of encryption context from envelope does not match the provided encryption context"));
} else {
assertTrue(e.getMessage().contains("Provided encryption context does not match information retrieved from S3"));
Expand All @@ -428,9 +243,9 @@ public void crossLanguageTestKmsWithIncorrectEncCtxFails(LanguageServerTarget en
}

@ParameterizedTest(name = "{displayName} for Encrypt: Java, Decrypt: {0}")
@MethodSource("clientsForTest")
@MethodSource("software.amazon.encryption.s3.TestUtils#clientsForTest")
public void kmsV1Legacy(String language) {
S3ECTestServerClient client = testServerClientFor(serverMap.get(language));
S3ECTestServerClient client = testServerClientFor(getServerMap().get(language));
final String objectKey = "test-key-kms-v1-" + language;
final String input = "simple-test-input";
KeyMaterial kmsKeyArn = KeyMaterial.builder()
Expand Down Expand Up @@ -470,9 +285,9 @@ public void kmsV1Legacy(String language) {
}

@ParameterizedTest(name = "{displayName} for Encrypt: Java, Decrypt: {0}")
@MethodSource("clientsForTest")
@MethodSource("software.amazon.encryption.s3.TestUtils#clientsForTest")
public void kmsV1LegacyWithEncCtx(String language) {
S3ECTestServerClient client = testServerClientFor(serverMap.get(language));
S3ECTestServerClient client = testServerClientFor(getServerMap().get(language));
final String objectKey = "test-key-kms-v1-with-enc-ctx-" + language;
final String input = "simple-test-input";
KeyMaterial kmsKeyArn = KeyMaterial.builder()
Expand Down Expand Up @@ -519,9 +334,9 @@ public void kmsV1LegacyWithEncCtx(String language) {
}

@ParameterizedTest(name = "{displayName} for Encrypt: Java, Decrypt: {0}")
@MethodSource("clientsForTest")
@MethodSource("software.amazon.encryption.s3.TestUtils#clientsForTest")
public void kmsV1LegacyFailsWhenLegacyDisabled(String language) {
S3ECTestServerClient client = testServerClientFor(serverMap.get(language));
S3ECTestServerClient client = testServerClientFor(getServerMap().get(language));
final String objectKey = "test-key-kms-v1-fails-disabled" + language;
final String input = "simple-test-input";
KeyMaterial kmsKeyArn = KeyMaterial.builder()
Expand Down
Loading
Loading