Skip to content

Commit

Permalink
cleanup round 1
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Jun 9, 2024
1 parent 49b9fbe commit 0471f67
Showing 1 changed file with 23 additions and 42 deletions.
65 changes: 23 additions & 42 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ class GGUFValue:
type: GGUFValueType


@dataclass
class Shard:
path: Path
tensor_count: int
size: int
tensors: deque[TensorTempData]


class SplitArguments:
def __init__(self, args: Namespace) -> None:
self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0
Expand Down Expand Up @@ -91,10 +83,10 @@ class SplitStyle(Enum):


class GGUFWriter:
fout: list[BufferedWriter | None] | None
fout: list[BufferedWriter] | None
path: os.PathLike[str] | str | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo | np.ndarray[Any, Any]]]
tensors: list[dict[str, TensorInfo]]
kv_data: list[dict[str, GGUFValue]]
state: WriterState
_simple_value_packing = {
Expand Down Expand Up @@ -137,7 +129,7 @@ def __init__(

def verify_arguments(self) -> None:
total_tensors = sum(len(ti) for ti in self.tensors)
total_size = sum(sum(GGUFWriter.get_tensor_size(ti) for ti in t.values()) for t in self.tensors)
total_size = sum(ti.nbytes for t in self.tensors for ti in t.values())

if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors:
logger.warning("Model has fewer tensors than the split threshold, not splitting")
Expand All @@ -149,10 +141,10 @@ def verify_arguments(self) -> None:

# no shards are created when writing vocab so make one
if not self.tensors or len(self.tensors) == 0:
self.tensors.append(dict())
self.tensors = [dict()]

def format_shard_names(self) -> list[os.PathLike[str]]:
pathobj = Path(self.path)
def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
pathobj = Path(path)
if self.split_arguments.split_style == SplitStyle.NONE:
return [pathobj]

Expand All @@ -174,14 +166,15 @@ def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:

if self.path is not None:
self.fout = []
for fout in self.format_shard_names():
for fout in self.format_shard_names(self.path):
self.fout.append(open(fout, "wb"))
self.state = WriterState.EMPTY

def print_plan(self) -> None:
def print_plan(self, path: os.PathLike[str] | str | None = None) -> None:
logger.info("Writing the following files:")
for i in range(len(self.fout)):
logger.info(f"{self.fout[i].name}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(GGUFWriter.get_tensors_total_size(self.tensors[i].values()))}")
filenames = self.format_shard_names(path)
for i in range(len(filenames)):
logger.info(f"{filenames[i]}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in self.tensors[i].values()))}")

if self.split_arguments.dry_run:
logger.info("Dry run, not writing files")
Expand All @@ -204,8 +197,8 @@ def add_shard_kv_data(self) -> None:

def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
self.verify_arguments()
self.print_plan(path)
self.open_output_file(path)
self.print_plan()

if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')
Expand All @@ -215,13 +208,12 @@ def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> No

self.add_shard_kv_data()

for i in range(len(self.fout)):
fout = self.fout[i]
for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data):
self._write_packed(fout, "<I", GGUF_MAGIC, skip_pack_prefix = True)
self._write_packed(fout, "I", GGUF_VERSION)
self._write_packed(fout, "Q", len(self.tensors[i]))
self._write_packed(fout, "Q", len(self.kv_data[i]))
self.fout[i].flush()
self._write_packed(fout, "Q", len(tensors))
self._write_packed(fout, "Q", len(kv_data))
fout.flush()
self.state = WriterState.HEADER

def write_kv_data_to_file(self) -> None:
Expand All @@ -246,12 +238,12 @@ def write_ti_data_to_file(self) -> None:
raise ValueError(f'Expected output file to contain KV data, got {self.state}')
assert self.fout is not None

for i in range(len(self.fout)):
assert self.fout[i] is not None
for fout, tensors in zip(self.fout, self.tensors):
assert fout is not None
ti_data = bytearray()
offset_tensor = 0

for name, ti in self.tensors[i].items():
for name, ti in tensors.items():
ti_data += self._pack_val(name, GGUFValueType.STRING, add_vtype=False)
n_dims = len(ti.shape)
ti_data += self._pack("I", n_dims)
Expand All @@ -261,8 +253,8 @@ def write_ti_data_to_file(self) -> None:
ti_data += self._pack("Q", offset_tensor)
offset_tensor += GGUFWriter.ggml_pad(ti.nbytes, self.data_alignment)

self.fout[i].write(ti_data)
self.fout[i].flush()
fout.write(ti_data)
fout.flush()
self.state = WriterState.TI_DATA

def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
Expand Down Expand Up @@ -359,7 +351,7 @@ def add_tensor_info(
and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \
# or split when over size limit
or (self.split_arguments.split_style == SplitStyle.SIZE \
and GGUFWriter.get_tensors_total_size(self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)):
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)):

self.tensors.append(dict())

Expand Down Expand Up @@ -424,7 +416,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
if progress:
from tqdm import tqdm

total_bytes = GGUFWriter.get_tensors_total_size(self.tensors[i].values())
total_bytes = sum(ti.nbytes for ti in self.tensors[i].values())

bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)

Expand Down Expand Up @@ -739,17 +731,6 @@ def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_pr
assert fout is not None
fout.write(self._pack(fmt, value, skip_pack_prefix))

@staticmethod
def get_tensor_size(tensor) -> int:
try:
return tensor.data_type.elements_to_bytes(np.prod(tensor.shape))
except AttributeError: # numpy ndarray[Any, Any]
return tensor.nbytes

@staticmethod
def get_tensors_total_size(tensors) -> int:
return sum(GGUFWriter.get_tensor_size(ti) for ti in tensors)

@staticmethod
def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"):
Expand Down

0 comments on commit 0471f67

Please sign in to comment.