diff --git a/lib/trino-cache/pom.xml b/lib/trino-cache/pom.xml
index eb8af95ab59c..9c7cfc40aa7c 100644
--- a/lib/trino-cache/pom.xml
+++ b/lib/trino-cache/pom.xml
@@ -72,5 +72,17 @@
junit-jupiter-api
test
+
+
+ org.openjdk.jmh
+ jmh-core
+ test
+
+
+
+ org.openjdk.jmh
+ jmh-generator-annprocess
+ test
+
diff --git a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java
index c62648e89d3a..be946fbd7e5c 100644
--- a/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java
+++ b/lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java
@@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListenableFuture;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
import jakarta.annotation.Nullable;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
@@ -40,7 +41,6 @@
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
-import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
@@ -68,9 +68,6 @@ class EvictableCache
// The dataCache must be bounded.
private final LoadingCache, V> dataCache;
- // Logically a concurrent Multiset
- private final ConcurrentHashMap, Long> ongoingLoads = new ConcurrentHashMap<>();
-
EvictableCache(CacheBuilder super Token, ? super V> cacheBuilder, CacheLoader super K, V> cacheLoader)
{
dataCache = buildUnsafeCache(
@@ -78,13 +75,16 @@ class EvictableCache
., V>removalListener(removal -> {
Token token = removal.getKey();
verify(token != null, "token is null");
- if (removal.getCause() == RemovalCause.REPLACED) {
- return;
- }
- if (removal.getCause() == RemovalCause.EXPIRED && ongoingLoads.containsKey(token)) {
- return;
+ // synchronize ongoing load check and token removal
+ synchronized (token) {
+ if (removal.getCause() == RemovalCause.REPLACED) {
+ return;
+ }
+ if (removal.getCause() == RemovalCause.EXPIRED && token.hasOngoingLoad()) {
+ return;
+ }
+ tokens.remove(token.getKey(), token);
}
- tokens.remove(token.getKey(), token);
}),
new TokenCacheLoader<>(cacheLoader));
}
@@ -113,12 +113,12 @@ public V get(K key, Callable extends V> valueLoader)
Token newToken = new Token<>(key);
Token token = tokens.computeIfAbsent(key, _ -> newToken);
try {
- startLoading(token);
+ token.startLoading();
try {
return dataCache.get(token, valueLoader);
}
finally {
- endLoading(token);
+ token.endLoading();
}
}
catch (Throwable e) {
@@ -142,12 +142,12 @@ public V get(K key)
Token newToken = new Token<>(key);
Token token = tokens.computeIfAbsent(key, _ -> newToken);
try {
- startLoading(token);
+ token.startLoading();
try {
return dataCache.get(token);
}
finally {
- endLoading(token);
+ token.endLoading();
}
}
catch (Throwable e) {
@@ -215,27 +215,14 @@ public ImmutableMap getAll(Iterable extends K> keys)
}
}
- private void startLoading(Token token)
- {
- ongoingLoads.compute(token, (_, count) -> firstNonNull(count, 0L) + 1);
- }
-
- private void endLoading(Token token)
- {
- ongoingLoads.compute(token, (_, count) -> {
- verify(count != null && count > 0, "Incorrect count for token %s: %s", token, count);
- if (count == 1) {
- return null;
- }
- return count - 1;
- });
- }
-
// Token eviction via removalListener is blocked during loading, so we may need to do manual cleanup
private void removeDangling(Token token)
{
- if (!dataCache.asMap().containsKey(token)) {
- tokens.remove(token.getKey(), token);
+ // synchronize to make accessing both collections thread-safe
+ synchronized (token) {
+ if (!dataCache.asMap().containsKey(token) && !token.hasOngoingLoad()) {
+ tokens.remove(token.getKey(), token);
+ }
}
}
@@ -431,6 +418,8 @@ public Set> entrySet()
static final class Token
{
private final K key;
+ @GuardedBy("this")
+ private int ongoingLoads;
Token(K key)
{
@@ -447,6 +436,22 @@ public String toString()
{
return format("CacheToken(%s; %s)", Integer.toHexString(hashCode()), key);
}
+
+ synchronized boolean hasOngoingLoad()
+ {
+ return ongoingLoads > 0;
+ }
+
+ synchronized void startLoading()
+ {
+ ongoingLoads++;
+ }
+
+ synchronized void endLoading()
+ {
+ ongoingLoads--;
+ verify(ongoingLoads >= 0, "ongoingLoads must be greater than or equal 0");
+ }
}
private static class TokenCacheLoader
diff --git a/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java b/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java
new file mode 100644
index 000000000000..765480d8729f
--- /dev/null
+++ b/lib/trino-cache/src/test/java/io/trino/cache/BenchmarkEvictableCache.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.cache;
+
+import com.google.common.cache.Cache;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+import org.openjdk.jmh.runner.RunnerException;
+
+import java.time.Duration;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.Function;
+
+import static io.trino.jmh.Benchmarks.benchmark;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
+
+@OutputTimeUnit(NANOSECONDS)
+@Fork(1)
+@Warmup(iterations = 5, time = 700, timeUnit = MILLISECONDS)
+@Measurement(iterations = 10, time = 700, timeUnit = MILLISECONDS)
+@BenchmarkMode(Mode.AverageTime)
+public class BenchmarkEvictableCache
+{
+ @Benchmark
+ public void cacheBenchmark(CacheBenchmarkState state)
+ throws InterruptedException, ExecutionException
+ {
+ Cache cache = EvictableCacheBuilder.newBuilder()
+ .expireAfterWrite(Duration.ofSeconds(60))
+ .maximumSize(1000)
+ .build();
+
+ try (ExecutorService executor = Executors.newFixedThreadPool(state.competingThreadsCount)) {
+ for (int iteration = 0; iteration < state.iterations; iteration++) {
+ CountDownLatch startLatch = new CountDownLatch(state.competingThreadsCount);
+ CountDownLatch endLatch = new CountDownLatch(state.competingThreadsCount);
+ for (int threadIndex = 0; threadIndex < state.competingThreadsCount; threadIndex++) {
+ executor.submit(() -> {
+ try {
+ startLatch.countDown();
+
+ startLatch.await();
+ cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis));
+ endLatch.countDown();
+ }
+ catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new RuntimeException(e);
+ }
+ catch (ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ endLatch.await();
+ // final cache retrieval after race (to see if value is still in cache)
+ cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis));
+ }
+ }
+ }
+
+ @State(Scope.Benchmark)
+ public static class CacheBenchmarkState
+ {
+ public final Function> expensiveCacheLoaderSupplier = cacheValueLoadTime -> () -> {
+ if (cacheValueLoadTime > 0) {
+ // simulate value loading takes some time, so each cache miss is expensive
+ Thread.sleep(cacheValueLoadTime);
+ }
+ return "key_nano:" + System.nanoTime();
+ };
+
+ @Param({"2", "5", "15", "50"})
+ public int competingThreadsCount;
+
+ @Param({"1", "2", "5"})
+ public int iterations;
+
+ @Param({"0", "500", "2000"})
+ public int cacheValueLoadTimeMillis;
+ }
+
+ public static void main(String[] args)
+ throws RunnerException
+ {
+ benchmark(BenchmarkEvictableCache.class).run();
+ }
+}
diff --git a/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java b/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java
index 32d8a384e8be..ee431db2dab5 100644
--- a/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java
+++ b/lib/trino-cache/src/test/java/io/trino/cache/TestEvictableCache.java
@@ -21,9 +21,11 @@
import io.airlift.testing.TestingTicker;
import io.trino.cache.EvictableCacheBuilder.DisabledCacheImplementation;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
+import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -34,6 +36,7 @@
import java.util.concurrent.Exchanger;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
@@ -585,4 +588,27 @@ public void testPutOnNonEmptyCacheImplementation()
.isInstanceOf(UnsupportedOperationException.class)
.hasMessage("The operation is not supported, as in inherently races with cache invalidation");
}
+
+ @RepeatedTest(1_000)
+ public void testParallelLoadingCacheEntries()
+ {
+ Cache cache = EvictableCacheBuilder.newBuilder()
+ .expireAfterWrite(Duration.ofSeconds(60))
+ .maximumSize(10)
+ .build();
+ try (ExecutorService executor = Executors.newFixedThreadPool(2)) {
+ Runnable cacheLoader = () -> {
+ try {
+ String value = cache.get("key", () -> "value");
+ assertThat(value).isEqualTo("value");
+ }
+ catch (ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ };
+ executor.submit(cacheLoader);
+ executor.submit(cacheLoader);
+ }
+ assertThat(cache.getIfPresent("key")).isNotNull();
+ }
}