Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DType refactoring #1

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

siberie
Copy link
Contributor

@siberie siberie commented Oct 1, 2024

  • Safetensors.swift has been decluttered by moving definitions to separate files
  • created enum DType and replaced string valued dtype fields with the enum
  • added MLMultiArrayDataType conversion initialiser and property to DType
  • added MLTensorScalar conversion property to DType
  • extended SafetensorsEncodable with useful properties
  • added Equtable to OffsetRange to fix UT
  • added typealias HeaderData = [String: HeaderElement]
  • created HeaderEncoder + HeaderDecoder to encapsulate ser/des of HeaderData
  • renamed files containing extensions to match extension naming pattern
  • changed the way tensorData is collected in Safetensors.encode() for performance
  • refactored remaining code to use the new constructs and for clarity where it was possible
  • removed scalarSize requirement from SafetensorsEncodable protocol

Copy link
Owner

@jkrukowski jkrukowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments, LMK what you think

} else if let tensorData = try? container.decode(TensorData.self) {
self = .tensorData(tensorData)
} else {
try! container.decode(TensorData.self)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not needed I guess

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ups ;)

Comment on lines 20 to 22
get throws{
try self.scalarSize * self.scalarCount
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space missing between throws and {

Sources/Safetensors/Safetensors.swift Outdated Show resolved Hide resolved
@@ -231,34 +51,26 @@ public enum Safetensors {
let data = try Data(contentsOf: url, options: .mappedIfSafe)
return try decode(data)
}

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove added whitespace

}

public init(from decoder: any Decoder) throws {
var container = try decoder.singleValueContainer()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be let


public var dtype: DType {
get throws {
try .init(mlMultiArrayDataType: dataType)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would prefer try DType(mlMultiArrayDataType: dataType)

}
try validate(header: headerData, dataCount: data.count - headerOffset)

let result = try HeaderDecoder().decode(data)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to instantiate a struct just to decode some data, this can be just a free function I think

let headerSize = withUnsafeBytes(of: UInt64(header.count)) { Data($0) }
return headerSize + header + Data(tensorData)

return try HeaderEncoder().encode(headerData) + tensorData
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to instantiate a struct just to encode some data, this can be just a free function I think

Comment on lines +9 to +20
public enum DType: String, Codable {
case float64 = "F64"
case float32 = "F32"
case float16 = "F16"
case int32 = "I32"
case uint32 = "U32"
case int16 = "I16"
case uint16 = "U16"
case int8 = "I8"
case uint8 = "U8"
case bool = "BOOL"
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the original lib implements more dtypes, see here https://github.com/huggingface/safetensors/blob/5db3b92c76ba293a0715b916c16b113c0b3551e9/safetensors/src/tensor.rs#L656

the problem I see with this approach is that even more types can be added in future. Keeping as a string allows us to parse it without a risk that a future addition will break it. We are validating it anyway when materializing to MLTensor of MLMultiArray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have to check support for the new string types anyway so at least we have clearly defined place where to add it

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let consider a scenario where a new type was added but we don't support it yet. The example file contains bunch of tensors, some of them we support, some of them we don't (they are not specified in DType enum). I might be wrong but using your approach user won't be able even to open the file, using the old approach user will be able to open the file and materialize the tensors we have the support for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you see value in being able to open the file but not able to materialise it?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a file can contain tensors data of multiple type, so yes, I see the value of being able to open the file and materialize the tensors we have the support for

Comment on lines 1 to 3
//
// Created by Tomasz Stachowiak on 1.10.2024.
//
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please unify the headers, you could remove it if you're ok with it

Comment on lines +29 to +30
case .float16:
return MemoryLayout<UInt16>.size
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why float16 has size of UInt16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants