diff --git a/lib/trino-cache/pom.xml b/lib/trino-cache/pom.xml index eb8af95ab59c..9c7cfc40aa7c 100644 --- a/lib/trino-cache/pom.xml +++ b/lib/trino-cache/pom.xml @@ -72,5 +72,17 @@ junit-jupiter-api test + + + org.openjdk.jmh + jmh-core + test + + + + org.openjdk.jmh + jmh-generator-annprocess + test + diff --git a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java index c62648e89d3a..be946fbd7e5c 100644 --- a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java +++ b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; import jakarta.annotation.Nullable; import org.gaul.modernizer_maven_annotations.SuppressModernizer; @@ -40,7 +41,6 @@ import java.util.concurrent.ExecutionException; import java.util.function.BiFunction; -import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; import static java.lang.String.format; @@ -68,9 +68,6 @@ class EvictableCache // The dataCache must be bounded. private final LoadingCache, V> dataCache; - // Logically a concurrent Multiset - private final ConcurrentHashMap, Long> ongoingLoads = new ConcurrentHashMap<>(); - EvictableCache(CacheBuilder, ? super V> cacheBuilder, CacheLoader cacheLoader) { dataCache = buildUnsafeCache( @@ -78,13 +75,16 @@ class EvictableCache ., V>removalListener(removal -> { Token token = removal.getKey(); verify(token != null, "token is null"); - if (removal.getCause() == RemovalCause.REPLACED) { - return; - } - if (removal.getCause() == RemovalCause.EXPIRED && ongoingLoads.containsKey(token)) { - return; + // synchronize ongoing load check and token removal + synchronized (token) { + if (removal.getCause() == RemovalCause.REPLACED) { + return; + } + if (removal.getCause() == RemovalCause.EXPIRED && token.hasOngoingLoad()) { + return; + } + tokens.remove(token.getKey(), token); } - tokens.remove(token.getKey(), token); }), new TokenCacheLoader<>(cacheLoader)); } @@ -113,12 +113,12 @@ public V get(K key, Callable valueLoader) Token newToken = new Token<>(key); Token token = tokens.computeIfAbsent(key, _ -> newToken); try { - startLoading(token); + token.startLoading(); try { return dataCache.get(token, valueLoader); } finally { - endLoading(token); + token.endLoading(); } } catch (Throwable e) { @@ -142,12 +142,12 @@ public V get(K key) Token newToken = new Token<>(key); Token token = tokens.computeIfAbsent(key, _ -> newToken); try { - startLoading(token); + token.startLoading(); try { return dataCache.get(token); } finally { - endLoading(token); + token.endLoading(); } } catch (Throwable e) { @@ -215,27 +215,14 @@ public ImmutableMap getAll(Iterable keys) } } - private void startLoading(Token token) - { - ongoingLoads.compute(token, (_, count) -> firstNonNull(count, 0L) + 1); - } - - private void endLoading(Token token) - { - ongoingLoads.compute(token, (_, count) -> { - verify(count != null && count > 0, "Incorrect count for token %s: %s", token, count); - if (count == 1) { - return null; - } - return count - 1; - }); - } - // Token eviction via removalListener is blocked during loading, so we may need to do manual cleanup private void removeDangling(Token token) { - if (!dataCache.asMap().containsKey(token)) { - tokens.remove(token.getKey(), token); + // synchronize to make accessing both collections thread-safe + synchronized (token) { + if (!dataCache.asMap().containsKey(token) && !token.hasOngoingLoad()) { + tokens.remove(token.getKey(), token); + } } } @@ -431,6 +418,8 @@ public Set> entrySet() static final class Token { private final K key; + @GuardedBy("this") + private int ongoingLoads; Token(K key) { @@ -447,6 +436,22 @@ public String toString() { return format("CacheToken(%s; %s)", Integer.toHexString(hashCode()), key); } + + synchronized boolean hasOngoingLoad() + { + return ongoingLoads > 0; + } + + synchronized void startLoading() + { + ongoingLoads++; + } + + synchronized void endLoading() + { + ongoingLoads--; + verify(ongoingLoads >= 0, "ongoingLoads must be greater than or equal 0"); + } } private static class TokenCacheLoader diff --git a/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java b/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java new file mode 100644 index 000000000000..765480d8729f --- /dev/null +++ b/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cache; + +import com.google.common.cache.Cache; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.RunnerException; + +import java.time.Duration; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import static io.trino.jmh.Benchmarks.benchmark; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +@OutputTimeUnit(NANOSECONDS) +@Fork(1) +@Warmup(iterations = 5, time = 700, timeUnit = MILLISECONDS) +@Measurement(iterations = 10, time = 700, timeUnit = MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkEvictableCache +{ + @Benchmark + public void cacheBenchmark(CacheBenchmarkState state) + throws InterruptedException, ExecutionException + { + Cache cache = EvictableCacheBuilder.newBuilder() + .expireAfterWrite(Duration.ofSeconds(60)) + .maximumSize(1000) + .build(); + + try (ExecutorService executor = Executors.newFixedThreadPool(state.competingThreadsCount)) { + for (int iteration = 0; iteration < state.iterations; iteration++) { + CountDownLatch startLatch = new CountDownLatch(state.competingThreadsCount); + CountDownLatch endLatch = new CountDownLatch(state.competingThreadsCount); + for (int threadIndex = 0; threadIndex < state.competingThreadsCount; threadIndex++) { + executor.submit(() -> { + try { + startLatch.countDown(); + + startLatch.await(); + cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis)); + endLatch.countDown(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + catch (ExecutionException e) { + throw new RuntimeException(e); + } + }); + } + endLatch.await(); + // final cache retrieval after race (to see if value is still in cache) + cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis)); + } + } + } + + @State(Scope.Benchmark) + public static class CacheBenchmarkState + { + public final Function> expensiveCacheLoaderSupplier = cacheValueLoadTime -> () -> { + if (cacheValueLoadTime > 0) { + // simulate value loading takes some time, so each cache miss is expensive + Thread.sleep(cacheValueLoadTime); + } + return "key_nano:" + System.nanoTime(); + }; + + @Param({"2", "5", "15", "50"}) + public int competingThreadsCount; + + @Param({"1", "2", "5"}) + public int iterations; + + @Param({"0", "500", "2000"}) + public int cacheValueLoadTimeMillis; + } + + public static void main(String[] args) + throws RunnerException + { + benchmark(BenchmarkEvictableCache.class).run(); + } +} diff --git a/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java b/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java index 32d8a384e8be..ee431db2dab5 100644 --- a/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java +++ b/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java @@ -21,9 +21,11 @@ import io.airlift.testing.TestingTicker; import io.trino.cache.EvictableCacheBuilder.DisabledCacheImplementation; import org.gaul.modernizer_maven_annotations.SuppressModernizer; +import org.junit.jupiter.api.RepeatedTest; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -34,6 +36,7 @@ import java.util.concurrent.Exchanger; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -585,4 +588,27 @@ public void testPutOnNonEmptyCacheImplementation() .isInstanceOf(UnsupportedOperationException.class) .hasMessage("The operation is not supported, as in inherently races with cache invalidation"); } + + @RepeatedTest(1_000) + public void testParallelLoadingCacheEntries() + { + Cache cache = EvictableCacheBuilder.newBuilder() + .expireAfterWrite(Duration.ofSeconds(60)) + .maximumSize(10) + .build(); + try (ExecutorService executor = Executors.newFixedThreadPool(2)) { + Runnable cacheLoader = () -> { + try { + String value = cache.get("key", () -> "value"); + assertThat(value).isEqualTo("value"); + } + catch (ExecutionException e) { + throw new RuntimeException(e); + } + }; + executor.submit(cacheLoader); + executor.submit(cacheLoader); + } + assertThat(cache.getIfPresent("key")).isNotNull(); + } }