diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 39cdb227626ec..a102cd123c200 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -145,12 +145,8 @@ def add_shard_kv_data(self) -> None: total_tensors = sum(len(t) for t in self.tensors) assert self.fout is not None total_splits = len(self.fout) + self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) for i in range(total_splits): - # just see whether it exists - try: - self.kv_data[i] - except IndexError: - self.kv_data.append(dict()) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) @@ -301,10 +297,12 @@ def add_tensor_info( if tensor_dtype == np.uint8: tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) - # split when over tensor limit - if (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors \ + # make sure there is at least one tensor before splitting + if (len(self.tensors[-1]) > 0 + # split when over tensor limit + and (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors) # or split when over size limit - or self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size): + or (self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size)): self.tensors.append(dict()) @@ -360,15 +358,25 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_padding(fout, fout.tell()) if self.temp_file is None: - for fout, tensors in zip(self.fout, self.tensors): - bar = None + bar = None + shard_bar = None - if progress: - from tqdm import tqdm + if progress: + from tqdm import tqdm - total_bytes = sum(ti.nbytes for ti in tensors.values()) + total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) - bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + shard_bar = tqdm(desc="Shard progress", total=total_bytes, unit="byte", unit_scale=True) + + for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): + if bar and len(self.fout) > 1: + bar.desc = f"Writing ({i + 1}/{len(self.fout)})" + if shard_bar and len(self.fout) > 1: + total = sum(ti.nbytes for ti in tensors.values()) + # bar behaves weirdly when total is 0 + if total > 0: + shard_bar.reset(total=total) # relying on the fact that Python dicts preserve insertion order (since 3.7) for ti in tensors.values(): @@ -377,6 +385,8 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: ti.tensor.tofile(fout) if bar is not None: bar.update(ti.nbytes) + if shard_bar is not None: + shard_bar.update(ti.nbytes) self.write_padding(fout, ti.nbytes) ti.tensor = None else: