diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index 5d6133fe6ea18..e7d2ef096cd49 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Sequence from argparse import Namespace from collections import deque +from dataclasses import dataclass import numpy as np @@ -28,7 +29,14 @@ KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)} TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype) -Shard: TypeAlias = list[os.PathLike[str], int, int, deque[TensorTempData]] # [shard filename, shard tensor count, shard size, [tensor data]] + + +@dataclass +class Shard: + path: str + tensor_count: int + size: int + tensors: deque[TensorTempData] class SplitStyle(IntEnum): @@ -73,11 +81,11 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl self.state = WriterState.EMPTY if self.split_arguments.small_first_shard: - self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None]) + self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) def init_shards(self) -> None: - self.total_tensors = sum(shard[1] for shard in self.shards) - total_size = sum(shard[2] for shard in self.shards) + self.total_tensors = sum(shard.tensor_count for shard in self.shards) + total_size = sum(shard.size for shard in self.shards) # check if we need to split if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors: @@ -90,19 +98,20 @@ def init_shards(self) -> None: # no shards are created when writing vocab so make one if not self.shards: - self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None]) + self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) # format shard names if len(self.shards) == 1: - self.shards[0][0] = self.path + self.shards[0].path = self.path else: for i in range(len(self.shards)): - self.shards[i][0] = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards))) + # TODO with_name is not explicit - import pathlib + self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards))) # print shard info print("\nWriting the following files:") - for (path, tensor_ct, size, _) in self.shards: - print(f" {path}: n_tensors = {tensor_ct}, total_size = {GGUFManager.format_n_bytes_to_str(size)}") + for shard in self.shards: + print(f" {shard.path}: n_tensors = {shard.tensor_count}, total_size = {GGUFManager.format_n_bytes_to_str(shard.size)}") print() if self.split_arguments.dry_run: @@ -110,10 +119,10 @@ def init_shards(self) -> None: exit() # we don't want to initialize GGUFWriters until now because they create files - for i, (path, _, _, tensors) in enumerate(self.shards): - # dont_add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards - writer = GGUFWriter(path, self.arch, use_temp_file=self.use_temp_file, - endianess=self.endianess, dont_add_architecture=not (i == 0)) + for i, shard in enumerate(self.shards): + # add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards + writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file, + endianess=self.endianess, add_architecture=(i == 0)) # only the first shard needs all the KV data if i == 0: @@ -130,7 +139,7 @@ def init_shards(self) -> None: # add tensors, deque popleft() ensures references to eager tensors are not kept while True: try: - (name, tensor, dtype) = tensors.popleft() + (name, tensor, dtype) = shard.tensors.popleft() writer.add_tensor(name, tensor, raw_dtype=dtype) except: break @@ -199,17 +208,17 @@ def add_tensor( if (len(self.shards) == self.split_arguments.small_first_shard \ # or split when over tensor limit or (self.split_arguments.split_style == SplitStyle.TENSORS \ - and self.shards[-1][1] >= self.split_arguments.split_max_tensors) \ + and self.shards[-1].tensor_count >= self.split_arguments.split_max_tensors) \ # or split when over size limit or (self.split_arguments.split_style == SplitStyle.SIZE \ - and self.shards[-1][2] + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): + and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): # we fill in the name later when we know how many shards there are - self.shards.append(["", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])]) + self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)]))) else: - self.shards[-1][1] += 1 - self.shards[-1][2] += GGUFManager.get_tensor_size(tensor) - self.shards[-1][3].append((name, tensor, raw_dtype)) + self.shards[-1].tensor_count += 1 + self.shards[-1].size += GGUFManager.get_tensor_size(tensor) + self.shards[-1].tensors.append((name, tensor, raw_dtype)) def close(self) -> None: for writer in self.shard_writers: diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 294f4d06dbb70..31ca9eabc9468 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -57,7 +57,7 @@ class GGUFWriter: def __init__( self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, - endianess: GGUFEndian = GGUFEndian.LITTLE, dont_add_architecture: bool = False + endianess: GGUFEndian = GGUFEndian.LITTLE, add_architecture: bool = True ): self.fout = open(path, "wb") self.arch = arch @@ -77,7 +77,7 @@ def __init__( )) self.state = WriterState.EMPTY - if not dont_add_architecture: + if add_architecture: self.add_architecture() def write_header_to_file(self) -> None: