diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e49c5db6866a2..6016128b13ccf 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -196,9 +196,6 @@ def add_tensor_info( if self.state is not WriterState.EMPTY: raise ValueError(f'Expected output file to be empty, got {self.state}') - if raw_dtype is None and tensor_dtype not in (np.float32, np.float16): - raise ValueError("Only F32 and F16 tensors are supported for now") - encoded_name = name.encode("utf8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name @@ -207,7 +204,12 @@ def add_tensor_info( for i in range(n_dims): self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) if raw_dtype is None: - dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 + if tensor_shape == np.float32: + dtype = GGMLQuantizationType.F32 + elif tensor_dtype == np.float16: + dtype = GGMLQuantizationType.F16 + else: + raise ValueError("Only F32 and F16 tensors are supported for now") else: dtype = raw_dtype self.ti_data += self._pack("I", dtype)