From c413892139daf9c717b09c46e4ceb867a0ab6ab4 Mon Sep 17 00:00:00 2001 From: "brian.mulier" Date: Fri, 5 Apr 2024 10:16:06 +0200 Subject: [PATCH] feat(runner): new ScriptRunner API + outputDir --- .../gcp/runner/GcpBatchScriptRunner.java | 105 +++++++++++------- .../gcp/runner/GcpBatchScriptRunnerTest.java | 2 + 2 files changed, 66 insertions(+), 41 deletions(-) diff --git a/src/main/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunner.java b/src/main/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunner.java index a82fc968..953df4f6 100644 --- a/src/main/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunner.java +++ b/src/main/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunner.java @@ -43,8 +43,7 @@ This job runner didn't resume the job if a Worker is restarted before the job finish. You need to have roles 'Batch Job Editor' and 'Logs Viewer' to be able to use it.""") @Plugin(examples = {}, beta = true) -public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface { - private static final String WORKING_DIR = "/kestra/working-dir"; +public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface, RemoteRunnerInterface { private static final int BUFFER_SIZE = 8 * 1024; public static final String MOUNT_PATH = "/mnt/disks/share"; @@ -109,26 +108,11 @@ public class GcpBatchScriptRunner extends ScriptRunner implements GcpInterface { @Override public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, List filesToUpload, List filesToDownload) throws Exception { - String renderedBucket = runContext.render(bucket); - String workingDirName = IdUtils.create(); - Map additionalVars = scriptCommands.getAdditionalVars(); - // TODO outputDir - Optional.ofNullable(renderedBucket).ifPresent(bucket -> additionalVars.putAll(Map.of( - ScriptService.VAR_BUCKET_PATH, "gs://" + bucket + "/" + workingDirName, - ScriptService.VAR_WORKING_DIR, WORKING_DIR - ))); + String renderedBucket = runContext.render(this.bucket); GoogleCredentials credentials = CredentialService.credentials(runContext, this); - List allFilesToUpload = new ArrayList<>(ListUtils.emptyOnNull(filesToUpload)); - List command = ScriptService.uploadInputFiles( - runContext, - runContext.render(scriptCommands.getCommands(), additionalVars), - (ignored, localFilePath) -> allFilesToUpload.add(localFilePath), - true - ); - - boolean hasFilesToUpload = !ListUtils.isEmpty(allFilesToUpload); + boolean hasFilesToUpload = !ListUtils.isEmpty(filesToUpload); if (hasFilesToUpload && bucket == null) { throw new IllegalArgumentException("You must provide a Cloud Storage Bucket to use `inputFiles` or `namespaceFiles`"); } @@ -137,12 +121,28 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li throw new IllegalArgumentException("You must provide a Cloud Storage Bucket to use `outputFiles`"); } - if (hasFilesToUpload) { + Map additionalVars = this.additionalVars(runContext, scriptCommands); + Path batchWorkingDirectory = (Path) additionalVars.get(ScriptService.VAR_WORKING_DIR); + String workingDirectoryToBlobPath = batchWorkingDirectory.toString().substring(1); + boolean hasBucket = this.bucket != null; + if (hasBucket) { + List filesToUploadWithOutputDir = new ArrayList<>(filesToUpload); + String outputDirName = (batchWorkingDirectory.relativize((Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR)) + "/").substring(1); + filesToUploadWithOutputDir.add(outputDirName); try (Storage storage = storage(runContext, credentials)) { - for (String file: allFilesToUpload) { - BlobInfo destination = BlobInfo.newBuilder(BlobId.of(bucket, workingDirName + "/" + file)).build(); - try (var fileInputStream = new FileInputStream(runContext.resolve(Path.of(file)).toFile()); - var writer = storage.writer(destination)) { + for (String relativePath: filesToUploadWithOutputDir) { + BlobInfo destination = BlobInfo.newBuilder(BlobId.of( + renderedBucket, + workingDirectoryToBlobPath + Path.of("/" + relativePath) + )).build(); + Path filePath = runContext.resolve(Path.of(relativePath)); + if (relativePath.endsWith("/")) { + storage.create(destination); + continue; + } + + try (var fileInputStream = new FileInputStream(filePath.toFile()); + var writer = storage.writer(destination)) { byte[] buffer = new byte[BUFFER_SIZE]; int limit; while ((limit = fileInputStream.read(buffer)) >= 0) { @@ -159,23 +159,18 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li if (hasFilesToDownload || hasFilesToUpload) { taskBuilder.addVolumes(Volume.newBuilder() - .setGcs(GCS.newBuilder().setRemotePath(this.bucket + "/" + workingDirName).build()) + .setGcs(GCS.newBuilder().setRemotePath(renderedBucket + batchWorkingDirectory).build()) .setMountPath(MOUNT_PATH) .build() ); } // main container - Map environment = new HashMap<>(runContext.renderMap(scriptCommands.getEnv(), additionalVars)); - environment.put(ScriptService.ENV_BUCKET_PATH, this.bucket + "/" + workingDirName); - environment.put(ScriptService.ENV_WORKING_DIR, WORKING_DIR); - // TODO outputDir -// environment.put(ScriptService.ENV_OUTPUT_DIR, scriptCommands.getOutputDirectory().toString()); Runnable runnable = Runnable.newBuilder() - .setContainer(mainContainer(scriptCommands, command, hasFilesToDownload || hasFilesToUpload)) + .setContainer(mainContainer(scriptCommands, scriptCommands.getCommands(), hasFilesToDownload || hasFilesToUpload, (Path) additionalVars.get(ScriptService.VAR_WORKING_DIR))) .setEnvironment(Environment.newBuilder() - .putAllVariables(environment) + .putAllVariables(this.env(runContext, scriptCommands)) .build() ) .build(); @@ -249,11 +244,14 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li runContext.logger().info("Job deleted"); } - if (hasFilesToDownload) { + if (hasBucket) { try (Storage storage = storage(runContext, credentials)) { - for (String file: filesToDownload) { - BlobInfo source = BlobInfo.newBuilder(BlobId.of(bucket, workingDirName + "/" + file)).build(); - try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(file)).toFile()); + for (String relativePath: filesToDownload) { + BlobInfo source = BlobInfo.newBuilder(BlobId.of( + renderedBucket, + workingDirectoryToBlobPath + Path.of("/" + relativePath) + )).build(); + try (var fileOutputStream = new FileOutputStream(runContext.resolve(Path.of(relativePath)).toFile()); var reader = storage.reader(source.getBlobId())) { byte[] buffer = new byte[BUFFER_SIZE]; int limit; @@ -262,30 +260,40 @@ public RunnerResult run(RunContext runContext, ScriptCommands scriptCommands, Li } } } + + Path batchOutputDirectory = (Path) additionalVars.get(ScriptService.VAR_OUTPUT_DIR); + Page outputDirEntries = storage.list(renderedBucket, Storage.BlobListOption.prefix(batchOutputDirectory.toString().substring(1))); + outputDirEntries.iterateAll().forEach(blob -> { + Path relativeBlobPathFromOutputDir = Path.of(batchOutputDirectory.toString().substring(1)).relativize(Path.of(blob.getBlobId().getName())); + storage.downloadTo( + blob.getBlobId(), + scriptCommands.getOutputDirectory().resolve(relativeBlobPathFromOutputDir) + ); + }); } } return new RunnerResult(0, scriptCommands.getLogConsumer()); } } finally { - if (hasFilesToUpload || hasFilesToDownload) { + if (hasBucket && delete) { try (Storage storage = storage(runContext, credentials)) { - Page list = storage.list(bucket, Storage.BlobListOption.prefix(workingDirName)); + Page list = storage.list(renderedBucket, Storage.BlobListOption.prefix(workingDirectoryToBlobPath)); list.iterateAll().forEach(blob -> storage.delete(blob.getBlobId())); - storage.delete(BlobInfo.newBuilder(BlobId.of(bucket, workingDirName)).build().getBlobId()); + storage.delete(BlobInfo.newBuilder(BlobId.of(renderedBucket, workingDirectoryToBlobPath)).build().getBlobId()); } } } } - private Runnable.Container mainContainer(ScriptCommands scriptCommands, List command, boolean mountVolume) { + private Runnable.Container mainContainer(ScriptCommands scriptCommands, List command, boolean mountVolume, Path batchWorkingDirectory) { // TODO working directory var builder = Runnable.Container.newBuilder() .setImageUri(scriptCommands.getContainerImage()) .addAllCommands(command); if (mountVolume) { - builder.addVolumes(MOUNT_PATH + ":" + WORKING_DIR); + builder.addVolumes(MOUNT_PATH + ":" + batchWorkingDirectory.toString()); } if (this.entryPoint != null) { @@ -329,6 +337,21 @@ private Storage storage(RunContext runContext, GoogleCredentials credentials) th .getService(); } + @Override + protected Map runnerAdditionalVars(RunContext runContext, ScriptCommands scriptCommands) throws IllegalVariableEvaluationException { + Map additionalVars = new HashMap<>(); + Path batchWorkingDirectory = Path.of("/" + IdUtils.create()); + additionalVars.put(ScriptService.VAR_WORKING_DIR, batchWorkingDirectory); + + if (bucket != null) { + Path batchOutputDirectory = batchWorkingDirectory.resolve(IdUtils.create()); + additionalVars.put(ScriptService.VAR_OUTPUT_DIR, batchOutputDirectory); + additionalVars.put(ScriptService.VAR_BUCKET_PATH, "gs://" + runContext.render(this.bucket) + batchWorkingDirectory); + } + + return additionalVars; + } + @Getter @Builder public static class NetworkInterface { diff --git a/src/test/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunnerTest.java b/src/test/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunnerTest.java index 475e4906..4feb092b 100644 --- a/src/test/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunnerTest.java +++ b/src/test/java/io/kestra/plugin/gcp/runner/GcpBatchScriptRunnerTest.java @@ -3,9 +3,11 @@ import io.kestra.core.models.script.AbstractScriptRunnerTest; import io.kestra.core.models.script.ScriptRunner; import io.micronaut.context.annotation.Value; +import org.junit.jupiter.api.Disabled; import java.util.List; +@Disabled("Need complex CI setup still needed to be done") class GcpBatchScriptRunnerTest extends AbstractScriptRunnerTest { @Value("${kestra.variables.globals.project}")