diff --git a/README.md b/README.md index 6798091..0b3b77d 100644 --- a/README.md +++ b/README.md @@ -23,50 +23,29 @@ let connection = SSHConnection( ) ) -connection.start(withTimeout: 3.0) { result in - switch result { - case .success: - // Handle connection - case .failure: - // Handle failure - } -} +try await connection.start() ``` Once connected, you can start executing concrete SSH operations on child communication channels. As `SSH Client` means to be a high level interface, you do not directly interact with them. Instead you use interfaces dedicated to your use case. - SSH shell ```swift -connection.requestShell(withTimeout: 3.0) { result in - switch result { - case .success(let shell): - // Start shell operations - ... - } +let shell = try await connection.requestShell() +for try await chunk in shell.data { + // ... } ``` - SFTP client ```swift -connection.requestSFTPClient(withTimeout: 3.0) { result in - switch result { - case .success(let client): - // Start sftp operations - ... - } -} +let sftpClient = try await connection.requestSFTPClient() +// sftp operations ``` - SSH commands ```swift -connection.execute("echo Hello\n", withTimeout: 3.0) { result in - switch result { - case .success(let response): - // Handle response - case .failure: - // Handle failure - } -} +let response = try await connection.execute("echo Hello\n") +// Handle response ``` You keep track of the connection state, using the dedicated `stateUpdateHandler` property: diff --git a/Sources/SSHClient/Async/Completion+Async.swift b/Sources/SSHClient/Async/Completion+Async.swift new file mode 100644 index 0000000..7beeb03 --- /dev/null +++ b/Sources/SSHClient/Async/Completion+Async.swift @@ -0,0 +1,63 @@ + +import Foundation + +func withCheckedResultContinuation(_ operation: (_ completion: @escaping (Result) -> Void) -> Void) async throws -> T { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + operation { result in + switch result { + case .success(let success): + continuation.resume(returning: success) + case .failure(let failure): + continuation.resume(throwing: failure) + } + } + } +} + +func withTaskCancellationHandler(_ operation: (_ completion: @escaping (Result) -> Void) -> SSHTask) async throws -> T { + let action = TaskAction() + return try await withTaskCancellationHandler(operation: { + try await withCheckedResultContinuation { completion in + let task = operation(completion) + action.setTask(task) + } + }, onCancel: { + action.cancel() + }) +} + +// inspired by https://github.com/swift-server/async-http-client/blob/main/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient%2Bexecute.swift#L155 +actor TaskAction { + enum State { + case initialized + case task(SSHTask) + case ended + } + + private var state: State = .initialized + + nonisolated func setTask(_ task: SSHTask) { + Task { + await _setTask(task) + } + } + + nonisolated func cancel() { + Task { + await _cancel() + } + } + + private func _setTask(_ task: SSHTask) { + state = .task(task) + } + + private func _cancel() { + switch state { + case .ended, .initialized: + break + case .task(let task): + task.cancel() + } + } +} diff --git a/Sources/SSHClient/Async/SFTPClient+Async.swift b/Sources/SSHClient/Async/SFTPClient+Async.swift new file mode 100644 index 0000000..8f1d1be --- /dev/null +++ b/Sources/SSHClient/Async/SFTPClient+Async.swift @@ -0,0 +1,108 @@ + +import Foundation + +public extension SFTPFile { + func readAttributes() async throws -> SFTPFileAttributes { + try await withCheckedResultContinuation { completion in + readAttributes(completion: completion) + } + } + + func read(from offset: UInt64 = 0, + length: UInt32 = .max) async throws -> Data { + try await withCheckedResultContinuation { completion in + read(from: offset, length: length, completion: completion) + } + } + + func write(_ data: Data, + at offset: UInt64 = 0) async throws { + try await withCheckedResultContinuation { completion in + write(data, at: offset, completion: completion) + } + } + + func close() async throws { + try await withCheckedResultContinuation { completion in + close(completion: completion) + } + } +} + +public extension SFTPClient { + func openFile(filePath: String, + flags: SFTPOpenFileFlags, + attributes: SFTPFileAttributes = .none) async throws -> SFTPFile { + try await withCheckedResultContinuation { completion in + openFile( + filePath: filePath, + flags: flags, + attributes: attributes, + completion: completion + ) + } + } + + func withFile(filePath: String, + flags: SFTPOpenFileFlags, + attributes: SFTPFileAttributes = .none, + _ closure: @escaping (SFTPFile) async -> Void) async throws { + try await withCheckedResultContinuation { completion in + withFile( + filePath: filePath, + flags: flags, + attributes: attributes, { file, close in + Task { + await closure(file) + close() + } + }, + completion: completion + ) + } + } + + func listDirectory(atPath path: String) async throws -> [SFTPPathComponent] { + try await withCheckedResultContinuation { completion in + listDirectory(atPath: path, completion: completion) + } + } + + func getAttributes(at filePath: String) async throws -> SFTPFileAttributes { + try await withCheckedResultContinuation { completion in + getAttributes(at: filePath, completion: completion) + } + } + + func createDirectory(atPath path: String, + attributes: SFTPFileAttributes = .none) async throws { + try await withCheckedResultContinuation { completion in + createDirectory(atPath: path, attributes: attributes, completion: completion) + } + } + + func moveItem(atPath current: String, + toPath destination: String) async throws { + try await withCheckedResultContinuation { completion in + moveItem(atPath: current, toPath: destination, completion: completion) + } + } + + func removeDirectory(atPath path: String) async throws { + try await withCheckedResultContinuation { completion in + removeDirectory(atPath: path, completion: completion) + } + } + + func removeFile(atPath path: String) async throws { + try await withCheckedResultContinuation { completion in + removeFile(atPath: path, completion: completion) + } + } + + func close() async { + await withCheckedContinuation { continuation in + close(completion: continuation.resume) + } + } +} diff --git a/Sources/SSHClient/Async/SSHConnection+Async.swift b/Sources/SSHClient/Async/SSHConnection+Async.swift new file mode 100644 index 0000000..7eba491 --- /dev/null +++ b/Sources/SSHClient/Async/SSHConnection+Async.swift @@ -0,0 +1,88 @@ + +import Foundation + +public extension SSHConnection { + typealias AsyncSSHCommandResponse = AsyncThrowingStream + + func start(withTimeout timeout: TimeInterval? = nil) async throws { + try await withCheckedResultContinuation { completion in + start(withTimeout: timeout, completion: completion) + } + } + + func cancel() async { + await withCheckedContinuation { continuation in + cancel(completion: continuation.resume) + } + } + + func execute(_ command: SSHCommand, + withTimeout timeout: TimeInterval? = nil) async throws -> SSHCommandResponse { + try await withTaskCancellationHandler { completion in + execute(command, withTimeout: timeout, completion: completion) + } + } + + func requestShell(withTimeout timeout: TimeInterval? = nil) async throws -> SSHShell { + try await withTaskCancellationHandler { completion in + requestShell(withTimeout: timeout, completion: completion) + } + } + + func requestSFTPClient(withTimeout timeout: TimeInterval? = nil) async throws -> SFTPClient { + try await withTaskCancellationHandler { completion in + requestSFTPClient(withTimeout: timeout, completion: completion) + } + } + + func stream(_ command: SSHCommand, + withTimeout timeout: TimeInterval? = nil) async throws -> AsyncSSHCommandResponse { + try await withTaskCancellationHandler { completion in + enum State { + case initializing + case streaming(AsyncSSHCommandResponse.Continuation) + } + let action = TaskAction() + // Each callback are executed on the internal serial ssh connection queue. + // This is thread safe to modify the state inside them. + var state: State = .initializing + let stream = { (responseChunk: SSHCommandResponseChunk) in + switch state { + case .initializing: + let response = AsyncSSHCommandResponse { continuation in + state = .streaming(continuation) + continuation.onTermination = { _ in + action.cancel() + } + continuation.yield(responseChunk) + } + completion(.success(response)) + case .streaming(let continuation): + continuation.yield(responseChunk) + } + } + let resultTask = execute( + command, + withTimeout: timeout + ) { chunk in + stream(.chunk(chunk)) + } onStatus: { st in + stream(.status(st)) + } completion: { result in + switch state { + case .initializing: + completion(.failure(SSHConnectionError.unknown)) + case .streaming(let continuation): + switch result { + case .success: + continuation.finish() + case .failure(let error): + continuation.finish(throwing: error) + } + } + } + action.setTask(resultTask) + return resultTask + } + } +} diff --git a/Sources/SSHClient/Async/SSHShell+Async.swift b/Sources/SSHClient/Async/SSHShell+Async.swift new file mode 100644 index 0000000..51aee8b --- /dev/null +++ b/Sources/SSHClient/Async/SSHShell+Async.swift @@ -0,0 +1,37 @@ +// +// SSHShell+async.swift +// Atomics +// +// Created by Gaetan Zanella on 10/04/2023. +// + +import Foundation + +public extension SSHShell { + typealias AsyncBytes = AsyncThrowingStream + + var data: AsyncBytes { + AsyncBytes { continuation in + let readID = addReadListener { continuation.yield($0) } + let closeID = addCloseListener { error in + continuation.finish(throwing: error) + } + continuation.onTermination = { [weak self] _ in + self?.removeReadListener(readID) + self?.removeCloseListener(closeID) + } + } + } + + func write(_ data: Data) async throws { + try await withCheckedResultContinuation { completion in + write(data, completion: completion) + } + } + + func close() async throws { + try await withCheckedResultContinuation { completion in + close(completion: completion) + } + } +} diff --git a/Sources/SSHClient/Internal/Command/SSHCommandSession.swift b/Sources/SSHClient/Internal/Command/SSHCommandSession.swift index fdefcf1..2b59b9e 100644 --- a/Sources/SSHClient/Internal/Command/SSHCommandSession.swift +++ b/Sources/SSHClient/Internal/Command/SSHCommandSession.swift @@ -4,37 +4,35 @@ import NIOSSH class SSHCommandSession: SSHSession { private let invocation: SSHCommandInvocation - private let promise: Promise - var futureResult: Future { - promise.futureResult - } + private var promise: Promise? // MARK: - Life Cycle - init(invocation: SSHCommandInvocation, - promise: Promise) { + init(invocation: SSHCommandInvocation) { self.invocation = invocation - self.promise = promise } deinit { - promise.fail(SSHConnectionError.unknown) + promise?.fail(SSHConnectionError.unknown) } // MARK: - SSHSession func start(in context: SSHSessionContext) { + promise = context.promise let channel = context.channel - let result = channel.pipeline.addHandlers( + channel.pipeline.addHandlers( [ SSHCommandHandler( invocation: invocation, - promise: promise + promise: context.promise ), ] ) - context.promise.completeWith(result) + .whenFailure { error in + context.promise.fail(error) + } } } @@ -78,6 +76,7 @@ private class SSHCommandHandler: ChannelDuplexHandler { func errorCaught(context: ChannelHandlerContext, error: Error) { context.channel.close(promise: nil) promise.fail(SSHConnectionError.unknown) + context.fireErrorCaught(error) } func handlerRemoved(context: ChannelHandlerContext) { @@ -85,6 +84,9 @@ private class SSHCommandHandler: ChannelDuplexHandler { } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + defer { + context.fireUserInboundEventTriggered(event) + } switch event { case let event as SSHChannelRequestEvent.ExitStatus: invocation.onStatus?(SSHCommandStatus(exitStatus: event.exitStatus)) @@ -96,11 +98,14 @@ private class SSHCommandHandler: ChannelDuplexHandler { break } default: - context.fireUserInboundEventTriggered(event) + break } } func channelRead(context: ChannelHandlerContext, data: NIOAny) { + defer { + context.fireChannelRead(data) + } let channelData = unwrapInboundIn(data) guard case .byteBuffer(var bytes) = channelData.data, let data = bytes.readData(length: bytes.readableBytes) diff --git a/Sources/SSHClient/Internal/Extension/Result+Utils.swift b/Sources/SSHClient/Internal/Extension/Result+Utils.swift index dae8364..efc2c1f 100644 --- a/Sources/SSHClient/Internal/Extension/Result+Utils.swift +++ b/Sources/SSHClient/Internal/Extension/Result+Utils.swift @@ -3,6 +3,6 @@ import Foundation extension Result { func mapThrowing(_ block: (Success) throws -> New) -> Result { - Result { try block(try get()) } + Result { try block(get()) } } } diff --git a/Sources/SSHClient/Internal/SFTP/SFTPMessage.swift b/Sources/SSHClient/Internal/SFTP/SFTPMessage.swift index 2402c10..dd1a801 100644 --- a/Sources/SSHClient/Internal/SFTP/SFTPMessage.swift +++ b/Sources/SSHClient/Internal/SFTP/SFTPMessage.swift @@ -7,7 +7,7 @@ typealias Future = EventLoopFuture typealias SFTPFileHandle = ByteBuffer typealias SFTPRequestID = UInt32 -public struct SFTPPathComponent { +public struct SFTPPathComponent: Sendable { public let filename: String public let longname: String public let attributes: SFTPFileAttributes diff --git a/Sources/SSHClient/Internal/SSH/IOSSHConnection.swift b/Sources/SSHClient/Internal/SSH/IOSSHConnection.swift index d25fdf5..1d42c33 100644 --- a/Sources/SSHClient/Internal/SSH/IOSSHConnection.swift +++ b/Sources/SSHClient/Internal/SSH/IOSSHConnection.swift @@ -12,7 +12,7 @@ class IOSSHConnection { private var stateMachine: SSHConnectionStateMachine private let eventLoopGroup: EventLoopGroup - private var eventLoop: EventLoop { + var eventLoop: EventLoop { eventLoopGroup.any() } @@ -57,20 +57,7 @@ class IOSSHConnection { } } - func execute(_ command: SSHCommandInvocation, - timeout: TimeInterval) -> Future { - let promise = eventLoop.makePromise(of: Void.self) - let session = SSHCommandSession(invocation: command, promise: promise) - return start(session, timeout: timeout).flatMap { - session.futureResult - } - .flatMap { - // TODO: Fix hack. We keep the session alive as long as the promise is running. - session.futureResult - } - } - - func start(_ session: SSHSession, + func start(_ session: SSHSessionStartingTask, timeout: TimeInterval) -> Future { let promise = eventLoop.makePromise(of: Void.self) return eventLoop.submit { @@ -79,6 +66,11 @@ class IOSSHConnection { .flatMap { promise.futureResult } + .map { + // TODO: Fix hack. We keep the session alive as long as the promise is running. + session + } + .mapAsVoid() } // MARK: - Private @@ -100,7 +92,7 @@ class IOSSHConnection { case .disconnect(let channel): disconnect(channel: channel) case .requestSession(let channel, let session, let timeout, let promise): - startSession(channel: channel, session: session, timeout: timeout, promise: promise) + startSession(channel: channel, sessionTask: session, timeout: timeout, promise: promise) case .callPromise(let promise, let result): promise.end(result) case .none: @@ -190,9 +182,10 @@ class IOSSHConnection { } private func startSession(channel: Channel, - session: SSHSession, + sessionTask: SSHSessionStartingTask, timeout: TimeInterval, promise: Promise) { + let session = sessionTask.session let createChannel = channel.eventLoop.makePromise(of: Channel.self) channel.eventLoop.scheduleTask(in: .seconds(Int64(timeout))) { createChannel.fail(SSHConnectionError.timeout) @@ -205,6 +198,7 @@ class IOSSHConnection { return createChannel .futureResult .flatMap { channel in + sessionTask.didLaunchSession(channel) let createSession = channel.eventLoop.makePromise(of: Void.self) // TODO: We should only consider the remaining time, but that's ok channel.eventLoop.scheduleTask(in: .seconds(Int64(timeout))) { @@ -227,6 +221,9 @@ class IOSSHConnection { } } } + result.whenComplete { result in + sessionTask.didEnd(result) + } promise.completeWith(result) } } diff --git a/Sources/SSHClient/Internal/SSH/ObserverHolder.swift b/Sources/SSHClient/Internal/SSH/ObserverHolder.swift new file mode 100644 index 0000000..5df823a --- /dev/null +++ b/Sources/SSHClient/Internal/SSH/ObserverHolder.swift @@ -0,0 +1,64 @@ + +import Foundation + +struct ObserverToken: Hashable { + private enum Content: Hashable { + case id(UUID) + case publicAPI + } + + private var content: Content + + init() { + content = .id(UUID()) + } + + private init(content: Content) { + self.content = content + } + + static func publicAPI() -> ObserverToken { + ObserverToken(content: .publicAPI) + } +} + +class BlockObserverHolder { + typealias Observer = (Value) -> Void + + private var observers: [ObserverToken: Observer] = [:] + private let lock = NSLock() + + func call(with value: Value) { + lock.withLock { + observers + } + .forEach { $1(value) } + } + + func observer(for token: ObserverToken) -> (Observer)? { + observers[token] + } + + func add(_ block: Observer?, for token: ObserverToken) { + if let block = block { + lock.withLock { + observers[token] = block + } + } else { + removeObserver(token) + } + } + + @discardableResult + func add(_ block: @escaping Observer) -> ObserverToken { + let token = ObserverToken() + add(block, for: token) + return token + } + + func removeObserver(_ token: ObserverToken) { + _ = lock.withLock { + observers.removeValue(forKey: token) + } + } +} diff --git a/Sources/SSHClient/Internal/SSH/SSHConnectionStateMachine.swift b/Sources/SSHClient/Internal/SSH/SSHConnectionStateMachine.swift index eeec667..43791b0 100644 --- a/Sources/SSHClient/Internal/SSH/SSHConnectionStateMachine.swift +++ b/Sources/SSHClient/Internal/SSH/SSHConnectionStateMachine.swift @@ -6,7 +6,7 @@ import NIOSSH enum SSHConnectionEvent { case requestDisconnection(Promise) case requestConnection(TimeInterval, Promise) - case requestSession(SSHSession, TimeInterval, Promise) + case requestSession(SSHSessionStartingTask, TimeInterval, Promise) case connected(Channel) case authenticated(Channel) case disconnected @@ -15,7 +15,7 @@ enum SSHConnectionEvent { enum SSHConnectionAction { case none case disconnect(Channel) - case requestSession(Channel, SSHSession, TimeInterval, Promise) + case requestSession(Channel, SSHSessionStartingTask, TimeInterval, Promise) case connect(TimeInterval) case callPromise(Promise, Result) } diff --git a/Sources/SSHClient/Internal/SSH/SSHSessionStartingTask.swift b/Sources/SSHClient/Internal/SSH/SSHSessionStartingTask.swift new file mode 100644 index 0000000..d0e0a99 --- /dev/null +++ b/Sources/SSHClient/Internal/SSH/SSHSessionStartingTask.swift @@ -0,0 +1,159 @@ + +import Foundation +import NIOCore +import NIOSSH + +class SSHSessionStartingTask: SSHTask { + private var stateMachine = SSHSessionStartingTaskStateMachine() + + let session: SSHSession + let eventLoop: EventLoop + + init(session: SSHSession, eventLoop: EventLoop) { + self.session = session + self.eventLoop = eventLoop + } + + func didEnd(_ result: Result) { + switch result { + case .success: + _ = eventLoop.submit { [weak self] in + self?.trigger(.ended) + } + case .failure: + _ = eventLoop.submit { [weak self] in + self?.trigger(.fail) + } + } + } + + func didLaunchSession(_ channel: Channel) { + _ = eventLoop.submit { [weak self] in + self?.trigger(.launching(channel)) + } + } + + func cancel() { + _ = eventLoop.submit { [weak self] in + self?.trigger(.cancelled) + } + } + + private func trigger(_ event: SSHSessionStartingTaskEvent) { + let action = stateMachine.handle(event) + handle(action) + } + + private func handle(_ action: SSHSessionStartingTaskAction) { + switch action { + case .none: + break + case .end(let channel): + _ = channel.close() + } + } +} + +enum SSHSessionStartingTaskEvent { + case cancelled + case launching(Channel) + case ended + case fail +} + +enum SSHSessionStartingTaskAction { + case none + case end(Channel) +} + +struct SSHSessionStartingTaskStateMachine { + enum State { + case initialized + case cancelled + case launching(Channel) + case cancelling(Channel) + case ended + } + + private var state: State = .initialized + + mutating func handle(_ event: SSHSessionStartingTaskEvent) -> SSHSessionStartingTaskAction { + print("eventĀ \(event)") + print("stateĀ \(state)") + switch state { + case .initialized: + switch event { + case .cancelled: + state = .cancelled + return .none + case .launching(let channel): + state = .launching(channel) + return .none + case .ended: + assertionFailure("Invalid transition") + return .none + case .fail: + state = .ended + return .none + } + case .cancelled: + switch event { + case .cancelled: + return .none + case .launching(let channel): + state = .cancelling(channel) + return .end(channel) + case .ended: + state = .ended + return .none + case .fail: + state = .ended + return .none + } + case .launching(let channel): + switch event { + case .cancelled: + state = .cancelling(channel) + return .end(channel) + case .ended: + state = .ended + return .none + case .launching: + assertionFailure("Invalid transition") + return .none + case .fail: + state = .ended + return .none + } + case .cancelling: + switch event { + case .cancelled: + return .none + case .launching: + assertionFailure("Invalid transition") + return .none + case .ended: + state = .ended + return .none + case .fail: + state = .ended + return .none + } + case .ended: + switch event { + case .cancelled: + // we ignore the cancellation if the task already ended. + return .none + case .launching: + assertionFailure("Invalid transition") + return .none + case .ended: + assertionFailure("Invalid transition") + return .none + case .fail: + assertionFailure("Invalid transition") + return .none + } + } + } +} diff --git a/Sources/SSHClient/SFTPClient.swift b/Sources/SSHClient/SFTPClient.swift index 8a462c2..3216e6c 100644 --- a/Sources/SSHClient/SFTPClient.swift +++ b/Sources/SSHClient/SFTPClient.swift @@ -7,8 +7,8 @@ public enum SFTPClientError: Error { case unknown } -public final class SFTPClient: SSHSession { - enum State: Equatable { +public final class SFTPClient: @unchecked Sendable, SSHSession { + enum State: Sendable, Equatable { case idle case ready case closed diff --git a/Sources/SSHClient/SFTPFile.swift b/Sources/SSHClient/SFTPFile.swift index 233c5e4..9a529a1 100644 --- a/Sources/SSHClient/SFTPFile.swift +++ b/Sources/SSHClient/SFTPFile.swift @@ -1,7 +1,7 @@ import Foundation import NIO -public final class SFTPFile { +public final class SFTPFile: @unchecked Sendable { private var isActive: Bool private let handle: SFTPFileHandle diff --git a/Sources/SSHClient/SFTPFileFlags.swift b/Sources/SSHClient/SFTPFileFlags.swift index 85f7320..fabe3af 100644 --- a/Sources/SSHClient/SFTPFileFlags.swift +++ b/Sources/SSHClient/SFTPFileFlags.swift @@ -1,6 +1,6 @@ import Foundation -public struct SFTPOpenFileFlags: OptionSet { +public struct SFTPOpenFileFlags: Sendable, OptionSet { public var rawValue: UInt32 public init(rawValue: UInt32) { @@ -44,7 +44,7 @@ public struct SFTPOpenFileFlags: OptionSet { public static let forceCreate = SFTPOpenFileFlags(rawValue: 0x0000_0020) } -public struct SFTPFileAttributes: CustomDebugStringConvertible { +public struct SFTPFileAttributes: Sendable, CustomDebugStringConvertible { public typealias Permissions = UInt32 public typealias ExtendedData = [(String, String)] @@ -63,7 +63,7 @@ public struct SFTPFileAttributes: CustomDebugStringConvertible { public static let extended = Flags(rawValue: 0x8000_0000) } - public struct UserGroupId { + public struct UserGroupId: Sendable { public let userId: UInt32 public let groupId: UInt32 @@ -76,7 +76,7 @@ public struct SFTPFileAttributes: CustomDebugStringConvertible { } } - public struct AccessModificationTime { + public struct AccessModificationTime: Sendable { // Both written as UInt32 seconds since jan 1 1970 as UTC public let accessTime: Date public let modificationTime: Date diff --git a/Sources/SSHClient/SSHCommand.swift b/Sources/SSHClient/SSHCommand.swift index 35ee183..96ca012 100644 --- a/Sources/SSHClient/SSHCommand.swift +++ b/Sources/SSHClient/SSHCommand.swift @@ -1,12 +1,17 @@ import Foundation -public struct SSHCommandStatus { +public struct SSHCommandStatus: Sendable { public let exitStatus: Int } -public struct SSHCommandChunk { - public enum Channel { +public enum SSHCommandResponseChunk: Sendable { + case chunk(SSHCommandChunk) + case status(SSHCommandStatus) +} + +public struct SSHCommandChunk: Sendable { + public enum Channel: Sendable { case standard case error } @@ -15,7 +20,7 @@ public struct SSHCommandChunk { public let data: Data } -public struct SSHCommand { +public struct SSHCommand: Sendable { public let command: String public init(_ command: String) { @@ -23,7 +28,7 @@ public struct SSHCommand { } } -public struct SSHCommandResponse { +public struct SSHCommandResponse: Sendable { public let command: SSHCommand public let status: SSHCommandStatus public let standardOutput: Data? diff --git a/Sources/SSHClient/SSHConnection.swift b/Sources/SSHClient/SSHConnection.swift index 7dfef26..47cb019 100644 --- a/Sources/SSHClient/SSHConnection.swift +++ b/Sources/SSHClient/SSHConnection.swift @@ -8,8 +8,8 @@ public enum SSHConnectionError: Error { case timeout } -public class SSHConnection { - public enum State: Equatable { +public class SSHConnection: @unchecked Sendable { + public enum State: Sendable, Equatable { case idle, ready, failed(SSHConnectionError) } @@ -29,6 +29,9 @@ public class SSHConnection { ioConnection.port } + private let defaultTimeout: TimeInterval + private var stateUpdateListeners = BlockObserverHolder() + private let ioConnection: IOSSHConnection private let updateQueue: DispatchQueue @@ -37,24 +40,32 @@ public class SSHConnection { public init(host: String, port: UInt16, authentication: SSHAuthentication, - updateQueue: DispatchQueue = .main) { + defaultTimeout: TimeInterval = 15.0) { ioConnection = IOSSHConnection( host: host, port: port, authentication: authentication, eventLoopGroup: MultiThreadedEventLoopGroup.ssh ) - self.updateQueue = updateQueue + self.defaultTimeout = defaultTimeout + updateQueue = DispatchQueue(label: "ssh_connection") setupIOConnection() } // MARK: - Connection - public var stateUpdateHandler: ((State) -> Void)? + public var stateUpdateHandler: ((State) -> Void)? { + set { + stateUpdateListeners.add(newValue, for: .publicAPI()) + } + get { + stateUpdateListeners.observer(for: .publicAPI()) + } + } - public func start(withTimeout timeout: TimeInterval, + public func start(withTimeout timeout: TimeInterval? = nil, completion: @escaping (Result) -> Void) { - ioConnection.start(timeout: timeout).whenComplete(on: updateQueue, completion) + ioConnection.start(timeout: timeout ?? defaultTimeout).whenComplete(on: updateQueue, completion) } public func cancel(completion: @escaping () -> Void) { @@ -65,23 +76,28 @@ public class SSHConnection { // MARK: - Clients - public func requestShell(withTimeout timeout: TimeInterval, - updateQueue: DispatchQueue = .main, - completion: @escaping (Result) -> Void) { + @discardableResult + public func requestShell(withTimeout timeout: TimeInterval? = nil, + completion: @escaping (Result) -> Void) -> SSHTask { let shell = SSHShell( ioShell: IOSSHShell( eventLoop: MultiThreadedEventLoopGroup.ssh.any() ), updateQueue: updateQueue ) - ioConnection.start(shell, timeout: timeout) + let task = SSHSessionStartingTask( + session: shell, + eventLoop: ioConnection.eventLoop + ) + ioConnection.start(task, timeout: timeout ?? defaultTimeout) .map { shell } - .whenComplete(on: self.updateQueue, completion) + .whenComplete(on: updateQueue, completion) + return task } - public func requestSFTPClient(withTimeout timeout: TimeInterval, - updateQueue: DispatchQueue = .main, - completion: @escaping (Result) -> Void) { + @discardableResult + public func requestSFTPClient(withTimeout timeout: TimeInterval? = nil, + completion: @escaping (Result) -> Void) -> SSHTask { let sftpClient = SFTPClient( sftpChannel: IOSFTPChannel( idAllocator: MonotonicRequestIDAllocator(start: 0), @@ -89,51 +105,86 @@ public class SSHConnection { ), updateQueue: updateQueue ) - ioConnection.start(sftpClient, timeout: timeout) + let task = SSHSessionStartingTask( + session: sftpClient, + eventLoop: ioConnection.eventLoop + ) + ioConnection.start(task, timeout: timeout ?? defaultTimeout) .map { sftpClient } - .whenComplete(on: self.updateQueue, completion) + .whenComplete(on: updateQueue, completion) + return task } // MARK: - Commands + @discardableResult + func execute(_ command: SSHCommand, + withTimeout timeout: TimeInterval? = nil, + onChunk: @escaping (SSHCommandChunk) -> Void, + onStatus: @escaping (SSHCommandStatus) -> Void, + completion: @escaping (Result) -> Void) -> SSHTask { + let invocation = SSHCommandInvocation( + command: command, + onChunk: { [weak self] chunk in + self?.updateQueue.async { + onChunk(chunk) + } + }, + onStatus: { [weak self] st in + self?.updateQueue.async { + onStatus(st) + } + } + ) + let task = SSHSessionStartingTask( + session: SSHCommandSession(invocation: invocation), + eventLoop: ioConnection.eventLoop + ) + ioConnection + .start(task, timeout: timeout ?? defaultTimeout) + .whenComplete(on: updateQueue, completion) + return task + } + + @discardableResult public func execute(_ command: SSHCommand, - withTimeout timeout: TimeInterval, - completion: @escaping (Result) -> Void) { + withTimeout timeout: TimeInterval? = nil, + completion: @escaping (Result) -> Void) -> SSHTask { var standard: Data? var error: Data? var status: SSHCommandStatus? - ioConnection.execute( - SSHCommandInvocation( - command: command, - onChunk: { chunk in - switch chunk.channel { - case .standard: - if standard == nil { - standard = Data() - } - standard?.append(chunk.data) - case .error: - if error == nil { - error = Data() - } - error?.append(chunk.data) + return execute( + command, + withTimeout: timeout, + onChunk: { chunk in + switch chunk.channel { + case .standard: + if standard == nil { + standard = Data() } - }, - onStatus: { st in status = st } - ), - timeout: timeout + standard?.append(chunk.data) + case .error: + if error == nil { + error = Data() + } + error?.append(chunk.data) + } + }, + onStatus: { + status = $0 + }, + completion: { result in + completion(result.mapThrowing { _ in + guard let status = status else { throw SSHConnectionError.unknown } + return SSHCommandResponse( + command: command, + status: status, + standardOutput: standard, + errorOutput: error + ) + }) + } ) - .whenComplete(on: updateQueue) { result in - completion(result.mapThrowing { _ in - guard let status = status else { throw SSHConnectionError.unknown } - return SSHCommandResponse( - command: command, - status: status, - standardOutput: standard, - errorOutput: error - ) - }) - } } // MARK: - Private diff --git a/Sources/SSHClient/SSHShell.swift b/Sources/SSHClient/SSHShell.swift index c1098b7..d169b59 100644 --- a/Sources/SSHClient/SSHShell.swift +++ b/Sources/SSHClient/SSHShell.swift @@ -7,7 +7,7 @@ public enum SSHShellError: Error { case unknown } -public class SSHShell: SSHSession { +public class SSHShell: @unchecked Sendable, SSHSession { enum State: Equatable { case idle case ready @@ -18,6 +18,9 @@ public class SSHShell: SSHSession { private let ioShell: IOSSHShell private let updateQueue: DispatchQueue + private var readListeners = BlockObserverHolder() + private var closeListeners = BlockObserverHolder() + // For testing purpose. // We expose a simple `closeHandler` instead of the state as the starting is // entirely managed by `SSHConnection` and a `SSHShell` can not restart. @@ -36,8 +39,23 @@ public class SSHShell: SSHSession { setupIOShell() } - public var readHandler: ((Data) -> Void)? - public var closeHandler: ((SSHShellError?) -> Void)? + public var readHandler: ((Data) -> Void)? { + set { + readListeners.add(newValue, for: .publicAPI()) + } + get { + readListeners.observer(for: .publicAPI()) + } + } + + public var closeHandler: ((SSHShellError?) -> Void)? { + set { + closeListeners.add(newValue, for: .publicAPI()) + } + get { + closeListeners.observer(for: .publicAPI()) + } + } // MARK: - SSHSession @@ -55,12 +73,30 @@ public class SSHShell: SSHSession { ioShell.close().whenComplete(on: updateQueue, completion) } + // MARK: - Internal + + func addReadListener(_ block: @escaping (Data) -> Void) -> ObserverToken { + readListeners.add(block) + } + + func removeReadListener(_ uuid: ObserverToken) { + readListeners.removeObserver(uuid) + } + + func addCloseListener(_ block: @escaping (SSHShellError?) -> Void) -> ObserverToken { + closeListeners.add(block) + } + + func removeCloseListener(_ uuid: ObserverToken) { + closeListeners.removeObserver(uuid) + } + // MARK: - Private private func setupIOShell() { ioShell.readHandler = { [weak self] data in self?.updateQueue.async { - self?.readHandler?(data) + self?.readListeners.call(with: data) } } ioShell.stateUpdateHandler = { [weak self] state in @@ -72,9 +108,9 @@ public class SSHShell: SSHSession { case .ready: break case .closed: - self?.closeHandler?(nil) + self?.closeListeners.call(with: nil) case .failed(let error): - self?.closeHandler?(error) + self?.closeListeners.call(with: error) } } } diff --git a/Sources/SSHClient/SSHTask.swift b/Sources/SSHClient/SSHTask.swift new file mode 100644 index 0000000..6c70af2 --- /dev/null +++ b/Sources/SSHClient/SSHTask.swift @@ -0,0 +1,4 @@ + +public protocol SSHTask: AnyObject { + func cancel() +} diff --git a/Tests/SSHClientTests/SSHAsyncTests.swift b/Tests/SSHClientTests/SSHAsyncTests.swift new file mode 100644 index 0000000..0af1569 --- /dev/null +++ b/Tests/SSHClientTests/SSHAsyncTests.swift @@ -0,0 +1,85 @@ + +import Foundation +import SSHClient +import XCTest + +class SSHAsyncTests: XCTestCase { + var sshServer: SSHServer! + var connection: SSHConnection! + + override func setUp() { + sshServer = DockerSSHServer() + connection = SSHConnection( + host: sshServer.host, + port: sshServer.port, + authentication: sshServer.credentials + ) + } + + // MARK: - Connection + + func testCommandExecution() async throws { + try await connection.start() + await connection.cancel() + } + + func testCommandStreaming() async throws { + try await connection.start() + let stream = try await connection.stream("yes \"long text\" | head -n 10000\n") + var standard = Data() + for try await chunk in stream { + switch chunk { + case .chunk(let output): + standard.append(output.data) + case .status: + break + } + } + XCTAssertEqual(standard.count, 100_000) + await connection.cancel() + } + + func testShell() async throws { + try await connection.start() + let shell = try await connection.requestShell() + let reader = ShellActor() + Task { + do { + for try await data in shell.data { + await reader.addData(data) + } + await reader.end() + } catch { + await reader.fail() + } + } + try await shell.write("echo Hello\n".data(using: .utf8)!) + wait(timeout: 0.5) + try await shell.close() + wait(timeout: 0.5) + let hasFailed = await reader.hasFailed + let isEnded = await reader.isEnded + let result = await reader.result + XCTAssertFalse(hasFailed) + XCTAssertTrue(isEnded) + XCTAssertEqual(result, "Hello\n".data(using: .utf8)!) + } +} + +private actor ShellActor { + private(set) var result = Data() + private(set) var isEnded = false + private(set) var hasFailed = false + + func addData(_ data: Data) { + result.append(data) + } + + func fail() { + hasFailed = true + } + + func end() { + isEnded = true + } +} diff --git a/Tests/SSHClientTests/SSHCommandTests.swift b/Tests/SSHClientTests/SSHCommandTests.swift index 9c41071..19e6954 100644 --- a/Tests/SSHClientTests/SSHCommandTests.swift +++ b/Tests/SSHClientTests/SSHCommandTests.swift @@ -76,6 +76,49 @@ class SSHCommandTests: XCTestCase { wait(for: [exp], timeout: 3) } + func testCommandImmediateCancellation() { + let connection = launchConnection() + let exp = XCTestExpectation() + let task = connection.execute("echo Hello\n") { result in + XCTAssertTrue(result.isFailure) + exp.fulfill() + } + task.cancel() + wait(for: [exp], timeout: 3) + } + + func testCommandDelayedCancellation() { + let connection = launchConnection() + let exp = XCTestExpectation() + var standardOutput = Data() + var chunkCount = 0 + var task: SSHTask? + task = connection.execute( + "yes \"long\" | head -n 1000000\n", + onChunk: { chunk in + standardOutput.append(chunk.data) + chunkCount += 1 + if chunkCount == 5 { + task?.cancel() + } + }, + onStatus: { _ in }, + completion: { result in + XCTAssertTrue(result.isSuccess) + switch result { + case .success: + // 5000000 = successful output size + XCTAssertTrue(standardOutput.count < 5_000_000) + XCTAssertTrue(standardOutput.count > 100) + case .failure: + break + } + exp.fulfill() + } + ) + wait(for: [exp], timeout: 3) + } + // MARK: - Private private func launchConnection() -> SSHConnection { diff --git a/Tests/SSHClientTests/SSHShellTests.swift b/Tests/SSHClientTests/SSHShellTests.swift index 1179930..0005732 100644 --- a/Tests/SSHClientTests/SSHShellTests.swift +++ b/Tests/SSHClientTests/SSHShellTests.swift @@ -4,11 +4,11 @@ import Foundation import XCTest class SSHShellTests: XCTestCase { - var sftpServer: SFTPServer! + var sftpServer: SSHServer! var connection: SSHConnection! override func setUp() { - sftpServer = SFTPServer(configuration: .docker) + sftpServer = DockerSSHServer() connection = SSHConnection( host: sftpServer.host, port: sftpServer.port, @@ -28,7 +28,6 @@ class SSHShellTests: XCTestCase { XCTAssertEqual(shell.state, .ready) } - // TODO: Fix test, this is a hack due to the sftp docker. func testCommand() throws { let shell = try launchShell() let exp = XCTestExpectation() @@ -38,8 +37,7 @@ class SSHShellTests: XCTestCase { } wait(for: [exp], timeout: 2) wait(timeout: 2) - XCTAssertEqual(shell.states, [.failed(.unknown)]) - XCTAssertEqual(shell.data[0], "This service allows sftp connections only.\n".data(using: .utf8)) + XCTAssertEqual(shell.data[0], "/config\n".data(using: .utf8)) } func testClosing() throws { @@ -54,6 +52,32 @@ class SSHShellTests: XCTestCase { XCTAssertEqual(shell.state, .closed) } + func testShellImmediateCancellation() throws { + let connection = try launchConnection() + let exp = XCTestExpectation() + let task = connection.requestShell { result in + XCTAssert(result.isFailure) + exp.fulfill() + } + task.cancel() + wait(for: [exp], timeout: 2) + } + + func testShellDelayedCancellationShouldDoNothing() throws { + let connection = try launchConnection() + let exp = XCTestExpectation() + var task: SSHTask? + var shell: EmbeddedShell? + task = connection.requestShell { result in + shell = (try? result.get()).flatMap { EmbeddedShell(shell: $0) } + task?.cancel() + exp.fulfill() + } + wait(for: [exp], timeout: 2) + wait(timeout: 0.5) + XCTAssertEqual(shell!.states, []) + } + // MARK: - Errors func testDisconnectionError() throws { @@ -84,21 +108,15 @@ class SSHShellTests: XCTestCase { private func launchShell() throws -> EmbeddedShell { let exp = XCTestExpectation() var shell: EmbeddedShell? - connection.start(withTimeout: 2) { result in + let connection = try launchConnection() + connection.requestShell(withTimeout: 15) { result in switch result { - case .success: - self.connection.requestShell(withTimeout: 15) { result in - switch result { - case .success(let success): - shell = EmbeddedShell(shell: success) - case .failure: - break - } - exp.fulfill() - } + case .success(let success): + shell = EmbeddedShell(shell: success) case .failure: - exp.fulfill() + break } + exp.fulfill() } wait(for: [exp], timeout: 3) if let shell = shell { @@ -107,6 +125,16 @@ class SSHShellTests: XCTestCase { struct AError: Error {} throw AError() } + + private func launchConnection() throws -> SSHConnection { + let exp = XCTestExpectation() + connection.start(withTimeout: 2) { result in + XCTAssertTrue(result.isSuccess) + exp.fulfill() + } + wait(for: [exp], timeout: 2) + return connection + } } private class EmbeddedShell {