Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change GroovyScriptExtension.load to use a map of context objects instead of a ComputationManager to be more generic #3308

Merged
merged 8 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import com.powsybl.dsl.ast.BooleanLiteralNode
import com.powsybl.dsl.ast.ExpressionNode
import com.powsybl.iidm.modification.NetworkModification
import com.powsybl.iidm.network.Network
import com.powsybl.scripting.groovy.GroovyScriptExtension
import com.powsybl.scripting.groovy.GroovyScripts
import org.codehaus.groovy.control.CompilationFailedException
import org.slf4j.LoggerFactory

Expand Down Expand Up @@ -110,19 +112,24 @@ class ActionDslLoader extends DslLoader {
}

ActionDb load(Network network) {
load(network, null)
load(network, null, new HashMap<Class<?>, Object>())
}

ActionDb load(Network network, Map<Class<?>, Object> contextObjects) {
load(network, null, contextObjects)
}

/**
* Loads in binding the functions which create contingencies, actions, and rules,
* binding them to the network parameter. The handler defines how created objects will be used.
*
* @param binding The context which functions will be created in
* @param network The network which functions will be bound to
* @param handler Will allow client code to define how objects created when interpreting a script will be used
* @param observer Will allow client code to observe the interpretation of the script
* @param binding The context which functions will be created in
* @param network The network which functions will be bound to
* @param handler Will allow client code to define how objects created when interpreting a script will be used
* @param observer Will allow client code to observe the interpretation of the script
* @param contextObjects Context objects used in groovy script extensions
*/
static void loadDsl(Binding binding, Network network, ActionDslHandler handler, ActionDslLoaderObserver observer) {
static void loadDsl(Binding binding, Network network, ActionDslHandler handler, ActionDslLoaderObserver observer, Map<Class<?>, Object> contextObjects) {

// set base network
binding.setVariable("network", network)
Expand All @@ -132,6 +139,10 @@ class ActionDslLoader extends DslLoader {

ConditionDslLoader.prepareClosures(binding)

// Bindings through extensions
Iterable<GroovyScriptExtension> extensions = ServiceLoader.load(GroovyScriptExtension.class, GroovyScripts.class.getClassLoader())
extensions.forEach { it.load(binding, contextObjects) }

// rules
binding.rule = { String id, Closure<Void> closure ->
def cloned = closure.clone()
Expand Down Expand Up @@ -198,13 +209,17 @@ class ActionDslLoader extends DslLoader {
}

void load(Network network, ActionDslHandler handler, ActionDslLoaderObserver observer) {
load(network, handler, observer, new HashMap<Class<?>, Object>())
}

void load(Network network, ActionDslHandler handler, ActionDslLoaderObserver observer, Map<Class<?>, Object> contextObjects) {

LOGGER.debug("Loading DSL '{}'", dslSrc.getName())
observer?.begin(dslSrc.getName())

Binding binding = new Binding()

loadDsl(binding, network, handler, observer)
loadDsl(binding, network, handler, observer, contextObjects)
try {

def shell = createShell(binding)
Expand All @@ -221,6 +236,10 @@ class ActionDslLoader extends DslLoader {
}

ActionDb load(Network network, ActionDslLoaderObserver observer) {
return load(network, observer, new HashMap<Class<?>, Object>())
}

ActionDb load(Network network, ActionDslLoaderObserver observer, Map<Class<?>, Object> contextObjects) {
ActionDb rulesDb = new ActionDb()

//Handler to create an ActionDb instance
Expand All @@ -242,7 +261,7 @@ class ActionDslLoader extends DslLoader {
}
}

load(network, actionDbBuilder, observer)
load(network, actionDbBuilder, observer, contextObjects)

rulesDb.checkUndefinedActions()
rulesDb
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.powsybl.iidm.network.Network;

import java.util.List;
import java.util.Map;

/**
* @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
Expand All @@ -18,6 +19,10 @@ public interface ContingenciesProvider {

List<Contingency> getContingencies(Network network);

default List<Contingency> getContingencies(Network network, Map<Class<?>, Object> contextObjects) {
return getContingencies(network);
}

default String asScript() {
throw new UnsupportedOperationException("Serialization not supported for contingencies provider of type " + this.getClass().getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
*/
package com.powsybl.contingency;

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.Test;

import java.util.HashMap;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

/**
* @author Mathieu Bague {@literal <mathieu.bague at rte-france.com>}
*/
Expand All @@ -22,5 +26,6 @@ void test() {

assertInstanceOf(EmptyContingencyListProvider.class, provider);
assertEquals(0, provider.getContingencies(null).size());
assertEquals(0, provider.getContingencies(null, new HashMap<>()).size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
*/
package com.powsybl.contingency;

import com.google.common.collect.ImmutableList;
import com.powsybl.computation.Partition;
import com.powsybl.iidm.network.Generator;
import com.powsybl.iidm.network.Network;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -29,21 +30,50 @@ class SubContingenciesProviderTest {
@Test
void test() {
ContingenciesProvider provider = n -> IntStream.range(1, 5)
.mapToObj(i -> new Contingency("contingency-" + i))
.collect(Collectors.toList());
.mapToObj(i -> new Contingency("contingency-" + i))
.collect(Collectors.toList());

Network network = Mockito.mock(Network.class);

List<String> subList1 = new SubContingenciesProvider(provider, new Partition(1, 2))
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());

List<String> subList2 = new SubContingenciesProvider(provider, new Partition(2, 2))
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());

assertEquals(ImmutableList.of("contingency-1", "contingency-2"), subList1);
assertEquals(ImmutableList.of("contingency-3", "contingency-4"), subList2);
assertEquals(List.of("contingency-1", "contingency-2"), subList1);
assertEquals(List.of("contingency-3", "contingency-4"), subList2);
}

@Test
void testWithContextObjects() {
ContingenciesProvider provider = n -> IntStream.range(1, 5)
.mapToObj(i -> new Contingency("contingency-" + i))
.collect(Collectors.toList());

Network network = Mockito.mock(Network.class);
Map<Class<?>, Object> contextObjects = Map.of(Generator.class, Mockito.mock(Generator.class));

List<String> subList1 = new SubContingenciesProvider(provider, new Partition(1, 2))
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());

List<String> subList2 = new SubContingenciesProvider(provider, new Partition(2, 2))
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());

List<String> subList1ContextObjects = new SubContingenciesProvider(provider, new Partition(1, 2))
.getContingencies(network, contextObjects)
.stream().map(Contingency::getId).collect(Collectors.toList());

List<String> subList2ContextObjects = new SubContingenciesProvider(provider, new Partition(2, 2))
.getContingencies(network, contextObjects)
.stream().map(Contingency::getId).collect(Collectors.toList());

assertEquals(subList1, subList1ContextObjects);
assertEquals(subList2, subList2ContextObjects);
}

@Test
Expand All @@ -53,8 +83,8 @@ void testEmpty() {
Network network = Mockito.mock(Network.class);

List<String> subList1 = new SubContingenciesProvider(provider, new Partition(1, 1))
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());
.getContingencies(network)
.stream().map(Contingency::getId).collect(Collectors.toList());

assertEquals(Collections.emptyList(), subList1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package com.powsybl.iidm.network.scripting

import com.google.auto.service.AutoService
import com.powsybl.computation.ComputationManager
import com.powsybl.computation.local.LocalComputationManager
import com.powsybl.iidm.network.ExportersLoader
import com.powsybl.iidm.network.ExportersServiceLoader
Expand Down Expand Up @@ -48,7 +47,7 @@ class NetworkLoadSaveGroovyScriptExtension implements GroovyScriptExtension {
}

@Override
void load(Binding binding, ComputationManager computationManager) {
void load(Binding binding, Map<Class<?>, Object> contextObjects) {
binding.loadNetwork = { String file, Properties parameters = null ->
Network.read(fileSystem.getPath(file), LocalComputationManager.getDefault(),
importConfig, parameters, importersLoader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,24 @@ class LoadFlowGroovyScriptExtension implements GroovyScriptExtension {
private final LoadFlowParameters parameters

LoadFlowGroovyScriptExtension(LoadFlowParameters parameters) {
assert parameters
this.parameters = parameters
this.parameters = Objects.requireNonNull(parameters)
}

LoadFlowGroovyScriptExtension() {
this(LoadFlowParameters.load())
}

@Override
void load(Binding binding, ComputationManager computationManager) {
binding.loadFlow = { Network network, LoadFlowParameters parameters = this.parameters ->
LoadFlow.run(network, network.getVariantManager().getWorkingVariantId(), computationManager, parameters)
}
binding.loadflow = { Network network, LoadFlowParameters parameters = this.parameters ->
LoadFlow.run(network, network.getVariantManager().getWorkingVariantId(), computationManager, parameters)
void load(Binding binding, Map<Class<?>, Object> contextObjects) {
ComputationManager computationManager = contextObjects.get(ComputationManager.class) as ComputationManager
if (computationManager != null) {

binding.loadFlow = { Network network, LoadFlowParameters parameters = this.parameters ->
LoadFlow.run(network, network.getVariantManager().getWorkingVariantId(), computationManager, parameters)
}
binding.loadflow = { Network network, LoadFlowParameters parameters = this.parameters ->
LoadFlow.run(network, network.getVariantManager().getWorkingVariantId(), computationManager, parameters)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*/
package com.powsybl.loadflow.scripting;

import com.powsybl.computation.ComputationManager;
import com.powsybl.iidm.network.Network;
import com.powsybl.iidm.network.VariantManager;
import com.powsybl.iidm.network.VariantManagerConstants;
Expand All @@ -21,14 +20,14 @@

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
*/
class LoadFlowExtensionGroovyScriptTest extends AbstractGroovyScriptTest {

private Network fooNetwork;
private ComputationManager computationManager;

@BeforeEach
void setUp() {
Expand Down Expand Up @@ -57,12 +56,13 @@ protected String getExpectedOutput() {
protected List<GroovyScriptExtension> getExtensions() {
GroovyScriptExtension ext = new GroovyScriptExtension() {
@Override
public void load(Binding binding, ComputationManager computationManager) {
public void load(Binding binding, Map<Class<?>, Object> contextObjects) {
binding.setVariable("n", fooNetwork);
}

@Override
public void unload() {
// Nothing to do here
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand All @@ -35,7 +36,7 @@ public abstract class AbstractGroovyScriptTest {
protected abstract List<GroovyScriptExtension> getExtensions();

public void doTest() {
ComputationManager computationManager = Mockito.mock(ComputationManager.class);
Map<Class<?>, Object> contextObjects = Map.of(ComputationManager.class, Mockito.mock(ComputationManager.class));
Binding binding = new Binding();
StringWriter out = null;
try {
Expand All @@ -47,7 +48,7 @@ public void doTest() {
// Add a check on thread interruption in every loop (for, while) in the script
conf.addCompilationCustomizers(new ASTTransformationCustomizer(ThreadInterrupt.class));

getExtensions().forEach(it -> it.load(binding, computationManager));
getExtensions().forEach(it -> it.load(binding, contextObjects));
GroovyShell shell = new GroovyShell(binding, conf);
shell.evaluate(getCode());
out = (StringWriter) binding.getProperty("out");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/
package com.powsybl.scripting.groovy

import com.powsybl.computation.ComputationManager
import com.powsybl.computation.DefaultComputationManagerConfig
import groovy.transform.ThreadInterrupt
import org.codehaus.groovy.control.CompilerConfiguration
Expand Down Expand Up @@ -47,6 +48,10 @@ class GroovyScripts {
}

static void run(Reader codeReader, Binding binding, Iterable<GroovyScriptExtension> extensions, PrintStream out) {
run(codeReader, binding, extensions, out, new HashMap<>())
}

static void run(Reader codeReader, Binding binding, Iterable<GroovyScriptExtension> extensions, PrintStream out, Map<Class<?>, Object> contextObjects) {
assert codeReader
assert extensions != null

Expand All @@ -58,14 +63,15 @@ class GroovyScripts {
// Computation manager
DefaultComputationManagerConfig config = DefaultComputationManagerConfig.load()
binding.computationManager = config.createShortTimeExecutionComputationManager()
contextObjects.put(ComputationManager.class, binding.computationManager)

if (out != null) {
binding.out = out
}

try {
// load extensions
extensions.forEach { it.load(binding, binding.computationManager) }
extensions.forEach { it.load(binding, contextObjects) }

GroovyShell shell = new GroovyShell(binding, conf)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.powsybl.scripting.groovy

import com.google.auto.service.AutoService

/**
* Extension used to bind the groovy script output to a specific writer
* @author Nicolas Rol {@literal <nicolas.rol at rte-france.com>}
*/
@AutoService(GroovyScriptExtension.class)
class LogsGroovyScriptExtension implements GroovyScriptExtension {

LogsGroovyScriptExtension() {}

@Override
void load(Binding binding, Map<Class<?>, Object> contextObjects) {
Writer writer = contextObjects.get(Writer.class) as Writer
if (writer != null) {
binding.out = writer
}
}

@Override
void unload() {}
}
Loading