Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in EvictableCache#removeDangling #23401

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions lib/trino-cache/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,17 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
69 changes: 37 additions & 32 deletions lib/trino-cache/src/main/java/io/trino/cache/EvictableCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import jakarta.annotation.Nullable;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;

Expand All @@ -40,7 +41,6 @@
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;

import static com.google.common.base.MoreObjects.firstNonNull;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
Expand Down Expand Up @@ -68,23 +68,23 @@ class EvictableCache<K, V>
// The dataCache must be bounded.
private final LoadingCache<Token<K>, V> dataCache;

// Logically a concurrent Multiset
private final ConcurrentHashMap<Token<K>, Long> ongoingLoads = new ConcurrentHashMap<>();

EvictableCache(CacheBuilder<? super Token<K>, ? super V> cacheBuilder, CacheLoader<? super K, V> cacheLoader)
{
dataCache = buildUnsafeCache(
cacheBuilder
.<Token<K>, V>removalListener(removal -> {
Token<K> token = removal.getKey();
verify(token != null, "token is null");
if (removal.getCause() == RemovalCause.REPLACED) {
return;
}
if (removal.getCause() == RemovalCause.EXPIRED && ongoingLoads.containsKey(token)) {
return;
// synchronize ongoing load check and token removal
synchronized (token) {
if (removal.getCause() == RemovalCause.REPLACED) {
return;
}
if (removal.getCause() == RemovalCause.EXPIRED && token.hasOngoingLoad()) {
return;
}
tokens.remove(token.getKey(), token);
}
tokens.remove(token.getKey(), token);
}),
new TokenCacheLoader<>(cacheLoader));
}
Expand Down Expand Up @@ -113,12 +113,12 @@ public V get(K key, Callable<? extends V> valueLoader)
Token<K> newToken = new Token<>(key);
Token<K> token = tokens.computeIfAbsent(key, _ -> newToken);
try {
startLoading(token);
token.startLoading();
try {
return dataCache.get(token, valueLoader);
}
finally {
endLoading(token);
token.endLoading();
}
}
catch (Throwable e) {
Expand All @@ -142,12 +142,12 @@ public V get(K key)
Token<K> newToken = new Token<>(key);
Token<K> token = tokens.computeIfAbsent(key, _ -> newToken);
try {
startLoading(token);
token.startLoading();
try {
return dataCache.get(token);
}
finally {
endLoading(token);
token.endLoading();
}
}
catch (Throwable e) {
Expand Down Expand Up @@ -215,27 +215,14 @@ public ImmutableMap<K, V> getAll(Iterable<? extends K> keys)
}
}

private void startLoading(Token<K> token)
{
ongoingLoads.compute(token, (_, count) -> firstNonNull(count, 0L) + 1);
}

private void endLoading(Token<K> token)
{
ongoingLoads.compute(token, (_, count) -> {
verify(count != null && count > 0, "Incorrect count for token %s: %s", token, count);
if (count == 1) {
return null;
}
return count - 1;
});
}

// Token eviction via removalListener is blocked during loading, so we may need to do manual cleanup
private void removeDangling(Token<K> token)
{
if (!dataCache.asMap().containsKey(token)) {
tokens.remove(token.getKey(), token);
// synchronize to make accessing both collections thread-safe
synchronized (token) {
if (!dataCache.asMap().containsKey(token) && !token.hasOngoingLoad()) {
tokens.remove(token.getKey(), token);
}
}
}

Expand Down Expand Up @@ -431,6 +418,8 @@ public Set<Map.Entry<K, V>> entrySet()
static final class Token<K>
{
private final K key;
@GuardedBy("this")
private int ongoingLoads;

Token(K key)
{
Expand All @@ -447,6 +436,22 @@ public String toString()
{
return format("CacheToken(%s; %s)", Integer.toHexString(hashCode()), key);
}

synchronized boolean hasOngoingLoad()
{
return ongoingLoads > 0;
}

synchronized void startLoading()
{
ongoingLoads++;
}

synchronized void endLoading()
{
ongoingLoads--;
verify(ongoingLoads >= 0, "ongoingLoads must be greater than or equal 0");
}
}

private static class TokenCacheLoader<K, V>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cache;

import com.google.common.cache.Cache;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.RunnerException;

import java.time.Duration;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;

import static io.trino.jmh.Benchmarks.benchmark;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

@OutputTimeUnit(NANOSECONDS)
@Fork(1)
@Warmup(iterations = 5, time = 700, timeUnit = MILLISECONDS)
@Measurement(iterations = 10, time = 700, timeUnit = MILLISECONDS)
@BenchmarkMode(Mode.AverageTime)
public class BenchmarkEvictableCache
{
@Benchmark
public void cacheBenchmark(CacheBenchmarkState state)
throws InterruptedException, ExecutionException
{
Cache<String, String> cache = EvictableCacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofSeconds(60))
.maximumSize(1000)
.build();

try (ExecutorService executor = Executors.newFixedThreadPool(state.competingThreadsCount)) {
for (int iteration = 0; iteration < state.iterations; iteration++) {
CountDownLatch startLatch = new CountDownLatch(state.competingThreadsCount);
CountDownLatch endLatch = new CountDownLatch(state.competingThreadsCount);
for (int threadIndex = 0; threadIndex < state.competingThreadsCount; threadIndex++) {
executor.submit(() -> {
try {
startLatch.countDown();

startLatch.await();
cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis));
endLatch.countDown();
}
catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
});
}
endLatch.await();
// final cache retrieval after race (to see if value is still in cache)
cache.get("key", state.expensiveCacheLoaderSupplier.apply(state.cacheValueLoadTimeMillis));
}
}
}

@State(Scope.Benchmark)
public static class CacheBenchmarkState
{
public final Function<Integer, Callable<String>> expensiveCacheLoaderSupplier = cacheValueLoadTime -> () -> {
if (cacheValueLoadTime > 0) {
// simulate value loading takes some time, so each cache miss is expensive
Thread.sleep(cacheValueLoadTime);
}
return "key_nano:" + System.nanoTime();
};

@Param({"2", "5", "15", "50"})
public int competingThreadsCount;

@Param({"1", "2", "5"})
public int iterations;

@Param({"0", "500", "2000"})
public int cacheValueLoadTimeMillis;
}

public static void main(String[] args)
throws RunnerException
{
benchmark(BenchmarkEvictableCache.class).run();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import io.airlift.testing.TestingTicker;
import io.trino.cache.EvictableCacheBuilder.DisabledCacheImplementation;
import org.gaul.modernizer_maven_annotations.SuppressModernizer;
import org.junit.jupiter.api.RepeatedTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -34,6 +36,7 @@
import java.util.concurrent.Exchanger;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -585,4 +588,27 @@ public void testPutOnNonEmptyCacheImplementation()
.isInstanceOf(UnsupportedOperationException.class)
.hasMessage("The operation is not supported, as in inherently races with cache invalidation");
}

@RepeatedTest(1_000)
public void testParallelLoadingCacheEntries()
{
Cache<String, String> cache = EvictableCacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofSeconds(60))
.maximumSize(10)
.build();
try (ExecutorService executor = Executors.newFixedThreadPool(2)) {
Runnable cacheLoader = () -> {
try {
String value = cache.get("key", () -> "value");
assertThat(value).isEqualTo("value");
}
catch (ExecutionException e) {
throw new RuntimeException(e);
}
};
executor.submit(cacheLoader);
executor.submit(cacheLoader);
}
assertThat(cache.getIfPresent("key")).isNotNull();
}
}
Loading