From 0c749224279a27f1b7f9bb9ece7093265455a5f8 Mon Sep 17 00:00:00 2001 From: kean Date: Mon, 29 Apr 2024 21:39:38 -0400 Subject: [PATCH] Improve Async/Await support in ImageTask --- Sources/Nuke/ImageTask.swift | 276 ++++++++++-------- Sources/Nuke/Pipeline/ImagePipeline.swift | 53 ++-- .../ImagePipelineAsyncAwaitTests.swift | 36 ++- 3 files changed, 210 insertions(+), 155 deletions(-) diff --git a/Sources/Nuke/ImageTask.swift b/Sources/Nuke/ImageTask.swift index 9239779da..bbe7dd531 100644 --- a/Sources/Nuke/ImageTask.swift +++ b/Sources/Nuke/ImageTask.swift @@ -25,28 +25,21 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send /// The original request. public let request: ImageRequest - /// Updates the priority of the task, even if it is already running. + /// Updates the priority of the task. The priority can be updated dynamically + /// even if that task is already running. public var priority: ImageRequest.Priority { get { sync { _priority } } - set { - let didChange: Bool = sync { - guard _priority != newValue else { return false } - _priority = newValue - return _state == .running - } - guard didChange else { return } - pipeline?.imageTaskUpdatePriorityCalled(self, priority: newValue) - } + set { setPriority(newValue) } } private var _priority: ImageRequest.Priority /// Returns the current download progress. Returns zeros before the download /// is started and the expected size of the resource is known. public internal(set) var currentProgress: Progress { - get { sync { _progress } } - set { sync { _progress = newValue } } + get { sync { _currentProgress } } + set { sync { _currentProgress = newValue } } } - private var _progress = Progress(completed: 0, total: 0) + private var _currentProgress = Progress(completed: 0, total: 0) /// The download progress. public struct Progress: Hashable, Sendable { @@ -82,14 +75,7 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send case completed } - let isDataTask: Bool - var task: Task? - var continuation: UnsafeContinuation? - var onEvent: ((Event, ImageTask) -> Void)? - weak var pipeline: ImagePipeline? - - /// Using it without a wrapper to reduce the number of allocations. - private let lock: os_unfair_lock_t + // MARK: - Async/Await /// Returns the response image. public var image: PlatformImage { @@ -101,77 +87,78 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send /// The image response. public var response: ImageResponse { get async throws { - guard let task else { - assertionFailure("This should never happen") - throw ImagePipeline.Error.pipelineInvalidated - } - return try await withTaskCancellationHandler { + try await withTaskCancellationHandler { try await task.value } onCancel: { - self.cancel() + cancel() } } } - /// The events sent by the pipeline during the task execution. - public var events: AsyncStream { - os_unfair_lock_lock(lock) - defer { os_unfair_lock_unlock(lock) } - if _context.events == nil { _context.events = AsyncStream.makeStream() } - return _context.events!.0 - } - /// The stream of progress updates. - public var progress: AsyncStream { - os_unfair_lock_lock(lock) - defer { os_unfair_lock_unlock(lock) } - if _context.progress == nil { _context.progress = AsyncStream.makeStream() } - return _context.progress!.0 - } + public var progress: AsyncStream { _progress } /// The stream of responses. /// /// If the progressive decoding is enabled (see ``ImagePipeline/Configuration-swift.struct/isProgressiveDecodingEnabled``), - /// the stream contains all of the progressive scans loaded by the pipeline - /// and finished with the full image. - public var stream: AsyncThrowingStream { - os_unfair_lock_lock(lock) - defer { os_unfair_lock_unlock(lock) } - if _context.progress == nil { _context.stream = AsyncThrowingStream.makeStream() } - return _context.stream!.0 - } + /// and the requested image supports it, the stream contains all of the + /// progressive scans loaded by the pipeline and finishes with the full image. + public var stream: AsyncThrowingStream { _stream } - /// Deprecated in Nuke 12.7. - @available(*, deprecated, renamed: "stream", message: "Please the new `stream` API instead that is now a throwing stream that also contains the full image as the last value") + // Deprecated in Nuke 12.7. + @available(*, deprecated, message: "Please use `stream` instead") public var previews: AsyncStream { _previews } - var _previews: AsyncStream { - AsyncStream { continuation in - Task { - for await event in events { - if case .preview(let response) = event { - continuation.yield(response) - } - } - } - } + /// The events sent by the pipeline during the task execution. + public var events: AsyncStream { _events } + + /// An event produced during the runetime of the task. + public enum Event: Sendable { + /// The download progress was updated. + case progress(Progress) + /// The pipleine generated a progressive scan of the image. + case preview(ImageResponse) + /// The task was cancelled. + /// + /// - note: You are guaranteed to receive either `.cancelled` or + /// `.finished`, but never both. + case cancelled + /// The task finish with the given response. + case finished(Result) } - private var _context = AsyncContext() + let isDataTask: Bool + var onEvent: ((Event, ImageTask) -> Void)? + weak var pipeline: ImagePipeline? + + private var task: Task! + private var continuation: UnsafeContinuation? + + /// Using it without a wrapper to reduce the number of allocations. + private let lock: os_unfair_lock_t + private var context = AsyncContext() deinit { lock.deinitialize(count: 1) lock.deallocate() } - init(taskId: Int64, request: ImageRequest, isDataTask: Bool) { + init(taskId: Int64, request: ImageRequest, isDataTask: Bool, pipeline: ImagePipeline) { self.taskId = taskId self.request = request self._priority = request.priority self.isDataTask = isDataTask + self.pipeline = pipeline lock = .allocate(capacity: 1) lock.initialize(to: os_unfair_lock()) + + task = Task { + try await withUnsafeThrowingContinuation { continuation in + self.continuation = continuation + pipeline.imageTaskStartCalled(self) + } + } } /// Marks task as being cancelled. @@ -184,46 +171,53 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send } } - private func setState(_ state: ImageTask.State) -> Bool { - assert(state == .cancelled || state == .completed) - os_unfair_lock_lock(lock) - guard _state == .running else { - os_unfair_lock_unlock(lock) - return false - } - _state = state - os_unfair_lock_unlock(lock) - return true + // MARK: Hashable + + public func hash(into hasher: inout Hasher) { + hasher.combine(ObjectIdentifier(self).hashValue) } - private func sync(_ closure: () -> T) -> T { - os_unfair_lock_lock(lock) - defer { os_unfair_lock_unlock(lock) } - return closure() + public static func == (lhs: ImageTask, rhs: ImageTask) -> Bool { + ObjectIdentifier(lhs) == ObjectIdentifier(rhs) } - private struct AsyncContext { - typealias Stream = AsyncThrowingStream + // MARK: CustomStringConvertible - var stream: (Stream, Stream.Continuation)? - var events: (AsyncStream, AsyncStream.Continuation)? - var progress: (AsyncStream, AsyncStream.Continuation)? + public var description: String { + "ImageTask(id: \(taskId), priority: \(_priority), progress: \(currentProgress.completed) / \(currentProgress.total), state: \(state))" } - // MARK: Events + // MARK: Internals + + private func setPriority(_ newValue: ImageRequest.Priority) { + let didChange: Bool = sync { + guard _priority != newValue else { return false } + _priority = newValue + return _state == .running + } + guard didChange else { return } + pipeline?.imageTaskUpdatePriorityCalled(self, priority: newValue) + } + + private func setState(_ state: ImageTask.State) -> Bool { + sync { + guard _state == .running else { return false } + _state = state + return true + } + } func process(_ event: Event) { switch event { case .progress(let progress): currentProgress = progress case .finished: - // TODO: do we need to check state? - _ = setState(.completed) + guard setState(.completed) else { return } default: break } - process(event, in: sync { _context }) + process(event, in: sync { context }) onEvent?(event, self) pipeline?.imageTask(self, didProcessEvent: event) } @@ -235,6 +229,7 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send context.progress?.1.yield(progress) case .preview(let response): context.stream?.1.yield(response) + context.previews?.1.yield(response) case .cancelled: context.events?.1.finish() context.progress?.1.finish() @@ -249,52 +244,89 @@ public final class ImageTask: Hashable, CustomStringConvertible, @unchecked Send } } - /// An event produced during the runetime of the task. - public enum Event: Sendable { - /// The download progress was updated. - case progress(Progress) - /// The pipleine generated a progressive scan of the image. - case preview(ImageResponse) - /// The task was cancelled. - /// - /// - note: You are guaranteed to receive either `.cancelled` or - /// `.finished`, but never both. - case cancelled - /// The task finish with the given response. - case finished(Result) + private func sync(_ closure: () -> T) -> T { + os_unfair_lock_lock(lock) + defer { os_unfair_lock_unlock(lock) } + return closure() + } +} + +@available(*, deprecated, renamed: "ImageTask", message: "Async/Await support was addedd directly to the existing `ImageTask` type") +public typealias AsyncImageTask = ImageTask - init(_ event: AsyncTask.Event) { - switch event { - case let .value(response, isCompleted): - if isCompleted { - self = .finished(.success(response)) - } else { - self = .preview(response) - } - case let .progress(value): - self = .progress(value) - case let .error(error): - self = .finished(.failure(error)) +extension ImageTask.Event { + init(_ event: AsyncTask.Event) { + switch event { + case let .value(response, isCompleted): + if isCompleted { + self = .finished(.success(response)) + } else { + self = .preview(response) } + case let .progress(value): + self = .progress(value) + case let .error(error): + self = .finished(.failure(error)) } } +} - // MARK: Hashable +// MARK: - ImageTask (Async) - public func hash(into hasher: inout Hasher) { - hasher.combine(ObjectIdentifier(self).hashValue) +extension ImageTask { + private var _stream: AsyncThrowingStream { + os_unfair_lock_lock(lock) + defer { os_unfair_lock_unlock(lock) } + if context.stream == nil { + context.stream = AsyncThrowingStream.makeStream() + context.stream!.1.onTermination = { [weak self] in + if case .cancelled = $0 { self?.cancel() } + } + } + return context.stream!.0 } - public static func == (lhs: ImageTask, rhs: ImageTask) -> Bool { - ObjectIdentifier(lhs) == ObjectIdentifier(rhs) + private var _progress: AsyncStream { + os_unfair_lock_lock(lock) + defer { os_unfair_lock_unlock(lock) } + if context.progress == nil { + context.progress = AsyncStream.makeStream() + context.progress!.1.onTermination = { [weak self] in + if case .cancelled = $0 { self?.cancel() } + } + } + return context.progress!.0 } - // MARK: CustomStringConvertible + var _previews: AsyncStream { + os_unfair_lock_lock(lock) + defer { os_unfair_lock_unlock(lock) } + if context.previews == nil { + context.previews = AsyncStream.makeStream() + context.previews!.1.onTermination = { [weak self] in + if case .cancelled = $0 { self?.cancel() } + } + } + return context.previews!.0 + } - public var description: String { - "ImageTask(id: \(taskId), priority: \(_priority), progress: \(currentProgress.completed) / \(currentProgress.total), state: \(state))" + /// The events sent by the pipeline during the task execution. + private var _events: AsyncStream { + os_unfair_lock_lock(lock) + defer { os_unfair_lock_unlock(lock) } + if context.events == nil { + context.events = AsyncStream.makeStream() + context.events!.1.onTermination = { [weak self] in + if case .cancelled = $0 { self?.cancel() } + } + } + return context.events!.0 } -} -@available(*, deprecated, renamed: "ImageTask", message: "Async/Await support was addedd directly to the existing `ImageTask` type") -public typealias AsyncImageTask = ImageTask + private struct AsyncContext { + var stream: (AsyncThrowingStream, AsyncThrowingStream.Continuation)? + var previews: (AsyncStream, AsyncStream.Continuation)? + var events: (AsyncStream, AsyncStream.Continuation)? + var progress: (AsyncStream, AsyncStream.Continuation)? + } +} diff --git a/Sources/Nuke/Pipeline/ImagePipeline.swift b/Sources/Nuke/Pipeline/ImagePipeline.swift index f169983ef..0338fc737 100644 --- a/Sources/Nuke/Pipeline/ImagePipeline.swift +++ b/Sources/Nuke/Pipeline/ImagePipeline.swift @@ -286,20 +286,17 @@ public final class ImagePipeline: @unchecked Sendable { // MARK: - ImageTask (Internal) private func makeStartedImageTask(with request: ImageRequest, isDataTask: Bool = false) -> ImageTask { - let task = ImageTask(taskId: nextTaskId, request: request, isDataTask: isDataTask) - task.pipeline = self - task.task = Task { - try await withUnsafeThrowingContinuation { continuation in - queue.async { - task.continuation = continuation - self.startImageTask(task) - } - } - } + let task = ImageTask(taskId: nextTaskId, request: request, isDataTask: isDataTask, pipeline: self) delegate.imageTaskCreated(task, pipeline: self) return task } + private func cancel(_ task: ImageTask) { + guard let subscription = tasks.removeValue(forKey: task) else { return } + task.process(.cancelled) + subscription.unsubscribe() + } + private func startImageTask(_ task: ImageTask) { guard !isInvalidated else { return task.process(.finished(.failure(.pipelineInvalidated))) @@ -317,6 +314,22 @@ public final class ImagePipeline: @unchecked Sendable { delegate.imageTaskDidStart(task, pipeline: self) } + // MARK: - Image Task Events + + func imageTaskCancelCalled(_ task: ImageTask) { + queue.async { self.cancel(task) } + } + + func imageTaskStartCalled(_ task: ImageTask) { + queue.async { self.startImageTask(task) } + } + + func imageTaskUpdatePriorityCalled(_ task: ImageTask, priority: ImageRequest.Priority) { + queue.async { + self.tasks[task]?.setPriority(priority.taskPriority) + } + } + func imageTask(_ task: ImageTask, didProcessEvent event: ImageTask.Event) { switch event { case .cancelled, .finished: @@ -338,26 +351,6 @@ public final class ImagePipeline: @unchecked Sendable { } } - // MARK: - Image Task Events - - func imageTaskCancelCalled(_ task: ImageTask) { - queue.async { - self.cancel(task) - } - } - - private func cancel(_ task: ImageTask) { - guard let subscription = tasks.removeValue(forKey: task) else { return } - task.process(.cancelled) - subscription.unsubscribe() - } - - func imageTaskUpdatePriorityCalled(_ task: ImageTask, priority: ImageRequest.Priority) { - queue.async { - self.tasks[task]?.setPriority(priority.taskPriority) - } - } - // MARK: - Task Factory (Private) // When you request an image or image data, the pipeline creates a graph of tasks diff --git a/Tests/NukeTests/ImagePipelineTests/ImagePipelineAsyncAwaitTests.swift b/Tests/NukeTests/ImagePipelineTests/ImagePipelineAsyncAwaitTests.swift index 9c1dc1d24..b0499966f 100644 --- a/Tests/NukeTests/ImagePipelineTests/ImagePipelineAsyncAwaitTests.swift +++ b/Tests/NukeTests/ImagePipelineTests/ImagePipelineAsyncAwaitTests.swift @@ -161,8 +161,7 @@ class ImagePipelineAsyncAwaitTests: XCTestCase, @unchecked Sendable { XCTAssertTrue(caughtError is CancellationError) } - // TODO: implement - func _testCancelFromEvents() async throws { + func testCancelFromEvents() async throws { dataLoader.queue.isSuspended = true let task = Task { @@ -171,15 +170,46 @@ class ImagePipelineAsyncAwaitTests: XCTestCase, @unchecked Sendable { recordedEvents.append(event) } } + task.cancel() + _ = await task.value + + // THEN nothing is recorded because the task is cancelled and + // stop observing the events + XCTAssertEqual(recordedEvents, []) + } + + func testObserveEventsAndCancelFromOtherTask() async throws { + dataLoader.queue.isSuspended = true + + let task = pipeline.imageTask(with: Test.url) + + let task1 = Task { + for await event in task.events { + recordedEvents.append(event) + } + } + + let task2 = Task { + try await task.response + } + + task2.cancel() + + async let result1: () = task1.value + async let result2 = task2.value + + // THEN you are able to observe `event` update because + // this task does no get cancelled var caughtError: Error? do { - _ = try await task.value + _ = try await (result1, result2) } catch { caughtError = error } XCTAssertTrue(caughtError is CancellationError) + XCTAssertEqual(recordedEvents, [.cancelled]) } func testCancelAsyncImageTask() async throws {