diff --git a/src/server/src/main/java/io/cassandrareaper/storage/MemoryStorageFacade.java b/src/server/src/main/java/io/cassandrareaper/storage/MemoryStorageFacade.java index 8891b97b7..84facac48 100644 --- a/src/server/src/main/java/io/cassandrareaper/storage/MemoryStorageFacade.java +++ b/src/server/src/main/java/io/cassandrareaper/storage/MemoryStorageFacade.java @@ -64,13 +64,13 @@ */ public final class MemoryStorageFacade implements IStorageDao { - private static final long DEFAULT_LEAD_TIME = 90; + // Default time to live of leads taken on a segment + private static final long DEFAULT_LEAD_TTL = 90_000; private static final Logger LOG = LoggerFactory.getLogger(MemoryStorageFacade.class); /** Field evaluator to find transient attributes. This is needed to deal with persisting Guava collections objects * that sometimes use the transient keyword for some of their implementation's backing stores**/ private static final PersistenceFieldEvaluator TRANSIENT_FIELD_EVALUATOR = (clazz, field) -> !field.getName().startsWith("_"); - private static final UUID REAPER_INSTANCE_ID = UUID.randomUUID(); private final EmbeddedStorageManager embeddedStorage; private final MemoryStorageRoot memoryStorageRoot; @@ -89,7 +89,7 @@ public final class MemoryStorageFacade implements IStorageDao { ); private final MemorySnapshotDao memSnapshotDao = new MemorySnapshotDao(); private final MemoryMetricsDao memMetricsDao = new MemoryMetricsDao(); - private final ReplicaLockManagerWithTtl repairRunLockManager; + private final ReplicaLockManagerWithTtl replicaLockManagerWithTtl; public MemoryStorageFacade(String persistenceStoragePath, long leadTime) { LOG.info("Using memory storage backend. Persistence storage path: {}", persistenceStoragePath); @@ -108,15 +108,15 @@ public MemoryStorageFacade(String persistenceStoragePath, long leadTime) { LOG.info("Loading existing data from persistence storage"); this.memoryStorageRoot = (MemoryStorageRoot) this.embeddedStorage.root(); } - this.repairRunLockManager = new ReplicaLockManagerWithTtl(leadTime); + this.replicaLockManagerWithTtl = new ReplicaLockManagerWithTtl(leadTime); } public MemoryStorageFacade() { - this(Files.createTempDir().getAbsolutePath(), DEFAULT_LEAD_TIME); + this(Files.createTempDir().getAbsolutePath(), DEFAULT_LEAD_TTL); } public MemoryStorageFacade(String persistenceStoragePath) { - this(persistenceStoragePath, DEFAULT_LEAD_TIME); + this(persistenceStoragePath, DEFAULT_LEAD_TTL); } public MemoryStorageFacade(long leadTime) { @@ -313,22 +313,22 @@ public Map getSubscriptionsById() { @Override public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { - return repairRunLockManager.lockRunningRepairsForNodes(runId, segmentId, replicas); + return replicaLockManagerWithTtl.lockRunningRepairsForNodes(runId, segmentId, replicas); } @Override public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { - return repairRunLockManager.renewRunningRepairsForNodes(runId, segmentId, replicas); + return replicaLockManagerWithTtl.renewRunningRepairsForNodes(runId, segmentId, replicas); } @Override public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { LOG.info("Releasing locks for runId: {}, segmentId: {}, replicas: {}", runId, segmentId, replicas); - return repairRunLockManager.releaseRunningRepairsForNodes(runId, segmentId, replicas); + return replicaLockManagerWithTtl.releaseRunningRepairsForNodes(runId, segmentId, replicas); } @Override public Set getLockedSegmentsForRun(UUID runId) { - return repairRunLockManager.getLockedSegmentsForRun(runId); + return replicaLockManagerWithTtl.getLockedSegmentsForRun(runId); } } diff --git a/src/server/src/main/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtl.java b/src/server/src/main/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtl.java index 1485e1c11..e684a3ad0 100644 --- a/src/server/src/main/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtl.java +++ b/src/server/src/main/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtl.java @@ -31,42 +31,49 @@ public class ReplicaLockManagerWithTtl { private final ConcurrentHashMap replicaLocks = new ConcurrentHashMap<>(); - private final ConcurrentHashMap> runToSegmentLocks = new ConcurrentHashMap<>(); - private final Lock lock = new ReentrantLock(); - private final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); + private final ConcurrentHashMap> repairRunToSegmentLocks = new ConcurrentHashMap<>(); + private final ConcurrentHashMap runIdLocks = new ConcurrentHashMap<>(); - private final long ttlSeconds; // 1 minute + private final long ttlMilliSeconds; - public ReplicaLockManagerWithTtl(long ttlSeconds) { - this.ttlSeconds = ttlSeconds; + public ReplicaLockManagerWithTtl(long ttlMilliSeconds) { + this.ttlMilliSeconds = ttlMilliSeconds; // Schedule cleanup of expired locks - scheduler.scheduleAtFixedRate(this::cleanupExpiredLocks, 1, 1, TimeUnit.SECONDS); + ScheduledExecutorService lockCleanupScheduler = Executors.newScheduledThreadPool(1); + lockCleanupScheduler.scheduleAtFixedRate(this::cleanupExpiredLocks, 1, 1, TimeUnit.SECONDS); } private String getReplicaLockKey(String replica, UUID runId) { return replica + runId; } + private Lock getLockForRunId(UUID runId) { + return runIdLocks.computeIfAbsent(runId, k -> new ReentrantLock()); + } + public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { + Lock lock = getLockForRunId(runId); lock.lock(); try { long currentTime = System.currentTimeMillis(); // Check if any replica is already locked by another runId - for (String replica : replicas) { - LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId)); - if (lockInfo != null && lockInfo.expirationTime > currentTime && lockInfo.runId.equals(runId)) { - return false; // Replica is locked by another runId and not expired - } + boolean anyReplicaLocked = replicas.stream() + .map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId))) + .anyMatch(lockInfo -> lockInfo != null + && lockInfo.expirationTime > currentTime && lockInfo.runId.equals(runId)); + + if (anyReplicaLocked) { + return false; // Replica is locked by another runId and not expired } // Lock the replicas for the given runId and segmentId - long expirationTime = currentTime + (ttlSeconds * 1000); - for (String replica : replicas) { - replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, expirationTime)); - } + long expirationTime = currentTime + ttlMilliSeconds; + replicas.forEach(replica -> + replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, expirationTime)) + ); // Update runId to segmentId mapping - runToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId); + repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId); return true; } finally { lock.unlock(); @@ -74,26 +81,29 @@ public boolean lockRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { + Lock lock = getLockForRunId(runId); lock.lock(); try { long currentTime = System.currentTimeMillis(); // Check if all replicas are already locked by this runId - for (String replica : replicas) { - LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId)); - if (lockInfo == null || !lockInfo.runId.equals(runId) || lockInfo.expirationTime <= currentTime) { - return false; // Some replica is not validly locked by this runId - } + boolean allReplicasLocked = replicas.stream() + .map(replica -> replicaLocks.get(getReplicaLockKey(replica, runId))) + .allMatch(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId) + && lockInfo.expirationTime > currentTime); + + if (!allReplicasLocked) { + return false; // Some replica is not validly locked by this runId } // Renew the lock by extending the expiration time - long newExpirationTime = currentTime + (ttlSeconds * 1000); - for (String replica : replicas) { - replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, newExpirationTime)); - } + long newExpirationTime = currentTime + ttlMilliSeconds; + replicas.forEach(replica -> + replicaLocks.put(getReplicaLockKey(replica, runId), new LockInfo(runId, newExpirationTime)) + ); // Ensure the segmentId is linked to the runId - runToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId); + repairRunToSegmentLocks.computeIfAbsent(runId, k -> ConcurrentHashMap.newKeySet()).add(segmentId); return true; } finally { lock.unlock(); @@ -101,22 +111,22 @@ public boolean renewRunningRepairsForNodes(UUID runId, UUID segmentId, Set replicas) { + Lock lock = getLockForRunId(runId); lock.lock(); try { // Remove the lock for replicas - for (String replica : replicas) { - LockInfo lockInfo = replicaLocks.get(getReplicaLockKey(replica, runId)); - if (lockInfo != null && lockInfo.runId.equals(runId)) { - replicaLocks.remove(getReplicaLockKey(replica, runId)); - } - } + replicas.stream() + .map(replica -> getReplicaLockKey(replica, runId)) + .map(replicaLocks::get) + .filter(lockInfo -> lockInfo != null && lockInfo.runId.equals(runId)) + .forEach(lockInfo -> replicaLocks.remove(getReplicaLockKey(lockInfo.runId.toString(), runId))); // Remove the segmentId from the runId mapping - Set segments = runToSegmentLocks.get(runId); + Set segments = repairRunToSegmentLocks.get(runId); if (segments != null) { segments.remove(segmentId); if (segments.isEmpty()) { - runToSegmentLocks.remove(runId); + repairRunToSegmentLocks.remove(runId); } } return true; @@ -126,12 +136,12 @@ public boolean releaseRunningRepairsForNodes(UUID runId, UUID segmentId, Set getLockedSegmentsForRun(UUID runId) { - return runToSegmentLocks.getOrDefault(runId, Collections.emptySet()); + return repairRunToSegmentLocks.getOrDefault(runId, Collections.emptySet()); } @VisibleForTesting public void cleanupExpiredLocks() { - lock.lock(); + runIdLocks.values().forEach(Lock::lock); try { long currentTime = System.currentTimeMillis(); @@ -139,7 +149,7 @@ public void cleanupExpiredLocks() { replicaLocks.entrySet().removeIf(entry -> entry.getValue().expirationTime <= currentTime); // Clean up runToSegmentLocks by removing segments with no active replicas - runToSegmentLocks.entrySet().removeIf(entry -> { + repairRunToSegmentLocks.entrySet().removeIf(entry -> { UUID runId = entry.getKey(); Set segments = entry.getValue(); @@ -152,7 +162,7 @@ public void cleanupExpiredLocks() { return segments.isEmpty(); }); } finally { - lock.unlock(); + runIdLocks.values().forEach(Lock::unlock); } } diff --git a/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerHangingTest.java b/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerHangingTest.java index b6de95266..577b6520b 100644 --- a/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerHangingTest.java +++ b/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerHangingTest.java @@ -83,7 +83,7 @@ public final class RepairRunnerHangingTest { - private static final long LEAD_TIME = 1L; + private static final long LEAD_TTL = 1000L; private static final Logger LOG = LoggerFactory.getLogger(RepairRunnerHangingTest.class); private static final Set TABLES = ImmutableSet.of("table1"); private static final List THREE_TOKENS = Lists.newArrayList( @@ -243,7 +243,7 @@ public void testHangingRepair() throws InterruptedException, ReaperException, JM final double intensity = 0.5f; final int repairThreadCount = 1; final int segmentTimeout = 1; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); storage.getClusterDao().addCluster(cluster); RepairUnit cf = storage.getRepairUnitDao().addRepairUnit( RepairUnit.builder() @@ -398,7 +398,7 @@ public void testHangingRepairNewApi() throws InterruptedException, ReaperExcepti final double intensity = 0.5f; final int repairThreadCount = 1; final int segmentTimeout = 1; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); storage.getClusterDao().addCluster(cluster); DateTimeUtils.setCurrentMillisFixed(timeRun); RepairUnit cf = storage.getRepairUnitDao().addRepairUnit( @@ -554,7 +554,7 @@ public void testDontFailRepairAfterTopologyChangeIncrementalRepair() throws Inte final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); diff --git a/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerTest.java b/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerTest.java index 147407420..91789d398 100644 --- a/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerTest.java +++ b/src/server/src/test/java/io/cassandrareaper/service/RepairRunnerTest.java @@ -99,7 +99,7 @@ import static org.mockito.Mockito.when; public final class RepairRunnerTest { - private static final long LEAD_TIME = 1L; + private static final long LEAD_TTL = 100L; private static final Set TABLES = ImmutableSet.of("table1"); private static final Duration POLL_INTERVAL = Duration.TWO_SECONDS; private static final List THREE_TOKENS = Lists.newArrayList( @@ -224,7 +224,7 @@ public void testResumeRepair() throws InterruptedException, ReaperException, Mal final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); @@ -356,7 +356,7 @@ public void testTooManyPendingCompactions() final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); @@ -549,7 +549,7 @@ public void testDontFailRepairAfterTopologyChange() throws InterruptedException, final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); @@ -690,7 +690,7 @@ public void testSubrangeIncrementalRepair() throws InterruptedException, ReaperE final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); @@ -965,7 +965,7 @@ public void getNodeMetricsInLocalDcAvailabilityForLocalDcNodeTest() throws Excep final int repairThreadCount = 1; final int segmentTimeout = 30; final List tokens = THREE_TOKENS; - final IStorageDao storage = new MemoryStorageFacade(LEAD_TIME); + final IStorageDao storage = new MemoryStorageFacade(LEAD_TTL); AppContext context = new AppContext(); context.storage = storage; context.config = new ReaperApplicationConfiguration(); diff --git a/src/server/src/test/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtlTest.java b/src/server/src/test/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtlTest.java index 7418ea072..08a5e12be 100644 --- a/src/server/src/test/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtlTest.java +++ b/src/server/src/test/java/io/cassandrareaper/storage/memory/ReplicaLockManagerWithTtlTest.java @@ -39,7 +39,7 @@ public class ReplicaLockManagerWithTtlTest { @BeforeEach public void setUp() { - replicaLockManager = new ReplicaLockManagerWithTtl(1); // TTL of 60 seconds + replicaLockManager = new ReplicaLockManagerWithTtl(1000); runId = UUID.randomUUID(); segmentId = UUID.randomUUID(); replicas = new HashSet<>(Arrays.asList("replica1", "replica2", "replica3"));