From a045dc38781a1b938b4f3264f75de1782110652e Mon Sep 17 00:00:00 2001 From: Tomasz Stachowiak Date: Tue, 1 Oct 2024 19:02:03 +0300 Subject: [PATCH 1/5] refactoring --- .idea/.gitignore | 8 + .idea/encodings.xml | 6 + .idea/inspectionProfiles/Project_Default.xml | 25 ++ .idea/misc.xml | 16 ++ .idea/modules.xml | 8 + .idea/swift-safetensors.iml | 2 + .idea/vcs.xml | 6 + Sources/Safetensors/Enums/DType.swift | 20 ++ Sources/Safetensors/Enums/HeaderElement.swift | 53 ++++ .../DType+MLMultiArrayDataType.swift | 46 +++ .../Extensions/DType+MLTensorScalar.swift | 41 +++ .../MLMultiArray+SafetensorsEncodable.swift | 51 ++++ .../MLMultiArrayDataType+ScalarSize.swift | 29 ++ Sources/Safetensors/MLMultiArray.swift | 95 ------ Sources/Safetensors/MLTensor.swift | 34 --- .../Protocols/SafetensorsEncodable.swift | 37 +++ Sources/Safetensors/Safetensors.swift | 271 +++--------------- Sources/Safetensors/Structs/OffsetRange.swift | 36 +++ .../Safetensors/Structs/ParsedTensors.swift | 112 ++++++++ Sources/Safetensors/Structs/TensorData.swift | 17 ++ Sources/Safetensors/Utils/HeaderDecoder.swift | 38 +++ Sources/Safetensors/Utils/HeaderEncoder.swift | 24 ++ Tests/SafetensorsTests/SafetensorTests.swift | 10 +- 23 files changed, 617 insertions(+), 368 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/encodings.xml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/swift-safetensors.iml create mode 100644 .idea/vcs.xml create mode 100644 Sources/Safetensors/Enums/DType.swift create mode 100644 Sources/Safetensors/Enums/HeaderElement.swift create mode 100644 Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift create mode 100644 Sources/Safetensors/Extensions/DType+MLTensorScalar.swift create mode 100644 Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift create mode 100644 Sources/Safetensors/Extensions/MLMultiArrayDataType+ScalarSize.swift delete mode 100644 Sources/Safetensors/MLMultiArray.swift delete mode 100644 Sources/Safetensors/MLTensor.swift create mode 100644 Sources/Safetensors/Protocols/SafetensorsEncodable.swift create mode 100644 Sources/Safetensors/Structs/OffsetRange.swift create mode 100644 Sources/Safetensors/Structs/ParsedTensors.swift create mode 100644 Sources/Safetensors/Structs/TensorData.swift create mode 100644 Sources/Safetensors/Utils/HeaderDecoder.swift create mode 100644 Sources/Safetensors/Utils/HeaderEncoder.swift diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/encodings.xml b/.idea/encodings.xml new file mode 100644 index 0000000..97626ba --- /dev/null +++ b/.idea/encodings.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..9b3b5a8 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,25 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..7026b53 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,16 @@ + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..8cdbfe9 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/swift-safetensors.iml b/.idea/swift-safetensors.iml new file mode 100644 index 0000000..6207ba4 --- /dev/null +++ b/.idea/swift-safetensors.iml @@ -0,0 +1,2 @@ + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/Sources/Safetensors/Enums/DType.swift b/Sources/Safetensors/Enums/DType.swift new file mode 100644 index 0000000..3d45d54 --- /dev/null +++ b/Sources/Safetensors/Enums/DType.swift @@ -0,0 +1,20 @@ +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation +import CoreML + + +public enum DType: String, Codable { + case float64 = "F64" + case float32 = "F32" + case float16 = "F16" + case int32 = "I32" + case uint32 = "U32" + case int16 = "I16" + case uint16 = "U16" + case int8 = "I8" + case uint8 = "U8" + case bool = "BOOL" +} diff --git a/Sources/Safetensors/Enums/HeaderElement.swift b/Sources/Safetensors/Enums/HeaderElement.swift new file mode 100644 index 0000000..a32d3ca --- /dev/null +++ b/Sources/Safetensors/Enums/HeaderElement.swift @@ -0,0 +1,53 @@ +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation + +public enum HeaderElement { + case metadata([String: String]) + case tensorData(TensorData) +} + +extension HeaderElement { + public var metadata: [String: String]? { + if case .metadata(let metadata) = self { + return metadata + } + return nil + } + + public var tensorData: TensorData? { + if case .tensorData(let tensorData) = self { + return tensorData + } + return nil + } +} + +extension HeaderElement: Codable { + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let metadata = try? container.decode([String: String].self) { + self = .metadata(metadata) + } else if let tensorData = try? container.decode(TensorData.self) { + self = .tensorData(tensorData) + } else { + try! container.decode(TensorData.self) + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "Invalid header element")) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case .metadata(let metadata): + try container.encode(metadata) + case .tensorData(let tensorData): + try container.encode(tensorData) + } + } +} diff --git a/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift b/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift new file mode 100644 index 0000000..1730c98 --- /dev/null +++ b/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift @@ -0,0 +1,46 @@ +// +// DType+MLMultiArrayDataType.swift +// swift-safetensors +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import CoreML + +extension DType { + init(mlMultiArrayDataType dataType: MLMultiArrayDataType) throws { + switch dataType { + case .float64, .double: + self = .float64 + case .float32, .float: + self = .float32 + case .int32: + self = .int32 +#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) + case .float16: + self = .float16 +#endif + @unknown default: + throw SafetensorsError.unsupportedDataType(String(describing: dataType.rawValue)) + } + } + + var mlMultiArrayDataType: MLMultiArrayDataType { + get throws { + switch self { + case .float64: + return .float64 + case .float32: + return .float32 + case .int32: + return .int32 +#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) + case .float16: + return .float16 +#endif + default: + throw SafetensorsError.unsupportedDataType(self.rawValue) + } + } + } +} diff --git a/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift b/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift new file mode 100644 index 0000000..219b93c --- /dev/null +++ b/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift @@ -0,0 +1,41 @@ +// +// DType+MLMultiArrayDataType.swift +// swift-safetensors +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import CoreML + + +@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, visionOS 2.0, *) +extension DType { + var mlTensorScalarType: MLTensorScalar.Type { + get throws { + switch self { + case .float32: + Float32.self + case .int32: + Int32.self + case .uint32: + UInt32.self + case .int16: + Int16.self + case .uint16: + UInt16.self + case .int8: + Int8.self + case .uint8: + UInt8.self + case .bool: + Bool.self +#if !(os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64) + case .float16: + Float16.self +#endif + default: + throw SafetensorsError.unsupportedDataType(rawValue) + } + } + } +} diff --git a/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift b/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift new file mode 100644 index 0000000..89286aa --- /dev/null +++ b/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift @@ -0,0 +1,51 @@ +import CoreML +import Foundation + +extension MLMultiArray: SafetensorsEncodable { + public var scalarCount: Int { + shape.reduce(1) { $0 * $1.intValue } + } + + public var tensorShape: [Int] { + shape.map { $0.intValue } + } + + public var dtype: DType { + get throws { + try .init(mlMultiArrayDataType: dataType) + } + } + + public var scalarSize: Int { + get throws { + try dataType.scalarSize + } + } + + public func toData() throws -> Data { + switch dataType { + case .double: + data(ofType: Double.self) + case .float32: + data(ofType: Float32.self) + case .int32: + data(ofType: Int32.self) +#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) && os + case .float16: + guard #available(macOS 15.0, iOS 16.0, tvOS 16.0, watchOS 9.0, visionOS 1.0, *) else { + fallthrough + } + + data(ofType: Float16.self) +#endif + default: + throw SafetensorsError.unsupportedDataType(dataType.rawValue.description) + } + } + + private func data(ofType type: T.Type) -> Data where T: MLShapedArrayScalar { + withUnsafeBufferPointer(ofType: type) { ptr in + Data(buffer: ptr) + } + } +} diff --git a/Sources/Safetensors/Extensions/MLMultiArrayDataType+ScalarSize.swift b/Sources/Safetensors/Extensions/MLMultiArrayDataType+ScalarSize.swift new file mode 100644 index 0000000..587c0fb --- /dev/null +++ b/Sources/Safetensors/Extensions/MLMultiArrayDataType+ScalarSize.swift @@ -0,0 +1,29 @@ +// +// MLMultiArrayDataType+ScalarSize.swift +// swift-safetensors +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import CoreML + +extension MLMultiArrayDataType { + var scalarSize: Int { + get throws { + switch self { + case .double: + return MemoryLayout.size + case .float32: + return MemoryLayout.size + case .int32: + return MemoryLayout.size +#if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) + case .float16: + return MemoryLayout.size +#endif + @unknown default: + throw SafetensorsError.unsupportedDataType(rawValue.description) + } + } + } +} diff --git a/Sources/Safetensors/MLMultiArray.swift b/Sources/Safetensors/MLMultiArray.swift deleted file mode 100644 index e7492a9..0000000 --- a/Sources/Safetensors/MLMultiArray.swift +++ /dev/null @@ -1,95 +0,0 @@ -import CoreML -import Foundation - -extension MLMultiArray: SafetensorsEncodable { - public var scalarCount: Int { - shape.reduce(1, { $0 * $1.intValue }) - } - - public var tensorShape: [Int] { - shape.map { $0.intValue } - } - - public func dtype() throws -> String { - switch dataType { - case .float64: - return "F64" - case .float32: - return "F32" - case .float16: - return "F16" - case .int32: - return "I32" - default: - throw SafetensorsError.unsupportedDataType(String(describing: dataType.rawValue)) - } - } - - public func scalarSize() throws -> Int { - switch dataType { - case .float64: - return MemoryLayout.size - case .float32: - return MemoryLayout.size - #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) - case .float16: - return MemoryLayout.size - #endif - case .int32: - return MemoryLayout.size - default: - throw SafetensorsError.unsupportedDataType(String(describing: dataType.rawValue)) - } - } - - public func toData() throws -> Data { - switch dataType { - case .double: - return withUnsafeBufferPointer(ofType: Float64.self) { ptr in - Data(buffer: ptr) - } - case .float32: - return withUnsafeBufferPointer(ofType: Float32.self) { ptr in - Data(buffer: ptr) - } - #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) - case .float16: - if #available(macOS 15.0, iOS 16.0, tvOS 16.0, watchOS 9.0, visionOS 1.0, *) { - return withUnsafeBufferPointer(ofType: Float16.self) { ptr in - Data(buffer: ptr) - } - } else { - throw SafetensorsError.unsupportedDataType( - String(describing: dataType.rawValue)) - } - #endif - case .int32: - return withUnsafeBufferPointer(ofType: Int32.self) { ptr in - Data(buffer: ptr) - } - @unknown default: - throw SafetensorsError.unsupportedDataType(String(describing: dataType.rawValue)) - } - } -} - -extension MLMultiArray { - static func toMLMultiArrayDataType(from dtype: String) throws -> MLMultiArrayDataType { - switch dtype { - case "F64": - return .float64 - case "F32": - return .float32 - case "F16": - #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) - return .float16 - #else - throw SafetensorsError.unsupportedDataType(dtype) - #endif - case "I32": - return .int32 - default: - throw SafetensorsError.unsupportedDataType(dtype) - } - } -} diff --git a/Sources/Safetensors/MLTensor.swift b/Sources/Safetensors/MLTensor.swift deleted file mode 100644 index 8cdee35..0000000 --- a/Sources/Safetensors/MLTensor.swift +++ /dev/null @@ -1,34 +0,0 @@ -import CoreML -import Foundation - -@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, visionOS 2.0, *) -extension MLTensor { - static func toMLTensorScalarType(from dtype: String) throws -> MLTensorScalar.Type { - switch dtype { - case "F32": - return Float32.self - case "F16": - #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) - return Float16.self - #else - throw SafetensorsError.unsupportedDataType(dtype) - #endif - case "I32": - return Int32.self - case "U32": - return UInt32.self - case "I16": - return Int16.self - case "U16": - return UInt16.self - case "I8": - return Int8.self - case "U8": - return UInt8.self - case "BOOL": - return Bool.self - default: - throw SafetensorsError.unsupportedDataType(dtype) - } - } -} diff --git a/Sources/Safetensors/Protocols/SafetensorsEncodable.swift b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift new file mode 100644 index 0000000..08d66b7 --- /dev/null +++ b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift @@ -0,0 +1,37 @@ +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation + +/// Protocol for types that can be encoded to a `Data` object. +public protocol SafetensorsEncodable { + var scalarCount: Int { get } + var tensorShape: [Int] { get } + var dtype: DType { get throws } + + var scalarSize: Int { get throws } + + func toData() throws -> Data +} + +extension SafetensorsEncodable { + var byteCount: Int { + get throws{ + try self.scalarSize * self.scalarCount + } + } +} + +extension SafetensorsEncodable { + func tensorData(at offset: Int = 0) throws -> TensorData { + try TensorData( + dtype: dtype, + shape: tensorShape, + dataOffsets: OffsetRange( + start: offset, + end: offset + byteCount + ) + ) + } +} diff --git a/Sources/Safetensors/Safetensors.swift b/Sources/Safetensors/Safetensors.swift index 5478dd5..736e6a9 100644 --- a/Sources/Safetensors/Safetensors.swift +++ b/Sources/Safetensors/Safetensors.swift @@ -1,16 +1,6 @@ import CoreML import Foundation -/// Protocol for types that can be encoded to a `Data` object. -public protocol SafetensorsEncodable { - var scalarCount: Int { get } - var tensorShape: [Int] { get } - - func dtype() throws -> String - func scalarSize() throws -> Int - func toData() throws -> Data -} - public enum SafetensorsError: Error { case invalidHeaderSize case invalidHeaderData @@ -19,187 +9,6 @@ public enum SafetensorsError: Error { case metadataIncompleteBuffer } -public enum HeaderElement { - case metadata([String: String]) - case tensorData(TensorData) -} - -extension HeaderElement { - public var metadata: [String: String]? { - if case .metadata(let metadata) = self { - return metadata - } - return nil - } - - public var tensorData: TensorData? { - if case .tensorData(let tensorData) = self { - return tensorData - } - return nil - } -} - -extension HeaderElement: Codable { - public init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - if let metadata = try? container.decode([String: String].self) { - self = .metadata(metadata) - } else if let tensorData = try? container.decode(TensorData.self) { - self = .tensorData(tensorData) - } else { - throw DecodingError.dataCorrupted( - DecodingError.Context( - codingPath: decoder.codingPath, - debugDescription: "Invalid header element")) - } - } - - public func encode(to encoder: Encoder) throws { - var container = encoder.singleValueContainer() - switch self { - case .metadata(let metadata): - try container.encode(metadata) - case .tensorData(let tensorData): - try container.encode(tensorData) - } - } -} - -public struct OffsetRange: Equatable, Codable { - public let start: Int - public let end: Int - - public init(start: Int, end: Int) { - self.start = start - self.end = end - } - - public init(from decoder: Decoder) throws { - let container = try decoder.singleValueContainer() - let array = try container.decode([Int].self) - precondition(array.count == 2, "Range array needs to have exactly 2 elements") - self.start = array[0] - self.end = array[1] - } - - public func encode(to encoder: Encoder) throws { - try [start, end].encode(to: encoder) - } -} - -public struct TensorData: Codable { - public let dtype: String - public let shape: [Int] - public let dataOffsets: OffsetRange - - public init(dtype: String, shape: [Int], dataOffsets: OffsetRange) { - self.dtype = dtype - self.shape = shape - self.dataOffsets = dataOffsets - } -} - -public struct ParsedSafetensors { - private let headerOffset: Int - private let headerData: [String: HeaderElement] - private let rawData: Data - - public init( - headerOffset: Int, - headerData: [String: HeaderElement], - rawData: Data - ) { - self.headerOffset = headerOffset - self.headerData = headerData - self.rawData = rawData - } - - public var metadata: [String: String]? { - headerData["__metadata__"]?.metadata - } - - public func tensorData(forKey key: String) throws -> TensorData { - guard let tensorData = headerData[key]?.tensorData else { - throw SafetensorsError.missingTensorData - } - return tensorData - } - - @available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, visionOS 2.0, *) - public func mlTensor(forKey key: String, noCopy: Bool = false) throws -> MLTensor { - let tensorData = try tensorData(forKey: key) - let scalarType = try MLTensor.toMLTensorScalarType(from: tensorData.dtype) - let startIndex = tensorData.dataOffsets.start + headerOffset - let endIndex = tensorData.dataOffsets.end + headerOffset - let count = endIndex - startIndex - if noCopy { - return rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in - let startPtr = ptr.baseAddress!.advanced(by: startIndex) - return MLTensor( - bytesNoCopy: UnsafeRawBufferPointer(start: startPtr, count: count), - shape: tensorData.shape, - scalarType: scalarType, - deallocator: .none - ) - } - } else { - return rawData.withUnsafeBytes { (sourcePtr: UnsafeRawBufferPointer) in - MLTensor( - unsafeUninitializedShape: tensorData.shape, - scalarType: scalarType, - initializingWith: { ptr in - ptr.copyMemory( - from: UnsafeRawBufferPointer( - start: sourcePtr.baseAddress!.advanced(by: startIndex), count: count - ) - ) - } - ) - } - } - } - - public func mlMultiArray(forKey key: String, noCopy: Bool = false) throws -> MLMultiArray { - let tensorData = try tensorData(forKey: key) - let dataType = try MLMultiArray.toMLMultiArrayDataType(from: tensorData.dtype) - let startIndex = tensorData.dataOffsets.start + headerOffset - let endIndex = tensorData.dataOffsets.end + headerOffset - let count = endIndex - startIndex - var strides = [NSNumber]() - var stride = 1 - for dimension in tensorData.shape.reversed() { - strides.append(NSNumber(value: stride)) - stride *= dimension - } - strides.reverse() - if noCopy { - return try rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in - let dataPtr = ptr.baseAddress!.advanced(by: startIndex) - return try MLMultiArray( - dataPointer: UnsafeMutableRawPointer(mutating: dataPtr), - shape: tensorData.shape.map { NSNumber(value: $0) }, - dataType: dataType, - strides: strides - ) - } - } else { - let rawDataCopy = rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in - let dataPtr = ptr.baseAddress!.advanced(by: startIndex) - return Data(bytes: dataPtr, count: count) - } - return try rawDataCopy.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in - try MLMultiArray( - dataPointer: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), - shape: tensorData.shape.map { NSNumber(value: $0) }, - dataType: dataType, - strides: strides - ) - } - } - } -} - public enum Safetensors { /// Validate the header data and ensure that the tensor data is contiguous. /// - Parameters: @@ -208,21 +17,32 @@ public enum Safetensors { static func validate(header: [String: HeaderElement], dataCount: Int) throws { let allDataOffsets = header .values - .compactMap { $0.tensorData?.dataOffsets } - .sorted { $0.start < $1.start } + .compactMap { + $0.tensorData?.dataOffsets + } + .sorted { + $0.start < $1.start + } + guard let first = allDataOffsets.first, let last = allDataOffsets.last else { throw SafetensorsError.metadataIncompleteBuffer } + if first.start != 0 || last.end != dataCount { throw SafetensorsError.metadataIncompleteBuffer } + + if zip(allDataOffsets, allDataOffsets.dropFirst()).contains(where: { $0.end != $1.start }) { + throw SafetensorsError.metadataIncompleteBuffer + } + for (first, second) in zip(allDataOffsets, allDataOffsets.dropFirst()) { if first.end != second.start { throw SafetensorsError.metadataIncompleteBuffer } } } - + /// Read file at given URL and return `ParsedSafetensors` object. /// - Parameter url: file URL to read the data from /// - Returns: `ParsedSafetensors` object containing the decoded data @@ -231,34 +51,26 @@ public enum Safetensors { let data = try Data(contentsOf: url, options: .mappedIfSafe) return try decode(data) } - + /// Decode Data object to ParsedSafetensors object. /// - Parameter data: `Data` object containing the encoded data /// - Returns: `ParsedSafetensors` object containing the decoded data public static func decode(_ data: Data) throws -> ParsedSafetensors { - guard data.count >= 8 else { + guard data.count >= MemoryLayout.size else { throw SafetensorsError.invalidHeaderSize } - let (headerOffset, headerData) = try data[0..<8].withUnsafeBytes { - (ptr: UnsafeRawBufferPointer) in - let headerSize = ptr.load(as: Int.self) - guard data.count >= 8 + headerSize else { - throw SafetensorsError.invalidHeaderData - } - let headerData = data[8..<8 + headerSize] - let decoder = JSONDecoder() - decoder.keyDecodingStrategy = .convertFromSnakeCase - let header = try decoder.decode([String: HeaderElement].self, from: headerData) - return (headerSize + 8, header) - } - try validate(header: headerData, dataCount: data.count - headerOffset) + + let result = try HeaderDecoder().decode(data) + + try validate(header: result.header, dataCount: data.count - result.size) + return ParsedSafetensors( - headerOffset: headerOffset, - headerData: headerData, + headerSize: result.size, + headerData: result.header, rawData: data ) } - + /// Save dictionary of `SafetensorsEncodable` values to file. /// - Parameters: /// - data: dictionary of `SafetensorsEncodable` values @@ -273,7 +85,7 @@ public enum Safetensors { let encodedData = try encode(data, metadata: metadata) try encodedData.write(to: url) } - + /// Encode dictionary of `SafetensorsEncodable` values to `Data` object. /// - Parameters: /// - data: dictionary of `SafetensorsEncodable` values @@ -283,31 +95,22 @@ public enum Safetensors { _ data: [String: any SafetensorsEncodable], metadata: [String: String]? = nil ) throws -> Data { - var headerData = [String: HeaderElement]() - headerData.reserveCapacity(data.count + (metadata == nil ? 0 : 1)) - var previousOffset = 0 - var tensorData = [UInt8]() + var headerData = ParsedSafetensors.HeaderData( + minimumCapacity: data.count + (metadata == nil ? 0 : 1) + ) + + let totalDataSize = try data.values.reduce(0) { try $1.byteCount + $0 } + var tensorData: Data = .init(capacity: totalDataSize) + for (key, tensor) in data { - let tensorByteCount = try tensor.scalarSize() * tensor.scalarCount - let tensorHeaderData = try TensorData( - dtype: tensor.dtype(), - shape: tensor.tensorShape, - dataOffsets: OffsetRange( - start: previousOffset, - end: tensorByteCount + previousOffset - ) - ) - previousOffset += tensorByteCount - headerData[key] = .tensorData(tensorHeaderData) - try tensorData.append(contentsOf: tensor.toData()) + headerData[key] = .tensorData(try tensor.tensorData(at: tensorData.count)) + tensorData.append(try tensor.toData()) } + if let metadata { headerData["__metadata__"] = .metadata(metadata) } - let encoder = JSONEncoder() - encoder.keyEncodingStrategy = .convertToSnakeCase - let header = try encoder.encode(headerData) - let headerSize = withUnsafeBytes(of: UInt64(header.count)) { Data($0) } - return headerSize + header + Data(tensorData) + + return try HeaderEncoder().encode(headerData) + tensorData } } diff --git a/Sources/Safetensors/Structs/OffsetRange.swift b/Sources/Safetensors/Structs/OffsetRange.swift new file mode 100644 index 0000000..1878863 --- /dev/null +++ b/Sources/Safetensors/Structs/OffsetRange.swift @@ -0,0 +1,36 @@ +// +// OffsetRange.swift +// swift-safetensors +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +public struct OffsetRange: Codable { + let start: Int + let end: Int + + public init(start: Int = 0, end: Int) { + self.start = start + self.end = end + } + + public init(from decoder: any Decoder) throws { + var container = try decoder.singleValueContainer() + let array = try container.decode([Int].self) + + precondition(array.count == 2, "range array needs to have exactly 2 elements") + + self.start = array[0] + self.end = array[1] + } + + public func encode(to encoder: any Encoder) throws { + try [start, end].encode(to: encoder) + } +} + +extension OffsetRange: Equatable { + public static func ==(lhs: OffsetRange, rhs: OffsetRange) -> Bool { + lhs.start == rhs.start && lhs.end == rhs.end + } +} diff --git a/Sources/Safetensors/Structs/ParsedTensors.swift b/Sources/Safetensors/Structs/ParsedTensors.swift new file mode 100644 index 0000000..a7cfacb --- /dev/null +++ b/Sources/Safetensors/Structs/ParsedTensors.swift @@ -0,0 +1,112 @@ +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation +import CoreML + +public struct ParsedSafetensors { + typealias HeaderData = [String: HeaderElement] + + private let headerSize: Int + private let headerData: HeaderData + private let rawData: Data + + public init( + headerSize: Int, + headerData: [String: HeaderElement], + rawData: Data + ) { + self.headerSize = headerSize + self.headerData = headerData + self.rawData = rawData + } + + public var metadata: [String: String]? { + headerData["__metadata__"]?.metadata + } + + public func tensorData(forKey key: String) throws -> TensorData { + guard let tensorData = headerData[key]?.tensorData else { + throw SafetensorsError.missingTensorData + } + return tensorData + } + + @available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, visionOS 2.0, *) + public func mlTensor(forKey key: String, noCopy: Bool = false) throws -> MLTensor { + let tensorData = try tensorData(forKey: key) + let scalarType = try tensorData.dtype.mlTensorScalarType + let startIndex = tensorData.dataOffsets.start + headerSize + let endIndex = tensorData.dataOffsets.end + headerSize + let count = endIndex - startIndex + if noCopy { + return rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in + let startPtr = ptr.baseAddress!.advanced(by: startIndex) + return MLTensor( + bytesNoCopy: UnsafeRawBufferPointer(start: startPtr, count: count), + shape: tensorData.shape, + scalarType: scalarType, + deallocator: .none + ) + } + } else { + return rawData.withUnsafeBytes { (sourcePtr: UnsafeRawBufferPointer) in + MLTensor( + unsafeUninitializedShape: tensorData.shape, + scalarType: scalarType, + initializingWith: { ptr in + ptr.copyMemory( + from: UnsafeRawBufferPointer( + start: sourcePtr.baseAddress!.advanced(by: startIndex), count: count + ) + ) + } + ) + } + } + } + + public func mlMultiArray(forKey key: String, noCopy: Bool = false) throws -> MLMultiArray { + let tensorData = try tensorData(forKey: key) + let dataType = try tensorData.dtype.mlMultiArrayDataType + let startIndex = tensorData.dataOffsets.start + headerSize + let endIndex = tensorData.dataOffsets.end + headerSize + let count = endIndex - startIndex + var strides = [NSNumber]() + var stride = 1 + for dimension in tensorData.shape.reversed() { + strides.append(NSNumber(value: stride)) + stride *= dimension + } + strides.reverse() + if noCopy { + return try rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in + let dataPtr = ptr.baseAddress!.advanced(by: startIndex) + return try MLMultiArray( + dataPointer: UnsafeMutableRawPointer(mutating: dataPtr), + shape: tensorData.shape.map { + NSNumber(value: $0) + }, + dataType: dataType, + strides: strides + ) + } + } else { + let rawDataCopy = rawData.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in + let dataPtr = ptr.baseAddress!.advanced(by: startIndex) + return Data(bytes: dataPtr, count: count) + } + return try rawDataCopy.withUnsafeBytes { (ptr: UnsafeRawBufferPointer) in + try MLMultiArray( + dataPointer: UnsafeMutableRawPointer(mutating: ptr.baseAddress!), + shape: tensorData.shape.map { + NSNumber(value: $0) + }, + dataType: dataType, + strides: strides + ) + } + } + } +} diff --git a/Sources/Safetensors/Structs/TensorData.swift b/Sources/Safetensors/Structs/TensorData.swift new file mode 100644 index 0000000..a43041e --- /dev/null +++ b/Sources/Safetensors/Structs/TensorData.swift @@ -0,0 +1,17 @@ +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation + +public struct TensorData: Codable { + public let dtype: DType + public let shape: [Int] + public let dataOffsets: OffsetRange + + public init(dtype: DType, shape: [Int], dataOffsets: OffsetRange) { + self.dtype = dtype + self.shape = shape + self.dataOffsets = dataOffsets + } +} diff --git a/Sources/Safetensors/Utils/HeaderDecoder.swift b/Sources/Safetensors/Utils/HeaderDecoder.swift new file mode 100644 index 0000000..9d1a46a --- /dev/null +++ b/Sources/Safetensors/Utils/HeaderDecoder.swift @@ -0,0 +1,38 @@ +// +// HeaderEncoder.swift +// swift-safetensors +// +// Created by Tomasz Stachowiak on 1.10.2024. +// + +import Foundation + +struct HeaderDecoder { + private var jsonDecoder: JSONDecoder + + init() { + jsonDecoder = JSONDecoder() + jsonDecoder.keyDecodingStrategy = .convertFromSnakeCase + } + + func decode(_ data: Data) throws -> (size: Int, header: ParsedSafetensors.HeaderData) { + guard data.count >= MemoryLayout.size else { + throw SafetensorsError.invalidHeaderData + } + + let headerOffset = MemoryLayout.size + let headerSize = data.withUnsafeBytes { + $0.load(as: Int.self) + } + + guard data.count >= headerOffset + headerSize else { + throw SafetensorsError.invalidHeaderData + } + + let headerData = data[headerOffset.. Data { + let encodedHeader = try jsonEncoder.encode(headerData) + return withUnsafeBytes(of: encodedHeader.count) { + Data($0) + } + encodedHeader + } +} diff --git a/Tests/SafetensorsTests/SafetensorTests.swift b/Tests/SafetensorsTests/SafetensorTests.swift index 4ba3333..6895b3a 100644 --- a/Tests/SafetensorsTests/SafetensorTests.swift +++ b/Tests/SafetensorsTests/SafetensorTests.swift @@ -13,7 +13,7 @@ import Testing let safeTensors = try Safetensors.decode(data) let testTensor = try safeTensors.tensorData(forKey: "test") - #expect(testTensor.dtype == "I32") + #expect(testTensor.dtype == .int32) #expect(testTensor.shape == [2, 2]) #expect(testTensor.dataOffsets == OffsetRange(start: 0, end: 16)) #expect(safeTensors.metadata == nil) @@ -28,7 +28,7 @@ import Testing let safeTensors = try Safetensors.decode(data) let testTensor = try safeTensors.tensorData(forKey: "test") - #expect(testTensor.dtype == "I32") + #expect(testTensor.dtype == .int32) #expect(testTensor.shape == [2, 2]) #expect(testTensor.dataOffsets == OffsetRange(start: 0, end: 16)) #expect(safeTensors.metadata == ["key1": "value1"]) @@ -65,7 +65,7 @@ import Testing let safeTensors = try Safetensors.read(at: fileUrl) let testTensor: TensorData = try safeTensors.tensorData(forKey: "test") - #expect(testTensor.dtype == "I32") + #expect(testTensor.dtype == .int32) #expect(testTensor.shape == [2, 2]) #expect(testTensor.dataOffsets == OffsetRange(start: 0, end: 16)) let tensor = try safeTensors.mlTensor(forKey: "test") @@ -102,7 +102,7 @@ import Testing let safeTensors = try Safetensors.decode(data) let testTensor = try safeTensors.tensorData(forKey: "test") - #expect(testTensor.dtype == "I32") + #expect(testTensor.dtype == .int32) #expect(testTensor.shape == []) #expect(testTensor.dataOffsets == OffsetRange(start: 0, end: 4)) } @@ -115,7 +115,7 @@ import Testing let safeTensors = try Safetensors.decode(data) let testTensor = try safeTensors.tensorData(forKey: "test") - #expect(testTensor.dtype == "I32") + #expect(testTensor.dtype == .int32) #expect(testTensor.shape == []) #expect(testTensor.dataOffsets == OffsetRange(start: 0, end: 0)) } From cb5422c3daaa34d749fafb37e189a3542b8c33f1 Mon Sep 17 00:00:00 2001 From: Tomasz Stachowiak Date: Tue, 1 Oct 2024 19:36:26 +0300 Subject: [PATCH 2/5] removed .idea, added gitignore --- .gitignore | 183 +++++++++++++++++++ .idea/.gitignore | 8 - .idea/encodings.xml | 6 - .idea/inspectionProfiles/Project_Default.xml | 25 --- .idea/misc.xml | 16 -- .idea/modules.xml | 8 - .idea/swift-safetensors.iml | 2 - .idea/vcs.xml | 6 - 8 files changed, 183 insertions(+), 71 deletions(-) delete mode 100644 .idea/.gitignore delete mode 100644 .idea/encodings.xml delete mode 100644 .idea/inspectionProfiles/Project_Default.xml delete mode 100644 .idea/misc.xml delete mode 100644 .idea/modules.xml delete mode 100644 .idea/swift-safetensors.iml delete mode 100644 .idea/vcs.xml diff --git a/.gitignore b/.gitignore index edf07a0..1ccef9e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,186 @@ DerivedData/ .netrc /.vscode default.profraw +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Xcode template +## User settings + +## Xcode 8 and earlier +*.xcscmblueprint +*.xccheckout + +### Swift template +# Xcode +# +# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore + +## User settings + +## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) + +## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) +build/ +*.moved-aside +*.pbxuser +!default.pbxuser +*.mode1v3 +!default.mode1v3 +*.mode2v3 +!default.mode2v3 +*.perspectivev3 +!default.perspectivev3 + +## Obj-C/Swift specific +*.hmap + +## App packaging +*.ipa +*.dSYM.zip +*.dSYM + +## Playgrounds +timeline.xctimeline +playground.xcworkspace + +# Swift Package Manager +# +# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. +# Packages/ +# Package.pins +# Package.resolved +# *.xcodeproj +# +# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata +# hence it is not needed unless you have added a package configuration file to your project +# .swiftpm + +.build/ + +# CocoaPods +# +# We recommend against adding the Pods directory to your .gitignore. However +# you should judge for yourself, the pros and cons are mentioned at: +# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control +# +# Pods/ +# +# Add this line if you want to avoid checking in source code from the Xcode workspace +# *.xcworkspace + +# Carthage +# +# Add this line if you want to avoid checking in source code from Carthage dependencies. +# Carthage/Checkouts + +Carthage/Build/ + +# Accio dependency management +Dependencies/ +.accio/ + +# fastlane +# +# It is recommended to not store the screenshots in the git repo. +# Instead, use fastlane to re-generate the screenshots whenever they are needed. +# For more information about the recommended setup visit: +# https://docs.fastlane.tools/best-practices/source-control/#source-control + +fastlane/report.xml +fastlane/Preview.html +fastlane/screenshots/**/*.png +fastlane/test_output + +# Code Injection +# +# After new code Injection tools there's a generated folder /iOSInjectionProject +# https://github.com/johnno1962/injectionforxcode + +iOSInjectionProject/ + +### SwiftPackageManager template +Packages +xcuserdata +*.xcodeproj + + +### SwiftPM template + + diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 13566b8..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml -# Editor-based HTTP Client requests -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/encodings.xml b/.idea/encodings.xml deleted file mode 100644 index 97626ba..0000000 --- a/.idea/encodings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 9b3b5a8..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,25 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index 7026b53..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,16 +0,0 @@ - - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 8cdbfe9..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/swift-safetensors.iml b/.idea/swift-safetensors.iml deleted file mode 100644 index 6207ba4..0000000 --- a/.idea/swift-safetensors.iml +++ /dev/null @@ -1,2 +0,0 @@ - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From a0aca234612466190983e484efee0464b86b5b55 Mon Sep 17 00:00:00 2001 From: Tomasz Stachowiak Date: Tue, 1 Oct 2024 19:40:30 +0300 Subject: [PATCH 3/5] spurious code removed --- .gitignore | 124 ++++++++++++++++++++++++++ Sources/Safetensors/Safetensors.swift | 4 - 2 files changed, 124 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 1ccef9e..4809357 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ DerivedData/ .netrc /.vscode default.profraw +/.idea ### JetBrains template # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 @@ -192,3 +193,126 @@ xcuserdata ### SwiftPM template +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff + +# AWS User-specific + +# Generated files + +# Sensitive or high-churn files + +# Gradle + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake + +# Mongo Explorer plugin + +# File-based project format + +# IntelliJ + +# mpeltonen/sbt-idea plugin + +# JIRA plugin + +# Cursive Clojure plugin + +# SonarLint plugin + +# Crashlytics plugin (for Android Studio and IntelliJ) + +# Editor-based Rest Client + +# Android studio 3.1+ serialized cache file + +### Xcode template +## User settings + +## Xcode 8 and earlier + +### Swift template +# Xcode +# +# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore + +## User settings + +## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9) + +## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4) + +## Obj-C/Swift specific + +## App packaging + +## Playgrounds + +# Swift Package Manager +# +# Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. +# Packages/ +# Package.pins +# Package.resolved +# *.xcodeproj +# +# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata +# hence it is not needed unless you have added a package configuration file to your project +# .swiftpm + + +# CocoaPods +# +# We recommend against adding the Pods directory to your .gitignore. However +# you should judge for yourself, the pros and cons are mentioned at: +# https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control +# +# Pods/ +# +# Add this line if you want to avoid checking in source code from the Xcode workspace +# *.xcworkspace + +# Carthage +# +# Add this line if you want to avoid checking in source code from Carthage dependencies. +# Carthage/Checkouts + + +# Accio dependency management + +# fastlane +# +# It is recommended to not store the screenshots in the git repo. +# Instead, use fastlane to re-generate the screenshots whenever they are needed. +# For more information about the recommended setup visit: +# https://docs.fastlane.tools/best-practices/source-control/#source-control + + +# Code Injection +# +# After new code Injection tools there's a generated folder /iOSInjectionProject +# https://github.com/johnno1962/injectionforxcode + + +### SwiftPM template + + +### SwiftPackageManager template + + diff --git a/Sources/Safetensors/Safetensors.swift b/Sources/Safetensors/Safetensors.swift index 736e6a9..60fadae 100644 --- a/Sources/Safetensors/Safetensors.swift +++ b/Sources/Safetensors/Safetensors.swift @@ -32,10 +32,6 @@ public enum Safetensors { throw SafetensorsError.metadataIncompleteBuffer } - if zip(allDataOffsets, allDataOffsets.dropFirst()).contains(where: { $0.end != $1.start }) { - throw SafetensorsError.metadataIncompleteBuffer - } - for (first, second) in zip(allDataOffsets, allDataOffsets.dropFirst()) { if first.end != second.start { throw SafetensorsError.metadataIncompleteBuffer From 1507165ff80379334793b9548be7ca20515a7b44 Mon Sep 17 00:00:00 2001 From: Tomasz Stachowiak Date: Tue, 1 Oct 2024 19:52:41 +0300 Subject: [PATCH 4/5] removed scalarSize requirement from the SafetensorEncodable protocol --- Sources/Safetensors/Enums/DType.swift | 27 +++++++++++++++++++ .../Protocols/SafetensorsEncodable.swift | 11 +++++--- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/Sources/Safetensors/Enums/DType.swift b/Sources/Safetensors/Enums/DType.swift index 3d45d54..2ce6fbb 100644 --- a/Sources/Safetensors/Enums/DType.swift +++ b/Sources/Safetensors/Enums/DType.swift @@ -18,3 +18,30 @@ public enum DType: String, Codable { case uint8 = "U8" case bool = "BOOL" } + +extension DType { + var scalarSize: Int { + switch self { + case .float64: + return MemoryLayout.size + case .float32: + return MemoryLayout.size + case .float16: + return MemoryLayout.size + case .int32: + return MemoryLayout.size + case .uint32: + return MemoryLayout.size + case .int16: + return MemoryLayout.size + case .uint16: + return MemoryLayout.size + case .int8: + return MemoryLayout.size + case .uint8: + return MemoryLayout.size + case .bool: + return MemoryLayout.size + } + } +} diff --git a/Sources/Safetensors/Protocols/SafetensorsEncodable.swift b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift index 08d66b7..6784a60 100644 --- a/Sources/Safetensors/Protocols/SafetensorsEncodable.swift +++ b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift @@ -9,16 +9,13 @@ public protocol SafetensorsEncodable { var scalarCount: Int { get } var tensorShape: [Int] { get } var dtype: DType { get throws } - - var scalarSize: Int { get throws } - func toData() throws -> Data } extension SafetensorsEncodable { var byteCount: Int { get throws{ - try self.scalarSize * self.scalarCount + try self.scalarSize * self.dtype.scalarSize } } } @@ -34,4 +31,10 @@ extension SafetensorsEncodable { ) ) } + + var scalarSize: Int { + get throws { + try dtype.scalarSize + } + } } From 39455a2c3f21cd8cd361ff23a121ae087afe8217 Mon Sep 17 00:00:00 2001 From: Tomasz Stachowiak Date: Tue, 1 Oct 2024 20:08:54 +0300 Subject: [PATCH 5/5] review changes --- Sources/Safetensors/Enums/DType.swift | 4 --- Sources/Safetensors/Enums/HeaderElement.swift | 5 --- .../DType+MLMultiArrayDataType.swift | 7 ---- .../Extensions/DType+MLTensorScalar.swift | 7 ---- .../MLMultiArray+SafetensorsEncodable.swift | 18 ++++++---- .../Protocols/SafetensorsEncodable.swift | 2 +- Sources/Safetensors/Safetensors.swift | 36 +++++++++---------- Sources/Safetensors/Structs/OffsetRange.swift | 10 ++---- Sources/Safetensors/Utils/HeaderDecoder.swift | 22 +++++------- Sources/Safetensors/Utils/HeaderEncoder.swift | 9 ++--- 10 files changed, 42 insertions(+), 78 deletions(-) diff --git a/Sources/Safetensors/Enums/DType.swift b/Sources/Safetensors/Enums/DType.swift index 2ce6fbb..c63dc22 100644 --- a/Sources/Safetensors/Enums/DType.swift +++ b/Sources/Safetensors/Enums/DType.swift @@ -1,7 +1,3 @@ -// -// Created by Tomasz Stachowiak on 1.10.2024. -// - import Foundation import CoreML diff --git a/Sources/Safetensors/Enums/HeaderElement.swift b/Sources/Safetensors/Enums/HeaderElement.swift index a32d3ca..2d7e512 100644 --- a/Sources/Safetensors/Enums/HeaderElement.swift +++ b/Sources/Safetensors/Enums/HeaderElement.swift @@ -1,7 +1,3 @@ -// -// Created by Tomasz Stachowiak on 1.10.2024. -// - import Foundation public enum HeaderElement { @@ -33,7 +29,6 @@ extension HeaderElement: Codable { } else if let tensorData = try? container.decode(TensorData.self) { self = .tensorData(tensorData) } else { - try! container.decode(TensorData.self) throw DecodingError.dataCorrupted( DecodingError.Context( codingPath: decoder.codingPath, diff --git a/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift b/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift index 1730c98..5077f4a 100644 --- a/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift +++ b/Sources/Safetensors/Extensions/DType+MLMultiArrayDataType.swift @@ -1,10 +1,3 @@ -// -// DType+MLMultiArrayDataType.swift -// swift-safetensors -// -// Created by Tomasz Stachowiak on 1.10.2024. -// - import CoreML extension DType { diff --git a/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift b/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift index 219b93c..d69d811 100644 --- a/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift +++ b/Sources/Safetensors/Extensions/DType+MLTensorScalar.swift @@ -1,10 +1,3 @@ -// -// DType+MLMultiArrayDataType.swift -// swift-safetensors -// -// Created by Tomasz Stachowiak on 1.10.2024. -// - import CoreML diff --git a/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift b/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift index 89286aa..e171a5f 100644 --- a/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift +++ b/Sources/Safetensors/Extensions/MLMultiArray+SafetensorsEncodable.swift @@ -3,16 +3,20 @@ import Foundation extension MLMultiArray: SafetensorsEncodable { public var scalarCount: Int { - shape.reduce(1) { $0 * $1.intValue } + shape.reduce(1) { + $0 * $1.intValue + } } public var tensorShape: [Int] { - shape.map { $0.intValue } + shape.map { + $0.intValue + } } public var dtype: DType { get throws { - try .init(mlMultiArrayDataType: dataType) + try Dtype(mlMultiArrayDataType: dataType) } } @@ -32,11 +36,11 @@ extension MLMultiArray: SafetensorsEncodable { data(ofType: Int32.self) #if !((os(macOS) || targetEnvironment(macCatalyst)) && arch(x86_64)) && os case .float16: - guard #available(macOS 15.0, iOS 16.0, tvOS 16.0, watchOS 9.0, visionOS 1.0, *) else { - fallthrough + if #available(macOS 15.0, iOS 16.0, tvOS 16.0, watchOS 9.0, visionOS 1.0, *) { + data(ofType: Float16.self) + } else { + throw SafetensorsError.unsupportedDataType(dataType.rawValue.description) } - - data(ofType: Float16.self) #endif default: throw SafetensorsError.unsupportedDataType(dataType.rawValue.description) diff --git a/Sources/Safetensors/Protocols/SafetensorsEncodable.swift b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift index 6784a60..3487928 100644 --- a/Sources/Safetensors/Protocols/SafetensorsEncodable.swift +++ b/Sources/Safetensors/Protocols/SafetensorsEncodable.swift @@ -14,7 +14,7 @@ public protocol SafetensorsEncodable { extension SafetensorsEncodable { var byteCount: Int { - get throws{ + get throws { try self.scalarSize * self.dtype.scalarSize } } diff --git a/Sources/Safetensors/Safetensors.swift b/Sources/Safetensors/Safetensors.swift index 60fadae..678ac62 100644 --- a/Sources/Safetensors/Safetensors.swift +++ b/Sources/Safetensors/Safetensors.swift @@ -23,22 +23,22 @@ public enum Safetensors { .sorted { $0.start < $1.start } - + guard let first = allDataOffsets.first, let last = allDataOffsets.last else { throw SafetensorsError.metadataIncompleteBuffer } - + if first.start != 0 || last.end != dataCount { throw SafetensorsError.metadataIncompleteBuffer } - + for (first, second) in zip(allDataOffsets, allDataOffsets.dropFirst()) { if first.end != second.start { throw SafetensorsError.metadataIncompleteBuffer } } } - + /// Read file at given URL and return `ParsedSafetensors` object. /// - Parameter url: file URL to read the data from /// - Returns: `ParsedSafetensors` object containing the decoded data @@ -47,26 +47,22 @@ public enum Safetensors { let data = try Data(contentsOf: url, options: .mappedIfSafe) return try decode(data) } - + /// Decode Data object to ParsedSafetensors object. /// - Parameter data: `Data` object containing the encoded data /// - Returns: `ParsedSafetensors` object containing the decoded data public static func decode(_ data: Data) throws -> ParsedSafetensors { - guard data.count >= MemoryLayout.size else { - throw SafetensorsError.invalidHeaderSize - } + let result = try HeaderDecoder.decode(data) - let result = try HeaderDecoder().decode(data) - try validate(header: result.header, dataCount: data.count - result.size) - + return ParsedSafetensors( headerSize: result.size, headerData: result.header, rawData: data ) } - + /// Save dictionary of `SafetensorsEncodable` values to file. /// - Parameters: /// - data: dictionary of `SafetensorsEncodable` values @@ -81,7 +77,7 @@ public enum Safetensors { let encodedData = try encode(data, metadata: metadata) try encodedData.write(to: url) } - + /// Encode dictionary of `SafetensorsEncodable` values to `Data` object. /// - Parameters: /// - data: dictionary of `SafetensorsEncodable` values @@ -94,19 +90,21 @@ public enum Safetensors { var headerData = ParsedSafetensors.HeaderData( minimumCapacity: data.count + (metadata == nil ? 0 : 1) ) - - let totalDataSize = try data.values.reduce(0) { try $1.byteCount + $0 } - var tensorData: Data = .init(capacity: totalDataSize) - + + let totalDataSize = try data.values.reduce(0) { + try $1.byteCount + $0 + } + var tensorData = Data(capacity: totalDataSize) + for (key, tensor) in data { headerData[key] = .tensorData(try tensor.tensorData(at: tensorData.count)) tensorData.append(try tensor.toData()) } - + if let metadata { headerData["__metadata__"] = .metadata(metadata) } - return try HeaderEncoder().encode(headerData) + tensorData + return try HeaderEncoder.encode(headerData) + tensorData } } diff --git a/Sources/Safetensors/Structs/OffsetRange.swift b/Sources/Safetensors/Structs/OffsetRange.swift index 1878863..d2eee6c 100644 --- a/Sources/Safetensors/Structs/OffsetRange.swift +++ b/Sources/Safetensors/Structs/OffsetRange.swift @@ -5,7 +5,7 @@ // Created by Tomasz Stachowiak on 1.10.2024. // -public struct OffsetRange: Codable { +public struct OffsetRange: Codable, Equatable { let start: Int let end: Int @@ -15,7 +15,7 @@ public struct OffsetRange: Codable { } public init(from decoder: any Decoder) throws { - var container = try decoder.singleValueContainer() + let container = try decoder.singleValueContainer() let array = try container.decode([Int].self) precondition(array.count == 2, "range array needs to have exactly 2 elements") @@ -28,9 +28,3 @@ public struct OffsetRange: Codable { try [start, end].encode(to: encoder) } } - -extension OffsetRange: Equatable { - public static func ==(lhs: OffsetRange, rhs: OffsetRange) -> Bool { - lhs.start == rhs.start && lhs.end == rhs.end - } -} diff --git a/Sources/Safetensors/Utils/HeaderDecoder.swift b/Sources/Safetensors/Utils/HeaderDecoder.swift index 9d1a46a..42c0ca1 100644 --- a/Sources/Safetensors/Utils/HeaderDecoder.swift +++ b/Sources/Safetensors/Utils/HeaderDecoder.swift @@ -8,31 +8,27 @@ import Foundation struct HeaderDecoder { - private var jsonDecoder: JSONDecoder - - init() { - jsonDecoder = JSONDecoder() - jsonDecoder.keyDecodingStrategy = .convertFromSnakeCase - } - - func decode(_ data: Data) throws -> (size: Int, header: ParsedSafetensors.HeaderData) { + static func decode(_ data: Data) throws -> (size: Int, header: ParsedSafetensors.HeaderData) { guard data.count >= MemoryLayout.size else { throw SafetensorsError.invalidHeaderData } - + let headerOffset = MemoryLayout.size let headerSize = data.withUnsafeBytes { $0.load(as: Int.self) } - + guard data.count >= headerOffset + headerSize else { throw SafetensorsError.invalidHeaderData } - + let headerData = data[headerOffset.. Data { + var jsonEncoder = JSONEncoder() jsonEncoder.keyEncodingStrategy = .convertToSnakeCase - } - - func encode(_ headerData: ParsedSafetensors.HeaderData) throws -> Data { let encodedHeader = try jsonEncoder.encode(headerData) return withUnsafeBytes(of: encodedHeader.count) { Data($0)