Skip to content

Commit

Permalink
Cleanup multipart download and allow for concurrent downloads (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Jan 10, 2024
1 parent 457e190 commit 10ac998
Showing 1 changed file with 75 additions and 41 deletions.
116 changes: 75 additions & 41 deletions Sources/Soto/Extensions/S3/S3+multipart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import Atomics
import Logging
import NIOConcurrencyHelpers
import NIOCore
import NIOPosix
import SotoCore
Expand Down Expand Up @@ -50,13 +51,15 @@ extension S3 {
/// - parameters:
/// - input: The GetObjectRequest shape that contains the details of the object request.
/// - partSize: Size of each part to be downloaded
/// - concurrentDownloads: How many downloads can you have running at one time
/// - outputStream: Function to be called for each downloaded part. Called with data block and file size
/// - returns: The complete file size once the multipart download has finished.
public func multipartDownload(
_ input: GetObjectRequest,
partSize: Int = 5 * 1024 * 1024,
concurrentDownloads: Int = 4,
logger: Logger = AWSClient.loggingDisabled,
outputStream: @escaping (ByteBuffer, Int64) async throws -> Void
outputStream: @escaping @Sendable (ByteBuffer, Int64) async throws -> Void
) async throws -> Int64 {
// get object size before downloading
let headRequest = S3.HeadObjectRequest(
Expand All @@ -77,48 +80,76 @@ extension S3 {
throw S3ErrorType.multipart.downloadEmpty(message: "Content length is unexpectedly zero")
}

// download part task
func downloadPartTask(offset: Int64, partSize: Int64) -> Task<GetObjectOutput, Swift.Error> {
let range = "bytes=\(offset)-\(offset + Int64(partSize - 1))"
let getRequest = S3.GetObjectRequest(
bucket: input.bucket,
key: input.key,
range: range,
sseCustomerAlgorithm: input.sseCustomerAlgorithm,
sseCustomerKey: input.sseCustomerKey,
sseCustomerKeyMD5: input.sseCustomerKeyMD5,
versionId: input.versionId
)
return Task {
try await getObject(getRequest, logger: logger)
}
}

// save part task
func savePart(downloadedPart: GetObjectOutput) async throws {
try await outputStream(downloadedPart.body.collect(upTo: .max), contentLength)
}

let partSize: Int64 = numericCast(partSize)
var offset = min(partSize, contentLength)
var downloadedPartTask = downloadPartTask(offset: 0, partSize: offset)
while offset < contentLength {
// wait for previous download
let downloadedPart = try await downloadedPartTask.value
try await withThrowingTaskGroup(of: (Int, ByteBuffer).self) { group in
/// Structure used to store downloaded buffers and then save them as and when
/// needed
struct DownloadedBuffers {
let outputStream: @Sendable (ByteBuffer) async throws -> Void
var buffers: [ByteBuffer?]
var bufferSavedIndex: Int

init(numberOfBuffers: Int, outputStream: @escaping @Sendable (ByteBuffer) async throws -> Void) {
self.outputStream = outputStream
self.buffers = Array(repeating: nil, count: numberOfBuffers)
self.bufferSavedIndex = 0
}

// start next download
let downloadPartSize = min(partSize, contentLength - offset)
downloadedPartTask = downloadPartTask(offset: offset, partSize: downloadPartSize)
offset += downloadPartSize
mutating func saveBuffer(index: Int, buffer: ByteBuffer) async throws {
assert(index >= 0 && index < self.buffers.count)
self.buffers[index] = buffer
while self.bufferSavedIndex < self.buffers.count, let bufferToSave = self.buffers[bufferSavedIndex] {
self.buffers[self.bufferSavedIndex] = nil
self.bufferSavedIndex += 1
try await self.outputStream(bufferToSave)
}
}
}
let partSize64: Int64 = numericCast(partSize)
var count = 0
var offset: Int64 = 0
let numberOfParts: Int = numericCast((contentLength - 1) / partSize64) + 1
var downloadBuffers = DownloadedBuffers(numberOfBuffers: numberOfParts) { buffer in
try await outputStream(buffer, contentLength)
}
// while we still have parts to download
while count < numberOfParts {
if count > concurrentDownloads {
// if count is greater than concurrentDownloads then start waiting for
// parts that have downloaded to save them
if let (index, buffer) = try await group.next() {
// save the buffer
try await downloadBuffers.saveBuffer(index: index, buffer: buffer)
}
}
let index = count
let currentPartSize = min(partSize64, contentLength - offset)
let currentOffset = offset
// add task downloading from S3
group.addTask {
let range = "bytes=\(currentOffset)-\(currentOffset + currentPartSize - 1)"
let getRequest = S3.GetObjectRequest(
bucket: input.bucket,
key: input.key,
range: range,
sseCustomerAlgorithm: input.sseCustomerAlgorithm,
sseCustomerKey: input.sseCustomerKey,
sseCustomerKeyMD5: input.sseCustomerKeyMD5,
versionId: input.versionId
)
let getObjectOutput = try await getObject(getRequest, logger: logger)
let buffer = try await getObjectOutput.body.collect(upTo: partSize)
return (index, buffer)
}
offset += partSize64
count += 1
}

// save part
try await savePart(downloadedPart: downloadedPart)
// save the remaining parts
for try await(index, buffer) in group {
// save the buffer
try await downloadBuffers.saveBuffer(index: index, buffer: buffer)
}
}
// wait for last download
let downloadedPart = try await downloadedPartTask.value
// and save part
try await savePart(downloadedPart: downloadedPart)

return contentLength
}

Expand All @@ -128,6 +159,7 @@ extension S3 {
/// - input: The GetObjectRequest shape that contains the details of the object request.
/// - partSize: Size of each part to be downloaded
/// - filename: Filename to save download to
/// - concurrentDownloads: How many downloads can you have running at one time
/// - threadPool: Thread pool used to save file
/// - logger: logger
/// - progress: Callback that returns the progress of the download. It is called after each part is downloaded with a value
Expand All @@ -137,9 +169,10 @@ extension S3 {
_ input: GetObjectRequest,
partSize: Int = 5 * 1024 * 1024,
filename: String,
concurrentDownloads: Int = 4,
threadPool: NIOThreadPool = .singleton,
logger: Logger = AWSClient.loggingDisabled,
progress: @escaping (Double) async throws -> Void = { _ in }
progress: @escaping @Sendable (Double) async throws -> Void = { _ in }
) async throws -> Int64 {
let eventLoop = self.client.eventLoopGroup.any()
let fileIO = NonBlockingFileIO(threadPool: threadPool)
Expand All @@ -151,6 +184,7 @@ extension S3 {
downloaded = try await self.multipartDownload(
input,
partSize: partSize,
concurrentDownloads: concurrentDownloads,
logger: logger
) { byteBuffer, fileSize in
let bufferSize = byteBuffer.readableBytes
Expand Down

0 comments on commit 10ac998

Please sign in to comment.