Skip to content

Commit

Permalink
use simplification from ggerganov#7827
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Jun 9, 2024
1 parent 666bb09 commit 03cc9bc
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions gguf-py/gguf/gguf_writer_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GGUFWriterSplit(GGUFWriter):
kv_data: KVTempData
split_arguments: SplitArguments
shards: list[Shard]
shard_writers: list[GGUFWriter]
shard_writers: list[tuple[GGUFWriter, os.PathLike[str]]]

def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
Expand Down Expand Up @@ -115,17 +115,15 @@ def init_shards(self) -> None:
logger.info("Dry run, not writing files")
exit()

# we don't want to initialize GGUFWriters until now because they create files
for i, shard in enumerate(self.shards):
# add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
writer = GGUFWriter(shard.path, self.arch, use_temp_file=self.use_temp_file,
writer = GGUFWriter(None, self.arch, use_temp_file=self.use_temp_file,
endianess=self.endianess, add_architecture=(i == 0))

# only the first shard needs all the KV data
if i == 0:
for key, (value, etype) in self.kv_data.items():
writer.add_key(key)
writer.add_val(value, etype)
writer.add_key_value(key, value, etype)

# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
Expand All @@ -141,22 +139,22 @@ def init_shards(self) -> None:
except IndexError:
break

self.shard_writers.append(writer)
self.shard_writers.append((writer, shard.path))

def write_header_to_file(self) -> None:
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected GGUFWriterSplit state to be EMPTY, got {self.state}')

for writer in self.shard_writers:
writer.write_header_to_file()
for (writer, path) in self.shard_writers:
writer.write_header_to_file(path)

self.state = WriterState.HEADER

def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected GGUFWriterSplit state to be HEADER, got {self.state}')

for writer in self.shard_writers:
for (writer, _) in self.shard_writers:
writer.write_kv_data_to_file()

self.state = WriterState.KV_DATA
Expand All @@ -167,32 +165,21 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:

running_total = self.total_tensors
for i in range(len(self.shard_writers)):
writer = self.shard_writers[i]
is_metadata = writer.ti_data_count == 0
writer = self.shard_writers[i][0]
is_metadata = len(writer.tensors) == 0
if is_metadata:
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only")
else:
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= writer.ti_data_count
logger.info(f"Writing to shard {i + 1}/{len(self.shards)} with {len(writer.tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= len(writer.tensors)
writer.write_tensors_to_file(progress=(progress and not is_metadata))
del writer

self.state = WriterState.TI_DATA

# override add_key, add_val to handle kv data separately
def add_key(self, key: str) -> None:
self.recent_key = key

def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if self.recent_key is None:
raise ValueError("No key set for value")
self.kv_data[self.recent_key] = (val, vtype)

# need to handle arrays separately
def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
self.kv_data[key] = (val, GGUFValueType.ARRAY)
# override add_key_value to handle kv data separately
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
self.kv_data[key] = (val, vtype)

def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
Expand All @@ -218,7 +205,7 @@ def add_tensor(
self.shards[-1].tensors.append((name, tensor, raw_dtype))

def close(self) -> None:
for writer in self.shard_writers:
for (writer, _) in self.shard_writers:
writer.close()

@staticmethod
Expand Down

0 comments on commit 03cc9bc

Please sign in to comment.