From 3ff27efa89a42afebf51cbfbc0964f81b479babd Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Thu, 23 May 2024 18:50:21 -0400 Subject: [PATCH] Fix eager tensor memory leak and remove convert.py changes Removed a memory leak caused by unexpected reference retention to eager tensors. Also removed GGUFManager functionality in convert.py in favor of specializing for convert-hf-to-gguf.py. --- convert-hf-to-gguf.py | 2 +- convert.py | 70 +++++++++--------- gguf-py/gguf/gguf_manager.py | 136 ++++++++++++++--------------------- gguf-py/gguf/gguf_writer.py | 1 + 4 files changed, 88 insertions(+), 121 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7c8d7a8ac75ea..24da4ebdd941e 100644 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2570,7 +2570,7 @@ def main() -> None: if args.split_max_tensors and args.split_max_size: raise ValueError("Can't specify both --split-max-tensors and --split-max-size") - split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments() + split_arguments = gguf.SplitArguments(args=args) if args.split else gguf.SplitArguments() ftype_map = { "f32": gguf.LlamaFileType.ALL_F32, diff --git a/convert.py b/convert.py index 26c0641250b0c..da1247957780c 100644 --- a/convert.py +++ b/convert.py @@ -24,17 +24,14 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable -# TEMPORARY IMPORT - TODO REMOVE -import importlib -gguf = importlib.import_module("gguf-py.gguf") +from typing import TYPE_CHECKING, Any, Callable, ClassVar, IO, Iterable, Literal, Protocol, TypeVar, runtime_checkable, Optional import numpy as np from sentencepiece import SentencePieceProcessor if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) -# import gguf +import gguf if TYPE_CHECKING: from typing_extensions import Self, TypeAlias @@ -1103,8 +1100,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) class OutputFile: - def __init__(self, fname_out: Path, split_arguments: gguf.SplitArguments, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): - self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], split_arguments, endianess=endianess) + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) def add_meta_model(self, params: Params, metadata: Metadata) -> None: # Metadata About The Model And Its Provenence @@ -1204,15 +1201,21 @@ def add_meta_vocab(self, vocab: Vocab) -> None: def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: svocab.add_to_gguf(self.gguf) + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + def write_meta(self) -> None: - self.gguf.write_to_file(meta_only=True) + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() - def write_tensors(self, ftype: GGMLFileType, concurrency: int) -> None: - self.gguf.write_to_file(ftype=ftype, concurrency=concurrency, write_tensor_data=OutputFile.write_tensor_data) + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() - # really awkward with how this is managed with gguf_manager.py: maybe refactor at some point? - @staticmethod - def write_tensor_data(ftype: GGMLFileType, model: LazyModel, concurrency: int, writer: gguf.GGUFWriter) -> None: + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) if ftype == GGMLFileType.MostlyQ8_0: ndarrays = bounded_parallel_map( @@ -1230,7 +1233,7 @@ def write_tensor_data(ftype: GGMLFileType, model: LazyModel, concurrency: int, w logger.info( f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" ) - writer.write_tensor_data(ndarray) + self.gguf.write_tensor_data(ndarray) def close(self) -> None: self.gguf.close() @@ -1242,7 +1245,7 @@ def write_vocab_only( ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) - of = OutputFile(fname_out, gguf.SplitArguments(), endianess=endianess) + of = OutputFile(fname_out, endianess=endianess) # meta data of.add_meta_model(params, metadata) @@ -1270,11 +1273,13 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: @staticmethod def write_all( fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, - split_arguments: gguf.SplitArguments, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, - pad_vocab: bool = False, metadata: Metadata = None, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + metadata: Metadata = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) - of = OutputFile(fname_out, split_arguments, endianess=endianess) + + of = OutputFile(fname_out, endianess=endianess) # meta data of.add_meta_model(params, metadata) @@ -1287,9 +1292,13 @@ def write_all( # tensor info for name, lazy_tensor in model.items(): - of.gguf.add_tensor_info(name, lazy_tensor) + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() - of.write_tensors(ftype, concurrency) + # tensor data + of.write_tensor_data(ftype, model, concurrency) of.close() @@ -1364,7 +1373,7 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] else: - raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.model_classweight") + raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight") tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) # HF models permut or pack some of the tensors, so we need to undo that @@ -1584,11 +1593,6 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") - parser.add_argument("--split", action="store_true", help="split the converted model into multiple files") - parser.add_argument("--split-max-tensors", type=int, help="max tensors in each split") - parser.add_argument("--split-max-size", type=str, help="max size per split N(M|G)") - parser.add_argument("--dry-run", action="store_true", help="only print out a split plan and exit, without writing any new files") - parser.add_argument("--large-first-shard", action="store_true", help="include tensors in the first shard when splitting (default: metadata only)") parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file") parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name") @@ -1622,14 +1626,6 @@ def main(args_in: list[str] | None = None) -> None: do_dump_model(model_plus) return - if args.split and not (args.split_max_tensors or args.split_max_size): - raise ValueError("Need to specify one of --split-max-tensors or --split-max-size when splitting") - - if args.split_max_tensors and args.split_max_size: - raise ValueError("Can't specify both --split-max-tensors and --split-max-size") - - split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments() - if not args.vocab_only: model_plus = load_some_model(args.model) else: @@ -1707,13 +1703,11 @@ def main(args_in: list[str] | None = None) -> None: outfile = args.outfile or default_outfile(model_plus.paths, ftype, params, model_params_count, metadata) params.ftype = ftype - logger.info(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, split_arguments, + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) - if not args.dry_run: - logger.info(f"Wrote {outfile}") + logger.info(f"Wrote {outfile}") if __name__ == '__main__': diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index f36b0173eafae..4a51b717e23e6 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -10,6 +10,7 @@ from string import ascii_letters, digits from argparse import Namespace from math import ceil +from collections import deque import numpy as np @@ -34,7 +35,7 @@ LLM_KV_SPLIT_COUNT = "split.count" LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" -SplitTensorsPerFile: TypeAlias = list[tuple[os.PathLike[str], list[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)] +SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)] KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)} TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any]] # (tensor name, tensor data), aka LazyModel @@ -53,23 +54,23 @@ class SplitArguments: split_max_size: int split_style: SplitStyle - def __init__(self) -> None: - self.split = False - self.dry_run = False - self.small_first_shard = False - self.split_max_tensors = 0 - self.split_max_size = 0 - self.split_style = SplitStyle.NONE - - def __init__(self, args: Namespace) -> None: - self.split = args.split - self.split_max_tensors = args.split_max_tensors - self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None - self.dry_run = args.dry_run - self.small_first_shard = not args.large_first_shard - self.split_style = SplitStyle.NONE if not self.split \ - else SplitStyle.TENSORS if self.split_max_tensors \ - else SplitStyle.SIZE + def __init__(self, args: Namespace = None) -> None: + if args is None: + self.split = False + self.dry_run = False + self.small_first_shard = False + self.split_max_tensors = 0 + self.split_max_size = 0 + self.split_style = SplitStyle.NONE + else: + self.split = args.split + self.split_max_tensors = args.split_max_tensors + self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None + self.dry_run = args.dry_run + self.small_first_shard = not args.large_first_shard + self.split_style = SplitStyle.NONE if not self.split \ + else SplitStyle.TENSORS if self.split_max_tensors \ + else SplitStyle.SIZE class SplitStrategy: @@ -78,7 +79,7 @@ class SplitStrategy: def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str, split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, ): - self.data = [] + self.data = deque() if split_arguments.split_style == SplitStyle.NONE: self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess))) @@ -96,7 +97,7 @@ def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arc self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess))) elif split_arguments.split_style == SplitStyle.SIZE: - shards = [] + shards = deque() # we have to determine the shards first to determine how many shards there will be in total - two passes for i, shard in enumerate(model): @@ -118,13 +119,7 @@ def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arc for i, shard in enumerate(shards): outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + shard_offset, total_shards)) - self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) - - def __getitem__(self, index): - return self.data[index] - - def __setitem__(self, index, value): - self.data[index] = value + self.append((outname, deque(shard), GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) def __len__(self): return len(self.data) @@ -176,7 +171,7 @@ def format_n_bytes_to_str(num: int) -> str: # ideally this has most of the same signatures as GGUFWriter so it's nearly a drop-in replacement class GGUFManager: kv_data: KVTempData - tensors: list[TensorTempData] + tensors: deque[TensorTempData] split_arguments: SplitArguments split_strategy: SplitStrategy @@ -188,7 +183,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl self.endianess = endianess self.offset_tensor = 0 self.kv_data = {} - self.tensors = [] + self.tensors = deque() self.split_strategy = None self.total_shards = None self.total_tensors = None @@ -200,9 +195,7 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: Spl # have to consolidate because we need to know kv data count and tensor count before we can write the header # and we need to write tensor info before we can write metadata # these all kinda show up around the same places anyway so it's not a huge deal? - def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: int = 8, - write_tensor_data: function = None - ) -> None: + def write_to_file(self, meta_only: bool = False) -> None: # here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here self.total_tensors = len(self.tensors) @@ -218,22 +211,23 @@ def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: in self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments, use_temp_file=self.use_temp_file, endianess=self.endianess) + del self.tensors self.total_shards = len(self.split_strategy) # only the first shard needs all the KV data for key, (value, etype) in self.kv_data.items(): - self.split_strategy[0][2].add_key(key) - self.split_strategy[0][2].add_val(value, etype) + self.split_strategy.data[0][2].add_key(key) + self.split_strategy.data[0][2].add_val(value, etype) if self.split_arguments.split_style != SplitStyle.NONE: - for i, (_, _, writer) in enumerate(self.split_strategy): + for i, (_, _, writer) in enumerate(self.split_strategy.data): writer.add_uint16(LLM_KV_SPLIT_NO, i) writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) # metadata/vocab only can write and return here if meta_only: - for i, (_, _, writer) in enumerate(self.split_strategy): + for i, (_, _, writer) in enumerate(self.split_strategy.data): writer.write_header_to_file() writer.write_kv_data_to_file() return @@ -241,57 +235,44 @@ def write_to_file(self, meta_only: bool = False, ftype: int = 0, concurrency: in # tensor writing code starts here print("\nWriting the following files:") - for (shard_path, shard_tensors, _) in self.split_strategy: + for (shard_path, shard_tensors, _) in self.split_strategy.data: size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only" print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}") if self.split_arguments.dry_run: print("\nDry run, not writing files") # instantiating GGUFWriters creates files - for name, _, _ in self.split_strategy: + for name, _, _ in self.split_strategy.data: os.remove(name) return # run add_tensor_info, write data, then write_tensor_data - taken from convert.py running_total = self.total_tensors - start = time.time() - for i, (_, tensors, writer) in enumerate(self.split_strategy): - + ct = 0 + while True: + try: + (_, tensors, writer) = self.split_strategy.data.popleft() + except IndexError: + break + + shard_num_tensors = len(tensors) if tensors else 0 + if tensors: - print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") - for j, (name, tensor) in enumerate(tensors): - n_elements = int(np.prod(tensor.shape)) - # logic from convert.py - if getattr(tensor, 'data_type', None): - raw_dtype = getattr(tensor.data_type, 'ggml_type', None) - data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype - data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - writer.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) - # logic from convert-hf-to-gguf.py - else: - # stolen from write_tensor_data because that doesn't get called with this logic - elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in tensor.shape) - padi = len(str(self.total_tensors)) - dtype = str(tensor.dtype) - print( - f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}" - ) - writer.add_tensor(name, tensor) - print(f"Writing to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") + while True: + try: + (name, tensor) = tensors.popleft() + except IndexError: + break + writer.add_tensor(name, tensor) + print(f"Writing to shard {ct + 1}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)") + running_total -= shard_num_tensors writer.write_header_to_file() writer.write_kv_data_to_file() - writer.write_tensors_to_file() - - if tensors: - # TODO this shows up AFTER writing which we don't really want - move it - running_total -= len(tensors) - - if write_tensor_data: - # convert.py's write_tensor_data is dependent on so many objects in convert.py itself that it's easier to pass the function as a parameter and call it here - write_tensor_data(ftype, dict(tensors), concurrency, writer) + writer.write_tensors_to_file(progress=True) + ct = ct + 1 + del tensors def add_uint8(self, key: str, val: int) -> None: self.kv_data[key] = (val, GGUFValueType.UINT8) @@ -336,11 +317,6 @@ def add_array(self, key: str, val: Sequence[Any]) -> None: raise ValueError(f'Expected a sequence for {key}, got {type(val)}') self.kv_data[key] = (val, GGUFValueType.ARRAY) - # this method is exclusive to convert.py - we don't have LazyTensor so Any type is used - def add_tensor_info(self, name: str, tensor: Any) -> None: - self.tensors.append((name, tensor)) - - # these methods are everywhere but convert.py (and convert-lora-to-ggml.py since that doesn't use the class) def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, @@ -354,7 +330,7 @@ def add_tensor( # fp.seek(0) # self.temp_file = fp - self.add_tensor_info(name, tensor) + self.tensors.append((name, tensor)) #if self.temp_file is None: # self.tensors.append(tensor) @@ -363,12 +339,8 @@ def add_tensor( #tensor.tofile(self.temp_file) #self.write_padding(self.temp_file, tensor.nbytes) - def write_tensors_to_file(self) -> None: - # TODO WRITE - pass - def close(self) -> None: - for _, _, writer in self.split_strategy: + for _, _, writer in self.split_strategy.data: writer.close() def add_architecture(self) -> None: diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8b41b54eaa5a6..964bf849c079a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -301,6 +301,7 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: tensor.tofile(self.fout) bar.update(tensor.nbytes) self.write_padding(self.fout, tensor.nbytes) + del tensor return while True: try: