From b7e9d5c8d49b2ca0571b455b8b241651e169db27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ond=C5=99ej=20=C4=8Cert=C3=ADk?= Date: Wed, 13 Mar 2024 12:21:24 -0600 Subject: [PATCH] Refactor dtype handling to be extensible This code is equivalent as before, but now it is prepared to easily add more NumPy dtypes. --- gguf-py/gguf/gguf_writer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e49c5db6866a2..6016128b13ccf 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -196,9 +196,6 @@ def add_tensor_info( if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') - if raw_dtype is None and tensor_dtype not in (np.float32, np.float16): - raise ValueError("Only F32 and F16 tensors are supported for now") - encoded_name = name.encode("utf8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name @@ -207,7 +204,12 @@ def add_tensor_info( for i in range(n_dims): self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) if raw_dtype is None: - dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 + if tensor_shape == np.float32: + dtype = GGMLQuantizationType.F32 + elif tensor_dtype == np.float16: + dtype = GGMLQuantizationType.F16 + else: + raise ValueError("Only F32 and F16 tensors are supported for now") else: dtype = raw_dtype self.ti_data += self._pack("I", dtype)