-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DType refactoring #1
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some comments, LMK what you think
} else if let tensorData = try? container.decode(TensorData.self) { | ||
self = .tensorData(tensorData) | ||
} else { | ||
try! container.decode(TensorData.self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not needed I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ups ;)
get throws{ | ||
try self.scalarSize * self.scalarCount | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space missing between throws
and {
@@ -231,34 +51,26 @@ public enum Safetensors { | |||
let data = try Data(contentsOf: url, options: .mappedIfSafe) | |||
return try decode(data) | |||
} | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove added whitespace
} | ||
|
||
public init(from decoder: any Decoder) throws { | ||
var container = try decoder.singleValueContainer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be let
|
||
public var dtype: DType { | ||
get throws { | ||
try .init(mlMultiArrayDataType: dataType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: would prefer try DType(mlMultiArrayDataType: dataType)
} | ||
try validate(header: headerData, dataCount: data.count - headerOffset) | ||
|
||
let result = try HeaderDecoder().decode(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no need to instantiate a struct just to decode some data, this can be just a free function I think
let headerSize = withUnsafeBytes(of: UInt64(header.count)) { Data($0) } | ||
return headerSize + header + Data(tensorData) | ||
|
||
return try HeaderEncoder().encode(headerData) + tensorData |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no need to instantiate a struct just to encode some data, this can be just a free function I think
public enum DType: String, Codable { | ||
case float64 = "F64" | ||
case float32 = "F32" | ||
case float16 = "F16" | ||
case int32 = "I32" | ||
case uint32 = "U32" | ||
case int16 = "I16" | ||
case uint16 = "U16" | ||
case int8 = "I8" | ||
case uint8 = "U8" | ||
case bool = "BOOL" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the original lib implements more dtypes, see here https://github.com/huggingface/safetensors/blob/5db3b92c76ba293a0715b916c16b113c0b3551e9/safetensors/src/tensor.rs#L656
the problem I see with this approach is that even more types can be added in future. Keeping as a string allows us to parse it without a risk that a future addition will break it. We are validating it anyway when materializing to MLTensor
of MLMultiArray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have to check support for the new string types anyway so at least we have clearly defined place where to add it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let consider a scenario where a new type was added but we don't support it yet. The example file contains bunch of tensors, some of them we support, some of them we don't (they are not specified in DType
enum). I might be wrong but using your approach user won't be able even to open the file, using the old approach user will be able to open the file and materialize the tensors we have the support for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you see value in being able to open the file but not able to materialise it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a file can contain tensors data of multiple type, so yes, I see the value of being able to open the file and materialize the tensors we have the support for
// | ||
// Created by Tomasz Stachowiak on 1.10.2024. | ||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please unify the headers, you could remove it if you're ok with it
case .float16: | ||
return MemoryLayout<UInt16>.size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why float16
has size of UInt16
?
enum DType
and replaced string valued dtype fields with the enumtypealias HeaderData = [String: HeaderElement]
tensorData
is collected inSafetensors.encode()
for performancescalarSize
requirement from SafetensorsEncodable protocol