diff --git a/common/common/src/main/java/io/helidon/build/common/LazyValue.java b/common/common/src/main/java/io/helidon/build/common/LazyValue.java index 8caa2068f..d99121363 100644 --- a/common/common/src/main/java/io/helidon/build/common/LazyValue.java +++ b/common/common/src/main/java/io/helidon/build/common/LazyValue.java @@ -17,7 +17,7 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; import java.util.function.Supplier; /** @@ -25,26 +25,36 @@ * * @param The type of the instance. */ -public class LazyValue { +public class LazyValue implements Supplier { - private static final VarHandle LATCH; + private static final VarHandle LOCK; private static final VarHandle STATE; static { try { + LOCK = MethodHandles.lookup().findVarHandle(LazyValue.class, "lock", Semaphore.class); STATE = MethodHandles.lookup().findVarHandle(LazyValue.class, "state", int.class); - LATCH = MethodHandles.lookup().findVarHandle(LazyValue.class, "latch", CountDownLatch.class); - } catch (ReflectiveOperationException e) { - throw new ExceptionInInitializerError(e); + } catch (Exception e) { + throw new Error("Unable to obtain VarHandle's", e); } } - private final Supplier supplier; + private T value; + private Supplier delegate; + @SuppressWarnings("unused") + private volatile Semaphore lock; private volatile int state; - @SuppressWarnings("unused") - private volatile CountDownLatch latch; - private T value; + + /** + * Create a new loaded instance. + * + * @param value value + */ + public LazyValue(T value) { + this.value = value; + this.state = 2; + } /** * Create a new instance. @@ -52,7 +62,7 @@ public class LazyValue { * @param supplier value supplier. */ public LazyValue(Supplier supplier) { - this.supplier = supplier; + this.delegate = supplier; } /** @@ -60,38 +70,41 @@ public LazyValue(Supplier supplier) { * * @return The value. */ + @Override public T get() { int stateCopy = state; - CountDownLatch latchCopy; - if (stateCopy == 0) { - // init - if (STATE.compareAndSet(this, 0, 1)) { - try { - value = supplier.get(); - state = 2; - } catch (Throwable th) { - state = 0; - throw th; - } finally { - latchCopy = latch; - if (latchCopy != null) { - latchCopy.countDown(); - } - } + if (stateCopy == 2) { + return value; + } + Semaphore lockCopy = lock; + while (stateCopy != 2 && !STATE.compareAndSet(this, 0, 1)) { + if (lockCopy == null) { + LOCK.compareAndSet(this, null, new Semaphore(0)); + lockCopy = lock; } stateCopy = state; + if (stateCopy == 1) { + lockCopy.acquireUninterruptibly(); + stateCopy = state; + } } - if (stateCopy == 1) { - // init race - latchCopy = latch; - if (latchCopy == null) { - LATCH.compareAndSet(this, null, new CountDownLatch(1)); - latchCopy = latch; + + try { + if (stateCopy == 2) { + return value; + } + stateCopy = 0; + value = delegate.get(); + delegate = null; + stateCopy = 2; + state = 2; + } finally { + if (stateCopy == 0) { + state = 0; } - try { - latchCopy.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); + lockCopy = lock; + if (lockCopy != null) { + lockCopy.release(); } } return value; diff --git a/common/common/src/test/java/io/helidon/build/common/LazyValueTest.java b/common/common/src/test/java/io/helidon/build/common/LazyValueTest.java index f33f578ae..bdd203b69 100644 --- a/common/common/src/test/java/io/helidon/build/common/LazyValueTest.java +++ b/common/common/src/test/java/io/helidon/build/common/LazyValueTest.java @@ -31,7 +31,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; import static org.junit.jupiter.api.Assertions.assertThrows; /** @@ -42,7 +41,7 @@ class LazyValueTest { @Test void testInitialized() { - LazyValue lazyValue = new LazyValue<>(() -> "value"); + LazyValue lazyValue = new LazyValue<>("value"); assertThat(lazyValue.get(), is("value")); assertThat(lazyValue.get(), is("value")); } @@ -64,76 +63,71 @@ void testBadSupplier() { @Test void testInitRaceWithBadSupplier() throws InterruptedException, ExecutionException { AtomicInteger counter = new AtomicInteger(); + CountDownLatch l1 = new CountDownLatch(1); + CountDownLatch l2 = new CountDownLatch(1); LazyValue lazyValue = new LazyValue<>(() -> { - if (counter.getAndIncrement() == 0) { - sleep(); - throw new RuntimeException("error!"); + try { + l1.countDown(); + l2.await(); + if (counter.getAndIncrement() == 0) { + throw new RuntimeException("error!"); + } + return Thread.currentThread().getName(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } finally { + System.out.println("after supplier.get()"); } - return Thread.currentThread().getName(); }); Deque> futures = new ArrayDeque<>(); ExecutorService executorService = executorService(); - CountDownLatch latch1 = new CountDownLatch(1); + futures.add(executorService.submit(lazyValue::get)); + l1.await(); futures.add(executorService.submit(() -> { - latch1.countDown(); + l2.countDown(); return lazyValue.get(); })); - latch1.await(); - futures.add(executorService.submit(lazyValue::get)); - - Future first = futures.pop(); - ExecutionException ex = assertThrows(ExecutionException.class, first::get); + ExecutionException ex = assertThrows(ExecutionException.class, futures.pop()::get); assertThat(ex.getCause(), is(instanceOf(RuntimeException.class))); assertThat(ex.getCause().getMessage(), is("error!")); - while (!futures.isEmpty()) { - Future future = futures.pop(); - String value = future.get(); - assertThat(value, is(nullValue())); - } - - CountDownLatch latch2 = new CountDownLatch(1); - futures.add(executorService.submit(() -> { - latch2.countDown(); - return lazyValue.get(); - })); - latch2.await(); - - futures.add(executorService.submit(lazyValue::get)); - - while (!futures.isEmpty()) { - Future future = futures.pop(); + for (Future future : futures) { String value = future.get(); - assertThat(value, is("test-3")); + assertThat(value, is("test-2")); } String value = lazyValue.get(); - assertThat(value, is("test-3")); + assertThat(value, is("test-2")); } @Test void testInitRace() throws InterruptedException, ExecutionException { + CountDownLatch l1 = new CountDownLatch(1); + CountDownLatch l2 = new CountDownLatch(3); LazyValue lazyValue = new LazyValue<>(() -> { - sleep(); - return Thread.currentThread().getName(); + try { + l1.countDown(); + l2.await(); + return Thread.currentThread().getName(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } }); List> futures = new ArrayList<>(); ExecutorService executorService = executorService(); - CountDownLatch latch = new CountDownLatch(1); - futures.add(executorService.submit(() -> { - latch.countDown(); - return lazyValue.get(); - })); - latch.await(); - - futures.add(executorService.submit(lazyValue::get)); - futures.add(executorService.submit(lazyValue::get)); futures.add(executorService.submit(lazyValue::get)); + l1.await(); + for (int i = 0; i < 3; i++) { + futures.add(executorService.submit(() -> { + l2.countDown(); + return lazyValue.get(); + })); + } for (Future future : futures) { String value = future.get(); @@ -148,12 +142,4 @@ private static ExecutorService executorService() { AtomicInteger counter = new AtomicInteger(1); return Executors.newFixedThreadPool(4, r -> new Thread(null, r, "test-" + counter.getAndIncrement())); } - - private static void sleep() { - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } }