From 83e4a3f5cce4c32feedfb0687743cc06556443b2 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 6 Jun 2024 09:00:59 -0400 Subject: [PATCH] make pathlib explicit --- gguf-py/gguf/gguf_manager.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index f4411e752cb7b..f74b24117d149 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -6,6 +6,7 @@ from argparse import Namespace from collections import deque from dataclasses import dataclass +from pathlib import Path import numpy as np @@ -30,7 +31,7 @@ @dataclass class Shard: - path: str + path: Path tensor_count: int size: int tensors: deque[TensorTempData] @@ -56,7 +57,6 @@ def __init__(self, args: Namespace) -> None: class GGUFManager(GGUFWriter): kv_data: KVTempData - tensors: list[TensorTempData] split_arguments: SplitArguments shards: list[Shard] shard_writers: list[GGUFWriter] @@ -66,7 +66,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl ) -> None: # we intentionally don't call superclass constructor self.arch = arch - self.path = path + self.path = Path(path) self.endianess = endianess self.kv_data = {} self.shards = [] @@ -78,7 +78,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl self.state = WriterState.EMPTY if self.split_arguments.small_first_shard: - self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) + self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque())) def init_shards(self) -> None: self.total_tensors = sum(shard.tensor_count for shard in self.shards) @@ -95,14 +95,13 @@ def init_shards(self) -> None: # no shards are created when writing vocab so make one if not self.shards: - self.shards.append(Shard("", 0, METADATA_ONLY_INDICATOR, deque())) + self.shards.append(Shard(Path(), 0, METADATA_ONLY_INDICATOR, deque())) # format shard names if len(self.shards) == 1: self.shards[0].path = self.path else: for i in range(len(self.shards)): - # TODO with_name is not explicit - import pathlib self.shards[i].path = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards))) # print shard info @@ -211,7 +210,7 @@ def add_tensor( and self.shards[-1].size + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): # we fill in the name later when we know how many shards there are - self.shards.append(Shard("", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)]))) + self.shards.append(Shard(Path(), 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)]))) else: self.shards[-1].tensor_count += 1 self.shards[-1].size += GGUFManager.get_tensor_size(tensor)