Skip to content

Commit

Permalink
Use move trick for saveStitchedTileGrid too. #329
Browse files Browse the repository at this point in the history
  • Loading branch information
EmileSonneveld committed Oct 29, 2024
1 parent bf32624 commit 7a00557
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ package object geotiff {

val layout = rdd.metadata.layout
val crs = rdd.metadata.crs
rdd.flatMap {
val res = rdd.flatMap {
case (key, tile) => features.filter { case (_, extent) =>
val tileBounds = layout.mapTransform(extent)

Expand All @@ -732,12 +732,31 @@ package object geotiff {
((name, extent), (key, tile))
}
}.groupByKey()
.map { case ((name, extent), tiles) =>
val filePath = newFilePath(path, name)
.map { case ((tileId, extent), tiles) =>
// Each executor writes to a unique folder to avoid conflicts:
val uniqueFolderName = "tmp" + java.lang.Long.toUnsignedString(new java.security.SecureRandom().nextLong())
val base = Paths.get(Path.of(path).getParent + "/" + uniqueFolderName)
Files.createDirectories(base)
val filePath = base + "/" + newFilePath(Path.of(path).getFileName.toString, tileId)

(stitchAndWriteToTiff(tiles, filePath, layout, crs, extent, croppedExtent, cropDimensions, compression), extent)
}.collect()
.toList.asJava
}.collect().map({
case (absolutePath, croppedExtent) =>
// Move output file to standard location. (On S3, a move is more a copy and delete):
val relativePath = Path.of(path).getParent.relativize(Path.of(absolutePath)).toString
val destinationPath = Path.of(path).getParent.resolve(relativePath.substring(relativePath.indexOf("/") + 1))
waitTillPathAvailable(Path.of(absolutePath))
Files.move(Path.of(absolutePath), destinationPath)
(destinationPath.toString, croppedExtent)
}).toList.asJava

// Clean up failed tasks:
Files.list(Path.of(path).getParent).forEach { p =>
if (Files.isDirectory(p) && p.getFileName.toString.startsWith("tmp")) {
FileUtils.deleteDirectory(p.toFile)
}
}
res
}

private def stitchAndWriteToTiff(tiles: Iterable[(SpatialKey, MultibandTile)], filePath: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.openeo.geotrellis.geotiff

import better.files.File.apply

import java.time.LocalTime.MIDNIGHT
import java.time.ZoneOffset.UTC
import java.time.{LocalDate, ZonedDateTime}
Expand All @@ -16,8 +18,10 @@ import org.openeo.geotrellis.png.PngTest
import org.openeo.geotrellis.tile_grid.TileGrid
import org.openeo.geotrellis.{LayerFixtures, geotiff}

import java.nio.file.{Files, Paths}
import java.time.format.DateTimeFormatter.ISO_ZONED_DATE_TIME
import scala.collection.JavaConverters._
import scala.reflect.io.Directory

object TileGridTest {
private var sc: SparkContext = _
Expand Down Expand Up @@ -48,6 +52,10 @@ class TileGridTest {

@Test
def testSaveStitchWithTileGrids(): Unit = {
val outDir = Paths.get("tmp/testSaveStitchWithTileGrids/")
new Directory(outDir.toFile).deepList().foreach(_.delete())
Files.createDirectories(outDir)

val date = ZonedDateTime.of(LocalDate.of(2020, 4, 5), MIDNIGHT, UTC)
val bbox = ProjectedExtent(Extent(1.95, 50.95, 2.05, 51.05), LatLng)

Expand All @@ -57,17 +65,27 @@ class TileGridTest {
.toSpatial()
.persist(DISK_ONLY)

val tiles = geotiff.saveStitchedTileGrid(spatialLayer, "/tmp/testSaveStitched.tiff", "10km", DeflateCompression(6))
val expectedPaths = Set("/tmp/testSaveStitched-31UDS_3_4.tiff", "/tmp/testSaveStitched-31UDS_2_4.tiff", "/tmp/testSaveStitched-31UDS_3_5.tiff", "/tmp/testSaveStitched-31UDS_2_5.tiff")
val tiles = geotiff.saveStitchedTileGrid(spatialLayer, outDir + "/testSaveStitched.tiff", "10km", DeflateCompression(6))
val expectedPaths = Set(
outDir + "/testSaveStitched-31UDS_3_4.tiff",
outDir + "/testSaveStitched-31UDS_2_4.tiff",
outDir + "/testSaveStitched-31UDS_3_5.tiff",
outDir + "/testSaveStitched-31UDS_2_5.tiff",
)

// TODO: check if extents (in the layer CRS) are 10000m wide/high (in UTM)
Assert.assertEquals(expectedPaths, tiles.asScala.map { case (path, _) => path }.toSet)

val extent = bbox.reproject(spatialLayer.metadata.crs)
val cropBounds = mapAsJavaMap(Map("xmin" -> extent.xmin, "xmax" -> extent.xmax, "ymin" -> extent.ymin, "ymax" -> extent.ymax))

val croppedTiles = geotiff.saveStitchedTileGrid(spatialLayer, "/tmp/testSaveStitched_cropped.tiff", "10km", cropBounds, DeflateCompression(6))
val expectedCroppedPaths = Set("/tmp/testSaveStitched_cropped-31UDS_3_4.tiff", "/tmp/testSaveStitched_cropped-31UDS_2_4.tiff", "/tmp/testSaveStitched_cropped-31UDS_3_5.tiff", "/tmp/testSaveStitched_cropped-31UDS_2_5.tiff")
val croppedTiles = geotiff.saveStitchedTileGrid(spatialLayer, outDir + "/testSaveStitched_cropped.tiff", "10km", cropBounds, DeflateCompression(6))
val expectedCroppedPaths = Set(
outDir + "/testSaveStitched_cropped-31UDS_3_4.tiff",
outDir + "/testSaveStitched_cropped-31UDS_2_4.tiff",
outDir + "/testSaveStitched_cropped-31UDS_3_5.tiff",
outDir + "/testSaveStitched_cropped-31UDS_2_5.tiff",
)

// TODO: also check extents
Assert.assertEquals(expectedCroppedPaths, croppedTiles.asScala.map { case (path, _) => path }.toSet)
Expand Down

0 comments on commit 7a00557

Please sign in to comment.