diff --git a/exllamav2/model.py b/exllamav2/model.py index b7ba0994..77ae01dd 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -244,6 +244,32 @@ def load( callback_gen: Callable[[int, int], None] | None = None, progress: bool = False ): + """ + Load model, regular manual split mode. + + :param gpu_split: + List of VRAM allocations for weights and fixed buffers per GPU. Does not account for the size of the cache + which must be allocated with reference to the model subsequently and whose split across GPUs will depend + on which devices end up receiving which attention layers. + + If None, only the first GPU is used. + + :param lazy: + Only set the device map according to the split, but don't actually load any of the modules. Modules can + subsequently be loaded and unloaded one by one for layer-streaming mode. + + :param stats: + Legacy, unused + + :param callback: + Callable function that triggers after each layer has loaded, for progress update etc. + + :param callback_gen: + Same as callback, but for use by async functions + + :param progress: + If True, create a rich progress bar in the console while loading. Cannot be used with callbacks + """ if progress: progressbar = get_basic_progress() @@ -270,7 +296,6 @@ def load_gen( callback: Callable[[int, int], None] | None = None, callback_gen: Callable[[int, int], None] | None = None ): - with torch.inference_mode(): stats_ = self.set_device_map(gpu_split or [99999]) @@ -306,7 +331,34 @@ def load_tp( expect_cache_tokens: int = 0, expect_cache_base: type = None ): + """ + Load model, tensor-parallel mode. + + :param gpu_split: + List of VRAM allocations per GPU. The loader attempts to balance tensor splits to stay within these + allocations, accounting for an uneven distribution of attention heads and the expected size of the cache. + + If None, the loader attempts to use all available GPUs and creates a split based on the currently available + VRAM according to nvidia-smi etc. + + :param callback: + Callable function that triggers after each layer has loaded, for progress update etc. + + :param callback_gen: + Same as callback, but for use by async functions + :param progress: + If True, create a rich progress bar in the console while loading. Cannot be used with callbacks + + :param expect_cache_tokens: + Expected size of the cache, in tokens (i.e. max_seq_len * max_batch_size, or just the cache size for use + with the dynamic generator) to inform the automatic tensor split. If not provided, the configured + max_seq_len for the model is assumed. + + :param expect_cache_base: + Cache type to expect, e.g. ExLlamaV2Cache_Q6. Also informs the tensor split. If not provided, FP16 cache + is assumed. + """ if progress: progressbar = get_basic_progress() progressbar.start() @@ -400,7 +452,31 @@ def load_autosplit( callback_gen: Callable[[int, int], None] | None = None, progress: bool = False ): + """ + Load model, auto-split mode. This mode loads the model and builds the cache in parallel, using available + devices in turn and moving on to the next device whenever the previous one is full. + + :param cache: + Cache constructed with lazy = True. Actual tensor allocation for the cache will happen while loading the + model. + + :param reserve_vram: + Number of bytes to reserve on each device, either for all devices (as an int) or per-device (as a list). + :param last_id_only: + If True, model will be loaded in a mode that does can only output one set of logits (i.e. one token + position) per forward pass. This conserves memory if the model is only to be used for generating text and + not e.g. perplexity measurement. + + :param callback: + Callable function that triggers after each layer has loaded, for progress update etc. + + :param callback_gen: + Same as callback, but for use by async functions + + :param progress: + If True, create a rich progress bar in the console while loading. Cannot be used with callbacks + """ if progress: progressbar = get_basic_progress() progressbar.start() @@ -569,6 +645,9 @@ def load_autosplit_gen( def unload(self): + """ + Unloads the model and frees all unmanaged resources. + """ for module in self.modules: module.unload()