Skip to content

Commit

Permalink
Changes to get soto-s3-file-transfer working with 7.x.x (#700)
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-fowler authored Oct 23, 2023
1 parent cf1de27 commit ed8128e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 62 deletions.
10 changes: 5 additions & 5 deletions Sources/Soto/Extensions/S3/ReportSizeByteBufferSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,24 @@ struct ReportProgressByteBufferAsyncSequence<Base: AsyncSequence>: AsyncSequence
typealias Element = ByteBuffer

let base: Base
let reportFn: @Sendable (Int) throws -> Void
let reportFn: @Sendable (Int) async throws -> Void

struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline
var iterator: Base.AsyncIterator
@usableFromInline
let reportFn: @Sendable (Int) throws -> Void
let reportFn: @Sendable (Int) async throws -> Void

@usableFromInline
init(iterator: Base.AsyncIterator, reportFn: @Sendable @escaping (Int) throws -> Void) {
init(iterator: Base.AsyncIterator, reportFn: @Sendable @escaping (Int) async throws -> Void) {
self.iterator = iterator
self.reportFn = reportFn
}

@inlinable
public mutating func next() async throws -> ByteBuffer? {
if let buffer = try await self.iterator.next() {
try self.reportFn(buffer.readableBytes)
try await self.reportFn(buffer.readableBytes)
return buffer
}
return nil
Expand All @@ -57,7 +57,7 @@ extension ReportProgressByteBufferAsyncSequence: Sendable where Base: Sendable {
extension AsyncSequence where Element == ByteBuffer {
/// Return an AsyncSequence that returns ByteBuffers of a fixed size
/// - Parameter chunkSize: Size of each chunk
func reportProgress(reportFn: @Sendable @escaping (Int) throws -> Void) -> ReportProgressByteBufferAsyncSequence<Self> {
func reportProgress(reportFn: @Sendable @escaping (Int) async throws -> Void) -> ReportProgressByteBufferAsyncSequence<Self> {
return .init(base: self, reportFn: reportFn)
}
}
83 changes: 26 additions & 57 deletions Sources/Soto/Extensions/S3/S3+multipart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,6 @@ extension S3ErrorType {
}

extension S3 {
public struct ThreadPoolProvider {
enum Internal {
case singleton
case shared(NIOThreadPool)
}

let value: Internal
init(_ value: Internal) {
self.value = value
}

public var threadPool: NIOThreadPool {
get async {
switch self.value {
case .singleton:
return await withCheckedContinuation { cont in
cont.resume(returning: NIOThreadPool.singleton)
}
case .shared(let sharedPool):
return sharedPool
}
}
}

public static var singleton: Self { .init(.singleton) }
public static func shared(_ threadPool: NIOThreadPool) -> Self { .init(.shared(threadPool)) }
}

/// Resume Multipart upload request object. This is used as a paramter `resumeMultipartUpload`. It can only be generated by a `abortedUpload` error
public struct ResumeMultipartUploadRequest: Sendable {
let uploadRequest: CreateMultipartUploadRequest
Expand Down Expand Up @@ -156,7 +128,7 @@ extension S3 {
/// - input: The GetObjectRequest shape that contains the details of the object request.
/// - partSize: Size of each part to be downloaded
/// - filename: Filename to save download to
/// - threadPoolProvider: Where to get the thread pool used by the file loader
/// - threadPool: Thread pool used to save file
/// - logger: logger
/// - progress: Callback that returns the progress of the download. It is called after each part is downloaded with a value
/// between 0.0 and 1.0 indicating how far the download is complete (1.0 meaning finished).
Expand All @@ -165,13 +137,11 @@ extension S3 {
_ input: GetObjectRequest,
partSize: Int = 5 * 1024 * 1024,
filename: String,
threadPoolProvider: ThreadPoolProvider = .singleton,
threadPool: NIOThreadPool = .singleton,
logger: Logger = AWSClient.loggingDisabled,
progress: @escaping (Double) throws -> Void = { _ in }
progress: @escaping (Double) async throws -> Void = { _ in }
) async throws -> Int64 {
let eventLoop = self.client.eventLoopGroup.any()

let threadPool = await threadPoolProvider.threadPool
let fileIO = NonBlockingFileIO(threadPool: threadPool)
let fileHandle = try await fileIO.openFile(path: filename, mode: .write, flags: .allowFileCreation(), eventLoop: eventLoop).get()
let progressValue = ManagedAtomic(0)
Expand All @@ -186,7 +156,7 @@ extension S3 {
let bufferSize = byteBuffer.readableBytes
_ = try await fileIO.write(fileHandle: fileHandle, buffer: byteBuffer, eventLoop: eventLoop).get()
let progressIntValue = progressValue.wrappingIncrementThenLoad(by: bufferSize, ordering: .relaxed)
try progress(Double(progressIntValue) / Double(fileSize))
try await progress(Double(progressIntValue) / Double(fileSize))
}
} catch {
try fileHandle.close()
Expand Down Expand Up @@ -221,7 +191,7 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
progress: (@Sendable (Int) throws -> Void)? = nil
progress: (@Sendable (Int) async throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput {
try await self.multipartUpload(
input,
Expand All @@ -248,7 +218,7 @@ extension S3 {
/// - concurrentUploads: Number of uploads to run at one time
/// - abortOnFail: Whether should abort multipart upload if it fails. If you want to attempt to resume after a fail this should
/// be set to false
/// - threadPoolProvider: Provide a thread pool to use or create a new one
/// - threadPool: Thread pool used to load file
/// - logger: logger
/// - progress: Callback that returns the progress of the upload. It is called after each part is uploaded with a value between
/// 0.0 and 1.0 indicating how far the upload is complete (1.0 meaning finished).
Expand All @@ -259,21 +229,21 @@ extension S3 {
filename: String,
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
threadPoolProvider: ThreadPoolProvider = .singleton,
threadPool: NIOThreadPool = .singleton,
logger: Logger = AWSClient.loggingDisabled,
progress: @escaping @Sendable (Double) throws -> Void = { _ in }
progress: @escaping @Sendable (Double) async throws -> Void = { _ in }
) async throws -> CompleteMultipartUploadOutput {
let eventLoop = self.client.eventLoopGroup.any()

return try await openFileForMultipartUpload(
filename: filename,
logger: logger,
on: eventLoop,
threadPoolProvider: threadPoolProvider
threadPool: threadPool
) { fileHandle, fileRegion, fileIO in
let length = Double(fileRegion.readableBytes)
@Sendable func percentProgress(_ value: Int) throws {
try progress(Double(value) / length)
@Sendable func percentProgress(_ value: Int) async throws {
try await progress(Double(value) / length)
}
return try await self.multipartUpload(
input,
Expand Down Expand Up @@ -313,7 +283,7 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
progress: (@Sendable (Int) throws -> Void)? = nil
progress: (@Sendable (Int) async throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput {
try await self.resumeMultipartUpload(
input,
Expand All @@ -338,7 +308,7 @@ extension S3 {
/// - concurrentUploads: Number of uploads to run at one time
/// - abortOnFail: Whether should abort multipart upload if it fails. If you want to attempt to resume after a fail
/// this should be set to false
/// - threadPoolProvider: Provide a thread pool to use or create a new one
/// - threadPool: Thread pool used to load file
/// - progress: Callback that returns the progress of the upload. It is called after each part is uploaded with a value
/// between 0.0 and 1.0 indicating how far the upload is complete (1.0 meaning finished).
/// - returns: Output from CompleteMultipartUpload.
Expand All @@ -349,20 +319,20 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
threadPoolProvider: ThreadPoolProvider = .singleton,
progress: @escaping (Double) throws -> Void = { _ in }
threadPool: NIOThreadPool = .singleton,
progress: @escaping (Double) async throws -> Void = { _ in }
) async throws -> CompleteMultipartUploadOutput {
let eventLoop = self.client.eventLoopGroup.any()

return try await openFileForMultipartUpload(
filename: filename,
logger: logger,
on: eventLoop,
threadPoolProvider: threadPoolProvider
threadPool: threadPool
) { fileHandle, fileRegion, fileIO in
let length = Double(fileRegion.readableBytes)
@Sendable func percentProgress(_ value: Int) throws {
try progress(Double(value) / length)
@Sendable func percentProgress(_ value: Int) async throws {
try await progress(Double(value) / length)
}
return try await self.resumeMultipartUpload(
input,
Expand Down Expand Up @@ -481,7 +451,7 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
progress: (@Sendable (Int) throws -> Void)? = nil
progress: (@Sendable (Int) async throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput where ByteBufferSequence.Element == ByteBuffer {
// initialize multipart upload
let upload = try await createMultipartUpload(input, logger: logger)
Expand Down Expand Up @@ -561,7 +531,7 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
progress: (@Sendable (Int) throws -> Void)? = nil
progress: (@Sendable (Int) async throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput where ByteBufferSequence.Element == ByteBuffer {
// upload all the parts
let partsSet = Set<Int>(input.completedParts.map { $0.partNumber! - 1 })
Expand Down Expand Up @@ -601,7 +571,7 @@ extension S3 {
concurrentUploads: Int = 4,
abortOnFail: Bool = true,
logger: Logger = AWSClient.loggingDisabled,
progress: (@Sendable (Int) throws -> Void)? = nil
progress: (@Sendable (Int) async throws -> Void)? = nil
) async throws -> CompleteMultipartUploadOutput where PartsSequence.Element == (Int, ByteBuffer) {
let uploadRequest = input.uploadRequest

Expand Down Expand Up @@ -666,10 +636,9 @@ extension S3 {
filename: String,
logger: Logger,
on eventLoop: EventLoop,
threadPoolProvider: ThreadPoolProvider,
threadPool: NIOThreadPool,
uploadCallback: @escaping (NIOFileHandle, FileRegion, NonBlockingFileIO) async throws -> CompleteMultipartUploadOutput
) async throws -> CompleteMultipartUploadOutput {
let threadPool = await threadPoolProvider.threadPool
let fileIO = NonBlockingFileIO(threadPool: threadPool)
let (fileHandle, fileRegion) = try await fileIO.openFile(path: filename, eventLoop: eventLoop).get()

Expand Down Expand Up @@ -703,14 +672,14 @@ extension S3 {
concurrentUploads: Int,
initialProgress: Int,
logger: Logger,
progress: (@Sendable (Int) throws -> Void)?
progress: (@Sendable (Int) async throws -> Void)?
) async throws -> [S3.CompletedPart] where PartSequence.Element == (Int, ByteBuffer) {
var newProgress: (@Sendable (Int) throws -> Void)?
var newProgress: (@Sendable (Int) async throws -> Void)?
if let progress = progress {
let size = ManagedAtomic(initialProgress)
@Sendable func accumulatingProgress(_ amount: Int) throws {
@Sendable func accumulatingProgress(_ amount: Int) async throws {
let totalSize = size.wrappingIncrementThenLoad(by: amount, ordering: .relaxed)
try progress(totalSize)
try await progress(totalSize)
}
newProgress = accumulatingProgress
}
Expand Down

0 comments on commit ed8128e

Please sign in to comment.