diff --git a/Package.swift b/Package.swift index 7c083bc..fa711cc 100644 --- a/Package.swift +++ b/Package.swift @@ -1,15 +1,17 @@ import PackageDescription +let beta = Version(2,0,0, prereleaseIdentifiers: ["beta"]) + let package = Package( name: "PostgreSQL", dependencies: [ // Module map for `libpq` - .Package(url: "https://github.com/vapor/cpostgresql.git", Version(2,0,0, prereleaseIdentifiers: ["alpha"])), + .Package(url: "https://github.com/vapor-community/cpostgresql.git", beta), // Data structure for converting between multiple representations - .Package(url: "https://github.com/vapor/node.git", Version(2,0,0, prereleaseIdentifiers: ["beta"])), + .Package(url: "https://github.com/vapor/node.git", beta), // Core extensions, type-aliases, and functions that facilitate common tasks - .Package(url: "https://github.com/vapor/core.git", Version(2,0,0, prereleaseIdentifiers: ["beta"])), + .Package(url: "https://github.com/vapor/core.git", beta), ] ) diff --git a/Sources/PostgreSQL/PostgresBinaryUtils.swift b/Sources/PostgreSQL/Bind/BinaryUtils.swift similarity index 89% rename from Sources/PostgreSQL/PostgresBinaryUtils.swift rename to Sources/PostgreSQL/Bind/BinaryUtils.swift index 8f73d2c..76501e3 100644 --- a/Sources/PostgreSQL/PostgresBinaryUtils.swift +++ b/Sources/PostgreSQL/Bind/BinaryUtils.swift @@ -22,14 +22,6 @@ extension Float32 { var bigEndian: Float32 { return Float32(bitPattern: bitPattern.bigEndian) } - - var littleEndian: Float32 { - return Float32(bitPattern: bitPattern.littleEndian) - } - - var byteSwapped: Float32 { - return Float32(bitPattern: bitPattern.byteSwapped) - } } extension Float64 { @@ -41,57 +33,21 @@ extension Float64 { var bigEndian: Float64 { return Float64(bitPattern: bitPattern.bigEndian) } - - var littleEndian: Float64 { - return Float64(bitPattern: bitPattern.littleEndian) - } - - var byteSwapped: Float64 { - return Float64(bitPattern: bitPattern.byteSwapped) - } } /// Most information for parsing binary formats has been retrieved from the following links: /// - https://www.postgresql.org/docs/9.6/static/datatype.html (Data types) /// - https://github.com/postgres/postgres/tree/55c3391d1e6a201b5b891781d21fe682a8c64fe6/src/backend/utils/adt (Backend sending code) -struct PostgresBinaryUtils { +struct BinaryUtils { - // MARK: - Formatters + // MARK: - Formatter struct Formatters { - private static func formatter(format: String, forceUTC: Bool) -> DateFormatter { + static let timestamptz: DateFormatter = { let formatter = DateFormatter() - if forceUTC { - formatter.timeZone = TimeZone(abbreviation: "UTC") - } - formatter.dateFormat = format + formatter.dateFormat = "yyyy-MM-dd HH:mm:ss.SSSX" return formatter - } - - private static let timestamp: DateFormatter = formatter(format: "yyyy-MM-dd HH:mm:ss.SSS", forceUTC: true) - private static let timestamptz: DateFormatter = formatter(format: "yyyy-MM-dd HH:mm:ss.SSSX", forceUTC: false) - - private static let date: DateFormatter = formatter(format: "yyyy-MM-dd", forceUTC: false) - - private static let time: DateFormatter = formatter(format: "HH:mm:ss.SSS", forceUTC: true) - private static let timetz: DateFormatter = formatter(format: "HH:mm:ss.SSSX", forceUTC: false) - - static func dateFormatter(for oid: OID) -> DateFormatter { - switch oid { - case .date: - return date - case .time: - return time - case .timetz: - return timetz - case .timestamp: - return timestamp - case .timestamptz: - return timestamptz - default: - return timestamptz - } - } + }() static let interval: NumberFormatter = { let formatter = NumberFormatter() @@ -133,11 +89,13 @@ struct PostgresBinaryUtils { return uint8Bytes } - static func valueToByteArray(_ value: inout T) -> [Int8] { + static func valueToBytes(_ value: inout T) -> (UnsafeMutablePointer, Int) { let size = MemoryLayout.size(ofValue: value) return withUnsafePointer(to: &value) { valuePointer in return valuePointer.withMemoryRebound(to: Int8.self, capacity: size) { bytePointer in - return UnsafeBufferPointer(start: bytePointer, count: size).array + let bytes: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: size) + bytes.assign(from: bytePointer, count: size) + return (bytes, size) } } } diff --git a/Sources/PostgreSQL/Bind/Bind+Node.swift b/Sources/PostgreSQL/Bind/Bind+Node.swift new file mode 100644 index 0000000..e819d8b --- /dev/null +++ b/Sources/PostgreSQL/Bind/Bind+Node.swift @@ -0,0 +1,202 @@ +import CPostgreSQL + +extension Bind { + /// Parses a PostgreSQL value from an output binding. + public var value: StructuredData { + // Check if we have data to parse + guard let value = bytes else { + return .null + } + + // We only parse binary data, otherwise simply return the data as a string + guard format == .binary else { + let string = BinaryUtils.parseString(value: value, length: length) + return .string(string) + } + + // Parse based on the type of data + switch type { + case .null: + return .null + + case .supported(let supportedType): + return Bind.parse(type: supportedType, configuration: configuration, value: value, length: length) + + case .array(let supportedArrayType): + return Bind.parse(type: supportedArrayType, configuration: configuration, value: value, length: length) + + case .unsupported(let oid): + print("Unsupported Oid type for PostgreSQL binding (\(oid)).") + + // Fallback to simply passing on the bytes + let bytes = BinaryUtils.parseBytes(value: value, length: length) + return .bytes(bytes) + } + } +} + +/** + Parsing data + */ +extension Bind { + fileprivate static func parse(type: FieldType.Supported, configuration: Configuration, value: UnsafeMutablePointer, length: Int) -> StructuredData { + switch type { + case .bool: + return .bool(value[0] != 0) + + case .char, .name, .text, .json, .xml, .bpchar, .varchar: + let string = BinaryUtils.parseString(value: value, length: length) + return .string(string) + + case .jsonb: + // Ignore jsonb version number + let jsonValue = value.advanced(by: 1) + let string = BinaryUtils.parseString(value: jsonValue, length: length - 1) + return .string(string) + + case .int2: + let integer = BinaryUtils.parseInt16(value: value) + return .number(.int(Int(integer))) + + case .int4: + let integer = BinaryUtils.parseInt32(value: value) + return .number(.int(Int(integer))) + + case .int8: + let integer = BinaryUtils.parseInt64(value: value) + if let intValue = Int(exactly: integer) { + return .number(.int(intValue)) + } else { + return .number(.double(Double(integer))) + } + + case .bytea: + let bytes = BinaryUtils.parseBytes(value: value, length: length) + return .bytes(bytes) + + case .float4: + let float = BinaryUtils.parseFloat32(value: value) + return .number(.double(Double(float))) + + case .float8: + let float = BinaryUtils.parseFloat64(value: value) + return .number(.double(Double(float))) + + case .numeric: + let number = BinaryUtils.parseNumeric(value: value) + return .string(number) + + case .uuid: + let uuid = BinaryUtils.parseUUID(value: value) + return .string(uuid) + + case .timestamp, .timestamptz, .date, .time, .timetz: + let date = BinaryUtils.parseTimetamp(value: value, isInteger: configuration.hasIntegerDatetimes) + return .date(date) + + case .interval: + let interval = BinaryUtils.parseInterval(value: value, timeIsInteger: configuration.hasIntegerDatetimes) + return .string(interval) + + case .point: + let point = BinaryUtils.parsePoint(value: value) + return .string(point) + + case .lseg: + let lseg = BinaryUtils.parseLineSegment(value: value) + return .string(lseg) + + case .path: + let path = BinaryUtils.parsePath(value: value) + return .string(path) + + case .box: + let box = BinaryUtils.parseBox(value: value) + return .string(box) + + case .polygon: + let polygon = BinaryUtils.parsePolygon(value: value) + return .string(polygon) + + case .circle: + let circle = BinaryUtils.parseCircle(value: value) + return .string(circle) + + case .inet, .cidr: + let inet = BinaryUtils.parseIPAddress(value: value) + return .string(inet) + + case .macaddr: + let macaddr = BinaryUtils.parseMacAddress(value: value) + return .string(macaddr) + + case .bit, .varbit: + let bitString = BinaryUtils.parseBitString(value: value, length: length) + return .string(bitString) + } + } + + fileprivate static func parse(type: FieldType.ArraySupported, configuration: Configuration, value: UnsafeMutablePointer, length: Int) -> StructuredData { + // Get the dimension of the array + let arrayDimension = BinaryUtils.parseInt32(value: value) + guard arrayDimension > 0 else { + return .array([]) + } + + var pointer = value.advanced(by: 12) + + // Get all dimension lengths + var dimensionLengths: [Int] = [] + for _ in 0..) -> StructuredData { + // Get the length of the array + let arrayLength = dimensionLengths[0] + + // Create elements array + var values: [StructuredData] = [] + values.reserveCapacity(arrayLength) + + // Loop through array and convert each item + let supportedType = type.supported + for _ in 0.. 1 { + + var subDimensionLengths = dimensionLengths + subDimensionLengths.removeFirst() + + let array = parse(type: type, configuration: configuration, dimensionLengths: subDimensionLengths, pointer: &pointer) + values.append(array) + + } else { + + let elementLength = Int(BinaryUtils.parseInt32(value: pointer)) + pointer = pointer.advanced(by: 4) + + // Check if the element is null + guard elementLength != -1 else { + values.append(.null) + continue + } + + // Parse to node + let item = parse(type: supportedType, configuration: configuration, value: pointer, length: elementLength) + values.append(item) + pointer = pointer.advanced(by: elementLength) + } + } + + return .array(values) + } +} + + diff --git a/Sources/PostgreSQL/Bind/Bind.swift b/Sources/PostgreSQL/Bind/Bind.swift new file mode 100644 index 0000000..ebd87ef --- /dev/null +++ b/Sources/PostgreSQL/Bind/Bind.swift @@ -0,0 +1,266 @@ +import Foundation +import Core + +public final class Bind { + + // MARK: - Enums + + public enum Format : Int32 { + case string = 0 + case binary = 1 + } + + // MARK: - Properties + + public let bytes: UnsafeMutablePointer? + public let length: Int + + public let type: FieldType + public let format: Format + + public let configuration: Configuration + public let result: Result? + + // MARK: - Init + + /** + Creates a NULL input binding. + + PQexecParams converts nil pointer to NULL. + see: https://www.postgresql.org/docs/9.1/static/libpq-exec.html + */ + public init(configuration: Configuration) { + self.configuration = configuration + + bytes = nil + length = 0 + + type = nil + format = .string + + result = nil + } + + /** + Creates an input binding from a String. + */ + public convenience init(string: String, configuration: Configuration) { + let utf8CString = string.utf8CString + let count = utf8CString.count + + let bytes = UnsafeMutablePointer.allocate(capacity: count) + for (i, char) in utf8CString.enumerated() { + bytes[i] = char + } + + self.init(bytes: bytes, length: count, type: nil, format: .string, configuration: configuration) + } + + /** + Creates an input binding from a UInt. + */ + public convenience init(bool: Bool, configuration: Configuration) { + let bytes = UnsafeMutablePointer.allocate(capacity: 1) + bytes.initialize(to: bool ? 1 : 0) + + self.init(bytes: bytes, length: 1, type: FieldType(.bool), format: .binary, configuration: configuration) + } + + /** + Creates an input binding from an Int. + */ + public convenience init(int: Int, configuration: Configuration) { + let count = MemoryLayout.size(ofValue: int) + + let type: FieldType + switch count { + case 2: + type = FieldType(.int2) + case 4: + type = FieldType(.int4) + case 8: + type = FieldType(.int8) + default: + // Unsupported integer size, use string instead + self.init(string: int.description, configuration: configuration) + return + } + + var value = int.bigEndian + let (bytes, length) = BinaryUtils.valueToBytes(&value) + self.init(bytes: bytes, length: length, type: type, format: .binary, configuration: configuration) + } + + /** + Creates an input binding from a UInt. + */ + public convenience init(uint: UInt, configuration: Configuration) { + let int: Int + if uint >= UInt(Int.max) { + int = Int.max + } + else { + int = Int(uint) + } + + self.init(int: int, configuration: configuration) + } + + /** + Creates an input binding from an Double. + */ + public convenience init(double: Double, configuration: Configuration) { + let count = MemoryLayout.size(ofValue: double) + + let type: FieldType + switch count { + case 4: + type = FieldType(.float4) + case 8: + type = FieldType(.float8) + default: + // Unsupported float size, use string instead + self.init(string: double.description, configuration: configuration) + return + } + + var value = double.bigEndian + let (bytes, length) = BinaryUtils.valueToBytes(&value) + self.init(bytes: bytes, length: length, type: type, format: .binary, configuration: configuration) + } + + /** + Creates an input binding from an array of bytes. + */ + public convenience init(bytes: Bytes, configuration: Configuration) { + let int8Bytes: UnsafeMutablePointer = UnsafeMutablePointer.allocate(capacity: bytes.count) + for (i, byte) in bytes.enumerated() { + int8Bytes[i] = Int8(bitPattern: byte) + } + + self.init(bytes: int8Bytes, length: bytes.count, type: nil, format: .binary, configuration: configuration) + } + + /** + Creates an input binding from a Date. + */ + public convenience init(date: Date, configuration: Configuration) { + let interval = date.timeIntervalSince(BinaryUtils.TimestampConstants.referenceDate) + + if configuration.hasIntegerDatetimes { + let microseconds = Int64(interval * 1_000_000) + var value = microseconds.bigEndian + let (bytes, length) = BinaryUtils.valueToBytes(&value) + self.init(bytes: bytes, length: length, type: FieldType(.timestamptz), format: .binary, configuration: configuration) + } + else { + let seconds = Float64(interval) + var value = seconds.bigEndian + let (bytes, length) = BinaryUtils.valueToBytes(&value) + self.init(bytes: bytes, length: length, type: FieldType(.timestamptz), format: .binary, configuration: configuration) + } + } + + /** + Creates an input binding from an array. + */ + public convenience init(array: [StructuredData], configuration: Configuration) { + let elements = array.map { $0.postgresArrayElementString } + let arrayString = "{\(elements.joined(separator: ","))}" + self.init(string: arrayString, configuration: configuration) + } + + public init(bytes: UnsafeMutablePointer?, length: Int, type: FieldType, format: Format, configuration: Configuration) { + self.bytes = bytes + self.length = length + + self.type = type + self.format = format + + self.configuration = configuration + + result = nil + } + + public init(result: Result, bytes: UnsafeMutablePointer, length: Int, type: FieldType, format: Format, configuration: Configuration) { + self.result = result + + self.bytes = bytes + self.length = length + + self.type = type + self.format = format + + self.configuration = configuration + } +} + +extension Node { + /** + Creates in input binding from a PostgreSQL Value. + */ + public func bind(with configuration: Configuration) -> Bind { + switch wrapped { + case .number(let number): + switch number { + case .int(let int): + return Bind(int: int, configuration: configuration) + case .double(let double): + return Bind(double: double, configuration: configuration) + case .uint(let uint): + return Bind(uint: uint, configuration: configuration) + } + case .string(let string): + return Bind(string: string, configuration: configuration) + case .null: + return Bind(configuration: configuration) + case .array(let array): + return Bind(array: array, configuration: configuration) + case .bytes(let bytes): + return Bind(bytes: bytes, configuration: configuration) + case .object(_): + print("Unsupported Node type for PostgreSQL binding, everything except for .object is supported.") + return Bind(configuration: configuration) + case .bool(let bool): + return Bind(bool: bool, configuration: configuration) + case .date(let date): + return Bind(date: date, configuration: configuration) + } + } +} + +extension StructuredData { + var postgresArrayElementString: String { + switch self { + case .null: + return "NULL" + + case .bytes(let bytes): + let hexString = bytes.map { $0.lowercaseHexPair }.joined() + return "\"\\\\x\(hexString)\"" + + case .bool(let bool): + return bool ? "t" : "f" + + case .number(let number): + return number.description + + case .string(let string): + let escapedString = string + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + return "\"\(escapedString)\"" + + case .array(let array): + let elements = array.map { $0.postgresArrayElementString } + return "{\(elements.joined(separator: ","))}" + + case .object(_): + print("Unsupported Node array type for PostgreSQL binding, everything except for .object is supported.") + return "NULL" + + case .date(let date): + return BinaryUtils.Formatters.timestamptz.string(from: date) + } + } +} diff --git a/Sources/PostgreSQL/Bind/FieldType.swift b/Sources/PostgreSQL/Bind/FieldType.swift new file mode 100644 index 0000000..51fd076 --- /dev/null +++ b/Sources/PostgreSQL/Bind/FieldType.swift @@ -0,0 +1,247 @@ +import CPostgreSQL + +public enum FieldType : ExpressibleByNilLiteral, Equatable { + case supported(Supported) + case array(ArraySupported) + case unsupported(Oid) + case null + + // MARK: - Init + + public init(_ oid: Oid) { + if let supported = Supported(rawValue: oid) { + self = .supported(supported) + } + else if let arraySupported = ArraySupported(rawValue: oid) { + self = .array(arraySupported) + } + else { + self = .unsupported(oid) + } + } + + public init(_ oid: Oid?) { + if let oid = oid { + self.init(oid) + } + else { + self.init(nilLiteral: ()) + } + } + + public init(_ supported: Supported) { + self = .supported(supported) + } + + // MARK: - ExpressibleByNilLiteral + + public init(nilLiteral: ()) { + self = .null + } + + // MARK: - Equatable + + public static func ==(lhs: FieldType, rhs: FieldType) -> Bool { + return lhs.oid == rhs.oid + } + + // MARK: - Oid + + public var oid: Oid? { + switch self { + case .supported(let supported): + return supported.rawValue + + case .unsupported(let oid): + return oid + + case .array(let supported): + return supported.rawValue + + case .null: + return nil + } + } +} + +extension FieldType { + /// Oid values can be found in the following file: + /// https://github.com/postgres/postgres/blob/55c3391d1e6a201b5b891781d21fe682a8c64fe6/src/include/catalog/pg_type.h + public enum Supported: Oid { + case bool = 16 + + case int2 = 21 + case int4 = 23 + case int8 = 20 + + case bytea = 17 + + case char = 18 + case name = 19 + case text = 25 + case bpchar = 1042 + case varchar = 1043 + + case json = 114 + case jsonb = 3802 + case xml = 142 + + case float4 = 700 + case float8 = 701 + + case numeric = 1700 + + case date = 1082 + case time = 1083 + case timetz = 1266 + case timestamp = 1114 + case timestamptz = 1184 + case interval = 1186 + + case uuid = 2950 + + case point = 600 + case lseg = 601 + case path = 602 + case box = 603 + case polygon = 604 + case circle = 718 + + case cidr = 650 + case inet = 869 + case macaddr = 829 + + case bit = 1560 + case varbit = 1562 + } +} + +extension FieldType { + public enum ArraySupported: Oid { + case bool = 1000 + + case int2 = 1005 + case int4 = 1007 + case int8 = 1016 + + case bytea = 1001 + + case char = 1002 + case name = 1003 + case text = 1009 + case bpchar = 1014 + case varchar = 1015 + + case json = 199 + case jsonb = 3807 + case xml = 143 + + case float4 = 1021 + case float8 = 1022 + + case numeric = 1231 + + case date = 1182 + case time = 1183 + case timetz = 1270 + case timestamp = 1115 + case timestamptz = 1185 + case interval = 1187 + + case uuid = 2951 + + case point = 1017 + case lseg = 1018 + case path = 1019 + case box = 1020 + case polygon = 1027 + case circle = 719 + + case cidr = 651 + case inet = 1041 + case macaddr = 1040 + + case bit = 1561 + case varbit = 1563 + + // MARK: - Supported + + public init(_ supported: Supported) { + switch supported { + case .bool: self = .bool + case .int2: self = .int2 + case .int4: self = .int4 + case .int8: self = .int8 + case .bytea: self = .bytea + case .char: self = .char + case .name: self = .name + case .text: self = .text + case .bpchar: self = .bpchar + case .varchar: self = .varchar + case .json: self = .json + case .jsonb: self = .jsonb + case .xml: self = .xml + case .float4: self = .float4 + case .float8: self = .float8 + case .numeric: self = .numeric + case .date: self = .date + case .time: self = .time + case .timetz: self = .timetz + case .timestamp: self = .timestamp + case .timestamptz: self = .timestamptz + case .interval: self = .interval + case .uuid: self = .uuid + case .point: self = .point + case .lseg: self = .lseg + case .path: self = .path + case .box: self = .box + case .polygon: self = .polygon + case .circle: self = .circle + case .cidr: self = .cidr + case .inet: self = .inet + case .macaddr: self = .macaddr + case .bit: self = .bit + case .varbit: self = .varbit + } + } + + public var supported: Supported { + switch self { + case .bool: return .bool + case .int2: return .int2 + case .int4: return .int4 + case .int8: return .int8 + case .bytea: return .bytea + case .char: return .char + case .name: return .name + case .text: return .text + case .bpchar: return .bpchar + case .varchar: return .varchar + case .json: return .json + case .jsonb: return .jsonb + case .xml: return .xml + case .float4: return .float4 + case .float8: return .float8 + case .numeric: return .numeric + case .date: return .date + case .time: return .time + case .timetz: return .timetz + case .timestamp: return .timestamp + case .timestamptz: return .timestamptz + case .interval: return .interval + case .uuid: return .uuid + case .point: return .point + case .lseg: return .lseg + case .path: return .path + case .box: return .box + case .polygon: return .polygon + case .circle: return .circle + case .cidr: return .cidr + case .inet: return .inet + case .macaddr: return .macaddr + case .bit: return .bit + case .varbit: return .varbit + } + } + } +} diff --git a/Sources/PostgreSQL/Connection.swift b/Sources/PostgreSQL/Connection.swift index ab4e038..a95dda8 100644 --- a/Sources/PostgreSQL/Connection.swift +++ b/Sources/PostgreSQL/Connection.swift @@ -1,23 +1,24 @@ import CPostgreSQL +import Dispatch // This structure represents a handle to one database connection. // It is used for almost all PostgreSQL functions. // Do not try to make a copy of a PostgreSQL structure. // There is no guarantee that such a copy will be usable. public final class Connection: ConnInfoInitializable { - public let cConnection: OpaquePointer - public var configuration: Configuration? - public var isConnected: Bool { - if PQstatus(cConnection) == CONNECTION_OK { - return true - } - return false - } - - public init(conninfo: ConnInfo) throws { + + // MARK: - CConnection + + public typealias CConnection = OpaquePointer + + public let cConnection: CConnection + + // MARK: - Init + + public init(connInfo: ConnInfo) throws { let string: String - switch conninfo { + switch connInfo { case .raw(let info): string = info case .params(let params): @@ -26,109 +27,209 @@ public final class Connection: ConnInfoInitializable { string = "host='\(hostname)' port='\(port)' dbname='\(database)' user='\(user)' password='\(password)' client_encoding='UTF8'" } - self.cConnection = PQconnectdb(string) - if isConnected == false { - throw DatabaseError.cannotEstablishConnection(lastError) - } + cConnection = PQconnectdb(string) + try validateConnection() + } + + // MARK: - Deinit + + deinit { + try? close() } + + // MARK: - Execute @discardableResult - public func execute(_ query: String, _ values: [Node]? = []) throws -> [[String: Node]] { - guard !query.isEmpty else { - throw DatabaseError.noQuery - } - - let values = values ?? [] - + public func execute(_ query: String, _ values: [Node] = []) throws -> Node { + let binds = values.map { $0.bind(with: configuration) } + return try execute(query, binds) + } + + @discardableResult + public func execute(_ query: String, _ binds: [Bind]) throws -> Node { var types: [Oid] = [] - types.reserveCapacity(values.count) - - var paramValues: [[Int8]?] = [] - paramValues.reserveCapacity(values.count) - - var lengths: [Int32] = [] - lengths.reserveCapacity(values.count) - + types.reserveCapacity(binds.count) + var formats: [Int32] = [] - formats.reserveCapacity(values.count) - - for value in values { - let (bytes, oid, format) = value.postgresBindingData - paramValues.append(bytes) - types.append(oid?.rawValue ?? 0) - lengths.append(Int32(bytes?.count ?? 0)) - formats.append(format.rawValue) + formats.reserveCapacity(binds.count) + + var values: [UnsafePointer?] = [] + values.reserveCapacity(binds.count) + + var lengths: [Int32] = [] + lengths.reserveCapacity(binds.count) + + for bind in binds { + + types.append(bind.type.oid ?? 0) + formats.append(bind.format.rawValue) + values.append(bind.bytes) + lengths.append(Int32(bind.length)) } - - let res: Result.Pointer = PQexecParams( - cConnection, query, - Int32(values.count), - types, paramValues.map { - UnsafePointer($0) - }, + + let resultPointer: Result.Pointer? = PQexecParams( + cConnection, + query, + Int32(binds.count), + types, + values, lengths, formats, - DataFormat.binary.rawValue + Bind.Format.binary.rawValue ) - - defer { - PQclear(res) - } - - switch Database.Status(result: res) { - case .nonFatalError, .fatalError, .unknown: - throw DatabaseError.invalidSQL(message: String(cString: PQresultErrorMessage(res))) - case .tuplesOk: - let configuration = try getConfiguration() - return Result(configuration: configuration, pointer: res).parsed - default: - return [] - } + + let result = Result(pointer: resultPointer, connection: self) + return try result.parseData() + } + + // MARK: - Connection Status + + public var isConnected: Bool { + return PQstatus(cConnection) == CONNECTION_OK } - public func status() -> ConnStatusType { + public var status: ConnStatusType { return PQstatus(cConnection) } - - public func reset() throws { - guard self.isConnected else { - throw PostgreSQLError(.connection_failure, reason: lastError) + + private func validateConnection() throws { + guard isConnected else { + throw PostgreSQLError(code: .connectionFailure, connection: self) } + } + public func reset() throws { + try validateConnection() PQreset(cConnection) } public func close() throws { - guard self.isConnected else { - throw PostgreSQLError(.connection_does_not_exist, reason: lastError) - } - + try validateConnection() PQfinish(cConnection) } - - // Contains the last error message generated by the PostgreSQL connection. - public var lastError: String { - guard let errorMessage = PQerrorMessage(cConnection) else { - return "" + + // MARK: - Transaction + + public enum TransactionIsolationLevel { + case readCommitted + case repeatableRead + case serializable + + var sqlName: String { + switch self { + case .readCommitted: + return "READ COMMITTED" + + case .repeatableRead: + return "REPEATABLE READ" + + case .serializable: + return "SERIALIZABLE" + } } - return String(cString: errorMessage) } + + public func transaction(isolationLevel: TransactionIsolationLevel = .readCommitted, closure: () throws -> R) throws -> R { + try execute("BEGIN TRANSACTION ISOLATION LEVEL \(isolationLevel.sqlName)") + + let value: R + do { + value = try closure() + } catch { + // rollback changes and then rethrow the error + try execute("ROLLBACK") + throw error + } - deinit { - try? close() + try execute("COMMIT") + return value + } + + // MARK: - LISTEN/NOTIFY + + public struct Notification { + public let pid: Int + public let channel: String + public let payload: String? + + init(pgNotify: PGnotify) { + channel = String(cString: pgNotify.relname) + pid = Int(pgNotify.be_pid) + + if pgNotify.extra != nil { + let string = String(cString: pgNotify.extra) + if !string.isEmpty { + payload = string + } + else { + payload = nil + } + } + else { + payload = nil + } + } + } + + /// Registers as a listener on a specific notification channel. + /// + /// - Parameters: + /// - channel: The channel to register for. + /// - queue: The queue to perform the listening on. + /// - callback: Callback containing any received notification or error and a boolean which can be set to true to stop listening. + public func listen(toChannel channel: String, on queue: DispatchQueue = DispatchQueue.global(), callback: @escaping (Notification?, Error?, inout Bool) -> Void) { + queue.async { + var stop: Bool = false + + do { + try self.execute("LISTEN \(channel)") + + while !stop { + try self.validateConnection() + + // Sleep to avoid looping continuously on cpu + sleep(1) + + PQconsumeInput(self.cConnection) + + while !stop, let pgNotify = PQnotifies(self.cConnection) { + let notification = Notification(pgNotify: pgNotify.pointee) + + callback(notification, nil, &stop) + + PQfreemem(pgNotify) + } + } + } + catch { + callback(nil, error, &stop) + } + } + } + + public func notify(channel: String, payload: String? = nil) throws { + if let payload = payload { + try execute("NOTIFY \(channel), '\(payload)'") + } + else { + try execute("NOTIFY \(channel)") + } } - // MARK: - Load Configuration - private func getConfiguration() throws -> Configuration { - if let configuration = self.configuration { + // MARK: - Configuration + + private var cachedConfiguration: Configuration? + + public var configuration: Configuration { + if let configuration = cachedConfiguration { return configuration } - + let hasIntegerDatetimes = getBooleanParameterStatus(key: "integer_datetimes", default: true) - + let configuration = Configuration(hasIntegerDatetimes: hasIntegerDatetimes) - self.configuration = configuration - + cachedConfiguration = configuration + return configuration } @@ -146,8 +247,7 @@ extension Connection { let values = try representable.map { return try $0.makeNode(in: PostgreSQLContext.shared) } - - let result: [[String: Node]] = try execute(query, values) - return try Node.array(result.map { try $0.makeNode(in: PostgreSQLContext.shared) }) + + return try execute(query, values) } } diff --git a/Sources/PostgreSQL/ConnectionInfo.swift b/Sources/PostgreSQL/ConnectionInfo.swift index 27f587f..fff34ed 100644 --- a/Sources/PostgreSQL/ConnectionInfo.swift +++ b/Sources/PostgreSQL/ConnectionInfo.swift @@ -1,5 +1,3 @@ -import CPostgreSQL - public enum ConnInfo { case raw(String) case params([String: String]) @@ -7,19 +5,19 @@ public enum ConnInfo { } public protocol ConnInfoInitializable { - init(conninfo: ConnInfo) throws + init(connInfo: ConnInfo) throws } extension ConnInfoInitializable { - public init(params: [String: String]) throws { - try self.init(conninfo: .params(params)) + public init(connInfo: String) throws { + try self.init(connInfo: .raw(connInfo)) } - - public init(hostname: String, port: Int, database: String, user: String, password: String) throws { - try self.init(conninfo: .basic(hostname: hostname, port: port, database: database, user: user, password: password)) + + public init(params: [String: String]) throws { + try self.init(connInfo: .params(params)) } - public init(conninfo: String) throws { - try self.init(conninfo: .raw(conninfo)) + public init(hostname: String, port: Int = 5432, database: String, user: String, password: String) throws { + try self.init(connInfo: .basic(hostname: hostname, port: port, database: database, user: user, password: password)) } } diff --git a/Sources/PostgreSQL/Database.swift b/Sources/PostgreSQL/Database.swift index 290ef54..5126519 100644 --- a/Sources/PostgreSQL/Database.swift +++ b/Sources/PostgreSQL/Database.swift @@ -1,85 +1,22 @@ import CPostgreSQL -import Core - -public enum DatabaseError: Error { - case cannotEstablishConnection(String) - case indexOutOfRange - case columnNotFound - case invalidSQL(message: String) - case noQuery - case noResults -} - -enum DataFormat : Int32 { - case string = 0 - case binary = 1 -} public final class Database: ConnInfoInitializable { + // MARK: - Properties - public let conninfo: ConnInfo + + public let connInfo: ConnInfo // MARK: - Init - public init(conninfo: ConnInfo) throws { - self.conninfo = conninfo + + public init(connInfo: ConnInfo) throws { + self.connInfo = connInfo } - // MARK: - Connection + /// Creates a new connection to + /// the database that can be reused between executions. + /// + /// The connection will close automatically when deinitialized. public func makeConnection() throws -> Connection { - return try Connection(conninfo: conninfo) - } - - // MARK: - Query Execution - @discardableResult - public func execute(_ query: String, _ values: [Node]? = [], on connection: Connection? = nil) throws -> [[String: Node]] { - guard !query.isEmpty else { - throw DatabaseError.noQuery - } - - let connection = try connection ?? makeConnection() - - return try connection.execute(query, values) - } - - // MARK: - LISTEN - public func listen(to channel: String, callback: @escaping (Notification) -> Void) { - background { - do { - let connection = try self.makeConnection() - - try self.execute("LISTEN \(channel)", on: connection) - - while true { - if connection.isConnected == false { - throw DatabaseError.cannotEstablishConnection(connection.lastError) - } - - PQconsumeInput(connection.cConnection) - - while let pgNotify = PQnotifies(connection.cConnection) { - let notification = Notification(relname: pgNotify.pointee.relname, extra: pgNotify.pointee.extra, be_pid: pgNotify.pointee.be_pid) - - callback(notification) - - PQfreemem(pgNotify) - } - } - } - catch { - fatalError("\(error)") - } - } - } - - // MARK: - NOTIFY - public func notify(channel: String, payload: String?, on connection: Connection? = nil) throws { - let connection = try connection ?? makeConnection() - - if let payload = payload { - try execute("NOTIFY \(channel), '\(payload)'", on: connection) - } - else { - try execute("NOTIFY \(channel)", on: connection) - } + return try Connection(connInfo: connInfo) } } diff --git a/Sources/PostgreSQL/Error.swift b/Sources/PostgreSQL/Error.swift index 0f0a437..82e29d8 100644 --- a/Sources/PostgreSQL/Error.swift +++ b/Sources/PostgreSQL/Error.swift @@ -10,318 +10,335 @@ public struct PostgreSQLError: Error { public let reason: String } +public enum PostgresSQLStatusError: Error { + case emptyQuery + case badResponse +} + extension PostgreSQLError { public enum Code: String { - case successful_completion = "00000" // Class 01 — Warning case warning = "01000" - case dynamic_result_sets_returned = "0100C" - case implicit_zero_bit_padding = "01008" - case null_value_eliminated_in_set_function = "01003" - case privilege_not_granted = "01007" - case privilege_not_revoked = "01006" - case string_data_right_truncation = "01004" - case deprecated_feature = "01P01" + case dynamicResultSetsReturned = "0100C" + case implicitZeroBitPadding = "01008" + case nullValueEliminatedInSetFunction = "01003" + case privilegeNotGranted = "01007" + case privilegeNotRevoked = "01006" + case stringDataRightTruncationWarning = "01004" + case deprecatedFeature = "01P01" // Class 02 — No Data (this is also a warning class per the SQL standard) - case no_data = "02000" - case no_additional_dynamic_result_sets_returned = "02001" + case noData = "02000" + case noAdditionalDynamicResultSetsReturned = "02001" // Class 03 — SQL Statement Not Yet Complete - case sql_statement_not_yet_complete = "03000" + case sqlStatementNotYetComplete = "03000" // Class 08 — Connection Exception - case connection_exception = "08000" - case connection_does_not_exist = "08003" - case connection_failure = "08006" - case sqlclient_unable_to_establish_sqlconnection = "08001" - case sqlserver_rejected_establishment_of_sqlconnection = "08004" - case transaction_resolution_unknown = "08007" - case protocol_violation = "08P01" + case connectionException = "08000" + case connectionDoesNotExist = "08003" + case connectionFailure = "08006" + case sqlclientUnableToEstablishSqlconnection = "08001" + case sqlserverRejectedEstablishmentOfSqlconnection = "08004" + case transactionResolutionUnknown = "08007" + case protocolViolation = "08P01" // Class 09 — Triggered Action Exception - case triggered_action_exception = "09000" + case triggeredActionException = "09000" // Class 0A — Feature Not Supported - case feature_not_supported = "0A000" + case featureNotSupported = "0A000" // Class 0B — Invalid Transaction Initiation - case invalid_transaction_initiation = "0B000" + case invalidTransactionInitiation = "0B000" // Class 0F — Locator Exception - case locator_exception = "0F000" - case invalid_locator_specification = "0F001" + case locatorException = "0F000" + case invalidLocatorSpecification = "0F001" // Class 0L — Invalid Grantor - case invalid_grantor = "0L000" - case invalid_grant_operation = "0LP01" + case invalidGrantor = "0L000" + case invalidGrantOperation = "0LP01" // Class 0P — Invalid Role Specification - case invalid_role_specification = "0P000" + case invalidRoleSpecification = "0P000" // Class 0Z — Diagnostics Exception - case diagnostics_exception = "0Z000" - case stacked_diagnostics_accessed_without_active_handler = "0Z002" + case diagnosticsException = "0Z000" + case stackedDiagnosticsAccessedWithoutActiveHandler = "0Z002" // Class 20 — Case Not Found - case case_not_found = "20000" + case caseNotFound = "20000" // Class 21 — Cardinality Violation - case cardinality_violation = "21000" + case cardinalityViolation = "21000" // Class 22 — Data Exception - case data_exception = "22000" - case array_subscript_error = "2202E" - case character_not_in_repertoire = "22021" - case datetime_field_overflow = "22008" - case division_by_zero = "22012" - case error_in_assignment = "22005" - case escape_character_conflict = "2200B" - case indicator_overflow = "22022" - case interval_field_overflow = "22015" - case invalid_argument_for_logarithm = "2201E" - case invalid_argument_for_ntile_function = "22014" - case invalid_argument_for_nth_value_function = "22016" - case invalid_argument_for_power_function = "2201F" - case invalid_argument_for_width_bucket_function = "2201G" - case invalid_character_value_for_cast = "22018" - case invalid_datetime_format = "22007" - case invalid_escape_character = "22019" - case invalid_escape_octet = "2200D" - case invalid_escape_sequence = "22025" - case nonstandard_use_of_escape_character = "22P06" - case invalid_indicator_parameter_value = "22010" - case invalid_parameter_value = "22023" - case invalid_regular_expression = "2201B" - case invalid_row_count_in_limit_clause = "2201W" - case invalid_row_count_in_result_offset_clause = "2201X" - case invalid_tablesample_argument = "2202H" - case invalid_tablesample_repeat = "2202G" - case invalid_time_zone_displacement_value = "22009" - case invalid_use_of_escape_character = "2200C" - case most_specific_type_mismatch = "2200G" - case null_value_not_allowed = "22004" - case null_value_no_indicator_parameter = "22002" - case numeric_value_out_of_range = "22003" - case string_data_length_mismatch = "22026" - // case string_data_right_truncation = "22001" - case substring_error = "22011" - case trim_error = "22027" - case unterminated_c_string = "22024" - case zero_length_character_string = "2200F" - case floating_point_exception = "22P01" - case invalid_text_representation = "22P02" - case invalid_binary_representation = "22P03" - case bad_copy_file_format = "22P04" - case untranslatable_character = "22P05" - case not_an_xml_document = "2200L" - case invalid_xml_document = "2200M" - case invalid_xml_content = "2200N" - case invalid_xml_comment = "2200S" - case invalid_xml_processing_instruction = "2200T" + case dataException = "22000" + case arraySubscriptError = "2202E" + case characterNotInRepertoire = "22021" + case datetimeFieldOverflow = "22008" + case divisionByZero = "22012" + case errorInAssignment = "22005" + case escapeCharacterConflict = "2200B" + case indicatorOverflow = "22022" + case intervalFieldOverflow = "22015" + case invalidArgumentForLogarithm = "2201E" + case invalidArgumentForNtileFunction = "22014" + case invalidArgumentForNthValueFunction = "22016" + case invalidArgumentForPowerFunction = "2201F" + case invalidArgumentForWidthBucketFunction = "2201G" + case invalidCharacterValueForCast = "22018" + case invalidDatetimeFormat = "22007" + case invalidEscapeCharacter = "22019" + case invalidEscapeOctet = "2200D" + case invalidEscapeSequence = "22025" + case nonstandardUseOfEscapeCharacter = "22P06" + case invalidIndicatorParameterValue = "22010" + case invalidParameterValue = "22023" + case invalidRegularExpression = "2201B" + case invalidRowCountInLimitClause = "2201W" + case invalidRowCountInResultOffsetClause = "2201X" + case invalidTablesampleArgument = "2202H" + case invalidTablesampleRepeat = "2202G" + case invalidTimeZoneDisplacementValue = "22009" + case invalidUseOfEscapeCharacter = "2200C" + case mostSpecificTypeMismatch = "2200G" + case nullValueNotAllowed = "22004" + case nullValueNoIndicatorParameter = "22002" + case numericValueOutOfRange = "22003" + case stringDataLengthMismatch = "22026" + case stringDataRightTruncationException = "22001" + case substringError = "22011" + case trimError = "22027" + case unterminatedCString = "22024" + case zeroLengthCharacterString = "2200F" + case floatingPointException = "22P01" + case invalidTextRepresentation = "22P02" + case invalidBinaryRepresentation = "22P03" + case badCopyFileFormat = "22P04" + case untranslatableCharacter = "22P05" + case notAnXmlDocument = "2200L" + case invalidXmlDocument = "2200M" + case invalidXmlContent = "2200N" + case invalidXmlComment = "2200S" + case invalidXmlProcessingInstruction = "2200T" // Class 23 — Integrity Constraint Violation - case integrity_constraint_violation = "23000" - case restrict_violation = "23001" - case not_null_violation = "23502" - case foreign_key_violation = "23503" - case unique_violation = "23505" - case check_violation = "23514" - case exclusion_violation = "23P01" + case integrityConstraintViolation = "23000" + case restrictViolation = "23001" + case notNullViolation = "23502" + case foreignKeyViolation = "23503" + case uniqueViolation = "23505" + case checkViolation = "23514" + case exclusionViolation = "23P01" // Class 24 — Invalid Cursor State - case invalid_cursor_state = "24000" + case invalidCursorState = "24000" // Class 25 — Invalid Transaction State - case invalid_transaction_state = "25000" - case active_sql_transaction = "25001" - case branch_transaction_already_active = "25002" - case held_cursor_requires_same_isolation_level = "25008" - case inappropriate_access_mode_for_branch_transaction = "25003" - case inappropriate_isolation_level_for_branch_transaction = "25004" - case no_active_sql_transaction_for_branch_transaction = "25005" - case read_only_sql_transaction = "25006" - case schema_and_data_statement_mixing_not_supported = "25007" - case no_active_sql_transaction = "25P01" - case in_failed_sql_transaction = "25P02" - case idle_in_transaction_session_timeout = "25P03" + case invalidTransactionState = "25000" + case activeSqlTransaction = "25001" + case branchTransactionAlreadyActive = "25002" + case heldCursorRequiresSameIsolationLevel = "25008" + case inappropriateAccessModeForBranchTransaction = "25003" + case inappropriateIsolationLevelForBranchTransaction = "25004" + case noActiveSqlTransactionForBranchTransaction = "25005" + case readOnlySqlTransaction = "25006" + case schemaAndDataStatementMixingNotSupported = "25007" + case noActiveSqlTransaction = "25P01" + case inFailedSqlTransaction = "25P02" + case idleInTransactionSessionTimeout = "25P03" // Class 26 — Invalid SQL Statement Name - case invalid_sql_statement_name = "26000" + case invalidSqlStatementName = "26000" // Class 27 — Triggered Data Change Violation - case triggered_data_change_violation = "27000" + case triggeredDataChangeViolation = "27000" // Class 28 — Invalid Authorization Specification - case invalid_authorization_specification = "28000" - case invalid_password = "28P01" + case invalidAuthorizationSpecification = "28000" + case invalidPassword = "28P01" // Class 2B — Dependent Privilege Descriptors Still Exist - case dependent_privilege_descriptors_still_exist = "2B000" - case dependent_objects_still_exist = "2BP01" + case dependentPrivilegeDescriptorsStillExist = "2B000" + case dependentObjectsStillExist = "2BP01" // Class 2D — Invalid Transaction Termination - case invalid_transaction_termination = "2D000" + case invalidTransactionTermination = "2D000" // Class 2F — SQL Routine Exception - case sql_routine_exception = "2F000" - case function_executed_no_return_statement = "2F005" - case modifying_sql_data_not_permitted = "2F002" - case prohibited_sql_statement_attempted = "2F003" - case reading_sql_data_not_permitted = "2F004" + case sqlRoutineException = "2F000" + case functionExecutedNoReturnStatement = "2F005" + case modifyingSqlDataNotPermittedSQL = "2F002" + case prohibitedSqlStatementAttemptedSQL = "2F003" + case readingSqlDataNotPermittedSQL = "2F004" // Class 34 — Invalid Cursor Name - case invalid_cursor_name = "34000" + case invalidCursorName = "34000" // Class 38 — External Routine Exception - case external_routine_exception = "38000" - case containing_sql_not_permitted = "38001" - // case modifying_sql_data_not_permitted = "38002" - // case prohibited_sql_statement_attempted = "38003" - // case reading_sql_data_not_permitted = "38004" + case externalRoutineException = "38000" + case containingSqlNotPermitted = "38001" + case modifyingSqlDataNotPermittedExternal = "38002" + case prohibitedSqlStatementAttemptedExternal = "38003" + case readingSqlDataNotPermittedExternal = "38004" // Class 39 — External Routine Invocation Exception - case external_routine_invocation_exception = "39000" - case invalid_sqlstate_returned = "39001" + case externalRoutineInvocationException = "39000" + case invalidSqlstateReturned = "39001" // case null_value_not_allowed = "39004" - case trigger_protocol_violated = "39P01" - case srf_protocol_violated = "39P02" + case triggerProtocolViolated = "39P01" + case srfProtocolViolated = "39P02" case event_trigger_protocol_violated = "39P03" // Class 3B — Savepoint Exception - case savepoint_exception = "3B000" - case invalid_savepoint_specification = "3B001" + case savepointException = "3B000" + case invalidSavepointSpecification = "3B001" // Class 3D — Invalid Catalog Name - case invalid_catalog_name = "3D000" + case invalidCatalogName = "3D000" // Class 3F — Invalid Schema Name - case invalid_schema_name = "3F000" + case invalidSchemaName = "3F000" // Class 40 — Transaction Rollback - case transaction_rollback = "40000" - case transaction_integrity_constraint_violation = "40002" - case serialization_failure = "40001" - case statement_completion_unknown = "40003" - case deadlock_detected = "40P01" + case transactionRollback = "40000" + case transactionIntegrityConstraintViolation = "40002" + case serializationFailure = "40001" + case statementCompletionUnknown = "40003" + case deadlockDetected = "40P01" // Class 42 — Syntax Error or Access Rule Violation - case syntax_error_or_access_rule_violation = "42000" - case syntax_error = "42601" - case insufficient_privilege = "42501" - case cannot_coerce = "42846" - case grouping_error = "42803" - case windowing_error = "42P20" - case invalid_recursion = "42P19" - case invalid_foreign_key = "42830" - case invalid_name = "42602" - case name_too_long = "42622" - case reserved_name = "42939" - case datatype_mismatch = "42804" - case indeterminate_datatype = "42P18" - case collation_mismatch = "42P21" - case indeterminate_collation = "42P22" - case wrong_object_type = "42809" - case undefined_column = "42703" - case undefined_function = "42883" - case undefined_table = "42P01" - case undefined_parameter = "42P02" - case undefined_object = "42704" - case duplicate_column = "42701" - case duplicate_cursor = "42P03" - case duplicate_database = "42P04" - case duplicate_function = "42723" - case duplicate_prepared_statement = "42P05" - case duplicate_schema = "42P06" - case duplicate_table = "42P07" - case duplicate_alias = "42712" - case duplicate_object = "42710" - case ambiguous_column = "42702" - case ambiguous_function = "42725" - case ambiguous_parameter = "42P08" - case ambiguous_alias = "42P09" - case invalid_column_reference = "42P10" - case invalid_column_definition = "42611" - case invalid_cursor_definition = "42P11" - case invalid_database_definition = "42P12" - case invalid_function_definition = "42P13" - case invalid_prepared_statement_definition = "42P14" - case invalid_schema_definition = "42P15" - case invalid_table_definition = "42P16" - case invalid_object_definition = "42P17" + case syntaxErrorOrAccessRuleViolation = "42000" + case syntaxError = "42601" + case insufficientPrivilege = "42501" + case cannotCoerce = "42846" + case groupingError = "42803" + case windowingError = "42P20" + case invalidRecursion = "42P19" + case invalidForeignKey = "42830" + case invalidName = "42602" + case nameTooLong = "42622" + case reservedName = "42939" + case datatypeMismatch = "42804" + case indeterminateDatatype = "42P18" + case collationMismatch = "42P21" + case indeterminateCollation = "42P22" + case wrongObjectType = "42809" + case undefinedColumn = "42703" + case undefinedFunction = "42883" + case undefinedTable = "42P01" + case undefinedParameter = "42P02" + case undefinedObject = "42704" + case duplicateColumn = "42701" + case duplicateCursor = "42P03" + case duplicateDatabase = "42P04" + case duplicateFunction = "42723" + case duplicatePreparedStatement = "42P05" + case duplicateSchema = "42P06" + case duplicateTable = "42P07" + case duplicateAlias = "42712" + case duplicateObject = "42710" + case ambiguousColumn = "42702" + case ambiguousFunction = "42725" + case ambiguousParameter = "42P08" + case ambiguousAlias = "42P09" + case invalidColumnReference = "42P10" + case invalidColumnDefinition = "42611" + case invalidCursorDefinition = "42P11" + case invalidDatabaseDefinition = "42P12" + case invalidFunctionDefinition = "42P13" + case invalidPreparedStatementDefinition = "42P14" + case invalidSchemaDefinition = "42P15" + case invalidTableDefinition = "42P16" + case invalidObjectDefinition = "42P17" // Class 44 — WITH CHECK OPTION Violation - case with_check_option_violation = "44000" + case withCheckOptionViolation = "44000" // Class 53 — Insufficient Resources - case insufficient_resources = "53000" - case disk_full = "53100" - case out_of_memory = "53200" - case too_many_connections = "53300" - case configuration_limit_exceeded = "53400" + case insufficientResources = "53000" + case diskFull = "53100" + case outOfMemory = "53200" + case tooManyConnections = "53300" + case configurationLimitExceeded = "53400" // Class 54 — Program Limit Exceeded - case program_limit_exceeded = "54000" - case statement_too_complex = "54001" - case too_many_columns = "54011" - case too_many_arguments = "54023" + case programLimitExceeded = "54000" + case statementTooComplex = "54001" + case tooManyColumns = "54011" + case tooManyArguments = "54023" // Class 55 — Object Not In Prerequisite State - case object_not_in_prerequisite_state = "55000" - case object_in_use = "55006" - case cant_change_runtime_param = "55P02" - case lock_not_available = "55P03" + case objectNotInPrerequisiteState = "55000" + case objectInUse = "55006" + case cantChangeRuntimeParam = "55P02" + case lockNotAvailable = "55P03" // Class 57 — Operator Intervention - case operator_intervention = "57000" - case query_canceled = "57014" - case admin_shutdown = "57P01" - case crash_shutdown = "57P02" - case cannot_connect_now = "57P03" - case database_dropped = "57P04" + case operatorIntervention = "57000" + case queryCanceled = "57014" + case adminShutdown = "57P01" + case crashShutdown = "57P02" + case cannotConnectNow = "57P03" + case databaseDropped = "57P04" // Class 58 — System Error (errors external to PostgreSQL itself) - case system_error = "58000" - case io_error = "58030" - case undefined_file = "58P01" - case duplicate_file = "58P02" + case systemError = "58000" + case ioError = "58030" + case undefinedFile = "58P01" + case duplicateFile = "58P02" // Class 72 — Snapshot Failure - case snapshot_too_old = "72000" + case snapshotTooOld = "72000" // Class F0 — Configuration File Error - case config_file_error = "F0000" - case lock_file_exists = "F0001" + case configFileError = "F0000" + case lockFileExists = "F0001" // Class HV — Foreign Data Wrapper Error (SQL/MED) - case fdw_error = "HV000" - case fdw_column_name_not_found = "HV005" - case fdw_dynamic_parameter_value_needed = "HV002" - case fdw_function_sequence_error = "HV010" - case fdw_inconsistent_descriptor_information = "HV021" - case fdw_invalid_attribute_value = "HV024" - case fdw_invalid_column_name = "HV007" - case fdw_invalid_column_number = "HV008" - case fdw_invalid_data_type = "HV004" - case fdw_invalid_data_type_descriptors = "HV006" - case fdw_invalid_descriptor_field_identifier = "HV091" - case fdw_invalid_handle = "HV00B" - case fdw_invalid_option_index = "HV00C" - case fdw_invalid_option_name = "HV00D" - case fdw_invalid_string_length_or_buffer_length = "HV090" - case fdw_invalid_string_format = "HV00A" - case fdw_invalid_use_of_null_pointer = "HV009" - case fdw_too_many_handles = "HV014" - case fdw_out_of_memory = "HV001" - case fdw_no_schemas = "HV00P" - case fdw_option_name_not_found = "HV00J" - case fdw_reply_handle = "HV00K" - case fdw_schema_not_found = "HV00Q" - case fdw_table_not_found = "HV00R" - case fdw_unable_to_create_execution = "HV00L" - case fdw_unable_to_create_reply = "HV00M" - case fdw_unable_to_establish_connection = "HV00N" - case plpgsql_error = "P0000" - case raise_exception = "P0001" - case no_data_found = "P0002" - case too_many_rows = "P0003" - case assert_failure = "P0004" + case fdwError = "HV000" + case fdwColumnNameNotFound = "HV005" + case fdwDynamicParameterValueNeeded = "HV002" + case fdwFunctionSequenceError = "HV010" + case fdwInconsistentDescriptorInformation = "HV021" + case fdwInvalidAttributeValue = "HV024" + case fdwInvalidColumnName = "HV007" + case fdwInvalidColumnNumber = "HV008" + case fdwInvalidDataType = "HV004" + case fdwInvalidDataTypeDescriptors = "HV006" + case fdwInvalidDescriptorFieldIdentifier = "HV091" + case fdwInvalidHandle = "HV00B" + case fdwInvalidOptionIndex = "HV00C" + case fdwInvalidOptionName = "HV00D" + case fdwInvalidStringLengthOrBufferLength = "HV090" + case fdwInvalidStringFormat = "HV00A" + case fdwInvalidUseOfNullPointer = "HV009" + case fdwTooManyHandles = "HV014" + case fdwOutOfMemory = "HV001" + case fdwNoSchemas = "HV00P" + case fdwOptionNameNotFound = "HV00J" + case fdwReplyHandle = "HV00K" + case fdwSchemaNotFound = "HV00Q" + case fdwTableNotFound = "HV00R" + case fdwUnableToCreateExecution = "HV00L" + case fdwUnableToCreateReply = "HV00M" + case fdwUnableToEstablishConnection = "HV00N" + case plpgsqlError = "P0000" + case raiseException = "P0001" + case noDataFound = "P0002" + case tooManyRows = "P0003" + case assertFailure = "P0004" // Class XX — Internal Error - case internal_error = "XX000" - case data_corrupted = "XX001" - case index_corrupted = "XX002" - case unknown = "Unknown" + case internalError = "XX000" + case dataCorrupted = "XX001" + case indexCorrupted = "XX002" + + case unknown } } // MARK: Inits extension PostgreSQLError { - public init(_ connection: Connection) { - let raw = String(cString: PQerrorMessage(connection.cConnection)) - - let message: String + public init(code: Code, connection: Connection) { + let reason: String if let error = PQerrorMessage(connection.cConnection) { - message = String(cString: error) - } else { - message = "Unknown" + reason = String(cString: error) } - - self.init( - rawCode: raw, - reason: message - ) - } - - public init(_ code: Code, reason: String) { - self.code = code - self.reason = reason + else { + reason = "Unknown" + } + + self.init(code: code, reason: reason) } - - public init(rawCode: String, reason: String) { - self.code = Code(rawValue: rawCode) ?? .unknown - self.reason = reason + + public init(result: Result) { + guard let pointer = result.pointer else { + self.init(code: .unknown, reason: "Unknown") + return + } + + let code: Code + if let rawCodePointer = PQresultErrorField(pointer, 67) { // 67 == 'C' == PG_DIAG_SQLSTATE + let rawCode = String(cString: rawCodePointer) + code = Code(rawValue: rawCode) ?? .unknown + } + else { + code = .unknown + } + + let reason: String + if let messagePointer = PQresultErrorMessage(pointer) { + reason = String(cString: messagePointer) + } + else { + reason = "Unknown" + } + + self.init(code: code, reason: reason) } } @@ -339,7 +356,7 @@ extension PostgreSQLError: Debuggable { public var possibleCauses: [String] { switch code { - case .connection_exception, .connection_does_not_exist, .connection_failure: + case .connectionException, .connectionDoesNotExist, .connectionFailure: return [ "The connection to the server degraded during the query", "The connection has been open for too long", @@ -352,15 +369,15 @@ extension PostgreSQLError: Debuggable { public var suggestedFixes: [String] { switch code { - case .syntax_error: + case .syntaxError: return [ "Fix the invalid syntax in your query", "If an ORM has generated this error, report the issue to its GitHub page" ] - case .connection_exception, .connection_does_not_exist, .connection_failure: + case .connectionException, .connectionFailure: return [ - "Increase the `wait_timeout`", - "Increase the `max_allowed_packet`" + "Make sure you have entered the correct username and password", + "Make sure the database has been created" ] default: return [] @@ -368,17 +385,10 @@ extension PostgreSQLError: Debuggable { } public var stackOverflowQuestions: [String] { - switch code { - case .syntax_error: - return [ - ] - default: - return [] - } + return [] } public var documentationLinks: [String] { - return [ - ] + return [] } } diff --git a/Sources/PostgreSQL/PostgreSQL+Node.swift b/Sources/PostgreSQL/Exports.swift similarity index 100% rename from Sources/PostgreSQL/PostgreSQL+Node.swift rename to Sources/PostgreSQL/Exports.swift diff --git a/Sources/PostgreSQL/Node+Binding.swift b/Sources/PostgreSQL/Node+Binding.swift deleted file mode 100644 index c25b4e1..0000000 --- a/Sources/PostgreSQL/Node+Binding.swift +++ /dev/null @@ -1,136 +0,0 @@ -import Foundation -import Core - -protocol Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { get } -} - -extension Node: Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { - switch wrapped { - case .null: - // PQexecParams converts nil pointer to NULL. - // see: https://www.postgresql.org/docs/9.1/static/libpq-exec.html - return (nil, nil, .string) - - case .bytes(let bytes): - let int8Bytes = bytes.map { Int8(bitPattern: $0) } - return (int8Bytes, nil, .binary) - - case .bool(let bool): - return bool.postgresBindingData - - case .number(let number): - if case .double(let value) = number { - return value.postgresBindingData - } else { - return number.int.postgresBindingData - } - - case .string(let string): - return string.postgresBindingData - - case .array(let array): - let elements = array.map { $0.postgresArrayElementString } - let arrayString = "{\(elements.joined(separator: ","))}" - return (arrayString.utf8CString.array, .none, .string) - - case .object(_): - print("Unsupported Node type for PostgreSQL binding, everything except for .object is supported.") - return (nil, nil, .string) - - default: - return (nil, nil, .string) - } - } -} - -extension StructuredData { - var postgresArrayElementString: String { - switch self { - case .null: - return "NULL" - - case .bytes(let bytes): - let hexString = bytes.map { $0.lowercaseHexPair }.joined() - return "\"\\\\x\(hexString)\"" - - case .bool(let bool): - return bool ? "t" : "f" - - case .number(let number): - return number.description - - case .string(let string): - let escapedString = string - .replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "\"", with: "\\\"") - return "\"\(escapedString)\"" - - case .array(let array): - let elements = array.map { $0.postgresArrayElementString } - return "{\(elements.joined(separator: ","))}" - - case .object(_): - print("Unsupported Node array type for PostgreSQL binding, everything except for .object is supported.") - return "NULL" - - default: - return "" - } - } -} - -extension Bool: Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { - return ([self ? 1 : 0], .bool, .binary) - } -} - -extension Int: Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { - let count = MemoryLayout.size(ofValue: self) - - let oid: OID - switch count { - case 2: - oid = .int2 - case 4: - oid = .int4 - case 8: - oid = .int8 - default: - // Unsupported integer size, use string instead - return description.postgresBindingData - } - - var value = bigEndian - return (PostgresBinaryUtils.valueToByteArray(&value), oid, .binary) - } -} - -extension Double: Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { - let count = MemoryLayout.size(ofValue: self) - - let oid: OID - switch count { - case 4: - oid = .float4 - case 8: - oid = .float8 - default: - // Unsupported float size, use string instead - return description.postgresBindingData - } - - var value = bigEndian - return (PostgresBinaryUtils.valueToByteArray(&value), oid, .binary) - } -} - -extension String: Bindable { - var postgresBindingData: ([Int8]?, OID?, DataFormat) { - return (utf8CString.array, .none, .string) - } -} diff --git a/Sources/PostgreSQL/Node+Oid.swift b/Sources/PostgreSQL/Node+Oid.swift deleted file mode 100644 index be36f0d..0000000 --- a/Sources/PostgreSQL/Node+Oid.swift +++ /dev/null @@ -1,287 +0,0 @@ -import CPostgreSQL -import Foundation - -/// Oid values can be found in the following file: -/// https://github.com/postgres/postgres/blob/55c3391d1e6a201b5b891781d21fe682a8c64fe6/src/include/catalog/pg_type.h -enum OID: Oid { - case bool = 16 - - case int2 = 21 - case int4 = 23 - case int8 = 20 - - case bytea = 17 - - case char = 18 - case name = 19 - case text = 25 - case bpchar = 1042 - case varchar = 1043 - - case json = 114 - case jsonb = 3802 - case xml = 142 - - case float4 = 700 - case float8 = 701 - - case numeric = 1700 - - case date = 1082 - case time = 1083 - case timetz = 1266 - case timestamp = 1114 - case timestamptz = 1184 - case interval = 1186 - - case uuid = 2950 - - case point = 600 - case lseg = 601 - case path = 602 - case box = 603 - case polygon = 604 - case circle = 718 - - case cidr = 650 - case inet = 869 - case macaddr = 829 - - case bit = 1560 - case varbit = 1562 - - static let supportedArrayOIDs: Set = [ - 1000, // bool - - 1005, // int2 - 1007, // int4 - 1016, // int8 - - 1001, // bytea - - 1002, // char - 1003, // name - 1009, // text - 1014, // bpchar - 1015, // varchar - - 199, // json - 3807, // jsonb - 143, // xml - - 1021, // float4 - 1022, // float8 - - 1231, // numeric - - 1182, // date - 1183, // time - 1270, // timetz - 1115, // timestamp - 1185, // timestamptz - 1187, // interval - - 2951, // uuid - - 1017, // point - 1018, // lseg - 1019, // path - 1020, // box - 1027, // polygon - 719, // circle - - 651, // cidr - 1041, // inet - 1040, // macaddr - - 1561, // bit - 1563, // varbit - ] -} - -extension Node { - init(configuration: Configuration, oid: Oid, value: UnsafeMutablePointer, length: Int) { - // Check if we support the type - guard let type = OID(rawValue: oid) else { - // Check if we have an array type and try to convert - if OID.supportedArrayOIDs.contains(oid), let node = Node(configuration: configuration, arrayValue: value) { - self = node - } else { - // Otherwise fallback to simply passing on the bytes - let bytes = PostgresBinaryUtils.parseBytes(value: value, length: length) - self = .bytes(bytes) - } - return - } - - self = Node(configuration: configuration, oid: type, value: value, length: length) - } - - init(configuration: Configuration, oid: OID, value: UnsafeMutablePointer, length: Int) { - switch oid { - case .bool: - self = .bool(value[0] != 0) - - case .char, .name, .text, .json, .xml, .bpchar, .varchar: - let string = PostgresBinaryUtils.parseString(value: value, length: length) - self = .string(string) - - case .jsonb: - // Ignore jsonb version number - let jsonValue = value.advanced(by: 1) - let string = PostgresBinaryUtils.parseString(value: jsonValue, length: length - 1) - self = .string(string) - - case .int2: - let integer = PostgresBinaryUtils.parseInt16(value: value) - self = .number(.int(Int(integer))) - - case .int4: - let integer = PostgresBinaryUtils.parseInt32(value: value) - self = .number(.int(Int(integer))) - - case .int8: - let integer = PostgresBinaryUtils.parseInt64(value: value) - if let intValue = Int(exactly: integer) { - self = .number(.int(intValue)) - } else { - self = .number(.double(Double(integer))) - } - - case .bytea: - let bytes = PostgresBinaryUtils.parseBytes(value: value, length: length) - self = .bytes(bytes) - - case .float4: - let float = PostgresBinaryUtils.parseFloat32(value: value) - self = .number(.double(Double(float))) - - case .float8: - let float = PostgresBinaryUtils.parseFloat64(value: value) - self = .number(.double(Double(float))) - - case .numeric: - let number = PostgresBinaryUtils.parseNumeric(value: value) - self = .string(number) - - case .uuid: - let uuid = PostgresBinaryUtils.parseUUID(value: value) - self = .string(uuid) - - case .timestamp, .timestamptz, .date, .time, .timetz: - let date = PostgresBinaryUtils.parseTimetamp(value: value, isInteger: configuration.hasIntegerDatetimes) - let formatter = PostgresBinaryUtils.Formatters.dateFormatter(for: oid) - let timestamp = formatter.string(from: date) - self = .string(timestamp) - - case .interval: - let interval = PostgresBinaryUtils.parseInterval(value: value, timeIsInteger: configuration.hasIntegerDatetimes) - self = .string(interval) - - case .point: - let point = PostgresBinaryUtils.parsePoint(value: value) - self = .string(point) - - case .lseg: - let lseg = PostgresBinaryUtils.parseLineSegment(value: value) - self = .string(lseg) - - case .path: - let path = PostgresBinaryUtils.parsePath(value: value) - self = .string(path) - - case .box: - let box = PostgresBinaryUtils.parseBox(value: value) - self = .string(box) - - case .polygon: - let polygon = PostgresBinaryUtils.parsePolygon(value: value) - self = .string(polygon) - - case .circle: - let circle = PostgresBinaryUtils.parseCircle(value: value) - self = .string(circle) - - case .inet, .cidr: - let inet = PostgresBinaryUtils.parseIPAddress(value: value) - self = .string(inet) - - case .macaddr: - let macaddr = PostgresBinaryUtils.parseMacAddress(value: value) - self = .string(macaddr) - - case .bit, .varbit: - let bitString = PostgresBinaryUtils.parseBitString(value: value, length: length) - self = .string(bitString) - } - } - - private init?(configuration: Configuration, arrayValue: UnsafeMutablePointer) { - let elementOid = Oid(bigEndian: PostgresBinaryUtils.convert(arrayValue.advanced(by: 8))) - - // Check if we support the type - guard let type = OID(rawValue: elementOid) else { - return nil - } - - // Get the dimension of the array - let arrayDimension = PostgresBinaryUtils.parseInt32(value: arrayValue) - guard arrayDimension > 0 else { - self = .array([]) - return - } - - var pointer = arrayValue.advanced(by: 12) - - // Get all dimension lengths - var dimensionLengths: [Int] = [] - for _ in 0..) -> Node { - // Get the length of the array - let arrayLength = dimensionLengths[0] - - // Create elements array - var elements: [Node] = [] - elements.reserveCapacity(arrayLength) - - // Loop through array and convert each item - for _ in 0.. 1 { - - var subDimensionLengths = dimensionLengths - subDimensionLengths.removeFirst() - - let array = parseArray(configuration: configuration, type: type, dimensionLengths: subDimensionLengths, pointer: &pointer) - elements.append(array) - - } else { - - let elementLength = Int(PostgresBinaryUtils.parseInt32(value: pointer)) - pointer = pointer.advanced(by: 4) - - // Check if the element is null - guard elementLength != -1 else { - elements.append(.null) - continue - } - - // Parse to node - let node = Node(configuration: configuration, oid: type, value: pointer, length: elementLength) - elements.append(node) - pointer = pointer.advanced(by: elementLength) - } - } - - return .array(elements) - } -} diff --git a/Sources/PostgreSQL/Notification.swift b/Sources/PostgreSQL/Notification.swift deleted file mode 100644 index cae2b0c..0000000 --- a/Sources/PostgreSQL/Notification.swift +++ /dev/null @@ -1,19 +0,0 @@ -public struct Notification { - let channel: String - let payload: String? - let pid: Int -} - -extension Notification { - init(relname: UnsafeMutablePointer, extra: UnsafeMutablePointer, be_pid: Int32) { - self.channel = String(cString: relname) - self.pid = Int(be_pid) - - if (extra.pointee != 0) { - self.payload = String(cString: extra) - } - else { - self.payload = nil - } - } -} diff --git a/Sources/PostgreSQL/Result.swift b/Sources/PostgreSQL/Result.swift index 7178033..12e5a6d 100644 --- a/Sources/PostgreSQL/Result.swift +++ b/Sources/PostgreSQL/Result.swift @@ -1,48 +1,144 @@ import CPostgreSQL -class Result { - typealias Pointer = OpaquePointer - - private let pointer: Pointer - private let configuration: Configuration - let parsed: [[String: Node]] - - init(configuration: Configuration, pointer: Pointer) { - self.configuration = configuration +public class Result { + + // MARK: - Pointer + + public typealias Pointer = OpaquePointer + + // MARK: - Status + + public enum Status { + case commandOk + case tuplesOk + case copyOut + case copyIn + case copyBoth + case badResponse + case nonFatalError + case fatalError + case emptyQuery + + init(_ pointer: Pointer?) { + guard let pointer = pointer else { + self = .fatalError + return + } + + switch PQresultStatus(pointer) { + case PGRES_COMMAND_OK: + self = .commandOk + case PGRES_TUPLES_OK: + self = .tuplesOk + case PGRES_COPY_OUT: + self = .copyOut + case PGRES_COPY_IN: + self = .copyIn + case PGRES_COPY_BOTH: + self = .copyBoth + case PGRES_BAD_RESPONSE: + self = .badResponse + case PGRES_NONFATAL_ERROR: + self = .nonFatalError + case PGRES_FATAL_ERROR: + self = .fatalError + case PGRES_EMPTY_QUERY: + self = .emptyQuery + default: + self = .fatalError + } + } + } + + // MARK: - Properties + + public let pointer: Pointer? + public let connection: Connection + public let status: Status + + // MARK: - Init + + public init(pointer: Pointer?, connection: Connection) { self.pointer = pointer + self.connection = connection + status = Status(pointer) + } + + // MARK: - Deinit + + deinit { + if let pointer = pointer { + PQclear(pointer) + } + } + + // MARK: - Value + + public func parseData() throws -> Node { + switch status { + case .nonFatalError, .fatalError: + throw PostgreSQLError(result: self) + + case .badResponse: + throw PostgresSQLStatusError.badResponse + + case .emptyQuery: + throw PostgresSQLStatusError.emptyQuery + + case .copyOut, .copyIn, .copyBoth, .commandOk: + // No data to parse + return Node(.null, in: PostgreSQLContext.shared) + + case .tuplesOk: + break + } + + var results: [StructuredData] = [] - var parsed: [[String: Node]] = [] + // This single dictionary is reused for all rows in the result set + // to avoid the runtime overhead of (de)allocating one per row. + var parsed: [String: StructuredData] = [:] let rowCount = PQntuples(pointer) let columnCount = PQnfields(pointer) - + if rowCount > 0 && columnCount > 0 { for row in 0..", "<(-1.2,-3.4),98>", "<(123.67,-598.15),0.123>", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, circle circle)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, circle circle)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let circle = resultRow["circle"] @@ -458,6 +517,8 @@ class PostgreSQLTests: XCTestCase { } func testInets() throws { + let conn = try postgreSQL.makeConnection() + let rows = [ "192.168.100.128", "192.168.100.128/25", @@ -468,13 +529,13 @@ class PostgreSQLTests: XCTestCase { "127.0.0.1", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, inet inet)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, inet inet)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let inet = resultRow["inet"] @@ -484,6 +545,8 @@ class PostgreSQLTests: XCTestCase { } func testCidrs() throws { + let conn = try postgreSQL.makeConnection() + let rows = [ "192.168.100.128/32", "192.168.100.128/25", @@ -494,13 +557,13 @@ class PostgreSQLTests: XCTestCase { "127.0.0.1/32", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, cidr cidr)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, cidr cidr)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let cidr = resultRow["cidr"] @@ -510,6 +573,8 @@ class PostgreSQLTests: XCTestCase { } func testMacAddresses() throws { + let conn = try postgreSQL.makeConnection() + let rows = [ "5a:92:79:a1:ce:1a", "74:da:91:28:6a:a6", @@ -523,13 +588,13 @@ class PostgreSQLTests: XCTestCase { "58:ff:b8:e9:85:30", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, macaddr macaddr)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, macaddr macaddr)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let macaddr = resultRow["macaddr"] @@ -539,6 +604,8 @@ class PostgreSQLTests: XCTestCase { } func testBitStrings() throws { + let conn = try postgreSQL.makeConnection() + let rows = [ "01010", "00000", @@ -551,13 +618,13 @@ class PostgreSQLTests: XCTestCase { "10000", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, bits bit(5))") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, bits bit(5))") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let bits = resultRow["bits"] @@ -567,6 +634,8 @@ class PostgreSQLTests: XCTestCase { } func testVarBitStrings() throws { + let conn = try postgreSQL.makeConnection() + let rows = [ "0", "1", @@ -581,13 +650,13 @@ class PostgreSQLTests: XCTestCase { "1111111111", ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, bits bit varying)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, bits bit varying)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row.makeNode(in: nil)]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for (i, resultRow) in result.enumerated() { let bits = resultRow["bits"] @@ -597,6 +666,8 @@ class PostgreSQLTests: XCTestCase { } func testUnsupportedObject() throws { + let conn = try postgreSQL.makeConnection() + let rows: [Node] = [ .object(["1":1, "2":2]), .object(["1":1, "2":2, "3":3]), @@ -604,13 +675,13 @@ class PostgreSQLTests: XCTestCase { .object(["1":1]), ] - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, text text)") + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, text text)") for row in rows { - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row]) + try conn.execute("INSERT INTO foo VALUES (DEFAULT, $1)", [row]) } - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, rows.count) for resultRow in result { let value = resultRow["text"] @@ -620,14 +691,16 @@ class PostgreSQLTests: XCTestCase { } func testUnsupportedOID() throws { - try postgreSQL.execute("DROP TABLE IF EXISTS foo") - try postgreSQL.execute("CREATE TABLE foo (id serial, oid oid)") - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, 1)", nil) - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, 2)", nil) - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, 123)", nil) - try postgreSQL.execute("INSERT INTO foo VALUES (DEFAULT, 456)", nil) - - let result = try postgreSQL.execute("SELECT * FROM foo ORDER BY id ASC") + let conn = try postgreSQL.makeConnection() + + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (id serial, oid oid)") + try conn.execute("INSERT INTO foo VALUES (DEFAULT, 1)") + try conn.execute("INSERT INTO foo VALUES (DEFAULT, 2)") + try conn.execute("INSERT INTO foo VALUES (DEFAULT, 123)") + try conn.execute("INSERT INTO foo VALUES (DEFAULT, 456)") + + let result = try conn.execute("SELECT * FROM foo ORDER BY id ASC").array ?? [] XCTAssertEqual(result.count, 4) for resultRow in result { let value = resultRow["oid"] @@ -636,41 +709,147 @@ class PostgreSQLTests: XCTestCase { } func testNotification() throws { + let conn1 = try postgreSQL.makeConnection() + let conn2 = try postgreSQL.makeConnection() + let testExpectation = expectation(description: "Receive notification") - postgreSQL.listen(to: "test_channel1") { notification in - XCTAssertEqual(notification.channel, "test_channel1") - XCTAssertNil(notification.payload) + conn1.listen(toChannel: "test_channel1") { (notification, error, stop) in + XCTAssertEqual(notification?.channel, "test_channel1") + XCTAssertNil(notification?.payload) + XCTAssertNil(error) testExpectation.fulfill() + stop = true } sleep(1) - try postgreSQL.notify(channel: "test_channel1", payload: nil) + try conn2.notify(channel: "test_channel1", payload: nil) waitForExpectations(timeout: 5) } func testNotificationWithPayload() throws { + let conn1 = try postgreSQL.makeConnection() + let conn2 = try postgreSQL.makeConnection() + let testExpectation = expectation(description: "Receive notification with payload") - postgreSQL.listen(to: "test_channel2") { notification in - XCTAssertEqual(notification.channel, "test_channel2") - XCTAssertEqual(notification.payload, "test_payload") + conn1.listen(toChannel: "test_channel2") { (notification, error, stop) in + XCTAssertEqual(notification?.channel, "test_channel2") + XCTAssertEqual(notification?.payload, "test_payload") + XCTAssertNil(error) testExpectation.fulfill() + stop = true } sleep(1) - try postgreSQL.notify(channel: "test_channel2", payload: "test_payload") + try conn2.notify(channel: "test_channel2", payload: "test_payload") waitForExpectations(timeout: 5) } func testQueryToNode() throws { - let results: Node = try postgreSQL.makeConnection().execute("SELECT version()", []) + let conn = try postgreSQL.makeConnection() + + let results = try conn.execute("SELECT version()") XCTAssertNotNil(results.array?[0].object?["version"]?.string) } + + func testEmptyQuery() throws { + let conn = try postgreSQL.makeConnection() + + do { + try conn.execute("") + XCTFail("This query should not succeed") + } + catch PostgresSQLStatusError.emptyQuery { + // Should end up here + } + catch { + throw error + } + } + + func testInvalidQuery() throws { + let conn = try postgreSQL.makeConnection() + + do { + try conn.execute("SELECT * FROM nothing") + XCTFail("This query should not succeed") + } + catch let error as PostgreSQLError { + XCTAssertEqual(error.code, PostgreSQLError.Code.undefinedTable) + } + catch { + throw error + } + } + + func testTransactionSuccess() throws { + let conn = try postgreSQL.makeConnection() + + let isolationLevels: [Connection.TransactionIsolationLevel] = [ + .readCommitted, + .repeatableRead, + .serializable, + ] + + for isolationLevel in isolationLevels { + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (bar INT, baz VARCHAR(16), bla BOOLEAN)") + + try conn.transaction(isolationLevel: isolationLevel) { + try conn.execute("INSERT INTO foo VALUES (42, 'Life', true)") + try conn.execute("INSERT INTO foo VALUES (1337, 'Elite', false)") + try conn.execute("INSERT INTO foo VALUES (9, NULL, true)") + } + + let resuls = try conn.execute("SELECT * FROM foo").array ?? [] + XCTAssertEqual(resuls.count, 3) + } + } + + func testTransactionFailure() throws { + let conn = try postgreSQL.makeConnection() + + enum TestError : Error { + case failure + } + + let isolationLevels: [Connection.TransactionIsolationLevel] = [ + .readCommitted, + .repeatableRead, + .serializable, + ] + + for isolationLevel in isolationLevels { + try conn.execute("DROP TABLE IF EXISTS foo") + try conn.execute("CREATE TABLE foo (bar INT, baz VARCHAR(16), bla BOOLEAN)") + + do { + try conn.transaction(isolationLevel: isolationLevel) { + try conn.execute("INSERT INTO foo VALUES (42, 'Life', true)") + try conn.execute("INSERT INTO foo VALUES (1337, 'Elite', false)") + try conn.execute("INSERT INTO foo VALUES (9, NULL, true)") + + throw TestError.failure + } + + XCTFail("transaction should throw error") + } + catch TestError.failure { + + } + catch { + XCTFail("Should not fail with unknown error") + } + + let resuls = try conn.execute("SELECT * FROM foo").array ?? [] + XCTAssertEqual(resuls.count, 0) + } + } } diff --git a/Tests/PostgreSQLTests/Utilities.swift b/Tests/PostgreSQLTests/Utilities.swift index e03df2f..0624520 100644 --- a/Tests/PostgreSQLTests/Utilities.swift +++ b/Tests/PostgreSQLTests/Utilities.swift @@ -3,7 +3,7 @@ import PostgreSQL import Foundation extension PostgreSQL.Database { - static func makeTestConnection() -> PostgreSQL.Database { + static func makeTest() -> PostgreSQL.Database { do { let postgreSQL = try PostgreSQL.Database( hostname: "127.0.0.1", @@ -12,7 +12,10 @@ extension PostgreSQL.Database { user: "postgres", password: "" ) - try postgreSQL.execute("SELECT version()") + + let connection = try postgreSQL.makeConnection() + try connection.execute("SELECT version()") + return postgreSQL } catch { print() diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..6f9c65b --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +coverage: + ignore: + - "Tests"