diff --git a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractComputationRunContext.java b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractComputationRunContext.java
index 6962d4f..3c5282e 100644
--- a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractComputationRunContext.java
+++ b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractComputationRunContext.java
@@ -7,6 +7,7 @@
package com.powsybl.ws.commons.computation.service;
import com.powsybl.commons.report.ReportNode;
+import com.powsybl.computation.ComputationManager;
import com.powsybl.iidm.network.Network;
import com.powsybl.ws.commons.computation.dto.ReportInfos;
import lombok.Getter;
@@ -30,9 +31,11 @@ public abstract class AbstractComputationRunContext
{
private P parameters;
private ReportNode reportNode;
private Network network;
+ private boolean debug;
+ private ComputationManager computationManager;
protected AbstractComputationRunContext(UUID networkUuid, String variantId, String receiver, ReportInfos reportInfos,
- String userId, String provider, P parameters) {
+ String userId, String provider, P parameters, boolean debug) {
this.networkUuid = networkUuid;
this.variantId = variantId;
this.receiver = receiver;
@@ -42,5 +45,6 @@ protected AbstractComputationRunContext(UUID networkUuid, String variantId, Stri
this.parameters = parameters;
this.reportNode = ReportNode.NO_OP;
this.network = null;
+ this.debug = debug;
}
}
diff --git a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractResultContext.java b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractResultContext.java
index 7ef4170..84ccc68 100644
--- a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractResultContext.java
+++ b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractResultContext.java
@@ -17,9 +17,7 @@
import java.util.Objects;
import java.util.UUID;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_PROVIDER;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_RECEIVER;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_USER_ID;
+import static com.powsybl.ws.commons.computation.service.NotificationService.*;
/**
* @author Mathieu Deharbe
@@ -42,6 +40,8 @@ public abstract class AbstractResultContext toMessage(ObjectMapper objectMapper) {
.setHeader(REPORT_UUID_HEADER, runContext.getReportInfos().reportUuid() != null ? runContext.getReportInfos().reportUuid().toString() : null)
.setHeader(REPORTER_ID_HEADER, runContext.getReportInfos().reporterId())
.setHeader(REPORT_TYPE_HEADER, runContext.getReportInfos().computationType())
+ .setHeader(DEBUG_HEADER, runContext.isDebug())
.copyHeaders(getSpecificMsgHeaders(objectMapper))
.build();
}
diff --git a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractWorkerService.java b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractWorkerService.java
index 4787ad4..2b9b648 100644
--- a/src/main/java/com/powsybl/ws/commons/computation/service/AbstractWorkerService.java
+++ b/src/main/java/com/powsybl/ws/commons/computation/service/AbstractWorkerService.java
@@ -9,6 +9,9 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.powsybl.commons.PowsyblException;
import com.powsybl.commons.report.ReportNode;
+import com.powsybl.computation.ComputationManager;
+import com.powsybl.computation.local.LocalComputationConfig;
+import com.powsybl.computation.local.LocalComputationManager;
import com.powsybl.iidm.network.Network;
import com.powsybl.iidm.network.VariantManagerConstants;
import com.powsybl.network.store.client.NetworkStoreService;
@@ -20,9 +23,11 @@
import org.springframework.messaging.Message;
import org.springframework.web.server.ResponseStatusException;
-import java.util.Map;
-import java.util.Objects;
-import java.util.UUID;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.*;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
@@ -32,6 +37,8 @@
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
+import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_DEBUG_DIR;
+
/**
* @author Mathieu Deharbe
* @param powsybl Result class specific to the computation
@@ -165,6 +172,8 @@ public Consumer> consumeRun() {
protected void clean(AbstractResultContext resultContext) {
futures.remove(resultContext.getResultUuid());
cancelComputationRequests.remove(resultContext.getResultUuid());
+
+ Optional.ofNullable(resultContext.getRunContext().getComputationManager()).ifPresent(ComputationManager::close);
}
/**
@@ -187,22 +196,41 @@ public Consumer> consumeCancel() {
protected abstract void saveResult(Network network, AbstractResultContext resultContext, R result);
- protected void sendResultMessage(AbstractResultContext resultContext, R ignoredResult) {
+ private Map getAdditionalHeaders(AbstractResultContext resultContext, R ignoredResult) {
+ Map additionalHeaders = new HashMap<>();
+ if (resultContext.getRunContext().isDebug() && resultContext.getRunContext().getComputationManager() != null) {
+ additionalHeaders.put(HEADER_DEBUG_DIR, resultContext.getRunContext().getComputationManager().getLocalDir().toAbsolutePath().toString());
+ }
+ return additionalHeaders;
+ }
+
+ public Map getResultHeaders(AbstractResultContext resultContext, R result) {
+ return getAdditionalHeaders(resultContext, result);
+ }
+
+ public Map getFailHeaders(AbstractResultContext resultContext, R result) {
+ return getAdditionalHeaders(resultContext, result);
+ }
+
+ private void sendResultMessage(AbstractResultContext resultContext, R result) {
+ Map additionalHeaders = getResultHeaders(resultContext, result);
notificationService.sendResultMessage(resultContext.getResultUuid(), resultContext.getRunContext().getReceiver(),
- resultContext.getRunContext().getUserId(), null);
+ resultContext.getRunContext().getUserId(), additionalHeaders);
}
- protected void publishFail(AbstractResultContext resultContext, String message) {
+ private void publishFail(AbstractResultContext resultContext, String message) {
+ Map additionalHeaders = getFailHeaders(resultContext, null);
notificationService.publishFail(resultContext.getResultUuid(), resultContext.getRunContext().getReceiver(),
- message, resultContext.getRunContext().getUserId(), getComputationType(), null);
+ message, resultContext.getRunContext().getUserId(), getComputationType(), additionalHeaders);
}
/**
* Do some extra task before running the computation, e.g. print log or init extra data for the run context
- * @param ignoredRunContext This context may be used for further computation in overriding classes
+ * @param runContext This context may be used for further computation in overriding classes
*/
- protected void preRun(C ignoredRunContext) {
+ protected void preRun(C runContext) {
LOGGER.info("Run {} computation...", getComputationType());
+ runContext.setComputationManager(createComputationManager());
}
protected R run(C runContext, UUID resultUuid, AtomicReference rootReporter) throws Exception {
@@ -262,4 +290,21 @@ protected CompletableFuture runAsync(
protected abstract String getComputationType();
protected abstract CompletableFuture getCompletableFuture(C runContext, String provider, UUID resultUuid);
+
+ /**
+ * set method as public to mock DockerLocalComputationManager when testing with test container
+ * @return a computation manager
+ */
+ public ComputationManager createComputationManager() {
+ LocalComputationConfig localComputationConfig = LocalComputationConfig.load();
+ Path localDir = localComputationConfig.getLocalDir();
+ try {
+ String workDirPrefix = getComputationType().replaceAll("\\s+", "_").toLowerCase() + "_";
+ Path workDir = Files.createTempDirectory(localDir, workDirPrefix);
+ return new LocalComputationManager(new LocalComputationConfig(workDir, localComputationConfig.getAvailableCore()), executionService.getExecutorService());
+ } catch (IOException e) {
+ throw new UncheckedIOException(String.format("Error occurred while creating a working directory inside the local directory %s",
+ localDir.toAbsolutePath()), e);
+ }
+ }
}
diff --git a/src/main/java/com/powsybl/ws/commons/computation/service/ExecutionService.java b/src/main/java/com/powsybl/ws/commons/computation/service/ExecutionService.java
index f6a124a..77c06ef 100644
--- a/src/main/java/com/powsybl/ws/commons/computation/service/ExecutionService.java
+++ b/src/main/java/com/powsybl/ws/commons/computation/service/ExecutionService.java
@@ -7,8 +7,6 @@
package com.powsybl.ws.commons.computation.service;
-import com.powsybl.computation.ComputationManager;
-import com.powsybl.computation.local.LocalComputationManager;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import lombok.Getter;
@@ -28,13 +26,10 @@ public class ExecutionService {
private ExecutorService executorService;
- private ComputationManager computationManager;
-
@SneakyThrows
@PostConstruct
private void postConstruct() {
executorService = Executors.newCachedThreadPool();
- computationManager = new LocalComputationManager(getExecutorService());
}
@PreDestroy
diff --git a/src/main/java/com/powsybl/ws/commons/computation/service/NotificationService.java b/src/main/java/com/powsybl/ws/commons/computation/service/NotificationService.java
index 59d9a01..d86fd80 100644
--- a/src/main/java/com/powsybl/ws/commons/computation/service/NotificationService.java
+++ b/src/main/java/com/powsybl/ws/commons/computation/service/NotificationService.java
@@ -48,6 +48,7 @@ public class NotificationService {
public static final String HEADER_PROVIDER = "provider";
public static final String HEADER_MESSAGE = "message";
public static final String HEADER_USER_ID = "userId";
+ public static final String HEADER_DEBUG_DIR = "debugDir";
public static final String SENDING_MESSAGE = "Sending message : {}";
diff --git a/src/test/java/com/powsybl/ws/commons/computation/ComputationTest.java b/src/test/java/com/powsybl/ws/commons/computation/ComputationTest.java
index 8dfd238..253ed9a 100644
--- a/src/test/java/com/powsybl/ws/commons/computation/ComputationTest.java
+++ b/src/test/java/com/powsybl/ws/commons/computation/ComputationTest.java
@@ -6,17 +6,7 @@
import com.powsybl.network.store.client.NetworkStoreService;
import com.powsybl.network.store.client.PreloadingStrategy;
import com.powsybl.ws.commons.computation.dto.ReportInfos;
-import com.powsybl.ws.commons.computation.service.AbstractComputationObserver;
-import com.powsybl.ws.commons.computation.service.AbstractComputationResultService;
-import com.powsybl.ws.commons.computation.service.AbstractComputationRunContext;
-import com.powsybl.ws.commons.computation.service.AbstractComputationService;
-import com.powsybl.ws.commons.computation.service.AbstractResultContext;
-import com.powsybl.ws.commons.computation.service.AbstractWorkerService;
-import com.powsybl.ws.commons.computation.service.CancelContext;
-import com.powsybl.ws.commons.computation.service.ExecutionService;
-import com.powsybl.ws.commons.computation.service.NotificationService;
-import com.powsybl.ws.commons.computation.service.ReportService;
-import com.powsybl.ws.commons.computation.service.UuidGeneratorService;
+import com.powsybl.ws.commons.computation.service.*;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.micrometer.observation.ObservationRegistry;
@@ -39,17 +29,13 @@
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_RECEIVER;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_RESULT_UUID;
-import static com.powsybl.ws.commons.computation.service.NotificationService.HEADER_USER_ID;
+import static com.powsybl.ws.commons.computation.service.NotificationService.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.ArgumentMatchers.isA;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.*;
@ExtendWith({ MockitoExtension.class })
@Slf4j
@@ -61,7 +47,9 @@ class ComputationTest implements WithAssertions {
private NetworkStoreService networkStoreService;
@Mock
private ReportService reportService;
+ @Mock
private final ExecutionService executionService = new ExecutionService();
+ private final ExecutorService executorService = Executors.newCachedThreadPool();
private final UuidGeneratorService uuidGeneratorService = new UuidGeneratorService();
@Mock
private StreamBridge publisher;
@@ -125,7 +113,7 @@ private static class MockComputationRunContext extends AbstractComputationRunCon
protected MockComputationRunContext(UUID networkUuid, String variantId, String receiver, ReportInfos reportInfos,
String userId, String provider, Object parameters) {
- super(networkUuid, variantId, receiver, reportInfos, userId, provider, parameters);
+ super(networkUuid, variantId, receiver, reportInfos, userId, provider, parameters, false);
}
}
@@ -235,6 +223,7 @@ private void initComputationExecution() {
when(networkStoreService.getNetwork(eq(networkUuid), any(PreloadingStrategy.class)))
.thenReturn(network);
when(network.getVariantManager()).thenReturn(variantManager);
+ when(executionService.getExecutorService()).thenReturn(executorService);
}
@Test
@@ -261,6 +250,7 @@ void testComputationFailed() {
// test the course
verify(notificationService.getPublisher(), times(1)).send(eq("publishFailed-out-0"), isA(Message.class));
+ executionService.getExecutorService().shutdown();
}
@Test