Skip to content

Commit

Permalink
feat(core): add verification to manual import + concatenated file sup…
Browse files Browse the repository at this point in the history
…port

- verify GGUFs on manual import
- show warning when dealing with concatenated files such as mradermacher's split GGUFs (partXofX)
  • Loading branch information
leafspark committed Aug 22, 2024
1 parent 88875e3 commit 4f2c805
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 22 deletions.
88 changes: 66 additions & 22 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,14 @@ def create_label(self, text, tooltip):
label.setToolTip(tooltip)
return label

def verify_gguf(self, file_path):
try:
with open(file_path, "rb") as f:
magic = f.read(4)
return magic == b"GGUF"
except Exception:
return False

def load_models(self):
self.logger.info(LOADING_MODELS)
models_dir = self.models_input.text()
Expand All @@ -1164,58 +1172,88 @@ def load_models(self):

sharded_models = {}
single_models = []
concatenated_models = []

# Regex pattern to match sharded model filenames
shard_pattern = re.compile(r"(.*)-(\d+)-of-(\d+)\.gguf$")
concat_pattern = re.compile(r"(.*)\.gguf\.part(\d+)of(\d+)$")

# Load models from the models directory
for file in os.listdir(models_dir):
full_path = os.path.join(models_dir, file)
if file.endswith(".gguf"):
if not self.verify_gguf(full_path):
show_error(self.logger, INVALID_GGUF_FILE.format(file))
continue

match = shard_pattern.match(file)
if match:
# This is a sharded model
base_name, shard_num, total_shards = match.groups()
if base_name not in sharded_models:
sharded_models[base_name] = []
sharded_models[base_name].append((int(shard_num), file))
else:
single_models.append(file)
else:
match = concat_pattern.match(file)
if match:
concatenated_models.append(file)

# Add imported models
if hasattr(self, "imported_models"):
for imported_model in self.imported_models:
file_name = os.path.basename(imported_model)
if file_name not in single_models:
single_models.append(file_name)
if (
file_name not in single_models
and file_name not in concatenated_models
):
if self.verify_gguf(imported_model):
single_models.append(file_name)
else:
show_error(
self.logger, INVALID_GGUF_FILE.format(imported_model)
)

# Add sharded models to the tree
for base_name, shards in sharded_models.items():
parent_item = QTreeWidgetItem(self.model_tree)
parent_item.setText(0, f"{base_name} ({SHARDED})")
parent_item.setText(0, SHARDED_MODEL_NAME.format(base_name))
first_shard = sorted(shards, key=lambda x: x[0])[0][1]
parent_item.setData(0, Qt.ItemDataRole.UserRole, first_shard)
for _, shard_file in sorted(shards):
child_item = QTreeWidgetItem(parent_item)
child_item.setText(0, shard_file)
child_item.setData(0, Qt.ItemDataRole.UserRole, shard_file)

# Add single models to the tree
for model in sorted(single_models):
item = QTreeWidgetItem(self.model_tree)
item.setText(0, model)
if hasattr(self, "imported_models") and model in [
os.path.basename(m) for m in self.imported_models
]:
full_path = next(
m for m in self.imported_models if os.path.basename(m) == model
)
item.setData(0, Qt.ItemDataRole.UserRole, full_path)
item.setToolTip(0, IMPORTED_MODEL_TOOLTIP.format(full_path))
else:
item.setData(0, Qt.ItemDataRole.UserRole, model)
self.add_model_to_tree(model)

for model in sorted(concatenated_models):
item = self.add_model_to_tree(model)
item.setForeground(0, Qt.gray)
item.setToolTip(0, CONCATENATED_FILE_WARNING)

self.model_tree.expandAll()
self.logger.info(LOADED_MODELS.format(len(single_models) + len(sharded_models)))
self.logger.info(
LOADED_MODELS.format(
len(single_models) + len(sharded_models) + len(concatenated_models)
)
)
if concatenated_models:
self.logger.warning(
CONCATENATED_FILES_FOUND.format(len(concatenated_models))
)

def add_model_to_tree(self, model):
item = QTreeWidgetItem(self.model_tree)
item.setText(0, model)
if hasattr(self, "imported_models") and model in [
os.path.basename(m) for m in self.imported_models
]:
full_path = next(
m for m in self.imported_models if os.path.basename(m) == model
)
item.setData(0, Qt.ItemDataRole.UserRole, full_path)
item.setToolTip(0, IMPORTED_MODEL_TOOLTIP.format(full_path))
else:
item.setData(0, Qt.ItemDataRole.UserRole, model)
return item

def validate_quantization_inputs(self):
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
Expand Down Expand Up @@ -1469,6 +1507,12 @@ def import_model(self):
)
if file_path:
file_name = os.path.basename(file_path)

# Verify GGUF file
if not self.verify_gguf(file_path):
show_error(self.logger, INVALID_GGUF_FILE.format(file_name))
return

reply = QMessageBox.question(
self,
CONFIRM_IMPORT,
Expand Down
9 changes: 9 additions & 0 deletions src/Localizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ def __init__(self):
self.IMPORTING_MODEL = "Importing model"
self.IMPORTED_MODEL_TOOLTIP = "Imported model: {}"

# GGUF Verification
self.INVALID_GGUF_FILE = "Invalid GGUF file: {}"
self.SHARDED_MODEL_NAME = "{} (Sharded)"
self.IMPORTED_MODEL_TOOLTIP = "Imported model: {}"
self.CONCATENATED_FILE_WARNING = "This is a concatenated file part. It will not work with llama-quantize; please concat the file first."
self.CONCATENATED_FILES_FOUND = (
"Found {} concatenated file parts. Please concat the files first."
)

# GPU Monitoring
self.GPU_USAGE = "GPU Usage:"
self.GPU_USAGE_FORMAT = "GPU: {:.1f}% | VRAM: {:.1f}% ({} MB / {} MB)"
Expand Down

0 comments on commit 4f2c805

Please sign in to comment.