Skip to content

Commit

Permalink
try to refactor kv data (still fails)
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Jun 9, 2024
1 parent 97dd416 commit ff2dd7d
Showing 1 changed file with 31 additions and 32 deletions.
63 changes: 31 additions & 32 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
NUM_SHARD_KV_DATA = 6
NUM_SHARD_KV_DATA = 3
METADATA_ONLY_INDICATOR = -1

KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)}
Expand Down Expand Up @@ -92,11 +92,11 @@ class SplitStyle(Enum):


class GGUFWriter:
fout: list[BufferedWriter | None]
fout: list[BufferedWriter | None] | None
path: os.PathLike[str] | str | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo]]
kv_data: dict[str, GGUFValue]
kv_data: list[dict[str, GGUFValue]]
state: WriterState
_simple_value_packing = {
GGUFValueType.UINT8: "B",
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
self.kv_data = dict()
self.kv_data = [dict()]
logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little",
))
Expand Down Expand Up @@ -188,6 +188,20 @@ def print_plan(self) -> None:
logger.info("Dry run, not writing files")
exit()

def add_shard_kv_data(self) -> None:
if self.split_arguments.split_style == SplitStyle.NONE:
return

total_tensors = sum(len(t) for t in self.tensors)
for i in range(len(self.fout)):
try: # TODO better way to do this
self.kv_data[i]
except IndexError:
self.kv_data.append(dict())
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)

def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
self.verify_arguments()
self.open_output_file(path)
Expand All @@ -197,50 +211,35 @@ def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> No
raise ValueError(f'Expected output file to be empty, got {self.state}')

assert len(self.fout) == len(self.tensors)
assert len(self.kv_data) == 1

self.add_shard_kv_data()

for i in range(len(self.fout)):
fout = self.fout[i]
#print(f"writing header: GGUF_VERSION={GGUF_VERSION}, GGUF_MAGIC={GGUF_MAGIC}, n_tensors={len(self.tensors[i])}, n_kv_data={len(self.kv_data[i])}")
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
self._write_packed(fout, "I", GGUF_VERSION)
self._write_packed(fout, "Q", len(self.tensors[i]))
kv_data_len = len(self.kv_data) if i == 0 else 0
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
kv_data_len += NUM_SHARD_KV_DATA
self._write_packed(fout, "Q", kv_data_len)
self._write_packed(fout, "Q", len(self.kv_data[i]))
self.fout[i].flush()
self.state = WriterState.HEADER

def add_shard_kv_data(self, kv_data: bytearray, shard_no: int) -> bytearray:
total_tensors = sum(len(t) for t in self.tensors)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_NO, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(shard_no, GGUFValueType.UINT16, add_vtype=True)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_COUNT, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(len(self.fout), GGUFValueType.UINT16, add_vtype=True)
kv_data += self._pack_val(Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(total_tensors, GGUFValueType.INT32, add_vtype=True)
return kv_data

def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')
assert self.fout is not None

kv_data = bytearray()

for key, val in self.kv_data.items():
kv_data += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
kv_data += self._pack_val(val.value, val.type, add_vtype=True)
for fout, kv_data in zip(self.fout, self.kv_data):
kv_bytes = bytearray()

if len(self.fout) > 1:
kv_data = self.add_shard_kv_data(kv_data, 0)
for key, val in kv_data.items():
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)

# only the first shard needs kv data
self.fout[0].write(kv_data)
self.fout[0].flush()
fout.write(kv_bytes)

for i in range(1, len(self.fout)):
self.fout[i].write(self.add_shard_kv_data(bytearray(), i))
self.fout[i].flush()
self.flush()
self.state = WriterState.KV_DATA

def write_ti_data_to_file(self) -> None:
Expand Down Expand Up @@ -271,7 +270,7 @@ def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
if key in self.kv_data:
raise ValueError(f'Duplicated key name {key!r}')

self.kv_data[key] = GGUFValue(value=val, type=vtype)
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)

def add_uint8(self, key: str, val: int) -> None:
self.add_key_value(key,val, GGUFValueType.UINT8)
Expand Down

0 comments on commit ff2dd7d

Please sign in to comment.