diff --git a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java index a25423471f..5d441d0063 100644 --- a/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java +++ b/src/integrationTest/java/org/opensearch/security/ResourceFocusedTests.java @@ -16,17 +16,9 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.lang.management.GarbageCollectorMXBean; -import java.lang.management.ManagementFactory; -import java.lang.management.MemoryPoolMXBean; -import java.lang.management.MemoryUsage; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.zip.GZIPOutputStream; @@ -35,12 +27,16 @@ import org.apache.hc.core5.http.ContentType; import org.apache.hc.core5.http.io.entity.ByteArrayEntity; import org.apache.hc.core5.http.message.BasicHeader; +import org.apache.http.HttpStatus; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.runner.RunWith; import org.opensearch.action.index.IndexRequest; import org.opensearch.client.Client; +import org.opensearch.test.framework.AsyncActions; import org.opensearch.test.framework.TestSecurityConfig; import org.opensearch.test.framework.TestSecurityConfig.User; import org.opensearch.test.framework.cluster.ClusterManager; @@ -52,6 +48,7 @@ @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class ResourceFocusedTests { + private final static Logger LOG = LogManager.getLogger(AsyncActions.class); private static final User ADMIN_USER = new User("admin").roles(ALL_ACCESS); private static final User LIMITED_USER = new User("limited_user").roles( new TestSecurityConfig.Role("limited-role").clusterPermissions( @@ -93,9 +90,8 @@ public void testUnauthenticatedFewBig() { final String requestPath = "/*/_search"; final int parrallelism = 5; final int totalNumberOfRequests = 100; - final boolean statsPrinter = false; - runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests); } @Test @@ -105,9 +101,8 @@ public void testUnauthenticatedManyMedium() { final String requestPath = "/*/_search"; final int parrallelism = 20; final int totalNumberOfRequests = 10_000; - final boolean statsPrinter = false; - runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests); } @Test @@ -116,62 +111,27 @@ public void testUnauthenticatedTonsSmall() { final RequestBodySize size = RequestBodySize.Small; final String requestPath = "/*/_search"; final int parrallelism = 100; - final int totalNumberOfRequests = 1_000_000; - final boolean statsPrinter = false; + final int totalNumberOfRequests = 15_000; - runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests, statsPrinter); + runResourceTest(size, requestPath, parrallelism, totalNumberOfRequests); } - private Long runResourceTest( + private void runResourceTest( final RequestBodySize size, final String requestPath, final int parrallelism, - final int totalNumberOfRequests, - final boolean statsPrinter + final int totalNumberOfRequests ) { final byte[] compressedRequestBody = createCompressedRequestBody(size); try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) { - - if (statsPrinter) { - printStats(); - } - final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); - post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); - - final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism); - - final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) - .boxed() - .map(i -> CompletableFuture.runAsync(() -> client.executeRequest(post), forkJoinPool)) - .collect(Collectors.toList()); - Supplier getCount = () -> waitingOn.stream().filter(cf -> cf.isDone() && !cf.isCompletedExceptionally()).count(); - - CompletableFuture statPrinter = statsPrinter ? CompletableFuture.runAsync(() -> { - while (true) { - printStats(); - System.err.println(" & Succesful completions: " + getCount.get()); - try { - Thread.sleep(500); - } catch (Exception e) { - break; - } - } - }, forkJoinPool) : CompletableFuture.completedFuture(null); - - final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); - - try { - allOfThem.get(30, TimeUnit.SECONDS); - statPrinter.cancel(true); - } catch (final Exception e) { - // Ignored - } - - if (statsPrinter) { - printStats(); - System.err.println(" & Succesful completions: " + getCount.get()); - } - return getCount.get(); + final var requests = AsyncActions.generate(() -> { + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return client.executeRequest(post); + }, parrallelism, totalNumberOfRequests); + + AsyncActions.getAll(requests, 2, TimeUnit.MINUTES) + .forEach((response) -> { response.assertStatusCode(HttpStatus.SC_UNAUTHORIZED); }); } } @@ -217,51 +177,17 @@ private byte[] createCompressedRequestBody(final RequestBodySize size) { gzipOutputStream.finish(); final byte[] compressedRequestBody = byteArrayOutputStream.toByteArray(); - System.err.println( - "^^^" - + String.format( - "Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f", - uncompressedBytesSize, - compressedRequestBody.length, - ((double) uncompressedBytesSize / compressedRequestBody.length) - ) + LOG.info( + String.format( + "Original size was %,d bytes, compressed to %,d bytes, ratio %,.2f", + uncompressedBytesSize, + compressedRequestBody.length, + ((double) uncompressedBytesSize / compressedRequestBody.length) + ) ); return compressedRequestBody; } catch (final IOException ioe) { throw new RuntimeException(ioe); } } - - private void printStats() { - System.err.println("** Stats "); - printMemory(); - printMemoryPools(); - printGCPools(); - } - - private void printMemory() { - final Runtime runtime = Runtime.getRuntime(); - - final long totalMemory = runtime.totalMemory(); // Total allocated memory - final long freeMemory = runtime.freeMemory(); // Amount of free memory - final long usedMemory = totalMemory - freeMemory; // Amount of used memory - - System.err.println(" Memory Total: " + totalMemory + " Free:" + freeMemory + " Used:" + usedMemory); - } - - private void printMemoryPools() { - List memoryPools = ManagementFactory.getMemoryPoolMXBeans(); - for (MemoryPoolMXBean memoryPool : memoryPools) { - MemoryUsage usage = memoryPool.getUsage(); - System.err.println(" " + memoryPool.getName() + " USED: " + usage.getUsed() + " MAX: " + usage.getMax()); - } - } - - private void printGCPools() { - List garbageCollectors = ManagementFactory.getGarbageCollectorMXBeans(); - for (GarbageCollectorMXBean garbageCollector : garbageCollectors) { - System.err.println(" " + garbageCollector.getName() + " COLLECTION TIME: " + garbageCollector.getCollectionTime()); - } - } - } diff --git a/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java b/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java index cf07f93ad8..aa747e2586 100644 --- a/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java +++ b/src/integrationTest/java/org/opensearch/security/rest/CompressionTests.java @@ -11,9 +11,9 @@ package org.opensearch.security.rest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + import org.apache.hc.client5.http.classic.methods.HttpPost; import org.apache.hc.core5.http.ContentType; -import org.apache.hc.core5.http.Header; import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.io.entity.ByteArrayEntity; import org.apache.hc.core5.http.message.BasicHeader; @@ -21,11 +21,7 @@ import org.junit.Test; import org.junit.runner.RunWith; -import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.CoreMatchers.anyOf; -import static org.hamcrest.MatcherAssert.assertThat; +import org.opensearch.test.framework.AsyncActions; import org.opensearch.test.framework.TestSecurityConfig; import org.opensearch.test.framework.cluster.ClusterManager; import org.opensearch.test.framework.cluster.LocalCluster; @@ -34,15 +30,14 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import java.util.zip.GZIPOutputStream; -import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; import static org.opensearch.test.framework.cluster.TestRestClientConfiguration.getBasicAuthHeader; @@ -60,7 +55,7 @@ public class CompressionTests { .build(); @Test - public void testAuthenticatedGzippedRequests() throws Exception { + public void testAuthenticatedGzippedRequests() { final String requestPath = "/*/_search"; final int parallelism = 10; final int totalNumberOfRequests = 100; @@ -69,31 +64,13 @@ public void testAuthenticatedGzippedRequests() throws Exception { final byte[] compressedRequestBody = createCompressedRequestBody(rawBody); try (final TestRestClient client = cluster.getRestClient(ADMIN_USER, new BasicHeader("Content-Encoding", "gzip"))) { + final var requests = AsyncActions.generate(() -> { + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return client.executeRequest(post); + }, parallelism, totalNumberOfRequests); - final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism); - - final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) - .boxed() - .map(i -> CompletableFuture.supplyAsync(() -> { - final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); - post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); - return client.executeRequest(post); - }, forkJoinPool)) - .collect(Collectors.toList()); - - final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); - - allOfThem.get(30, TimeUnit.SECONDS); - - waitingOn.stream().forEach(future -> { - try { - final HttpResponse response = future.get(); - response.assertStatusCode(HttpStatus.SC_OK); - } catch (final Exception ex) { - throw new RuntimeException(ex); - } - }); - ; + AsyncActions.getAll(requests, 30, TimeUnit.SECONDS).forEach((response) -> { response.assertStatusCode(HttpStatus.SC_OK); }); } } @@ -101,40 +78,40 @@ public void testAuthenticatedGzippedRequests() throws Exception { public void testMixOfAuthenticatedAndUnauthenticatedGzippedRequests() throws Exception { final String requestPath = "/*/_search"; final int parallelism = 10; - final int totalNumberOfRequests = 100; + final int totalNumberOfRequests = 50; final String rawBody = "{ \"query\": { \"match\": { \"foo\": \"bar\" }}}"; final byte[] compressedRequestBody = createCompressedRequestBody(rawBody); try (final TestRestClient client = cluster.getRestClient(new BasicHeader("Content-Encoding", "gzip"))) { - - final ForkJoinPool forkJoinPool = new ForkJoinPool(parallelism); - - final Header basicAuthHeader = getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword()); - - final List> waitingOn = IntStream.rangeClosed(1, totalNumberOfRequests) - .boxed() - .map(i -> CompletableFuture.supplyAsync(() -> { - final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); - post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); - return i % 2 == 0 ? client.executeRequest(post) : client.executeRequest(post, basicAuthHeader); - }, forkJoinPool)) - .collect(Collectors.toList()); - - final CompletableFuture allOfThem = CompletableFuture.allOf(waitingOn.toArray(new CompletableFuture[0])); - - allOfThem.get(30, TimeUnit.SECONDS); - - waitingOn.stream().forEach(future -> { - try { - final HttpResponse response = future.get(); - assertThat(response.getBody(), not(containsString("json_parse_exception"))); - assertThat(response.getStatusCode(), anyOf(equalTo(HttpStatus.SC_UNAUTHORIZED), equalTo(HttpStatus.SC_OK))); - } catch (final Exception ex) { - throw new RuntimeException(ex); - } + final CountDownLatch countDownLatch = new CountDownLatch(1); + + final var authorizedRequests = AsyncActions.generate(() -> { + countDownLatch.await(); + System.err.println("Generation triggered authorizedRequests"); + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return client.executeRequest(post, getBasicAuthHeader(ADMIN_USER.getName(), ADMIN_USER.getPassword())); + }, parallelism, totalNumberOfRequests); + + final var unauthorizedRequests = AsyncActions.generate(() -> { + countDownLatch.await(); + System.err.println("Generation triggered unauthorizedRequests"); + final HttpPost post = new HttpPost(client.getHttpServerUri() + requestPath); + post.setEntity(new ByteArrayEntity(compressedRequestBody, ContentType.APPLICATION_JSON)); + return client.executeRequest(post); + }, parallelism, totalNumberOfRequests); + + // Make sure all requests start at the same time + countDownLatch.countDown(); + + AsyncActions.getAll(authorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> { + assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_OK)); + }); + AsyncActions.getAll(unauthorizedRequests, 30, TimeUnit.SECONDS).forEach((response) -> { + assertThat(response.getBody(), not(containsString("json_parse_exception"))); + assertThat(response.getStatusCode(), equalTo(HttpStatus.SC_UNAUTHORIZED)); }); - ; } } diff --git a/src/integrationTest/java/org/opensearch/test/framework/AsyncActions.java b/src/integrationTest/java/org/opensearch/test/framework/AsyncActions.java new file mode 100644 index 0000000000..409aa5a416 --- /dev/null +++ b/src/integrationTest/java/org/opensearch/test/framework/AsyncActions.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.test.framework; + +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +public class AsyncActions { + private final static Logger LOG = LogManager.getLogger(AsyncActions.class); + + /** + * Using the provided generator create a list of completable futures. + * @param parrallelism How many calls to the generator should be done at the same time. + * @param generationCount The total number of calls to the generator to conduct. + * @return The list of completable futures running on the fork join thread pool. + */ + public static List> generate(final Callable generator, final int parrallelism, final int generationCount) { + final ForkJoinPool forkJoinPool = new ForkJoinPool(parrallelism); + return IntStream.rangeClosed(1, generationCount).boxed().map(i -> CompletableFuture.supplyAsync(() -> { + try { + return generator.call(); + } catch (final Exception ex) { + throw new RuntimeException(ex); + } + }, forkJoinPool)).collect(Collectors.toList()); + } + + /** + * Waits for futures for a time period and then returns them a list + * @param futures Futures to wait for completion with a result + * @param n Amount of time to wait + * @param unit Time associated with those units + * @return Completed results from the futures + */ + public static List getAll(final List> futures, final int n, final TimeUnit unit) { + LOG.info("Starting to wait for " + futures.size() + " futures to complete in " + unit.toSeconds(n) + " seconds."); + final long startTimeMs = System.currentTimeMillis(); + final CompletableFuture futuresCompleted = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); + try { + futuresCompleted.get(n, unit); + } catch (final Exception ex) { + final long completedFuturesCount = futures.stream().filter(CompletableFuture::isDone).count(); + final String perfReport = calculatePerfReport(startTimeMs, completedFuturesCount); + throw new RuntimeException( + "Unable to wait for all futures to complete, of " + + futures.size() + + " futures " + + completedFuturesCount + + " have finished." + + perfReport + ); + } + final long completedFuturesCount = futures.stream().filter(CompletableFuture::isDone).count(); + final String perfReport = calculatePerfReport(startTimeMs, completedFuturesCount); + LOG.info(perfReport); + + final long elapsedTimeMs = System.currentTimeMillis() - startTimeMs; + final long expectedMs = unit.toMillis(n); + if (elapsedTimeMs > .75 * expectedMs) { + LOG.warn("Completion time was within 25% of the expected time, more than this threshold is recommended."); + } + + return futures.stream().map(future -> { + try { + return future.get(); + } catch (final Exception ex) { + throw new RuntimeException(ex); + } + }).collect(Collectors.toList()); + } + + private static String calculatePerfReport(final long startTimeMs, final long completedFuturesCount) { + final long elapsedTimeMs = System.currentTimeMillis() - startTimeMs; + final double avgTimePerFutureMs = (double) elapsedTimeMs / completedFuturesCount; + final double futuresPerSecond = 1000 / avgTimePerFutureMs; + return String.format( + "Waited for %d seconds, completion speed was on average %.2fms per future %.2fx per second.", + TimeUnit.MILLISECONDS.toSeconds(elapsedTimeMs), + avgTimePerFutureMs, + futuresPerSecond + ); + } +} diff --git a/src/integrationTest/resources/log4j2-test.properties b/src/integrationTest/resources/log4j2-test.properties index 8d9cf87666..0b865b46b3 100644 --- a/src/integrationTest/resources/log4j2-test.properties +++ b/src/integrationTest/resources/log4j2-test.properties @@ -28,6 +28,7 @@ logger.auditlogs.level = info # Logger required by test org.opensearch.security.http.JwtAuthenticationTests logger.httpjwtauthenticator.name = com.amazon.dlic.auth.http.jwt.HTTPJwtAuthenticator logger.httpjwtauthenticator.level = debug +logger.backendreg.additivity = false logger.httpjwtauthenticator.appenderRef.capturing.ref = logCapturingAppender #Required by tests: @@ -35,10 +36,12 @@ logger.httpjwtauthenticator.appenderRef.capturing.ref = logCapturingAppender # org.opensearch.security.UserBruteForceAttacksPreventionTests logger.backendreg.name = org.opensearch.security.auth.BackendRegistry logger.backendreg.level = debug +logger.backendreg.additivity = false logger.backendreg.appenderRef.capturing.ref = logCapturingAppender #com.amazon.dlic.auth.ldap #logger.ldap.name=com.amazon.dlic.auth.ldap.backend.LDAPAuthenticationBackend logger.ldap.name=com.amazon.dlic.auth.ldap.backend logger.ldap.level=TRACE +logger.backendreg.additivity = false logger.ldap.appenderRef.capturing.ref = logCapturingAppender