Skip to content

Commit

Permalink
feat(runner): new ScriptRunner API + outputDir
Browse files Browse the repository at this point in the history
  • Loading branch information
brian-mulier-p authored and loicmathieu committed Apr 5, 2024
1 parent 0e00a94 commit c413892
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 41 deletions.
105 changes: 64 additions & 41 deletions src/main/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@
This job runner didn't resume the job if a Worker is restarted before the job finish.
You need to have roles 'Batch Job Editor' and 'Logs Viewer' to be able to use it.""")
@Plugin(examples = {}, beta = true)
public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface {
private static final String WORKING_DIR = "/kestra/working-dir";
public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface, RemoteRunnerInterface {
private static final int BUFFER_SIZE = 8 * 1024;
public static final String MOUNT_PATH = "/mnt/disks/share";

Expand Down Expand Up @@ -109,26 +108,11 @@ public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface {

@Override
public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List<String> filesToUpload, List<String> filesToDownload) throws Exception {
String renderedBucket = runContext.render(bucket);
String workingDirName = IdUtils.create();
Map<String, Object> additionalVars = scriptCommands.getAdditionalVars();
// TODO outputDir
Optional.ofNullable(renderedBucket).ifPresent(bucket -> additionalVars.putAll(Map.<String, Object>of(
ScriptService.VAR_BUCKET_PATH, "gs://" + bucket + "/" + workingDirName,
ScriptService.VAR_WORKING_DIR, WORKING_DIR
)));
String renderedBucket = runContext.render(this.bucket);

GoogleCredentials credentials = CredentialService.credentials(runContext, this);

List<String> allFilesToUpload = new ArrayList<>(ListUtils.emptyOnNull(filesToUpload));
List<String> command = ScriptService.uploadInputFiles(
runContext,
runContext.render(scriptCommands.getCommands(), additionalVars),
(ignored, localFilePath) -> allFilesToUpload.add(localFilePath),
true
);

boolean hasFilesToUpload = !ListUtils.isEmpty(allFilesToUpload);
boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload);
if (hasFilesToUpload && bucket == null) {
throw new IllegalArgumentException("You must provide a Cloud Storage Bucket to use `inputFiles` or `namespaceFiles`");
}
Expand All @@ -137,12 +121,28 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
throw new IllegalArgumentException("You must provide a Cloud Storage Bucket to use `outputFiles`");
}

if (hasFilesToUpload) {
Map<String, Object> additionalVars = this.additionalVars(runContext, scriptCommands);
Path batchWorkingDirectory = (Path) additionalVars.get(ScriptService.VAR_WORKING_DIR);
String workingDirectoryToBlobPath = batchWorkingDirectory.toString().substring(1);
boolean hasBucket = this.bucket != null;
if (hasBucket) {
List<String> filesToUploadWithOutputDir = new ArrayList<>(filesToUpload);
String outputDirName = (batchWorkingDirectory.relativize((Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR)) + "/").substring(1);
filesToUploadWithOutputDir.add(outputDirName);
try (Storage storage = storage(runContext, credentials)) {
for (String file: allFilesToUpload) {
BlobInfo destination = BlobInfo.newBuilder(BlobId.of(bucket, workingDirName + "/" + file)).build();
try (var fileInputStream = new FileInputStream(runContext.resolve(Path.of(file)).toFile());
var writer = storage.writer(destination)) {
for (String relativePath: filesToUploadWithOutputDir) {
BlobInfo destination = BlobInfo.newBuilder(BlobId.of(
renderedBucket,
workingDirectoryToBlobPath + Path.of("/" + relativePath)
)).build();
Path filePath = runContext.resolve(Path.of(relativePath));
if (relativePath.endsWith("/")) {
storage.create(destination);
continue;
}

try (var fileInputStream = new FileInputStream(filePath.toFile());
var writer = storage.writer(destination)) {
byte[] buffer = new byte[BUFFER_SIZE];
int limit;
while ((limit = fileInputStream.read(buffer)) >= 0) {
Expand All @@ -159,23 +159,18 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li

if (hasFilesToDownload || hasFilesToUpload) {
taskBuilder.addVolumes(Volume.newBuilder()
.setGcs(GCS.newBuilder().setRemotePath(this.bucket + "/" + workingDirName).build())
.setGcs(GCS.newBuilder().setRemotePath(renderedBucket + batchWorkingDirectory).build())
.setMountPath(MOUNT_PATH)
.build()
);
}

// main container
Map<String, String> environment = new HashMap<>(runContext.renderMap(scriptCommands.getEnv(), additionalVars));
environment.put(ScriptService.ENV_BUCKET_PATH, this.bucket + "/" + workingDirName);
environment.put(ScriptService.ENV_WORKING_DIR, WORKING_DIR);
// TODO outputDir
// environment.put(ScriptService.ENV_OUTPUT_DIR, scriptCommands.getOutputDirectory().toString());
Runnable runnable =
Runnable.newBuilder()
.setContainer(mainContainer(scriptCommands, command, hasFilesToDownload || hasFilesToUpload))
.setContainer(mainContainer(scriptCommands, scriptCommands.getCommands(), hasFilesToDownload || hasFilesToUpload, (Path) additionalVars.get(ScriptService.VAR_WORKING_DIR)))
.setEnvironment(Environment.newBuilder()
.putAllVariables(environment)
.putAllVariables(this.env(runContext, scriptCommands))
.build()
)
.build();
Expand Down Expand Up @@ -249,11 +244,14 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
runContext.logger().info("Job deleted");
}

if (hasFilesToDownload) {
if (hasBucket) {
try (Storage storage = storage(runContext, credentials)) {
for (String file: filesToDownload) {
BlobInfo source = BlobInfo.newBuilder(BlobId.of(bucket, workingDirName + "/" + file)).build();
try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(file)).toFile());
for (String relativePath: filesToDownload) {
BlobInfo source = BlobInfo.newBuilder(BlobId.of(
renderedBucket,
workingDirectoryToBlobPath + Path.of("/" + relativePath)
)).build();
try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(relativePath)).toFile());
var reader = storage.reader(source.getBlobId())) {
byte[] buffer = new byte[BUFFER_SIZE];
int limit;
Expand All @@ -262,30 +260,40 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li
}
}
}

Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR);
Page<Blob> outputDirEntries = storage.list(renderedBucket, Storage.BlobListOption.prefix(batchOutputDirectory.toString().substring(1)));
outputDirEntries.iterateAll().forEach(blob -> {
Path relativeBlobPathFromOutputDir = Path.of(batchOutputDirectory.toString().substring(1)).relativize(Path.of(blob.getBlobId().getName()));
storage.downloadTo(
blob.getBlobId(),
scriptCommands.getOutputDirectory().resolve(relativeBlobPathFromOutputDir)
);
});
}
}

return new RunnerResult(0, scriptCommands.getLogConsumer());
}
} finally {
if (hasFilesToUpload || hasFilesToDownload) {
if (hasBucket && delete) {
try (Storage storage = storage(runContext, credentials)) {
Page<Blob> list = storage.list(bucket, Storage.BlobListOption.prefix(workingDirName));
Page<Blob> list = storage.list(renderedBucket, Storage.BlobListOption.prefix(workingDirectoryToBlobPath));
list.iterateAll().forEach(blob -> storage.delete(blob.getBlobId()));
storage.delete(BlobInfo.newBuilder(BlobId.of(bucket, workingDirName)).build().getBlobId());
storage.delete(BlobInfo.newBuilder(BlobId.of(renderedBucket, workingDirectoryToBlobPath)).build().getBlobId());
}
}
}
}

private Runnable.Container mainContainer(ScriptCommands scriptCommands, List<String> command, boolean mountVolume) {
private Runnable.Container mainContainer(ScriptCommands scriptCommands, List<String> command, boolean mountVolume, Path batchWorkingDirectory) {
// TODO working directory
var builder = Runnable.Container.newBuilder()
.setImageUri(scriptCommands.getContainerImage())
.addAllCommands(command);

if (mountVolume) {
builder.addVolumes(MOUNT_PATH + ":" + WORKING_DIR);
builder.addVolumes(MOUNT_PATH + ":" + batchWorkingDirectory.toString());
}

if (this.entryPoint != null) {
Expand Down Expand Up @@ -329,6 +337,21 @@ private Storage storage(RunContext runContext, GoogleCredentials credentials) th
.getService();
}

@Override
protected Map<String, Object> runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException {
Map<String, Object> additionalVars = new HashMap<>();
Path batchWorkingDirectory = Path.of("/" + IdUtils.create());
additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory);

if (bucket != null) {
Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create());
additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory);
additionalVars.put(ScriptService.VAR_BUCKET_PATH, "gs://" + runContext.render(this.bucket) + batchWorkingDirectory);
}

return additionalVars;
}

@Getter
@Builder
public static class NetworkInterface {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import io.kestra.core.models.script.AbstractScriptRunnerTest;
import io.kestra.core.models.script.ScriptRunner;
import io.micronaut.context.annotation.Value;
import org.junit.jupiter.api.Disabled;

import java.util.List;

@Disabled("Need complex CI setup still needed to be done")
class GcpBatchScriptRunnerTest extends AbstractScriptRunnerTest {

@Value("${kestra.variables.globals.project}")
Expand Down

0 comments on commit c413892

Please sign in to comment.