Skip to content

Commit

Permalink
make pathlib explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
christianazinn committed Jun 6, 2024
1 parent 2037eab commit 83e4a3f
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions gguf-py/gguf/gguf_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from argparse import Namespace
from collections import deque
from dataclasses import dataclass
from pathlib import Path

import numpy as np

Expand All @@ -30,7 +31,7 @@

@dataclass
class Shard:
path: str
path: Path
tensor_count: int
size: int
tensors: deque[TensorTempData]
Expand All @@ -56,7 +57,6 @@ def __init__(self, args: Namespace) -> None:

class GGUFManager(GGUFWriter):
kv_data: KVTempData
tensors: list[TensorTempData]
split_arguments: SplitArguments
shards: list[Shard]
shard_writers: list[GGUFWriter]
Expand All @@ -66,7 +66,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl
) -> None:
# we intentionally don't call superclass constructor
self.arch = arch
self.path = path
self.path = Path(path)
self.endianess = endianess
self.kv_data = {}
self.shards = []
Expand All @@ -78,7 +78,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl
self.state = WriterState.EMPTY

if self.split_arguments.small_first_shard:
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))

def init_shards(self) -> None:
self.total_tensors = sum(shard.tensor_count for shard in self.shards)
Expand All @@ -95,14 +95,13 @@ def init_shards(self) -> None:

# no shards are created when writing vocab so make one
if not self.shards:
self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque()))
self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque()))

# format shard names
if len(self.shards) == 1:
self.shards[0].path = self.path
else:
for i in range(len(self.shards)):
# TODO with_name is not explicit - import pathlib
self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))

# print shard info
Expand Down Expand Up @@ -211,7 +210,7 @@ def add_tensor(
and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):

# we fill in the name later when we know how many shards there are
self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])))
else:
self.shards[-1].tensor_count += 1
self.shards[-1].size += GGUFManager.get_tensor_size(tensor)
Expand Down

0 comments on commit 83e4a3f

Please sign in to comment.