From 3804da0a3f1397af24c4edc5e99723fed3ba1272 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Tue, 10 Sep 2024 15:58:17 -0700 Subject: [PATCH] feat(ui): add RAM and CPU usage graphs - add RAM and CPU usage graphs - add input validation using wraps - reduce strictness of iMatrix status checking - add right click context menu to models list --- src/AutoGGUF.py | 158 ++++++++++++++++++++++++++++++++++++-- src/Localizations.py | 11 +++ src/QuantizationThread.py | 14 ++-- src/ui_update.py | 8 ++ 4 files changed, 176 insertions(+), 15 deletions(-) diff --git a/src/AutoGGUF.py b/src/AutoGGUF.py index b934d6c..417545f 100644 --- a/src/AutoGGUF.py +++ b/src/AutoGGUF.py @@ -4,7 +4,7 @@ import urllib.request import urllib.error from datetime import datetime -from functools import partial +from functools import partial, wraps from typing import Any, Dict, List, Tuple from PySide6.QtCore import * @@ -16,7 +16,7 @@ import ui_update import utils from CustomTitleBar import CustomTitleBar -from GPUMonitor import GPUMonitor +from GPUMonitor import GPUMonitor, SimpleGraph from Localizations import * from Logger import Logger from QuantizationThread import QuantizationThread @@ -32,6 +32,36 @@ class AutoGGUF(QMainWindow): + def validate_input(*fields): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + for field in fields: + value = getattr(self, field).text().strip() + + # Length check + if len(value) > 1024: + show_error(f"{field} exceeds maximum length") + + # Normalize path + normalized_path = os.path.normpath(value) + + # Check for path traversal attempts + if ".." in normalized_path: + show_error(f"Invalid path in {field}") + + # Disallow control characters and null bytes + if re.search(r"[\x00-\x1f\x7f]", value): + show_error(f"Invalid characters in {field}") + + # Update the field with normalized path + getattr(self, field).setText(normalized_path) + + return func(self, *args, **kwargs) + + return wrapper + + return decorator def __init__(self, args: List[str]) -> None: super().__init__() @@ -49,6 +79,7 @@ def __init__(self, args: List[str]) -> None: self.setWindowFlag(Qt.FramelessWindowHint) load_dotenv(self) # Loads the .env file + self.process_args(args) # Load any command line parameters # Configuration self.model_dir_name = os.environ.get("AUTOGGUF_MODEL_DIR_NAME", "models") @@ -308,11 +339,6 @@ def __init__(self, args: List[str]) -> None: # Initialize threads self.quant_threads = [] - # Timer for updating system info - self.timer = QTimer() - self.timer.timeout.connect(self.update_system_info) - self.timer.start(200) - # Add all widgets to content_layout left_widget = QWidget() right_widget = QWidget() @@ -335,6 +361,19 @@ def __init__(self, args: List[str]) -> None: left_layout.addWidget(QLabel(GPU_USAGE)) left_layout.addWidget(self.gpu_monitor) + # Add mouse click event handlers for RAM and CPU bars + self.ram_bar.mouseDoubleClickEvent = self.show_ram_graph + self.cpu_bar.mouseDoubleClickEvent = self.show_cpu_graph + + # Initialize data lists for CPU and RAM usage + self.cpu_data = [] + self.ram_data = [] + + # Timer for updating system info + self.timer = QTimer() + self.timer.timeout.connect(self.update_system_info) + self.timer.start(200) + # Backend selection backend_layout = QHBoxLayout() self.backend_combo = QComboBox() @@ -415,6 +454,10 @@ def __init__(self, args: List[str]) -> None: left_layout.addWidget(QLabel(AVAILABLE_MODELS)) left_layout.addWidget(self.model_tree) + # Ssupport right-click menu + self.model_tree.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.model_tree.customContextMenuRequested.connect(self.show_model_context_menu) + # Refresh models button refresh_models_button = QPushButton(REFRESH_MODELS) refresh_models_button.clicked.connect(self.load_models) @@ -930,6 +973,97 @@ def __init__(self, args: List[str]) -> None: self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE) self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed())) + def show_ram_graph(self, event) -> None: + self.show_detailed_stats(RAM_USAGE_OVER_TIME, self.ram_data) + + def show_cpu_graph(self, event) -> None: + self.show_detailed_stats(CPU_USAGE_OVER_TIME, self.cpu_data) + + def show_detailed_stats(self, title, data) -> None: + dialog = QDialog(self) + dialog.setWindowTitle(title) + dialog.setMinimumSize(800, 600) + + layout = QVBoxLayout(dialog) + + graph = SimpleGraph(title) + layout.addWidget(graph) + + def update_graph_data() -> None: + graph.update_data(data) + + timer = QTimer(dialog) + timer.timeout.connect(update_graph_data) + timer.start(200) # Update every 0.2 seconds + + dialog.exec() + + def show_model_context_menu(self, position): + item = self.model_tree.itemAt(position) + if item: + # Child of a sharded model or top-level item without children + if item.parent() is not None or item.childCount() == 0: + menu = QMenu() + rename_action = menu.addAction(RENAME) + delete_action = menu.addAction(DELETE) + + action = menu.exec(self.model_tree.viewport().mapToGlobal(position)) + if action == rename_action: + self.rename_model(item) + elif action == delete_action: + self.delete_model(item) + + def rename_model(self, item): + old_name = item.text(0) + new_name, ok = QInputDialog.getText(self, RENAME, f"New name for {old_name}:") + if ok and new_name: + old_path = os.path.join(self.models_input.text(), old_name) + new_path = os.path.join(self.models_input.text(), new_name) + try: + os.rename(old_path, new_path) + item.setText(0, new_name) + self.logger.info(MODEL_RENAMED_SUCCESSFULLY.format(old_name, new_name)) + except Exception as e: + show_error(self.logger, f"Error renaming model: {e}") + + def delete_model(self, item): + model_name = item.text(0) + reply = QMessageBox.question( + self, + CONFIRM_DELETE, + DELETE_WARNING.format(model_name), + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + model_path = os.path.join(self.models_input.text(), model_name) + try: + os.remove(model_path) + self.model_tree.takeTopLevelItem( + self.model_tree.indexOfTopLevelItem(item) + ) + self.logger.info(MODEL_DELETED_SUCCESSFULLY.format(model_name)) + except Exception as e: + show_error(self.logger, f"Error deleting model: {e}") + + def process_args(self, args: List[str]) -> bool: + try: + i = 1 + while i < len(args): + key = ( + args[i][2:].replace("-", "_").upper() + ) # Strip the first two '--' and replace '-' with '_' + if i + 1 < len(args) and not args[i + 1].startswith("--"): + value = args[i + 1] + i += 2 + else: + value = "enabled" + i += 1 + os.environ[key] = value + return True + except Exception: + return False + def load_plugins(self) -> Dict[str, Dict[str, Any]]: plugins = {} plugin_dir = "plugins" @@ -1174,6 +1308,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None: show_error(self.logger, f"{ERROR_STARTING_AUTOFP8_QUANTIZATION}: {e}") self.logger.info(AUTOFP8_QUANTIZATION_TASK_STARTED) + @validate_input( + "hf_model_input", + "hf_outfile", + "hf_split_max_size", + "hf_model_name", + "logs_input", + ) def convert_hf_to_gguf(self) -> None: self.logger.info(STARTING_HF_TO_GGUF_CONVERSION) try: @@ -1718,6 +1859,9 @@ def import_model(self) -> None: self.load_models() self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name)) + @validate_input( + "imatrix_model", "imatrix_datafile", "imatrix_model", "imatrix_output" + ) def generate_imatrix(self) -> None: self.logger.info(STARTING_IMATRIX_GENERATION) try: diff --git a/src/Localizations.py b/src/Localizations.py index 576b75a..cb1ad06 100644 --- a/src/Localizations.py +++ b/src/Localizations.py @@ -27,6 +27,10 @@ def __init__(self): self.REFRESH_MODELS = "Refresh Models" self.STARTUP_ELASPED_TIME = "Initialization took {0} ms" + # Usage Graphs + self.CPU_USAGE_OVER_TIME = "CPU Usage Over Time" + self.RAM_USAGE_OVER_TIME = "RAM Usage Over Time" + # Environment variables self.DOTENV_FILE_NOT_FOUND = ".env file not found." self.COULD_NOT_PARSE_LINE = "Could not parse line: {0}" @@ -187,6 +191,7 @@ def __init__(self): self.CANCEL = "Cancel" self.RESTART = "Restart" self.DELETE = "Delete" + self.RENAME = "Rename" self.CONFIRM_DELETION = "Are you sure you want to delete this task?" self.TASK_RUNNING_WARNING = ( "Some tasks are still running. Are you sure you want to quit?" @@ -405,6 +410,12 @@ def __init__(self): self.SPLIT_GGUF_COMMAND = "GGUF Split Command" self.SPLIT_GGUF_ERROR = "Error starting GGUF split" + # Model actions + self.CONFIRM_DELETE = "Confirm Delete" + self.DELETE_MODEL_WARNING = "Are you sure you want to delete the model: {}?" + self.MODEL_RENAMED_SUCCESSFULLY = "Model renamed successfully." + self.MODEL_DELETED_SUCCESSFULLY = "Model deleted successfully." + class _French(_Localization): def __init__(self): diff --git a/src/QuantizationThread.py b/src/QuantizationThread.py index 1c07638..daae1ed 100644 --- a/src/QuantizationThread.py +++ b/src/QuantizationThread.py @@ -96,14 +96,12 @@ def parse_progress(self, line, task_item, imatrix_chunks=None) -> None: if imatrix_match: imatrix_chunks = int(imatrix_match.group(1)) elif imatrix_chunks is not None: - save_match = re.search( - r"save_imatrix: stored collected data after (\d+) chunks in .*", - line, - ) - if save_match: - saved_chunks = int(save_match.group(1)) - progress = int((saved_chunks / self.imatrix_chunks) * 100) - task_item.update_progress(progress) + if "save_imatrix: stored collected data" in line: + save_match = re.search(r"collected data after (\d+) chunks", line) + if save_match: + saved_chunks = int(save_match.group(1)) + progress = int((saved_chunks / self.imatrix_chunks) * 100) + task_item.update_progress(progress) def terminate(self) -> None: # Terminate the subprocess if it's still running diff --git a/src/ui_update.py b/src/ui_update.py index aba9938..e935fea 100644 --- a/src/ui_update.py +++ b/src/ui_update.py @@ -85,6 +85,14 @@ def update_system_info(self) -> None: ) self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu)) + # Collect CPU and RAM usage data + self.cpu_data.append(cpu) + self.ram_data.append(ram.percent) + + if len(self.cpu_data) > 60: + self.cpu_data.pop(0) + self.ram_data.pop(0) + def animate_bar(self, bar, target_value) -> None: current_value = bar.value()