Skip to content

Commit

Permalink
progress bar, fix split logic
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Jun 9, 2024
1 parent 70a6bc9 commit 1e2d9cb
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,8 @@ def add_shard_kv_data(self) -> None:
total_tensors = sum(len(t) for t in self.tensors)
assert self.fout is not None
total_splits = len(self.fout)
self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits))
for i in range(total_splits):
# just see whether it exists
try:
self.kv_data[i]
except IndexError:
self.kv_data.append(dict())
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
Expand Down Expand Up @@ -301,10 +297,12 @@ def add_tensor_info(
if tensor_dtype == np.uint8:
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)

# split when over tensor limit
if (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors \
# make sure there is at least one tensor before splitting
if (len(self.tensors[-1]) > 0
# split when over tensor limit
and (self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors)
# or split when over size limit
or self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size):
or (self.split_max_size != 0 and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size)):

self.tensors.append(dict())

Expand Down Expand Up @@ -360,15 +358,25 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_padding(fout, fout.tell())

if self.temp_file is None:
for fout, tensors in zip(self.fout, self.tensors):
bar = None
bar = None
shard_bar = None

if progress:
from tqdm import tqdm
if progress:
from tqdm import tqdm

total_bytes = sum(ti.nbytes for ti in tensors.values())
total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values())

bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
shard_bar = tqdm(desc="Shard progress", total=total_bytes, unit="byte", unit_scale=True)

for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)):
if bar and len(self.fout) > 1:
bar.desc = f"Writing ({i + 1}/{len(self.fout)})"
if shard_bar and len(self.fout) > 1:
total = sum(ti.nbytes for ti in tensors.values())
# bar behaves weirdly when total is 0
if total > 0:
shard_bar.reset(total=total)

# relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in tensors.values():
Expand All @@ -377,6 +385,8 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
ti.tensor.tofile(fout)
if bar is not None:
bar.update(ti.nbytes)
if shard_bar is not None:
shard_bar.update(ti.nbytes)
self.write_padding(fout, ti.nbytes)
ti.tensor = None
else:
Expand Down

0 comments on commit 1e2d9cb

Please sign in to comment.