Skip to content

Commit

Permalink
Refactor dtype handling to be extensible
Browse files Browse the repository at this point in the history
This code is equivalent as before, but now it is prepared to easily add
more NumPy dtypes.
  • Loading branch information
certik committed Mar 13, 2024
1 parent 4636283 commit b7e9d5c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,6 @@ def add_tensor_info(
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')

if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
raise ValueError("Only F32 and F16 tensors are supported for now")

encoded_name = name.encode("utf8")
self.ti_data += self._pack("Q", len(encoded_name))
self.ti_data += encoded_name
Expand All @@ -207,7 +204,12 @@ def add_tensor_info(
for i in range(n_dims):
self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i])
if raw_dtype is None:
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
if tensor_shape == np.float32:
dtype = GGMLQuantizationType.F32
elif tensor_dtype == np.float16:
dtype = GGMLQuantizationType.F16
else:
raise ValueError("Only F32 and F16 tensors are supported for now")
else:
dtype = raw_dtype
self.ti_data += self._pack("I", dtype)
Expand Down

0 comments on commit b7e9d5c

Please sign in to comment.